diff --git a/tests/app/test_phone_stats_home.py b/tests/app/test_phone_stats_home.py
new file mode 100644
index 0000000000..5527e278db
--- /dev/null
+++ b/tests/app/test_phone_stats_home.py
@@ -0,0 +1,395 @@
+import synapse
+from synapse.app.phone_stats_home import start_phone_stats_home
+from synapse.rest.client.v1 import login, room
+
+from tests import unittest
+from tests.unittest import HomeserverTestCase
+
+FIVE_MINUTES_IN_SECONDS = 300
+ONE_DAY_IN_SECONDS = 86400
+
+
+class PhoneHomeTestCase(HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ # Override the retention time for the user_ips table because otherwise it
+ # gets pruned too aggressively for our R30 test.
+ @unittest.override_config({"user_ips_max_age": "365d"})
+ def test_r30_minimum_usage(self):
+ """
+ Tests the minimum amount of interaction necessary for the R30 metric
+ to consider a user 'retained'.
+ """
+
+ # Register a user, log it in, create a room and send a message
+ user_id = self.register_user("u1", "secret!")
+ access_token = self.login("u1", "secret!")
+ room_id = self.helper.create_room_as(room_creator=user_id, tok=access_token)
+ self.helper.send(room_id, "message", tok=access_token)
+
+ # Check the R30 results do not count that user.
+ r30_results = self.get_success(self.hs.get_datastore().count_r30_users())
+ self.assertEqual(r30_results, {"all": 0})
+
+ # Advance 30 days (+ 1 second, because strict inequality causes issues if we are
+ # bang on 30 days later).
+ self.reactor.advance(30 * ONE_DAY_IN_SECONDS + 1)
+
+ # (Make sure the user isn't somehow counted by this point.)
+ r30_results = self.get_success(self.hs.get_datastore().count_r30_users())
+ self.assertEqual(r30_results, {"all": 0})
+
+ # Send a message (this counts as activity)
+ self.helper.send(room_id, "message2", tok=access_token)
+
+ # We have to wait some time for _update_client_ips_batch to get
+ # called and update the user_ips table.
+ self.reactor.advance(2 * 60 * 60)
+
+ # *Now* the user is counted.
+ r30_results = self.get_success(self.hs.get_datastore().count_r30_users())
+ self.assertEqual(r30_results, {"all": 1, "unknown": 1})
+
+ # Advance 29 days. The user has now not posted for 29 days.
+ self.reactor.advance(29 * ONE_DAY_IN_SECONDS)
+
+ # The user is still counted.
+ r30_results = self.get_success(self.hs.get_datastore().count_r30_users())
+ self.assertEqual(r30_results, {"all": 1, "unknown": 1})
+
+ # Advance another day. The user has now not posted for 30 days.
+ self.reactor.advance(ONE_DAY_IN_SECONDS)
+
+ # The user is now no longer counted in R30.
+ r30_results = self.get_success(self.hs.get_datastore().count_r30_users())
+ self.assertEqual(r30_results, {"all": 0})
+
+ def test_r30_minimum_usage_using_default_config(self):
+ """
+ Tests the minimum amount of interaction necessary for the R30 metric
+ to consider a user 'retained'.
+
+ N.B. This test does not override the `user_ips_max_age` config setting,
+ which defaults to 28 days.
+ """
+
+ # Register a user, log it in, create a room and send a message
+ user_id = self.register_user("u1", "secret!")
+ access_token = self.login("u1", "secret!")
+ room_id = self.helper.create_room_as(room_creator=user_id, tok=access_token)
+ self.helper.send(room_id, "message", tok=access_token)
+
+ # Check the R30 results do not count that user.
+ r30_results = self.get_success(self.hs.get_datastore().count_r30_users())
+ self.assertEqual(r30_results, {"all": 0})
+
+ # Advance 30 days (+ 1 second, because strict inequality causes issues if we are
+ # bang on 30 days later).
+ self.reactor.advance(30 * ONE_DAY_IN_SECONDS + 1)
+
+ # (Make sure the user isn't somehow counted by this point.)
+ r30_results = self.get_success(self.hs.get_datastore().count_r30_users())
+ self.assertEqual(r30_results, {"all": 0})
+
+ # Send a message (this counts as activity)
+ self.helper.send(room_id, "message2", tok=access_token)
+
+ # We have to wait some time for _update_client_ips_batch to get
+ # called and update the user_ips table.
+ self.reactor.advance(2 * 60 * 60)
+
+ # *Now* the user is counted.
+ r30_results = self.get_success(self.hs.get_datastore().count_r30_users())
+ self.assertEqual(r30_results, {"all": 1, "unknown": 1})
+
+ # Advance 27 days. The user has now not posted for 27 days.
+ self.reactor.advance(27 * ONE_DAY_IN_SECONDS)
+
+ # The user is still counted.
+ r30_results = self.get_success(self.hs.get_datastore().count_r30_users())
+ self.assertEqual(r30_results, {"all": 1, "unknown": 1})
+
+ # Advance another day. The user has now not posted for 28 days.
+ self.reactor.advance(ONE_DAY_IN_SECONDS)
+
+ # The user is now no longer counted in R30.
+ # (This is because the user_ips table has been pruned, which by default
+ # only preserves the last 28 days of entries.)
+ r30_results = self.get_success(self.hs.get_datastore().count_r30_users())
+ self.assertEqual(r30_results, {"all": 0})
+
+ def test_r30_user_must_be_retained_for_at_least_a_month(self):
+ """
+ Tests that a newly-registered user must be retained for a whole month
+ before appearing in the R30 statistic, even if they post every day
+ during that time!
+ """
+ # Register a user and send a message
+ user_id = self.register_user("u1", "secret!")
+ access_token = self.login("u1", "secret!")
+ room_id = self.helper.create_room_as(room_creator=user_id, tok=access_token)
+ self.helper.send(room_id, "message", tok=access_token)
+
+ # Check the user does not contribute to R30 yet.
+ r30_results = self.get_success(self.hs.get_datastore().count_r30_users())
+ self.assertEqual(r30_results, {"all": 0})
+
+ for _ in range(30):
+ # This loop posts a message every day for 30 days
+ self.reactor.advance(ONE_DAY_IN_SECONDS)
+ self.helper.send(room_id, "I'm still here", tok=access_token)
+
+ # Notice that the user *still* does not contribute to R30!
+ r30_results = self.get_success(self.hs.get_datastore().count_r30_users())
+ self.assertEqual(r30_results, {"all": 0})
+
+ self.reactor.advance(ONE_DAY_IN_SECONDS)
+ self.helper.send(room_id, "Still here!", tok=access_token)
+
+ # *Now* the user appears in R30.
+ r30_results = self.get_success(self.hs.get_datastore().count_r30_users())
+ self.assertEqual(r30_results, {"all": 1, "unknown": 1})
+
+
+class PhoneHomeR30V2TestCase(HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def _advance_to(self, desired_time_secs: float):
+ now = self.hs.get_clock().time()
+ assert now < desired_time_secs
+ self.reactor.advance(desired_time_secs - now)
+
+ def make_homeserver(self, reactor, clock):
+ hs = super(PhoneHomeR30V2TestCase, self).make_homeserver(reactor, clock)
+
+ # We don't want our tests to actually report statistics, so check
+ # that it's not enabled
+ assert not hs.config.report_stats
+
+ # This starts the needed data collection that we rely on to calculate
+ # R30v2 metrics.
+ start_phone_stats_home(hs)
+ return hs
+
+ def test_r30v2_minimum_usage(self):
+ """
+ Tests the minimum amount of interaction necessary for the R30v2 metric
+ to consider a user 'retained'.
+ """
+
+ # Register a user, log it in, create a room and send a message
+ user_id = self.register_user("u1", "secret!")
+ access_token = self.login("u1", "secret!")
+ room_id = self.helper.create_room_as(room_creator=user_id, tok=access_token)
+ self.helper.send(room_id, "message", tok=access_token)
+ first_post_at = self.hs.get_clock().time()
+
+ # Give time for user_daily_visits table to be updated.
+ # (user_daily_visits is updated every 5 minutes using a looping call.)
+ self.reactor.advance(FIVE_MINUTES_IN_SECONDS)
+
+ store = self.hs.get_datastore()
+
+ # Check the R30 results do not count that user.
+ r30_results = self.get_success(store.count_r30v2_users())
+ self.assertEqual(
+ r30_results, {"all": 0, "android": 0, "electron": 0, "ios": 0, "web": 0}
+ )
+
+ # Advance 31 days.
+ # (R30v2 includes users with **more** than 30 days between the two visits,
+ # and user_daily_visits records the timestamp as the start of the day.)
+ self.reactor.advance(31 * ONE_DAY_IN_SECONDS)
+ # Also advance 5 minutes to let another user_daily_visits update occur
+ self.reactor.advance(FIVE_MINUTES_IN_SECONDS)
+
+ # (Make sure the user isn't somehow counted by this point.)
+ r30_results = self.get_success(store.count_r30v2_users())
+ self.assertEqual(
+ r30_results, {"all": 0, "android": 0, "electron": 0, "ios": 0, "web": 0}
+ )
+
+ # Send a message (this counts as activity)
+ self.helper.send(room_id, "message2", tok=access_token)
+
+ # We have to wait a few minutes for the user_daily_visits table to
+ # be updated by a background process.
+ self.reactor.advance(FIVE_MINUTES_IN_SECONDS)
+
+ # *Now* the user is counted.
+ r30_results = self.get_success(store.count_r30v2_users())
+ self.assertEqual(
+ r30_results, {"all": 1, "android": 0, "electron": 0, "ios": 0, "web": 0}
+ )
+
+ # Advance to JUST under 60 days after the user's first post
+ self._advance_to(first_post_at + 60 * ONE_DAY_IN_SECONDS - 5)
+
+ # Check the user is still counted.
+ r30_results = self.get_success(store.count_r30v2_users())
+ self.assertEqual(
+ r30_results, {"all": 1, "android": 0, "electron": 0, "ios": 0, "web": 0}
+ )
+
+ # Advance into the next day. The user's first activity is now more than 60 days old.
+ self._advance_to(first_post_at + 60 * ONE_DAY_IN_SECONDS + 5)
+
+ # Check the user is now no longer counted in R30.
+ r30_results = self.get_success(store.count_r30v2_users())
+ self.assertEqual(
+ r30_results, {"all": 0, "android": 0, "electron": 0, "ios": 0, "web": 0}
+ )
+
+ def test_r30v2_user_must_be_retained_for_at_least_a_month(self):
+ """
+ Tests that a newly-registered user must be retained for a whole month
+ before appearing in the R30v2 statistic, even if they post every day
+ during that time!
+ """
+
+ # set a custom user-agent to impersonate Element/Android.
+ headers = (
+ (
+ "User-Agent",
+ "Element/1.1 (Linux; U; Android 9; MatrixAndroidSDK_X 0.0.1)",
+ ),
+ )
+
+ # Register a user and send a message
+ user_id = self.register_user("u1", "secret!")
+ access_token = self.login("u1", "secret!", custom_headers=headers)
+ room_id = self.helper.create_room_as(
+ room_creator=user_id, tok=access_token, custom_headers=headers
+ )
+ self.helper.send(room_id, "message", tok=access_token, custom_headers=headers)
+
+ # Give time for user_daily_visits table to be updated.
+ # (user_daily_visits is updated every 5 minutes using a looping call.)
+ self.reactor.advance(FIVE_MINUTES_IN_SECONDS)
+
+ store = self.hs.get_datastore()
+
+ # Check the user does not contribute to R30 yet.
+ r30_results = self.get_success(store.count_r30v2_users())
+ self.assertEqual(
+ r30_results, {"all": 0, "android": 0, "electron": 0, "ios": 0, "web": 0}
+ )
+
+ for _ in range(30):
+ # This loop posts a message every day for 30 days
+ self.reactor.advance(ONE_DAY_IN_SECONDS - FIVE_MINUTES_IN_SECONDS)
+ self.helper.send(
+ room_id, "I'm still here", tok=access_token, custom_headers=headers
+ )
+
+ # give time for user_daily_visits to update
+ self.reactor.advance(FIVE_MINUTES_IN_SECONDS)
+
+ # Notice that the user *still* does not contribute to R30!
+ r30_results = self.get_success(store.count_r30v2_users())
+ self.assertEqual(
+ r30_results, {"all": 0, "android": 0, "electron": 0, "ios": 0, "web": 0}
+ )
+
+ # advance yet another day with more activity
+ self.reactor.advance(ONE_DAY_IN_SECONDS)
+ self.helper.send(
+ room_id, "Still here!", tok=access_token, custom_headers=headers
+ )
+
+ # give time for user_daily_visits to update
+ self.reactor.advance(FIVE_MINUTES_IN_SECONDS)
+
+ # *Now* the user appears in R30.
+ r30_results = self.get_success(store.count_r30v2_users())
+ self.assertEqual(
+ r30_results, {"all": 1, "android": 1, "electron": 0, "ios": 0, "web": 0}
+ )
+
+ def test_r30v2_returning_dormant_users_not_counted(self):
+ """
+ Tests that dormant users (users inactive for a long time) do not
+ contribute to R30v2 when they return for just a single day.
+ This is a key difference between R30 and R30v2.
+ """
+
+ # set a custom user-agent to impersonate Element/iOS.
+ headers = (
+ (
+ "User-Agent",
+ "Riot/1.4 (iPhone; iOS 13; Scale/4.00)",
+ ),
+ )
+
+ # Register a user and send a message
+ user_id = self.register_user("u1", "secret!")
+ access_token = self.login("u1", "secret!", custom_headers=headers)
+ room_id = self.helper.create_room_as(
+ room_creator=user_id, tok=access_token, custom_headers=headers
+ )
+ self.helper.send(room_id, "message", tok=access_token, custom_headers=headers)
+
+ # the user goes inactive for 2 months
+ self.reactor.advance(60 * ONE_DAY_IN_SECONDS)
+
+ # the user returns for one day, perhaps just to check out a new feature
+ self.helper.send(room_id, "message", tok=access_token, custom_headers=headers)
+
+ # Give time for user_daily_visits table to be updated.
+ # (user_daily_visits is updated every 5 minutes using a looping call.)
+ self.reactor.advance(FIVE_MINUTES_IN_SECONDS)
+
+ store = self.hs.get_datastore()
+
+ # Check that the user does not contribute to R30v2, even though it's been
+ # more than 30 days since registration.
+ r30_results = self.get_success(store.count_r30v2_users())
+ self.assertEqual(
+ r30_results, {"all": 0, "android": 0, "electron": 0, "ios": 0, "web": 0}
+ )
+
+ # Check that this is a situation where old R30 differs:
+ # old R30 DOES count this as 'retained'.
+ r30_results = self.get_success(store.count_r30_users())
+ self.assertEqual(r30_results, {"all": 1, "ios": 1})
+
+ # Now we want to check that the user will still be able to appear in
+ # R30v2 as long as the user performs some other activity between
+ # 30 and 60 days later.
+ self.reactor.advance(32 * ONE_DAY_IN_SECONDS)
+ self.helper.send(room_id, "message", tok=access_token, custom_headers=headers)
+
+ # (give time for tables to update)
+ self.reactor.advance(FIVE_MINUTES_IN_SECONDS)
+
+ # Check the user now satisfies the requirements to appear in R30v2.
+ r30_results = self.get_success(store.count_r30v2_users())
+ self.assertEqual(
+ r30_results, {"all": 1, "ios": 1, "android": 0, "electron": 0, "web": 0}
+ )
+
+ # Advance to 59.5 days after the user's first R30v2-eligible activity.
+ self.reactor.advance(27.5 * ONE_DAY_IN_SECONDS)
+
+ # Check the user still appears in R30v2.
+ r30_results = self.get_success(store.count_r30v2_users())
+ self.assertEqual(
+ r30_results, {"all": 1, "ios": 1, "android": 0, "electron": 0, "web": 0}
+ )
+
+ # Advance to 60.5 days after the user's first R30v2-eligible activity.
+ self.reactor.advance(ONE_DAY_IN_SECONDS)
+
+ # Check the user no longer appears in R30v2.
+ r30_results = self.get_success(store.count_r30v2_users())
+ self.assertEqual(
+ r30_results, {"all": 0, "android": 0, "electron": 0, "ios": 0, "web": 0}
+ )
diff --git a/tests/config/test_load.py b/tests/config/test_load.py
index ebe2c05165..903c69127d 100644
--- a/tests/config/test_load.py
+++ b/tests/config/test_load.py
@@ -43,7 +43,7 @@ class ConfigLoadingTestCase(unittest.TestCase):
def test_generates_and_loads_macaroon_secret_key(self):
self.generate_config()
- with open(self.file, "r") as f:
+ with open(self.file) as f:
raw = yaml.safe_load(f)
self.assertIn("macaroon_secret_key", raw)
@@ -120,7 +120,7 @@ class ConfigLoadingTestCase(unittest.TestCase):
def generate_config_and_remove_lines_containing(self, needle):
self.generate_config()
- with open(self.file, "r") as f:
+ with open(self.file) as f:
contents = f.readlines()
contents = [line for line in contents if needle not in line]
with open(self.file, "w") as f:
diff --git a/tests/events/test_presence_router.py b/tests/events/test_presence_router.py
index 875b0d0a11..3f41e99950 100644
--- a/tests/events/test_presence_router.py
+++ b/tests/events/test_presence_router.py
@@ -152,7 +152,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
)
self.assertEqual(len(presence_updates), 1)
- presence_update = presence_updates[0] # type: UserPresenceState
+ presence_update: UserPresenceState = presence_updates[0]
self.assertEqual(presence_update.user_id, self.other_user_one_id)
self.assertEqual(presence_update.state, "online")
self.assertEqual(presence_update.status_msg, "boop")
@@ -274,7 +274,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
presence_updates, _ = sync_presence(self, self.other_user_id)
self.assertEqual(len(presence_updates), 1)
- presence_update = presence_updates[0] # type: UserPresenceState
+ presence_update: UserPresenceState = presence_updates[0]
self.assertEqual(presence_update.user_id, self.other_user_id)
self.assertEqual(presence_update.state, "online")
self.assertEqual(presence_update.status_msg, "I'm online!")
@@ -285,6 +285,10 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
presence_updates, _ = sync_presence(self, self.presence_receiving_user_two_id)
self.assertEqual(len(presence_updates), 3)
+ # We stagger sending of presence, so we need to wait a bit for them to
+ # get sent out.
+ self.reactor.advance(60)
+
# Test that sending to a remote user works
remote_user_id = "@far_away_person:island"
@@ -301,6 +305,10 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
self.module_api.send_local_online_presence_to([remote_user_id])
)
+ # We stagger sending of presence, so we need to wait a bit for them to
+ # get sent out.
+ self.reactor.advance(60)
+
# Check that the expected presence updates were sent
# We explicitly compare using sets as we expect that calling
# module_api.send_local_online_presence_to will create a presence
@@ -320,7 +328,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
)
for call in calls:
call_args = call[0]
- federation_transaction = call_args[0] # type: Transaction
+ federation_transaction: Transaction = call_args[0]
# Get the sent EDUs in this transaction
edus = federation_transaction.get_dict()["edus"]
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index cdb41101b3..2928c4f48c 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -103,7 +103,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
)
self.assertIsNone(
- (self.get_success(self.store.get_profile_displayname(self.frank.localpart)))
+ self.get_success(self.store.get_profile_displayname(self.frank.localpart))
)
def test_set_my_name_if_disabled(self):
diff --git a/tests/handlers/test_space_summary.py b/tests/handlers/test_space_summary.py
index 9771d3fb3b..3f73ad7f94 100644
--- a/tests/handlers/test_space_summary.py
+++ b/tests/handlers/test_space_summary.py
@@ -14,8 +14,18 @@
from typing import Any, Iterable, Optional, Tuple
from unittest import mock
-from synapse.api.constants import EventContentFields, RoomTypes
+from synapse.api.constants import (
+ EventContentFields,
+ EventTypes,
+ HistoryVisibility,
+ JoinRules,
+ Membership,
+ RestrictedJoinRuleTypes,
+ RoomTypes,
+)
from synapse.api.errors import AuthError
+from synapse.api.room_versions import RoomVersions
+from synapse.events import make_event_from_dict
from synapse.handlers.space_summary import _child_events_comparison_key
from synapse.rest import admin
from synapse.rest.client.v1 import login, room
@@ -117,7 +127,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
"""Add a child room to a space."""
self.helper.send_state(
space_id,
- event_type="m.space.child",
+ event_type=EventTypes.SpaceChild,
body={"via": [self.hs.hostname]},
tok=token,
state_key=room_id,
@@ -155,26 +165,379 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
# The user cannot see the space.
self.get_failure(self.handler.get_space_summary(user2, self.space), AuthError)
- # Joining the room causes it to be visible.
- self.helper.join(self.space, user2, tok=token2)
+ # If the space is made world-readable it should return a result.
+ self.helper.send_state(
+ self.space,
+ event_type=EventTypes.RoomHistoryVisibility,
+ body={"history_visibility": HistoryVisibility.WORLD_READABLE},
+ tok=self.token,
+ )
result = self.get_success(self.handler.get_space_summary(user2, self.space))
-
- # The result should only have the space, but includes the link to the room.
- self._assert_rooms(result, [self.space])
+ self._assert_rooms(result, [self.space, self.room])
self._assert_events(result, [(self.space, self.room)])
- def test_world_readable(self):
- """A world-readable room is visible to everyone."""
+ # Make it not world-readable again and confirm it results in an error.
self.helper.send_state(
self.space,
- event_type="m.room.history_visibility",
- body={"history_visibility": "world_readable"},
+ event_type=EventTypes.RoomHistoryVisibility,
+ body={"history_visibility": HistoryVisibility.JOINED},
+ tok=self.token,
+ )
+ self.get_failure(self.handler.get_space_summary(user2, self.space), AuthError)
+
+ # Join the space and results should be returned.
+ self.helper.join(self.space, user2, tok=token2)
+ result = self.get_success(self.handler.get_space_summary(user2, self.space))
+ self._assert_rooms(result, [self.space, self.room])
+ self._assert_events(result, [(self.space, self.room)])
+
+ def _create_room_with_join_rule(
+ self, join_rule: str, room_version: Optional[str] = None, **extra_content
+ ) -> str:
+ """Create a room with the given join rule and add it to the space."""
+ room_id = self.helper.create_room_as(
+ self.user,
+ room_version=room_version,
tok=self.token,
+ extra_content={
+ "initial_state": [
+ {
+ "type": EventTypes.JoinRules,
+ "state_key": "",
+ "content": {
+ "join_rule": join_rule,
+ **extra_content,
+ },
+ }
+ ]
+ },
)
+ self._add_child(self.space, room_id, self.token)
+ return room_id
+ def test_filtering(self):
+ """
+ Rooms should be properly filtered to only include rooms the user has access to.
+ """
user2 = self.register_user("user2", "pass")
+ token2 = self.login("user2", "pass")
- # The space should be visible, as well as the link to the room.
+ # Create a few rooms which will have different properties.
+ public_room = self._create_room_with_join_rule(JoinRules.PUBLIC)
+ knock_room = self._create_room_with_join_rule(
+ JoinRules.KNOCK, room_version=RoomVersions.V7.identifier
+ )
+ not_invited_room = self._create_room_with_join_rule(JoinRules.INVITE)
+ invited_room = self._create_room_with_join_rule(JoinRules.INVITE)
+ self.helper.invite(invited_room, targ=user2, tok=self.token)
+ restricted_room = self._create_room_with_join_rule(
+ JoinRules.MSC3083_RESTRICTED,
+ room_version=RoomVersions.MSC3083.identifier,
+ allow=[],
+ )
+ restricted_accessible_room = self._create_room_with_join_rule(
+ JoinRules.MSC3083_RESTRICTED,
+ room_version=RoomVersions.MSC3083.identifier,
+ allow=[
+ {
+ "type": RestrictedJoinRuleTypes.ROOM_MEMBERSHIP,
+ "room_id": self.space,
+ "via": [self.hs.hostname],
+ }
+ ],
+ )
+ world_readable_room = self._create_room_with_join_rule(JoinRules.INVITE)
+ self.helper.send_state(
+ world_readable_room,
+ event_type=EventTypes.RoomHistoryVisibility,
+ body={"history_visibility": HistoryVisibility.WORLD_READABLE},
+ tok=self.token,
+ )
+ joined_room = self._create_room_with_join_rule(JoinRules.INVITE)
+ self.helper.invite(joined_room, targ=user2, tok=self.token)
+ self.helper.join(joined_room, user2, tok=token2)
+
+ # Join the space.
+ self.helper.join(self.space, user2, tok=token2)
result = self.get_success(self.handler.get_space_summary(user2, self.space))
- self._assert_rooms(result, [self.space])
- self._assert_events(result, [(self.space, self.room)])
+
+ self._assert_rooms(
+ result,
+ [
+ self.space,
+ self.room,
+ public_room,
+ knock_room,
+ invited_room,
+ restricted_accessible_room,
+ world_readable_room,
+ joined_room,
+ ],
+ )
+ self._assert_events(
+ result,
+ [
+ (self.space, self.room),
+ (self.space, public_room),
+ (self.space, knock_room),
+ (self.space, not_invited_room),
+ (self.space, invited_room),
+ (self.space, restricted_room),
+ (self.space, restricted_accessible_room),
+ (self.space, world_readable_room),
+ (self.space, joined_room),
+ ],
+ )
+
+ def test_complex_space(self):
+ """
+ Create a "complex" space to see how it handles things like loops and subspaces.
+ """
+ # Create an inaccessible room.
+ user2 = self.register_user("user2", "pass")
+ token2 = self.login("user2", "pass")
+ room2 = self.helper.create_room_as(user2, is_public=False, tok=token2)
+ # This is a bit odd as "user" is adding a room they don't know about, but
+ # it works for the tests.
+ self._add_child(self.space, room2, self.token)
+
+ # Create a subspace under the space with an additional room in it.
+ subspace = self.helper.create_room_as(
+ self.user,
+ tok=self.token,
+ extra_content={
+ "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE}
+ },
+ )
+ subroom = self.helper.create_room_as(self.user, tok=self.token)
+ self._add_child(self.space, subspace, token=self.token)
+ self._add_child(subspace, subroom, token=self.token)
+ # Also add the two rooms from the space into this subspace (causing loops).
+ self._add_child(subspace, self.room, token=self.token)
+ self._add_child(subspace, room2, self.token)
+
+ result = self.get_success(self.handler.get_space_summary(self.user, self.space))
+
+ # The result should include each room a single time and each link.
+ self._assert_rooms(result, [self.space, self.room, subspace, subroom])
+ self._assert_events(
+ result,
+ [
+ (self.space, self.room),
+ (self.space, room2),
+ (self.space, subspace),
+ (subspace, subroom),
+ (subspace, self.room),
+ (subspace, room2),
+ ],
+ )
+
+ def test_fed_complex(self):
+ """
+ Return data over federation and ensure that it is handled properly.
+ """
+ fed_hostname = self.hs.hostname + "2"
+ subspace = "#subspace:" + fed_hostname
+ subroom = "#subroom:" + fed_hostname
+
+ async def summarize_remote_room(
+ _self, room, suggested_only, max_children, exclude_rooms
+ ):
+ # Return some good data, and some bad data:
+ #
+ # * Event *back* to the root room.
+ # * Unrelated events / rooms
+ # * Multiple levels of events (in a not-useful order, e.g. grandchild
+ # events before child events).
+
+ # Note that these entries are brief, but should contain enough info.
+ rooms = [
+ {
+ "room_id": subspace,
+ "world_readable": True,
+ "room_type": RoomTypes.SPACE,
+ },
+ {
+ "room_id": subroom,
+ "world_readable": True,
+ },
+ ]
+ event_content = {"via": [fed_hostname]}
+ events = [
+ {
+ "room_id": subspace,
+ "state_key": subroom,
+ "content": event_content,
+ },
+ ]
+ return rooms, events
+
+ # Add a room to the space which is on another server.
+ self._add_child(self.space, subspace, self.token)
+
+ with mock.patch(
+ "synapse.handlers.space_summary.SpaceSummaryHandler._summarize_remote_room",
+ new=summarize_remote_room,
+ ):
+ result = self.get_success(
+ self.handler.get_space_summary(self.user, self.space)
+ )
+
+ self._assert_rooms(result, [self.space, self.room, subspace, subroom])
+ self._assert_events(
+ result,
+ [
+ (self.space, self.room),
+ (self.space, subspace),
+ (subspace, subroom),
+ ],
+ )
+
+ def test_fed_filtering(self):
+ """
+ Rooms returned over federation should be properly filtered to only include
+ rooms the user has access to.
+ """
+ fed_hostname = self.hs.hostname + "2"
+ subspace = "#subspace:" + fed_hostname
+
+ # Create a few rooms which will have different properties.
+ public_room = "#public:" + fed_hostname
+ knock_room = "#knock:" + fed_hostname
+ not_invited_room = "#not_invited:" + fed_hostname
+ invited_room = "#invited:" + fed_hostname
+ restricted_room = "#restricted:" + fed_hostname
+ restricted_accessible_room = "#restricted_accessible:" + fed_hostname
+ world_readable_room = "#world_readable:" + fed_hostname
+ joined_room = self.helper.create_room_as(self.user, tok=self.token)
+
+ # Poke an invite over federation into the database.
+ fed_handler = self.hs.get_federation_handler()
+ event = make_event_from_dict(
+ {
+ "room_id": invited_room,
+ "event_id": "!abcd:" + fed_hostname,
+ "type": EventTypes.Member,
+ "sender": "@remote:" + fed_hostname,
+ "state_key": self.user,
+ "content": {"membership": Membership.INVITE},
+ "prev_events": [],
+ "auth_events": [],
+ "depth": 1,
+ "origin_server_ts": 1234,
+ }
+ )
+ self.get_success(
+ fed_handler.on_invite_request(fed_hostname, event, RoomVersions.V6)
+ )
+
+ async def summarize_remote_room(
+ _self, room, suggested_only, max_children, exclude_rooms
+ ):
+ # Note that these entries are brief, but should contain enough info.
+ rooms = [
+ {
+ "room_id": public_room,
+ "world_readable": False,
+ "join_rules": JoinRules.PUBLIC,
+ },
+ {
+ "room_id": knock_room,
+ "world_readable": False,
+ "join_rules": JoinRules.KNOCK,
+ },
+ {
+ "room_id": not_invited_room,
+ "world_readable": False,
+ "join_rules": JoinRules.INVITE,
+ },
+ {
+ "room_id": invited_room,
+ "world_readable": False,
+ "join_rules": JoinRules.INVITE,
+ },
+ {
+ "room_id": restricted_room,
+ "world_readable": False,
+ "join_rules": JoinRules.MSC3083_RESTRICTED,
+ "allowed_spaces": [],
+ },
+ {
+ "room_id": restricted_accessible_room,
+ "world_readable": False,
+ "join_rules": JoinRules.MSC3083_RESTRICTED,
+ "allowed_spaces": [self.room],
+ },
+ {
+ "room_id": world_readable_room,
+ "world_readable": True,
+ "join_rules": JoinRules.INVITE,
+ },
+ {
+ "room_id": joined_room,
+ "world_readable": False,
+ "join_rules": JoinRules.INVITE,
+ },
+ ]
+
+ # Place each room in the sub-space.
+ event_content = {"via": [fed_hostname]}
+ events = [
+ {
+ "room_id": subspace,
+ "state_key": room["room_id"],
+ "content": event_content,
+ }
+ for room in rooms
+ ]
+
+ # Also include the subspace.
+ rooms.insert(
+ 0,
+ {
+ "room_id": subspace,
+ "world_readable": True,
+ },
+ )
+ return rooms, events
+
+ # Add a room to the space which is on another server.
+ self._add_child(self.space, subspace, self.token)
+
+ with mock.patch(
+ "synapse.handlers.space_summary.SpaceSummaryHandler._summarize_remote_room",
+ new=summarize_remote_room,
+ ):
+ result = self.get_success(
+ self.handler.get_space_summary(self.user, self.space)
+ )
+
+ self._assert_rooms(
+ result,
+ [
+ self.space,
+ self.room,
+ subspace,
+ public_room,
+ knock_room,
+ invited_room,
+ restricted_accessible_room,
+ world_readable_room,
+ joined_room,
+ ],
+ )
+ self._assert_events(
+ result,
+ [
+ (self.space, self.room),
+ (self.space, subspace),
+ (subspace, public_room),
+ (subspace, knock_room),
+ (subspace, not_invited_room),
+ (subspace, invited_room),
+ (subspace, restricted_room),
+ (subspace, restricted_accessible_room),
+ (subspace, world_readable_room),
+ (subspace, joined_room),
+ ],
+ )
diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py
index c9d4fd9336..e4059acda3 100644
--- a/tests/handlers/test_stats.py
+++ b/tests/handlers/test_stats.py
@@ -88,16 +88,12 @@ class StatsRoomTests(unittest.HomeserverTestCase):
def _get_current_stats(self, stats_type, stat_id):
table, id_col = stats.TYPE_TO_TABLE[stats_type]
- cols = list(stats.ABSOLUTE_STATS_FIELDS[stats_type]) + list(
- stats.PER_SLICE_FIELDS[stats_type]
- )
-
- end_ts = self.store.quantise_stats_time(self.reactor.seconds() * 1000)
+ cols = list(stats.ABSOLUTE_STATS_FIELDS[stats_type])
return self.get_success(
self.store.db_pool.simple_select_one(
- table + "_historical",
- {id_col: stat_id, end_ts: end_ts},
+ table + "_current",
+ {id_col: stat_id},
cols,
allow_none=True,
)
@@ -156,115 +152,6 @@ class StatsRoomTests(unittest.HomeserverTestCase):
self.assertEqual(len(r), 1)
self.assertEqual(r[0]["topic"], "foo")
- def test_initial_earliest_token(self):
- """
- Ingestion via notify_new_event will ignore tokens that the background
- update have already processed.
- """
-
- self.reactor.advance(86401)
-
- self.hs.config.stats_enabled = False
- self.handler.stats_enabled = False
-
- u1 = self.register_user("u1", "pass")
- u1_token = self.login("u1", "pass")
-
- u2 = self.register_user("u2", "pass")
- u2_token = self.login("u2", "pass")
-
- u3 = self.register_user("u3", "pass")
- u3_token = self.login("u3", "pass")
-
- room_1 = self.helper.create_room_as(u1, tok=u1_token)
- self.helper.send_state(
- room_1, event_type="m.room.topic", body={"topic": "foo"}, tok=u1_token
- )
-
- # Begin the ingestion by creating the temp tables. This will also store
- # the position that the deltas should begin at, once they take over.
- self.hs.config.stats_enabled = True
- self.handler.stats_enabled = True
- self.store.db_pool.updates._all_done = False
- self.get_success(
- self.store.db_pool.simple_update_one(
- table="stats_incremental_position",
- keyvalues={},
- updatevalues={"stream_id": 0},
- )
- )
-
- self.get_success(
- self.store.db_pool.simple_insert(
- "background_updates",
- {"update_name": "populate_stats_prepare", "progress_json": "{}"},
- )
- )
-
- while not self.get_success(
- self.store.db_pool.updates.has_completed_background_updates()
- ):
- self.get_success(
- self.store.db_pool.updates.do_next_background_update(100), by=0.1
- )
-
- # Now, before the table is actually ingested, add some more events.
- self.helper.invite(room=room_1, src=u1, targ=u2, tok=u1_token)
- self.helper.join(room=room_1, user=u2, tok=u2_token)
-
- # orig_delta_processor = self.store.
-
- # Now do the initial ingestion.
- self.get_success(
- self.store.db_pool.simple_insert(
- "background_updates",
- {"update_name": "populate_stats_process_rooms", "progress_json": "{}"},
- )
- )
- self.get_success(
- self.store.db_pool.simple_insert(
- "background_updates",
- {
- "update_name": "populate_stats_cleanup",
- "progress_json": "{}",
- "depends_on": "populate_stats_process_rooms",
- },
- )
- )
-
- self.store.db_pool.updates._all_done = False
- while not self.get_success(
- self.store.db_pool.updates.has_completed_background_updates()
- ):
- self.get_success(
- self.store.db_pool.updates.do_next_background_update(100), by=0.1
- )
-
- self.reactor.advance(86401)
-
- # Now add some more events, triggering ingestion. Because of the stream
- # position being set to before the events sent in the middle, a simpler
- # implementation would reprocess those events, and say there were four
- # users, not three.
- self.helper.invite(room=room_1, src=u1, targ=u3, tok=u1_token)
- self.helper.join(room=room_1, user=u3, tok=u3_token)
-
- # self.handler.notify_new_event()
-
- # We need to let the delta processor advanceâŠ
- self.reactor.advance(10 * 60)
-
- # Get the slices! There should be two -- day 1, and day 2.
- r = self.get_success(self.store.get_statistics_for_subject("room", room_1, 0))
-
- self.assertEqual(len(r), 2)
-
- # The oldest has 2 joined members
- self.assertEqual(r[-1]["joined_members"], 2)
-
- # The newest has 3
- self.assertEqual(r[0]["joined_members"], 3)
-
def test_create_user(self):
"""
When we create a user, it should have statistics already ready.
@@ -296,22 +183,6 @@ class StatsRoomTests(unittest.HomeserverTestCase):
self.assertIsNotNone(r1stats)
self.assertIsNotNone(r2stats)
- # contains the default things you'd expect in a fresh room
- self.assertEqual(
- r1stats["total_events"],
- EXPT_NUM_STATE_EVTS_IN_FRESH_PUBLIC_ROOM,
- "Wrong number of total_events in new room's stats!"
- " You may need to update this if more state events are added to"
- " the room creation process.",
- )
- self.assertEqual(
- r2stats["total_events"],
- EXPT_NUM_STATE_EVTS_IN_FRESH_PRIVATE_ROOM,
- "Wrong number of total_events in new room's stats!"
- " You may need to update this if more state events are added to"
- " the room creation process.",
- )
-
self.assertEqual(
r1stats["current_state_events"], EXPT_NUM_STATE_EVTS_IN_FRESH_PUBLIC_ROOM
)
@@ -327,24 +198,6 @@ class StatsRoomTests(unittest.HomeserverTestCase):
self.assertEqual(r2stats["invited_members"], 0)
self.assertEqual(r2stats["banned_members"], 0)
- def test_send_message_increments_total_events(self):
- """
- When we send a message, it increments total_events.
- """
-
- self._perform_background_initial_update()
-
- u1 = self.register_user("u1", "pass")
- u1token = self.login("u1", "pass")
- r1 = self.helper.create_room_as(u1, tok=u1token)
- r1stats_ante = self._get_current_stats("room", r1)
-
- self.helper.send(r1, "hiss", tok=u1token)
-
- r1stats_post = self._get_current_stats("room", r1)
-
- self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1)
-
def test_updating_profile_information_does_not_increase_joined_members_count(self):
"""
Check that the joined_members count does not increase when a user changes their
@@ -378,7 +231,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
def test_send_state_event_nonoverwriting(self):
"""
- When we send a non-overwriting state event, it increments total_events AND current_state_events
+ When we send a non-overwriting state event, it increments current_state_events
"""
self._perform_background_initial_update()
@@ -399,44 +252,14 @@ class StatsRoomTests(unittest.HomeserverTestCase):
r1stats_post = self._get_current_stats("room", r1)
- self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1)
self.assertEqual(
r1stats_post["current_state_events"] - r1stats_ante["current_state_events"],
1,
)
- def test_send_state_event_overwriting(self):
- """
- When we send an overwriting state event, it increments total_events ONLY
- """
-
- self._perform_background_initial_update()
-
- u1 = self.register_user("u1", "pass")
- u1token = self.login("u1", "pass")
- r1 = self.helper.create_room_as(u1, tok=u1token)
-
- self.helper.send_state(
- r1, "cat.hissing", {"value": True}, tok=u1token, state_key="tabby"
- )
-
- r1stats_ante = self._get_current_stats("room", r1)
-
- self.helper.send_state(
- r1, "cat.hissing", {"value": False}, tok=u1token, state_key="tabby"
- )
-
- r1stats_post = self._get_current_stats("room", r1)
-
- self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1)
- self.assertEqual(
- r1stats_post["current_state_events"] - r1stats_ante["current_state_events"],
- 0,
- )
-
def test_join_first_time(self):
"""
- When a user joins a room for the first time, total_events, current_state_events and
+ When a user joins a room for the first time, current_state_events and
joined_members should increase by exactly 1.
"""
@@ -455,7 +278,6 @@ class StatsRoomTests(unittest.HomeserverTestCase):
r1stats_post = self._get_current_stats("room", r1)
- self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1)
self.assertEqual(
r1stats_post["current_state_events"] - r1stats_ante["current_state_events"],
1,
@@ -466,7 +288,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
def test_join_after_leave(self):
"""
- When a user joins a room after being previously left, total_events and
+ When a user joins a room after being previously left,
joined_members should increase by exactly 1.
current_state_events should not increase.
left_members should decrease by exactly 1.
@@ -490,7 +312,6 @@ class StatsRoomTests(unittest.HomeserverTestCase):
r1stats_post = self._get_current_stats("room", r1)
- self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1)
self.assertEqual(
r1stats_post["current_state_events"] - r1stats_ante["current_state_events"],
0,
@@ -504,7 +325,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
def test_invited(self):
"""
- When a user invites another user, current_state_events, total_events and
+ When a user invites another user, current_state_events and
invited_members should increase by exactly 1.
"""
@@ -522,7 +343,6 @@ class StatsRoomTests(unittest.HomeserverTestCase):
r1stats_post = self._get_current_stats("room", r1)
- self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1)
self.assertEqual(
r1stats_post["current_state_events"] - r1stats_ante["current_state_events"],
1,
@@ -533,7 +353,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
def test_join_after_invite(self):
"""
- When a user joins a room after being invited, total_events and
+ When a user joins a room after being invited and
joined_members should increase by exactly 1.
current_state_events should not increase.
invited_members should decrease by exactly 1.
@@ -556,7 +376,6 @@ class StatsRoomTests(unittest.HomeserverTestCase):
r1stats_post = self._get_current_stats("room", r1)
- self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1)
self.assertEqual(
r1stats_post["current_state_events"] - r1stats_ante["current_state_events"],
0,
@@ -570,7 +389,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
def test_left(self):
"""
- When a user leaves a room after joining, total_events and
+ When a user leaves a room after joining and
left_members should increase by exactly 1.
current_state_events should not increase.
joined_members should decrease by exactly 1.
@@ -593,7 +412,6 @@ class StatsRoomTests(unittest.HomeserverTestCase):
r1stats_post = self._get_current_stats("room", r1)
- self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1)
self.assertEqual(
r1stats_post["current_state_events"] - r1stats_ante["current_state_events"],
0,
@@ -607,7 +425,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
def test_banned(self):
"""
- When a user is banned from a room after joining, total_events and
+ When a user is banned from a room after joining and
left_members should increase by exactly 1.
current_state_events should not increase.
banned_members should decrease by exactly 1.
@@ -630,7 +448,6 @@ class StatsRoomTests(unittest.HomeserverTestCase):
r1stats_post = self._get_current_stats("room", r1)
- self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1)
self.assertEqual(
r1stats_post["current_state_events"] - r1stats_ante["current_state_events"],
0,
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index f58afbc244..fa3cff598e 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -38,6 +38,9 @@ U_ONION = UserID.from_string("@onion:farm")
# Test room id
ROOM_ID = "a-room"
+# Room we're not in
+OTHER_ROOM_ID = "another-room"
+
def _expect_edu_transaction(edu_type, content, origin="test"):
return {
@@ -115,6 +118,11 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
hs.get_auth().check_user_in_room = check_user_in_room
+ async def check_host_in_room(room_id, server_name):
+ return room_id == ROOM_ID
+
+ hs.get_event_auth_handler().check_host_in_room = check_host_in_room
+
def get_joined_hosts_for_room(room_id):
return {member.domain for member in self.room_members}
@@ -244,6 +252,35 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
],
)
+ def test_started_typing_remote_recv_not_in_room(self):
+ self.room_members = [U_APPLE, U_ONION]
+
+ self.assertEquals(self.event_source.get_current_key(), 0)
+
+ channel = self.make_request(
+ "PUT",
+ "/_matrix/federation/v1/send/1000000",
+ _make_edu_transaction_json(
+ "m.typing",
+ content={
+ "room_id": OTHER_ROOM_ID,
+ "user_id": U_ONION.to_string(),
+ "typing": True,
+ },
+ ),
+ federation_auth_origin=b"farm",
+ )
+ self.assertEqual(channel.code, 200)
+
+ self.on_new_event.assert_not_called()
+
+ self.assertEquals(self.event_source.get_current_key(), 0)
+ events = self.get_success(
+ self.event_source.get_new_events(room_ids=[OTHER_ROOM_ID], from_key=0)
+ )
+ self.assertEquals(events[0], [])
+ self.assertEquals(events[1], 0)
+
@override_config({"send_federation": True})
def test_stopped_typing(self):
self.room_members = [U_APPLE, U_BANANA, U_ONION]
diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
index e45980316b..a37bce08c3 100644
--- a/tests/http/federation/test_matrix_federation_agent.py
+++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -273,7 +273,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.assertEqual(response.code, 200)
# Send the body
- request.write('{ "a": 1 }'.encode("ascii"))
+ request.write(b'{ "a": 1 }')
request.finish()
self.reactor.pump((0.1,))
diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py
index ed9a884d76..d9a8b077d3 100644
--- a/tests/http/test_fedclient.py
+++ b/tests/http/test_fedclient.py
@@ -102,7 +102,7 @@ class FederationClientTests(HomeserverTestCase):
self.assertNoResult(test_d)
# Send it the HTTP response
- res_json = '{ "a": 1 }'.encode("ascii")
+ res_json = b'{ "a": 1 }'
protocol.dataReceived(
b"HTTP/1.1 200 OK\r\n"
b"Server: Fake\r\n"
@@ -339,10 +339,8 @@ class FederationClientTests(HomeserverTestCase):
# Send it the HTTP response
client.dataReceived(
- (
- b"HTTP/1.1 200 OK\r\nContent-Type: application/json\r\n"
- b"Server: Fake\r\n\r\n"
- )
+ b"HTTP/1.1 200 OK\r\nContent-Type: application/json\r\n"
+ b"Server: Fake\r\n\r\n"
)
# Push by enough to time it out
diff --git a/tests/http/test_proxyagent.py b/tests/http/test_proxyagent.py
index fefc8099c9..437113929a 100644
--- a/tests/http/test_proxyagent.py
+++ b/tests/http/test_proxyagent.py
@@ -205,6 +205,41 @@ class MatrixFederationAgentTests(TestCase):
@patch.dict(os.environ, {"http_proxy": "proxy.com:8888", "no_proxy": "unused.com"})
def test_http_request_via_proxy(self):
+ """
+ Tests that requests can be made through a proxy.
+ """
+ self._do_http_request_via_proxy(auth_credentials=None)
+
+ @patch.dict(
+ os.environ,
+ {"http_proxy": "bob:pinkponies@proxy.com:8888", "no_proxy": "unused.com"},
+ )
+ def test_http_request_via_proxy_with_auth(self):
+ """
+ Tests that authenticated requests can be made through a proxy.
+ """
+ self._do_http_request_via_proxy(auth_credentials="bob:pinkponies")
+
+ @patch.dict(os.environ, {"https_proxy": "proxy.com", "no_proxy": "unused.com"})
+ def test_https_request_via_proxy(self):
+ """Tests that TLS-encrypted requests can be made through a proxy"""
+ self._do_https_request_via_proxy(auth_credentials=None)
+
+ @patch.dict(
+ os.environ,
+ {"https_proxy": "bob:pinkponies@proxy.com", "no_proxy": "unused.com"},
+ )
+ def test_https_request_via_proxy_with_auth(self):
+ """Tests that authenticated, TLS-encrypted requests can be made through a proxy"""
+ self._do_https_request_via_proxy(auth_credentials="bob:pinkponies")
+
+ def _do_http_request_via_proxy(
+ self,
+ auth_credentials: Optional[str] = None,
+ ):
+ """
+ Tests that requests can be made through a proxy.
+ """
agent = ProxyAgent(self.reactor, use_proxy=True)
self.reactor.lookups["proxy.com"] = "1.2.3.5"
@@ -229,6 +264,23 @@ class MatrixFederationAgentTests(TestCase):
self.assertEqual(len(http_server.requests), 1)
request = http_server.requests[0]
+
+ # Check whether auth credentials have been supplied to the proxy
+ proxy_auth_header_values = request.requestHeaders.getRawHeaders(
+ b"Proxy-Authorization"
+ )
+
+ if auth_credentials is not None:
+ # Compute the correct header value for Proxy-Authorization
+ encoded_credentials = base64.b64encode(b"bob:pinkponies")
+ expected_header_value = b"Basic " + encoded_credentials
+
+ # Validate the header's value
+ self.assertIn(expected_header_value, proxy_auth_header_values)
+ else:
+ # Check that the Proxy-Authorization header has not been supplied to the proxy
+ self.assertIsNone(proxy_auth_header_values)
+
self.assertEqual(request.method, b"GET")
self.assertEqual(request.path, b"http://test.com")
self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
@@ -241,19 +293,6 @@ class MatrixFederationAgentTests(TestCase):
body = self.successResultOf(treq.content(resp))
self.assertEqual(body, b"result")
- @patch.dict(os.environ, {"https_proxy": "proxy.com", "no_proxy": "unused.com"})
- def test_https_request_via_proxy(self):
- """Tests that TLS-encrypted requests can be made through a proxy"""
- self._do_https_request_via_proxy(auth_credentials=None)
-
- @patch.dict(
- os.environ,
- {"https_proxy": "bob:pinkponies@proxy.com", "no_proxy": "unused.com"},
- )
- def test_https_request_via_proxy_with_auth(self):
- """Tests that authenticated, TLS-encrypted requests can be made through a proxy"""
- self._do_https_request_via_proxy(auth_credentials="bob:pinkponies")
-
def _do_https_request_via_proxy(
self,
auth_credentials: Optional[str] = None,
diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
index 2c68b9a13c..81d9e2f484 100644
--- a/tests/module_api/test_api.py
+++ b/tests/module_api/test_api.py
@@ -100,9 +100,9 @@ class ModuleApiTestCase(HomeserverTestCase):
"content": content,
"sender": user_id,
}
- event = self.get_success(
+ event: EventBase = self.get_success(
self.module_api.create_and_send_event_into_room(event_dict)
- ) # type: EventBase
+ )
self.assertEqual(event.sender, user_id)
self.assertEqual(event.type, "m.room.message")
self.assertEqual(event.room_id, room_id)
@@ -136,9 +136,9 @@ class ModuleApiTestCase(HomeserverTestCase):
"sender": user_id,
"state_key": "",
}
- event = self.get_success(
+ event: EventBase = self.get_success(
self.module_api.create_and_send_event_into_room(event_dict)
- ) # type: EventBase
+ )
self.assertEqual(event.sender, user_id)
self.assertEqual(event.type, "m.room.power_levels")
self.assertEqual(event.room_id, room_id)
@@ -281,7 +281,7 @@ class ModuleApiTestCase(HomeserverTestCase):
)
for call in calls:
call_args = call[0]
- federation_transaction = call_args[0] # type: Transaction
+ federation_transaction: Transaction = call_args[0]
# Get the sent EDUs in this transaction
edus = federation_transaction.get_dict()["edus"]
@@ -390,7 +390,7 @@ def _test_sending_local_online_presence_to_local_user(
)
test_case.assertEqual(len(presence_updates), 1)
- presence_update = presence_updates[0] # type: UserPresenceState
+ presence_update: UserPresenceState = presence_updates[0]
test_case.assertEqual(presence_update.user_id, test_case.presence_sender_id)
test_case.assertEqual(presence_update.state, "online")
@@ -443,7 +443,7 @@ def _test_sending_local_online_presence_to_local_user(
)
test_case.assertEqual(len(presence_updates), 1)
- presence_update = presence_updates[0] # type: UserPresenceState
+ presence_update: UserPresenceState = presence_updates[0]
test_case.assertEqual(presence_update.user_id, test_case.presence_sender_id)
test_case.assertEqual(presence_update.state, "online")
@@ -454,7 +454,7 @@ def _test_sending_local_online_presence_to_local_user(
)
test_case.assertEqual(len(presence_updates), 1)
- presence_update = presence_updates[0] # type: UserPresenceState
+ presence_update: UserPresenceState = presence_updates[0]
test_case.assertEqual(presence_update.user_id, test_case.presence_sender_id)
test_case.assertEqual(presence_update.state, "online")
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index 624bd1b927..e9fd991718 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -53,9 +53,9 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
# build a replication server
server_factory = ReplicationStreamProtocolFactory(hs)
self.streamer = hs.get_replication_streamer()
- self.server = server_factory.buildProtocol(
+ self.server: ServerReplicationStreamProtocol = server_factory.buildProtocol(
None
- ) # type: ServerReplicationStreamProtocol
+ )
# Make a new HomeServer object for the worker
self.reactor.lookups["testserv"] = "1.2.3.4"
@@ -195,7 +195,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
fetching updates for given stream.
"""
- path = request.path # type: bytes # type: ignore
+ path: bytes = request.path # type: ignore
self.assertRegex(
path,
br"^/_synapse/replication/get_repl_stream_updates/%s/[^/]+$"
@@ -212,7 +212,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
unlike `BaseStreamTestCase`.
"""
- servlets = [] # type: List[Callable[[HomeServer, JsonResource], None]]
+ servlets: List[Callable[[HomeServer, JsonResource], None]] = []
def setUp(self):
super().setUp()
@@ -448,7 +448,7 @@ class TestReplicationDataHandler(ReplicationDataHandler):
super().__init__(hs)
# list of received (stream_name, token, row) tuples
- self.received_rdata_rows = [] # type: List[Tuple[str, int, Any]]
+ self.received_rdata_rows: List[Tuple[str, int, Any]] = []
async def on_rdata(self, stream_name, instance_name, token, rows):
await super().on_rdata(stream_name, instance_name, token, rows)
@@ -484,7 +484,7 @@ class FakeRedisPubSubServer:
class FakeRedisPubSubProtocol(Protocol):
"""A connection from a client talking to the fake Redis server."""
- transport = None # type: Optional[FakeTransport]
+ transport: Optional[FakeTransport] = None
def __init__(self, server: FakeRedisPubSubServer):
self._server = server
@@ -550,12 +550,12 @@ class FakeRedisPubSubProtocol(Protocol):
if obj is None:
return "$-1\r\n"
if isinstance(obj, str):
- return "${len}\r\n{str}\r\n".format(len=len(obj), str=obj)
+ return f"${len(obj)}\r\n{obj}\r\n"
if isinstance(obj, int):
- return ":{val}\r\n".format(val=obj)
+ return f":{obj}\r\n"
if isinstance(obj, (list, tuple)):
items = "".join(self.encode(a) for a in obj)
- return "*{len}\r\n{items}".format(len=len(obj), items=items)
+ return f"*{len(obj)}\r\n{items}"
raise Exception("Unrecognized type for encoding redis: %r: %r", type(obj), obj)
diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py
index f51fa0a79e..666008425a 100644
--- a/tests/replication/tcp/streams/test_events.py
+++ b/tests/replication/tcp/streams/test_events.py
@@ -135,9 +135,9 @@ class EventsStreamTestCase(BaseStreamTestCase):
)
# this is the point in the DAG where we make a fork
- fork_point = self.get_success(
+ fork_point: List[str] = self.get_success(
self.hs.get_datastore().get_latest_event_ids_in_room(self.room_id)
- ) # type: List[str]
+ )
events = [
self._inject_state_event(sender=OTHER_USER)
@@ -238,7 +238,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
self.assertEqual(row.data.event_id, pl_event.event_id)
# the state rows are unsorted
- state_rows = [] # type: List[EventsStreamCurrentStateRow]
+ state_rows: List[EventsStreamCurrentStateRow] = []
for stream_name, _, row in received_rows:
self.assertEqual("events", stream_name)
self.assertIsInstance(row, EventsStreamRow)
@@ -290,11 +290,11 @@ class EventsStreamTestCase(BaseStreamTestCase):
)
# this is the point in the DAG where we make a fork
- fork_point = self.get_success(
+ fork_point: List[str] = self.get_success(
self.hs.get_datastore().get_latest_event_ids_in_room(self.room_id)
- ) # type: List[str]
+ )
- events = [] # type: List[EventBase]
+ events: List[EventBase] = []
for user in user_ids:
events.extend(
self._inject_state_event(sender=user) for _ in range(STATES_PER_USER)
@@ -355,7 +355,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
self.assertEqual(row.data.event_id, pl_events[i].event_id)
# the state rows are unsorted
- state_rows = [] # type: List[EventsStreamCurrentStateRow]
+ state_rows: List[EventsStreamCurrentStateRow] = []
for _ in range(STATES_PER_USER + 1):
stream_name, token, row = received_rows.pop(0)
self.assertEqual("events", stream_name)
diff --git a/tests/replication/tcp/streams/test_receipts.py b/tests/replication/tcp/streams/test_receipts.py
index 7f5d932f0b..38e292c1ab 100644
--- a/tests/replication/tcp/streams/test_receipts.py
+++ b/tests/replication/tcp/streams/test_receipts.py
@@ -43,7 +43,7 @@ class ReceiptsStreamTestCase(BaseStreamTestCase):
stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
self.assertEqual(stream_name, "receipts")
self.assertEqual(1, len(rdata_rows))
- row = rdata_rows[0] # type: ReceiptsStream.ReceiptsStreamRow
+ row: ReceiptsStream.ReceiptsStreamRow = rdata_rows[0]
self.assertEqual("!room:blue", row.room_id)
self.assertEqual("m.read", row.receipt_type)
self.assertEqual(USER_ID, row.user_id)
@@ -75,7 +75,7 @@ class ReceiptsStreamTestCase(BaseStreamTestCase):
self.assertEqual(token, 3)
self.assertEqual(1, len(rdata_rows))
- row = rdata_rows[0] # type: ReceiptsStream.ReceiptsStreamRow
+ row: ReceiptsStream.ReceiptsStreamRow = rdata_rows[0]
self.assertEqual("!room2:blue", row.room_id)
self.assertEqual("m.read", row.receipt_type)
self.assertEqual(USER_ID, row.user_id)
diff --git a/tests/replication/tcp/streams/test_typing.py b/tests/replication/tcp/streams/test_typing.py
index ecd360c2d0..3ff5afc6e5 100644
--- a/tests/replication/tcp/streams/test_typing.py
+++ b/tests/replication/tcp/streams/test_typing.py
@@ -47,7 +47,7 @@ class TypingStreamTestCase(BaseStreamTestCase):
stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
self.assertEqual(stream_name, "typing")
self.assertEqual(1, len(rdata_rows))
- row = rdata_rows[0] # type: TypingStream.TypingStreamRow
+ row: TypingStream.TypingStreamRow = rdata_rows[0]
self.assertEqual(ROOM_ID, row.room_id)
self.assertEqual([USER_ID], row.user_ids)
@@ -102,7 +102,7 @@ class TypingStreamTestCase(BaseStreamTestCase):
stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
self.assertEqual(stream_name, "typing")
self.assertEqual(1, len(rdata_rows))
- row = rdata_rows[0] # type: TypingStream.TypingStreamRow
+ row: TypingStream.TypingStreamRow = rdata_rows[0]
self.assertEqual(ROOM_ID, row.room_id)
self.assertEqual([USER_ID], row.user_ids)
diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py
index 76e6644353..ffa425328f 100644
--- a/tests/replication/test_multi_media_repo.py
+++ b/tests/replication/test_multi_media_repo.py
@@ -31,7 +31,7 @@ from tests.server import FakeChannel, FakeSite, FakeTransport, make_request
logger = logging.getLogger(__name__)
-test_server_connection_factory = None # type: Optional[TestServerTLSConnectionFactory]
+test_server_connection_factory: Optional[TestServerTLSConnectionFactory] = None
class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
@@ -70,7 +70,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
self.reactor,
FakeSite(resource),
"GET",
- "/{}/{}".format(target, media_id),
+ f"/{target}/{media_id}",
shorthand=False,
access_token=self.access_token,
await_result=False,
@@ -113,7 +113,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
self.assertEqual(request.method, b"GET")
self.assertEqual(
request.path,
- "/_matrix/media/r0/download/{}/{}".format(target, media_id).encode("utf-8"),
+ f"/_matrix/media/r0/download/{target}/{media_id}".encode("utf-8"),
)
self.assertEqual(
request.requestHeaders.getRawHeaders(b"host"), [target.encode("utf-8")]
diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py
index 5eca5c165d..f3615af97e 100644
--- a/tests/replication/test_sharded_event_persister.py
+++ b/tests/replication/test_sharded_event_persister.py
@@ -211,7 +211,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
self.reactor,
sync_hs_site,
"GET",
- "/sync?since={}".format(next_batch),
+ f"/sync?since={next_batch}",
access_token=access_token,
)
@@ -241,7 +241,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
self.reactor,
sync_hs_site,
"GET",
- "/sync?since={}".format(vector_clock_token),
+ f"/sync?since={vector_clock_token}",
access_token=access_token,
)
@@ -266,7 +266,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
self.reactor,
sync_hs_site,
"GET",
- "/sync?since={}".format(next_batch),
+ f"/sync?since={next_batch}",
access_token=access_token,
)
diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index 2f7090e554..a7c6e595b9 100644
--- a/tests/rest/admin/test_admin.py
+++ b/tests/rest/admin/test_admin.py
@@ -66,7 +66,7 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
# Create a new group
channel = self.make_request(
"POST",
- "/create_group".encode("ascii"),
+ b"/create_group",
access_token=self.admin_user_tok,
content={"localpart": "test"},
)
@@ -129,9 +129,7 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
def _get_groups_user_is_in(self, access_token):
"""Returns the list of groups the user is in (given their access token)"""
- channel = self.make_request(
- "GET", "/joined_groups".encode("ascii"), access_token=access_token
- )
+ channel = self.make_request("GET", b"/joined_groups", access_token=access_token)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index ee071c2477..17ec8bfd3b 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -535,7 +535,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
)
)
- self.assertEqual(count, 0, msg="Rows not purged in {}".format(table))
+ self.assertEqual(count, 0, msg=f"Rows not purged in {table}")
def _assert_peek(self, room_id, expect_code):
"""Assert that the admin user can (or cannot) peek into the room."""
@@ -599,7 +599,7 @@ class PurgeRoomTestCase(unittest.HomeserverTestCase):
)
)
- self.assertEqual(count, 0, msg="Rows not purged in {}".format(table))
+ self.assertEqual(count, 0, msg=f"Rows not purged in {table}")
class RoomTestCase(unittest.HomeserverTestCase):
@@ -1280,7 +1280,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
self.public_room_id = self.helper.create_room_as(
self.creator, tok=self.creator_tok, is_public=True
)
- self.url = "/_synapse/admin/v1/join/{}".format(self.public_room_id)
+ self.url = f"/_synapse/admin/v1/join/{self.public_room_id}"
def test_requester_is_no_admin(self):
"""
@@ -1420,7 +1420,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
private_room_id = self.helper.create_room_as(
self.creator, tok=self.creator_tok, is_public=False
)
- url = "/_synapse/admin/v1/join/{}".format(private_room_id)
+ url = f"/_synapse/admin/v1/join/{private_room_id}"
body = json.dumps({"user_id": self.second_user_id})
channel = self.make_request(
@@ -1463,7 +1463,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
# Join user to room.
- url = "/_synapse/admin/v1/join/{}".format(private_room_id)
+ url = f"/_synapse/admin/v1/join/{private_room_id}"
body = json.dumps({"user_id": self.second_user_id})
channel = self.make_request(
@@ -1493,7 +1493,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
private_room_id = self.helper.create_room_as(
self.admin_user, tok=self.admin_user_tok, is_public=False
)
- url = "/_synapse/admin/v1/join/{}".format(private_room_id)
+ url = f"/_synapse/admin/v1/join/{private_room_id}"
body = json.dumps({"user_id": self.second_user_id})
channel = self.make_request(
@@ -1633,7 +1633,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"POST",
- "/_synapse/admin/v1/rooms/{}/make_room_admin".format(room_id),
+ f"/_synapse/admin/v1/rooms/{room_id}/make_room_admin",
content={},
access_token=self.admin_user_tok,
)
@@ -1660,7 +1660,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"POST",
- "/_synapse/admin/v1/rooms/{}/make_room_admin".format(room_id),
+ f"/_synapse/admin/v1/rooms/{room_id}/make_room_admin",
content={},
access_token=self.admin_user_tok,
)
@@ -1686,7 +1686,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"POST",
- "/_synapse/admin/v1/rooms/{}/make_room_admin".format(room_id),
+ f"/_synapse/admin/v1/rooms/{room_id}/make_room_admin",
content={"user_id": self.second_user_id},
access_token=self.admin_user_tok,
)
@@ -1720,7 +1720,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"POST",
- "/_synapse/admin/v1/rooms/{}/make_room_admin".format(room_id),
+ f"/_synapse/admin/v1/rooms/{room_id}/make_room_admin",
content={},
access_token=self.admin_user_tok,
)
@@ -1753,7 +1753,6 @@ PURGE_TABLES = [
"room_memberships",
"room_stats_state",
"room_stats_current",
- "room_stats_historical",
"room_stats_earliest_token",
"rooms",
"stream_ordering_to_exterm",
diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
index e1fe72fc5d..28dd47a28b 100644
--- a/tests/rest/client/test_third_party_rules.py
+++ b/tests/rest/client/test_third_party_rules.py
@@ -16,17 +16,19 @@ from typing import Dict
from unittest.mock import Mock
from synapse.events import EventBase
+from synapse.events.third_party_rules import load_legacy_third_party_event_rules
from synapse.module_api import ModuleApi
from synapse.rest import admin
from synapse.rest.client.v1 import login, room
from synapse.types import Requester, StateMap
+from synapse.util.frozenutils import unfreeze
from tests import unittest
thread_local = threading.local()
-class ThirdPartyRulesTestModule:
+class LegacyThirdPartyRulesTestModule:
def __init__(self, config: Dict, module_api: ModuleApi):
# keep a record of the "current" rules module, so that the test can patch
# it if desired.
@@ -46,8 +48,26 @@ class ThirdPartyRulesTestModule:
return config
-def current_rules_module() -> ThirdPartyRulesTestModule:
- return thread_local.rules_module
+class LegacyDenyNewRooms(LegacyThirdPartyRulesTestModule):
+ def __init__(self, config: Dict, module_api: ModuleApi):
+ super().__init__(config, module_api)
+
+ def on_create_room(
+ self, requester: Requester, config: dict, is_requester_admin: bool
+ ):
+ return False
+
+
+class LegacyChangeEvents(LegacyThirdPartyRulesTestModule):
+ def __init__(self, config: Dict, module_api: ModuleApi):
+ super().__init__(config, module_api)
+
+ async def check_event_allowed(self, event: EventBase, state: StateMap[EventBase]):
+ d = event.get_dict()
+ content = unfreeze(event.content)
+ content["foo"] = "bar"
+ d["content"] = content
+ return d
class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
@@ -57,20 +77,23 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
room.register_servlets,
]
- def default_config(self):
- config = super().default_config()
- config["third_party_event_rules"] = {
- "module": __name__ + ".ThirdPartyRulesTestModule",
- "config": {},
- }
- return config
+ def make_homeserver(self, reactor, clock):
+ hs = self.setup_test_homeserver()
+
+ load_legacy_third_party_event_rules(hs)
+
+ return hs
def prepare(self, reactor, clock, homeserver):
# Create a user and room to play with during the tests
self.user_id = self.register_user("kermit", "monkey")
self.tok = self.login("kermit", "monkey")
- self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
+ # Some tests might prevent room creation on purpose.
+ try:
+ self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
+ except Exception:
+ pass
def test_third_party_rules(self):
"""Tests that a forbidden event is forbidden from being sent, but an allowed one
@@ -79,10 +102,12 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
# patch the rules module with a Mock which will return False for some event
# types
async def check(ev, state):
- return ev.type != "foo.bar.forbidden"
+ return ev.type != "foo.bar.forbidden", None
callback = Mock(spec=[], side_effect=check)
- current_rules_module().check_event_allowed = callback
+ self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [
+ callback
+ ]
channel = self.make_request(
"PUT",
@@ -116,9 +141,9 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
# first patch the event checker so that it will try to modify the event
async def check(ev: EventBase, state):
ev.content = {"x": "y"}
- return True
+ return True, None
- current_rules_module().check_event_allowed = check
+ self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [check]
# now send the event
channel = self.make_request(
@@ -127,7 +152,19 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
{"x": "x"},
access_token=self.tok,
)
- self.assertEqual(channel.result["code"], b"500", channel.result)
+ # check_event_allowed has some error handling, so it shouldn't 500 just because a
+ # module did something bad.
+ self.assertEqual(channel.code, 200, channel.result)
+ event_id = channel.json_body["event_id"]
+
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id),
+ access_token=self.tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+ ev = channel.json_body
+ self.assertEqual(ev["content"]["x"], "x")
def test_modify_event(self):
"""The module can return a modified version of the event"""
@@ -135,9 +172,9 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
async def check(ev: EventBase, state):
d = ev.get_dict()
d["content"] = {"x": "y"}
- return d
+ return True, d
- current_rules_module().check_event_allowed = check
+ self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [check]
# now send the event
channel = self.make_request(
@@ -168,9 +205,9 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
"msgtype": "m.text",
"body": d["content"]["body"].upper(),
}
- return d
+ return True, d
- current_rules_module().check_event_allowed = check
+ self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [check]
# Send an event, then edit it.
channel = self.make_request(
@@ -222,7 +259,7 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
self.assertEqual(ev["content"]["body"], "EDITED BODY")
def test_send_event(self):
- """Tests that the module can send an event into a room via the module api"""
+ """Tests that a module can send an event into a room via the module api"""
content = {
"msgtype": "m.text",
"body": "Hello!",
@@ -233,13 +270,60 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
"content": content,
"sender": self.user_id,
}
- event = self.get_success(
- current_rules_module().module_api.create_and_send_event_into_room(
- event_dict
- )
- ) # type: EventBase
+ event: EventBase = self.get_success(
+ self.hs.get_module_api().create_and_send_event_into_room(event_dict)
+ )
self.assertEquals(event.sender, self.user_id)
self.assertEquals(event.room_id, self.room_id)
self.assertEquals(event.type, "m.room.message")
self.assertEquals(event.content, content)
+
+ @unittest.override_config(
+ {
+ "third_party_event_rules": {
+ "module": __name__ + ".LegacyChangeEvents",
+ "config": {},
+ }
+ }
+ )
+ def test_legacy_check_event_allowed(self):
+ """Tests that the wrapper for legacy check_event_allowed callbacks works
+ correctly.
+ """
+ channel = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/rooms/%s/send/m.room.message/1" % self.room_id,
+ {
+ "msgtype": "m.text",
+ "body": "Original body",
+ },
+ access_token=self.tok,
+ )
+ self.assertEqual(channel.result["code"], b"200", channel.result)
+
+ event_id = channel.json_body["event_id"]
+
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id),
+ access_token=self.tok,
+ )
+ self.assertEqual(channel.result["code"], b"200", channel.result)
+
+ self.assertIn("foo", channel.json_body["content"].keys())
+ self.assertEqual(channel.json_body["content"]["foo"], "bar")
+
+ @unittest.override_config(
+ {
+ "third_party_event_rules": {
+ "module": __name__ + ".LegacyDenyNewRooms",
+ "config": {},
+ }
+ }
+ )
+ def test_legacy_on_create_room(self):
+ """Tests that the wrapper for legacy on_create_room callbacks works
+ correctly.
+ """
+ self.helper.create_room_as(self.user_id, tok=self.tok, expect_code=403)
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index 605b952316..7eba69642a 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/v1/test_login.py
@@ -453,7 +453,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200, channel.result)
# stick the flows results in a dict by type
- flow_results = {} # type: Dict[str, Any]
+ flow_results: Dict[str, Any] = {}
for f in channel.json_body["flows"]:
flow_type = f["type"]
self.assertNotIn(
@@ -501,7 +501,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
p.close()
# there should be a link for each href
- returned_idps = [] # type: List[str]
+ returned_idps: List[str] = []
for link in p.links:
path, query = link.split("?", 1)
self.assertEqual(path, "pick_idp")
@@ -582,7 +582,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
# ... and should have set a cookie including the redirect url
cookie_headers = channel.headers.getRawHeaders("Set-Cookie")
assert cookie_headers
- cookies = {} # type: Dict[str, str]
+ cookies: Dict[str, str] = {}
for h in cookie_headers:
key, value = h.split(";")[0].split("=", maxsplit=1)
cookies[key] = value
@@ -874,9 +874,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_secret) -> str:
# PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str.
- result = jwt.encode(
- payload, secret, self.jwt_algorithm
- ) # type: Union[str, bytes]
+ result: Union[str, bytes] = jwt.encode(payload, secret, self.jwt_algorithm)
if isinstance(result, bytes):
return result.decode("ascii")
return result
@@ -1084,7 +1082,7 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase):
def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_privatekey) -> str:
# PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str.
- result = jwt.encode(payload, secret, "RS256") # type: Union[bytes,str]
+ result: Union[bytes, str] = jwt.encode(payload, secret, "RS256")
if isinstance(result, bytes):
return result.decode("ascii")
return result
@@ -1272,7 +1270,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
self.assertEqual(picker_url, "/_synapse/client/pick_username/account_details")
# ... with a username_mapping_session cookie
- cookies = {} # type: Dict[str,str]
+ cookies: Dict[str, str] = {}
channel.extract_cookies(cookies)
self.assertIn("username_mapping_session", cookies)
session_id = cookies["username_mapping_session"]
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index e94566ffd7..3df070c936 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -1206,7 +1206,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
reason = "hello"
channel = self.make_request(
"POST",
- "/_matrix/client/r0/rooms/{}/join".format(self.room_id),
+ f"/_matrix/client/r0/rooms/{self.room_id}/join",
content={"reason": reason},
access_token=self.second_tok,
)
@@ -1220,7 +1220,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
reason = "hello"
channel = self.make_request(
"POST",
- "/_matrix/client/r0/rooms/{}/leave".format(self.room_id),
+ f"/_matrix/client/r0/rooms/{self.room_id}/leave",
content={"reason": reason},
access_token=self.second_tok,
)
@@ -1234,7 +1234,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
reason = "hello"
channel = self.make_request(
"POST",
- "/_matrix/client/r0/rooms/{}/kick".format(self.room_id),
+ f"/_matrix/client/r0/rooms/{self.room_id}/kick",
content={"reason": reason, "user_id": self.second_user_id},
access_token=self.second_tok,
)
@@ -1248,7 +1248,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
reason = "hello"
channel = self.make_request(
"POST",
- "/_matrix/client/r0/rooms/{}/ban".format(self.room_id),
+ f"/_matrix/client/r0/rooms/{self.room_id}/ban",
content={"reason": reason, "user_id": self.second_user_id},
access_token=self.creator_tok,
)
@@ -1260,7 +1260,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
reason = "hello"
channel = self.make_request(
"POST",
- "/_matrix/client/r0/rooms/{}/unban".format(self.room_id),
+ f"/_matrix/client/r0/rooms/{self.room_id}/unban",
content={"reason": reason, "user_id": self.second_user_id},
access_token=self.creator_tok,
)
@@ -1272,7 +1272,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
reason = "hello"
channel = self.make_request(
"POST",
- "/_matrix/client/r0/rooms/{}/invite".format(self.room_id),
+ f"/_matrix/client/r0/rooms/{self.room_id}/invite",
content={"reason": reason, "user_id": self.second_user_id},
access_token=self.creator_tok,
)
@@ -1291,7 +1291,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
reason = "hello"
channel = self.make_request(
"POST",
- "/_matrix/client/r0/rooms/{}/leave".format(self.room_id),
+ f"/_matrix/client/r0/rooms/{self.room_id}/leave",
content={"reason": reason},
access_token=self.second_tok,
)
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index 69798e95c3..fc2d35596e 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -19,7 +19,7 @@ import json
import re
import time
import urllib.parse
-from typing import Any, Dict, Mapping, MutableMapping, Optional
+from typing import Any, Dict, Iterable, Mapping, MutableMapping, Optional, Tuple, Union
from unittest.mock import patch
import attr
@@ -53,6 +53,9 @@ class RestHelper:
tok: str = None,
expect_code: int = 200,
extra_content: Optional[Dict] = None,
+ custom_headers: Optional[
+ Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
+ ] = None,
) -> str:
"""
Create a room.
@@ -87,6 +90,7 @@ class RestHelper:
"POST",
path,
json.dumps(content).encode("utf8"),
+ custom_headers=custom_headers,
)
assert channel.result["code"] == b"%d" % expect_code, channel.result
@@ -175,14 +179,30 @@ class RestHelper:
self.auth_user_id = temp_id
- def send(self, room_id, body=None, txn_id=None, tok=None, expect_code=200):
+ def send(
+ self,
+ room_id,
+ body=None,
+ txn_id=None,
+ tok=None,
+ expect_code=200,
+ custom_headers: Optional[
+ Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
+ ] = None,
+ ):
if body is None:
body = "body_text_here"
content = {"msgtype": "m.text", "body": body}
return self.send_event(
- room_id, "m.room.message", content, txn_id, tok, expect_code
+ room_id,
+ "m.room.message",
+ content,
+ txn_id,
+ tok,
+ expect_code,
+ custom_headers=custom_headers,
)
def send_event(
@@ -193,6 +213,9 @@ class RestHelper:
txn_id=None,
tok=None,
expect_code=200,
+ custom_headers: Optional[
+ Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
+ ] = None,
):
if txn_id is None:
txn_id = "m%s" % (str(time.time()))
@@ -207,6 +230,7 @@ class RestHelper:
"PUT",
path,
json.dumps(content or {}).encode("utf8"),
+ custom_headers=custom_headers,
)
assert (
diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py
index 856aa8682f..2e2f94742e 100644
--- a/tests/rest/client/v2_alpha/test_relations.py
+++ b/tests/rest/client/v2_alpha/test_relations.py
@@ -273,7 +273,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
prev_token = None
found_event_ids = []
- encoded_key = urllib.parse.quote_plus("đ".encode("utf-8"))
+ encoded_key = urllib.parse.quote_plus("đ".encode())
for _ in range(20):
from_token = ""
if prev_token:
diff --git a/tests/rest/client/v2_alpha/test_report_event.py b/tests/rest/client/v2_alpha/test_report_event.py
index 1ec6b05e5b..a76a6fef1e 100644
--- a/tests/rest/client/v2_alpha/test_report_event.py
+++ b/tests/rest/client/v2_alpha/test_report_event.py
@@ -41,7 +41,7 @@ class ReportEventTestCase(unittest.HomeserverTestCase):
self.helper.join(self.room_id, user=self.admin_user, tok=self.admin_user_tok)
resp = self.helper.send(self.room_id, tok=self.admin_user_tok)
self.event_id = resp["event_id"]
- self.report_path = "rooms/{}/report/{}".format(self.room_id, self.event_id)
+ self.report_path = f"rooms/{self.room_id}/report/{self.event_id}"
def test_reason_str_and_score_int(self):
data = {"reason": "this makes me sad", "score": -100}
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index 95e7075841..2d6b49692e 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -310,7 +310,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
correctly decode it as the UTF-8 string, and use filename* in the
response.
"""
- filename = parse.quote("\u2603".encode("utf8")).encode("ascii")
+ filename = parse.quote("\u2603".encode()).encode("ascii")
channel = self._req(
b"inline; filename*=utf-8''" + filename + self.test_image.extension
)
diff --git a/tests/server.py b/tests/server.py
index f32d8dc375..6fddd3b305 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -52,7 +52,7 @@ class FakeChannel:
_reactor = attr.ib()
result = attr.ib(type=dict, default=attr.Factory(dict))
_ip = attr.ib(type=str, default="127.0.0.1")
- _producer = None # type: Optional[Union[IPullProducer, IPushProducer]]
+ _producer: Optional[Union[IPullProducer, IPushProducer]] = None
@property
def json_body(self):
@@ -316,8 +316,10 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
self._tcp_callbacks = {}
self._udp = []
- lookups = self.lookups = {} # type: Dict[str, str]
- self._thread_callbacks = deque() # type: Deque[Callable[[], None]]
+ self.lookups: Dict[str, str] = {}
+ self._thread_callbacks: Deque[Callable[[], None]] = deque()
+
+ lookups = self.lookups
@implementer(IResolverSimple)
class FakeResolver:
diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py
index 069db0edc4..0da42b5ac5 100644
--- a/tests/storage/test_background_update.py
+++ b/tests/storage/test_background_update.py
@@ -7,9 +7,7 @@ from tests import unittest
class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, homeserver):
- self.updates = (
- self.hs.get_datastore().db_pool.updates
- ) # type: BackgroundUpdater
+ self.updates: BackgroundUpdater = self.hs.get_datastore().db_pool.updates
# the base test class should have run the real bg updates for us
self.assertTrue(
self.get_success(self.updates.has_completed_background_updates())
diff --git a/tests/storage/test_directory.py b/tests/storage/test_directory.py
index 41bef62ca8..43628ce44f 100644
--- a/tests/storage/test_directory.py
+++ b/tests/storage/test_directory.py
@@ -59,5 +59,5 @@ class DirectoryStoreTestCase(HomeserverTestCase):
self.assertEqual(self.room.to_string(), room_id)
self.assertIsNone(
- (self.get_success(self.store.get_association_from_room_alias(self.alias)))
+ self.get_success(self.store.get_association_from_room_alias(self.alias))
)
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index 792b1c44c1..7486078284 100644
--- a/tests/storage/test_id_generators.py
+++ b/tests/storage/test_id_generators.py
@@ -27,7 +27,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
- self.db_pool = self.store.db_pool # type: DatabasePool
+ self.db_pool: DatabasePool = self.store.db_pool
self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
@@ -460,7 +460,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
- self.db_pool = self.store.db_pool # type: DatabasePool
+ self.db_pool: DatabasePool = self.store.db_pool
self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
@@ -586,7 +586,7 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
- self.db_pool = self.store.db_pool # type: DatabasePool
+ self.db_pool: DatabasePool = self.store.db_pool
self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py
index 8a446da848..a1ba99ff14 100644
--- a/tests/storage/test_profile.py
+++ b/tests/storage/test_profile.py
@@ -45,11 +45,7 @@ class ProfileStoreTestCase(unittest.HomeserverTestCase):
)
self.assertIsNone(
- (
- self.get_success(
- self.store.get_profile_displayname(self.u_frank.localpart)
- )
- )
+ self.get_success(self.store.get_profile_displayname(self.u_frank.localpart))
)
def test_avatar_url(self):
@@ -76,9 +72,5 @@ class ProfileStoreTestCase(unittest.HomeserverTestCase):
)
self.assertIsNone(
- (
- self.get_success(
- self.store.get_profile_avatar_url(self.u_frank.localpart)
- )
- )
+ self.get_success(self.store.get_profile_avatar_url(self.u_frank.localpart))
)
diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py
index 54c5b470c7..e5574063f1 100644
--- a/tests/storage/test_purge.py
+++ b/tests/storage/test_purge.py
@@ -75,7 +75,7 @@ class PurgeTests(HomeserverTestCase):
token = self.get_success(
self.store.get_topological_token_for_event(last["event_id"])
)
- event = "t{}-{}".format(token.topological + 1, token.stream + 1)
+ event = f"t{token.topological + 1}-{token.stream + 1}"
# Purge everything before this topological token
f = self.get_failure(
diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py
index 70257bf210..31ce7f6252 100644
--- a/tests/storage/test_room.py
+++ b/tests/storage/test_room.py
@@ -49,7 +49,7 @@ class RoomStoreTestCase(HomeserverTestCase):
)
def test_get_room_unknown_room(self):
- self.assertIsNone((self.get_success(self.store.get_room("!uknown:test"))))
+ self.assertIsNone(self.get_success(self.store.get_room("!uknown:test")))
def test_get_room_with_stats(self):
self.assertDictContainsSubset(
diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py
index 88888319cc..f73306ecc4 100644
--- a/tests/test_event_auth.py
+++ b/tests/test_event_auth.py
@@ -13,12 +13,13 @@
# limitations under the License.
import unittest
+from typing import Optional
from synapse import event_auth
from synapse.api.errors import AuthError
from synapse.api.room_versions import RoomVersions
-from synapse.events import make_event_from_dict
-from synapse.types import get_domain_from_id
+from synapse.events import EventBase, make_event_from_dict
+from synapse.types import JsonDict, get_domain_from_id
class EventAuthTestCase(unittest.TestCase):
@@ -432,7 +433,7 @@ class EventAuthTestCase(unittest.TestCase):
TEST_ROOM_ID = "!test:room"
-def _create_event(user_id):
+def _create_event(user_id: str) -> EventBase:
return make_event_from_dict(
{
"room_id": TEST_ROOM_ID,
@@ -444,7 +445,9 @@ def _create_event(user_id):
)
-def _member_event(user_id, membership, sender=None):
+def _member_event(
+ user_id: str, membership: str, sender: Optional[str] = None
+) -> EventBase:
return make_event_from_dict(
{
"room_id": TEST_ROOM_ID,
@@ -458,11 +461,11 @@ def _member_event(user_id, membership, sender=None):
)
-def _join_event(user_id):
+def _join_event(user_id: str) -> EventBase:
return _member_event(user_id, "join")
-def _power_levels_event(sender, content):
+def _power_levels_event(sender: str, content: JsonDict) -> EventBase:
return make_event_from_dict(
{
"room_id": TEST_ROOM_ID,
@@ -475,7 +478,7 @@ def _power_levels_event(sender, content):
)
-def _alias_event(sender, **kwargs):
+def _alias_event(sender: str, **kwargs) -> EventBase:
data = {
"room_id": TEST_ROOM_ID,
"event_id": _get_event_id(),
@@ -488,7 +491,7 @@ def _alias_event(sender, **kwargs):
return make_event_from_dict(data)
-def _random_state_event(sender):
+def _random_state_event(sender: str) -> EventBase:
return make_event_from_dict(
{
"room_id": TEST_ROOM_ID,
@@ -501,7 +504,7 @@ def _random_state_event(sender):
)
-def _join_rules_event(sender, join_rule):
+def _join_rules_event(sender: str, join_rule: str) -> EventBase:
return make_event_from_dict(
{
"room_id": TEST_ROOM_ID,
@@ -519,7 +522,7 @@ def _join_rules_event(sender, join_rule):
event_count = 0
-def _get_event_id():
+def _get_event_id() -> str:
global event_count
c = event_count
event_count += 1
diff --git a/tests/test_state.py b/tests/test_state.py
index 62f7095873..e5488df1ac 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -168,6 +168,7 @@ class StateTestCase(unittest.TestCase):
"get_state_handler",
"get_clock",
"get_state_resolution_handler",
+ "get_account_validity_handler",
"hostname",
]
)
@@ -199,7 +200,7 @@ class StateTestCase(unittest.TestCase):
self.store.register_events(graph.walk())
- context_store = {} # type: dict[str, EventContext]
+ context_store: dict[str, EventContext] = {}
for event in graph.walk():
context = yield defer.ensureDeferred(
diff --git a/tests/test_types.py b/tests/test_types.py
index d7881021d3..0d0c00d97a 100644
--- a/tests/test_types.py
+++ b/tests/test_types.py
@@ -103,6 +103,4 @@ class MapUsernameTestCase(unittest.TestCase):
def testNonAscii(self):
# this should work with either a unicode or a bytes
self.assertEqual(map_username_to_mxid_localpart("tĂȘst"), "t=c3=aast")
- self.assertEqual(
- map_username_to_mxid_localpart("tĂȘst".encode("utf-8")), "t=c3=aast"
- )
+ self.assertEqual(map_username_to_mxid_localpart("tĂȘst".encode()), "t=c3=aast")
diff --git a/tests/test_utils/html_parsers.py b/tests/test_utils/html_parsers.py
index 1fbb38f4be..e878af5f12 100644
--- a/tests/test_utils/html_parsers.py
+++ b/tests/test_utils/html_parsers.py
@@ -23,13 +23,13 @@ class TestHtmlParser(HTMLParser):
super().__init__()
# a list of links found in the doc
- self.links = [] # type: List[str]
+ self.links: List[str] = []
# the values of any hidden <input>s: map from name to value
- self.hiddens = {} # type: Dict[str, Optional[str]]
+ self.hiddens: Dict[str, Optional[str]] = {}
# the values of any radio buttons: map from name to list of values
- self.radios = {} # type: Dict[str, List[Optional[str]]]
+ self.radios: Dict[str, List[Optional[str]]] = {}
def handle_starttag(
self, tag: str, attrs: Iterable[Tuple[str, Optional[str]]]
diff --git a/tests/unittest.py b/tests/unittest.py
index 74db7c08f1..3eec9c4d5b 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -140,7 +140,7 @@ class TestCase(unittest.TestCase):
try:
self.assertEquals(attrs[key], getattr(obj, key))
except AssertionError as e:
- raise (type(e))("Assert error for '.{}':".format(key)) from e
+ raise (type(e))(f"Assert error for '.{key}':") from e
def assert_dict(self, required, actual):
"""Does a partial assert of a dict.
@@ -520,7 +520,7 @@ class HomeserverTestCase(TestCase):
if not isinstance(deferred, Deferred):
return d
- results = [] # type: list
+ results: list = []
deferred.addBoth(results.append)
self.pump(by=by)
@@ -594,7 +594,15 @@ class HomeserverTestCase(TestCase):
user_id = channel.json_body["user_id"]
return user_id
- def login(self, username, password, device_id=None):
+ def login(
+ self,
+ username,
+ password,
+ device_id=None,
+ custom_headers: Optional[
+ Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
+ ] = None,
+ ):
"""
Log in a user, and get an access token. Requires the Login API be
registered.
@@ -605,7 +613,10 @@ class HomeserverTestCase(TestCase):
body["device_id"] = device_id
channel = self.make_request(
- "POST", "/_matrix/client/r0/login", json.dumps(body).encode("utf8")
+ "POST",
+ "/_matrix/client/r0/login",
+ json.dumps(body).encode("utf8"),
+ custom_headers=custom_headers,
)
self.assertEqual(channel.code, 200, channel.result)
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 0277998cbe..39947a166b 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -174,7 +174,7 @@ class DescriptorTestCase(unittest.TestCase):
return self.result
obj = Cls()
- callbacks = set() # type: Set[str]
+ callbacks: Set[str] = set()
# set off an asynchronous request
obj.result = origin_d = defer.Deferred()
diff --git a/tests/util/test_itertools.py b/tests/util/test_itertools.py
index e712eb42ea..3c0ddd4f18 100644
--- a/tests/util/test_itertools.py
+++ b/tests/util/test_itertools.py
@@ -44,7 +44,7 @@ class ChunkSeqTests(TestCase):
)
def test_empty_input(self):
- parts = chunk_seq([], 5) # type: Iterable[Sequence]
+ parts: Iterable[Sequence] = chunk_seq([], 5)
self.assertEqual(
list(parts),
@@ -56,13 +56,13 @@ class SortTopologically(TestCase):
def test_empty(self):
"Test that an empty graph works correctly"
- graph = {} # type: Dict[int, List[int]]
+ graph: Dict[int, List[int]] = {}
self.assertEqual(list(sorted_topologically([], graph)), [])
def test_handle_empty_graph(self):
"Test that a graph where a node doesn't have an entry is treated as empty"
- graph = {} # type: Dict[int, List[int]]
+ graph: Dict[int, List[int]] = {}
# For disconnected nodes the output is simply sorted.
self.assertEqual(list(sorted_topologically([1, 2], graph)), [1, 2])
@@ -70,7 +70,7 @@ class SortTopologically(TestCase):
def test_disconnected(self):
"Test that a graph with no edges work"
- graph = {1: [], 2: []} # type: Dict[int, List[int]]
+ graph: Dict[int, List[int]] = {1: [], 2: []}
# For disconnected nodes the output is simply sorted.
self.assertEqual(list(sorted_topologically([1, 2], graph)), [1, 2])
@@ -78,19 +78,19 @@ class SortTopologically(TestCase):
def test_linear(self):
"Test that a simple `4 -> 3 -> 2 -> 1` graph works"
- graph = {1: [], 2: [1], 3: [2], 4: [3]} # type: Dict[int, List[int]]
+ graph: Dict[int, List[int]] = {1: [], 2: [1], 3: [2], 4: [3]}
self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4])
def test_subset(self):
"Test that only sorting a subset of the graph works"
- graph = {1: [], 2: [1], 3: [2], 4: [3]} # type: Dict[int, List[int]]
+ graph: Dict[int, List[int]] = {1: [], 2: [1], 3: [2], 4: [3]}
self.assertEqual(list(sorted_topologically([4, 3], graph)), [3, 4])
def test_fork(self):
"Test that a forked graph works"
- graph = {1: [], 2: [1], 3: [1], 4: [2, 3]} # type: Dict[int, List[int]]
+ graph: Dict[int, List[int]] = {1: [], 2: [1], 3: [1], 4: [2, 3]}
# Valid orderings are `[1, 3, 2, 4]` or `[1, 2, 3, 4]`, but we should
# always get the same one.
@@ -98,12 +98,12 @@ class SortTopologically(TestCase):
def test_duplicates(self):
"Test that a graph with duplicate edges work"
- graph = {1: [], 2: [1, 1], 3: [2, 2], 4: [3]} # type: Dict[int, List[int]]
+ graph: Dict[int, List[int]] = {1: [], 2: [1, 1], 3: [2, 2], 4: [3]}
self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4])
def test_multiple_paths(self):
"Test that a graph with multiple paths between two nodes work"
- graph = {1: [], 2: [1], 3: [2], 4: [3, 2, 1]} # type: Dict[int, List[int]]
+ graph: Dict[int, List[int]] = {1: [], 2: [1], 3: [2], 4: [3, 2, 1]}
self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4])
|