diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index dada4f98c9..0e4013ebea 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -147,6 +147,49 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
)
@parameterized.expand([(False,), (True,)])
+ def test_get_last_client_ip_by_device(self, after_persisting: bool):
+ """Test `get_last_client_ip_by_device` for persisted and unpersisted data"""
+ self.reactor.advance(12345678)
+
+ user_id = "@user:id"
+ device_id = "MY_DEVICE"
+
+ # Insert a user IP
+ self.get_success(
+ self.store.store_device(
+ user_id,
+ device_id,
+ "display name",
+ )
+ )
+ self.get_success(
+ self.store.insert_client_ip(
+ user_id, "access_token", "ip", "user_agent", device_id
+ )
+ )
+
+ if after_persisting:
+ # Trigger the storage loop
+ self.reactor.advance(10)
+
+ result = self.get_success(
+ self.store.get_last_client_ip_by_device(user_id, device_id)
+ )
+
+ self.assertEqual(
+ result,
+ {
+ (user_id, device_id): {
+ "user_id": user_id,
+ "device_id": device_id,
+ "ip": "ip",
+ "user_agent": "user_agent",
+ "last_seen": 12345678000,
+ },
+ },
+ )
+
+ @parameterized.expand([(False,), (True,)])
def test_get_user_ip_and_agents(self, after_persisting: bool):
"""Test `get_user_ip_and_agents` for persisted and unpersisted data"""
self.reactor.advance(12345678)
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index 32060f2abd..70d52b088c 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -21,7 +21,7 @@ from synapse.api.room_versions import RoomVersions
from synapse.storage.state import StateFilter
from synapse.types import RoomID, UserID
-from tests.unittest import HomeserverTestCase
+from tests.unittest import HomeserverTestCase, TestCase
logger = logging.getLogger(__name__)
@@ -105,7 +105,6 @@ class StateStoreTestCase(HomeserverTestCase):
self.assertEqual({ev.event_id for ev in state_list}, {e1.event_id, e2.event_id})
def test_get_state_for_event(self):
-
# this defaults to a linear DAG as each new injection defaults to whatever
# forward extremities are currently in the DB for this room.
e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
@@ -483,3 +482,513 @@ class StateStoreTestCase(HomeserverTestCase):
self.assertEqual(is_all, True)
self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
+
+
+class StateFilterDifferenceTestCase(TestCase):
+ def assert_difference(
+ self, minuend: StateFilter, subtrahend: StateFilter, expected: StateFilter
+ ):
+ self.assertEqual(
+ minuend.approx_difference(subtrahend),
+ expected,
+ f"StateFilter difference not correct:\n\n\t{minuend!r}\nminus\n\t{subtrahend!r}\nwas\n\t{minuend.approx_difference(subtrahend)}\nexpected\n\t{expected}",
+ )
+
+ def test_state_filter_difference_no_include_other_minus_no_include_other(self):
+ """
+ Tests the StateFilter.approx_difference method
+ where, in a.approx_difference(b), both a and b do not have the
+ include_others flag set.
+ """
+ # (wildcard on state keys) - (wildcard on state keys):
+ self.assert_difference(
+ StateFilter.freeze(
+ {EventTypes.Member: None, EventTypes.Create: None},
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {EventTypes.Member: None, EventTypes.CanonicalAlias: None},
+ include_others=False,
+ ),
+ StateFilter.freeze({EventTypes.Create: None}, include_others=False),
+ )
+
+ # (wildcard on state keys) - (specific state keys)
+ # This one is an over-approximation because we can't represent
+ # 'all state keys except a few named examples'
+ self.assert_difference(
+ StateFilter.freeze({EventTypes.Member: None}, include_others=False),
+ StateFilter.freeze(
+ {EventTypes.Member: {"@wombat:spqr"}},
+ include_others=False,
+ ),
+ StateFilter.freeze({EventTypes.Member: None}, include_others=False),
+ )
+
+ # (wildcard on state keys) - (no state keys)
+ self.assert_difference(
+ StateFilter.freeze(
+ {EventTypes.Member: None},
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: set(),
+ },
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {EventTypes.Member: None},
+ include_others=False,
+ ),
+ )
+
+ # (specific state keys) - (wildcard on state keys):
+ self.assert_difference(
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+ EventTypes.CanonicalAlias: {""},
+ },
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {EventTypes.Member: None},
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {EventTypes.CanonicalAlias: {""}},
+ include_others=False,
+ ),
+ )
+
+ # (specific state keys) - (specific state keys)
+ self.assert_difference(
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+ EventTypes.CanonicalAlias: {""},
+ },
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr"},
+ },
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@spqr:spqr"},
+ EventTypes.CanonicalAlias: {""},
+ },
+ include_others=False,
+ ),
+ )
+
+ # (specific state keys) - (no state keys)
+ self.assert_difference(
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+ EventTypes.CanonicalAlias: {""},
+ },
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: set(),
+ },
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+ EventTypes.CanonicalAlias: {""},
+ },
+ include_others=False,
+ ),
+ )
+
+ def test_state_filter_difference_include_other_minus_no_include_other(self):
+ """
+ Tests the StateFilter.approx_difference method
+ where, in a.approx_difference(b), only a has the include_others flag set.
+ """
+ # (wildcard on state keys) - (wildcard on state keys):
+ self.assert_difference(
+ StateFilter.freeze(
+ {EventTypes.Member: None, EventTypes.Create: None},
+ include_others=True,
+ ),
+ StateFilter.freeze(
+ {EventTypes.Member: None, EventTypes.CanonicalAlias: None},
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Create: None,
+ EventTypes.Member: set(),
+ EventTypes.CanonicalAlias: set(),
+ },
+ include_others=True,
+ ),
+ )
+
+ # (wildcard on state keys) - (specific state keys)
+ # This one is an over-approximation because we can't represent
+ # 'all state keys except a few named examples'
+ # This also shows that the resultant state filter is normalised.
+ self.assert_difference(
+ StateFilter.freeze({EventTypes.Member: None}, include_others=True),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr"},
+ EventTypes.Create: {""},
+ },
+ include_others=False,
+ ),
+ StateFilter(types=frozendict(), include_others=True),
+ )
+
+ # (wildcard on state keys) - (no state keys)
+ self.assert_difference(
+ StateFilter.freeze(
+ {EventTypes.Member: None},
+ include_others=True,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: set(),
+ },
+ include_others=False,
+ ),
+ StateFilter(
+ types=frozendict(),
+ include_others=True,
+ ),
+ )
+
+ # (specific state keys) - (wildcard on state keys):
+ self.assert_difference(
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+ EventTypes.CanonicalAlias: {""},
+ },
+ include_others=True,
+ ),
+ StateFilter.freeze(
+ {EventTypes.Member: None},
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.CanonicalAlias: {""},
+ EventTypes.Member: set(),
+ },
+ include_others=True,
+ ),
+ )
+
+ # (specific state keys) - (specific state keys)
+ self.assert_difference(
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+ EventTypes.CanonicalAlias: {""},
+ },
+ include_others=True,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr"},
+ },
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@spqr:spqr"},
+ EventTypes.CanonicalAlias: {""},
+ },
+ include_others=True,
+ ),
+ )
+
+ # (specific state keys) - (no state keys)
+ self.assert_difference(
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+ EventTypes.CanonicalAlias: {""},
+ },
+ include_others=True,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: set(),
+ },
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+ EventTypes.CanonicalAlias: {""},
+ },
+ include_others=True,
+ ),
+ )
+
+ def test_state_filter_difference_include_other_minus_include_other(self):
+ """
+ Tests the StateFilter.approx_difference method
+ where, in a.approx_difference(b), both a and b have the include_others
+ flag set.
+ """
+ # (wildcard on state keys) - (wildcard on state keys):
+ self.assert_difference(
+ StateFilter.freeze(
+ {EventTypes.Member: None, EventTypes.Create: None},
+ include_others=True,
+ ),
+ StateFilter.freeze(
+ {EventTypes.Member: None, EventTypes.CanonicalAlias: None},
+ include_others=True,
+ ),
+ StateFilter(types=frozendict(), include_others=False),
+ )
+
+ # (wildcard on state keys) - (specific state keys)
+ # This one is an over-approximation because we can't represent
+ # 'all state keys except a few named examples'
+ self.assert_difference(
+ StateFilter.freeze({EventTypes.Member: None}, include_others=True),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr"},
+ EventTypes.CanonicalAlias: {""},
+ },
+ include_others=True,
+ ),
+ StateFilter.freeze(
+ {EventTypes.Member: None, EventTypes.CanonicalAlias: None},
+ include_others=False,
+ ),
+ )
+
+ # (wildcard on state keys) - (no state keys)
+ self.assert_difference(
+ StateFilter.freeze(
+ {EventTypes.Member: None},
+ include_others=True,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: set(),
+ },
+ include_others=True,
+ ),
+ StateFilter.freeze(
+ {EventTypes.Member: None},
+ include_others=False,
+ ),
+ )
+
+ # (specific state keys) - (wildcard on state keys):
+ self.assert_difference(
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+ EventTypes.CanonicalAlias: {""},
+ },
+ include_others=True,
+ ),
+ StateFilter.freeze(
+ {EventTypes.Member: None},
+ include_others=True,
+ ),
+ StateFilter(
+ types=frozendict(),
+ include_others=False,
+ ),
+ )
+
+ # (specific state keys) - (specific state keys)
+ # This one is an over-approximation because we can't represent
+ # 'all state keys except a few named examples'
+ self.assert_difference(
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+ EventTypes.CanonicalAlias: {""},
+ EventTypes.Create: {""},
+ },
+ include_others=True,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr"},
+ EventTypes.Create: set(),
+ },
+ include_others=True,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@spqr:spqr"},
+ EventTypes.Create: {""},
+ },
+ include_others=False,
+ ),
+ )
+
+ # (specific state keys) - (no state keys)
+ self.assert_difference(
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+ EventTypes.CanonicalAlias: {""},
+ },
+ include_others=True,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: set(),
+ },
+ include_others=True,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+ },
+ include_others=False,
+ ),
+ )
+
+ def test_state_filter_difference_no_include_other_minus_include_other(self):
+ """
+ Tests the StateFilter.approx_difference method
+ where, in a.approx_difference(b), only b has the include_others flag set.
+ """
+ # (wildcard on state keys) - (wildcard on state keys):
+ self.assert_difference(
+ StateFilter.freeze(
+ {EventTypes.Member: None, EventTypes.Create: None},
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {EventTypes.Member: None, EventTypes.CanonicalAlias: None},
+ include_others=True,
+ ),
+ StateFilter(types=frozendict(), include_others=False),
+ )
+
+ # (wildcard on state keys) - (specific state keys)
+ # This one is an over-approximation because we can't represent
+ # 'all state keys except a few named examples'
+ self.assert_difference(
+ StateFilter.freeze({EventTypes.Member: None}, include_others=False),
+ StateFilter.freeze(
+ {EventTypes.Member: {"@wombat:spqr"}},
+ include_others=True,
+ ),
+ StateFilter.freeze({EventTypes.Member: None}, include_others=False),
+ )
+
+ # (wildcard on state keys) - (no state keys)
+ self.assert_difference(
+ StateFilter.freeze(
+ {EventTypes.Member: None},
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: set(),
+ },
+ include_others=True,
+ ),
+ StateFilter.freeze(
+ {EventTypes.Member: None},
+ include_others=False,
+ ),
+ )
+
+ # (specific state keys) - (wildcard on state keys):
+ self.assert_difference(
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+ EventTypes.CanonicalAlias: {""},
+ },
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {EventTypes.Member: None},
+ include_others=True,
+ ),
+ StateFilter(
+ types=frozendict(),
+ include_others=False,
+ ),
+ )
+
+ # (specific state keys) - (specific state keys)
+ # This one is an over-approximation because we can't represent
+ # 'all state keys except a few named examples'
+ self.assert_difference(
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+ EventTypes.CanonicalAlias: {""},
+ },
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr"},
+ },
+ include_others=True,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@spqr:spqr"},
+ },
+ include_others=False,
+ ),
+ )
+
+ # (specific state keys) - (no state keys)
+ self.assert_difference(
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+ EventTypes.CanonicalAlias: {""},
+ },
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: set(),
+ },
+ include_others=True,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+ },
+ include_others=False,
+ ),
+ )
+
+ def test_state_filter_difference_simple_cases(self):
+ """
+ Tests some very simple cases of the StateFilter approx_difference,
+ that are not explicitly tested by the more in-depth tests.
+ """
+
+ self.assert_difference(StateFilter.all(), StateFilter.all(), StateFilter.none())
+
+ self.assert_difference(
+ StateFilter.all(),
+ StateFilter.none(),
+ StateFilter.all(),
+ )
diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index 9f483ad681..be3ed64f5e 100644
--- a/tests/storage/test_user_directory.py
+++ b/tests/storage/test_user_directory.py
@@ -256,7 +256,7 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase):
users = self.get_success(self.user_dir_helper.get_users_in_user_directory())
self.assertEqual(users, {u1, u2, u3})
- # The next three tests (test_population_excludes_*) all set up
+ # The next four tests (test_population_excludes_*) all set up
# - A normal user included in the user dir
# - A public and private room created by that user
# - A user excluded from the room dir, belonging to both rooms
@@ -364,6 +364,21 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase):
# Check the AS user is not in the directory.
self._check_room_sharing_tables(user, public, private)
+ def test_population_excludes_appservice_sender(self) -> None:
+ user = self.register_user("user", "pass")
+ token = self.login(user, "pass")
+
+ # Join the AS sender to rooms owned by the normal user.
+ public, private = self._create_rooms_and_inject_memberships(
+ user, token, self.appservice.sender
+ )
+
+ # Rebuild the directory.
+ self._purge_and_rebuild_user_dir()
+
+ # Check the AS sender is not in the directory.
+ self._check_room_sharing_tables(user, public, private)
+
def test_population_conceals_private_nickname(self) -> None:
# Make a private room, and set a nickname within
user = self.register_user("aaaa", "pass")
|