summary refs log tree commit diff
path: root/tests/storage
diff options
context:
space:
mode:
authorDavid Robertson <davidr@element.io>2021-10-13 10:38:22 +0100
committerGitHub <noreply@github.com>2021-10-13 09:38:22 +0000
commitb83e82255630f2ba7ea959b68e5a82b4d938f000 (patch)
tree8e5ee3373e3686a21b3ed7756639af8e553a5195 /tests/storage
parentFix formatting string when oEmbed errors occur. (#11061) (diff)
downloadsynapse-b83e82255630f2ba7ea959b68e5a82b4d938f000.tar.xz
Stop user directory from failing if it encounters users not in the `users` table. (#11053)
The following scenarios would halt the user directory updater:

- user joins room
- user leaves room
- user present in room which switches from private to public, or vice versa.

for two classes of users:

- appservice senders
- users missing from the user table.

If this happened, the user directory would be stuck, unable to make forward progress.

Exclude both cases from the user directory, so that we ignore them.

Co-authored-by: Eric Eastwood <erice@element.io>
Co-authored-by: reivilibre <oliverw@matrix.org>
Co-authored-by: Sean Quah <8349537+squahtx@users.noreply.github.com>
Co-authored-by: Brendan Abolivier <babolivier@matrix.org>
Diffstat (limited to 'tests/storage')
-rw-r--r--tests/storage/test_client_ips.py43
-rw-r--r--tests/storage/test_state.py513
-rw-r--r--tests/storage/test_user_directory.py17
3 files changed, 570 insertions, 3 deletions
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index dada4f98c9..0e4013ebea 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -147,6 +147,49 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
         )
 
     @parameterized.expand([(False,), (True,)])
+    def test_get_last_client_ip_by_device(self, after_persisting: bool):
+        """Test `get_last_client_ip_by_device` for persisted and unpersisted data"""
+        self.reactor.advance(12345678)
+
+        user_id = "@user:id"
+        device_id = "MY_DEVICE"
+
+        # Insert a user IP
+        self.get_success(
+            self.store.store_device(
+                user_id,
+                device_id,
+                "display name",
+            )
+        )
+        self.get_success(
+            self.store.insert_client_ip(
+                user_id, "access_token", "ip", "user_agent", device_id
+            )
+        )
+
+        if after_persisting:
+            # Trigger the storage loop
+            self.reactor.advance(10)
+
+        result = self.get_success(
+            self.store.get_last_client_ip_by_device(user_id, device_id)
+        )
+
+        self.assertEqual(
+            result,
+            {
+                (user_id, device_id): {
+                    "user_id": user_id,
+                    "device_id": device_id,
+                    "ip": "ip",
+                    "user_agent": "user_agent",
+                    "last_seen": 12345678000,
+                },
+            },
+        )
+
+    @parameterized.expand([(False,), (True,)])
     def test_get_user_ip_and_agents(self, after_persisting: bool):
         """Test `get_user_ip_and_agents` for persisted and unpersisted data"""
         self.reactor.advance(12345678)
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index 32060f2abd..70d52b088c 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -21,7 +21,7 @@ from synapse.api.room_versions import RoomVersions
 from synapse.storage.state import StateFilter
 from synapse.types import RoomID, UserID
 
-from tests.unittest import HomeserverTestCase
+from tests.unittest import HomeserverTestCase, TestCase
 
 logger = logging.getLogger(__name__)
 
@@ -105,7 +105,6 @@ class StateStoreTestCase(HomeserverTestCase):
         self.assertEqual({ev.event_id for ev in state_list}, {e1.event_id, e2.event_id})
 
     def test_get_state_for_event(self):
-
         # this defaults to a linear DAG as each new injection defaults to whatever
         # forward extremities are currently in the DB for this room.
         e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
@@ -483,3 +482,513 @@ class StateStoreTestCase(HomeserverTestCase):
 
         self.assertEqual(is_all, True)
         self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
+
+
+class StateFilterDifferenceTestCase(TestCase):
+    def assert_difference(
+        self, minuend: StateFilter, subtrahend: StateFilter, expected: StateFilter
+    ):
+        self.assertEqual(
+            minuend.approx_difference(subtrahend),
+            expected,
+            f"StateFilter difference not correct:\n\n\t{minuend!r}\nminus\n\t{subtrahend!r}\nwas\n\t{minuend.approx_difference(subtrahend)}\nexpected\n\t{expected}",
+        )
+
+    def test_state_filter_difference_no_include_other_minus_no_include_other(self):
+        """
+        Tests the StateFilter.approx_difference method
+        where, in a.approx_difference(b), both a and b do not have the
+        include_others flag set.
+        """
+        # (wildcard on state keys) - (wildcard on state keys):
+        self.assert_difference(
+            StateFilter.freeze(
+                {EventTypes.Member: None, EventTypes.Create: None},
+                include_others=False,
+            ),
+            StateFilter.freeze(
+                {EventTypes.Member: None, EventTypes.CanonicalAlias: None},
+                include_others=False,
+            ),
+            StateFilter.freeze({EventTypes.Create: None}, include_others=False),
+        )
+
+        # (wildcard on state keys) - (specific state keys)
+        # This one is an over-approximation because we can't represent
+        # 'all state keys except a few named examples'
+        self.assert_difference(
+            StateFilter.freeze({EventTypes.Member: None}, include_others=False),
+            StateFilter.freeze(
+                {EventTypes.Member: {"@wombat:spqr"}},
+                include_others=False,
+            ),
+            StateFilter.freeze({EventTypes.Member: None}, include_others=False),
+        )
+
+        # (wildcard on state keys) - (no state keys)
+        self.assert_difference(
+            StateFilter.freeze(
+                {EventTypes.Member: None},
+                include_others=False,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: set(),
+                },
+                include_others=False,
+            ),
+            StateFilter.freeze(
+                {EventTypes.Member: None},
+                include_others=False,
+            ),
+        )
+
+        # (specific state keys) - (wildcard on state keys):
+        self.assert_difference(
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+                    EventTypes.CanonicalAlias: {""},
+                },
+                include_others=False,
+            ),
+            StateFilter.freeze(
+                {EventTypes.Member: None},
+                include_others=False,
+            ),
+            StateFilter.freeze(
+                {EventTypes.CanonicalAlias: {""}},
+                include_others=False,
+            ),
+        )
+
+        # (specific state keys) - (specific state keys)
+        self.assert_difference(
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+                    EventTypes.CanonicalAlias: {""},
+                },
+                include_others=False,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr"},
+                },
+                include_others=False,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@spqr:spqr"},
+                    EventTypes.CanonicalAlias: {""},
+                },
+                include_others=False,
+            ),
+        )
+
+        # (specific state keys) - (no state keys)
+        self.assert_difference(
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+                    EventTypes.CanonicalAlias: {""},
+                },
+                include_others=False,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: set(),
+                },
+                include_others=False,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+                    EventTypes.CanonicalAlias: {""},
+                },
+                include_others=False,
+            ),
+        )
+
+    def test_state_filter_difference_include_other_minus_no_include_other(self):
+        """
+        Tests the StateFilter.approx_difference method
+        where, in a.approx_difference(b), only a has the include_others flag set.
+        """
+        # (wildcard on state keys) - (wildcard on state keys):
+        self.assert_difference(
+            StateFilter.freeze(
+                {EventTypes.Member: None, EventTypes.Create: None},
+                include_others=True,
+            ),
+            StateFilter.freeze(
+                {EventTypes.Member: None, EventTypes.CanonicalAlias: None},
+                include_others=False,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.Create: None,
+                    EventTypes.Member: set(),
+                    EventTypes.CanonicalAlias: set(),
+                },
+                include_others=True,
+            ),
+        )
+
+        # (wildcard on state keys) - (specific state keys)
+        # This one is an over-approximation because we can't represent
+        # 'all state keys except a few named examples'
+        # This also shows that the resultant state filter is normalised.
+        self.assert_difference(
+            StateFilter.freeze({EventTypes.Member: None}, include_others=True),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr"},
+                    EventTypes.Create: {""},
+                },
+                include_others=False,
+            ),
+            StateFilter(types=frozendict(), include_others=True),
+        )
+
+        # (wildcard on state keys) - (no state keys)
+        self.assert_difference(
+            StateFilter.freeze(
+                {EventTypes.Member: None},
+                include_others=True,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: set(),
+                },
+                include_others=False,
+            ),
+            StateFilter(
+                types=frozendict(),
+                include_others=True,
+            ),
+        )
+
+        # (specific state keys) - (wildcard on state keys):
+        self.assert_difference(
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+                    EventTypes.CanonicalAlias: {""},
+                },
+                include_others=True,
+            ),
+            StateFilter.freeze(
+                {EventTypes.Member: None},
+                include_others=False,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.CanonicalAlias: {""},
+                    EventTypes.Member: set(),
+                },
+                include_others=True,
+            ),
+        )
+
+        # (specific state keys) - (specific state keys)
+        self.assert_difference(
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+                    EventTypes.CanonicalAlias: {""},
+                },
+                include_others=True,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr"},
+                },
+                include_others=False,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@spqr:spqr"},
+                    EventTypes.CanonicalAlias: {""},
+                },
+                include_others=True,
+            ),
+        )
+
+        # (specific state keys) - (no state keys)
+        self.assert_difference(
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+                    EventTypes.CanonicalAlias: {""},
+                },
+                include_others=True,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: set(),
+                },
+                include_others=False,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+                    EventTypes.CanonicalAlias: {""},
+                },
+                include_others=True,
+            ),
+        )
+
+    def test_state_filter_difference_include_other_minus_include_other(self):
+        """
+        Tests the StateFilter.approx_difference method
+        where, in a.approx_difference(b), both a and b have the include_others
+        flag set.
+        """
+        # (wildcard on state keys) - (wildcard on state keys):
+        self.assert_difference(
+            StateFilter.freeze(
+                {EventTypes.Member: None, EventTypes.Create: None},
+                include_others=True,
+            ),
+            StateFilter.freeze(
+                {EventTypes.Member: None, EventTypes.CanonicalAlias: None},
+                include_others=True,
+            ),
+            StateFilter(types=frozendict(), include_others=False),
+        )
+
+        # (wildcard on state keys) - (specific state keys)
+        # This one is an over-approximation because we can't represent
+        # 'all state keys except a few named examples'
+        self.assert_difference(
+            StateFilter.freeze({EventTypes.Member: None}, include_others=True),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr"},
+                    EventTypes.CanonicalAlias: {""},
+                },
+                include_others=True,
+            ),
+            StateFilter.freeze(
+                {EventTypes.Member: None, EventTypes.CanonicalAlias: None},
+                include_others=False,
+            ),
+        )
+
+        # (wildcard on state keys) - (no state keys)
+        self.assert_difference(
+            StateFilter.freeze(
+                {EventTypes.Member: None},
+                include_others=True,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: set(),
+                },
+                include_others=True,
+            ),
+            StateFilter.freeze(
+                {EventTypes.Member: None},
+                include_others=False,
+            ),
+        )
+
+        # (specific state keys) - (wildcard on state keys):
+        self.assert_difference(
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+                    EventTypes.CanonicalAlias: {""},
+                },
+                include_others=True,
+            ),
+            StateFilter.freeze(
+                {EventTypes.Member: None},
+                include_others=True,
+            ),
+            StateFilter(
+                types=frozendict(),
+                include_others=False,
+            ),
+        )
+
+        # (specific state keys) - (specific state keys)
+        # This one is an over-approximation because we can't represent
+        # 'all state keys except a few named examples'
+        self.assert_difference(
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+                    EventTypes.CanonicalAlias: {""},
+                    EventTypes.Create: {""},
+                },
+                include_others=True,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr"},
+                    EventTypes.Create: set(),
+                },
+                include_others=True,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@spqr:spqr"},
+                    EventTypes.Create: {""},
+                },
+                include_others=False,
+            ),
+        )
+
+        # (specific state keys) - (no state keys)
+        self.assert_difference(
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+                    EventTypes.CanonicalAlias: {""},
+                },
+                include_others=True,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: set(),
+                },
+                include_others=True,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+                },
+                include_others=False,
+            ),
+        )
+
+    def test_state_filter_difference_no_include_other_minus_include_other(self):
+        """
+        Tests the StateFilter.approx_difference method
+        where, in a.approx_difference(b), only b has the include_others flag set.
+        """
+        # (wildcard on state keys) - (wildcard on state keys):
+        self.assert_difference(
+            StateFilter.freeze(
+                {EventTypes.Member: None, EventTypes.Create: None},
+                include_others=False,
+            ),
+            StateFilter.freeze(
+                {EventTypes.Member: None, EventTypes.CanonicalAlias: None},
+                include_others=True,
+            ),
+            StateFilter(types=frozendict(), include_others=False),
+        )
+
+        # (wildcard on state keys) - (specific state keys)
+        # This one is an over-approximation because we can't represent
+        # 'all state keys except a few named examples'
+        self.assert_difference(
+            StateFilter.freeze({EventTypes.Member: None}, include_others=False),
+            StateFilter.freeze(
+                {EventTypes.Member: {"@wombat:spqr"}},
+                include_others=True,
+            ),
+            StateFilter.freeze({EventTypes.Member: None}, include_others=False),
+        )
+
+        # (wildcard on state keys) - (no state keys)
+        self.assert_difference(
+            StateFilter.freeze(
+                {EventTypes.Member: None},
+                include_others=False,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: set(),
+                },
+                include_others=True,
+            ),
+            StateFilter.freeze(
+                {EventTypes.Member: None},
+                include_others=False,
+            ),
+        )
+
+        # (specific state keys) - (wildcard on state keys):
+        self.assert_difference(
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+                    EventTypes.CanonicalAlias: {""},
+                },
+                include_others=False,
+            ),
+            StateFilter.freeze(
+                {EventTypes.Member: None},
+                include_others=True,
+            ),
+            StateFilter(
+                types=frozendict(),
+                include_others=False,
+            ),
+        )
+
+        # (specific state keys) - (specific state keys)
+        # This one is an over-approximation because we can't represent
+        # 'all state keys except a few named examples'
+        self.assert_difference(
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+                    EventTypes.CanonicalAlias: {""},
+                },
+                include_others=False,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr"},
+                },
+                include_others=True,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@spqr:spqr"},
+                },
+                include_others=False,
+            ),
+        )
+
+        # (specific state keys) - (no state keys)
+        self.assert_difference(
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+                    EventTypes.CanonicalAlias: {""},
+                },
+                include_others=False,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: set(),
+                },
+                include_others=True,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+                },
+                include_others=False,
+            ),
+        )
+
+    def test_state_filter_difference_simple_cases(self):
+        """
+        Tests some very simple cases of the StateFilter approx_difference,
+        that are not explicitly tested by the more in-depth tests.
+        """
+
+        self.assert_difference(StateFilter.all(), StateFilter.all(), StateFilter.none())
+
+        self.assert_difference(
+            StateFilter.all(),
+            StateFilter.none(),
+            StateFilter.all(),
+        )
diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index 9f483ad681..be3ed64f5e 100644
--- a/tests/storage/test_user_directory.py
+++ b/tests/storage/test_user_directory.py
@@ -256,7 +256,7 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase):
         users = self.get_success(self.user_dir_helper.get_users_in_user_directory())
         self.assertEqual(users, {u1, u2, u3})
 
-    # The next three tests (test_population_excludes_*) all set up
+    # The next four tests (test_population_excludes_*) all set up
     #   - A normal user included in the user dir
     #   - A public and private room created by that user
     #   - A user excluded from the room dir, belonging to both rooms
@@ -364,6 +364,21 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase):
         # Check the AS user is not in the directory.
         self._check_room_sharing_tables(user, public, private)
 
+    def test_population_excludes_appservice_sender(self) -> None:
+        user = self.register_user("user", "pass")
+        token = self.login(user, "pass")
+
+        # Join the AS sender to rooms owned by the normal user.
+        public, private = self._create_rooms_and_inject_memberships(
+            user, token, self.appservice.sender
+        )
+
+        # Rebuild the directory.
+        self._purge_and_rebuild_user_dir()
+
+        # Check the AS sender is not in the directory.
+        self._check_room_sharing_tables(user, public, private)
+
     def test_population_conceals_private_nickname(self) -> None:
         # Make a private room, and set a nickname within
         user = self.register_user("aaaa", "pass")