diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index db65253773..0120b4688b 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -63,7 +63,9 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
hostname="test",
id="1234",
namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]},
- sender="@as:test",
+ # Note: this user does not match the regex above, so that tests
+ # can distinguish the sender from the AS user.
+ sender="@as_main:test",
)
mock_load_appservices = Mock(return_value=[self.appservice])
@@ -122,7 +124,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
{(alice, bob, private), (bob, alice, private)},
)
- # The next three tests (test_population_excludes_*) all setup
+ # The next four tests (test_excludes_*) all setup
# - A normal user included in the user dir
# - A public and private room created by that user
# - A user excluded from the room dir, belonging to both rooms
@@ -179,6 +181,34 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
)
self._check_only_one_user_in_directory(user, public)
+ def test_excludes_appservice_sender(self) -> None:
+ user = self.register_user("user", "pass")
+ token = self.login(user, "pass")
+ room = self.helper.create_room_as(user, is_public=True, tok=token)
+ self.helper.join(room, self.appservice.sender, tok=self.appservice.token)
+ self._check_only_one_user_in_directory(user, room)
+
+ def test_user_not_in_users_table(self) -> None:
+ """Unclear how it happens, but on matrix.org we've seen join events
+ for users who aren't in the users table. Test that we don't fall over
+ when processing such a user.
+ """
+ user1 = self.register_user("user1", "pass")
+ token1 = self.login(user1, "pass")
+ room = self.helper.create_room_as(user1, is_public=True, tok=token1)
+
+ # Inject a join event for a user who doesn't exist
+ self.get_success(inject_member_event(self.hs, room, "@not-a-user:test", "join"))
+
+ # Another new user registers and joins the room
+ user2 = self.register_user("user2", "pass")
+ token2 = self.login(user2, "pass")
+ self.helper.join(room, user2, tok=token2)
+
+ # The dodgy event should not have stopped us from processing user2's join.
+ in_public = self.get_success(self.user_dir_helper.get_users_in_public_rooms())
+ self.assertEqual(set(in_public), {(user1, room), (user2, room)})
+
def _create_rooms_and_inject_memberships(
self, creator: str, token: str, joiner: str
) -> Tuple[str, str]:
@@ -230,7 +260,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
)
)
profile = self.get_success(self.store.get_user_in_directory(support_user_id))
- self.assertTrue(profile is None)
+ self.assertIsNone(profile)
display_name = "display_name"
profile_info = ProfileInfo(avatar_url="avatar_url", display_name=display_name)
@@ -264,7 +294,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
# profile is not in directory
profile = self.get_success(self.store.get_user_in_directory(r_user_id))
- self.assertTrue(profile is None)
+ self.assertIsNone(profile)
# update profile after deactivation
self.get_success(
@@ -273,7 +303,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
# profile is furthermore not in directory
profile = self.get_success(self.store.get_user_in_directory(r_user_id))
- self.assertTrue(profile is None)
+ self.assertIsNone(profile)
def test_handle_local_profile_change_with_appservice_user(self) -> None:
# create user
@@ -283,7 +313,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
# profile is not in directory
profile = self.get_success(self.store.get_user_in_directory(as_user_id))
- self.assertTrue(profile is None)
+ self.assertIsNone(profile)
# update profile
profile_info = ProfileInfo(avatar_url="avatar_url", display_name="4L1c3")
@@ -293,7 +323,28 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
# profile is still not in directory
profile = self.get_success(self.store.get_user_in_directory(as_user_id))
- self.assertTrue(profile is None)
+ self.assertIsNone(profile)
+
+ def test_handle_local_profile_change_with_appservice_sender(self) -> None:
+ # profile is not in directory
+ profile = self.get_success(
+ self.store.get_user_in_directory(self.appservice.sender)
+ )
+ self.assertIsNone(profile)
+
+ # update profile
+ profile_info = ProfileInfo(avatar_url="avatar_url", display_name="4L1c3")
+ self.get_success(
+ self.handler.handle_local_profile_change(
+ self.appservice.sender, profile_info
+ )
+ )
+
+ # profile is still not in directory
+ profile = self.get_success(
+ self.store.get_user_in_directory(self.appservice.sender)
+ )
+ self.assertIsNone(profile)
def test_handle_user_deactivated_support_user(self) -> None:
s_user_id = "@support:test"
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index dada4f98c9..0e4013ebea 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -147,6 +147,49 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
)
@parameterized.expand([(False,), (True,)])
+ def test_get_last_client_ip_by_device(self, after_persisting: bool):
+ """Test `get_last_client_ip_by_device` for persisted and unpersisted data"""
+ self.reactor.advance(12345678)
+
+ user_id = "@user:id"
+ device_id = "MY_DEVICE"
+
+ # Insert a user IP
+ self.get_success(
+ self.store.store_device(
+ user_id,
+ device_id,
+ "display name",
+ )
+ )
+ self.get_success(
+ self.store.insert_client_ip(
+ user_id, "access_token", "ip", "user_agent", device_id
+ )
+ )
+
+ if after_persisting:
+ # Trigger the storage loop
+ self.reactor.advance(10)
+
+ result = self.get_success(
+ self.store.get_last_client_ip_by_device(user_id, device_id)
+ )
+
+ self.assertEqual(
+ result,
+ {
+ (user_id, device_id): {
+ "user_id": user_id,
+ "device_id": device_id,
+ "ip": "ip",
+ "user_agent": "user_agent",
+ "last_seen": 12345678000,
+ },
+ },
+ )
+
+ @parameterized.expand([(False,), (True,)])
def test_get_user_ip_and_agents(self, after_persisting: bool):
"""Test `get_user_ip_and_agents` for persisted and unpersisted data"""
self.reactor.advance(12345678)
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index 32060f2abd..70d52b088c 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -21,7 +21,7 @@ from synapse.api.room_versions import RoomVersions
from synapse.storage.state import StateFilter
from synapse.types import RoomID, UserID
-from tests.unittest import HomeserverTestCase
+from tests.unittest import HomeserverTestCase, TestCase
logger = logging.getLogger(__name__)
@@ -105,7 +105,6 @@ class StateStoreTestCase(HomeserverTestCase):
self.assertEqual({ev.event_id for ev in state_list}, {e1.event_id, e2.event_id})
def test_get_state_for_event(self):
-
# this defaults to a linear DAG as each new injection defaults to whatever
# forward extremities are currently in the DB for this room.
e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
@@ -483,3 +482,513 @@ class StateStoreTestCase(HomeserverTestCase):
self.assertEqual(is_all, True)
self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
+
+
+class StateFilterDifferenceTestCase(TestCase):
+ def assert_difference(
+ self, minuend: StateFilter, subtrahend: StateFilter, expected: StateFilter
+ ):
+ self.assertEqual(
+ minuend.approx_difference(subtrahend),
+ expected,
+ f"StateFilter difference not correct:\n\n\t{minuend!r}\nminus\n\t{subtrahend!r}\nwas\n\t{minuend.approx_difference(subtrahend)}\nexpected\n\t{expected}",
+ )
+
+ def test_state_filter_difference_no_include_other_minus_no_include_other(self):
+ """
+ Tests the StateFilter.approx_difference method
+ where, in a.approx_difference(b), both a and b do not have the
+ include_others flag set.
+ """
+ # (wildcard on state keys) - (wildcard on state keys):
+ self.assert_difference(
+ StateFilter.freeze(
+ {EventTypes.Member: None, EventTypes.Create: None},
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {EventTypes.Member: None, EventTypes.CanonicalAlias: None},
+ include_others=False,
+ ),
+ StateFilter.freeze({EventTypes.Create: None}, include_others=False),
+ )
+
+ # (wildcard on state keys) - (specific state keys)
+ # This one is an over-approximation because we can't represent
+ # 'all state keys except a few named examples'
+ self.assert_difference(
+ StateFilter.freeze({EventTypes.Member: None}, include_others=False),
+ StateFilter.freeze(
+ {EventTypes.Member: {"@wombat:spqr"}},
+ include_others=False,
+ ),
+ StateFilter.freeze({EventTypes.Member: None}, include_others=False),
+ )
+
+ # (wildcard on state keys) - (no state keys)
+ self.assert_difference(
+ StateFilter.freeze(
+ {EventTypes.Member: None},
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: set(),
+ },
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {EventTypes.Member: None},
+ include_others=False,
+ ),
+ )
+
+ # (specific state keys) - (wildcard on state keys):
+ self.assert_difference(
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+ EventTypes.CanonicalAlias: {""},
+ },
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {EventTypes.Member: None},
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {EventTypes.CanonicalAlias: {""}},
+ include_others=False,
+ ),
+ )
+
+ # (specific state keys) - (specific state keys)
+ self.assert_difference(
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+ EventTypes.CanonicalAlias: {""},
+ },
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr"},
+ },
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@spqr:spqr"},
+ EventTypes.CanonicalAlias: {""},
+ },
+ include_others=False,
+ ),
+ )
+
+ # (specific state keys) - (no state keys)
+ self.assert_difference(
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+ EventTypes.CanonicalAlias: {""},
+ },
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: set(),
+ },
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+ EventTypes.CanonicalAlias: {""},
+ },
+ include_others=False,
+ ),
+ )
+
+ def test_state_filter_difference_include_other_minus_no_include_other(self):
+ """
+ Tests the StateFilter.approx_difference method
+ where, in a.approx_difference(b), only a has the include_others flag set.
+ """
+ # (wildcard on state keys) - (wildcard on state keys):
+ self.assert_difference(
+ StateFilter.freeze(
+ {EventTypes.Member: None, EventTypes.Create: None},
+ include_others=True,
+ ),
+ StateFilter.freeze(
+ {EventTypes.Member: None, EventTypes.CanonicalAlias: None},
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Create: None,
+ EventTypes.Member: set(),
+ EventTypes.CanonicalAlias: set(),
+ },
+ include_others=True,
+ ),
+ )
+
+ # (wildcard on state keys) - (specific state keys)
+ # This one is an over-approximation because we can't represent
+ # 'all state keys except a few named examples'
+ # This also shows that the resultant state filter is normalised.
+ self.assert_difference(
+ StateFilter.freeze({EventTypes.Member: None}, include_others=True),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr"},
+ EventTypes.Create: {""},
+ },
+ include_others=False,
+ ),
+ StateFilter(types=frozendict(), include_others=True),
+ )
+
+ # (wildcard on state keys) - (no state keys)
+ self.assert_difference(
+ StateFilter.freeze(
+ {EventTypes.Member: None},
+ include_others=True,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: set(),
+ },
+ include_others=False,
+ ),
+ StateFilter(
+ types=frozendict(),
+ include_others=True,
+ ),
+ )
+
+ # (specific state keys) - (wildcard on state keys):
+ self.assert_difference(
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+ EventTypes.CanonicalAlias: {""},
+ },
+ include_others=True,
+ ),
+ StateFilter.freeze(
+ {EventTypes.Member: None},
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.CanonicalAlias: {""},
+ EventTypes.Member: set(),
+ },
+ include_others=True,
+ ),
+ )
+
+ # (specific state keys) - (specific state keys)
+ self.assert_difference(
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+ EventTypes.CanonicalAlias: {""},
+ },
+ include_others=True,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr"},
+ },
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@spqr:spqr"},
+ EventTypes.CanonicalAlias: {""},
+ },
+ include_others=True,
+ ),
+ )
+
+ # (specific state keys) - (no state keys)
+ self.assert_difference(
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+ EventTypes.CanonicalAlias: {""},
+ },
+ include_others=True,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: set(),
+ },
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+ EventTypes.CanonicalAlias: {""},
+ },
+ include_others=True,
+ ),
+ )
+
+ def test_state_filter_difference_include_other_minus_include_other(self):
+ """
+ Tests the StateFilter.approx_difference method
+ where, in a.approx_difference(b), both a and b have the include_others
+ flag set.
+ """
+ # (wildcard on state keys) - (wildcard on state keys):
+ self.assert_difference(
+ StateFilter.freeze(
+ {EventTypes.Member: None, EventTypes.Create: None},
+ include_others=True,
+ ),
+ StateFilter.freeze(
+ {EventTypes.Member: None, EventTypes.CanonicalAlias: None},
+ include_others=True,
+ ),
+ StateFilter(types=frozendict(), include_others=False),
+ )
+
+ # (wildcard on state keys) - (specific state keys)
+ # This one is an over-approximation because we can't represent
+ # 'all state keys except a few named examples'
+ self.assert_difference(
+ StateFilter.freeze({EventTypes.Member: None}, include_others=True),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr"},
+ EventTypes.CanonicalAlias: {""},
+ },
+ include_others=True,
+ ),
+ StateFilter.freeze(
+ {EventTypes.Member: None, EventTypes.CanonicalAlias: None},
+ include_others=False,
+ ),
+ )
+
+ # (wildcard on state keys) - (no state keys)
+ self.assert_difference(
+ StateFilter.freeze(
+ {EventTypes.Member: None},
+ include_others=True,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: set(),
+ },
+ include_others=True,
+ ),
+ StateFilter.freeze(
+ {EventTypes.Member: None},
+ include_others=False,
+ ),
+ )
+
+ # (specific state keys) - (wildcard on state keys):
+ self.assert_difference(
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+ EventTypes.CanonicalAlias: {""},
+ },
+ include_others=True,
+ ),
+ StateFilter.freeze(
+ {EventTypes.Member: None},
+ include_others=True,
+ ),
+ StateFilter(
+ types=frozendict(),
+ include_others=False,
+ ),
+ )
+
+ # (specific state keys) - (specific state keys)
+ # This one is an over-approximation because we can't represent
+ # 'all state keys except a few named examples'
+ self.assert_difference(
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+ EventTypes.CanonicalAlias: {""},
+ EventTypes.Create: {""},
+ },
+ include_others=True,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr"},
+ EventTypes.Create: set(),
+ },
+ include_others=True,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@spqr:spqr"},
+ EventTypes.Create: {""},
+ },
+ include_others=False,
+ ),
+ )
+
+ # (specific state keys) - (no state keys)
+ self.assert_difference(
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+ EventTypes.CanonicalAlias: {""},
+ },
+ include_others=True,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: set(),
+ },
+ include_others=True,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+ },
+ include_others=False,
+ ),
+ )
+
+ def test_state_filter_difference_no_include_other_minus_include_other(self):
+ """
+ Tests the StateFilter.approx_difference method
+ where, in a.approx_difference(b), only b has the include_others flag set.
+ """
+ # (wildcard on state keys) - (wildcard on state keys):
+ self.assert_difference(
+ StateFilter.freeze(
+ {EventTypes.Member: None, EventTypes.Create: None},
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {EventTypes.Member: None, EventTypes.CanonicalAlias: None},
+ include_others=True,
+ ),
+ StateFilter(types=frozendict(), include_others=False),
+ )
+
+ # (wildcard on state keys) - (specific state keys)
+ # This one is an over-approximation because we can't represent
+ # 'all state keys except a few named examples'
+ self.assert_difference(
+ StateFilter.freeze({EventTypes.Member: None}, include_others=False),
+ StateFilter.freeze(
+ {EventTypes.Member: {"@wombat:spqr"}},
+ include_others=True,
+ ),
+ StateFilter.freeze({EventTypes.Member: None}, include_others=False),
+ )
+
+ # (wildcard on state keys) - (no state keys)
+ self.assert_difference(
+ StateFilter.freeze(
+ {EventTypes.Member: None},
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: set(),
+ },
+ include_others=True,
+ ),
+ StateFilter.freeze(
+ {EventTypes.Member: None},
+ include_others=False,
+ ),
+ )
+
+ # (specific state keys) - (wildcard on state keys):
+ self.assert_difference(
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+ EventTypes.CanonicalAlias: {""},
+ },
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {EventTypes.Member: None},
+ include_others=True,
+ ),
+ StateFilter(
+ types=frozendict(),
+ include_others=False,
+ ),
+ )
+
+ # (specific state keys) - (specific state keys)
+ # This one is an over-approximation because we can't represent
+ # 'all state keys except a few named examples'
+ self.assert_difference(
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+ EventTypes.CanonicalAlias: {""},
+ },
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr"},
+ },
+ include_others=True,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@spqr:spqr"},
+ },
+ include_others=False,
+ ),
+ )
+
+ # (specific state keys) - (no state keys)
+ self.assert_difference(
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+ EventTypes.CanonicalAlias: {""},
+ },
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: set(),
+ },
+ include_others=True,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+ },
+ include_others=False,
+ ),
+ )
+
+ def test_state_filter_difference_simple_cases(self):
+ """
+ Tests some very simple cases of the StateFilter approx_difference,
+ that are not explicitly tested by the more in-depth tests.
+ """
+
+ self.assert_difference(StateFilter.all(), StateFilter.all(), StateFilter.none())
+
+ self.assert_difference(
+ StateFilter.all(),
+ StateFilter.none(),
+ StateFilter.all(),
+ )
diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index 9f483ad681..be3ed64f5e 100644
--- a/tests/storage/test_user_directory.py
+++ b/tests/storage/test_user_directory.py
@@ -256,7 +256,7 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase):
users = self.get_success(self.user_dir_helper.get_users_in_user_directory())
self.assertEqual(users, {u1, u2, u3})
- # The next three tests (test_population_excludes_*) all set up
+ # The next four tests (test_population_excludes_*) all set up
# - A normal user included in the user dir
# - A public and private room created by that user
# - A user excluded from the room dir, belonging to both rooms
@@ -364,6 +364,21 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase):
# Check the AS user is not in the directory.
self._check_room_sharing_tables(user, public, private)
+ def test_population_excludes_appservice_sender(self) -> None:
+ user = self.register_user("user", "pass")
+ token = self.login(user, "pass")
+
+ # Join the AS sender to rooms owned by the normal user.
+ public, private = self._create_rooms_and_inject_memberships(
+ user, token, self.appservice.sender
+ )
+
+ # Rebuild the directory.
+ self._purge_and_rebuild_user_dir()
+
+ # Check the AS sender is not in the directory.
+ self._check_room_sharing_tables(user, public, private)
+
def test_population_conceals_private_nickname(self) -> None:
# Make a private room, and set a nickname within
user = self.register_user("aaaa", "pass")
|