diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py
index 51714a2b06..24fa8dbb45 100644
--- a/tests/federation/test_complexity.py
+++ b/tests/federation/test_complexity.py
@@ -18,17 +18,14 @@ from mock import Mock
from twisted.internet import defer
from synapse.api.errors import Codes, SynapseError
-from synapse.config.ratelimiting import FederationRateLimitConfig
-from synapse.federation.transport import server
from synapse.rest import admin
from synapse.rest.client.v1 import login, room
from synapse.types import UserID
-from synapse.util.ratelimitutils import FederationRateLimiter
from tests import unittest
-class RoomComplexityTests(unittest.HomeserverTestCase):
+class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
servlets = [
admin.register_servlets,
@@ -41,25 +38,6 @@ class RoomComplexityTests(unittest.HomeserverTestCase):
config["limit_remote_rooms"] = {"enabled": True, "complexity": 0.05}
return config
- def prepare(self, reactor, clock, homeserver):
- class Authenticator(object):
- def authenticate_request(self, request, content):
- return defer.succeed("otherserver.nottld")
-
- ratelimiter = FederationRateLimiter(
- clock,
- FederationRateLimitConfig(
- window_size=1,
- sleep_limit=1,
- sleep_msec=1,
- reject_limit=1000,
- concurrent_requests=1000,
- ),
- )
- server.register_servlets(
- homeserver, self.resource, Authenticator(), ratelimiter
- )
-
def test_complexity_simple(self):
u1 = self.register_user("u1", "pass")
@@ -105,7 +83,7 @@ class RoomComplexityTests(unittest.HomeserverTestCase):
d = handler._remote_join(
None,
- ["otherserver.example"],
+ ["other.example.com"],
"roomid",
UserID.from_string(u1),
{"membership": "join"},
@@ -146,7 +124,7 @@ class RoomComplexityTests(unittest.HomeserverTestCase):
d = handler._remote_join(
None,
- ["otherserver.example"],
+ ["other.example.com"],
room_1,
UserID.from_string(u1),
{"membership": "join"},
diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py
index cce8d8c6de..d456267b87 100644
--- a/tests/federation/test_federation_sender.py
+++ b/tests/federation/test_federation_sender.py
@@ -19,7 +19,7 @@ from twisted.internet import defer
from synapse.types import ReadReceipt
-from tests.unittest import HomeserverTestCase
+from tests.unittest import HomeserverTestCase, override_config
class FederationSenderTestCases(HomeserverTestCase):
@@ -29,6 +29,7 @@ class FederationSenderTestCases(HomeserverTestCase):
federation_transport_client=Mock(spec=["send_transaction"]),
)
+ @override_config({"send_federation": True})
def test_send_receipts(self):
mock_state_handler = self.hs.get_state_handler()
mock_state_handler.get_current_hosts_in_room.return_value = ["test", "host2"]
@@ -69,6 +70,7 @@ class FederationSenderTestCases(HomeserverTestCase):
],
)
+ @override_config({"send_federation": True})
def test_send_receipts_with_backoff(self):
"""Send two receipts in quick succession; the second should be flushed, but
only after 20ms"""
diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py
index b08be451aa..1ec8c40901 100644
--- a/tests/federation/test_federation_server.py
+++ b/tests/federation/test_federation_server.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
+# Copyright 2019 Matrix.org Federation C.I.C
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,6 +17,8 @@ import logging
from synapse.events import FrozenEvent
from synapse.federation.federation_server import server_matches_acl_event
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
from tests import unittest
@@ -41,6 +44,66 @@ class ServerACLsTestCase(unittest.TestCase):
self.assertTrue(server_matches_acl_event("1:2:3:4", e))
+class StateQueryTests(unittest.FederatingHomeserverTestCase):
+
+ servlets = [
+ admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def test_without_event_id(self):
+ """
+ Querying v1/state/<room_id> without an event ID will return the current
+ known state.
+ """
+ u1 = self.register_user("u1", "pass")
+ u1_token = self.login("u1", "pass")
+
+ room_1 = self.helper.create_room_as(u1, tok=u1_token)
+ self.inject_room_member(room_1, "@user:other.example.com", "join")
+
+ request, channel = self.make_request(
+ "GET", "/_matrix/federation/v1/state/%s" % (room_1,)
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.result)
+
+ self.assertEqual(
+ channel.json_body["room_version"],
+ self.hs.config.default_room_version.identifier,
+ )
+
+ members = set(
+ map(
+ lambda x: x["state_key"],
+ filter(
+ lambda x: x["type"] == "m.room.member", channel.json_body["pdus"]
+ ),
+ )
+ )
+
+ self.assertEqual(members, set(["@user:other.example.com", u1]))
+ self.assertEqual(len(channel.json_body["pdus"]), 6)
+
+ def test_needs_to_be_in_room(self):
+ """
+ Querying v1/state/<room_id> requires the server
+ be in the room to provide data.
+ """
+ u1 = self.register_user("u1", "pass")
+ u1_token = self.login("u1", "pass")
+
+ room_1 = self.helper.create_room_as(u1, tok=u1_token)
+
+ request, channel = self.make_request(
+ "GET", "/_matrix/federation/v1/state/%s" % (room_1,)
+ )
+ self.render(request)
+ self.assertEquals(403, channel.code, channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
+
+
def _create_acl_event(content):
return FrozenEvent(
{
diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py
index 0bb96674a2..70f172eb02 100644
--- a/tests/handlers/test_e2e_room_keys.py
+++ b/tests/handlers/test_e2e_room_keys.py
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
# Copyright 2017 New Vector Ltd
+# Copyright 2019 Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -94,23 +95,29 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
# check we can retrieve it as the current version
res = yield self.handler.get_version_info(self.local_user)
+ version_etag = res["etag"]
+ del res["etag"]
self.assertDictEqual(
res,
{
"version": "1",
"algorithm": "m.megolm_backup.v1",
"auth_data": "first_version_auth_data",
+ "count": 0,
},
)
# check we can retrieve it as a specific version
res = yield self.handler.get_version_info(self.local_user, "1")
+ self.assertEqual(res["etag"], version_etag)
+ del res["etag"]
self.assertDictEqual(
res,
{
"version": "1",
"algorithm": "m.megolm_backup.v1",
"auth_data": "first_version_auth_data",
+ "count": 0,
},
)
@@ -126,12 +133,14 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
# check we can retrieve it as the current version
res = yield self.handler.get_version_info(self.local_user)
+ del res["etag"]
self.assertDictEqual(
res,
{
"version": "2",
"algorithm": "m.megolm_backup.v1",
"auth_data": "second_version_auth_data",
+ "count": 0,
},
)
@@ -158,12 +167,14 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
# check we can retrieve it as the current version
res = yield self.handler.get_version_info(self.local_user)
+ del res["etag"]
self.assertDictEqual(
res,
{
"algorithm": "m.megolm_backup.v1",
"auth_data": "revised_first_version_auth_data",
"version": version,
+ "count": 0,
},
)
@@ -207,12 +218,14 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
# check we can retrieve it as the current version
res = yield self.handler.get_version_info(self.local_user)
+ del res["etag"] # etag is opaque, so don't test its contents
self.assertDictEqual(
res,
{
"algorithm": "m.megolm_backup.v1",
"auth_data": "revised_first_version_auth_data",
"version": version,
+ "count": 0,
},
)
@@ -409,6 +422,11 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
yield self.handler.upload_room_keys(self.local_user, version, room_keys)
+ # get the etag to compare to future versions
+ res = yield self.handler.get_version_info(self.local_user)
+ backup_etag = res["etag"]
+ self.assertEqual(res["count"], 1)
+
new_room_keys = copy.deepcopy(room_keys)
new_room_key = new_room_keys["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]
@@ -423,6 +441,10 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
"SSBBTSBBIEZJU0gK",
)
+ # the etag should be the same since the session did not change
+ res = yield self.handler.get_version_info(self.local_user)
+ self.assertEqual(res["etag"], backup_etag)
+
# test that marking the session as verified however /does/ replace it
new_room_key["is_verified"] = True
yield self.handler.upload_room_keys(self.local_user, version, new_room_keys)
@@ -432,6 +454,11 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new"
)
+ # the etag should NOT be equal now, since the key changed
+ res = yield self.handler.get_version_info(self.local_user)
+ self.assertNotEqual(res["etag"], backup_etag)
+ backup_etag = res["etag"]
+
# test that a session with a higher forwarded_count doesn't replace one
# with a lower forwarding count
new_room_key["forwarded_count"] = 2
@@ -443,6 +470,10 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new"
)
+ # the etag should be the same since the session did not change
+ res = yield self.handler.get_version_info(self.local_user)
+ self.assertEqual(res["etag"], backup_etag)
+
# TODO: check edge cases as well as the common variations here
@defer.inlineCallbacks
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index d56220f403..b4d92cf732 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -12,13 +12,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+
from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError, Codes
+from synapse.federation.federation_base import event_from_pdu_json
+from synapse.logging.context import LoggingContext, run_in_background
from synapse.rest import admin
from synapse.rest.client.v1 import login, room
from tests import unittest
+logger = logging.getLogger(__name__)
+
class FederationTestCase(unittest.HomeserverTestCase):
servlets = [
@@ -79,3 +85,123 @@ class FederationTestCase(unittest.HomeserverTestCase):
self.assertEqual(failure.code, 403, failure)
self.assertEqual(failure.errcode, Codes.FORBIDDEN, failure)
self.assertEqual(failure.msg, "You are not invited to this room.")
+
+ def test_rejected_message_event_state(self):
+ """
+ Check that we store the state group correctly for rejected non-state events.
+
+ Regression test for #6289.
+ """
+ OTHER_SERVER = "otherserver"
+ OTHER_USER = "@otheruser:" + OTHER_SERVER
+
+ # create the room
+ user_id = self.register_user("kermit", "test")
+ tok = self.login("kermit", "test")
+ room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
+
+ # pretend that another server has joined
+ join_event = self._build_and_send_join_event(OTHER_SERVER, OTHER_USER, room_id)
+
+ # check the state group
+ sg = self.successResultOf(
+ self.store._get_state_group_for_event(join_event.event_id)
+ )
+
+ # build and send an event which will be rejected
+ ev = event_from_pdu_json(
+ {
+ "type": EventTypes.Message,
+ "content": {},
+ "room_id": room_id,
+ "sender": "@yetanotheruser:" + OTHER_SERVER,
+ "depth": join_event["depth"] + 1,
+ "prev_events": [join_event.event_id],
+ "auth_events": [],
+ "origin_server_ts": self.clock.time_msec(),
+ },
+ join_event.format_version,
+ )
+
+ with LoggingContext(request="send_rejected"):
+ d = run_in_background(self.handler.on_receive_pdu, OTHER_SERVER, ev)
+ self.get_success(d)
+
+ # that should have been rejected
+ e = self.get_success(self.store.get_event(ev.event_id, allow_rejected=True))
+ self.assertIsNotNone(e.rejected_reason)
+
+ # ... and the state group should be the same as before
+ sg2 = self.successResultOf(self.store._get_state_group_for_event(ev.event_id))
+
+ self.assertEqual(sg, sg2)
+
+ def test_rejected_state_event_state(self):
+ """
+ Check that we store the state group correctly for rejected state events.
+
+ Regression test for #6289.
+ """
+ OTHER_SERVER = "otherserver"
+ OTHER_USER = "@otheruser:" + OTHER_SERVER
+
+ # create the room
+ user_id = self.register_user("kermit", "test")
+ tok = self.login("kermit", "test")
+ room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
+
+ # pretend that another server has joined
+ join_event = self._build_and_send_join_event(OTHER_SERVER, OTHER_USER, room_id)
+
+ # check the state group
+ sg = self.successResultOf(
+ self.store._get_state_group_for_event(join_event.event_id)
+ )
+
+ # build and send an event which will be rejected
+ ev = event_from_pdu_json(
+ {
+ "type": "org.matrix.test",
+ "state_key": "test_key",
+ "content": {},
+ "room_id": room_id,
+ "sender": "@yetanotheruser:" + OTHER_SERVER,
+ "depth": join_event["depth"] + 1,
+ "prev_events": [join_event.event_id],
+ "auth_events": [],
+ "origin_server_ts": self.clock.time_msec(),
+ },
+ join_event.format_version,
+ )
+
+ with LoggingContext(request="send_rejected"):
+ d = run_in_background(self.handler.on_receive_pdu, OTHER_SERVER, ev)
+ self.get_success(d)
+
+ # that should have been rejected
+ e = self.get_success(self.store.get_event(ev.event_id, allow_rejected=True))
+ self.assertIsNotNone(e.rejected_reason)
+
+ # ... and the state group should be the same as before
+ sg2 = self.successResultOf(self.store._get_state_group_for_event(ev.event_id))
+
+ self.assertEqual(sg, sg2)
+
+ def _build_and_send_join_event(self, other_server, other_user, room_id):
+ join_event = self.get_success(
+ self.handler.on_make_join_request(other_server, room_id, other_user)
+ )
+ # the auth code requires that a signature exists, but doesn't check that
+ # signature... go figure.
+ join_event.signatures[other_server] = {"x": "y"}
+ with LoggingContext(request="send_join"):
+ d = run_in_background(
+ self.handler.on_send_join_request, other_server, join_event
+ )
+ self.get_success(d)
+
+ # sanity-check: the room should show that the new user is a member
+ r = self.get_success(self.store.get_current_state_ids(room_id))
+ self.assertEqual(r[(EventTypes.Member, other_user)], join_event.event_id)
+
+ return join_event
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 5ec568f4e6..f6d8660285 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -24,6 +24,7 @@ from synapse.api.errors import AuthError
from synapse.types import UserID
from tests import unittest
+from tests.unittest import override_config
from tests.utils import register_federation_servlets
# Some local users to test with
@@ -174,6 +175,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
],
)
+ @override_config({"send_federation": True})
def test_started_typing_remote_send(self):
self.room_members = [U_APPLE, U_ONION]
@@ -237,6 +239,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
],
)
+ @override_config({"send_federation": True})
def test_stopped_typing(self):
self.room_members = [U_APPLE, U_BANANA, U_ONION]
diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py
index 4f924ce451..e7472e3a93 100644
--- a/tests/replication/slave/storage/_base.py
+++ b/tests/replication/slave/storage/_base.py
@@ -48,7 +48,10 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase):
server_factory = ReplicationStreamProtocolFactory(self.hs)
self.streamer = server_factory.streamer
+ handler_factory = Mock()
self.replication_handler = ReplicationClientHandler(self.slaved_store)
+ self.replication_handler.factory = handler_factory
+
client_factory = ReplicationClientFactory(
self.hs, "client_name", self.replication_handler
)
diff --git a/tests/replication/tcp/streams/_base.py b/tests/replication/tcp/streams/_base.py
index ce3835ae6a..1d14e77255 100644
--- a/tests/replication/tcp/streams/_base.py
+++ b/tests/replication/tcp/streams/_base.py
@@ -12,6 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from mock import Mock
+
from synapse.replication.tcp.commands import ReplicateCommand
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
@@ -30,7 +32,9 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
server = server_factory.buildProtocol(None)
# build a replication client, with a dummy handler
+ handler_factory = Mock()
self.test_handler = TestReplicationClientHandler()
+ self.test_handler.factory = handler_factory
self.client = ClientReplicationStreamProtocol(
"client", "test", clock, self.test_handler
)
diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index 8e1ca8b738..9575058252 100644
--- a/tests/rest/admin/test_admin.py
+++ b/tests/rest/admin/test_admin.py
@@ -628,10 +628,12 @@ class PurgeRoomTestCase(unittest.HomeserverTestCase):
"local_invites",
"room_account_data",
"room_tags",
+ "state_groups",
+ "state_groups_state",
):
count = self.get_success(
self.store._simple_select_one_onecol(
- table="events",
+ table=table,
keyvalues={"room_id": room_id},
retcol="COUNT(*)",
desc="test_purge_room",
@@ -639,3 +641,5 @@ class PurgeRoomTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(count, 0, msg="Rows not purged in {}".format(table))
+
+ test_purge_room.skip = "Disabled because it's currently broken"
diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py
new file mode 100644
index 0000000000..95475bb651
--- /dev/null
+++ b/tests/rest/client/test_retention.py
@@ -0,0 +1,293 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from mock import Mock
+
+from synapse.api.constants import EventTypes
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+from synapse.visibility import filter_events_for_client
+
+from tests import unittest
+
+one_hour_ms = 3600000
+one_day_ms = one_hour_ms * 24
+
+
+class RetentionTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ config = self.default_config()
+ config["retention"] = {
+ "enabled": True,
+ "default_policy": {
+ "min_lifetime": one_day_ms,
+ "max_lifetime": one_day_ms * 3,
+ },
+ "allowed_lifetime_min": one_day_ms,
+ "allowed_lifetime_max": one_day_ms * 3,
+ }
+
+ self.hs = self.setup_test_homeserver(config=config)
+ return self.hs
+
+ def prepare(self, reactor, clock, homeserver):
+ self.user_id = self.register_user("user", "password")
+ self.token = self.login("user", "password")
+
+ def test_retention_state_event(self):
+ """Tests that the server configuration can limit the values a user can set to the
+ room's retention policy.
+ """
+ room_id = self.helper.create_room_as(self.user_id, tok=self.token)
+
+ self.helper.send_state(
+ room_id=room_id,
+ event_type=EventTypes.Retention,
+ body={"max_lifetime": one_day_ms * 4},
+ tok=self.token,
+ expect_code=400,
+ )
+
+ self.helper.send_state(
+ room_id=room_id,
+ event_type=EventTypes.Retention,
+ body={"max_lifetime": one_hour_ms},
+ tok=self.token,
+ expect_code=400,
+ )
+
+ def test_retention_event_purged_with_state_event(self):
+ """Tests that expired events are correctly purged when the room's retention policy
+ is defined by a state event.
+ """
+ room_id = self.helper.create_room_as(self.user_id, tok=self.token)
+
+ # Set the room's retention period to 2 days.
+ lifetime = one_day_ms * 2
+ self.helper.send_state(
+ room_id=room_id,
+ event_type=EventTypes.Retention,
+ body={"max_lifetime": lifetime},
+ tok=self.token,
+ )
+
+ self._test_retention_event_purged(room_id, one_day_ms * 1.5)
+
+ def test_retention_event_purged_without_state_event(self):
+ """Tests that expired events are correctly purged when the room's retention policy
+ is defined by the server's configuration's default retention policy.
+ """
+ room_id = self.helper.create_room_as(self.user_id, tok=self.token)
+
+ self._test_retention_event_purged(room_id, one_day_ms * 2)
+
+ def test_visibility(self):
+ """Tests that synapse.visibility.filter_events_for_client correctly filters out
+ outdated events
+ """
+ store = self.hs.get_datastore()
+ storage = self.hs.get_storage()
+ room_id = self.helper.create_room_as(self.user_id, tok=self.token)
+ events = []
+
+ # Send a first event, which should be filtered out at the end of the test.
+ resp = self.helper.send(room_id=room_id, body="1", tok=self.token)
+
+ # Get the event from the store so that we end up with a FrozenEvent that we can
+ # give to filter_events_for_client. We need to do this now because the event won't
+ # be in the database anymore after it has expired.
+ events.append(self.get_success(store.get_event(resp.get("event_id"))))
+
+ # Advance the time by 2 days. We're using the default retention policy, therefore
+ # after this the first event will still be valid.
+ self.reactor.advance(one_day_ms * 2 / 1000)
+
+ # Send another event, which shouldn't get filtered out.
+ resp = self.helper.send(room_id=room_id, body="2", tok=self.token)
+
+ valid_event_id = resp.get("event_id")
+
+ events.append(self.get_success(store.get_event(valid_event_id)))
+
+ # Advance the time by anothe 2 days. After this, the first event should be
+ # outdated but not the second one.
+ self.reactor.advance(one_day_ms * 2 / 1000)
+
+ # Run filter_events_for_client with our list of FrozenEvents.
+ filtered_events = self.get_success(
+ filter_events_for_client(storage, self.user_id, events)
+ )
+
+ # We should only get one event back.
+ self.assertEqual(len(filtered_events), 1, filtered_events)
+ # That event should be the second, not outdated event.
+ self.assertEqual(filtered_events[0].event_id, valid_event_id, filtered_events)
+
+ def _test_retention_event_purged(self, room_id, increment):
+ # Get the create event to, later, check that we can still access it.
+ message_handler = self.hs.get_message_handler()
+ create_event = self.get_success(
+ message_handler.get_room_data(self.user_id, room_id, EventTypes.Create)
+ )
+
+ # Send a first event to the room. This is the event we'll want to be purged at the
+ # end of the test.
+ resp = self.helper.send(room_id=room_id, body="1", tok=self.token)
+
+ expired_event_id = resp.get("event_id")
+
+ # Check that we can retrieve the event.
+ expired_event = self.get_event(room_id, expired_event_id)
+ self.assertEqual(
+ expired_event.get("content", {}).get("body"), "1", expired_event
+ )
+
+ # Advance the time.
+ self.reactor.advance(increment / 1000)
+
+ # Send another event. We need this because the purge job won't purge the most
+ # recent event in the room.
+ resp = self.helper.send(room_id=room_id, body="2", tok=self.token)
+
+ valid_event_id = resp.get("event_id")
+
+ # Advance the time again. Now our first event should have expired but our second
+ # one should still be kept.
+ self.reactor.advance(increment / 1000)
+
+ # Check that the event has been purged from the database.
+ self.get_event(room_id, expired_event_id, expected_code=404)
+
+ # Check that the event that hasn't been purged can still be retrieved.
+ valid_event = self.get_event(room_id, valid_event_id)
+ self.assertEqual(valid_event.get("content", {}).get("body"), "2", valid_event)
+
+ # Check that we can still access state events that were sent before the event that
+ # has been purged.
+ self.get_event(room_id, create_event.event_id)
+
+ def get_event(self, room_id, event_id, expected_code=200):
+ url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id)
+
+ request, channel = self.make_request("GET", url, access_token=self.token)
+ self.render(request)
+
+ self.assertEqual(channel.code, expected_code, channel.result)
+
+ return channel.json_body
+
+
+class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ config = self.default_config()
+ config["retention"] = {
+ "enabled": True,
+ }
+
+ mock_federation_client = Mock(spec=["backfill"])
+
+ self.hs = self.setup_test_homeserver(
+ config=config, federation_client=mock_federation_client,
+ )
+ return self.hs
+
+ def prepare(self, reactor, clock, homeserver):
+ self.user_id = self.register_user("user", "password")
+ self.token = self.login("user", "password")
+
+ def test_no_default_policy(self):
+ """Tests that an event doesn't get expired if there is neither a default retention
+ policy nor a policy specific to the room.
+ """
+ room_id = self.helper.create_room_as(self.user_id, tok=self.token)
+
+ self._test_retention(room_id)
+
+ def test_state_policy(self):
+ """Tests that an event gets correctly expired if there is no default retention
+ policy but there's a policy specific to the room.
+ """
+ room_id = self.helper.create_room_as(self.user_id, tok=self.token)
+
+ # Set the maximum lifetime to 35 days so that the first event gets expired but not
+ # the second one.
+ self.helper.send_state(
+ room_id=room_id,
+ event_type=EventTypes.Retention,
+ body={"max_lifetime": one_day_ms * 35},
+ tok=self.token,
+ )
+
+ self._test_retention(room_id, expected_code_for_first_event=404)
+
+ def _test_retention(self, room_id, expected_code_for_first_event=200):
+ # Send a first event to the room. This is the event we'll want to be purged at the
+ # end of the test.
+ resp = self.helper.send(room_id=room_id, body="1", tok=self.token)
+
+ first_event_id = resp.get("event_id")
+
+ # Check that we can retrieve the event.
+ expired_event = self.get_event(room_id, first_event_id)
+ self.assertEqual(
+ expired_event.get("content", {}).get("body"), "1", expired_event
+ )
+
+ # Advance the time by a month.
+ self.reactor.advance(one_day_ms * 30 / 1000)
+
+ # Send another event. We need this because the purge job won't purge the most
+ # recent event in the room.
+ resp = self.helper.send(room_id=room_id, body="2", tok=self.token)
+
+ second_event_id = resp.get("event_id")
+
+ # Advance the time by another month.
+ self.reactor.advance(one_day_ms * 30 / 1000)
+
+ # Check if the event has been purged from the database.
+ first_event = self.get_event(
+ room_id, first_event_id, expected_code=expected_code_for_first_event
+ )
+
+ if expected_code_for_first_event == 200:
+ self.assertEqual(
+ first_event.get("content", {}).get("body"), "1", first_event
+ )
+
+ # Check that the event that hasn't been purged can still be retrieved.
+ second_event = self.get_event(room_id, second_event_id)
+ self.assertEqual(second_event.get("content", {}).get("body"), "2", second_event)
+
+ def get_event(self, room_id, event_id, expected_code=200):
+ url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id)
+
+ request, channel = self.make_request("GET", url, access_token=self.token)
+ self.render(request)
+
+ self.assertEqual(channel.code, expected_code, channel.result)
+
+ return channel.json_body
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index c5d67fc1cd..ea4b58d37a 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -27,7 +27,9 @@ from twisted.internet import defer
import synapse.rest.admin
from synapse.api.constants import EventContentFields, EventTypes, Membership
+from synapse.handlers.pagination import PurgeStatus
from synapse.rest.client.v1 import login, profile, room
+from synapse.util.stringutils import random_string
from tests import unittest
@@ -1011,6 +1013,146 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase):
self.assertEqual(res_displayname, self.displayname, channel.result)
+class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
+ """Tests that clients can add a "reason" field to membership events and
+ that they get correctly added to the generated events and propagated.
+ """
+
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, homeserver):
+ self.creator = self.register_user("creator", "test")
+ self.creator_tok = self.login("creator", "test")
+
+ self.second_user_id = self.register_user("second", "test")
+ self.second_tok = self.login("second", "test")
+
+ self.room_id = self.helper.create_room_as(self.creator, tok=self.creator_tok)
+
+ def test_join_reason(self):
+ reason = "hello"
+ request, channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/rooms/{}/join".format(self.room_id),
+ content={"reason": reason},
+ access_token=self.second_tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.result)
+
+ self._check_for_reason(reason)
+
+ def test_leave_reason(self):
+ self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok)
+
+ reason = "hello"
+ request, channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/rooms/{}/leave".format(self.room_id),
+ content={"reason": reason},
+ access_token=self.second_tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.result)
+
+ self._check_for_reason(reason)
+
+ def test_kick_reason(self):
+ self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok)
+
+ reason = "hello"
+ request, channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/rooms/{}/kick".format(self.room_id),
+ content={"reason": reason, "user_id": self.second_user_id},
+ access_token=self.second_tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.result)
+
+ self._check_for_reason(reason)
+
+ def test_ban_reason(self):
+ self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok)
+
+ reason = "hello"
+ request, channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/rooms/{}/ban".format(self.room_id),
+ content={"reason": reason, "user_id": self.second_user_id},
+ access_token=self.creator_tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.result)
+
+ self._check_for_reason(reason)
+
+ def test_unban_reason(self):
+ reason = "hello"
+ request, channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/rooms/{}/unban".format(self.room_id),
+ content={"reason": reason, "user_id": self.second_user_id},
+ access_token=self.creator_tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.result)
+
+ self._check_for_reason(reason)
+
+ def test_invite_reason(self):
+ reason = "hello"
+ request, channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/rooms/{}/invite".format(self.room_id),
+ content={"reason": reason, "user_id": self.second_user_id},
+ access_token=self.creator_tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.result)
+
+ self._check_for_reason(reason)
+
+ def test_reject_invite_reason(self):
+ self.helper.invite(
+ self.room_id,
+ src=self.creator,
+ targ=self.second_user_id,
+ tok=self.creator_tok,
+ )
+
+ reason = "hello"
+ request, channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/rooms/{}/leave".format(self.room_id),
+ content={"reason": reason},
+ access_token=self.second_tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.result)
+
+ self._check_for_reason(reason)
+
+ def _check_for_reason(self, reason):
+ request, channel = self.make_request(
+ "GET",
+ "/_matrix/client/r0/rooms/{}/state/m.room.member/{}".format(
+ self.room_id, self.second_user_id
+ ),
+ access_token=self.creator_tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.result)
+
+ event_content = channel.json_body
+
+ self.assertEqual(event_content.get("reason"), reason, channel.result)
+
+
class LabelsTestCase(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
@@ -1319,7 +1461,6 @@ class LabelsTestCase(unittest.HomeserverTestCase):
def _send_labelled_messages_in_room(self):
"""Sends several messages to a room with different labels (or without any) to test
filtering by label.
-
Returns:
The ID of the event to use if we're testing filtering on /context.
"""
@@ -1383,4 +1524,4 @@ class LabelsTestCase(unittest.HomeserverTestCase):
tok=self.tok,
)
- return event_id
+ return event_id
\ No newline at end of file
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index dab87e5edf..c0d0d2b44e 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -203,6 +203,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
@unittest.override_config(
{
+ "public_baseurl": "https://test_server",
"enable_registration_captcha": True,
"user_consent": {
"version": "1",
diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py
index 976652aee8..852b8ab11c 100644
--- a/tests/rest/media/v1/test_url_preview.py
+++ b/tests/rest/media/v1/test_url_preview.py
@@ -247,6 +247,41 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body["og:title"], "\u0434\u043a\u0430")
+ def test_overlong_title(self):
+ self.lookups["matrix.org"] = [(IPv4Address, "8.8.8.8")]
+
+ end_content = (
+ b"<html><head>"
+ b"<title>" + b"x" * 2000 + b"</title>"
+ b'<meta property="og:description" content="hi" />'
+ b"</head></html>"
+ )
+
+ request, channel = self.make_request(
+ "GET", "url_preview?url=http://matrix.org", shorthand=False
+ )
+ request.render(self.preview_url)
+ self.pump()
+
+ client = self.reactor.tcpClients[0][2].buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+ client.dataReceived(
+ (
+ b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
+ b'Content-Type: text/html; charset="windows-1251"\r\n\r\n'
+ )
+ % (len(end_content),)
+ + end_content
+ )
+
+ self.pump()
+ self.assertEqual(channel.code, 200)
+ res = channel.json_body
+ # We should only see the `og:description` field, as `title` is too long and should be stripped out
+ self.assertCountEqual(["og:description"], res.keys())
+
def test_ipaddr(self):
"""
IP addresses can be previewed directly.
diff --git a/tests/server.py b/tests/server.py
index f878aeaada..2b7cf4242e 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -379,6 +379,7 @@ class FakeTransport(object):
disconnecting = False
disconnected = False
+ connected = True
buffer = attr.ib(default=b"")
producer = attr.ib(default=None)
autoflush = attr.ib(default=True)
@@ -402,6 +403,7 @@ class FakeTransport(object):
"FakeTransport: Delaying disconnect until buffer is flushed"
)
else:
+ self.connected = False
self.disconnected = True
def abortConnection(self):
diff --git a/tests/storage/test_e2e_room_keys.py b/tests/storage/test_e2e_room_keys.py
index d128fde441..35dafbb904 100644
--- a/tests/storage/test_e2e_room_keys.py
+++ b/tests/storage/test_e2e_room_keys.py
@@ -39,8 +39,8 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
)
self.get_success(
- self.store.set_e2e_room_key(
- "user_id", version1, "room", "session", room_key
+ self.store.add_e2e_room_keys(
+ "user_id", version1, [("room", "session", room_key)]
)
)
@@ -51,8 +51,8 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
)
self.get_success(
- self.store.set_e2e_room_key(
- "user_id", version2, "room", "session", room_key
+ self.store.add_e2e_room_keys(
+ "user_id", version2, [("room", "session", room_key)]
)
)
diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py
index f671599cb8..b9fafaa1a6 100644
--- a/tests/storage/test_purge.py
+++ b/tests/storage/test_purge.py
@@ -40,23 +40,24 @@ class PurgeTests(HomeserverTestCase):
third = self.helper.send(self.room_id, body="test3")
last = self.helper.send(self.room_id, body="test4")
- storage = self.hs.get_datastore()
+ store = self.hs.get_datastore()
+ storage = self.hs.get_storage()
# Get the topological token
- event = storage.get_topological_token_for_event(last["event_id"])
+ event = store.get_topological_token_for_event(last["event_id"])
self.pump()
event = self.successResultOf(event)
# Purge everything before this topological token
- purge = storage.purge_history(self.room_id, event, True)
+ purge = storage.purge_events.purge_history(self.room_id, event, True)
self.pump()
self.assertEqual(self.successResultOf(purge), None)
# Try and get the events
- get_first = storage.get_event(first["event_id"])
- get_second = storage.get_event(second["event_id"])
- get_third = storage.get_event(third["event_id"])
- get_last = storage.get_event(last["event_id"])
+ get_first = store.get_event(first["event_id"])
+ get_second = store.get_event(second["event_id"])
+ get_third = store.get_event(third["event_id"])
+ get_last = store.get_event(last["event_id"])
self.pump()
# 1-3 should fail and last will succeed, meaning that 1-3 are deleted
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index 9ddd17f73d..105a0c2b02 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -16,8 +16,7 @@
from unittest.mock import Mock
-from synapse.api.constants import EventTypes, Membership
-from synapse.api.room_versions import RoomVersions
+from synapse.api.constants import Membership
from synapse.rest.admin import register_servlets_for_client_rest_resource
from synapse.rest.client.v1 import login, room
from synapse.types import Requester, UserID
@@ -44,9 +43,6 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
# We can't test the RoomMemberStore on its own without the other event
# storage logic
self.store = hs.get_datastore()
- self.storage = hs.get_storage()
- self.event_builder_factory = hs.get_event_builder_factory()
- self.event_creation_handler = hs.get_event_creation_handler()
self.u_alice = self.register_user("alice", "pass")
self.t_alice = self.login("alice", "pass")
@@ -55,26 +51,6 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
# User elsewhere on another host
self.u_charlie = UserID.from_string("@charlie:elsewhere")
- def inject_room_member(self, room, user, membership, replaces_state=None):
- builder = self.event_builder_factory.for_room_version(
- RoomVersions.V1,
- {
- "type": EventTypes.Member,
- "sender": user,
- "state_key": user,
- "room_id": room,
- "content": {"membership": membership},
- },
- )
-
- event, context = self.get_success(
- self.event_creation_handler.create_new_client_event(builder)
- )
-
- self.get_success(self.storage.persistence.persist_event(event, context))
-
- return event
-
def test_one_member(self):
# Alice creates the room, and is automatically joined
diff --git a/tests/test_state.py b/tests/test_state.py
index 38246555bd..176535947a 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -21,6 +21,7 @@ from synapse.api.auth import Auth
from synapse.api.constants import EventTypes, Membership
from synapse.api.room_versions import RoomVersions
from synapse.events import FrozenEvent
+from synapse.events.snapshot import EventContext
from synapse.state import StateHandler, StateResolutionHandler
from tests import unittest
@@ -198,16 +199,22 @@ class StateTestCase(unittest.TestCase):
self.store.register_events(graph.walk())
- context_store = {}
+ context_store = {} # type: dict[str, EventContext]
for event in graph.walk():
context = yield self.state.compute_event_context(event)
self.store.register_event_context(event, context)
context_store[event.event_id] = context
- prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store)
+ ctx_c = context_store["C"]
+ ctx_d = context_store["D"]
+
+ prev_state_ids = yield ctx_d.get_prev_state_ids(self.store)
self.assertEqual(2, len(prev_state_ids))
+ self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event)
+ self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group)
+
@defer.inlineCallbacks
def test_branch_basic_conflict(self):
graph = Graph(
@@ -241,12 +248,19 @@ class StateTestCase(unittest.TestCase):
self.store.register_event_context(event, context)
context_store[event.event_id] = context
- prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store)
+ # C ends up winning the resolution between B and C
+
+ ctx_c = context_store["C"]
+ ctx_d = context_store["D"]
+ prev_state_ids = yield ctx_d.get_prev_state_ids(self.store)
self.assertSetEqual(
{"START", "A", "C"}, {e_id for e_id in prev_state_ids.values()}
)
+ self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event)
+ self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group)
+
@defer.inlineCallbacks
def test_branch_have_banned_conflict(self):
graph = Graph(
@@ -292,11 +306,18 @@ class StateTestCase(unittest.TestCase):
self.store.register_event_context(event, context)
context_store[event.event_id] = context
- prev_state_ids = yield context_store["E"].get_prev_state_ids(self.store)
+ # C ends up winning the resolution between C and D because bans win over other
+ # changes
+
+ ctx_c = context_store["C"]
+ ctx_e = context_store["E"]
+ prev_state_ids = yield ctx_e.get_prev_state_ids(self.store)
self.assertSetEqual(
{"START", "A", "B", "C"}, {e for e in prev_state_ids.values()}
)
+ self.assertEqual(ctx_c.state_group, ctx_e.state_group_before_event)
+ self.assertEqual(ctx_e.state_group_before_event, ctx_e.state_group)
@defer.inlineCallbacks
def test_branch_have_perms_conflict(self):
@@ -360,12 +381,20 @@ class StateTestCase(unittest.TestCase):
self.store.register_event_context(event, context)
context_store[event.event_id] = context
- prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store)
+ # B ends up winning the resolution between B and C because power levels
+ # win over other changes.
+ ctx_b = context_store["B"]
+ ctx_d = context_store["D"]
+
+ prev_state_ids = yield ctx_d.get_prev_state_ids(self.store)
self.assertSetEqual(
{"A1", "A2", "A3", "A5", "B"}, {e for e in prev_state_ids.values()}
)
+ self.assertEqual(ctx_b.state_group, ctx_d.state_group_before_event)
+ self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group)
+
def _add_depths(self, nodes, edges):
def _get_depth(ev):
node = nodes[ev]
@@ -390,13 +419,16 @@ class StateTestCase(unittest.TestCase):
context = yield self.state.compute_event_context(event, old_state=old_state)
- current_state_ids = yield context.get_current_state_ids(self.store)
+ prev_state_ids = yield context.get_prev_state_ids(self.store)
+ self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
- self.assertEqual(
- set(e.event_id for e in old_state), set(current_state_ids.values())
+ current_state_ids = yield context.get_current_state_ids(self.store)
+ self.assertCountEqual(
+ (e.event_id for e in old_state), current_state_ids.values()
)
- self.assertIsNotNone(context.state_group)
+ self.assertIsNotNone(context.state_group_before_event)
+ self.assertEqual(context.state_group_before_event, context.state_group)
@defer.inlineCallbacks
def test_annotate_with_old_state(self):
@@ -411,11 +443,18 @@ class StateTestCase(unittest.TestCase):
context = yield self.state.compute_event_context(event, old_state=old_state)
prev_state_ids = yield context.get_prev_state_ids(self.store)
+ self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
- self.assertEqual(
- set(e.event_id for e in old_state), set(prev_state_ids.values())
+ current_state_ids = yield context.get_current_state_ids(self.store)
+ self.assertCountEqual(
+ (e.event_id for e in old_state + [event]), current_state_ids.values()
)
+ self.assertIsNotNone(context.state_group_before_event)
+ self.assertNotEqual(context.state_group_before_event, context.state_group)
+ self.assertEqual(context.state_group_before_event, context.prev_group)
+ self.assertEqual({("state", ""): event.event_id}, context.delta_ids)
+
@defer.inlineCallbacks
def test_trivial_annotate_message(self):
prev_event_id = "prev_event_id"
diff --git a/tests/unittest.py b/tests/unittest.py
index 561cebc223..31997a0f31 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018 New Vector
+# Copyright 2019 Matrix.org Federation C.I.C
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,6 +14,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
import gc
import hashlib
import hmac
@@ -27,13 +29,17 @@ from twisted.internet.defer import Deferred, succeed
from twisted.python.threadpool import ThreadPool
from twisted.trial import unittest
-from synapse.api.constants import EventTypes
+from synapse.api.constants import EventTypes, Membership
+from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.config.homeserver import HomeServerConfig
+from synapse.config.ratelimiting import FederationRateLimitConfig
+from synapse.federation.transport import server as federation_server
from synapse.http.server import JsonResource
from synapse.http.site import SynapseRequest
from synapse.logging.context import LoggingContext
from synapse.server import HomeServer
from synapse.types import Requester, UserID, create_requester
+from synapse.util.ratelimitutils import FederationRateLimiter
from tests.server import get_clock, make_request, render, setup_test_homeserver
from tests.test_utils.logging_setup import setup_logging
@@ -559,6 +565,66 @@ class HomeserverTestCase(TestCase):
self.render(request)
self.assertEqual(channel.code, 403, channel.result)
+ def inject_room_member(self, room: str, user: str, membership: Membership) -> None:
+ """
+ Inject a membership event into a room.
+
+ Args:
+ room: Room ID to inject the event into.
+ user: MXID of the user to inject the membership for.
+ membership: The membership type.
+ """
+ event_builder_factory = self.hs.get_event_builder_factory()
+ event_creation_handler = self.hs.get_event_creation_handler()
+
+ room_version = self.get_success(self.hs.get_datastore().get_room_version(room))
+
+ builder = event_builder_factory.for_room_version(
+ KNOWN_ROOM_VERSIONS[room_version],
+ {
+ "type": EventTypes.Member,
+ "sender": user,
+ "state_key": user,
+ "room_id": room,
+ "content": {"membership": membership},
+ },
+ )
+
+ event, context = self.get_success(
+ event_creation_handler.create_new_client_event(builder)
+ )
+
+ self.get_success(
+ self.hs.get_storage().persistence.persist_event(event, context)
+ )
+
+
+class FederatingHomeserverTestCase(HomeserverTestCase):
+ """
+ A federating homeserver that authenticates incoming requests as `other.example.com`.
+ """
+
+ def prepare(self, reactor, clock, homeserver):
+ class Authenticator(object):
+ def authenticate_request(self, request, content):
+ return succeed("other.example.com")
+
+ ratelimiter = FederationRateLimiter(
+ clock,
+ FederationRateLimitConfig(
+ window_size=1,
+ sleep_limit=1,
+ sleep_msec=1,
+ reject_limit=1000,
+ concurrent_requests=1000,
+ ),
+ )
+ federation_server.register_servlets(
+ homeserver, self.resource, Authenticator(), ratelimiter
+ )
+
+ return super().prepare(reactor, clock, homeserver)
+
def override_config(extra_config):
"""A decorator which can be applied to test functions to give additional HS config
diff --git a/tests/utils.py b/tests/utils.py
index 7dc9bdc505..de2ac1ed33 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -109,6 +109,7 @@ def default_config(name, parse=False):
"""
config_dict = {
"server_name": name,
+ "send_federation": False,
"media_store_path": "media",
"uploads_path": "uploads",
# the test signing key is just an arbitrary ed25519 key to keep the config
|