summary refs log tree commit diff
path: root/tests/storage
diff options
context:
space:
mode:
Diffstat (limited to 'tests/storage')
-rw-r--r--tests/storage/test_appservice.py2
-rw-r--r--tests/storage/test_background_update.py4
-rw-r--r--tests/storage/test_base.py5
-rw-r--r--tests/storage/test_cleanup_extrems.py248
-rw-r--r--tests/storage/test_client_ips.py8
-rw-r--r--tests/storage/test_end_to_end_keys.py1
-rw-r--r--tests/storage/test_keys.py70
-rw-r--r--tests/storage/test_monthly_active_users.py17
-rw-r--r--tests/storage/test_redaction.py6
-rw-r--r--tests/storage/test_registration.py2
-rw-r--r--tests/storage/test_roommember.py2
-rw-r--r--tests/storage/test_state.py83
-rw-r--r--tests/storage/test_user_directory.py4
13 files changed, 350 insertions, 102 deletions
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index 3f0083831b..25a6c89ef5 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -340,7 +340,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
         other_events = [Mock(event_id="e5"), Mock(event_id="e6")]
 
         # we aren't testing store._base stuff here, so mock this out
-        self.store._get_events = Mock(return_value=events)
+        self.store.get_events_as_list = Mock(return_value=events)
 
         yield self._insert_txn(self.as_list[1]["id"], 9, other_events)
         yield self._insert_txn(service.id, 10, events)
diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py
index 5568a607c7..fbb9302694 100644
--- a/tests/storage/test_background_update.py
+++ b/tests/storage/test_background_update.py
@@ -9,9 +9,7 @@ from tests.utils import setup_test_homeserver
 class BackgroundUpdateTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def setUp(self):
-        hs = yield setup_test_homeserver(
-            self.addCleanup
-        )
+        hs = yield setup_test_homeserver(self.addCleanup)
         self.store = hs.get_datastore()
         self.clock = hs.get_clock()
 
diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index f18db8c384..c778de1f0c 100644
--- a/tests/storage/test_base.py
+++ b/tests/storage/test_base.py
@@ -56,10 +56,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
         fake_engine = Mock(wraps=engine)
         fake_engine.can_native_upsert = False
         hs = TestHomeServer(
-            "test",
-            db_pool=self.db_pool,
-            config=config,
-            database_engine=fake_engine,
+            "test", db_pool=self.db_pool, config=config, database_engine=fake_engine
         )
 
         self.datastore = SQLBaseStore(None, hs)
diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py
new file mode 100644
index 0000000000..6dda66ecd3
--- /dev/null
+++ b/tests/storage/test_cleanup_extrems.py
@@ -0,0 +1,248 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os.path
+
+from synapse.api.constants import EventTypes
+from synapse.storage import prepare_database
+from synapse.types import Requester, UserID
+
+from tests.unittest import HomeserverTestCase
+
+
+class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
+    """Test the background update to clean forward extremities table.
+    """
+
+    def prepare(self, reactor, clock, homeserver):
+        self.store = homeserver.get_datastore()
+        self.event_creator = homeserver.get_event_creation_handler()
+        self.room_creator = homeserver.get_room_creation_handler()
+
+        # Create a test user and room
+        self.user = UserID("alice", "test")
+        self.requester = Requester(self.user, None, False, None, None)
+        info = self.get_success(self.room_creator.create_room(self.requester, {}))
+        self.room_id = info["room_id"]
+
+    def create_and_send_event(self, soft_failed=False, prev_event_ids=None):
+        """Create and send an event.
+
+        Args:
+            soft_failed (bool): Whether to create a soft failed event or not
+            prev_event_ids (list[str]|None): Explicitly set the prev events,
+                or if None just use the default
+
+        Returns:
+            str: The new event's ID.
+        """
+        prev_events_and_hashes = None
+        if prev_event_ids:
+            prev_events_and_hashes = [[p, {}, 0] for p in prev_event_ids]
+
+        event, context = self.get_success(
+            self.event_creator.create_event(
+                self.requester,
+                {
+                    "type": EventTypes.Message,
+                    "room_id": self.room_id,
+                    "sender": self.user.to_string(),
+                    "content": {"body": "", "msgtype": "m.text"},
+                },
+                prev_events_and_hashes=prev_events_and_hashes,
+            )
+        )
+
+        if soft_failed:
+            event.internal_metadata.soft_failed = True
+
+        self.get_success(
+            self.event_creator.send_nonmember_event(self.requester, event, context)
+        )
+
+        return event.event_id
+
+    def add_extremity(self, event_id):
+        """Add the given event as an extremity to the room.
+        """
+        self.get_success(
+            self.store._simple_insert(
+                table="event_forward_extremities",
+                values={"room_id": self.room_id, "event_id": event_id},
+                desc="test_add_extremity",
+            )
+        )
+
+        self.store.get_latest_event_ids_in_room.invalidate((self.room_id,))
+
+    def run_background_update(self):
+        """Re run the background update to clean up the extremities.
+        """
+        # Make sure we don't clash with in progress updates.
+        self.assertTrue(self.store._all_done, "Background updates are still ongoing")
+
+        schema_path = os.path.join(
+            prepare_database.dir_path,
+            "schema",
+            "delta",
+            "54",
+            "delete_forward_extremities.sql",
+        )
+
+        def run_delta_file(txn):
+            prepare_database.executescript(txn, schema_path)
+
+        self.get_success(
+            self.store.runInteraction("test_delete_forward_extremities", run_delta_file)
+        )
+
+        # Ugh, have to reset this flag
+        self.store._all_done = False
+
+        while not self.get_success(self.store.has_completed_background_updates()):
+            self.get_success(self.store.do_next_background_update(100), by=0.1)
+
+    def test_soft_failed_extremities_handled_correctly(self):
+        """Test that extremities are correctly calculated in the presence of
+        soft failed events.
+
+        Tests a graph like:
+
+            A <- SF1 <- SF2 <- B
+
+        Where SF* are soft failed.
+        """
+
+        # Create the room graph
+        event_id_1 = self.create_and_send_event()
+        event_id_2 = self.create_and_send_event(True, [event_id_1])
+        event_id_3 = self.create_and_send_event(True, [event_id_2])
+        event_id_4 = self.create_and_send_event(False, [event_id_3])
+
+        # Check the latest events are as expected
+        latest_event_ids = self.get_success(
+            self.store.get_latest_event_ids_in_room(self.room_id)
+        )
+
+        self.assertEqual(latest_event_ids, [event_id_4])
+
+    def test_basic_cleanup(self):
+        """Test that extremities are correctly calculated in the presence of
+        soft failed events.
+
+        Tests a graph like:
+
+            A <- SF1 <- B
+
+        Where SF* are soft failed, and with extremities of A and B
+        """
+        # Create the room graph
+        event_id_a = self.create_and_send_event()
+        event_id_sf1 = self.create_and_send_event(True, [event_id_a])
+        event_id_b = self.create_and_send_event(False, [event_id_sf1])
+
+        # Add the new extremity and check the latest events are as expected
+        self.add_extremity(event_id_a)
+
+        latest_event_ids = self.get_success(
+            self.store.get_latest_event_ids_in_room(self.room_id)
+        )
+        self.assertEqual(set(latest_event_ids), set((event_id_a, event_id_b)))
+
+        # Run the background update and check it did the right thing
+        self.run_background_update()
+
+        latest_event_ids = self.get_success(
+            self.store.get_latest_event_ids_in_room(self.room_id)
+        )
+        self.assertEqual(latest_event_ids, [event_id_b])
+
+    def test_chain_of_fail_cleanup(self):
+        """Test that extremities are correctly calculated in the presence of
+        soft failed events.
+
+        Tests a graph like:
+
+            A <- SF1 <- SF2 <- B
+
+        Where SF* are soft failed, and with extremities of A and B
+        """
+        # Create the room graph
+        event_id_a = self.create_and_send_event()
+        event_id_sf1 = self.create_and_send_event(True, [event_id_a])
+        event_id_sf2 = self.create_and_send_event(True, [event_id_sf1])
+        event_id_b = self.create_and_send_event(False, [event_id_sf2])
+
+        # Add the new extremity and check the latest events are as expected
+        self.add_extremity(event_id_a)
+
+        latest_event_ids = self.get_success(
+            self.store.get_latest_event_ids_in_room(self.room_id)
+        )
+        self.assertEqual(set(latest_event_ids), set((event_id_a, event_id_b)))
+
+        # Run the background update and check it did the right thing
+        self.run_background_update()
+
+        latest_event_ids = self.get_success(
+            self.store.get_latest_event_ids_in_room(self.room_id)
+        )
+        self.assertEqual(latest_event_ids, [event_id_b])
+
+    def test_forked_graph_cleanup(self):
+        r"""Test that extremities are correctly calculated in the presence of
+        soft failed events.
+
+        Tests a graph like, where time flows down the page:
+
+                A     B
+               / \   /
+              /   \ /
+            SF1   SF2
+             |     |
+            SF3    |
+           /   \   |
+           |    \  |
+           C     SF4
+
+        Where SF* are soft failed, and with them A, B and C marked as
+        extremities. This should resolve to B and C being marked as extremity.
+        """
+        # Create the room graph
+        event_id_a = self.create_and_send_event()
+        event_id_b = self.create_and_send_event()
+        event_id_sf1 = self.create_and_send_event(True, [event_id_a])
+        event_id_sf2 = self.create_and_send_event(True, [event_id_a, event_id_b])
+        event_id_sf3 = self.create_and_send_event(True, [event_id_sf1])
+        self.create_and_send_event(True, [event_id_sf2, event_id_sf3])  # SF4
+        event_id_c = self.create_and_send_event(False, [event_id_sf3])
+
+        # Add the new extremity and check the latest events are as expected
+        self.add_extremity(event_id_a)
+
+        latest_event_ids = self.get_success(
+            self.store.get_latest_event_ids_in_room(self.room_id)
+        )
+        self.assertEqual(
+            set(latest_event_ids), set((event_id_a, event_id_b, event_id_c))
+        )
+
+        # Run the background update and check it did the right thing
+        self.run_background_update()
+
+        latest_event_ids = self.get_success(
+            self.store.get_latest_event_ids_in_room(self.room_id)
+        )
+        self.assertEqual(set(latest_event_ids), set([event_id_b, event_id_c]))
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index 858efe4992..b62eae7abc 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -18,8 +18,9 @@ from mock import Mock
 
 from twisted.internet import defer
 
+import synapse.rest.admin
 from synapse.http.site import XForwardedForRequest
-from synapse.rest.client.v1 import admin, login
+from synapse.rest.client.v1 import login
 
 from tests import unittest
 
@@ -205,7 +206,10 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
 
 class ClientIpAuthTestCase(unittest.HomeserverTestCase):
 
-    servlets = [admin.register_servlets, login.register_servlets]
+    servlets = [
+        synapse.rest.admin.register_servlets_for_client_rest_resource,
+        login.register_servlets,
+    ]
 
     def make_homeserver(self, reactor, clock):
         hs = self.setup_test_homeserver()
diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py
index 11fb8c0c19..cd2bcd4ca3 100644
--- a/tests/storage/test_end_to_end_keys.py
+++ b/tests/storage/test_end_to_end_keys.py
@@ -20,7 +20,6 @@ import tests.utils
 
 
 class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
-
     @defer.inlineCallbacks
     def setUp(self):
         hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
diff --git a/tests/storage/test_keys.py b/tests/storage/test_keys.py
index 6bfaa00fe9..e07ff01201 100644
--- a/tests/storage/test_keys.py
+++ b/tests/storage/test_keys.py
@@ -17,6 +17,8 @@ import signedjson.key
 
 from twisted.internet.defer import Deferred
 
+from synapse.storage.keys import FetchKeyResult
+
 import tests.unittest
 
 KEY_1 = signedjson.key.decode_verify_key_base64(
@@ -31,23 +33,34 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
     def test_get_server_verify_keys(self):
         store = self.hs.get_datastore()
 
-        d = store.store_server_verify_key("server1", "from_server", 0, KEY_1)
-        self.get_success(d)
-        d = store.store_server_verify_key("server1", "from_server", 0, KEY_2)
+        key_id_1 = "ed25519:key1"
+        key_id_2 = "ed25519:KEY_ID_2"
+        d = store.store_server_verify_keys(
+            "from_server",
+            10,
+            [
+                ("server1", key_id_1, FetchKeyResult(KEY_1, 100)),
+                ("server1", key_id_2, FetchKeyResult(KEY_2, 200)),
+            ],
+        )
         self.get_success(d)
 
         d = store.get_server_verify_keys(
-            [
-                ("server1", "ed25519:key1"),
-                ("server1", "ed25519:key2"),
-                ("server1", "ed25519:key3"),
-            ]
+            [("server1", key_id_1), ("server1", key_id_2), ("server1", "ed25519:key3")]
         )
         res = self.get_success(d)
 
         self.assertEqual(len(res.keys()), 3)
-        self.assertEqual(res[("server1", "ed25519:key1")].version, "key1")
-        self.assertEqual(res[("server1", "ed25519:key2")].version, "key2")
+        res1 = res[("server1", key_id_1)]
+        self.assertEqual(res1.verify_key, KEY_1)
+        self.assertEqual(res1.verify_key.version, "key1")
+        self.assertEqual(res1.valid_until_ts, 100)
+
+        res2 = res[("server1", key_id_2)]
+        self.assertEqual(res2.verify_key, KEY_2)
+        # version comes from the ID it was stored with
+        self.assertEqual(res2.verify_key.version, "KEY_ID_2")
+        self.assertEqual(res2.valid_until_ts, 200)
 
         # non-existent result gives None
         self.assertIsNone(res[("server1", "ed25519:key3")])
@@ -60,32 +73,51 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
         key_id_1 = "ed25519:key1"
         key_id_2 = "ed25519:key2"
 
-        d = store.store_server_verify_key("srv1", "from_server", 0, KEY_1)
-        self.get_success(d)
-        d = store.store_server_verify_key("srv1", "from_server", 0, KEY_2)
+        d = store.store_server_verify_keys(
+            "from_server",
+            0,
+            [
+                ("srv1", key_id_1, FetchKeyResult(KEY_1, 100)),
+                ("srv1", key_id_2, FetchKeyResult(KEY_2, 200)),
+            ],
+        )
         self.get_success(d)
 
         d = store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)])
         res = self.get_success(d)
         self.assertEqual(len(res.keys()), 2)
-        self.assertEqual(res[("srv1", key_id_1)], KEY_1)
-        self.assertEqual(res[("srv1", key_id_2)], KEY_2)
+
+        res1 = res[("srv1", key_id_1)]
+        self.assertEqual(res1.verify_key, KEY_1)
+        self.assertEqual(res1.valid_until_ts, 100)
+
+        res2 = res[("srv1", key_id_2)]
+        self.assertEqual(res2.verify_key, KEY_2)
+        self.assertEqual(res2.valid_until_ts, 200)
 
         # we should be able to look up the same thing again without a db hit
         res = store.get_server_verify_keys([("srv1", key_id_1)])
         if isinstance(res, Deferred):
             res = self.successResultOf(res)
         self.assertEqual(len(res.keys()), 1)
-        self.assertEqual(res[("srv1", key_id_1)], KEY_1)
+        self.assertEqual(res[("srv1", key_id_1)].verify_key, KEY_1)
 
         new_key_2 = signedjson.key.get_verify_key(
             signedjson.key.generate_signing_key("key2")
         )
-        d = store.store_server_verify_key("srv1", "from_server", 10, new_key_2)
+        d = store.store_server_verify_keys(
+            "from_server", 10, [("srv1", key_id_2, FetchKeyResult(new_key_2, 300))]
+        )
         self.get_success(d)
 
         d = store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)])
         res = self.get_success(d)
         self.assertEqual(len(res.keys()), 2)
-        self.assertEqual(res[("srv1", key_id_1)], KEY_1)
-        self.assertEqual(res[("srv1", key_id_2)], new_key_2)
+
+        res1 = res[("srv1", key_id_1)]
+        self.assertEqual(res1.verify_key, KEY_1)
+        self.assertEqual(res1.valid_until_ts, 100)
+
+        res2 = res[("srv1", key_id_2)]
+        self.assertEqual(res2.verify_key, new_key_2)
+        self.assertEqual(res2.valid_until_ts, 300)
diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py
index d6569a82bb..f458c03054 100644
--- a/tests/storage/test_monthly_active_users.py
+++ b/tests/storage/test_monthly_active_users.py
@@ -56,8 +56,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
         self.store.register(user_id=user1, token="123", password_hash=None)
         self.store.register(user_id=user2, token="456", password_hash=None)
         self.store.register(
-            user_id=user3, token="789",
-            password_hash=None, user_type=UserTypes.SUPPORT
+            user_id=user3, token="789", password_hash=None, user_type=UserTypes.SUPPORT
         )
         self.pump()
 
@@ -173,9 +172,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
     def test_populate_monthly_users_should_update(self):
         self.store.upsert_monthly_active_user = Mock()
 
-        self.store.is_trial_user = Mock(
-            return_value=defer.succeed(False)
-        )
+        self.store.is_trial_user = Mock(return_value=defer.succeed(False))
 
         self.store.user_last_seen_monthly_active = Mock(
             return_value=defer.succeed(None)
@@ -187,13 +184,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
     def test_populate_monthly_users_should_not_update(self):
         self.store.upsert_monthly_active_user = Mock()
 
-        self.store.is_trial_user = Mock(
-            return_value=defer.succeed(False)
-        )
+        self.store.is_trial_user = Mock(return_value=defer.succeed(False))
         self.store.user_last_seen_monthly_active = Mock(
-            return_value=defer.succeed(
-                self.hs.get_clock().time_msec()
-            )
+            return_value=defer.succeed(self.hs.get_clock().time_msec())
         )
         self.store.populate_monthly_active_users('user_id')
         self.pump()
@@ -243,7 +236,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
             user_id=support_user_id,
             token="123",
             password_hash=None,
-            user_type=UserTypes.SUPPORT
+            user_type=UserTypes.SUPPORT,
         )
 
         self.store.upsert_monthly_active_user(support_user_id)
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index 0fc5019e9f..4823d44dec 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -60,7 +60,7 @@ class RedactionTestCase(unittest.TestCase):
                 "state_key": user.to_string(),
                 "room_id": room.to_string(),
                 "content": content,
-            }
+            },
         )
 
         event, context = yield self.event_creation_handler.create_new_client_event(
@@ -83,7 +83,7 @@ class RedactionTestCase(unittest.TestCase):
                 "state_key": user.to_string(),
                 "room_id": room.to_string(),
                 "content": {"body": body, "msgtype": u"message"},
-            }
+            },
         )
 
         event, context = yield self.event_creation_handler.create_new_client_event(
@@ -105,7 +105,7 @@ class RedactionTestCase(unittest.TestCase):
                 "room_id": room.to_string(),
                 "content": {"reason": reason},
                 "redacts": event_id,
-            }
+            },
         )
 
         event, context = yield self.event_creation_handler.create_new_client_event(
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index cb3cc4d2e5..c0e0155bb4 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -116,7 +116,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
             user_id=SUPPORT_USER,
             token="456",
             password_hash=None,
-            user_type=UserTypes.SUPPORT
+            user_type=UserTypes.SUPPORT,
         )
         res = yield self.store.is_support_user(SUPPORT_USER)
         self.assertTrue(res)
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index 063387863e..73ed943f5a 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -58,7 +58,7 @@ class RoomMemberStoreTestCase(unittest.TestCase):
                 "state_key": user.to_string(),
                 "room_id": room.to_string(),
                 "content": {"membership": membership},
-            }
+            },
         )
 
         event, context = yield self.event_creation_handler.create_new_client_event(
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index 78e260a7fa..b6169436de 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -29,7 +29,6 @@ logger = logging.getLogger(__name__)
 
 
 class StateStoreTestCase(tests.unittest.TestCase):
-
     @defer.inlineCallbacks
     def setUp(self):
         hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
@@ -57,7 +56,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
                 "state_key": state_key,
                 "room_id": room.to_string(),
                 "content": content,
-            }
+            },
         )
 
         event, context = yield self.event_creation_handler.create_new_client_event(
@@ -83,15 +82,14 @@ class StateStoreTestCase(tests.unittest.TestCase):
             self.room, self.u_alice, EventTypes.Name, '', {"name": "test room"}
         )
 
-        state_group_map = yield self.store.get_state_groups_ids(self.room, [e2.event_id])
+        state_group_map = yield self.store.get_state_groups_ids(
+            self.room, [e2.event_id]
+        )
         self.assertEqual(len(state_group_map), 1)
         state_map = list(state_group_map.values())[0]
         self.assertDictEqual(
             state_map,
-            {
-                (EventTypes.Create, ''): e1.event_id,
-                (EventTypes.Name, ''): e2.event_id,
-            },
+            {(EventTypes.Create, ''): e1.event_id, (EventTypes.Name, ''): e2.event_id},
         )
 
     @defer.inlineCallbacks
@@ -103,15 +101,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
             self.room, self.u_alice, EventTypes.Name, '', {"name": "test room"}
         )
 
-        state_group_map = yield self.store.get_state_groups(
-            self.room, [e2.event_id])
+        state_group_map = yield self.store.get_state_groups(self.room, [e2.event_id])
         self.assertEqual(len(state_group_map), 1)
         state_list = list(state_group_map.values())[0]
 
-        self.assertEqual(
-            {ev.event_id for ev in state_list},
-            {e1.event_id, e2.event_id},
-        )
+        self.assertEqual({ev.event_id for ev in state_list}, {e1.event_id, e2.event_id})
 
     @defer.inlineCallbacks
     def test_get_state_for_event(self):
@@ -147,9 +141,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
         )
 
         # check we get the full state as of the final event
-        state = yield self.store.get_state_for_event(
-            e5.event_id,
-        )
+        state = yield self.store.get_state_for_event(e5.event_id)
 
         self.assertIsNotNone(e4)
 
@@ -194,7 +186,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
             state_filter=StateFilter(
                 types={EventTypes.Member: {self.u_alice.to_string()}},
                 include_others=True,
-            )
+            ),
         )
 
         self.assertStateMapEqual(
@@ -208,9 +200,9 @@ class StateStoreTestCase(tests.unittest.TestCase):
 
         # check that we can grab everything except members
         state = yield self.store.get_state_for_event(
-            e5.event_id, state_filter=StateFilter(
-                types={EventTypes.Member: set()},
-                include_others=True,
+            e5.event_id,
+            state_filter=StateFilter(
+                types={EventTypes.Member: set()}, include_others=True
             ),
         )
 
@@ -229,10 +221,10 @@ class StateStoreTestCase(tests.unittest.TestCase):
         # test _get_state_for_group_using_cache correctly filters out members
         # with types=[]
         (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
-            self.store._state_group_cache, group,
+            self.store._state_group_cache,
+            group,
             state_filter=StateFilter(
-                types={EventTypes.Member: set()},
-                include_others=True,
+                types={EventTypes.Member: set()}, include_others=True
             ),
         )
 
@@ -249,8 +241,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
             self.store._state_group_members_cache,
             group,
             state_filter=StateFilter(
-                types={EventTypes.Member: set()},
-                include_others=True,
+                types={EventTypes.Member: set()}, include_others=True
             ),
         )
 
@@ -263,8 +254,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
             self.store._state_group_cache,
             group,
             state_filter=StateFilter(
-                types={EventTypes.Member: None},
-                include_others=True,
+                types={EventTypes.Member: None}, include_others=True
             ),
         )
 
@@ -281,8 +271,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
             self.store._state_group_members_cache,
             group,
             state_filter=StateFilter(
-                types={EventTypes.Member: None},
-                include_others=True,
+                types={EventTypes.Member: None}, include_others=True
             ),
         )
 
@@ -302,8 +291,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
             self.store._state_group_cache,
             group,
             state_filter=StateFilter(
-                types={EventTypes.Member: {e5.state_key}},
-                include_others=True,
+                types={EventTypes.Member: {e5.state_key}}, include_others=True
             ),
         )
 
@@ -320,8 +308,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
             self.store._state_group_members_cache,
             group,
             state_filter=StateFilter(
-                types={EventTypes.Member: {e5.state_key}},
-                include_others=True,
+                types={EventTypes.Member: {e5.state_key}}, include_others=True
             ),
         )
 
@@ -334,8 +321,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
             self.store._state_group_members_cache,
             group,
             state_filter=StateFilter(
-                types={EventTypes.Member: {e5.state_key}},
-                include_others=False,
+                types={EventTypes.Member: {e5.state_key}}, include_others=False
             ),
         )
 
@@ -384,10 +370,10 @@ class StateStoreTestCase(tests.unittest.TestCase):
         # with types=[]
         room_id = self.room.to_string()
         (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
-            self.store._state_group_cache, group,
+            self.store._state_group_cache,
+            group,
             state_filter=StateFilter(
-                types={EventTypes.Member: set()},
-                include_others=True,
+                types={EventTypes.Member: set()}, include_others=True
             ),
         )
 
@@ -399,8 +385,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
             self.store._state_group_members_cache,
             group,
             state_filter=StateFilter(
-                types={EventTypes.Member: set()},
-                include_others=True,
+                types={EventTypes.Member: set()}, include_others=True
             ),
         )
 
@@ -413,8 +398,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
             self.store._state_group_cache,
             group,
             state_filter=StateFilter(
-                types={EventTypes.Member: None},
-                include_others=True,
+                types={EventTypes.Member: None}, include_others=True
             ),
         )
 
@@ -425,8 +409,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
             self.store._state_group_members_cache,
             group,
             state_filter=StateFilter(
-                types={EventTypes.Member: None},
-                include_others=True,
+                types={EventTypes.Member: None}, include_others=True
             ),
         )
 
@@ -445,8 +428,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
             self.store._state_group_cache,
             group,
             state_filter=StateFilter(
-                types={EventTypes.Member: {e5.state_key}},
-                include_others=True,
+                types={EventTypes.Member: {e5.state_key}}, include_others=True
             ),
         )
 
@@ -457,8 +439,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
             self.store._state_group_members_cache,
             group,
             state_filter=StateFilter(
-                types={EventTypes.Member: {e5.state_key}},
-                include_others=True,
+                types={EventTypes.Member: {e5.state_key}}, include_others=True
             ),
         )
 
@@ -471,8 +452,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
             self.store._state_group_cache,
             group,
             state_filter=StateFilter(
-                types={EventTypes.Member: {e5.state_key}},
-                include_others=False,
+                types={EventTypes.Member: {e5.state_key}}, include_others=False
             ),
         )
 
@@ -483,8 +463,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
             self.store._state_group_members_cache,
             group,
             state_filter=StateFilter(
-                types={EventTypes.Member: {e5.state_key}},
-                include_others=False,
+                types={EventTypes.Member: {e5.state_key}}, include_others=False
             ),
         )
 
diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index fd3361404f..d7d244ce97 100644
--- a/tests/storage/test_user_directory.py
+++ b/tests/storage/test_user_directory.py
@@ -36,9 +36,7 @@ class UserDirectoryStoreTestCase(unittest.TestCase):
         yield self.store.update_profile_in_user_dir(ALICE, "alice", None)
         yield self.store.update_profile_in_user_dir(BOB, "bob", None)
         yield self.store.update_profile_in_user_dir(BOBBY, "bobby", None)
-        yield self.store.add_users_in_public_rooms(
-            "!room:id", (ALICE, BOB)
-        )
+        yield self.store.add_users_in_public_rooms("!room:id", (ALICE, BOB))
 
     @defer.inlineCallbacks
     def test_search_user_dir(self):