diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py
index 222449baac..aa6af5ad7b 100644
--- a/tests/api/test_filtering.py
+++ b/tests/api/test_filtering.py
@@ -48,8 +48,6 @@ class FilteringTestCase(unittest.HomeserverTestCase):
invalid_filters: List[JsonDict] = [
# `account_data` must be a dictionary
{"account_data": "Hello World"},
- # `event_fields` entries must not contain backslashes
- {"event_fields": [r"\\foo"]},
# `event_format` must be "client" or "federation"
{"event_format": "other"},
# `not_rooms` must contain valid room IDs
@@ -114,10 +112,6 @@ class FilteringTestCase(unittest.HomeserverTestCase):
"event_format": "client",
"event_fields": ["type", "content", "sender"],
},
- # a single backslash should be permitted (though it is debatable whether
- # it should be permitted before anything other than `.`, and what that
- # actually means)
- #
# (note that event_fields is implemented in
# synapse.events.utils.serialize_event, and so whether this actually works
# is tested elsewhere. We just want to check that it is allowed through the
diff --git a/tests/app/test_openid_listener.py b/tests/app/test_openid_listener.py
index 056d9402a4..5a965f233b 100644
--- a/tests/app/test_openid_listener.py
+++ b/tests/app/test_openid_listener.py
@@ -38,7 +38,7 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase):
def default_config(self) -> JsonDict:
conf = super().default_config()
- # we're using FederationReaderServer, which uses a SlavedStore, so we
+ # we're using GenericWorkerServer, which uses a GenericWorkerStore, so we
# have to tell the FederationHandler not to try to access stuff that is only
# in the primary store.
conf["worker_app"] = "yes"
diff --git a/tests/app/test_phone_stats_home.py b/tests/app/test_phone_stats_home.py
index a860eedbcf..9305b758d7 100644
--- a/tests/app/test_phone_stats_home.py
+++ b/tests/app/test_phone_stats_home.py
@@ -4,7 +4,6 @@ from synapse.rest.client import login, room
from synapse.server import HomeServer
from synapse.util import Clock
-from tests import unittest
from tests.server import ThreadedMemoryReactorClock
from tests.unittest import HomeserverTestCase
@@ -12,154 +11,6 @@ FIVE_MINUTES_IN_SECONDS = 300
ONE_DAY_IN_SECONDS = 86400
-class PhoneHomeTestCase(HomeserverTestCase):
- servlets = [
- synapse.rest.admin.register_servlets_for_client_rest_resource,
- room.register_servlets,
- login.register_servlets,
- ]
-
- # Override the retention time for the user_ips table because otherwise it
- # gets pruned too aggressively for our R30 test.
- @unittest.override_config({"user_ips_max_age": "365d"})
- def test_r30_minimum_usage(self) -> None:
- """
- Tests the minimum amount of interaction necessary for the R30 metric
- to consider a user 'retained'.
- """
-
- # Register a user, log it in, create a room and send a message
- user_id = self.register_user("u1", "secret!")
- access_token = self.login("u1", "secret!")
- room_id = self.helper.create_room_as(room_creator=user_id, tok=access_token)
- self.helper.send(room_id, "message", tok=access_token)
-
- # Check the R30 results do not count that user.
- r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users())
- self.assertEqual(r30_results, {"all": 0})
-
- # Advance 30 days (+ 1 second, because strict inequality causes issues if we are
- # bang on 30 days later).
- self.reactor.advance(30 * ONE_DAY_IN_SECONDS + 1)
-
- # (Make sure the user isn't somehow counted by this point.)
- r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users())
- self.assertEqual(r30_results, {"all": 0})
-
- # Send a message (this counts as activity)
- self.helper.send(room_id, "message2", tok=access_token)
-
- # We have to wait some time for _update_client_ips_batch to get
- # called and update the user_ips table.
- self.reactor.advance(2 * 60 * 60)
-
- # *Now* the user is counted.
- r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users())
- self.assertEqual(r30_results, {"all": 1, "unknown": 1})
-
- # Advance 29 days. The user has now not posted for 29 days.
- self.reactor.advance(29 * ONE_DAY_IN_SECONDS)
-
- # The user is still counted.
- r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users())
- self.assertEqual(r30_results, {"all": 1, "unknown": 1})
-
- # Advance another day. The user has now not posted for 30 days.
- self.reactor.advance(ONE_DAY_IN_SECONDS)
-
- # The user is now no longer counted in R30.
- r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users())
- self.assertEqual(r30_results, {"all": 0})
-
- def test_r30_minimum_usage_using_default_config(self) -> None:
- """
- Tests the minimum amount of interaction necessary for the R30 metric
- to consider a user 'retained'.
-
- N.B. This test does not override the `user_ips_max_age` config setting,
- which defaults to 28 days.
- """
-
- # Register a user, log it in, create a room and send a message
- user_id = self.register_user("u1", "secret!")
- access_token = self.login("u1", "secret!")
- room_id = self.helper.create_room_as(room_creator=user_id, tok=access_token)
- self.helper.send(room_id, "message", tok=access_token)
-
- # Check the R30 results do not count that user.
- r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users())
- self.assertEqual(r30_results, {"all": 0})
-
- # Advance 30 days (+ 1 second, because strict inequality causes issues if we are
- # bang on 30 days later).
- self.reactor.advance(30 * ONE_DAY_IN_SECONDS + 1)
-
- # (Make sure the user isn't somehow counted by this point.)
- r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users())
- self.assertEqual(r30_results, {"all": 0})
-
- # Send a message (this counts as activity)
- self.helper.send(room_id, "message2", tok=access_token)
-
- # We have to wait some time for _update_client_ips_batch to get
- # called and update the user_ips table.
- self.reactor.advance(2 * 60 * 60)
-
- # *Now* the user is counted.
- r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users())
- self.assertEqual(r30_results, {"all": 1, "unknown": 1})
-
- # Advance 27 days. The user has now not posted for 27 days.
- self.reactor.advance(27 * ONE_DAY_IN_SECONDS)
-
- # The user is still counted.
- r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users())
- self.assertEqual(r30_results, {"all": 1, "unknown": 1})
-
- # Advance another day. The user has now not posted for 28 days.
- self.reactor.advance(ONE_DAY_IN_SECONDS)
-
- # The user is now no longer counted in R30.
- # (This is because the user_ips table has been pruned, which by default
- # only preserves the last 28 days of entries.)
- r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users())
- self.assertEqual(r30_results, {"all": 0})
-
- def test_r30_user_must_be_retained_for_at_least_a_month(self) -> None:
- """
- Tests that a newly-registered user must be retained for a whole month
- before appearing in the R30 statistic, even if they post every day
- during that time!
- """
- # Register a user and send a message
- user_id = self.register_user("u1", "secret!")
- access_token = self.login("u1", "secret!")
- room_id = self.helper.create_room_as(room_creator=user_id, tok=access_token)
- self.helper.send(room_id, "message", tok=access_token)
-
- # Check the user does not contribute to R30 yet.
- r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users())
- self.assertEqual(r30_results, {"all": 0})
-
- for _ in range(30):
- # This loop posts a message every day for 30 days
- self.reactor.advance(ONE_DAY_IN_SECONDS)
- self.helper.send(room_id, "I'm still here", tok=access_token)
-
- # Notice that the user *still* does not contribute to R30!
- r30_results = self.get_success(
- self.hs.get_datastores().main.count_r30_users()
- )
- self.assertEqual(r30_results, {"all": 0})
-
- self.reactor.advance(ONE_DAY_IN_SECONDS)
- self.helper.send(room_id, "Still here!", tok=access_token)
-
- # *Now* the user appears in R30.
- r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users())
- self.assertEqual(r30_results, {"all": 1, "unknown": 1})
-
-
class PhoneHomeR30V2TestCase(HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
@@ -363,11 +214,6 @@ class PhoneHomeR30V2TestCase(HomeserverTestCase):
r30_results, {"all": 0, "android": 0, "electron": 0, "ios": 0, "web": 0}
)
- # Check that this is a situation where old R30 differs:
- # old R30 DOES count this as 'retained'.
- r30_results = self.get_success(store.count_r30_users())
- self.assertEqual(r30_results, {"all": 1, "ios": 1})
-
# Now we want to check that the user will still be able to appear in
# R30v2 as long as the user performs some other activity between
# 30 and 60 days later.
diff --git a/tests/appservice/test_appservice.py b/tests/appservice/test_appservice.py
index dee976356f..66753c60c4 100644
--- a/tests/appservice/test_appservice.py
+++ b/tests/appservice/test_appservice.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import re
-from typing import Generator
+from typing import Any, Generator
from unittest.mock import Mock
from twisted.internet import defer
@@ -49,15 +49,13 @@ class ApplicationServiceTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_regex_user_id_prefix_match(
self,
- ) -> Generator["defer.Deferred[object]", object, None]:
+ ) -> Generator["defer.Deferred[Any]", object, None]:
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
self.event.sender = "@irc_foobar:matrix.org"
self.assertTrue(
(
- yield defer.ensureDeferred(
- self.service.is_interested_in_event(
- self.event.event_id, self.event, self.store
- )
+ yield self.service.is_interested_in_event(
+ self.event.event_id, self.event, self.store
)
)
)
@@ -65,15 +63,13 @@ class ApplicationServiceTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_regex_user_id_prefix_no_match(
self,
- ) -> Generator["defer.Deferred[object]", object, None]:
+ ) -> Generator["defer.Deferred[Any]", object, None]:
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
self.event.sender = "@someone_else:matrix.org"
self.assertFalse(
(
- yield defer.ensureDeferred(
- self.service.is_interested_in_event(
- self.event.event_id, self.event, self.store
- )
+ yield self.service.is_interested_in_event(
+ self.event.event_id, self.event, self.store
)
)
)
@@ -81,17 +77,15 @@ class ApplicationServiceTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_regex_room_member_is_checked(
self,
- ) -> Generator["defer.Deferred[object]", object, None]:
+ ) -> Generator["defer.Deferred[Any]", object, None]:
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
self.event.sender = "@someone_else:matrix.org"
self.event.type = "m.room.member"
self.event.state_key = "@irc_foobar:matrix.org"
self.assertTrue(
(
- yield defer.ensureDeferred(
- self.service.is_interested_in_event(
- self.event.event_id, self.event, self.store
- )
+ yield self.service.is_interested_in_event(
+ self.event.event_id, self.event, self.store
)
)
)
@@ -99,17 +93,15 @@ class ApplicationServiceTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_regex_room_id_match(
self,
- ) -> Generator["defer.Deferred[object]", object, None]:
+ ) -> Generator["defer.Deferred[Any]", object, None]:
self.service.namespaces[ApplicationService.NS_ROOMS].append(
_regex("!some_prefix.*some_suffix:matrix.org")
)
self.event.room_id = "!some_prefixs0m3th1nGsome_suffix:matrix.org"
self.assertTrue(
(
- yield defer.ensureDeferred(
- self.service.is_interested_in_event(
- self.event.event_id, self.event, self.store
- )
+ yield self.service.is_interested_in_event(
+ self.event.event_id, self.event, self.store
)
)
)
@@ -117,25 +109,21 @@ class ApplicationServiceTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_regex_room_id_no_match(
self,
- ) -> Generator["defer.Deferred[object]", object, None]:
+ ) -> Generator["defer.Deferred[Any]", object, None]:
self.service.namespaces[ApplicationService.NS_ROOMS].append(
_regex("!some_prefix.*some_suffix:matrix.org")
)
self.event.room_id = "!XqBunHwQIXUiqCaoxq:matrix.org"
self.assertFalse(
(
- yield defer.ensureDeferred(
- self.service.is_interested_in_event(
- self.event.event_id, self.event, self.store
- )
+ yield self.service.is_interested_in_event(
+ self.event.event_id, self.event, self.store
)
)
)
@defer.inlineCallbacks
- def test_regex_alias_match(
- self,
- ) -> Generator["defer.Deferred[object]", object, None]:
+ def test_regex_alias_match(self) -> Generator["defer.Deferred[Any]", object, None]:
self.service.namespaces[ApplicationService.NS_ALIASES].append(
_regex("#irc_.*:matrix.org")
)
@@ -145,10 +133,8 @@ class ApplicationServiceTestCase(unittest.TestCase):
self.store.get_local_users_in_room = simple_async_mock([])
self.assertTrue(
(
- yield defer.ensureDeferred(
- self.service.is_interested_in_event(
- self.event.event_id, self.event, self.store
- )
+ yield self.service.is_interested_in_event(
+ self.event.event_id, self.event, self.store
)
)
)
@@ -192,7 +178,7 @@ class ApplicationServiceTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_regex_alias_no_match(
self,
- ) -> Generator["defer.Deferred[object]", object, None]:
+ ) -> Generator["defer.Deferred[Any]", object, None]:
self.service.namespaces[ApplicationService.NS_ALIASES].append(
_regex("#irc_.*:matrix.org")
)
@@ -213,7 +199,7 @@ class ApplicationServiceTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_regex_multiple_matches(
self,
- ) -> Generator["defer.Deferred[object]", object, None]:
+ ) -> Generator["defer.Deferred[Any]", object, None]:
self.service.namespaces[ApplicationService.NS_ALIASES].append(
_regex("#irc_.*:matrix.org")
)
@@ -223,18 +209,14 @@ class ApplicationServiceTestCase(unittest.TestCase):
self.store.get_local_users_in_room = simple_async_mock([])
self.assertTrue(
(
- yield defer.ensureDeferred(
- self.service.is_interested_in_event(
- self.event.event_id, self.event, self.store
- )
+ yield self.service.is_interested_in_event(
+ self.event.event_id, self.event, self.store
)
)
)
@defer.inlineCallbacks
- def test_interested_in_self(
- self,
- ) -> Generator["defer.Deferred[object]", object, None]:
+ def test_interested_in_self(self) -> Generator["defer.Deferred[Any]", object, None]:
# make sure invites get through
self.service.sender = "@appservice:name"
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
@@ -243,18 +225,14 @@ class ApplicationServiceTestCase(unittest.TestCase):
self.event.state_key = self.service.sender
self.assertTrue(
(
- yield defer.ensureDeferred(
- self.service.is_interested_in_event(
- self.event.event_id, self.event, self.store
- )
+ yield self.service.is_interested_in_event(
+ self.event.event_id, self.event, self.store
)
)
)
@defer.inlineCallbacks
- def test_member_list_match(
- self,
- ) -> Generator["defer.Deferred[object]", object, None]:
+ def test_member_list_match(self) -> Generator["defer.Deferred[Any]", object, None]:
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
# Note that @irc_fo:here is the AS user.
self.store.get_local_users_in_room = simple_async_mock(
@@ -265,10 +243,8 @@ class ApplicationServiceTestCase(unittest.TestCase):
self.event.sender = "@xmpp_foobar:matrix.org"
self.assertTrue(
(
- yield defer.ensureDeferred(
- self.service.is_interested_in_event(
- self.event.event_id, self.event, self.store
- )
+ yield self.service.is_interested_in_event(
+ self.event.event_id, self.event, self.store
)
)
)
diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py
index e40eac2eb0..c9a610db9a 100644
--- a/tests/events/test_utils.py
+++ b/tests/events/test_utils.py
@@ -16,6 +16,7 @@ import unittest as stdlib_unittest
from typing import Any, List, Mapping, Optional
import attr
+from parameterized import parameterized
from synapse.api.constants import EventContentFields
from synapse.api.room_versions import RoomVersions
@@ -23,6 +24,7 @@ from synapse.events import EventBase, make_event_from_dict
from synapse.events.utils import (
PowerLevelsContent,
SerializeEventConfig,
+ _split_field,
copy_and_fixup_power_levels_contents,
maybe_upsert_event_field,
prune_event,
@@ -794,3 +796,40 @@ class CopyPowerLevelsContentTestCase(stdlib_unittest.TestCase):
def test_invalid_nesting_raises_type_error(self) -> None:
with self.assertRaises(TypeError):
copy_and_fixup_power_levels_contents({"a": {"b": {"c": 1}}}) # type: ignore[dict-item]
+
+
+class SplitFieldTestCase(stdlib_unittest.TestCase):
+ @parameterized.expand(
+ [
+ # A field with no dots.
+ ["m", ["m"]],
+ # Simple dotted fields.
+ ["m.foo", ["m", "foo"]],
+ ["m.foo.bar", ["m", "foo", "bar"]],
+ # Backslash is used as an escape character.
+ [r"m\.foo", ["m.foo"]],
+ [r"m\\.foo", ["m\\", "foo"]],
+ [r"m\\\.foo", [r"m\.foo"]],
+ [r"m\\\\.foo", ["m\\\\", "foo"]],
+ [r"m\foo", [r"m\foo"]],
+ [r"m\\foo", [r"m\foo"]],
+ [r"m\\\foo", [r"m\\foo"]],
+ [r"m\\\\foo", [r"m\\foo"]],
+ # Ensure that escapes at the end don't cause issues.
+ ["m.foo\\", ["m", "foo\\"]],
+ ["m.foo\\", ["m", "foo\\"]],
+ [r"m.foo\.", ["m", "foo."]],
+ [r"m.foo\\.", ["m", "foo\\", ""]],
+ [r"m.foo\\\.", ["m", r"foo\."]],
+ # Empty parts (corresponding to properties which are an empty string) are allowed.
+ [".m", ["", "m"]],
+ ["..m", ["", "", "m"]],
+ ["m.", ["m", ""]],
+ ["m..", ["m", "", ""]],
+ ["m..foo", ["m", "", "foo"]],
+ # Invalid escape sequences.
+ [r"\m", [r"\m"]],
+ ]
+ )
+ def test_split_field(self, input: str, expected: str) -> None:
+ self.assertEqual(_split_field(input), expected)
diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py
index 6c7738d810..5c850d1843 100644
--- a/tests/federation/test_federation_server.py
+++ b/tests/federation/test_federation_server.py
@@ -63,7 +63,7 @@ class FederationServerTests(unittest.FederatingHomeserverTestCase):
class ServerACLsTestCase(unittest.TestCase):
- def test_blacklisted_server(self) -> None:
+ def test_blocked_server(self) -> None:
e = _create_acl_event({"allow": ["*"], "deny": ["evil.com"]})
logging.info("ACL event: %s", e.content)
diff --git a/tests/federation/transport/test_client.py b/tests/federation/transport/test_client.py
index 3d61b1e8a9..93e5c85a27 100644
--- a/tests/federation/transport/test_client.py
+++ b/tests/federation/transport/test_client.py
@@ -86,18 +86,7 @@ class SendJoinParserTestCase(TestCase):
return parsed_response.members_omitted
self.assertTrue(parse({"members_omitted": True}))
- self.assertTrue(parse({"org.matrix.msc3706.partial_state": True}))
-
self.assertFalse(parse({"members_omitted": False}))
- self.assertFalse(parse({"org.matrix.msc3706.partial_state": False}))
-
- # If there's a conflict, the stable field wins.
- self.assertTrue(
- parse({"members_omitted": True, "org.matrix.msc3706.partial_state": False})
- )
- self.assertFalse(
- parse({"members_omitted": False, "org.matrix.msc3706.partial_state": True})
- )
def test_servers_in_room(self) -> None:
"""Check that the servers_in_room field is correctly parsed"""
@@ -113,28 +102,10 @@ class SendJoinParserTestCase(TestCase):
parsed_response = parser.finish()
return parsed_response.servers_in_room
- self.assertEqual(
- parse({"org.matrix.msc3706.servers_in_room": ["hs1", "hs2"]}),
- ["hs1", "hs2"],
- )
self.assertEqual(parse({"servers_in_room": ["example.com"]}), ["example.com"])
- # If both are provided, the stable identifier should win
- self.assertEqual(
- parse(
- {
- "org.matrix.msc3706.servers_in_room": ["old"],
- "servers_in_room": ["new"],
- }
- ),
- ["new"],
- )
-
- # And lastly, we should be able to tell if neither field was present.
- self.assertEqual(
- parse({}),
- None,
- )
+ # We should be able to tell the field is not present.
+ self.assertEqual(parse({}), None)
def test_errors_closing_coroutines(self) -> None:
"""Check we close all coroutines, even if closing the first raises an Exception.
@@ -143,7 +114,7 @@ class SendJoinParserTestCase(TestCase):
assertions about its attributes or type.
"""
parser = SendJoinParser(RoomVersions.V1, False)
- response = {"org.matrix.msc3706.servers_in_room": ["hs1", "hs2"]}
+ response = {"servers_in_room": ["hs1", "hs2"]}
serialisation = json.dumps(response).encode()
# Mock the coroutines managed by this parser.
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 72d0584061..2eaffe511e 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -27,7 +27,7 @@ from synapse.appservice import ApplicationService
from synapse.handlers.device import DeviceHandler
from synapse.server import HomeServer
from synapse.storage.databases.main.appservice import _make_exclusive_regex
-from synapse.types import JsonDict
+from synapse.types import JsonDict, UserID
from synapse.util import Clock
from tests import unittest
@@ -45,6 +45,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.handler = hs.get_e2e_keys_handler()
self.store = self.hs.get_datastores().main
+ self.requester = UserID.from_string(f"@test_requester:{self.hs.hostname}")
def test_query_local_devices_no_devices(self) -> None:
"""If the user has no devices, we expect an empty list."""
@@ -161,6 +162,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
res2 = self.get_success(
self.handler.claim_one_time_keys(
{local_user: {device_id: {"alg1": 1}}},
+ self.requester,
timeout=None,
always_include_fallback_keys=False,
)
@@ -206,6 +208,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{local_user: {device_id: {"alg1": 1}}},
+ self.requester,
timeout=None,
always_include_fallback_keys=False,
)
@@ -225,6 +228,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{local_user: {device_id: {"alg1": 1}}},
+ self.requester,
timeout=None,
always_include_fallback_keys=False,
)
@@ -274,6 +278,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{local_user: {device_id: {"alg1": 1}}},
+ self.requester,
timeout=None,
always_include_fallback_keys=False,
)
@@ -286,6 +291,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{local_user: {device_id: {"alg1": 1}}},
+ self.requester,
timeout=None,
always_include_fallback_keys=False,
)
@@ -307,6 +313,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{local_user: {device_id: {"alg1": 1}}},
+ self.requester,
timeout=None,
always_include_fallback_keys=False,
)
@@ -348,6 +355,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{local_user: {device_id: {"alg1": 1}}},
+ self.requester,
timeout=None,
always_include_fallback_keys=True,
)
@@ -370,6 +378,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{local_user: {device_id: {"alg1": 1}}},
+ self.requester,
timeout=None,
always_include_fallback_keys=True,
)
@@ -1080,6 +1089,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{local_user: {device_id_1: {"alg1": 1}, device_id_2: {"alg1": 1}}},
+ self.requester,
timeout=None,
always_include_fallback_keys=False,
)
@@ -1125,6 +1135,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{local_user: {device_id_1: {"alg1": 1}}},
+ self.requester,
timeout=None,
always_include_fallback_keys=True,
)
@@ -1169,6 +1180,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{local_user: {device_id_1: {"alg1": 1}}},
+ self.requester,
timeout=None,
always_include_fallback_keys=True,
)
@@ -1202,6 +1214,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{local_user: {device_id_1: {"alg1": 1}}},
+ self.requester,
timeout=None,
always_include_fallback_keys=True,
)
@@ -1229,6 +1242,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{local_user: {device_id_1: {"alg1": 1}}},
+ self.requester,
timeout=None,
always_include_fallback_keys=True,
)
diff --git a/tests/handlers/test_federation_event.py b/tests/handlers/test_federation_event.py
index c067e5bfe3..23f1b33b2f 100644
--- a/tests/handlers/test_federation_event.py
+++ b/tests/handlers/test_federation_event.py
@@ -664,6 +664,101 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
StoreError,
)
+ def test_backfill_process_previously_failed_pull_attempt_event_in_the_background(
+ self,
+ ) -> None:
+ """
+ Sanity check that events are still processed even if it is in the background
+ for events that already have failed pull attempts.
+ """
+ OTHER_USER = f"@user:{self.OTHER_SERVER_NAME}"
+ main_store = self.hs.get_datastores().main
+
+ # Create the room
+ user_id = self.register_user("kermit", "test")
+ tok = self.login("kermit", "test")
+ room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
+ room_version = self.get_success(main_store.get_room_version(room_id))
+
+ # Allow the remote user to send state events
+ self.helper.send_state(
+ room_id,
+ "m.room.power_levels",
+ {"events_default": 0, "state_default": 0},
+ tok=tok,
+ )
+
+ # Add the remote user to the room
+ member_event = self.get_success(
+ event_injection.inject_member_event(self.hs, room_id, OTHER_USER, "join")
+ )
+
+ initial_state_map = self.get_success(
+ main_store.get_partial_current_state_ids(room_id)
+ )
+
+ auth_event_ids = [
+ initial_state_map[("m.room.create", "")],
+ initial_state_map[("m.room.power_levels", "")],
+ member_event.event_id,
+ ]
+
+ # Create a regular event that should process
+ pulled_event = make_event_from_dict(
+ self.add_hashes_and_signatures_from_other_server(
+ {
+ "type": "test_regular_type",
+ "room_id": room_id,
+ "sender": OTHER_USER,
+ "prev_events": [
+ member_event.event_id,
+ ],
+ "auth_events": auth_event_ids,
+ "origin_server_ts": 1,
+ "depth": 12,
+ "content": {"body": "pulled_event"},
+ }
+ ),
+ room_version,
+ )
+
+ # Record a failed pull attempt for this event which will cause us to backfill it
+ # in the background from here on out.
+ self.get_success(
+ main_store.record_event_failed_pull_attempt(
+ room_id, pulled_event.event_id, "fake cause"
+ )
+ )
+
+ # We expect an outbound request to /backfill, so stub that out
+ self.mock_federation_transport_client.backfill.return_value = make_awaitable(
+ {
+ "origin": self.OTHER_SERVER_NAME,
+ "origin_server_ts": 123,
+ "pdus": [
+ pulled_event.get_pdu_json(),
+ ],
+ }
+ )
+
+ # The function under test: try to backfill and process the pulled event
+ with LoggingContext("test"):
+ self.get_success(
+ self.hs.get_federation_event_handler().backfill(
+ self.OTHER_SERVER_NAME,
+ room_id,
+ limit=1,
+ extremities=["$some_extremity"],
+ )
+ )
+
+ # Ensure `run_as_background_process(...)` has a chance to run (essentially
+ # `wait_for_background_processes()`)
+ self.reactor.pump((0.1,))
+
+ # Make sure we processed and persisted the pulled event
+ self.get_success(main_store.get_event(pulled_event.event_id, allow_none=False))
+
def test_process_pulled_event_with_rejected_missing_state(self) -> None:
"""Ensure that we correctly handle pulled events with missing state containing a
rejected state event
diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py
index aa91bc0a3d..394006f5f3 100644
--- a/tests/handlers/test_password_providers.py
+++ b/tests/handlers/test_password_providers.py
@@ -18,13 +18,17 @@ from http import HTTPStatus
from typing import Any, Dict, List, Optional, Type, Union
from unittest.mock import Mock
+from twisted.test.proto_helpers import MemoryReactor
+
import synapse
from synapse.api.constants import LoginType
from synapse.api.errors import Codes
from synapse.handlers.account import AccountHandler
from synapse.module_api import ModuleApi
from synapse.rest.client import account, devices, login, logout, register
+from synapse.server import HomeServer
from synapse.types import JsonDict, UserID
+from synapse.util import Clock
from tests import unittest
from tests.server import FakeChannel
@@ -162,10 +166,16 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
CALLBACK_USERNAME = "get_username_for_registration"
CALLBACK_DISPLAYNAME = "get_displayname_for_registration"
- def setUp(self) -> None:
+ def prepare(
+ self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+ ) -> None:
# we use a global mock device, so make sure we are starting with a clean slate
mock_password_provider.reset_mock()
- super().setUp()
+
+ # The mock password provider doesn't register the users, so ensure they
+ # are registered first.
+ self.register_user("u", "not-the-tested-password")
+ self.register_user("user", "not-the-tested-password")
@override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
def test_password_only_auth_progiver_login_legacy(self) -> None:
@@ -185,22 +195,12 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
mock_password_provider.reset_mock()
# login with mxid should work too
- channel = self._send_password_login("@u:bz", "p")
+ channel = self._send_password_login("@u:test", "p")
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
- self.assertEqual("@u:bz", channel.json_body["user_id"])
- mock_password_provider.check_password.assert_called_once_with("@u:bz", "p")
+ self.assertEqual("@u:test", channel.json_body["user_id"])
+ mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
mock_password_provider.reset_mock()
- # try a weird username / pass. Honestly it's unclear what we *expect* to happen
- # in these cases, but at least we can guard against the API changing
- # unexpectedly
- channel = self._send_password_login(" USER🙂NAME ", " pASS\U0001F622word ")
- self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
- self.assertEqual("@ USER🙂NAME :test", channel.json_body["user_id"])
- mock_password_provider.check_password.assert_called_once_with(
- "@ USER🙂NAME :test", " pASS😢word "
- )
-
@override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
def test_password_only_auth_provider_ui_auth_legacy(self) -> None:
self.password_only_auth_provider_ui_auth_test_body()
@@ -208,10 +208,6 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
def password_only_auth_provider_ui_auth_test_body(self) -> None:
"""UI Auth should delegate correctly to the password provider"""
- # create the user, otherwise access doesn't work
- module_api = self.hs.get_module_api()
- self.get_success(module_api.register_user("u"))
-
# log in twice, to get two devices
mock_password_provider.check_password.return_value = make_awaitable(True)
tok1 = self.login("u", "p")
@@ -401,29 +397,16 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
mock_password_provider.check_auth.assert_not_called()
mock_password_provider.check_auth.return_value = make_awaitable(
- ("@user:bz", None)
+ ("@user:test", None)
)
channel = self._send_login("test.login_type", "u", test_field="y")
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
- self.assertEqual("@user:bz", channel.json_body["user_id"])
+ self.assertEqual("@user:test", channel.json_body["user_id"])
mock_password_provider.check_auth.assert_called_once_with(
"u", "test.login_type", {"test_field": "y"}
)
mock_password_provider.reset_mock()
- # try a weird username. Again, it's unclear what we *expect* to happen
- # in these cases, but at least we can guard against the API changing
- # unexpectedly
- mock_password_provider.check_auth.return_value = make_awaitable(
- ("@ MALFORMED! :bz", None)
- )
- channel = self._send_login("test.login_type", " USER🙂NAME ", test_field=" abc ")
- self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
- self.assertEqual("@ MALFORMED! :bz", channel.json_body["user_id"])
- mock_password_provider.check_auth.assert_called_once_with(
- " USER🙂NAME ", "test.login_type", {"test_field": " abc "}
- )
-
@override_config(legacy_providers_config(LegacyCustomAuthProvider))
def test_custom_auth_provider_ui_auth_legacy(self) -> None:
self.custom_auth_provider_ui_auth_test_body()
@@ -465,7 +448,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
# right params, but authing as the wrong user
mock_password_provider.check_auth.return_value = make_awaitable(
- ("@user:bz", None)
+ ("@user:test", None)
)
body["auth"]["test_field"] = "foo"
channel = self._delete_device(tok1, "dev2", body)
@@ -498,11 +481,11 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
callback = Mock(return_value=make_awaitable(None))
mock_password_provider.check_auth.return_value = make_awaitable(
- ("@user:bz", callback)
+ ("@user:test", callback)
)
channel = self._send_login("test.login_type", "u", test_field="y")
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
- self.assertEqual("@user:bz", channel.json_body["user_id"])
+ self.assertEqual("@user:test", channel.json_body["user_id"])
mock_password_provider.check_auth.assert_called_once_with(
"u", "test.login_type", {"test_field": "y"}
)
@@ -512,7 +495,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
call_args, call_kwargs = callback.call_args
# should be one positional arg
self.assertEqual(len(call_args), 1)
- self.assertEqual(call_args[0]["user_id"], "@user:bz")
+ self.assertEqual(call_args[0]["user_id"], "@user:test")
for p in ["user_id", "access_token", "device_id", "home_server"]:
self.assertIn(p, call_args[0])
diff --git a/tests/handlers/test_sso.py b/tests/handlers/test_sso.py
index 620ae3a4ba..b9ffdb4ced 100644
--- a/tests/handlers/test_sso.py
+++ b/tests/handlers/test_sso.py
@@ -31,7 +31,7 @@ class TestSSOHandler(unittest.HomeserverTestCase):
self.http_client.get_file.side_effect = mock_get_file
self.http_client.user_agent = b"Synapse Test"
hs = self.setup_test_homeserver(
- proxied_blacklisted_http_client=self.http_client
+ proxied_blocklisted_http_client=self.http_client
)
return hs
diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
index eb7f53fee5..105b4caefa 100644
--- a/tests/http/federation/test_matrix_federation_agent.py
+++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -269,8 +269,8 @@ class MatrixFederationAgentTests(unittest.TestCase):
reactor=cast(ISynapseReactor, self.reactor),
tls_client_options_factory=self.tls_factory,
user_agent=b"test-agent", # Note that this is unused since _well_known_resolver is provided.
- ip_whitelist=IPSet(),
- ip_blacklist=IPSet(),
+ ip_allowlist=IPSet(),
+ ip_blocklist=IPSet(),
_srv_resolver=self.mock_resolver,
_well_known_resolver=self.well_known_resolver,
)
@@ -997,8 +997,8 @@ class MatrixFederationAgentTests(unittest.TestCase):
reactor=self.reactor,
tls_client_options_factory=tls_factory,
user_agent=b"test-agent", # This is unused since _well_known_resolver is passed below.
- ip_whitelist=IPSet(),
- ip_blacklist=IPSet(),
+ ip_allowlist=IPSet(),
+ ip_blocklist=IPSet(),
_srv_resolver=self.mock_resolver,
_well_known_resolver=WellKnownResolver(
cast(ISynapseReactor, self.reactor),
diff --git a/tests/http/test_client.py b/tests/http/test_client.py
index 57b6a84e23..a05b9f17a6 100644
--- a/tests/http/test_client.py
+++ b/tests/http/test_client.py
@@ -27,8 +27,8 @@ from twisted.web.iweb import UNKNOWN_LENGTH
from synapse.api.errors import SynapseError
from synapse.http.client import (
- BlacklistingAgentWrapper,
- BlacklistingReactorWrapper,
+ BlocklistingAgentWrapper,
+ BlocklistingReactorWrapper,
BodyExceededMaxSize,
_DiscardBodyWithMaxSizeProtocol,
read_body_with_max_size,
@@ -140,7 +140,7 @@ class ReadBodyWithMaxSizeTests(TestCase):
self.assertEqual(result.getvalue(), b"")
-class BlacklistingAgentTest(TestCase):
+class BlocklistingAgentTest(TestCase):
def setUp(self) -> None:
self.reactor, self.clock = get_clock()
@@ -157,16 +157,16 @@ class BlacklistingAgentTest(TestCase):
self.reactor.lookups[domain.decode()] = ip.decode()
self.reactor.lookups[ip.decode()] = ip.decode()
- self.ip_whitelist = IPSet([self.allowed_ip.decode()])
- self.ip_blacklist = IPSet(["5.0.0.0/8"])
+ self.ip_allowlist = IPSet([self.allowed_ip.decode()])
+ self.ip_blocklist = IPSet(["5.0.0.0/8"])
def test_reactor(self) -> None:
- """Apply the blacklisting reactor and ensure it properly blocks connections to particular domains and IPs."""
+ """Apply the blocklisting reactor and ensure it properly blocks connections to particular domains and IPs."""
agent = Agent(
- BlacklistingReactorWrapper(
+ BlocklistingReactorWrapper(
self.reactor,
- ip_whitelist=self.ip_whitelist,
- ip_blacklist=self.ip_blacklist,
+ ip_allowlist=self.ip_allowlist,
+ ip_blocklist=self.ip_blocklist,
),
)
@@ -207,11 +207,11 @@ class BlacklistingAgentTest(TestCase):
self.assertEqual(response.code, 200)
def test_agent(self) -> None:
- """Apply the blacklisting agent and ensure it properly blocks connections to particular IPs."""
- agent = BlacklistingAgentWrapper(
+ """Apply the blocklisting agent and ensure it properly blocks connections to particular IPs."""
+ agent = BlocklistingAgentWrapper(
Agent(self.reactor),
- ip_blacklist=self.ip_blacklist,
- ip_whitelist=self.ip_whitelist,
+ ip_blocklist=self.ip_blocklist,
+ ip_allowlist=self.ip_allowlist,
)
# The unsafe IPs should be rejected.
diff --git a/tests/http/test_matrixfederationclient.py b/tests/http/test_matrixfederationclient.py
index d89a91c59d..0dfc03ce50 100644
--- a/tests/http/test_matrixfederationclient.py
+++ b/tests/http/test_matrixfederationclient.py
@@ -231,11 +231,11 @@ class FederationClientTests(HomeserverTestCase):
self.assertIsInstance(f.value, RequestSendFailed)
self.assertIsInstance(f.value.inner_exception, ResponseNeverReceived)
- def test_client_ip_range_blacklist(self) -> None:
- """Ensure that Synapse does not try to connect to blacklisted IPs"""
+ def test_client_ip_range_blocklist(self) -> None:
+ """Ensure that Synapse does not try to connect to blocked IPs"""
- # Set up the ip_range blacklist
- self.hs.config.server.federation_ip_range_blacklist = IPSet(
+ # Set up the ip_range blocklist
+ self.hs.config.server.federation_ip_range_blocklist = IPSet(
["127.0.0.0/8", "fe80::/64"]
)
self.reactor.lookups["internal"] = "127.0.0.1"
@@ -243,7 +243,7 @@ class FederationClientTests(HomeserverTestCase):
self.reactor.lookups["fine"] = "10.20.30.40"
cl = MatrixFederationHttpClient(self.hs, None)
- # Try making a GET request to a blacklisted IPv4 address
+ # Try making a GET request to a blocked IPv4 address
# ------------------------------------------------------
# Make the request
d = defer.ensureDeferred(cl.get_json("internal:8008", "foo/bar", timeout=10000))
@@ -261,7 +261,7 @@ class FederationClientTests(HomeserverTestCase):
self.assertIsInstance(f.value, RequestSendFailed)
self.assertIsInstance(f.value.inner_exception, DNSLookupError)
- # Try making a POST request to a blacklisted IPv6 address
+ # Try making a POST request to a blocked IPv6 address
# -------------------------------------------------------
# Make the request
d = defer.ensureDeferred(
@@ -278,11 +278,11 @@ class FederationClientTests(HomeserverTestCase):
clients = self.reactor.tcpClients
self.assertEqual(len(clients), 0)
- # Check that it was due to a blacklisted DNS lookup
+ # Check that it was due to a blocked DNS lookup
f = self.failureResultOf(d, RequestSendFailed)
self.assertIsInstance(f.value.inner_exception, DNSLookupError)
- # Try making a GET request to a non-blacklisted IPv4 address
+ # Try making a GET request to an allowed IPv4 address
# ----------------------------------------------------------
# Make the request
d = defer.ensureDeferred(cl.post_json("fine:8008", "foo/bar", timeout=10000))
diff --git a/tests/http/test_proxyagent.py b/tests/http/test_proxyagent.py
index cc175052ac..e0ae5a88ff 100644
--- a/tests/http/test_proxyagent.py
+++ b/tests/http/test_proxyagent.py
@@ -32,7 +32,7 @@ from twisted.internet.protocol import Factory, Protocol
from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
from twisted.web.http import HTTPChannel
-from synapse.http.client import BlacklistingReactorWrapper
+from synapse.http.client import BlocklistingReactorWrapper
from synapse.http.connectproxyclient import ProxyCredentials
from synapse.http.proxyagent import ProxyAgent, parse_proxy
@@ -684,11 +684,11 @@ class MatrixFederationAgentTests(TestCase):
self.assertEqual(body, b"result")
@patch.dict(os.environ, {"http_proxy": "proxy.com:8888"})
- def test_http_request_via_proxy_with_blacklist(self) -> None:
- # The blacklist includes the configured proxy IP.
+ def test_http_request_via_proxy_with_blocklist(self) -> None:
+ # The blocklist includes the configured proxy IP.
agent = ProxyAgent(
- BlacklistingReactorWrapper(
- self.reactor, ip_whitelist=None, ip_blacklist=IPSet(["1.0.0.0/8"])
+ BlocklistingReactorWrapper(
+ self.reactor, ip_allowlist=None, ip_blocklist=IPSet(["1.0.0.0/8"])
),
self.reactor,
use_proxy=True,
@@ -730,11 +730,11 @@ class MatrixFederationAgentTests(TestCase):
self.assertEqual(body, b"result")
@patch.dict(os.environ, {"HTTPS_PROXY": "proxy.com"})
- def test_https_request_via_uppercase_proxy_with_blacklist(self) -> None:
- # The blacklist includes the configured proxy IP.
+ def test_https_request_via_uppercase_proxy_with_blocklist(self) -> None:
+ # The blocklist includes the configured proxy IP.
agent = ProxyAgent(
- BlacklistingReactorWrapper(
- self.reactor, ip_whitelist=None, ip_blacklist=IPSet(["1.0.0.0/8"])
+ BlocklistingReactorWrapper(
+ self.reactor, ip_allowlist=None, ip_blocklist=IPSet(["1.0.0.0/8"])
),
self.reactor,
contextFactory=get_test_https_policy(),
diff --git a/tests/http/test_simple_client.py b/tests/http/test_simple_client.py
index 010601da4b..be731645bf 100644
--- a/tests/http/test_simple_client.py
+++ b/tests/http/test_simple_client.py
@@ -123,17 +123,17 @@ class SimpleHttpClientTests(HomeserverTestCase):
self.assertIsInstance(f.value, RequestTimedOutError)
- def test_client_ip_range_blacklist(self) -> None:
- """Ensure that Synapse does not try to connect to blacklisted IPs"""
+ def test_client_ip_range_blocklist(self) -> None:
+ """Ensure that Synapse does not try to connect to blocked IPs"""
- # Add some DNS entries we'll blacklist
+ # Add some DNS entries we'll block
self.reactor.lookups["internal"] = "127.0.0.1"
self.reactor.lookups["internalv6"] = "fe80:0:0:0:0:8a2e:370:7337"
- ip_blacklist = IPSet(["127.0.0.0/8", "fe80::/64"])
+ ip_blocklist = IPSet(["127.0.0.0/8", "fe80::/64"])
- cl = SimpleHttpClient(self.hs, ip_blacklist=ip_blacklist)
+ cl = SimpleHttpClient(self.hs, ip_blocklist=ip_blocklist)
- # Try making a GET request to a blacklisted IPv4 address
+ # Try making a GET request to a blocked IPv4 address
# ------------------------------------------------------
# Make the request
d = defer.ensureDeferred(cl.get_json("http://internal:8008/foo/bar"))
@@ -145,7 +145,7 @@ class SimpleHttpClientTests(HomeserverTestCase):
self.failureResultOf(d, DNSLookupError)
- # Try making a POST request to a blacklisted IPv6 address
+ # Try making a POST request to a blocked IPv6 address
# -------------------------------------------------------
# Make the request
d = defer.ensureDeferred(
@@ -159,10 +159,10 @@ class SimpleHttpClientTests(HomeserverTestCase):
clients = self.reactor.tcpClients
self.assertEqual(len(clients), 0)
- # Check that it was due to a blacklisted DNS lookup
+ # Check that it was due to a blocked DNS lookup
self.failureResultOf(d, DNSLookupError)
- # Try making a GET request to a non-blacklisted IPv4 address
+ # Try making a GET request to a non-blocked IPv4 address
# ----------------------------------------------------------
# Make the request
d = defer.ensureDeferred(cl.get_json("http://testserv:8008/foo/bar"))
diff --git a/tests/media/test_url_previewer.py b/tests/media/test_url_previewer.py
new file mode 100644
index 0000000000..3c4c7d6765
--- /dev/null
+++ b/tests/media/test_url_previewer.py
@@ -0,0 +1,113 @@
+# Copyright 2023 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.server import HomeServer
+from synapse.util import Clock
+
+from tests import unittest
+from tests.unittest import override_config
+
+try:
+ import lxml
+except ImportError:
+ lxml = None
+
+
+class URLPreviewTests(unittest.HomeserverTestCase):
+ if not lxml:
+ skip = "url preview feature requires lxml"
+
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+ config = self.default_config()
+ config["url_preview_enabled"] = True
+ config["max_spider_size"] = 9999999
+ config["url_preview_ip_range_blacklist"] = (
+ "192.168.1.1",
+ "1.0.0.0/8",
+ "3fff:ffff:ffff:ffff:ffff:ffff:ffff:ffff",
+ "2001:800::/21",
+ )
+
+ self.storage_path = self.mktemp()
+ self.media_store_path = self.mktemp()
+ os.mkdir(self.storage_path)
+ os.mkdir(self.media_store_path)
+ config["media_store_path"] = self.media_store_path
+
+ provider_config = {
+ "module": "synapse.media.storage_provider.FileStorageProviderBackend",
+ "store_local": True,
+ "store_synchronous": False,
+ "store_remote": True,
+ "config": {"directory": self.storage_path},
+ }
+
+ config["media_storage_providers"] = [provider_config]
+
+ return self.setup_test_homeserver(config=config)
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ media_repo_resource = hs.get_media_repository_resource()
+ preview_url = media_repo_resource.children[b"preview_url"]
+ self.url_previewer = preview_url._url_previewer
+
+ def test_all_urls_allowed(self) -> None:
+ self.assertFalse(self.url_previewer._is_url_blocked("http://matrix.org"))
+ self.assertFalse(self.url_previewer._is_url_blocked("https://matrix.org"))
+ self.assertFalse(self.url_previewer._is_url_blocked("http://localhost:8000"))
+ self.assertFalse(
+ self.url_previewer._is_url_blocked("http://user:pass@matrix.org")
+ )
+
+ @override_config(
+ {
+ "url_preview_url_blacklist": [
+ {"username": "user"},
+ {"scheme": "http", "netloc": "matrix.org"},
+ ]
+ }
+ )
+ def test_blocked_url(self) -> None:
+ # Blocked via scheme and URL.
+ self.assertTrue(self.url_previewer._is_url_blocked("http://matrix.org"))
+ # Not blocked because all components must match.
+ self.assertFalse(self.url_previewer._is_url_blocked("https://matrix.org"))
+
+ # Blocked due to the user.
+ self.assertTrue(
+ self.url_previewer._is_url_blocked("http://user:pass@example.com")
+ )
+ self.assertTrue(self.url_previewer._is_url_blocked("http://user@example.com"))
+
+ @override_config({"url_preview_url_blacklist": [{"netloc": "*.example.com"}]})
+ def test_glob_blocked_url(self) -> None:
+ # All subdomains are blocked.
+ self.assertTrue(self.url_previewer._is_url_blocked("http://foo.example.com"))
+ self.assertTrue(self.url_previewer._is_url_blocked("http://.example.com"))
+
+ # The TLD is not blocked.
+ self.assertFalse(self.url_previewer._is_url_blocked("https://example.com"))
+
+ @override_config({"url_preview_url_blacklist": [{"netloc": "^.+\\.example\\.com"}]})
+ def test_regex_blocked_urL(self) -> None:
+ # All subdomains are blocked.
+ self.assertTrue(self.url_previewer._is_url_blocked("http://foo.example.com"))
+ # Requires a non-empty subdomain.
+ self.assertFalse(self.url_previewer._is_url_blocked("http://.example.com"))
+
+ # The TLD is not blocked.
+ self.assertFalse(self.url_previewer._is_url_blocked("https://example.com"))
diff --git a/tests/push/test_http.py b/tests/push/test_http.py
index 54f558742d..e68a979ee0 100644
--- a/tests/push/test_http.py
+++ b/tests/push/test_http.py
@@ -52,7 +52,7 @@ class HTTPPusherTests(HomeserverTestCase):
m.post_json_get_json = post_json_get_json
- hs = self.setup_test_homeserver(proxied_blacklisted_http_client=m)
+ hs = self.setup_test_homeserver(proxied_blocklisted_http_client=m)
return hs
diff --git a/tests/replication/slave/storage/__init__.py b/tests/replication/slave/storage/__init__.py
deleted file mode 100644
index f43a360a80..0000000000
--- a/tests/replication/slave/storage/__init__.py
+++ /dev/null
@@ -1,13 +0,0 @@
-# Copyright 2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
diff --git a/tests/replication/slave/__init__.py b/tests/replication/storage/__init__.py
index f43a360a80..f43a360a80 100644
--- a/tests/replication/slave/__init__.py
+++ b/tests/replication/storage/__init__.py
diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/storage/_base.py
index 4c9b494344..de26a62ae1 100644
--- a/tests/replication/slave/storage/_base.py
+++ b/tests/replication/storage/_base.py
@@ -24,7 +24,7 @@ from synapse.util import Clock
from tests.replication._base import BaseStreamTestCase
-class BaseSlavedStoreTestCase(BaseStreamTestCase):
+class BaseWorkerStoreTestCase(BaseStreamTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
return self.setup_test_homeserver(federation_client=Mock())
@@ -34,7 +34,7 @@ class BaseSlavedStoreTestCase(BaseStreamTestCase):
self.reconnect()
self.master_store = hs.get_datastores().main
- self.slaved_store = self.worker_hs.get_datastores().main
+ self.worker_store = self.worker_hs.get_datastores().main
persistence = hs.get_storage_controllers().persistence
assert persistence is not None
self.persistance = persistence
@@ -50,7 +50,7 @@ class BaseSlavedStoreTestCase(BaseStreamTestCase):
self, method: str, args: Iterable[Any], expected_result: Optional[Any] = None
) -> None:
master_result = self.get_success(getattr(self.master_store, method)(*args))
- slaved_result = self.get_success(getattr(self.slaved_store, method)(*args))
+ worker_result = self.get_success(getattr(self.worker_store, method)(*args))
if expected_result is not None:
self.assertEqual(
master_result,
@@ -59,14 +59,14 @@ class BaseSlavedStoreTestCase(BaseStreamTestCase):
% (expected_result, master_result),
)
self.assertEqual(
- slaved_result,
+ worker_result,
expected_result,
- "Expected slave result to be %r but was %r"
- % (expected_result, slaved_result),
+ "Expected worker result to be %r but was %r"
+ % (expected_result, worker_result),
)
self.assertEqual(
master_result,
- slaved_result,
- "Slave result %r does not match master result %r"
- % (slaved_result, master_result),
+ worker_result,
+ "Worker result %r does not match master result %r"
+ % (worker_result, master_result),
)
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/storage/test_events.py
index b2125b1fea..f7c6417a09 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/storage/test_events.py
@@ -36,7 +36,7 @@ from synapse.util import Clock
from tests.server import FakeTransport
-from ._base import BaseSlavedStoreTestCase
+from ._base import BaseWorkerStoreTestCase
USER_ID = "@feeling:test"
USER_ID_2 = "@bright:test"
@@ -63,7 +63,7 @@ def patch__eq__(cls: object) -> Callable[[], None]:
return unpatch
-class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
+class EventsWorkerStoreTestCase(BaseWorkerStoreTestCase):
STORE_TYPE = EventsWorkerStore
def setUp(self) -> None:
@@ -294,7 +294,7 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
assert j2.internal_metadata.stream_ordering is not None
event_source = RoomEventSource(self.hs)
- event_source.store = self.slaved_store
+ event_source.store = self.worker_store
current_token = event_source.get_current_key()
# gradually stream out the replication
@@ -310,12 +310,12 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
#
# First, we get a list of the rooms we are joined to
joined_rooms = self.get_success(
- self.slaved_store.get_rooms_for_user_with_stream_ordering(USER_ID_2)
+ self.worker_store.get_rooms_for_user_with_stream_ordering(USER_ID_2)
)
# Then, we get a list of the events since the last sync
membership_changes = self.get_success(
- self.slaved_store.get_membership_changes_for_user(
+ self.worker_store.get_membership_changes_for_user(
USER_ID_2, prev_token, current_token
)
)
diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py
index dcb3e6669b..875811669c 100644
--- a/tests/replication/test_pusher_shard.py
+++ b/tests/replication/test_pusher_shard.py
@@ -93,7 +93,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
self.make_worker_hs(
"synapse.app.generic_worker",
{"worker_name": "pusher1", "pusher_instances": ["pusher1"]},
- proxied_blacklisted_http_client=http_client_mock,
+ proxied_blocklisted_http_client=http_client_mock,
)
event_id = self._create_pusher_and_send_msg("user")
@@ -126,7 +126,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
"worker_name": "pusher1",
"pusher_instances": ["pusher1", "pusher2"],
},
- proxied_blacklisted_http_client=http_client_mock1,
+ proxied_blocklisted_http_client=http_client_mock1,
)
http_client_mock2 = Mock(spec_set=["post_json_get_json"])
@@ -140,7 +140,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
"worker_name": "pusher2",
"pusher_instances": ["pusher1", "pusher2"],
},
- proxied_blacklisted_http_client=http_client_mock2,
+ proxied_blocklisted_http_client=http_client_mock2,
)
# We choose a user name that we know should go to pusher1.
diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py
index 62acf4f44e..dc32982e22 100644
--- a/tests/rest/client/test_login.py
+++ b/tests/rest/client/test_login.py
@@ -42,7 +42,7 @@ from tests.test_utils.html_parsers import TestHtmlParser
from tests.unittest import HomeserverTestCase, override_config, skip_unless
try:
- from authlib.jose import jwk, jwt
+ from authlib.jose import JsonWebKey, jwt
HAS_JWT = True
except ImportError:
@@ -1054,6 +1054,22 @@ class JWTTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(channel.json_body["error"], "Token field for JWT is missing")
+ def test_deactivated_user(self) -> None:
+ """Logging in as a deactivated account should error."""
+ user_id = self.register_user("kermit", "monkey")
+ self.get_success(
+ self.hs.get_deactivate_account_handler().deactivate_account(
+ user_id, erase_data=False, requester=create_requester(user_id)
+ )
+ )
+
+ channel = self.jwt_login({"sub": "kermit"})
+ self.assertEqual(channel.code, 403, msg=channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_USER_DEACTIVATED")
+ self.assertEqual(
+ channel.json_body["error"], "This account has been deactivated"
+ )
+
# The JWTPubKeyTestCase is a complement to JWTTestCase where we instead use
# RSS256, with a public key configured in synapse as "jwt_secret", and tokens
@@ -1121,7 +1137,7 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase):
def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_privatekey) -> str:
header = {"alg": "RS256"}
if secret.startswith("-----BEGIN RSA PRIVATE KEY-----"):
- secret = jwk.dumps(secret, kty="RSA")
+ secret = JsonWebKey.import_key(secret, {"kty": "RSA"})
result: bytes = jwt.encode(header, payload, secret)
return result.decode("ascii")
diff --git a/tests/rest/client/test_mutual_rooms.py b/tests/rest/client/test_mutual_rooms.py
index a4327f7ace..22fddbd6d6 100644
--- a/tests/rest/client/test_mutual_rooms.py
+++ b/tests/rest/client/test_mutual_rooms.py
@@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from urllib.parse import quote
+
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
@@ -44,8 +46,8 @@ class UserMutualRoomsTest(unittest.HomeserverTestCase):
def _get_mutual_rooms(self, token: str, other_user: str) -> FakeChannel:
return self.make_request(
"GET",
- "/_matrix/client/unstable/uk.half-shot.msc2666/user/mutual_rooms/%s"
- % other_user,
+ "/_matrix/client/unstable/uk.half-shot.msc2666/user/mutual_rooms"
+ f"?user_id={quote(other_user)}",
access_token=token,
)
diff --git a/tests/rest/client/test_read_marker.py b/tests/rest/client/test_read_marker.py
new file mode 100644
index 0000000000..0eedcdb476
--- /dev/null
+++ b/tests/rest/client/test_read_marker.py
@@ -0,0 +1,147 @@
+# Copyright 2023 Beeper
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from twisted.test.proto_helpers import MemoryReactor
+
+import synapse.rest.admin
+from synapse.api.constants import EventTypes
+from synapse.rest import admin
+from synapse.rest.client import login, read_marker, register, room
+from synapse.server import HomeServer
+from synapse.util import Clock
+
+from tests import unittest
+
+ONE_HOUR_MS = 3600000
+ONE_DAY_MS = ONE_HOUR_MS * 24
+
+
+class ReadMarkerTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ login.register_servlets,
+ register.register_servlets,
+ read_marker.register_servlets,
+ room.register_servlets,
+ synapse.rest.admin.register_servlets,
+ admin.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+ config = self.default_config()
+
+ # merge this default retention config with anything that was specified in
+ # @override_config
+ retention_config = {
+ "enabled": True,
+ "allowed_lifetime_min": ONE_DAY_MS,
+ "allowed_lifetime_max": ONE_DAY_MS * 3,
+ }
+ retention_config.update(config.get("retention", {}))
+ config["retention"] = retention_config
+
+ self.hs = self.setup_test_homeserver(config=config)
+
+ return self.hs
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.owner = self.register_user("owner", "pass")
+ self.owner_tok = self.login("owner", "pass")
+ self.store = self.hs.get_datastores().main
+ self.clock = self.hs.get_clock()
+
+ def test_send_read_marker(self) -> None:
+ room_id = self.helper.create_room_as(self.owner, tok=self.owner_tok)
+
+ def send_message() -> str:
+ res = self.helper.send(room_id=room_id, body="1", tok=self.owner_tok)
+ return res["event_id"]
+
+ # Test setting the read marker on the room
+ event_id_1 = send_message()
+
+ channel = self.make_request(
+ "POST",
+ "/rooms/!abc:beep/read_markers",
+ content={
+ "m.fully_read": event_id_1,
+ },
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ # Test moving the read marker to a newer event
+ event_id_2 = send_message()
+ channel = self.make_request(
+ "POST",
+ "/rooms/!abc:beep/read_markers",
+ content={
+ "m.fully_read": event_id_2,
+ },
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ def test_send_read_marker_missing_previous_event(self) -> None:
+ """
+ Test moving a read marker from an event that previously existed but was
+ later removed due to retention rules.
+ """
+
+ room_id = self.helper.create_room_as(self.owner, tok=self.owner_tok)
+
+ # Set retention rule on the room so we remove old events to test this case
+ self.helper.send_state(
+ room_id=room_id,
+ event_type=EventTypes.Retention,
+ body={"max_lifetime": ONE_DAY_MS},
+ tok=self.owner_tok,
+ )
+
+ def send_message() -> str:
+ res = self.helper.send(room_id=room_id, body="1", tok=self.owner_tok)
+ return res["event_id"]
+
+ # Test setting the read marker on the room
+ event_id_1 = send_message()
+
+ channel = self.make_request(
+ "POST",
+ "/rooms/!abc:beep/read_markers",
+ content={
+ "m.fully_read": event_id_1,
+ },
+ access_token=self.owner_tok,
+ )
+
+ # Send a second message (retention will not remove the latest event ever)
+ send_message()
+ # And then advance so retention rules remove the first event (where the marker is)
+ self.reactor.advance(ONE_DAY_MS * 2 / 1000)
+
+ event = self.get_success(self.store.get_event(event_id_1, allow_none=True))
+ assert event is None
+
+ # TODO See https://github.com/matrix-org/synapse/issues/13476
+ self.store.get_event_ordering.invalidate_all()
+
+ # Test moving the read marker to a newer event
+ event_id_2 = send_message()
+ channel = self.make_request(
+ "POST",
+ "/rooms/!abc:beep/read_markers",
+ content={
+ "m.fully_read": event_id_2,
+ },
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
diff --git a/tests/rest/media/test_url_preview.py b/tests/rest/media/test_url_preview.py
index e44beae8c1..170fb0534a 100644
--- a/tests/rest/media/test_url_preview.py
+++ b/tests/rest/media/test_url_preview.py
@@ -418,9 +418,9 @@ class URLPreviewTests(unittest.HomeserverTestCase):
channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
)
- def test_blacklisted_ip_specific(self) -> None:
+ def test_blocked_ip_specific(self) -> None:
"""
- Blacklisted IP addresses, found via DNS, are not spidered.
+ Blocked IP addresses, found via DNS, are not spidered.
"""
self.lookups["example.com"] = [(IPv4Address, "192.168.1.1")]
@@ -439,9 +439,9 @@ class URLPreviewTests(unittest.HomeserverTestCase):
},
)
- def test_blacklisted_ip_range(self) -> None:
+ def test_blocked_ip_range(self) -> None:
"""
- Blacklisted IP ranges, IPs found over DNS, are not spidered.
+ Blocked IP ranges, IPs found over DNS, are not spidered.
"""
self.lookups["example.com"] = [(IPv4Address, "1.1.1.2")]
@@ -458,9 +458,9 @@ class URLPreviewTests(unittest.HomeserverTestCase):
},
)
- def test_blacklisted_ip_specific_direct(self) -> None:
+ def test_blocked_ip_specific_direct(self) -> None:
"""
- Blacklisted IP addresses, accessed directly, are not spidered.
+ Blocked IP addresses, accessed directly, are not spidered.
"""
channel = self.make_request(
"GET", "preview_url?url=http://192.168.1.1", shorthand=False
@@ -470,16 +470,13 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.assertEqual(len(self.reactor.tcpClients), 0)
self.assertEqual(
channel.json_body,
- {
- "errcode": "M_UNKNOWN",
- "error": "IP address blocked by IP blacklist entry",
- },
+ {"errcode": "M_UNKNOWN", "error": "IP address blocked"},
)
self.assertEqual(channel.code, 403)
- def test_blacklisted_ip_range_direct(self) -> None:
+ def test_blocked_ip_range_direct(self) -> None:
"""
- Blacklisted IP ranges, accessed directly, are not spidered.
+ Blocked IP ranges, accessed directly, are not spidered.
"""
channel = self.make_request(
"GET", "preview_url?url=http://1.1.1.2", shorthand=False
@@ -488,15 +485,12 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 403)
self.assertEqual(
channel.json_body,
- {
- "errcode": "M_UNKNOWN",
- "error": "IP address blocked by IP blacklist entry",
- },
+ {"errcode": "M_UNKNOWN", "error": "IP address blocked"},
)
- def test_blacklisted_ip_range_whitelisted_ip(self) -> None:
+ def test_blocked_ip_range_whitelisted_ip(self) -> None:
"""
- Blacklisted but then subsequently whitelisted IP addresses can be
+ Blocked but then subsequently whitelisted IP addresses can be
spidered.
"""
self.lookups["example.com"] = [(IPv4Address, "1.1.1.1")]
@@ -527,10 +521,10 @@ class URLPreviewTests(unittest.HomeserverTestCase):
channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
)
- def test_blacklisted_ip_with_external_ip(self) -> None:
+ def test_blocked_ip_with_external_ip(self) -> None:
"""
- If a hostname resolves a blacklisted IP, even if there's a
- non-blacklisted one, it will be rejected.
+ If a hostname resolves a blocked IP, even if there's a non-blocked one,
+ it will be rejected.
"""
# Hardcode the URL resolving to the IP we want.
self.lookups["example.com"] = [
@@ -550,9 +544,9 @@ class URLPreviewTests(unittest.HomeserverTestCase):
},
)
- def test_blacklisted_ipv6_specific(self) -> None:
+ def test_blocked_ipv6_specific(self) -> None:
"""
- Blacklisted IP addresses, found via DNS, are not spidered.
+ Blocked IP addresses, found via DNS, are not spidered.
"""
self.lookups["example.com"] = [
(IPv6Address, "3fff:ffff:ffff:ffff:ffff:ffff:ffff:ffff")
@@ -573,9 +567,9 @@ class URLPreviewTests(unittest.HomeserverTestCase):
},
)
- def test_blacklisted_ipv6_range(self) -> None:
+ def test_blocked_ipv6_range(self) -> None:
"""
- Blacklisted IP ranges, IPs found over DNS, are not spidered.
+ Blocked IP ranges, IPs found over DNS, are not spidered.
"""
self.lookups["example.com"] = [(IPv6Address, "2001:800::1")]
@@ -653,6 +647,57 @@ class URLPreviewTests(unittest.HomeserverTestCase):
server.data,
)
+ def test_image(self) -> None:
+ """An image should be precached if mentioned in the HTML."""
+ self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
+ self.lookups["cdn.matrix.org"] = [(IPv4Address, "10.1.2.4")]
+
+ result = (
+ b"""<html><body><img src="http://cdn.matrix.org/foo.png"></body></html>"""
+ )
+
+ channel = self.make_request(
+ "GET",
+ "preview_url?url=http://matrix.org",
+ shorthand=False,
+ await_result=False,
+ )
+ self.pump()
+
+ # Respond with the HTML.
+ client = self.reactor.tcpClients[0][2].buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+ client.dataReceived(
+ (
+ b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
+ b'Content-Type: text/html; charset="utf8"\r\n\r\n'
+ )
+ % (len(result),)
+ + result
+ )
+ self.pump()
+
+ # Respond with the photo.
+ client = self.reactor.tcpClients[1][2].buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+ client.dataReceived(
+ (
+ b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
+ b"Content-Type: image/png\r\n\r\n"
+ )
+ % (len(SMALL_PNG),)
+ + SMALL_PNG
+ )
+ self.pump()
+
+ # The image should be in the result.
+ self.assertEqual(channel.code, 200)
+ self._assert_small_png(channel.json_body)
+
def test_nonexistent_image(self) -> None:
"""If the preview image doesn't exist, ensure some data is returned."""
self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
@@ -683,9 +728,53 @@ class URLPreviewTests(unittest.HomeserverTestCase):
)
self.pump()
+
+ # There should not be a second connection.
+ self.assertEqual(len(self.reactor.tcpClients), 1)
+
+ # The image should not be in the result.
self.assertEqual(channel.code, 200)
+ self.assertNotIn("og:image", channel.json_body)
+
+ @unittest.override_config(
+ {"url_preview_url_blacklist": [{"netloc": "cdn.matrix.org"}]}
+ )
+ def test_image_blocked(self) -> None:
+ """If the preview image doesn't exist, ensure some data is returned."""
+ self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
+ self.lookups["cdn.matrix.org"] = [(IPv4Address, "10.1.2.4")]
+
+ result = (
+ b"""<html><body><img src="http://cdn.matrix.org/foo.jpg"></body></html>"""
+ )
+
+ channel = self.make_request(
+ "GET",
+ "preview_url?url=http://matrix.org",
+ shorthand=False,
+ await_result=False,
+ )
+ self.pump()
+
+ client = self.reactor.tcpClients[0][2].buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+ client.dataReceived(
+ (
+ b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
+ b'Content-Type: text/html; charset="utf8"\r\n\r\n'
+ )
+ % (len(result),)
+ + result
+ )
+ self.pump()
+
+ # There should not be a second connection.
+ self.assertEqual(len(self.reactor.tcpClients), 1)
# The image should not be in the result.
+ self.assertEqual(channel.code, 200)
self.assertNotIn("og:image", channel.json_body)
def test_oembed_failure(self) -> None:
@@ -880,6 +969,11 @@ class URLPreviewTests(unittest.HomeserverTestCase):
)
self.pump()
+
+ # Double check that the proper host is being connected to. (Note that
+ # twitter.com can't be resolved so this is already implicitly checked.)
+ self.assertIn(b"\r\nHost: publish.twitter.com\r\n", server.data)
+
self.assertEqual(channel.code, 200)
body = channel.json_body
self.assertEqual(
@@ -940,6 +1034,22 @@ class URLPreviewTests(unittest.HomeserverTestCase):
},
)
+ @unittest.override_config(
+ {"url_preview_url_blacklist": [{"netloc": "publish.twitter.com"}]}
+ )
+ def test_oembed_blocked(self) -> None:
+ """The oEmbed URL should not be downloaded if the oEmbed URL is blocked."""
+ self.lookups["twitter.com"] = [(IPv4Address, "10.1.2.3")]
+
+ channel = self.make_request(
+ "GET",
+ "preview_url?url=http://twitter.com/matrixdotorg/status/12345",
+ shorthand=False,
+ await_result=False,
+ )
+ self.pump()
+ self.assertEqual(channel.code, 403, channel.result)
+
def test_oembed_autodiscovery(self) -> None:
"""
Autodiscovery works by finding the link in the HTML response and then requesting an oEmbed URL.
@@ -980,7 +1090,6 @@ class URLPreviewTests(unittest.HomeserverTestCase):
% (len(result),)
+ result
)
-
self.pump()
# The oEmbed response.
@@ -1004,7 +1113,6 @@ class URLPreviewTests(unittest.HomeserverTestCase):
% (len(oembed_content),)
+ oembed_content
)
-
self.pump()
# Ensure the URL is what was requested.
@@ -1023,7 +1131,6 @@ class URLPreviewTests(unittest.HomeserverTestCase):
% (len(SMALL_PNG),)
+ SMALL_PNG
)
-
self.pump()
# Ensure the URL is what was requested.
@@ -1036,6 +1143,59 @@ class URLPreviewTests(unittest.HomeserverTestCase):
)
self._assert_small_png(body)
+ @unittest.override_config(
+ {"url_preview_url_blacklist": [{"netloc": "publish.twitter.com"}]}
+ )
+ def test_oembed_autodiscovery_blocked(self) -> None:
+ """
+ If the discovered oEmbed URL is blocked, it should be discarded.
+ """
+ # This is a little cheesy in that we use the www subdomain (which isn't the
+ # list of oEmbed patterns) to get "raw" HTML response.
+ self.lookups["www.twitter.com"] = [(IPv4Address, "10.1.2.3")]
+ self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.4")]
+
+ result = b"""
+ <title>Test</title>
+ <link rel="alternate" type="application/json+oembed"
+ href="http://publish.twitter.com/oembed?url=http%3A%2F%2Fcdn.twitter.com%2Fmatrixdotorg%2Fstatus%2F12345&format=json"
+ title="matrixdotorg" />
+ """
+
+ channel = self.make_request(
+ "GET",
+ "preview_url?url=http://www.twitter.com/matrixdotorg/status/12345",
+ shorthand=False,
+ await_result=False,
+ )
+ self.pump()
+
+ client = self.reactor.tcpClients[0][2].buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+ client.dataReceived(
+ (
+ b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
+ b'Content-Type: text/html; charset="utf8"\r\n\r\n'
+ )
+ % (len(result),)
+ + result
+ )
+
+ self.pump()
+
+ # Ensure there's no additional connections.
+ self.assertEqual(len(self.reactor.tcpClients), 1)
+
+ # Ensure the URL is what was requested.
+ self.assertIn(b"\r\nHost: www.twitter.com\r\n", server.data)
+
+ self.assertEqual(channel.code, 200)
+ body = channel.json_body
+ self.assertEqual(body["og:title"], "Test")
+ self.assertNotIn("og:image", body)
+
def _download_image(self) -> Tuple[str, str]:
"""Downloads an image into the URL cache.
Returns:
@@ -1192,8 +1352,8 @@ class URLPreviewTests(unittest.HomeserverTestCase):
)
@unittest.override_config({"url_preview_url_blacklist": [{"port": "*"}]})
- def test_blacklist_port(self) -> None:
- """Tests that blacklisting URLs with a port makes previewing such URLs
+ def test_blocked_port(self) -> None:
+ """Tests that blocking URLs with a port makes previewing such URLs
fail with a 403 error and doesn't impact other previews.
"""
self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
@@ -1230,3 +1390,23 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.pump()
self.assertEqual(channel.code, 200)
+
+ @unittest.override_config(
+ {"url_preview_url_blacklist": [{"netloc": "example.com"}]}
+ )
+ def test_blocked_url(self) -> None:
+ """Tests that blocking URLs with a host makes previewing such URLs
+ fail with a 403 error.
+ """
+ self.lookups["example.com"] = [(IPv4Address, "10.1.2.3")]
+
+ bad_url = quote("http://example.com/foo")
+
+ channel = self.make_request(
+ "GET",
+ "preview_url?url=" + bad_url,
+ shorthand=False,
+ await_result=False,
+ )
+ self.pump()
+ self.assertEqual(channel.code, 403, channel.result)
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index 81e50bdd55..4b8d8328d7 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -1134,6 +1134,43 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
self.assertEqual(backfill_event_ids, ["insertion_eventA"])
+ def test_get_event_ids_with_failed_pull_attempts(self) -> None:
+ """
+ Test to make sure we properly get event_ids based on whether they have any
+ failed pull attempts.
+ """
+ # Create the room
+ user_id = self.register_user("alice", "test")
+ tok = self.login("alice", "test")
+ room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
+
+ self.get_success(
+ self.store.record_event_failed_pull_attempt(
+ room_id, "$failed_event_id1", "fake cause"
+ )
+ )
+ self.get_success(
+ self.store.record_event_failed_pull_attempt(
+ room_id, "$failed_event_id2", "fake cause"
+ )
+ )
+
+ event_ids_with_failed_pull_attempts = self.get_success(
+ self.store.get_event_ids_with_failed_pull_attempts(
+ event_ids=[
+ "$failed_event_id1",
+ "$fresh_event_id1",
+ "$failed_event_id2",
+ "$fresh_event_id2",
+ ]
+ )
+ )
+
+ self.assertEqual(
+ event_ids_with_failed_pull_attempts,
+ {"$failed_event_id1", "$failed_event_id2"},
+ )
+
def test_get_event_ids_to_not_pull_from_backoff(self) -> None:
"""
Test to make sure only event IDs we should backoff from are returned.
diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py
index 6ec34997ea..f9cf0fcb82 100644
--- a/tests/storage/test_profile.py
+++ b/tests/storage/test_profile.py
@@ -14,6 +14,8 @@
from twisted.test.proto_helpers import MemoryReactor
from synapse.server import HomeServer
+from synapse.storage.database import LoggingTransaction
+from synapse.storage.engines import PostgresEngine
from synapse.types import UserID
from synapse.util import Clock
@@ -69,3 +71,64 @@ class ProfileStoreTestCase(unittest.HomeserverTestCase):
self.assertIsNone(
self.get_success(self.store.get_profile_avatar_url(self.u_frank.localpart))
)
+
+ def test_profiles_bg_migration(self) -> None:
+ """
+ Test background job that copies entries from column user_id to full_user_id, adding
+ the hostname in the process.
+ """
+ updater = self.hs.get_datastores().main.db_pool.updates
+
+ # drop the constraint so we can insert nulls in full_user_id to populate the test
+ if isinstance(self.store.database_engine, PostgresEngine):
+
+ def f(txn: LoggingTransaction) -> None:
+ txn.execute(
+ "ALTER TABLE profiles DROP CONSTRAINT full_user_id_not_null"
+ )
+
+ self.get_success(self.store.db_pool.runInteraction("", f))
+
+ for i in range(0, 70):
+ self.get_success(
+ self.store.db_pool.simple_insert(
+ "profiles",
+ {"user_id": f"hello{i:02}"},
+ )
+ )
+
+ # re-add the constraint so that when it's validated it actually exists
+ if isinstance(self.store.database_engine, PostgresEngine):
+
+ def f(txn: LoggingTransaction) -> None:
+ txn.execute(
+ "ALTER TABLE profiles ADD CONSTRAINT full_user_id_not_null CHECK (full_user_id IS NOT NULL) NOT VALID"
+ )
+
+ self.get_success(self.store.db_pool.runInteraction("", f))
+
+ self.get_success(
+ self.store.db_pool.simple_insert(
+ "background_updates",
+ values={
+ "update_name": "populate_full_user_id_profiles",
+ "progress_json": "{}",
+ },
+ )
+ )
+
+ self.get_success(
+ updater.run_background_updates(False),
+ )
+
+ expected_values = []
+ for i in range(0, 70):
+ expected_values.append((f"@hello{i:02}:{self.hs.hostname}",))
+
+ res = self.get_success(
+ self.store.db_pool.execute(
+ "", None, "SELECT full_user_id from profiles ORDER BY full_user_id"
+ )
+ )
+ self.assertEqual(len(res), len(expected_values))
+ self.assertEqual(res, expected_values)
diff --git a/tests/storage/test_transactions.py b/tests/storage/test_transactions.py
index db9ee9955e..2fab84a529 100644
--- a/tests/storage/test_transactions.py
+++ b/tests/storage/test_transactions.py
@@ -33,15 +33,14 @@ class TransactionStoreTestCase(HomeserverTestCase):
destination retries, as well as testing tht we can set and get
correctly.
"""
- d = self.store.get_destination_retry_timings("example.com")
- r = self.get_success(d)
+ r = self.get_success(self.store.get_destination_retry_timings("example.com"))
self.assertIsNone(r)
- d = self.store.set_destination_retry_timings("example.com", 1000, 50, 100)
- self.get_success(d)
+ self.get_success(
+ self.store.set_destination_retry_timings("example.com", 1000, 50, 100)
+ )
- d = self.store.get_destination_retry_timings("example.com")
- r = self.get_success(d)
+ r = self.get_success(self.store.get_destination_retry_timings("example.com"))
self.assertEqual(
DestinationRetryTimings(
diff --git a/tests/storage/test_user_filters.py b/tests/storage/test_user_filters.py
new file mode 100644
index 0000000000..bab802f56e
--- /dev/null
+++ b/tests/storage/test_user_filters.py
@@ -0,0 +1,94 @@
+# Copyright 2023 The Matrix.org Foundation C.I.C
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.server import HomeServer
+from synapse.storage.database import LoggingTransaction
+from synapse.storage.engines import PostgresEngine
+from synapse.util import Clock
+
+from tests import unittest
+
+
+class UserFiltersStoreTestCase(unittest.HomeserverTestCase):
+ """
+ Test background migration that copies entries from column user_id to full_user_id, adding
+ the hostname in the process.
+ """
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.store = hs.get_datastores().main
+
+ def test_bg_migration(self) -> None:
+ updater = self.hs.get_datastores().main.db_pool.updates
+
+ # drop the constraint so we can insert nulls in full_user_id to populate the test
+ if isinstance(self.store.database_engine, PostgresEngine):
+
+ def f(txn: LoggingTransaction) -> None:
+ txn.execute(
+ "ALTER TABLE user_filters DROP CONSTRAINT full_user_id_not_null"
+ )
+
+ self.get_success(self.store.db_pool.runInteraction("", f))
+
+ for i in range(0, 70):
+ self.get_success(
+ self.store.db_pool.simple_insert(
+ "user_filters",
+ {
+ "user_id": f"hello{i:02}",
+ "filter_id": i,
+ "filter_json": bytearray(i),
+ },
+ )
+ )
+
+ # re-add the constraint so that when it's validated it actually exists
+ if isinstance(self.store.database_engine, PostgresEngine):
+
+ def f(txn: LoggingTransaction) -> None:
+ txn.execute(
+ "ALTER TABLE user_filters ADD CONSTRAINT full_user_id_not_null CHECK (full_user_id IS NOT NULL) NOT VALID"
+ )
+
+ self.get_success(self.store.db_pool.runInteraction("", f))
+
+ self.get_success(
+ self.store.db_pool.simple_insert(
+ "background_updates",
+ values={
+ "update_name": "populate_full_user_id_user_filters",
+ "progress_json": "{}",
+ },
+ )
+ )
+
+ self.get_success(
+ updater.run_background_updates(False),
+ )
+
+ expected_values = []
+ for i in range(0, 70):
+ expected_values.append((f"@hello{i:02}:{self.hs.hostname}",))
+
+ res = self.get_success(
+ self.store.db_pool.execute(
+ "", None, "SELECT full_user_id from user_filters ORDER BY full_user_id"
+ )
+ )
+ self.assertEqual(len(res), len(expected_values))
+ self.assertEqual(res, expected_values)
diff --git a/tests/test_state.py b/tests/test_state.py
index 2029d3d60a..ddf59916b1 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -264,7 +264,7 @@ class StateTestCase(unittest.TestCase):
self.dummy_store.register_events(graph.walk())
- context_store: dict[str, EventContext] = {}
+ context_store: Dict[str, EventContext] = {}
for event in graph.walk():
context = yield defer.ensureDeferred(
diff --git a/tests/test_utils/logging_setup.py b/tests/test_utils/logging_setup.py
index b522163a34..c37f205ed0 100644
--- a/tests/test_utils/logging_setup.py
+++ b/tests/test_utils/logging_setup.py
@@ -40,10 +40,9 @@ def setup_logging() -> None:
"""
root_logger = logging.getLogger()
- log_format = (
- "%(asctime)s - %(name)s - %(lineno)d - "
- "%(levelname)s - %(request)s - %(message)s"
- )
+ # We exclude `%(asctime)s` from this format because the Twisted logger adds its own
+ # timestamp
+ log_format = "%(name)s - %(lineno)d - " "%(levelname)s - %(request)s - %(message)s"
handler = ToTwistedHandler()
formatter = logging.Formatter(log_format)
diff --git a/tests/unittest.py b/tests/unittest.py
index b6fdf69635..c73195b32b 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -13,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import functools
import gc
import hashlib
import hmac
@@ -150,7 +151,11 @@ def deepcopy_config(config: _TConfig) -> _TConfig:
return new_config
-_make_homeserver_config_obj_cache: Dict[str, Union[RootConfig, Config]] = {}
+@functools.lru_cache(maxsize=8)
+def _parse_config_dict(config: str) -> RootConfig:
+ config_obj = HomeServerConfig()
+ config_obj.parse_config_dict(json.loads(config), "", "")
+ return config_obj
def make_homeserver_config_obj(config: Dict[str, Any]) -> RootConfig:
@@ -164,21 +169,7 @@ def make_homeserver_config_obj(config: Dict[str, Any]) -> RootConfig:
but it keeps a cache of `HomeServerConfig` instances and deepcopies them as needed,
to avoid validating the whole configuration every time.
"""
- cache_key = json.dumps(config)
-
- if cache_key in _make_homeserver_config_obj_cache:
- # Cache hit: reuse the existing instance
- config_obj = _make_homeserver_config_obj_cache[cache_key]
- else:
- # Cache miss; create the actual instance
- config_obj = HomeServerConfig()
- config_obj.parse_config_dict(config, "", "")
-
- # Add to the cache
- _make_homeserver_config_obj_cache[cache_key] = config_obj
-
- assert isinstance(config_obj, RootConfig)
-
+ config_obj = _parse_config_dict(json.dumps(config, sort_keys=True))
return deepcopy_config(config_obj)
@@ -229,13 +220,20 @@ class TestCase(unittest.TestCase):
#
# The easiest way to do this would be to do a full GC after each test
# run, but that is very expensive. Instead, we disable GC (above) for
- # the duration of the test so that we only need to run a gen-0 GC, which
- # is a lot quicker.
+ # the duration of the test and only run a gen-0 GC, which is a lot
+ # quicker. This doesn't clean up everything, since the TestCase
+ # instance still holds references to objects created during the test,
+ # such as HomeServers, so we do a full GC every so often.
@around(self)
def tearDown(orig: Callable[[], R]) -> R:
ret = orig()
gc.collect(0)
+ # Run a full GC every 50 gen-0 GCs.
+ gen0_stats = gc.get_stats()[0]
+ gen0_collections = gen0_stats["collections"]
+ if gen0_collections % 50 == 0:
+ gc.collect()
gc.enable()
set_current_context(SENTINEL_CONTEXT)
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 13f1edd533..064f4987df 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -13,7 +13,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import Iterable, Set, Tuple, cast
+from typing import (
+ Any,
+ Dict,
+ Generator,
+ Iterable,
+ List,
+ NoReturn,
+ Optional,
+ Set,
+ Tuple,
+ cast,
+)
from unittest import mock
from twisted.internet import defer, reactor
@@ -29,7 +40,7 @@ from synapse.logging.context import (
make_deferred_yieldable,
)
from synapse.util.caches import descriptors
-from synapse.util.caches.descriptors import cached, cachedList
+from synapse.util.caches.descriptors import _CacheContext, cached, cachedList
from tests import unittest
from tests.test_utils import get_awaitable_result
@@ -37,21 +48,21 @@ from tests.test_utils import get_awaitable_result
logger = logging.getLogger(__name__)
-def run_on_reactor():
- d: "Deferred[int]" = defer.Deferred()
+def run_on_reactor() -> "Deferred[int]":
+ d: "Deferred[int]" = Deferred()
cast(IReactorTime, reactor).callLater(0, d.callback, 0)
return make_deferred_yieldable(d)
class DescriptorTestCase(unittest.TestCase):
@defer.inlineCallbacks
- def test_cache(self):
+ def test_cache(self) -> Generator["Deferred[Any]", object, None]:
class Cls:
- def __init__(self):
+ def __init__(self) -> None:
self.mock = mock.Mock()
@descriptors.cached()
- def fn(self, arg1, arg2):
+ def fn(self, arg1: int, arg2: int) -> str:
return self.mock(arg1, arg2)
obj = Cls()
@@ -77,15 +88,15 @@ class DescriptorTestCase(unittest.TestCase):
obj.mock.assert_not_called()
@defer.inlineCallbacks
- def test_cache_num_args(self):
+ def test_cache_num_args(self) -> Generator["Deferred[Any]", object, None]:
"""Only the first num_args arguments should matter to the cache"""
class Cls:
- def __init__(self):
+ def __init__(self) -> None:
self.mock = mock.Mock()
@descriptors.cached(num_args=1)
- def fn(self, arg1, arg2):
+ def fn(self, arg1: int, arg2: int) -> mock.Mock:
return self.mock(arg1, arg2)
obj = Cls()
@@ -111,7 +122,7 @@ class DescriptorTestCase(unittest.TestCase):
obj.mock.assert_not_called()
@defer.inlineCallbacks
- def test_cache_uncached_args(self):
+ def test_cache_uncached_args(self) -> Generator["Deferred[Any]", object, None]:
"""
Only the arguments not named in uncached_args should matter to the cache
@@ -123,10 +134,10 @@ class DescriptorTestCase(unittest.TestCase):
# Note that it is important that this is not the last argument to
# test behaviour of skipping arguments properly.
@descriptors.cached(uncached_args=("arg2",))
- def fn(self, arg1, arg2, arg3):
+ def fn(self, arg1: int, arg2: int, arg3: int) -> str:
return self.mock(arg1, arg2, arg3)
- def __init__(self):
+ def __init__(self) -> None:
self.mock = mock.Mock()
obj = Cls()
@@ -152,15 +163,15 @@ class DescriptorTestCase(unittest.TestCase):
obj.mock.assert_not_called()
@defer.inlineCallbacks
- def test_cache_kwargs(self):
+ def test_cache_kwargs(self) -> Generator["Deferred[Any]", object, None]:
"""Test that keyword arguments are treated properly"""
class Cls:
- def __init__(self):
+ def __init__(self) -> None:
self.mock = mock.Mock()
@descriptors.cached()
- def fn(self, arg1, kwarg1=2):
+ def fn(self, arg1: int, kwarg1: int = 2) -> str:
return self.mock(arg1, kwarg1=kwarg1)
obj = Cls()
@@ -188,12 +199,12 @@ class DescriptorTestCase(unittest.TestCase):
self.assertEqual(r, "fish")
obj.mock.assert_not_called()
- def test_cache_with_sync_exception(self):
+ def test_cache_with_sync_exception(self) -> None:
"""If the wrapped function throws synchronously, things should continue to work"""
class Cls:
@cached()
- def fn(self, arg1):
+ def fn(self, arg1: int) -> NoReturn:
raise SynapseError(100, "mai spoon iz too big!!1")
obj = Cls()
@@ -209,15 +220,15 @@ class DescriptorTestCase(unittest.TestCase):
d = obj.fn(1)
self.failureResultOf(d, SynapseError)
- def test_cache_with_async_exception(self):
+ def test_cache_with_async_exception(self) -> None:
"""The wrapped function returns a failure"""
class Cls:
- result = None
+ result: Optional[Deferred] = None
call_count = 0
@cached()
- def fn(self, arg1):
+ def fn(self, arg1: int) -> Optional[Deferred]:
self.call_count += 1
return self.result
@@ -225,7 +236,7 @@ class DescriptorTestCase(unittest.TestCase):
callbacks: Set[str] = set()
# set off an asynchronous request
- origin_d: Deferred = defer.Deferred()
+ origin_d: Deferred = Deferred()
obj.result = origin_d
d1 = obj.fn(1, on_invalidate=lambda: callbacks.add("d1"))
@@ -260,17 +271,17 @@ class DescriptorTestCase(unittest.TestCase):
self.assertEqual(self.successResultOf(d3), 100)
self.assertEqual(obj.call_count, 2)
- def test_cache_logcontexts(self):
+ def test_cache_logcontexts(self) -> Deferred:
"""Check that logcontexts are set and restored correctly when
using the cache."""
- complete_lookup: Deferred = defer.Deferred()
+ complete_lookup: Deferred = Deferred()
class Cls:
@descriptors.cached()
- def fn(self, arg1):
+ def fn(self, arg1: int) -> "Deferred[int]":
@defer.inlineCallbacks
- def inner_fn():
+ def inner_fn() -> Generator["Deferred[object]", object, int]:
with PreserveLoggingContext():
yield complete_lookup
return 1
@@ -278,13 +289,13 @@ class DescriptorTestCase(unittest.TestCase):
return inner_fn()
@defer.inlineCallbacks
- def do_lookup():
+ def do_lookup() -> Generator["Deferred[Any]", object, int]:
with LoggingContext("c1") as c1:
r = yield obj.fn(1)
self.assertEqual(current_context(), c1)
- return r
+ return cast(int, r)
- def check_result(r):
+ def check_result(r: int) -> None:
self.assertEqual(r, 1)
obj = Cls()
@@ -304,15 +315,15 @@ class DescriptorTestCase(unittest.TestCase):
return defer.gatherResults([d1, d2])
- def test_cache_logcontexts_with_exception(self):
+ def test_cache_logcontexts_with_exception(self) -> "Deferred[None]":
"""Check that the cache sets and restores logcontexts correctly when
the lookup function throws an exception"""
class Cls:
@descriptors.cached()
- def fn(self, arg1):
+ def fn(self, arg1: int) -> Deferred:
@defer.inlineCallbacks
- def inner_fn():
+ def inner_fn() -> Generator["Deferred[Any]", object, NoReturn]:
# we want this to behave like an asynchronous function
yield run_on_reactor()
raise SynapseError(400, "blah")
@@ -320,7 +331,7 @@ class DescriptorTestCase(unittest.TestCase):
return inner_fn()
@defer.inlineCallbacks
- def do_lookup():
+ def do_lookup() -> Generator["Deferred[object]", object, None]:
with LoggingContext("c1") as c1:
try:
d = obj.fn(1)
@@ -347,13 +358,13 @@ class DescriptorTestCase(unittest.TestCase):
return d1
@defer.inlineCallbacks
- def test_cache_default_args(self):
+ def test_cache_default_args(self) -> Generator["Deferred[Any]", object, None]:
class Cls:
- def __init__(self):
+ def __init__(self) -> None:
self.mock = mock.Mock()
@descriptors.cached()
- def fn(self, arg1, arg2=2, arg3=3):
+ def fn(self, arg1: int, arg2: int = 2, arg3: int = 3) -> str:
return self.mock(arg1, arg2, arg3)
obj = Cls()
@@ -384,13 +395,13 @@ class DescriptorTestCase(unittest.TestCase):
self.assertEqual(r, "chips")
obj.mock.assert_not_called()
- def test_cache_iterable(self):
+ def test_cache_iterable(self) -> None:
class Cls:
- def __init__(self):
+ def __init__(self) -> None:
self.mock = mock.Mock()
@descriptors.cached(iterable=True)
- def fn(self, arg1, arg2):
+ def fn(self, arg1: int, arg2: int) -> List[str]:
return self.mock(arg1, arg2)
obj = Cls()
@@ -417,12 +428,12 @@ class DescriptorTestCase(unittest.TestCase):
self.assertEqual(r.result, ["chips"])
obj.mock.assert_not_called()
- def test_cache_iterable_with_sync_exception(self):
+ def test_cache_iterable_with_sync_exception(self) -> None:
"""If the wrapped function throws synchronously, things should continue to work"""
class Cls:
@descriptors.cached(iterable=True)
- def fn(self, arg1):
+ def fn(self, arg1: int) -> NoReturn:
raise SynapseError(100, "mai spoon iz too big!!1")
obj = Cls()
@@ -438,20 +449,20 @@ class DescriptorTestCase(unittest.TestCase):
d = obj.fn(1)
self.failureResultOf(d, SynapseError)
- def test_invalidate_cascade(self):
+ def test_invalidate_cascade(self) -> None:
"""Invalidations should cascade up through cache contexts"""
class Cls:
@cached(cache_context=True)
- async def func1(self, key, cache_context):
+ async def func1(self, key: str, cache_context: _CacheContext) -> int:
return await self.func2(key, on_invalidate=cache_context.invalidate)
@cached(cache_context=True)
- async def func2(self, key, cache_context):
+ async def func2(self, key: str, cache_context: _CacheContext) -> int:
return await self.func3(key, on_invalidate=cache_context.invalidate)
@cached(cache_context=True)
- async def func3(self, key, cache_context):
+ async def func3(self, key: str, cache_context: _CacheContext) -> int:
self.invalidate = cache_context.invalidate
return 42
@@ -463,13 +474,13 @@ class DescriptorTestCase(unittest.TestCase):
obj.invalidate()
top_invalidate.assert_called_once()
- def test_cancel(self):
+ def test_cancel(self) -> None:
"""Test that cancelling a lookup does not cancel other lookups"""
complete_lookup: "Deferred[None]" = Deferred()
class Cls:
@cached()
- async def fn(self, arg1):
+ async def fn(self, arg1: int) -> str:
await complete_lookup
return str(arg1)
@@ -488,7 +499,7 @@ class DescriptorTestCase(unittest.TestCase):
self.failureResultOf(d1, CancelledError)
self.assertEqual(d2.result, "123")
- def test_cancel_logcontexts(self):
+ def test_cancel_logcontexts(self) -> None:
"""Test that cancellation does not break logcontexts.
* The `CancelledError` must be raised with the correct logcontext.
@@ -501,14 +512,14 @@ class DescriptorTestCase(unittest.TestCase):
inner_context_was_finished = False
@cached()
- async def fn(self, arg1):
+ async def fn(self, arg1: int) -> str:
await make_deferred_yieldable(complete_lookup)
self.inner_context_was_finished = current_context().finished
return str(arg1)
obj = Cls()
- async def do_lookup():
+ async def do_lookup() -> None:
with LoggingContext("c1") as c1:
try:
await obj.fn(123)
@@ -542,10 +553,10 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
"""
@defer.inlineCallbacks
- def test_passthrough(self):
+ def test_passthrough(self) -> Generator["Deferred[Any]", object, None]:
class A:
@cached()
- def func(self, key):
+ def func(self, key: str) -> str:
return key
a = A()
@@ -554,12 +565,12 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
self.assertEqual((yield a.func("bar")), "bar")
@defer.inlineCallbacks
- def test_hit(self):
+ def test_hit(self) -> Generator["Deferred[Any]", object, None]:
callcount = [0]
class A:
@cached()
- def func(self, key):
+ def func(self, key: str) -> str:
callcount[0] += 1
return key
@@ -572,12 +583,12 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
self.assertEqual(callcount[0], 1)
@defer.inlineCallbacks
- def test_invalidate(self):
+ def test_invalidate(self) -> Generator["Deferred[Any]", object, None]:
callcount = [0]
class A:
@cached()
- def func(self, key):
+ def func(self, key: str) -> str:
callcount[0] += 1
return key
@@ -592,21 +603,21 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
self.assertEqual(callcount[0], 2)
- def test_invalidate_missing(self):
+ def test_invalidate_missing(self) -> None:
class A:
@cached()
- def func(self, key):
+ def func(self, key: str) -> str:
return key
A().func.invalidate(("what",))
@defer.inlineCallbacks
- def test_max_entries(self):
+ def test_max_entries(self) -> Generator["Deferred[Any]", object, None]:
callcount = [0]
class A:
@cached(max_entries=10)
- def func(self, key):
+ def func(self, key: int) -> int:
callcount[0] += 1
return key
@@ -626,14 +637,14 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
callcount[0] >= 14, msg="Expected callcount >= 14, got %d" % (callcount[0])
)
- def test_prefill(self):
+ def test_prefill(self) -> None:
callcount = [0]
d = defer.succeed(123)
class A:
@cached()
- def func(self, key):
+ def func(self, key: str) -> "Deferred[int]":
callcount[0] += 1
return d
@@ -645,18 +656,18 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
self.assertEqual(callcount[0], 0)
@defer.inlineCallbacks
- def test_invalidate_context(self):
+ def test_invalidate_context(self) -> Generator["Deferred[Any]", object, None]:
callcount = [0]
callcount2 = [0]
class A:
@cached()
- def func(self, key):
+ def func(self, key: str) -> str:
callcount[0] += 1
return key
@cached(cache_context=True)
- def func2(self, key, cache_context):
+ def func2(self, key: str, cache_context: _CacheContext) -> "Deferred[str]":
callcount2[0] += 1
return self.func(key, on_invalidate=cache_context.invalidate)
@@ -678,18 +689,18 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
self.assertEqual(callcount2[0], 2)
@defer.inlineCallbacks
- def test_eviction_context(self):
+ def test_eviction_context(self) -> Generator["Deferred[Any]", object, None]:
callcount = [0]
callcount2 = [0]
class A:
@cached(max_entries=2)
- def func(self, key):
+ def func(self, key: str) -> str:
callcount[0] += 1
return key
@cached(cache_context=True)
- def func2(self, key, cache_context):
+ def func2(self, key: str, cache_context: _CacheContext) -> "Deferred[str]":
callcount2[0] += 1
return self.func(key, on_invalidate=cache_context.invalidate)
@@ -715,18 +726,18 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
self.assertEqual(callcount2[0], 3)
@defer.inlineCallbacks
- def test_double_get(self):
+ def test_double_get(self) -> Generator["Deferred[Any]", object, None]:
callcount = [0]
callcount2 = [0]
class A:
@cached()
- def func(self, key):
+ def func(self, key: str) -> str:
callcount[0] += 1
return key
@cached(cache_context=True)
- def func2(self, key, cache_context):
+ def func2(self, key: str, cache_context: _CacheContext) -> "Deferred[str]":
callcount2[0] += 1
return self.func(key, on_invalidate=cache_context.invalidate)
@@ -763,17 +774,17 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
class CachedListDescriptorTestCase(unittest.TestCase):
@defer.inlineCallbacks
- def test_cache(self):
+ def test_cache(self) -> Generator["Deferred[Any]", object, None]:
class Cls:
- def __init__(self):
+ def __init__(self) -> None:
self.mock = mock.Mock()
@descriptors.cached()
- def fn(self, arg1, arg2):
+ def fn(self, arg1: int, arg2: int) -> None:
pass
@descriptors.cachedList(cached_method_name="fn", list_name="args1")
- async def list_fn(self, args1, arg2):
+ async def list_fn(self, args1: Iterable[int], arg2: int) -> Dict[int, str]:
context = current_context()
assert isinstance(context, LoggingContext)
assert context.name == "c1"
@@ -824,19 +835,19 @@ class CachedListDescriptorTestCase(unittest.TestCase):
obj.mock.assert_called_once_with({40}, 2)
self.assertEqual(r, {10: "fish", 40: "gravy"})
- def test_concurrent_lookups(self):
+ def test_concurrent_lookups(self) -> None:
"""All concurrent lookups should get the same result"""
class Cls:
- def __init__(self):
+ def __init__(self) -> None:
self.mock = mock.Mock()
@descriptors.cached()
- def fn(self, arg1):
+ def fn(self, arg1: int) -> None:
pass
@descriptors.cachedList(cached_method_name="fn", list_name="args1")
- def list_fn(self, args1) -> "Deferred[dict]":
+ def list_fn(self, args1: List[int]) -> "Deferred[dict]":
return self.mock(args1)
obj = Cls()
@@ -867,19 +878,19 @@ class CachedListDescriptorTestCase(unittest.TestCase):
self.assertEqual(self.successResultOf(d3), {10: "peas"})
@defer.inlineCallbacks
- def test_invalidate(self):
+ def test_invalidate(self) -> Generator["Deferred[Any]", object, None]:
"""Make sure that invalidation callbacks are called."""
class Cls:
- def __init__(self):
+ def __init__(self) -> None:
self.mock = mock.Mock()
@descriptors.cached()
- def fn(self, arg1, arg2):
+ def fn(self, arg1: int, arg2: int) -> None:
pass
@descriptors.cachedList(cached_method_name="fn", list_name="args1")
- async def list_fn(self, args1, arg2):
+ async def list_fn(self, args1: List[int], arg2: int) -> Dict[int, str]:
# we want this to behave like an asynchronous function
await run_on_reactor()
return self.mock(args1, arg2)
@@ -908,17 +919,17 @@ class CachedListDescriptorTestCase(unittest.TestCase):
invalidate0.assert_called_once()
invalidate1.assert_called_once()
- def test_cancel(self):
+ def test_cancel(self) -> None:
"""Test that cancelling a lookup does not cancel other lookups"""
complete_lookup: "Deferred[None]" = Deferred()
class Cls:
@cached()
- def fn(self, arg1):
+ def fn(self, arg1: int) -> None:
pass
@cachedList(cached_method_name="fn", list_name="args")
- async def list_fn(self, args):
+ async def list_fn(self, args: List[int]) -> Dict[int, str]:
await complete_lookup
return {arg: str(arg) for arg in args}
@@ -936,7 +947,7 @@ class CachedListDescriptorTestCase(unittest.TestCase):
self.failureResultOf(d1, CancelledError)
self.assertEqual(d2.result, {123: "123", 456: "456", 789: "789"})
- def test_cancel_logcontexts(self):
+ def test_cancel_logcontexts(self) -> None:
"""Test that cancellation does not break logcontexts.
* The `CancelledError` must be raised with the correct logcontext.
@@ -949,18 +960,18 @@ class CachedListDescriptorTestCase(unittest.TestCase):
inner_context_was_finished = False
@cached()
- def fn(self, arg1):
+ def fn(self, arg1: int) -> None:
pass
@cachedList(cached_method_name="fn", list_name="args")
- async def list_fn(self, args):
+ async def list_fn(self, args: List[int]) -> Dict[int, str]:
await make_deferred_yieldable(complete_lookup)
self.inner_context_was_finished = current_context().finished
return {arg: str(arg) for arg in args}
obj = Cls()
- async def do_lookup():
+ async def do_lookup() -> None:
with LoggingContext("c1") as c1:
try:
await obj.list_fn([123])
@@ -983,7 +994,7 @@ class CachedListDescriptorTestCase(unittest.TestCase):
)
self.assertEqual(current_context(), SENTINEL_CONTEXT)
- def test_num_args_mismatch(self):
+ def test_num_args_mismatch(self) -> None:
"""
Make sure someone does not accidentally use @cachedList on a method with
a mismatch in the number args to the underlying single cache method.
@@ -991,14 +1002,14 @@ class CachedListDescriptorTestCase(unittest.TestCase):
class Cls:
@descriptors.cached(tree=True)
- def fn(self, room_id, event_id):
+ def fn(self, room_id: str, event_id: str) -> None:
pass
# This is wrong ❌. `@cachedList` expects to be given the same number
# of arguments as the underlying cached function, just with one of
# the arguments being an iterable
@descriptors.cachedList(cached_method_name="fn", list_name="keys")
- def list_fn(self, keys: Iterable[Tuple[str, str]]):
+ def list_fn(self, keys: Iterable[Tuple[str, str]]) -> None:
pass
# Corrected syntax ✅
|