diff --git a/tests/federation/test_federation_catch_up.py b/tests/federation/test_federation_catch_up.py
index 6381583c24..391ae51707 100644
--- a/tests/federation/test_federation_catch_up.py
+++ b/tests/federation/test_federation_catch_up.py
@@ -1,4 +1,5 @@
-from typing import Callable, List, Optional, Tuple
+from typing import Callable, Collection, List, Optional, Tuple
+from unittest import mock
from unittest.mock import Mock
from twisted.test.proto_helpers import MemoryReactor
@@ -500,3 +501,87 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
self.assertEqual(len(sent_pdus), 1)
self.assertEqual(sent_pdus[0].event_id, event_2.event_id)
self.assertFalse(per_dest_queue._catching_up)
+
+ def test_catch_up_is_not_blocked_by_remote_event_in_partial_state_room(
+ self,
+ ) -> None:
+ """Detects (part of?) https://github.com/matrix-org/synapse/issues/15220."""
+ # ARRANGE:
+ # - a local user (u1)
+ # - a room which contains u1 and two remote users, @u2:host2 and @u3:other
+ # - events in that room such that
+ # - history visibility is restricted
+ # - u1 sent message events e1 and e2
+ # - afterwards, u3 sent a remote event e3
+ # - catchup to begin for host2; last successfully sent event was e1
+ per_dest_queue, sent_pdus = self.make_fake_destination_queue()
+
+ self.register_user("u1", "you the one")
+ u1_token = self.login("u1", "you the one")
+ room = self.helper.create_room_as("u1", tok=u1_token)
+ self.helper.send_state(
+ room_id=room,
+ event_type="m.room.history_visibility",
+ body={"history_visibility": "joined"},
+ tok=u1_token,
+ )
+ self.get_success(
+ event_injection.inject_member_event(self.hs, room, "@u2:host2", "join")
+ )
+ self.get_success(
+ event_injection.inject_member_event(self.hs, room, "@u3:other", "join")
+ )
+
+ # create some events
+ event_id_1 = self.helper.send(room, "hello", tok=u1_token)["event_id"]
+ event_id_2 = self.helper.send(room, "world", tok=u1_token)["event_id"]
+ # pretend that u3 changes their displayname
+ event_id_3 = self.get_success(
+ event_injection.inject_member_event(self.hs, room, "@u3:other", "join")
+ ).event_id
+
+ # destination_rooms should already be populated, but let us pretend that we already
+ # sent (successfully) up to and including event id 1
+ event_1 = self.get_success(self.hs.get_datastores().main.get_event(event_id_1))
+ assert event_1.internal_metadata.stream_ordering is not None
+ self.get_success(
+ self.hs.get_datastores().main.set_destination_last_successful_stream_ordering(
+ "host2", event_1.internal_metadata.stream_ordering
+ )
+ )
+
+ # also fetch event 2 so we can compare its stream ordering to the sender's
+ # last_successful_stream_ordering later
+ event_2 = self.get_success(self.hs.get_datastores().main.get_event(event_id_2))
+
+ # Mock event 3 as having partial state
+ self.get_success(
+ event_injection.mark_event_as_partial_state(self.hs, event_id_3, room)
+ )
+
+ # Fail the test if we block on full state for event 3.
+ async def mock_await_full_state(event_ids: Collection[str]) -> None:
+ if event_id_3 in event_ids:
+ raise AssertionError("Tried to await full state for event_id_3")
+
+ # ACT
+ with mock.patch.object(
+ self.hs.get_storage_controllers().state._partial_state_events_tracker,
+ "await_full_state",
+ mock_await_full_state,
+ ):
+ self.get_success(per_dest_queue._catch_up_transmission_loop())
+
+ # ASSERT
+ # We should have:
+ # - not sent event 3: it's not ours, and the room is partial stated
+ # - fallen back to sending event 2: it's the most recent event in the room
+ # we tried to send to host2
+ # - completed catch-up
+ self.assertEqual(len(sent_pdus), 1)
+ self.assertEqual(sent_pdus[0].event_id, event_id_2)
+ self.assertFalse(per_dest_queue._catching_up)
+ self.assertEqual(
+ per_dest_queue._last_successful_stream_ordering,
+ event_2.internal_metadata.stream_ordering,
+ )
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index c5746005b5..ce095eb68a 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -19,17 +19,18 @@ from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
from synapse.api.constants import UserTypes
+from synapse.api.errors import SynapseError
from synapse.api.room_versions import RoomVersion, RoomVersions
from synapse.appservice import ApplicationService
from synapse.rest.client import login, register, room, user_directory
from synapse.server import HomeServer
from synapse.storage.roommember import ProfileInfo
-from synapse.types import UserProfile, create_requester
+from synapse.types import JsonDict, UserProfile, create_requester
from synapse.util import Clock
from tests import unittest
from tests.storage.test_user_directory import GetUserDirectoryTables
-from tests.test_utils import make_awaitable
+from tests.test_utils import event_injection, make_awaitable
from tests.test_utils.event_injection import inject_member_event
from tests.unittest import override_config
@@ -1103,3 +1104,185 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase):
)
self.assertEqual(200, channel.code, channel.result)
self.assertTrue(len(channel.json_body["results"]) == 0)
+
+
+class UserDirectoryRemoteProfileTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ login.register_servlets,
+ synapse.rest.admin.register_servlets,
+ register.register_servlets,
+ room.register_servlets,
+ ]
+
+ def default_config(self) -> JsonDict:
+ config = super().default_config()
+ # Re-enables updating the user directory, as that functionality is needed below.
+ config["update_user_directory_from_worker"] = None
+ return config
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.store = hs.get_datastores().main
+ self.alice = self.register_user("alice", "alice123")
+ self.alice_tok = self.login("alice", "alice123")
+ self.user_dir_helper = GetUserDirectoryTables(self.store)
+ self.user_dir_handler = hs.get_user_directory_handler()
+ self.profile_handler = hs.get_profile_handler()
+
+ # Cancel the startup call: in the steady-state case we can't rely on it anyway.
+ assert self.user_dir_handler._refresh_remote_profiles_call_later is not None
+ self.user_dir_handler._refresh_remote_profiles_call_later.cancel()
+
+ def test_public_rooms_have_profiles_collected(self) -> None:
+ """
+ In a public room, member state events are treated as reflecting the user's
+ real profile and they are accepted.
+ (The main motivation for accepting this is to prevent having to query
+ *every* single profile change over federation.)
+ """
+ room_id = self.helper.create_room_as(
+ self.alice, is_public=True, tok=self.alice_tok
+ )
+ self.get_success(
+ event_injection.inject_member_event(
+ self.hs,
+ room_id,
+ "@bruce:remote",
+ "join",
+ "@bruce:remote",
+ extra_content={
+ "displayname": "Bruce!",
+ "avatar_url": "mxc://remote/123",
+ },
+ )
+ )
+ # Sending this event makes the streams move forward after the injection...
+ self.helper.send(room_id, "Test", tok=self.alice_tok)
+ self.pump(0.1)
+
+ profiles = self.get_success(
+ self.user_dir_helper.get_profiles_in_user_directory()
+ )
+ self.assertEqual(
+ profiles.get("@bruce:remote"),
+ ProfileInfo(display_name="Bruce!", avatar_url="mxc://remote/123"),
+ )
+
+ def test_private_rooms_do_not_have_profiles_collected(self) -> None:
+ """
+ In a private room, member state events are not pulled out and used to populate
+ the user directory.
+ """
+ room_id = self.helper.create_room_as(
+ self.alice, is_public=False, tok=self.alice_tok
+ )
+ self.get_success(
+ event_injection.inject_member_event(
+ self.hs,
+ room_id,
+ "@bruce:remote",
+ "join",
+ "@bruce:remote",
+ extra_content={
+ "displayname": "super-duper bruce",
+ "avatar_url": "mxc://remote/456",
+ },
+ )
+ )
+ # Sending this event makes the streams move forward after the injection...
+ self.helper.send(room_id, "Test", tok=self.alice_tok)
+ self.pump(0.1)
+
+ profiles = self.get_success(
+ self.user_dir_helper.get_profiles_in_user_directory()
+ )
+ self.assertNotIn("@bruce:remote", profiles)
+
+ def test_private_rooms_have_profiles_requested(self) -> None:
+ """
+ When a name changes in a private room, the homeserver instead requests
+ the user's global profile over federation.
+ """
+
+ async def get_remote_profile(
+ user_id: str, ignore_backoff: bool = True
+ ) -> JsonDict:
+ if user_id == "@bruce:remote":
+ return {
+ "displayname": "Sir Bruce Bruceson",
+ "avatar_url": "mxc://remote/789",
+ }
+ else:
+ raise ValueError(f"unable to fetch {user_id}")
+
+ with patch.object(self.profile_handler, "get_profile", get_remote_profile):
+ # Continue from the earlier test...
+ self.test_private_rooms_do_not_have_profiles_collected()
+
+ # Advance by a minute
+ self.reactor.advance(61.0)
+
+ profiles = self.get_success(
+ self.user_dir_helper.get_profiles_in_user_directory()
+ )
+ self.assertEqual(
+ profiles.get("@bruce:remote"),
+ ProfileInfo(
+ display_name="Sir Bruce Bruceson", avatar_url="mxc://remote/789"
+ ),
+ )
+
+ def test_profile_requests_are_retried(self) -> None:
+ """
+ When we fail to fetch the user's profile over federation,
+ we try again later.
+ """
+ has_failed_once = False
+
+ async def get_remote_profile(
+ user_id: str, ignore_backoff: bool = True
+ ) -> JsonDict:
+ nonlocal has_failed_once
+ if user_id == "@bruce:remote":
+ if not has_failed_once:
+ has_failed_once = True
+ raise SynapseError(502, "temporary network problem")
+
+ return {
+ "displayname": "Sir Bruce Bruceson",
+ "avatar_url": "mxc://remote/789",
+ }
+ else:
+ raise ValueError(f"unable to fetch {user_id}")
+
+ with patch.object(self.profile_handler, "get_profile", get_remote_profile):
+ # Continue from the earlier test...
+ self.test_private_rooms_do_not_have_profiles_collected()
+
+ # Advance by a minute
+ self.reactor.advance(61.0)
+
+ # The request has already failed once
+ self.assertTrue(has_failed_once)
+
+ # The profile has yet to be updated.
+ profiles = self.get_success(
+ self.user_dir_helper.get_profiles_in_user_directory()
+ )
+ self.assertNotIn(
+ "@bruce:remote",
+ profiles,
+ )
+
+ # Advance by five minutes, after the backoff has finished
+ self.reactor.advance(301.0)
+
+ # The profile should have been updated now
+ profiles = self.get_success(
+ self.user_dir_helper.get_profiles_in_user_directory()
+ )
+ self.assertEqual(
+ profiles.get("@bruce:remote"),
+ ProfileInfo(
+ display_name="Sir Bruce Bruceson", avatar_url="mxc://remote/789"
+ ),
+ )
diff --git a/tests/http/test_client.py b/tests/http/test_client.py
index f6d6684985..57b6a84e23 100644
--- a/tests/http/test_client.py
+++ b/tests/http/test_client.py
@@ -210,8 +210,8 @@ class BlacklistingAgentTest(TestCase):
"""Apply the blacklisting agent and ensure it properly blocks connections to particular IPs."""
agent = BlacklistingAgentWrapper(
Agent(self.reactor),
- ip_whitelist=self.ip_whitelist,
ip_blacklist=self.ip_blacklist,
+ ip_whitelist=self.ip_whitelist,
)
# The unsafe IPs should be rejected.
diff --git a/tests/push/test_bulk_push_rule_evaluator.py b/tests/push/test_bulk_push_rule_evaluator.py
index af0341808d..978c2d5a34 100644
--- a/tests/push/test_bulk_push_rule_evaluator.py
+++ b/tests/push/test_bulk_push_rule_evaluator.py
@@ -228,14 +228,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
)
return len(result) > 0
- @override_config(
- {
- "experimental_features": {
- "msc3952_intentional_mentions": True,
- "msc3966_exact_event_property_contains": True,
- }
- }
- )
+ @override_config({"experimental_features": {"msc3952_intentional_mentions": True}})
def test_user_mentions(self) -> None:
"""Test the behavior of an event which includes invalid user mentions."""
bulk_evaluator = BulkPushRuleEvaluator(self.hs)
@@ -331,14 +324,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
)
)
- @override_config(
- {
- "experimental_features": {
- "msc3952_intentional_mentions": True,
- "msc3966_exact_event_property_contains": True,
- }
- }
- )
+ @override_config({"experimental_features": {"msc3952_intentional_mentions": True}})
def test_room_mentions(self) -> None:
"""Test the behavior of an event which includes invalid room mentions."""
bulk_evaluator = BulkPushRuleEvaluator(self.hs)
diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py
index ff5a9a66f5..52c4aafea6 100644
--- a/tests/push/test_push_rule_evaluator.py
+++ b/tests/push/test_push_rule_evaluator.py
@@ -51,11 +51,7 @@ class FlattenDictTestCase(unittest.TestCase):
# If a field has a dot in it, escape it.
input = {"m.foo": {"b\\ar": "abc"}}
- self.assertEqual({"m.foo.b\\ar": "abc"}, _flatten_dict(input))
- self.assertEqual(
- {"m\\.foo.b\\\\ar": "abc"},
- _flatten_dict(input, msc3873_escape_event_match_key=True),
- )
+ self.assertEqual({"m\\.foo.b\\\\ar": "abc"}, _flatten_dict(input))
def test_non_string(self) -> None:
"""String, booleans, ints, nulls and list of those should be kept while other items are dropped."""
@@ -125,7 +121,7 @@ class FlattenDictTestCase(unittest.TestCase):
"room_id": "!test:test",
"sender": "@alice:test",
"type": "m.room.message",
- "content.org.matrix.msc1767.markup": [],
+ "content.org\\.matrix\\.msc1767\\.markup": [],
}
self.assertEqual(expected, _flatten_dict(event))
@@ -137,7 +133,7 @@ class FlattenDictTestCase(unittest.TestCase):
"room_id": "!test:test",
"sender": "@alice:test",
"type": "m.room.message",
- "content.org.matrix.msc1767.markup": [],
+ "content.org\\.matrix\\.msc1767\\.markup": [],
}
self.assertEqual(expected, _flatten_dict(event))
@@ -173,7 +169,6 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
related_event_match_enabled=True,
room_version_feature_flags=event.room_version.msc3931_push_features,
msc3931_enabled=True,
- msc3966_exact_event_property_contains=True,
)
def test_display_name(self) -> None:
@@ -526,7 +521,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
"""Check that exact_event_property_contains conditions work as expected."""
condition = {
- "kind": "org.matrix.msc3966.exact_event_property_contains",
+ "kind": "event_property_contains",
"key": "content.value",
"value": "foobaz",
}
diff --git a/tests/replication/tcp/test_remote_server_up.py b/tests/replication/tcp/test_remote_server_up.py
deleted file mode 100644
index b75fc05fd5..0000000000
--- a/tests/replication/tcp/test_remote_server_up.py
+++ /dev/null
@@ -1,63 +0,0 @@
-# Copyright 2020 The Matrix.org Foundation C.I.C.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from typing import Tuple
-
-from twisted.internet.address import IPv4Address
-from twisted.internet.interfaces import IProtocol
-from twisted.test.proto_helpers import MemoryReactor, StringTransport
-
-from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
-from synapse.server import HomeServer
-from synapse.util import Clock
-
-from tests.unittest import HomeserverTestCase
-
-
-class RemoteServerUpTestCase(HomeserverTestCase):
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- self.factory = ReplicationStreamProtocolFactory(hs)
-
- def _make_client(self) -> Tuple[IProtocol, StringTransport]:
- """Create a new direct TCP replication connection"""
-
- proto = self.factory.buildProtocol(IPv4Address("TCP", "127.0.0.1", 0))
- transport = StringTransport()
- proto.makeConnection(transport)
-
- # We can safely ignore the commands received during connection.
- self.pump()
- transport.clear()
-
- return proto, transport
-
- def test_relay(self) -> None:
- """Test that Synapse will relay REMOTE_SERVER_UP commands to all
- other connections, but not the one that sent it.
- """
-
- proto1, transport1 = self._make_client()
-
- # We shouldn't receive an echo.
- proto1.dataReceived(b"REMOTE_SERVER_UP example.com\n")
- self.pump()
- self.assertEqual(transport1.value(), b"")
-
- # But we should see an echo if we connect another client
- proto2, transport2 = self._make_client()
- proto1.dataReceived(b"REMOTE_SERVER_UP example.com\n")
-
- self.pump()
- self.assertEqual(transport1.value(), b"")
- self.assertEqual(transport2.value(), b"REMOTE_SERVER_UP example.com\n")
diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py
index 2b05dffc7d..7f675c44a2 100644
--- a/tests/rest/client/test_account.py
+++ b/tests/rest/client/test_account.py
@@ -1249,9 +1249,8 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
# account status will fail.
return UserID.from_string(user_id).localpart == "someuser"
- self.hs.get_account_validity_handler()._is_user_expired_callbacks.append(
- is_expired
- )
+ account_validity_callbacks = self.hs.get_module_api_callbacks().account_validity
+ account_validity_callbacks.is_user_expired_callbacks.append(is_expired)
self._test_status(
users=[user],
diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
index 7245830b01..1bdb6bb6a5 100644
--- a/tests/rest/client/test_third_party_rules.py
+++ b/tests/rest/client/test_third_party_rules.py
@@ -981,18 +981,16 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
just before associating and removing a 3PID to/from an account.
"""
# Pretend to be a Synapse module and register both callbacks as mocks.
- third_party_rules = self.hs.get_third_party_event_rules()
on_add_user_third_party_identifier_callback_mock = Mock(
return_value=make_awaitable(None)
)
on_remove_user_third_party_identifier_callback_mock = Mock(
return_value=make_awaitable(None)
)
- third_party_rules._on_threepid_bind_callbacks.append(
- on_add_user_third_party_identifier_callback_mock
- )
- third_party_rules._on_threepid_bind_callbacks.append(
- on_remove_user_third_party_identifier_callback_mock
+ third_party_rules = self.hs.get_third_party_event_rules()
+ third_party_rules.register_third_party_rules_callbacks(
+ on_add_user_third_party_identifier=on_add_user_third_party_identifier_callback_mock,
+ on_remove_user_third_party_identifier=on_remove_user_third_party_identifier_callback_mock,
)
# Register an admin user.
@@ -1048,12 +1046,12 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
when a user is deactivated and their third-party ID associations are deleted.
"""
# Pretend to be a Synapse module and register both callbacks as mocks.
- third_party_rules = self.hs.get_third_party_event_rules()
on_remove_user_third_party_identifier_callback_mock = Mock(
return_value=make_awaitable(None)
)
- third_party_rules._on_threepid_bind_callbacks.append(
- on_remove_user_third_party_identifier_callback_mock
+ third_party_rules = self.hs.get_third_party_event_rules()
+ third_party_rules.register_third_party_rules_callbacks(
+ on_remove_user_third_party_identifier=on_remove_user_third_party_identifier_callback_mock,
)
# Register an admin user.
@@ -1079,6 +1077,9 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
)
self.assertEqual(channel.code, 200, channel.json_body)
+ # Check that the mock was not called on the act of adding a third-party ID.
+ on_remove_user_third_party_identifier_callback_mock.assert_not_called()
+
# Now deactivate the user.
channel = self.make_request(
"PUT",
diff --git a/tests/rest/client/test_transactions.py b/tests/rest/client/test_transactions.py
index 3086e1b565..d8dc56261a 100644
--- a/tests/rest/client/test_transactions.py
+++ b/tests/rest/client/test_transactions.py
@@ -39,15 +39,23 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
self.cache = HttpTransactionCache(self.hs)
self.mock_http_response = (HTTPStatus.OK, {"result": "GOOD JOB!"})
- self.mock_key = "foo"
+
+ # Here we make sure that we're setting all the fields that HttpTransactionCache
+ # uses to build the transaction key.
+ self.mock_request = Mock()
+ self.mock_request.path = b"/foo/bar"
+ self.mock_requester = Mock()
+ self.mock_requester.app_service = None
+ self.mock_requester.is_guest = False
+ self.mock_requester.access_token_id = 1234
@defer.inlineCallbacks
def test_executes_given_function(
self,
) -> Generator["defer.Deferred[Any]", object, None]:
cb = Mock(return_value=make_awaitable(self.mock_http_response))
- res = yield self.cache.fetch_or_execute(
- self.mock_key, cb, "some_arg", keyword="arg"
+ res = yield self.cache.fetch_or_execute_request(
+ self.mock_request, self.mock_requester, cb, "some_arg", keyword="arg"
)
cb.assert_called_once_with("some_arg", keyword="arg")
self.assertEqual(res, self.mock_http_response)
@@ -58,8 +66,13 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
) -> Generator["defer.Deferred[Any]", object, None]:
cb = Mock(return_value=make_awaitable(self.mock_http_response))
for i in range(3): # invoke multiple times
- res = yield self.cache.fetch_or_execute(
- self.mock_key, cb, "some_arg", keyword="arg", changing_args=i
+ res = yield self.cache.fetch_or_execute_request(
+ self.mock_request,
+ self.mock_requester,
+ cb,
+ "some_arg",
+ keyword="arg",
+ changing_args=i,
)
self.assertEqual(res, self.mock_http_response)
# expect only a single call to do the work
@@ -77,7 +90,9 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test() -> Generator["defer.Deferred[Any]", object, None]:
with LoggingContext("c") as c1:
- res = yield self.cache.fetch_or_execute(self.mock_key, cb)
+ res = yield self.cache.fetch_or_execute_request(
+ self.mock_request, self.mock_requester, cb
+ )
self.assertIs(current_context(), c1)
self.assertEqual(res, (1, {}))
@@ -106,12 +121,16 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
with LoggingContext("test") as test_context:
try:
- yield self.cache.fetch_or_execute(self.mock_key, cb)
+ yield self.cache.fetch_or_execute_request(
+ self.mock_request, self.mock_requester, cb
+ )
except Exception as e:
self.assertEqual(e.args[0], "boo")
self.assertIs(current_context(), test_context)
- res = yield self.cache.fetch_or_execute(self.mock_key, cb)
+ res = yield self.cache.fetch_or_execute_request(
+ self.mock_request, self.mock_requester, cb
+ )
self.assertEqual(res, self.mock_http_response)
self.assertIs(current_context(), test_context)
@@ -134,29 +153,39 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
with LoggingContext("test") as test_context:
try:
- yield self.cache.fetch_or_execute(self.mock_key, cb)
+ yield self.cache.fetch_or_execute_request(
+ self.mock_request, self.mock_requester, cb
+ )
except Exception as e:
self.assertEqual(e.args[0], "boo")
self.assertIs(current_context(), test_context)
- res = yield self.cache.fetch_or_execute(self.mock_key, cb)
+ res = yield self.cache.fetch_or_execute_request(
+ self.mock_request, self.mock_requester, cb
+ )
self.assertEqual(res, self.mock_http_response)
self.assertIs(current_context(), test_context)
@defer.inlineCallbacks
def test_cleans_up(self) -> Generator["defer.Deferred[Any]", object, None]:
cb = Mock(return_value=make_awaitable(self.mock_http_response))
- yield self.cache.fetch_or_execute(self.mock_key, cb, "an arg")
+ yield self.cache.fetch_or_execute_request(
+ self.mock_request, self.mock_requester, cb, "an arg"
+ )
# should NOT have cleaned up yet
self.clock.advance_time_msec(CLEANUP_PERIOD_MS / 2)
- yield self.cache.fetch_or_execute(self.mock_key, cb, "an arg")
+ yield self.cache.fetch_or_execute_request(
+ self.mock_request, self.mock_requester, cb, "an arg"
+ )
# still using cache
cb.assert_called_once_with("an arg")
self.clock.advance_time_msec(CLEANUP_PERIOD_MS)
- yield self.cache.fetch_or_execute(self.mock_key, cb, "an arg")
+ yield self.cache.fetch_or_execute_request(
+ self.mock_request, self.mock_requester, cb, "an arg"
+ )
# no longer using cache
self.assertEqual(cb.call_count, 2)
self.assertEqual(cb.call_args_list, [call("an arg"), call("an arg")])
diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py
index de797992c3..c619ef7f38 100644
--- a/tests/test_utils/event_injection.py
+++ b/tests/test_utils/event_injection.py
@@ -103,3 +103,34 @@ async def create_event(
context = await unpersisted_context.persist(event)
return event, context
+
+
+async def mark_event_as_partial_state(
+ hs: synapse.server.HomeServer,
+ event_id: str,
+ room_id: str,
+) -> None:
+ """
+ (Falsely) mark an event as having partial state.
+
+ Naughty, but occasionally useful when checking that partial state doesn't
+ block something from happening.
+
+ If the event already has partial state, this insert will fail (event_id is unique
+ in this table).
+ """
+ store = hs.get_datastores().main
+ await store.db_pool.simple_upsert(
+ table="partial_state_rooms",
+ keyvalues={"room_id": room_id},
+ values={},
+ insertion_values={"room_id": room_id},
+ )
+
+ await store.db_pool.simple_insert(
+ table="partial_state_events",
+ values={
+ "room_id": room_id,
+ "event_id": event_id,
+ },
+ )
diff --git a/tests/test_visibility.py b/tests/test_visibility.py
index 2b6d7048d1..6004490b8c 100644
--- a/tests/test_visibility.py
+++ b/tests/test_visibility.py
@@ -63,7 +63,13 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
filtered = self.get_success(
filter_events_for_server(
- self._storage_controllers, "test_server", "hs", events_to_filter
+ self._storage_controllers,
+ "test_server",
+ "hs",
+ events_to_filter,
+ redact=True,
+ filter_out_erased_senders=True,
+ filter_out_remote_partial_state_events=True,
)
)
@@ -85,7 +91,13 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
self.assertEqual(
self.get_success(
filter_events_for_server(
- self._storage_controllers, "remote_hs", "hs", [outlier]
+ self._storage_controllers,
+ "remote_hs",
+ "hs",
+ [outlier],
+ redact=True,
+ filter_out_erased_senders=True,
+ filter_out_remote_partial_state_events=True,
)
),
[outlier],
@@ -96,7 +108,13 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
filtered = self.get_success(
filter_events_for_server(
- self._storage_controllers, "remote_hs", "local_hs", [outlier, evt]
+ self._storage_controllers,
+ "remote_hs",
+ "local_hs",
+ [outlier, evt],
+ redact=True,
+ filter_out_erased_senders=True,
+ filter_out_remote_partial_state_events=True,
)
)
self.assertEqual(len(filtered), 2, f"expected 2 results, got: {filtered}")
@@ -108,7 +126,13 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
# be redacted)
filtered = self.get_success(
filter_events_for_server(
- self._storage_controllers, "other_server", "local_hs", [outlier, evt]
+ self._storage_controllers,
+ "other_server",
+ "local_hs",
+ [outlier, evt],
+ redact=True,
+ filter_out_erased_senders=True,
+ filter_out_remote_partial_state_events=True,
)
)
self.assertEqual(filtered[0], outlier)
@@ -143,7 +167,13 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
# ... and the filtering happens.
filtered = self.get_success(
filter_events_for_server(
- self._storage_controllers, "test_server", "local_hs", events_to_filter
+ self._storage_controllers,
+ "test_server",
+ "local_hs",
+ events_to_filter,
+ redact=True,
+ filter_out_erased_senders=True,
+ filter_out_remote_partial_state_events=True,
)
)
|