summary refs log tree commit diff
path: root/tests/storage
diff options
context:
space:
mode:
Diffstat (limited to 'tests/storage')
-rw-r--r--tests/storage/test__base.py7
-rw-r--r--tests/storage/test__init__.py65
-rw-r--r--tests/storage/test_appservice.py175
-rw-r--r--tests/storage/test_background_update.py21
-rw-r--r--tests/storage/test_base.py46
-rw-r--r--tests/storage/test_client_ips.py71
-rw-r--r--tests/storage/test_devices.py71
-rw-r--r--tests/storage/test_directory.py24
-rw-r--r--tests/storage/test_end_to_end_keys.py64
-rw-r--r--tests/storage/test_event_federation.py41
-rw-r--r--tests/storage/test_event_push_actions.py80
-rw-r--r--tests/storage/test_keys.py13
-rw-r--r--tests/storage/test_monthly_active_users.py131
-rw-r--r--tests/storage/test_presence.py71
-rw-r--r--tests/storage/test_profile.py20
-rw-r--r--tests/storage/test_redaction.py74
-rw-r--r--tests/storage/test_registration.py41
-rw-r--r--tests/storage/test_room.py51
-rw-r--r--tests/storage/test_roommember.py35
-rw-r--r--tests/storage/test_state.py342
-rw-r--r--tests/storage/test_user_directory.py42
21 files changed, 793 insertions, 692 deletions
diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index 6d6f00c5c5..52eb05bfbf 100644
--- a/tests/storage/test__base.py
+++ b/tests/storage/test__base.py
@@ -18,14 +18,13 @@ from mock import Mock
 
 from twisted.internet import defer
 
-from synapse.util.async import ObservableDeferred
+from synapse.util.async_helpers import ObservableDeferred
 from synapse.util.caches.descriptors import Cache, cached
 
 from tests import unittest
 
 
 class CacheTestCase(unittest.TestCase):
-
     def setUp(self):
         self.cache = Cache("test")
 
@@ -97,7 +96,6 @@ class CacheTestCase(unittest.TestCase):
 
 
 class CacheDecoratorTestCase(unittest.TestCase):
-
     @defer.inlineCallbacks
     def test_passthrough(self):
         class A(object):
@@ -180,8 +178,7 @@ class CacheDecoratorTestCase(unittest.TestCase):
             yield a.func(k)
 
         self.assertTrue(
-            callcount[0] >= 14,
-            msg="Expected callcount >= 14, got %d" % (callcount[0])
+            callcount[0] >= 14, msg="Expected callcount >= 14, got %d" % (callcount[0])
         )
 
     def test_prefill(self):
diff --git a/tests/storage/test__init__.py b/tests/storage/test__init__.py
deleted file mode 100644
index f19cb1265c..0000000000
--- a/tests/storage/test__init__.py
+++ /dev/null
@@ -1,65 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2018 New Vector Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from twisted.internet import defer
-
-import tests.utils
-
-
-class InitTestCase(tests.unittest.TestCase):
-    def __init__(self, *args, **kwargs):
-        super(InitTestCase, self).__init__(*args, **kwargs)
-        self.store = None  # type: synapse.storage.DataStore
-
-    @defer.inlineCallbacks
-    def setUp(self):
-        hs = yield tests.utils.setup_test_homeserver()
-
-        hs.config.max_mau_value = 50
-        hs.config.limit_usage_by_mau = True
-        self.store = hs.get_datastore()
-        self.clock = hs.get_clock()
-
-    @defer.inlineCallbacks
-    def test_count_monthly_users(self):
-        count = yield self.store.count_monthly_users()
-        self.assertEqual(0, count)
-
-        yield self._insert_user_ips("@user:server1")
-        yield self._insert_user_ips("@user:server2")
-
-        count = yield self.store.count_monthly_users()
-        self.assertEqual(2, count)
-
-    @defer.inlineCallbacks
-    def _insert_user_ips(self, user):
-        """
-        Helper function to populate user_ips without using batch insertion infra
-        args:
-            user (str):  specify username i.e. @user:server.com
-        """
-        yield self.store._simple_upsert(
-            table="user_ips",
-            keyvalues={
-                "user_id": user,
-                "access_token": "access_token",
-                "ip": "ip",
-                "user_agent": "user_agent",
-                "device_id": "device_id",
-            },
-            values={
-                "last_seen": self.clock.time_msec(),
-            }
-        )
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index 099861b27c..c893990454 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -34,7 +34,6 @@ from tests.utils import setup_test_homeserver
 
 
 class ApplicationServiceStoreTestCase(unittest.TestCase):
-
     @defer.inlineCallbacks
     def setUp(self):
         self.as_yaml_files = []
@@ -44,6 +43,7 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
             password_providers=[],
         )
         hs = yield setup_test_homeserver(
+            self.addCleanup,
             config=config,
             federation_sender=Mock(),
             federation_client=Mock(),
@@ -53,11 +53,7 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
         self.as_url = "some_url"
         self.as_id = "as1"
         self._add_appservice(
-            self.as_token,
-            self.as_id,
-            self.as_url,
-            "some_hs_token",
-            "bob"
+            self.as_token, self.as_id, self.as_url, "some_hs_token", "bob"
         )
         self._add_appservice("token2", "as2", "some_url", "some_hs_token", "bob")
         self._add_appservice("token3", "as3", "some_url", "some_hs_token", "bob")
@@ -73,8 +69,14 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
                 pass
 
     def _add_appservice(self, as_token, id, url, hs_token, sender):
-        as_yaml = dict(url=url, as_token=as_token, hs_token=hs_token,
-                       id=id, sender_localpart=sender, namespaces={})
+        as_yaml = dict(
+            url=url,
+            as_token=as_token,
+            hs_token=hs_token,
+            id=id,
+            sender_localpart=sender,
+            namespaces={},
+        )
         # use the token as the filename
         with open(as_token, 'w') as outfile:
             outfile.write(yaml.dump(as_yaml))
@@ -85,24 +87,13 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
         self.assertEquals(service, None)
 
     def test_retrieval_of_service(self):
-        stored_service = self.store.get_app_service_by_token(
-            self.as_token
-        )
+        stored_service = self.store.get_app_service_by_token(self.as_token)
         self.assertEquals(stored_service.token, self.as_token)
         self.assertEquals(stored_service.id, self.as_id)
         self.assertEquals(stored_service.url, self.as_url)
-        self.assertEquals(
-            stored_service.namespaces[ApplicationService.NS_ALIASES],
-            []
-        )
-        self.assertEquals(
-            stored_service.namespaces[ApplicationService.NS_ROOMS],
-            []
-        )
-        self.assertEquals(
-            stored_service.namespaces[ApplicationService.NS_USERS],
-            []
-        )
+        self.assertEquals(stored_service.namespaces[ApplicationService.NS_ALIASES], [])
+        self.assertEquals(stored_service.namespaces[ApplicationService.NS_ROOMS], [])
+        self.assertEquals(stored_service.namespaces[ApplicationService.NS_USERS], [])
 
     def test_retrieval_of_all_services(self):
         services = self.store.get_app_services()
@@ -110,7 +101,6 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
 
 
 class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
-
     @defer.inlineCallbacks
     def setUp(self):
         self.as_yaml_files = []
@@ -121,6 +111,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
             password_providers=[],
         )
         hs = yield setup_test_homeserver(
+            self.addCleanup,
             config=config,
             federation_sender=Mock(),
             federation_client=Mock(),
@@ -128,26 +119,10 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
         self.db_pool = hs.get_db_pool()
 
         self.as_list = [
-            {
-                "token": "token1",
-                "url": "https://matrix-as.org",
-                "id": "id_1"
-            },
-            {
-                "token": "alpha_tok",
-                "url": "https://alpha.com",
-                "id": "id_alpha"
-            },
-            {
-                "token": "beta_tok",
-                "url": "https://beta.com",
-                "id": "id_beta"
-            },
-            {
-                "token": "gamma_tok",
-                "url": "https://gamma.com",
-                "id": "id_gamma"
-            },
+            {"token": "token1", "url": "https://matrix-as.org", "id": "id_1"},
+            {"token": "alpha_tok", "url": "https://alpha.com", "id": "id_alpha"},
+            {"token": "beta_tok", "url": "https://beta.com", "id": "id_beta"},
+            {"token": "gamma_tok", "url": "https://gamma.com", "id": "id_gamma"},
         ]
         for s in self.as_list:
             yield self._add_service(s["url"], s["token"], s["id"])
@@ -157,8 +132,14 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
         self.store = TestTransactionStore(None, hs)
 
     def _add_service(self, url, as_token, id):
-        as_yaml = dict(url=url, as_token=as_token, hs_token="something",
-                       id=id, sender_localpart="a_sender", namespaces={})
+        as_yaml = dict(
+            url=url,
+            as_token=as_token,
+            hs_token="something",
+            id=id,
+            sender_localpart="a_sender",
+            namespaces={},
+        )
         # use the token as the filename
         with open(as_token, 'w') as outfile:
             outfile.write(yaml.dump(as_yaml))
@@ -168,21 +149,21 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
         return self.db_pool.runQuery(
             "INSERT INTO application_services_state(as_id, state, last_txn) "
             "VALUES(?,?,?)",
-            (id, state, txn)
+            (id, state, txn),
         )
 
     def _insert_txn(self, as_id, txn_id, events):
         return self.db_pool.runQuery(
             "INSERT INTO application_services_txns(as_id, txn_id, event_ids) "
             "VALUES(?,?,?)",
-            (as_id, txn_id, json.dumps([e.event_id for e in events]))
+            (as_id, txn_id, json.dumps([e.event_id for e in events])),
         )
 
     def _set_last_txn(self, as_id, txn_id):
         return self.db_pool.runQuery(
             "INSERT INTO application_services_state(as_id, last_txn, state) "
             "VALUES(?,?,?)",
-            (as_id, txn_id, ApplicationServiceState.UP)
+            (as_id, txn_id, ApplicationServiceState.UP),
         )
 
     @defer.inlineCallbacks
@@ -193,24 +174,16 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_get_appservice_state_up(self):
-        yield self._set_state(
-            self.as_list[0]["id"], ApplicationServiceState.UP
-        )
+        yield self._set_state(self.as_list[0]["id"], ApplicationServiceState.UP)
         service = Mock(id=self.as_list[0]["id"])
         state = yield self.store.get_appservice_state(service)
         self.assertEquals(ApplicationServiceState.UP, state)
 
     @defer.inlineCallbacks
     def test_get_appservice_state_down(self):
-        yield self._set_state(
-            self.as_list[0]["id"], ApplicationServiceState.UP
-        )
-        yield self._set_state(
-            self.as_list[1]["id"], ApplicationServiceState.DOWN
-        )
-        yield self._set_state(
-            self.as_list[2]["id"], ApplicationServiceState.DOWN
-        )
+        yield self._set_state(self.as_list[0]["id"], ApplicationServiceState.UP)
+        yield self._set_state(self.as_list[1]["id"], ApplicationServiceState.DOWN)
+        yield self._set_state(self.as_list[2]["id"], ApplicationServiceState.DOWN)
         service = Mock(id=self.as_list[1]["id"])
         state = yield self.store.get_appservice_state(service)
         self.assertEquals(ApplicationServiceState.DOWN, state)
@@ -225,34 +198,22 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def test_set_appservices_state_down(self):
         service = Mock(id=self.as_list[1]["id"])
-        yield self.store.set_appservice_state(
-            service,
-            ApplicationServiceState.DOWN
-        )
+        yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
         rows = yield self.db_pool.runQuery(
             "SELECT as_id FROM application_services_state WHERE state=?",
-            (ApplicationServiceState.DOWN,)
+            (ApplicationServiceState.DOWN,),
         )
         self.assertEquals(service.id, rows[0][0])
 
     @defer.inlineCallbacks
     def test_set_appservices_state_multiple_up(self):
         service = Mock(id=self.as_list[1]["id"])
-        yield self.store.set_appservice_state(
-            service,
-            ApplicationServiceState.UP
-        )
-        yield self.store.set_appservice_state(
-            service,
-            ApplicationServiceState.DOWN
-        )
-        yield self.store.set_appservice_state(
-            service,
-            ApplicationServiceState.UP
-        )
+        yield self.store.set_appservice_state(service, ApplicationServiceState.UP)
+        yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
+        yield self.store.set_appservice_state(service, ApplicationServiceState.UP)
         rows = yield self.db_pool.runQuery(
             "SELECT as_id FROM application_services_state WHERE state=?",
-            (ApplicationServiceState.UP,)
+            (ApplicationServiceState.UP,),
         )
         self.assertEquals(service.id, rows[0][0])
 
@@ -319,14 +280,13 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
 
         res = yield self.db_pool.runQuery(
             "SELECT last_txn FROM application_services_state WHERE as_id=?",
-            (service.id,)
+            (service.id,),
         )
         self.assertEquals(1, len(res))
         self.assertEquals(txn_id, res[0][0])
 
         res = yield self.db_pool.runQuery(
-            "SELECT * FROM application_services_txns WHERE txn_id=?",
-            (txn_id,)
+            "SELECT * FROM application_services_txns WHERE txn_id=?", (txn_id,)
         )
         self.assertEquals(0, len(res))
 
@@ -340,17 +300,15 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
         yield self.store.complete_appservice_txn(txn_id=txn_id, service=service)
 
         res = yield self.db_pool.runQuery(
-            "SELECT last_txn, state FROM application_services_state WHERE "
-            "as_id=?",
-            (service.id,)
+            "SELECT last_txn, state FROM application_services_state WHERE " "as_id=?",
+            (service.id,),
         )
         self.assertEquals(1, len(res))
         self.assertEquals(txn_id, res[0][0])
         self.assertEquals(ApplicationServiceState.UP, res[0][1])
 
         res = yield self.db_pool.runQuery(
-            "SELECT * FROM application_services_txns WHERE txn_id=?",
-            (txn_id,)
+            "SELECT * FROM application_services_txns WHERE txn_id=?", (txn_id,)
         )
         self.assertEquals(0, len(res))
 
@@ -382,12 +340,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_get_appservices_by_state_single(self):
-        yield self._set_state(
-            self.as_list[0]["id"], ApplicationServiceState.DOWN
-        )
-        yield self._set_state(
-            self.as_list[1]["id"], ApplicationServiceState.UP
-        )
+        yield self._set_state(self.as_list[0]["id"], ApplicationServiceState.DOWN)
+        yield self._set_state(self.as_list[1]["id"], ApplicationServiceState.UP)
 
         services = yield self.store.get_appservices_by_state(
             ApplicationServiceState.DOWN
@@ -397,18 +351,10 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_get_appservices_by_state_multiple(self):
-        yield self._set_state(
-            self.as_list[0]["id"], ApplicationServiceState.DOWN
-        )
-        yield self._set_state(
-            self.as_list[1]["id"], ApplicationServiceState.UP
-        )
-        yield self._set_state(
-            self.as_list[2]["id"], ApplicationServiceState.DOWN
-        )
-        yield self._set_state(
-            self.as_list[3]["id"], ApplicationServiceState.UP
-        )
+        yield self._set_state(self.as_list[0]["id"], ApplicationServiceState.DOWN)
+        yield self._set_state(self.as_list[1]["id"], ApplicationServiceState.UP)
+        yield self._set_state(self.as_list[2]["id"], ApplicationServiceState.DOWN)
+        yield self._set_state(self.as_list[3]["id"], ApplicationServiceState.UP)
 
         services = yield self.store.get_appservices_by_state(
             ApplicationServiceState.DOWN
@@ -416,20 +362,17 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
         self.assertEquals(2, len(services))
         self.assertEquals(
             set([self.as_list[2]["id"], self.as_list[0]["id"]]),
-            set([services[0].id, services[1].id])
+            set([services[0].id, services[1].id]),
         )
 
 
 # required for ApplicationServiceTransactionStoreTestCase tests
-class TestTransactionStore(ApplicationServiceTransactionStore,
-                           ApplicationServiceStore):
-
+class TestTransactionStore(ApplicationServiceTransactionStore, ApplicationServiceStore):
     def __init__(self, db_conn, hs):
         super(TestTransactionStore, self).__init__(db_conn, hs)
 
 
 class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
-
     def _write_config(self, suffix, **kwargs):
         vals = {
             "id": "id" + suffix,
@@ -452,10 +395,10 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
         f2 = self._write_config(suffix="2")
 
         config = Mock(
-            app_service_config_files=[f1, f2], event_cache_size=1,
-            password_providers=[]
+            app_service_config_files=[f1, f2], event_cache_size=1, password_providers=[]
         )
         hs = yield setup_test_homeserver(
+            self.addCleanup,
             config=config,
             datastore=Mock(),
             federation_sender=Mock(),
@@ -470,10 +413,10 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
         f2 = self._write_config(id="id", suffix="2")
 
         config = Mock(
-            app_service_config_files=[f1, f2], event_cache_size=1,
-            password_providers=[]
+            app_service_config_files=[f1, f2], event_cache_size=1, password_providers=[]
         )
         hs = yield setup_test_homeserver(
+            self.addCleanup,
             config=config,
             datastore=Mock(),
             federation_sender=Mock(),
@@ -494,10 +437,10 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
         f2 = self._write_config(as_token="as_token", suffix="2")
 
         config = Mock(
-            app_service_config_files=[f1, f2], event_cache_size=1,
-            password_providers=[]
+            app_service_config_files=[f1, f2], event_cache_size=1, password_providers=[]
         )
         hs = yield setup_test_homeserver(
+            self.addCleanup,
             config=config,
             datastore=Mock(),
             federation_sender=Mock(),
diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py
index ab1f310572..81403727c5 100644
--- a/tests/storage/test_background_update.py
+++ b/tests/storage/test_background_update.py
@@ -7,10 +7,11 @@ from tests.utils import setup_test_homeserver
 
 
 class BackgroundUpdateTestCase(unittest.TestCase):
-
     @defer.inlineCallbacks
     def setUp(self):
-        hs = yield setup_test_homeserver()  # type: synapse.server.HomeServer
+        hs = yield setup_test_homeserver(
+            self.addCleanup
+        )  # type: synapse.server.HomeServer
         self.store = hs.get_datastore()
         self.clock = hs.get_clock()
 
@@ -51,9 +52,7 @@ class BackgroundUpdateTestCase(unittest.TestCase):
         yield self.store.start_background_update("test_update", {"my_key": 1})
 
         self.update_handler.reset_mock()
-        result = yield self.store.do_next_background_update(
-            duration_ms * desired_count
-        )
+        result = yield self.store.do_next_background_update(duration_ms * desired_count)
         self.assertIsNotNone(result)
         self.update_handler.assert_called_once_with(
             {"my_key": 1}, self.store.DEFAULT_BACKGROUND_BATCH_SIZE
@@ -67,18 +66,12 @@ class BackgroundUpdateTestCase(unittest.TestCase):
 
         self.update_handler.side_effect = update
         self.update_handler.reset_mock()
-        result = yield self.store.do_next_background_update(
-            duration_ms * desired_count
-        )
+        result = yield self.store.do_next_background_update(duration_ms * desired_count)
         self.assertIsNotNone(result)
-        self.update_handler.assert_called_once_with(
-            {"my_key": 2}, desired_count
-        )
+        self.update_handler.assert_called_once_with({"my_key": 2}, desired_count)
 
         # third step: we don't expect to be called any more
         self.update_handler.reset_mock()
-        result = yield self.store.do_next_background_update(
-            duration_ms * desired_count
-        )
+        result = yield self.store.do_next_background_update(duration_ms * desired_count)
         self.assertIsNone(result)
         self.assertFalse(self.update_handler.called)
diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index 1d1234ee39..7cb5f0e4cf 100644
--- a/tests/storage/test_base.py
+++ b/tests/storage/test_base.py
@@ -40,10 +40,12 @@ class SQLBaseStoreTestCase(unittest.TestCase):
 
         def runInteraction(func, *args, **kwargs):
             return defer.succeed(func(self.mock_txn, *args, **kwargs))
+
         self.db_pool.runInteraction = runInteraction
 
         def runWithConnection(func, *args, **kwargs):
             return defer.succeed(func(self.mock_conn, *args, **kwargs))
+
         self.db_pool.runWithConnection = runWithConnection
 
         config = Mock()
@@ -63,8 +65,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
         self.mock_txn.rowcount = 1
 
         yield self.datastore._simple_insert(
-            table="tablename",
-            values={"columname": "Value"}
+            table="tablename", values={"columname": "Value"}
         )
 
         self.mock_txn.execute.assert_called_with(
@@ -78,12 +79,11 @@ class SQLBaseStoreTestCase(unittest.TestCase):
         yield self.datastore._simple_insert(
             table="tablename",
             # Use OrderedDict() so we can assert on the SQL generated
-            values=OrderedDict([("colA", 1), ("colB", 2), ("colC", 3)])
+            values=OrderedDict([("colA", 1), ("colB", 2), ("colC", 3)]),
         )
 
         self.mock_txn.execute.assert_called_with(
-            "INSERT INTO tablename (colA, colB, colC) VALUES(?, ?, ?)",
-            (1, 2, 3,)
+            "INSERT INTO tablename (colA, colB, colC) VALUES(?, ?, ?)", (1, 2, 3)
         )
 
     @defer.inlineCallbacks
@@ -92,9 +92,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
         self.mock_txn.__iter__ = Mock(return_value=iter([("Value",)]))
 
         value = yield self.datastore._simple_select_one_onecol(
-            table="tablename",
-            keyvalues={"keycol": "TheKey"},
-            retcol="retcol"
+            table="tablename", keyvalues={"keycol": "TheKey"}, retcol="retcol"
         )
 
         self.assertEquals("Value", value)
@@ -110,13 +108,12 @@ class SQLBaseStoreTestCase(unittest.TestCase):
         ret = yield self.datastore._simple_select_one(
             table="tablename",
             keyvalues={"keycol": "TheKey"},
-            retcols=["colA", "colB", "colC"]
+            retcols=["colA", "colB", "colC"],
         )
 
         self.assertEquals({"colA": 1, "colB": 2, "colC": 3}, ret)
         self.mock_txn.execute.assert_called_with(
-            "SELECT colA, colB, colC FROM tablename WHERE keycol = ?",
-            ["TheKey"]
+            "SELECT colA, colB, colC FROM tablename WHERE keycol = ?", ["TheKey"]
         )
 
     @defer.inlineCallbacks
@@ -128,7 +125,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
             table="tablename",
             keyvalues={"keycol": "Not here"},
             retcols=["colA"],
-            allow_none=True
+            allow_none=True,
         )
 
         self.assertFalse(ret)
@@ -137,20 +134,15 @@ class SQLBaseStoreTestCase(unittest.TestCase):
     def test_select_list(self):
         self.mock_txn.rowcount = 3
         self.mock_txn.__iter__ = Mock(return_value=iter([(1,), (2,), (3,)]))
-        self.mock_txn.description = (
-            ("colA", None, None, None, None, None, None),
-        )
+        self.mock_txn.description = (("colA", None, None, None, None, None, None),)
 
         ret = yield self.datastore._simple_select_list(
-            table="tablename",
-            keyvalues={"keycol": "A set"},
-            retcols=["colA"],
+            table="tablename", keyvalues={"keycol": "A set"}, retcols=["colA"]
         )
 
         self.assertEquals([{"colA": 1}, {"colA": 2}, {"colA": 3}], ret)
         self.mock_txn.execute.assert_called_with(
-            "SELECT colA FROM tablename WHERE keycol = ?",
-            ["A set"]
+            "SELECT colA FROM tablename WHERE keycol = ?", ["A set"]
         )
 
     @defer.inlineCallbacks
@@ -160,12 +152,12 @@ class SQLBaseStoreTestCase(unittest.TestCase):
         yield self.datastore._simple_update_one(
             table="tablename",
             keyvalues={"keycol": "TheKey"},
-            updatevalues={"columnname": "New Value"}
+            updatevalues={"columnname": "New Value"},
         )
 
         self.mock_txn.execute.assert_called_with(
             "UPDATE tablename SET columnname = ? WHERE keycol = ?",
-            ["New Value", "TheKey"]
+            ["New Value", "TheKey"],
         )
 
     @defer.inlineCallbacks
@@ -175,13 +167,12 @@ class SQLBaseStoreTestCase(unittest.TestCase):
         yield self.datastore._simple_update_one(
             table="tablename",
             keyvalues=OrderedDict([("colA", 1), ("colB", 2)]),
-            updatevalues=OrderedDict([("colC", 3), ("colD", 4)])
+            updatevalues=OrderedDict([("colC", 3), ("colD", 4)]),
         )
 
         self.mock_txn.execute.assert_called_with(
-            "UPDATE tablename SET colC = ?, colD = ? WHERE"
-            " colA = ? AND colB = ?",
-            [3, 4, 1, 2]
+            "UPDATE tablename SET colC = ?, colD = ? WHERE" " colA = ? AND colB = ?",
+            [3, 4, 1, 2],
         )
 
     @defer.inlineCallbacks
@@ -189,8 +180,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
         self.mock_txn.rowcount = 1
 
         yield self.datastore._simple_delete_one(
-            table="tablename",
-            keyvalues={"keycol": "Go away"},
+            table="tablename", keyvalues={"keycol": "Go away"}
         )
 
         self.mock_txn.execute.assert_called_with(
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index bd6fda6cb1..c2e88bdbaf 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -12,6 +12,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from mock import Mock
 
 from twisted.internet import defer
 
@@ -27,17 +28,16 @@ class ClientIpStoreTestCase(tests.unittest.TestCase):
 
     @defer.inlineCallbacks
     def setUp(self):
-        hs = yield tests.utils.setup_test_homeserver()
-        self.store = hs.get_datastore()
-        self.clock = hs.get_clock()
+        self.hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
+        self.store = self.hs.get_datastore()
+        self.clock = self.hs.get_clock()
 
     @defer.inlineCallbacks
     def test_insert_new_client_ip(self):
         self.clock.now = 12345678
         user_id = "@user:id"
         yield self.store.insert_client_ip(
-            user_id,
-            "access_token", "ip", "user_agent", "device_id",
+            user_id, "access_token", "ip", "user_agent", "device_id"
         )
 
         result = yield self.store.get_last_client_ip_by_device(user_id, "device_id")
@@ -52,5 +52,64 @@ class ClientIpStoreTestCase(tests.unittest.TestCase):
                 "user_agent": "user_agent",
                 "last_seen": 12345678000,
             },
-            r
+            r,
         )
+
+    @defer.inlineCallbacks
+    def test_disabled_monthly_active_user(self):
+        self.hs.config.limit_usage_by_mau = False
+        self.hs.config.max_mau_value = 50
+        user_id = "@user:server"
+        yield self.store.insert_client_ip(
+            user_id, "access_token", "ip", "user_agent", "device_id"
+        )
+        active = yield self.store.user_last_seen_monthly_active(user_id)
+        self.assertFalse(active)
+
+    @defer.inlineCallbacks
+    def test_adding_monthly_active_user_when_full(self):
+        self.hs.config.limit_usage_by_mau = True
+        self.hs.config.max_mau_value = 50
+        lots_of_users = 100
+        user_id = "@user:server"
+
+        self.store.get_monthly_active_count = Mock(
+            return_value=defer.succeed(lots_of_users)
+        )
+        yield self.store.insert_client_ip(
+            user_id, "access_token", "ip", "user_agent", "device_id"
+        )
+        active = yield self.store.user_last_seen_monthly_active(user_id)
+        self.assertFalse(active)
+
+    @defer.inlineCallbacks
+    def test_adding_monthly_active_user_when_space(self):
+        self.hs.config.limit_usage_by_mau = True
+        self.hs.config.max_mau_value = 50
+        user_id = "@user:server"
+        active = yield self.store.user_last_seen_monthly_active(user_id)
+        self.assertFalse(active)
+
+        yield self.store.insert_client_ip(
+            user_id, "access_token", "ip", "user_agent", "device_id"
+        )
+        active = yield self.store.user_last_seen_monthly_active(user_id)
+        self.assertTrue(active)
+
+    @defer.inlineCallbacks
+    def test_updating_monthly_active_user_when_space(self):
+        self.hs.config.limit_usage_by_mau = True
+        self.hs.config.max_mau_value = 50
+        user_id = "@user:server"
+
+        active = yield self.store.user_last_seen_monthly_active(user_id)
+        self.assertFalse(active)
+
+        yield self.store.insert_client_ip(
+            user_id, "access_token", "ip", "user_agent", "device_id"
+        )
+        yield self.store.insert_client_ip(
+            user_id, "access_token", "ip", "user_agent", "device_id"
+        )
+        active = yield self.store.user_last_seen_monthly_active(user_id)
+        self.assertTrue(active)
diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index a54cc6bc32..aef4dfaf57 100644
--- a/tests/storage/test_devices.py
+++ b/tests/storage/test_devices.py
@@ -28,68 +28,64 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
 
     @defer.inlineCallbacks
     def setUp(self):
-        hs = yield tests.utils.setup_test_homeserver()
+        hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
 
         self.store = hs.get_datastore()
 
     @defer.inlineCallbacks
     def test_store_new_device(self):
-        yield self.store.store_device(
-            "user_id", "device_id", "display_name"
-        )
+        yield self.store.store_device("user_id", "device_id", "display_name")
 
         res = yield self.store.get_device("user_id", "device_id")
-        self.assertDictContainsSubset({
-            "user_id": "user_id",
-            "device_id": "device_id",
-            "display_name": "display_name",
-        }, res)
+        self.assertDictContainsSubset(
+            {
+                "user_id": "user_id",
+                "device_id": "device_id",
+                "display_name": "display_name",
+            },
+            res,
+        )
 
     @defer.inlineCallbacks
     def test_get_devices_by_user(self):
-        yield self.store.store_device(
-            "user_id", "device1", "display_name 1"
-        )
-        yield self.store.store_device(
-            "user_id", "device2", "display_name 2"
-        )
-        yield self.store.store_device(
-            "user_id2", "device3", "display_name 3"
-        )
+        yield self.store.store_device("user_id", "device1", "display_name 1")
+        yield self.store.store_device("user_id", "device2", "display_name 2")
+        yield self.store.store_device("user_id2", "device3", "display_name 3")
 
         res = yield self.store.get_devices_by_user("user_id")
         self.assertEqual(2, len(res.keys()))
-        self.assertDictContainsSubset({
-            "user_id": "user_id",
-            "device_id": "device1",
-            "display_name": "display_name 1",
-        }, res["device1"])
-        self.assertDictContainsSubset({
-            "user_id": "user_id",
-            "device_id": "device2",
-            "display_name": "display_name 2",
-        }, res["device2"])
+        self.assertDictContainsSubset(
+            {
+                "user_id": "user_id",
+                "device_id": "device1",
+                "display_name": "display_name 1",
+            },
+            res["device1"],
+        )
+        self.assertDictContainsSubset(
+            {
+                "user_id": "user_id",
+                "device_id": "device2",
+                "display_name": "display_name 2",
+            },
+            res["device2"],
+        )
 
     @defer.inlineCallbacks
     def test_update_device(self):
-        yield self.store.store_device(
-            "user_id", "device_id", "display_name 1"
-        )
+        yield self.store.store_device("user_id", "device_id", "display_name 1")
 
         res = yield self.store.get_device("user_id", "device_id")
         self.assertEqual("display_name 1", res["display_name"])
 
         # do a no-op first
-        yield self.store.update_device(
-            "user_id", "device_id",
-        )
+        yield self.store.update_device("user_id", "device_id")
         res = yield self.store.get_device("user_id", "device_id")
         self.assertEqual("display_name 1", res["display_name"])
 
         # do the update
         yield self.store.update_device(
-            "user_id", "device_id",
-            new_display_name="display_name 2",
+            "user_id", "device_id", new_display_name="display_name 2"
         )
 
         # check it worked
@@ -100,7 +96,6 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
     def test_update_unknown_device(self):
         with self.assertRaises(synapse.api.errors.StoreError) as cm:
             yield self.store.update_device(
-                "user_id", "unknown_device_id",
-                new_display_name="display_name 2",
+                "user_id", "unknown_device_id", new_display_name="display_name 2"
             )
         self.assertEqual(404, cm.exception.code)
diff --git a/tests/storage/test_directory.py b/tests/storage/test_directory.py
index 129ebaf343..b4510c1c8d 100644
--- a/tests/storage/test_directory.py
+++ b/tests/storage/test_directory.py
@@ -24,10 +24,9 @@ from tests.utils import setup_test_homeserver
 
 
 class DirectoryStoreTestCase(unittest.TestCase):
-
     @defer.inlineCallbacks
     def setUp(self):
-        hs = yield setup_test_homeserver()
+        hs = yield setup_test_homeserver(self.addCleanup)
 
         self.store = DirectoryStore(None, hs)
 
@@ -37,38 +36,29 @@ class DirectoryStoreTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def test_room_to_alias(self):
         yield self.store.create_room_alias_association(
-            room_alias=self.alias,
-            room_id=self.room.to_string(),
-            servers=["test"],
+            room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
         )
 
         self.assertEquals(
             ["#my-room:test"],
-            (yield self.store.get_aliases_for_room(self.room.to_string()))
+            (yield self.store.get_aliases_for_room(self.room.to_string())),
         )
 
     @defer.inlineCallbacks
     def test_alias_to_room(self):
         yield self.store.create_room_alias_association(
-            room_alias=self.alias,
-            room_id=self.room.to_string(),
-            servers=["test"],
+            room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
         )
 
         self.assertObjectHasAttributes(
-            {
-                "room_id": self.room.to_string(),
-                "servers": ["test"],
-            },
-            (yield self.store.get_association_from_room_alias(self.alias))
+            {"room_id": self.room.to_string(), "servers": ["test"]},
+            (yield self.store.get_association_from_room_alias(self.alias)),
         )
 
     @defer.inlineCallbacks
     def test_delete_alias(self):
         yield self.store.create_room_alias_association(
-            room_alias=self.alias,
-            room_id=self.room.to_string(),
-            servers=["test"],
+            room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
         )
 
         room_id = yield self.store.delete_room_alias(self.alias)
diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py
index 84ce492a2c..8f0aaece40 100644
--- a/tests/storage/test_end_to_end_keys.py
+++ b/tests/storage/test_end_to_end_keys.py
@@ -26,8 +26,7 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
 
     @defer.inlineCallbacks
     def setUp(self):
-        hs = yield tests.utils.setup_test_homeserver()
-
+        hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
         self.store = hs.get_datastore()
 
     @defer.inlineCallbacks
@@ -35,70 +34,49 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
         now = 1470174257070
         json = {"key": "value"}
 
-        yield self.store.store_device(
-            "user", "device", None
-        )
+        yield self.store.store_device("user", "device", None)
 
-        yield self.store.set_e2e_device_keys(
-            "user", "device", now, json)
+        yield self.store.set_e2e_device_keys("user", "device", now, json)
 
         res = yield self.store.get_e2e_device_keys((("user", "device"),))
         self.assertIn("user", res)
         self.assertIn("device", res["user"])
         dev = res["user"]["device"]
-        self.assertDictContainsSubset({
-            "keys": json,
-            "device_display_name": None,
-        }, dev)
+        self.assertDictContainsSubset({"keys": json, "device_display_name": None}, dev)
 
     @defer.inlineCallbacks
     def test_get_key_with_device_name(self):
         now = 1470174257070
         json = {"key": "value"}
 
-        yield self.store.set_e2e_device_keys(
-            "user", "device", now, json)
-        yield self.store.store_device(
-            "user", "device", "display_name"
-        )
+        yield self.store.set_e2e_device_keys("user", "device", now, json)
+        yield self.store.store_device("user", "device", "display_name")
 
         res = yield self.store.get_e2e_device_keys((("user", "device"),))
         self.assertIn("user", res)
         self.assertIn("device", res["user"])
         dev = res["user"]["device"]
-        self.assertDictContainsSubset({
-            "keys": json,
-            "device_display_name": "display_name",
-        }, dev)
+        self.assertDictContainsSubset(
+            {"keys": json, "device_display_name": "display_name"}, dev
+        )
 
     @defer.inlineCallbacks
     def test_multiple_devices(self):
         now = 1470174257070
 
-        yield self.store.store_device(
-            "user1", "device1", None
-        )
-        yield self.store.store_device(
-            "user1", "device2", None
-        )
-        yield self.store.store_device(
-            "user2", "device1", None
-        )
-        yield self.store.store_device(
-            "user2", "device2", None
-        )
+        yield self.store.store_device("user1", "device1", None)
+        yield self.store.store_device("user1", "device2", None)
+        yield self.store.store_device("user2", "device1", None)
+        yield self.store.store_device("user2", "device2", None)
 
-        yield self.store.set_e2e_device_keys(
-            "user1", "device1", now, 'json11')
-        yield self.store.set_e2e_device_keys(
-            "user1", "device2", now, 'json12')
-        yield self.store.set_e2e_device_keys(
-            "user2", "device1", now, 'json21')
-        yield self.store.set_e2e_device_keys(
-            "user2", "device2", now, 'json22')
-
-        res = yield self.store.get_e2e_device_keys((("user1", "device1"),
-                                                    ("user2", "device2")))
+        yield self.store.set_e2e_device_keys("user1", "device1", now, 'json11')
+        yield self.store.set_e2e_device_keys("user1", "device2", now, 'json12')
+        yield self.store.set_e2e_device_keys("user2", "device1", now, 'json21')
+        yield self.store.set_e2e_device_keys("user2", "device2", now, 'json22')
+
+        res = yield self.store.get_e2e_device_keys(
+            (("user1", "device1"), ("user2", "device2"))
+        )
         self.assertIn("user1", res)
         self.assertIn("device1", res["user1"])
         self.assertNotIn("device2", res["user1"])
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index 30683e7888..2fdf34fdf6 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -22,7 +22,7 @@ import tests.utils
 class EventFederationWorkerStoreTestCase(tests.unittest.TestCase):
     @defer.inlineCallbacks
     def setUp(self):
-        hs = yield tests.utils.setup_test_homeserver()
+        hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
         self.store = hs.get_datastore()
 
     @defer.inlineCallbacks
@@ -33,23 +33,32 @@ class EventFederationWorkerStoreTestCase(tests.unittest.TestCase):
         def insert_event(txn, i):
             event_id = '$event_%i:local' % i
 
-            txn.execute((
-                "INSERT INTO events ("
-                "   room_id, event_id, type, depth, topological_ordering,"
-                "   content, processed, outlier) "
-                "VALUES (?, ?, 'm.test', ?, ?, 'test', ?, ?)"
-            ), (room_id, event_id, i, i, True, False))
+            txn.execute(
+                (
+                    "INSERT INTO events ("
+                    "   room_id, event_id, type, depth, topological_ordering,"
+                    "   content, processed, outlier) "
+                    "VALUES (?, ?, 'm.test', ?, ?, 'test', ?, ?)"
+                ),
+                (room_id, event_id, i, i, True, False),
+            )
 
-            txn.execute((
-                'INSERT INTO event_forward_extremities (room_id, event_id) '
-                'VALUES (?, ?)'
-            ), (room_id, event_id))
+            txn.execute(
+                (
+                    'INSERT INTO event_forward_extremities (room_id, event_id) '
+                    'VALUES (?, ?)'
+                ),
+                (room_id, event_id),
+            )
 
-            txn.execute((
-                'INSERT INTO event_reference_hashes '
-                '(event_id, algorithm, hash) '
-                "VALUES (?, 'sha256', ?)"
-            ), (event_id, 'ffff'))
+            txn.execute(
+                (
+                    'INSERT INTO event_reference_hashes '
+                    '(event_id, algorithm, hash) '
+                    "VALUES (?, 'sha256', ?)"
+                ),
+                (event_id, b'ffff'),
+            )
 
         for i in range(0, 11):
             yield self.store.runInteraction("insert", insert_event, i)
diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index 8430fc7ba6..b114c6fb1d 100644
--- a/tests/storage/test_event_push_actions.py
+++ b/tests/storage/test_event_push_actions.py
@@ -24,15 +24,16 @@ USER_ID = "@user:example.com"
 
 PlAIN_NOTIF = ["notify", {"set_tweak": "highlight", "value": False}]
 HIGHLIGHT = [
-    "notify", {"set_tweak": "sound", "value": "default"}, {"set_tweak": "highlight"}
+    "notify",
+    {"set_tweak": "sound", "value": "default"},
+    {"set_tweak": "highlight"},
 ]
 
 
 class EventPushActionsStoreTestCase(tests.unittest.TestCase):
-
     @defer.inlineCallbacks
     def setUp(self):
-        hs = yield tests.utils.setup_test_homeserver()
+        hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
         self.store = hs.get_datastore()
 
     @defer.inlineCallbacks
@@ -55,12 +56,11 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
         @defer.inlineCallbacks
         def _assert_counts(noitf_count, highlight_count):
             counts = yield self.store.runInteraction(
-                "", self.store._get_unread_counts_by_pos_txn,
-                room_id, user_id, 0
+                "", self.store._get_unread_counts_by_pos_txn, room_id, user_id, 0
             )
             self.assertEquals(
                 counts,
-                {"notify_count": noitf_count, "highlight_count": highlight_count}
+                {"notify_count": noitf_count, "highlight_count": highlight_count},
             )
 
         @defer.inlineCallbacks
@@ -72,11 +72,13 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
             event.depth = stream
 
             yield self.store.add_push_actions_to_staging(
-                event.event_id, {user_id: action},
+                event.event_id, {user_id: action}
             )
             yield self.store.runInteraction(
-                "", self.store._set_push_actions_for_event_and_users_txn,
-                [(event, None)], [(event, None)],
+                "",
+                self.store._set_push_actions_for_event_and_users_txn,
+                [(event, None)],
+                [(event, None)],
             )
 
         def _rotate(stream):
@@ -86,8 +88,11 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
 
         def _mark_read(stream, depth):
             return self.store.runInteraction(
-                "", self.store._remove_old_push_actions_before_txn,
-                room_id, user_id, stream
+                "",
+                self.store._remove_old_push_actions_before_txn,
+                room_id,
+                user_id,
+                stream,
             )
 
         yield _assert_counts(0, 0)
@@ -112,9 +117,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
         yield _rotate(7)
 
         yield self.store._simple_delete(
-            table="event_push_actions",
-            keyvalues={"1": 1},
-            desc="",
+            table="event_push_actions", keyvalues={"1": 1}, desc=""
         )
 
         yield _assert_counts(1, 0)
@@ -132,18 +135,21 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
     @defer.inlineCallbacks
     def test_find_first_stream_ordering_after_ts(self):
         def add_event(so, ts):
-            return self.store._simple_insert("events", {
-                "stream_ordering": so,
-                "received_ts": ts,
-                "event_id": "event%i" % so,
-                "type": "",
-                "room_id": "",
-                "content": "",
-                "processed": True,
-                "outlier": False,
-                "topological_ordering": 0,
-                "depth": 0,
-            })
+            return self.store._simple_insert(
+                "events",
+                {
+                    "stream_ordering": so,
+                    "received_ts": ts,
+                    "event_id": "event%i" % so,
+                    "type": "",
+                    "room_id": "",
+                    "content": "",
+                    "processed": True,
+                    "outlier": False,
+                    "topological_ordering": 0,
+                    "depth": 0,
+                },
+            )
 
         # start with the base case where there are no events in the table
         r = yield self.store.find_first_stream_ordering_after_ts(11)
@@ -160,31 +166,27 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
 
         # add a bunch of dummy events to the events table
         for (stream_ordering, ts) in (
-                (3, 110),
-                (4, 120),
-                (5, 120),
-                (10, 130),
-                (20, 140),
+            (3, 110),
+            (4, 120),
+            (5, 120),
+            (10, 130),
+            (20, 140),
         ):
             yield add_event(stream_ordering, ts)
 
         r = yield self.store.find_first_stream_ordering_after_ts(110)
-        self.assertEqual(r, 3,
-                         "First event after 110ms should be 3, was %i" % r)
+        self.assertEqual(r, 3, "First event after 110ms should be 3, was %i" % r)
 
         # 4 and 5 are both after 120: we want 4 rather than 5
         r = yield self.store.find_first_stream_ordering_after_ts(120)
-        self.assertEqual(r, 4,
-                         "First event after 120ms should be 4, was %i" % r)
+        self.assertEqual(r, 4, "First event after 120ms should be 4, was %i" % r)
 
         r = yield self.store.find_first_stream_ordering_after_ts(129)
-        self.assertEqual(r, 10,
-                         "First event after 129ms should be 10, was %i" % r)
+        self.assertEqual(r, 10, "First event after 129ms should be 10, was %i" % r)
 
         # check we can get the last event
         r = yield self.store.find_first_stream_ordering_after_ts(140)
-        self.assertEqual(r, 20,
-                         "First event after 14ms should be 20, was %i" % r)
+        self.assertEqual(r, 20, "First event after 14ms should be 20, was %i" % r)
 
         # off the end
         r = yield self.store.find_first_stream_ordering_after_ts(160)
diff --git a/tests/storage/test_keys.py b/tests/storage/test_keys.py
index 3a3d002782..47f4a8ceac 100644
--- a/tests/storage/test_keys.py
+++ b/tests/storage/test_keys.py
@@ -28,7 +28,7 @@ class KeyStoreTestCase(tests.unittest.TestCase):
 
     @defer.inlineCallbacks
     def setUp(self):
-        hs = yield tests.utils.setup_test_homeserver()
+        hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
         self.store = hs.get_datastore()
 
     @defer.inlineCallbacks
@@ -39,15 +39,12 @@ class KeyStoreTestCase(tests.unittest.TestCase):
         key2 = signedjson.key.decode_verify_key_base64(
             "ed25519", "key2", "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw"
         )
-        yield self.store.store_server_verify_key(
-            "server1", "from_server", 0, key1
-        )
-        yield self.store.store_server_verify_key(
-            "server1", "from_server", 0, key2
-        )
+        yield self.store.store_server_verify_key("server1", "from_server", 0, key1)
+        yield self.store.store_server_verify_key("server1", "from_server", 0, key2)
 
         res = yield self.store.get_server_verify_keys(
-            "server1", ["ed25519:key1", "ed25519:key2", "ed25519:key3"])
+            "server1", ["ed25519:key1", "ed25519:key2", "ed25519:key3"]
+        )
 
         self.assertEqual(len(res.keys()), 2)
         self.assertEqual(res["ed25519:key1"].version, "key1")
diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py
new file mode 100644
index 0000000000..f2ed866ae7
--- /dev/null
+++ b/tests/storage/test_monthly_active_users.py
@@ -0,0 +1,131 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+import tests.unittest
+import tests.utils
+from tests.utils import setup_test_homeserver
+
+FORTY_DAYS = 40 * 24 * 60 * 60
+
+
+class MonthlyActiveUsersTestCase(tests.unittest.TestCase):
+    def __init__(self, *args, **kwargs):
+        super(MonthlyActiveUsersTestCase, self).__init__(*args, **kwargs)
+
+    @defer.inlineCallbacks
+    def setUp(self):
+        self.hs = yield setup_test_homeserver(self.addCleanup)
+        self.store = self.hs.get_datastore()
+
+    @defer.inlineCallbacks
+    def test_initialise_reserved_users(self):
+        self.hs.config.max_mau_value = 5
+        user1 = "@user1:server"
+        user1_email = "user1@matrix.org"
+        user2 = "@user2:server"
+        user2_email = "user2@matrix.org"
+        threepids = [
+            {'medium': 'email', 'address': user1_email},
+            {'medium': 'email', 'address': user2_email},
+        ]
+        user_num = len(threepids)
+
+        yield self.store.register(user_id=user1, token="123", password_hash=None)
+
+        yield self.store.register(user_id=user2, token="456", password_hash=None)
+
+        now = int(self.hs.get_clock().time_msec())
+        yield self.store.user_add_threepid(user1, "email", user1_email, now, now)
+        yield self.store.user_add_threepid(user2, "email", user2_email, now, now)
+        yield self.store.initialise_reserved_users(threepids)
+
+        active_count = yield self.store.get_monthly_active_count()
+
+        # Test total counts
+        self.assertEquals(active_count, user_num)
+
+        # Test user is marked as active
+
+        timestamp = yield self.store.user_last_seen_monthly_active(user1)
+        self.assertTrue(timestamp)
+        timestamp = yield self.store.user_last_seen_monthly_active(user2)
+        self.assertTrue(timestamp)
+
+        # Test that users are never removed from the db.
+        self.hs.config.max_mau_value = 0
+
+        self.hs.get_clock().advance_time(FORTY_DAYS)
+
+        yield self.store.reap_monthly_active_users()
+
+        active_count = yield self.store.get_monthly_active_count()
+        self.assertEquals(active_count, user_num)
+
+        # Test that regalar users are removed from the db
+        ru_count = 2
+        yield self.store.upsert_monthly_active_user("@ru1:server")
+        yield self.store.upsert_monthly_active_user("@ru2:server")
+        active_count = yield self.store.get_monthly_active_count()
+
+        self.assertEqual(active_count, user_num + ru_count)
+        self.hs.config.max_mau_value = user_num
+        yield self.store.reap_monthly_active_users()
+
+        active_count = yield self.store.get_monthly_active_count()
+        self.assertEquals(active_count, user_num)
+
+    @defer.inlineCallbacks
+    def test_can_insert_and_count_mau(self):
+        count = yield self.store.get_monthly_active_count()
+        self.assertEqual(0, count)
+
+        yield self.store.upsert_monthly_active_user("@user:server")
+        count = yield self.store.get_monthly_active_count()
+
+        self.assertEqual(1, count)
+
+    @defer.inlineCallbacks
+    def test_user_last_seen_monthly_active(self):
+        user_id1 = "@user1:server"
+        user_id2 = "@user2:server"
+        user_id3 = "@user3:server"
+
+        result = yield self.store.user_last_seen_monthly_active(user_id1)
+        self.assertFalse(result == 0)
+        yield self.store.upsert_monthly_active_user(user_id1)
+        yield self.store.upsert_monthly_active_user(user_id2)
+        result = yield self.store.user_last_seen_monthly_active(user_id1)
+        self.assertTrue(result > 0)
+        result = yield self.store.user_last_seen_monthly_active(user_id3)
+        self.assertFalse(result == 0)
+
+    @defer.inlineCallbacks
+    def test_reap_monthly_active_users(self):
+        self.hs.config.max_mau_value = 5
+        initial_users = 10
+        for i in range(initial_users):
+            yield self.store.upsert_monthly_active_user("@user%d:server" % i)
+        count = yield self.store.get_monthly_active_count()
+        self.assertTrue(count, initial_users)
+        yield self.store.reap_monthly_active_users()
+        count = yield self.store.get_monthly_active_count()
+        self.assertEquals(count, initial_users - self.hs.config.max_mau_value)
+
+        self.hs.get_clock().advance_time(FORTY_DAYS)
+        yield self.store.reap_monthly_active_users()
+        count = yield self.store.get_monthly_active_count()
+        self.assertEquals(count, 0)
diff --git a/tests/storage/test_presence.py b/tests/storage/test_presence.py
index 3276b39504..b5b58ff660 100644
--- a/tests/storage/test_presence.py
+++ b/tests/storage/test_presence.py
@@ -24,10 +24,9 @@ from tests.utils import MockClock, setup_test_homeserver
 
 
 class PresenceStoreTestCase(unittest.TestCase):
-
     @defer.inlineCallbacks
     def setUp(self):
-        hs = yield setup_test_homeserver(clock=MockClock())
+        hs = yield setup_test_homeserver(self.addCleanup, clock=MockClock())
 
         self.store = PresenceStore(None, hs)
 
@@ -38,16 +37,19 @@ class PresenceStoreTestCase(unittest.TestCase):
     def test_presence_list(self):
         self.assertEquals(
             [],
-            (yield self.store.get_presence_list(
-                observer_localpart=self.u_apple.localpart,
-            ))
+            (
+                yield self.store.get_presence_list(
+                    observer_localpart=self.u_apple.localpart
+                )
+            ),
         )
         self.assertEquals(
             [],
-            (yield self.store.get_presence_list(
-                observer_localpart=self.u_apple.localpart,
-                accepted=True,
-            ))
+            (
+                yield self.store.get_presence_list(
+                    observer_localpart=self.u_apple.localpart, accepted=True
+                )
+            ),
         )
 
         yield self.store.add_presence_list_pending(
@@ -57,16 +59,19 @@ class PresenceStoreTestCase(unittest.TestCase):
 
         self.assertEquals(
             [{"observed_user_id": "@banana:test", "accepted": 0}],
-            (yield self.store.get_presence_list(
-                observer_localpart=self.u_apple.localpart,
-            ))
+            (
+                yield self.store.get_presence_list(
+                    observer_localpart=self.u_apple.localpart
+                )
+            ),
         )
         self.assertEquals(
             [],
-            (yield self.store.get_presence_list(
-                observer_localpart=self.u_apple.localpart,
-                accepted=True,
-            ))
+            (
+                yield self.store.get_presence_list(
+                    observer_localpart=self.u_apple.localpart, accepted=True
+                )
+            ),
         )
 
         yield self.store.set_presence_list_accepted(
@@ -76,16 +81,19 @@ class PresenceStoreTestCase(unittest.TestCase):
 
         self.assertEquals(
             [{"observed_user_id": "@banana:test", "accepted": 1}],
-            (yield self.store.get_presence_list(
-                observer_localpart=self.u_apple.localpart,
-            ))
+            (
+                yield self.store.get_presence_list(
+                    observer_localpart=self.u_apple.localpart
+                )
+            ),
         )
         self.assertEquals(
             [{"observed_user_id": "@banana:test", "accepted": 1}],
-            (yield self.store.get_presence_list(
-                observer_localpart=self.u_apple.localpart,
-                accepted=True,
-            ))
+            (
+                yield self.store.get_presence_list(
+                    observer_localpart=self.u_apple.localpart, accepted=True
+                )
+            ),
         )
 
         yield self.store.del_presence_list(
@@ -95,14 +103,17 @@ class PresenceStoreTestCase(unittest.TestCase):
 
         self.assertEquals(
             [],
-            (yield self.store.get_presence_list(
-                observer_localpart=self.u_apple.localpart,
-            ))
+            (
+                yield self.store.get_presence_list(
+                    observer_localpart=self.u_apple.localpart
+                )
+            ),
         )
         self.assertEquals(
             [],
-            (yield self.store.get_presence_list(
-                observer_localpart=self.u_apple.localpart,
-                accepted=True,
-            ))
+            (
+                yield self.store.get_presence_list(
+                    observer_localpart=self.u_apple.localpart, accepted=True
+                )
+            ),
         )
diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py
index 2c95e5e95a..a1f6618bf9 100644
--- a/tests/storage/test_profile.py
+++ b/tests/storage/test_profile.py
@@ -24,10 +24,9 @@ from tests.utils import setup_test_homeserver
 
 
 class ProfileStoreTestCase(unittest.TestCase):
-
     @defer.inlineCallbacks
     def setUp(self):
-        hs = yield setup_test_homeserver()
+        hs = yield setup_test_homeserver(self.addCleanup)
 
         self.store = ProfileStore(None, hs)
 
@@ -35,24 +34,17 @@ class ProfileStoreTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_displayname(self):
-        yield self.store.create_profile(
-            self.u_frank.localpart
-        )
+        yield self.store.create_profile(self.u_frank.localpart)
 
-        yield self.store.set_profile_displayname(
-            self.u_frank.localpart, "Frank"
-        )
+        yield self.store.set_profile_displayname(self.u_frank.localpart, "Frank")
 
         self.assertEquals(
-            "Frank",
-            (yield self.store.get_profile_displayname(self.u_frank.localpart))
+            "Frank", (yield self.store.get_profile_displayname(self.u_frank.localpart))
         )
 
     @defer.inlineCallbacks
     def test_avatar_url(self):
-        yield self.store.create_profile(
-            self.u_frank.localpart
-        )
+        yield self.store.create_profile(self.u_frank.localpart)
 
         yield self.store.set_profile_avatar_url(
             self.u_frank.localpart, "http://my.site/here"
@@ -60,5 +52,5 @@ class ProfileStoreTestCase(unittest.TestCase):
 
         self.assertEquals(
             "http://my.site/here",
-            (yield self.store.get_profile_avatar_url(self.u_frank.localpart))
+            (yield self.store.get_profile_avatar_url(self.u_frank.localpart)),
         )
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index 475ec900c4..02bf975fbf 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -22,16 +22,14 @@ from synapse.api.constants import EventTypes, Membership
 from synapse.types import RoomID, UserID
 
 from tests import unittest
-from tests.utils import setup_test_homeserver
+from tests.utils import create_room, setup_test_homeserver
 
 
 class RedactionTestCase(unittest.TestCase):
-
     @defer.inlineCallbacks
     def setUp(self):
         hs = yield setup_test_homeserver(
-            resource_for_federation=Mock(),
-            http_client=None,
+            self.addCleanup, resource_for_federation=Mock(), http_client=None
         )
 
         self.store = hs.get_datastore()
@@ -43,20 +41,25 @@ class RedactionTestCase(unittest.TestCase):
 
         self.room1 = RoomID.from_string("!abc123:test")
 
+        yield create_room(hs, self.room1.to_string(), self.u_alice.to_string())
+
         self.depth = 1
 
     @defer.inlineCallbacks
-    def inject_room_member(self, room, user, membership, replaces_state=None,
-                           extra_content={}):
+    def inject_room_member(
+        self, room, user, membership, replaces_state=None, extra_content={}
+    ):
         content = {"membership": membership}
         content.update(extra_content)
-        builder = self.event_builder_factory.new({
-            "type": EventTypes.Member,
-            "sender": user.to_string(),
-            "state_key": user.to_string(),
-            "room_id": room.to_string(),
-            "content": content,
-        })
+        builder = self.event_builder_factory.new(
+            {
+                "type": EventTypes.Member,
+                "sender": user.to_string(),
+                "state_key": user.to_string(),
+                "room_id": room.to_string(),
+                "content": content,
+            }
+        )
 
         event, context = yield self.event_creation_handler.create_new_client_event(
             builder
@@ -70,13 +73,15 @@ class RedactionTestCase(unittest.TestCase):
     def inject_message(self, room, user, body):
         self.depth += 1
 
-        builder = self.event_builder_factory.new({
-            "type": EventTypes.Message,
-            "sender": user.to_string(),
-            "state_key": user.to_string(),
-            "room_id": room.to_string(),
-            "content": {"body": body, "msgtype": u"message"},
-        })
+        builder = self.event_builder_factory.new(
+            {
+                "type": EventTypes.Message,
+                "sender": user.to_string(),
+                "state_key": user.to_string(),
+                "room_id": room.to_string(),
+                "content": {"body": body, "msgtype": u"message"},
+            }
+        )
 
         event, context = yield self.event_creation_handler.create_new_client_event(
             builder
@@ -88,14 +93,16 @@ class RedactionTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def inject_redaction(self, room, event_id, user, reason):
-        builder = self.event_builder_factory.new({
-            "type": EventTypes.Redaction,
-            "sender": user.to_string(),
-            "state_key": user.to_string(),
-            "room_id": room.to_string(),
-            "content": {"reason": reason},
-            "redacts": event_id,
-        })
+        builder = self.event_builder_factory.new(
+            {
+                "type": EventTypes.Redaction,
+                "sender": user.to_string(),
+                "state_key": user.to_string(),
+                "room_id": room.to_string(),
+                "content": {"reason": reason},
+                "redacts": event_id,
+            }
+        )
 
         event, context = yield self.event_creation_handler.create_new_client_event(
             builder
@@ -105,9 +112,7 @@ class RedactionTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_redact(self):
-        yield self.inject_room_member(
-            self.room1, self.u_alice, Membership.JOIN
-        )
+        yield self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
 
         msg_event = yield self.inject_message(self.room1, self.u_alice, u"t")
 
@@ -157,13 +162,10 @@ class RedactionTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_redact_join(self):
-        yield self.inject_room_member(
-            self.room1, self.u_alice, Membership.JOIN
-        )
+        yield self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
 
         msg_event = yield self.inject_room_member(
-            self.room1, self.u_bob, Membership.JOIN,
-            extra_content={"blue": "red"},
+            self.room1, self.u_bob, Membership.JOIN, extra_content={"blue": "red"}
         )
 
         event = yield self.store.get_event(msg_event.event_id)
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index 7821ea3fa3..3dfb7b903a 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -21,19 +21,15 @@ from tests.utils import setup_test_homeserver
 
 
 class RegistrationStoreTestCase(unittest.TestCase):
-
     @defer.inlineCallbacks
     def setUp(self):
-        hs = yield setup_test_homeserver()
+        hs = yield setup_test_homeserver(self.addCleanup)
         self.db_pool = hs.get_db_pool()
 
         self.store = hs.get_datastore()
 
         self.user_id = "@my-user:test"
-        self.tokens = [
-            "AbCdEfGhIjKlMnOpQrStUvWxYz",
-            "BcDeFgHiJkLmNoPqRsTuVwXyZa"
-        ]
+        self.tokens = ["AbCdEfGhIjKlMnOpQrStUvWxYz", "BcDeFgHiJkLmNoPqRsTuVwXyZa"]
         self.pwhash = "{xx1}123456789"
         self.device_id = "akgjhdjklgshg"
 
@@ -50,35 +46,28 @@ class RegistrationStoreTestCase(unittest.TestCase):
                 "consent_version": None,
                 "consent_server_notice_sent": None,
                 "appservice_id": None,
+                "creation_ts": 1000,
             },
-            (yield self.store.get_user_by_id(self.user_id))
+            (yield self.store.get_user_by_id(self.user_id)),
         )
 
         result = yield self.store.get_user_by_access_token(self.tokens[0])
 
-        self.assertDictContainsSubset(
-            {
-                "name": self.user_id,
-            },
-            result
-        )
+        self.assertDictContainsSubset({"name": self.user_id}, result)
 
         self.assertTrue("token_id" in result)
 
     @defer.inlineCallbacks
     def test_add_tokens(self):
         yield self.store.register(self.user_id, self.tokens[0], self.pwhash)
-        yield self.store.add_access_token_to_user(self.user_id, self.tokens[1],
-                                                  self.device_id)
+        yield self.store.add_access_token_to_user(
+            self.user_id, self.tokens[1], self.device_id
+        )
 
         result = yield self.store.get_user_by_access_token(self.tokens[1])
 
         self.assertDictContainsSubset(
-            {
-                "name": self.user_id,
-                "device_id": self.device_id,
-            },
-            result
+            {"name": self.user_id, "device_id": self.device_id}, result
         )
 
         self.assertTrue("token_id" in result)
@@ -87,12 +76,13 @@ class RegistrationStoreTestCase(unittest.TestCase):
     def test_user_delete_access_tokens(self):
         # add some tokens
         yield self.store.register(self.user_id, self.tokens[0], self.pwhash)
-        yield self.store.add_access_token_to_user(self.user_id, self.tokens[1],
-                                                  self.device_id)
+        yield self.store.add_access_token_to_user(
+            self.user_id, self.tokens[1], self.device_id
+        )
 
         # now delete some
         yield self.store.user_delete_access_tokens(
-            self.user_id, device_id=self.device_id,
+            self.user_id, device_id=self.device_id
         )
 
         # check they were deleted
@@ -107,8 +97,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
         yield self.store.user_delete_access_tokens(self.user_id)
 
         user = yield self.store.get_user_by_access_token(self.tokens[0])
-        self.assertIsNone(user,
-                          "access token was not deleted without device_id")
+        self.assertIsNone(user, "access token was not deleted without device_id")
 
 
 class TokenGenerator:
@@ -117,4 +106,4 @@ class TokenGenerator:
 
     def generate(self, user_id):
         self._last_issued_token += 1
-        return u"%s-%d" % (user_id, self._last_issued_token,)
+        return u"%s-%d" % (user_id, self._last_issued_token)
diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py
index ae8ae94b6d..a1ea23b068 100644
--- a/tests/storage/test_room.py
+++ b/tests/storage/test_room.py
@@ -24,10 +24,9 @@ from tests.utils import setup_test_homeserver
 
 
 class RoomStoreTestCase(unittest.TestCase):
-
     @defer.inlineCallbacks
     def setUp(self):
-        hs = yield setup_test_homeserver()
+        hs = yield setup_test_homeserver(self.addCleanup)
 
         # We can't test RoomStore on its own without the DirectoryStore, for
         # management of the 'room_aliases' table
@@ -40,7 +39,7 @@ class RoomStoreTestCase(unittest.TestCase):
         yield self.store.store_room(
             self.room.to_string(),
             room_creator_user_id=self.u_creator.to_string(),
-            is_public=True
+            is_public=True,
         )
 
     @defer.inlineCallbacks
@@ -49,17 +48,16 @@ class RoomStoreTestCase(unittest.TestCase):
             {
                 "room_id": self.room.to_string(),
                 "creator": self.u_creator.to_string(),
-                "is_public": True
+                "is_public": True,
             },
-            (yield self.store.get_room(self.room.to_string()))
+            (yield self.store.get_room(self.room.to_string())),
         )
 
 
 class RoomEventsStoreTestCase(unittest.TestCase):
-
     @defer.inlineCallbacks
     def setUp(self):
-        hs = setup_test_homeserver()
+        hs = setup_test_homeserver(self.addCleanup)
 
         # Room events need the full datastore, for persist_event() and
         # get_room_state()
@@ -69,18 +67,13 @@ class RoomEventsStoreTestCase(unittest.TestCase):
         self.room = RoomID.from_string("!abcde:test")
 
         yield self.store.store_room(
-            self.room.to_string(),
-            room_creator_user_id="@creator:text",
-            is_public=True
+            self.room.to_string(), room_creator_user_id="@creator:text", is_public=True
         )
 
     @defer.inlineCallbacks
     def inject_room_event(self, **kwargs):
         yield self.store.persist_event(
-            self.event_factory.create_event(
-                room_id=self.room.to_string(),
-                **kwargs
-            )
+            self.event_factory.create_event(room_id=self.room.to_string(), **kwargs)
         )
 
     @defer.inlineCallbacks
@@ -88,22 +81,15 @@ class RoomEventsStoreTestCase(unittest.TestCase):
         name = u"A-Room-Name"
 
         yield self.inject_room_event(
-            etype=EventTypes.Name,
-            name=name,
-            content={"name": name},
-            depth=1,
+            etype=EventTypes.Name, name=name, content={"name": name}, depth=1
         )
 
-        state = yield self.store.get_current_state(
-            room_id=self.room.to_string()
-        )
+        state = yield self.store.get_current_state(room_id=self.room.to_string())
 
         self.assertEquals(1, len(state))
         self.assertObjectHasAttributes(
-            {"type": "m.room.name",
-             "room_id": self.room.to_string(),
-             "name": name},
-            state[0]
+            {"type": "m.room.name", "room_id": self.room.to_string(), "name": name},
+            state[0],
         )
 
     @defer.inlineCallbacks
@@ -111,22 +97,15 @@ class RoomEventsStoreTestCase(unittest.TestCase):
         topic = u"A place for things"
 
         yield self.inject_room_event(
-            etype=EventTypes.Topic,
-            topic=topic,
-            content={"topic": topic},
-            depth=1,
+            etype=EventTypes.Topic, topic=topic, content={"topic": topic}, depth=1
         )
 
-        state = yield self.store.get_current_state(
-            room_id=self.room.to_string()
-        )
+        state = yield self.store.get_current_state(room_id=self.room.to_string())
 
         self.assertEquals(1, len(state))
         self.assertObjectHasAttributes(
-            {"type": "m.room.topic",
-             "room_id": self.room.to_string(),
-             "topic": topic},
-            state[0]
+            {"type": "m.room.topic", "room_id": self.room.to_string(), "topic": topic},
+            state[0],
         )
 
     # Not testing the various 'level' methods for now because there's lots
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index c5fd54f67e..978c66133d 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -22,16 +22,14 @@ from synapse.api.constants import EventTypes, Membership
 from synapse.types import RoomID, UserID
 
 from tests import unittest
-from tests.utils import setup_test_homeserver
+from tests.utils import create_room, setup_test_homeserver
 
 
 class RoomMemberStoreTestCase(unittest.TestCase):
-
     @defer.inlineCallbacks
     def setUp(self):
         hs = yield setup_test_homeserver(
-            resource_for_federation=Mock(),
-            http_client=None,
+            self.addCleanup, resource_for_federation=Mock(), http_client=None
         )
         # We can't test the RoomMemberStore on its own without the other event
         # storage logic
@@ -47,15 +45,19 @@ class RoomMemberStoreTestCase(unittest.TestCase):
 
         self.room = RoomID.from_string("!abc123:test")
 
+        yield create_room(hs, self.room.to_string(), self.u_alice.to_string())
+
     @defer.inlineCallbacks
     def inject_room_member(self, room, user, membership, replaces_state=None):
-        builder = self.event_builder_factory.new({
-            "type": EventTypes.Member,
-            "sender": user.to_string(),
-            "state_key": user.to_string(),
-            "room_id": room.to_string(),
-            "content": {"membership": membership},
-        })
+        builder = self.event_builder_factory.new(
+            {
+                "type": EventTypes.Member,
+                "sender": user.to_string(),
+                "state_key": user.to_string(),
+                "room_id": room.to_string(),
+                "content": {"membership": membership},
+            }
+        )
 
         event, context = yield self.event_creation_handler.create_new_client_event(
             builder
@@ -71,9 +73,12 @@ class RoomMemberStoreTestCase(unittest.TestCase):
 
         self.assertEquals(
             [self.room.to_string()],
-            [m.room_id for m in (
-                yield self.store.get_rooms_for_user_where_membership_is(
-                    self.u_alice.to_string(), [Membership.JOIN]
+            [
+                m.room_id
+                for m in (
+                    yield self.store.get_rooms_for_user_where_membership_is(
+                        self.u_alice.to_string(), [Membership.JOIN]
+                    )
                 )
-            )]
+            ],
         )
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index 7a76d67b8c..d717b9f94e 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -33,7 +33,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
 
     @defer.inlineCallbacks
     def setUp(self):
-        hs = yield tests.utils.setup_test_homeserver()
+        hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
 
         self.store = hs.get_datastore()
         self.event_builder_factory = hs.get_event_builder_factory()
@@ -45,20 +45,20 @@ class StateStoreTestCase(tests.unittest.TestCase):
         self.room = RoomID.from_string("!abc123:test")
 
         yield self.store.store_room(
-            self.room.to_string(),
-            room_creator_user_id="@creator:text",
-            is_public=True
+            self.room.to_string(), room_creator_user_id="@creator:text", is_public=True
         )
 
     @defer.inlineCallbacks
     def inject_state_event(self, room, sender, typ, state_key, content):
-        builder = self.event_builder_factory.new({
-            "type": typ,
-            "sender": sender.to_string(),
-            "state_key": state_key,
-            "room_id": room.to_string(),
-            "content": content,
-        })
+        builder = self.event_builder_factory.new(
+            {
+                "type": typ,
+                "sender": sender.to_string(),
+                "state_key": state_key,
+                "room_id": room.to_string(),
+                "content": content,
+            }
+        )
 
         event, context = yield self.event_creation_handler.create_new_client_event(
             builder
@@ -80,27 +80,31 @@ class StateStoreTestCase(tests.unittest.TestCase):
         # this defaults to a linear DAG as each new injection defaults to whatever
         # forward extremities are currently in the DB for this room.
         e1 = yield self.inject_state_event(
-            self.room, self.u_alice, EventTypes.Create, '', {},
+            self.room, self.u_alice, EventTypes.Create, '', {}
         )
         e2 = yield self.inject_state_event(
-            self.room, self.u_alice, EventTypes.Name, '', {
-                "name": "test room"
-            },
+            self.room, self.u_alice, EventTypes.Name, '', {"name": "test room"}
         )
         e3 = yield self.inject_state_event(
-            self.room, self.u_alice, EventTypes.Member, self.u_alice.to_string(), {
-                "membership": Membership.JOIN
-            },
+            self.room,
+            self.u_alice,
+            EventTypes.Member,
+            self.u_alice.to_string(),
+            {"membership": Membership.JOIN},
         )
         e4 = yield self.inject_state_event(
-            self.room, self.u_bob, EventTypes.Member, self.u_bob.to_string(), {
-                "membership": Membership.JOIN
-            },
+            self.room,
+            self.u_bob,
+            EventTypes.Member,
+            self.u_bob.to_string(),
+            {"membership": Membership.JOIN},
         )
         e5 = yield self.inject_state_event(
-            self.room, self.u_bob, EventTypes.Member, self.u_bob.to_string(), {
-                "membership": Membership.LEAVE
-            },
+            self.room,
+            self.u_bob,
+            EventTypes.Member,
+            self.u_bob.to_string(),
+            {"membership": Membership.LEAVE},
         )
 
         # check we get the full state as of the final event
@@ -110,65 +114,66 @@ class StateStoreTestCase(tests.unittest.TestCase):
 
         self.assertIsNotNone(e4)
 
-        self.assertStateMapEqual({
-            (e1.type, e1.state_key): e1,
-            (e2.type, e2.state_key): e2,
-            (e3.type, e3.state_key): e3,
-            # e4 is overwritten by e5
-            (e5.type, e5.state_key): e5,
-        }, state)
+        self.assertStateMapEqual(
+            {
+                (e1.type, e1.state_key): e1,
+                (e2.type, e2.state_key): e2,
+                (e3.type, e3.state_key): e3,
+                # e4 is overwritten by e5
+                (e5.type, e5.state_key): e5,
+            },
+            state,
+        )
 
         # check we can filter to the m.room.name event (with a '' state key)
         state = yield self.store.get_state_for_event(
             e5.event_id, [(EventTypes.Name, '')], filtered_types=None
         )
 
-        self.assertStateMapEqual({
-            (e2.type, e2.state_key): e2,
-        }, state)
+        self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
 
         # check we can filter to the m.room.name event (with a wildcard None state key)
         state = yield self.store.get_state_for_event(
             e5.event_id, [(EventTypes.Name, None)], filtered_types=None
         )
 
-        self.assertStateMapEqual({
-            (e2.type, e2.state_key): e2,
-        }, state)
+        self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
 
         # check we can grab the m.room.member events (with a wildcard None state key)
         state = yield self.store.get_state_for_event(
             e5.event_id, [(EventTypes.Member, None)], filtered_types=None
         )
 
-        self.assertStateMapEqual({
-            (e3.type, e3.state_key): e3,
-            (e5.type, e5.state_key): e5,
-        }, state)
+        self.assertStateMapEqual(
+            {(e3.type, e3.state_key): e3, (e5.type, e5.state_key): e5}, state
+        )
 
-        # check we can use filter_types to grab a specific room member
+        # check we can use filtered_types to grab a specific room member
         # without filtering out the other event types
         state = yield self.store.get_state_for_event(
-            e5.event_id, [(EventTypes.Member, self.u_alice.to_string())],
+            e5.event_id,
+            [(EventTypes.Member, self.u_alice.to_string())],
             filtered_types=[EventTypes.Member],
         )
 
-        self.assertStateMapEqual({
-            (e1.type, e1.state_key): e1,
-            (e2.type, e2.state_key): e2,
-            (e3.type, e3.state_key): e3,
-        }, state)
+        self.assertStateMapEqual(
+            {
+                (e1.type, e1.state_key): e1,
+                (e2.type, e2.state_key): e2,
+                (e3.type, e3.state_key): e3,
+            },
+            state,
+        )
 
         # check that types=[], filtered_types=[EventTypes.Member]
         # doesn't return all members
         state = yield self.store.get_state_for_event(
-            e5.event_id, [], filtered_types=[EventTypes.Member],
+            e5.event_id, [], filtered_types=[EventTypes.Member]
         )
 
-        self.assertStateMapEqual({
-            (e1.type, e1.state_key): e1,
-            (e2.type, e2.state_key): e2,
-        }, state)
+        self.assertStateMapEqual(
+            {(e1.type, e1.state_key): e1, (e2.type, e2.state_key): e2}, state
+        )
 
         #######################################################
         # _get_some_state_from_cache tests against a full cache
@@ -176,70 +181,122 @@ class StateStoreTestCase(tests.unittest.TestCase):
 
         room_id = self.room.to_string()
         group_ids = yield self.store.get_state_groups_ids(room_id, [e5.event_id])
-        group = group_ids.keys()[0]
+        group = list(group_ids.keys())[0]
 
         # test _get_some_state_from_cache correctly filters out members with types=[]
         (state_dict, is_all) = yield self.store._get_some_state_from_cache(
+            self.store._state_group_cache,
             group, [], filtered_types=[EventTypes.Member]
         )
 
         self.assertEqual(is_all, True)
-        self.assertDictEqual({
-            (e1.type, e1.state_key): e1.event_id,
-            (e2.type, e2.state_key): e2.event_id,
-        }, state_dict)
+        self.assertDictEqual(
+            {
+                (e1.type, e1.state_key): e1.event_id,
+                (e2.type, e2.state_key): e2.event_id,
+            },
+            state_dict,
+        )
+
+        (state_dict, is_all) = yield self.store._get_some_state_from_cache(
+            self.store._state_group_members_cache,
+            group, [], filtered_types=[EventTypes.Member]
+        )
+
+        self.assertEqual(is_all, True)
+        self.assertDictEqual(
+            {},
+            state_dict,
+        )
 
         # test _get_some_state_from_cache correctly filters in members with wildcard types
         (state_dict, is_all) = yield self.store._get_some_state_from_cache(
+            self.store._state_group_cache,
+            group, [(EventTypes.Member, None)], filtered_types=[EventTypes.Member]
+        )
+
+        self.assertEqual(is_all, True)
+        self.assertDictEqual(
+            {
+                (e1.type, e1.state_key): e1.event_id,
+                (e2.type, e2.state_key): e2.event_id,
+            },
+            state_dict,
+        )
+
+        (state_dict, is_all) = yield self.store._get_some_state_from_cache(
+            self.store._state_group_members_cache,
             group, [(EventTypes.Member, None)], filtered_types=[EventTypes.Member]
         )
 
         self.assertEqual(is_all, True)
-        self.assertDictEqual({
-            (e1.type, e1.state_key): e1.event_id,
-            (e2.type, e2.state_key): e2.event_id,
-            (e3.type, e3.state_key): e3.event_id,
-            # e4 is overwritten by e5
-            (e5.type, e5.state_key): e5.event_id,
-        }, state_dict)
+        self.assertDictEqual(
+            {
+                (e3.type, e3.state_key): e3.event_id,
+                # e4 is overwritten by e5
+                (e5.type, e5.state_key): e5.event_id,
+            },
+            state_dict,
+        )
 
         # test _get_some_state_from_cache correctly filters in members with specific types
         (state_dict, is_all) = yield self.store._get_some_state_from_cache(
-            group, [(EventTypes.Member, e5.state_key)], filtered_types=[EventTypes.Member]
+            self.store._state_group_cache,
+            group,
+            [(EventTypes.Member, e5.state_key)],
+            filtered_types=[EventTypes.Member],
         )
 
         self.assertEqual(is_all, True)
-        self.assertDictEqual({
-            (e1.type, e1.state_key): e1.event_id,
-            (e2.type, e2.state_key): e2.event_id,
-            (e5.type, e5.state_key): e5.event_id,
-        }, state_dict)
+        self.assertDictEqual(
+            {
+                (e1.type, e1.state_key): e1.event_id,
+                (e2.type, e2.state_key): e2.event_id,
+            },
+            state_dict,
+        )
+
+        (state_dict, is_all) = yield self.store._get_some_state_from_cache(
+            self.store._state_group_members_cache,
+            group,
+            [(EventTypes.Member, e5.state_key)],
+            filtered_types=[EventTypes.Member],
+        )
+
+        self.assertEqual(is_all, True)
+        self.assertDictEqual(
+            {
+                (e5.type, e5.state_key): e5.event_id,
+            },
+            state_dict,
+        )
 
         # test _get_some_state_from_cache correctly filters in members with specific types
         # and no filtered_types
         (state_dict, is_all) = yield self.store._get_some_state_from_cache(
+            self.store._state_group_members_cache,
             group, [(EventTypes.Member, e5.state_key)], filtered_types=None
         )
 
         self.assertEqual(is_all, True)
-        self.assertDictEqual({
-            (e5.type, e5.state_key): e5.event_id,
-        }, state_dict)
+        self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
 
         #######################################################
         # deliberately remove e2 (room name) from the _state_group_cache
 
-        (is_all, known_absent, state_dict_ids) = self.store._state_group_cache.get(group)
+        (is_all, known_absent, state_dict_ids) = self.store._state_group_cache.get(
+            group
+        )
 
         self.assertEqual(is_all, True)
         self.assertEqual(known_absent, set())
-        self.assertDictEqual(state_dict_ids, {
-            (e1.type, e1.state_key): e1.event_id,
-            (e2.type, e2.state_key): e2.event_id,
-            (e3.type, e3.state_key): e3.event_id,
-            # e4 is overwritten by e5
-            (e5.type, e5.state_key): e5.event_id,
-        })
+        self.assertDictEqual(
+            state_dict_ids,
+            {
+                (e1.type, e1.state_key): e1.event_id,
+                (e2.type, e2.state_key): e2.event_id,
+            },
+        )
 
         state_dict_ids.pop((e2.type, e2.state_key))
         self.store._state_group_cache.invalidate(group)
@@ -250,24 +307,28 @@ class StateStoreTestCase(tests.unittest.TestCase):
             # list fetched keys so it knows it's partial
             fetched_keys=(
                 (e1.type, e1.state_key),
-                (e3.type, e3.state_key),
-                (e5.type, e5.state_key),
-            )
+            ),
         )
 
-        (is_all, known_absent, state_dict_ids) = self.store._state_group_cache.get(group)
+        (is_all, known_absent, state_dict_ids) = self.store._state_group_cache.get(
+            group
+        )
 
         self.assertEqual(is_all, False)
-        self.assertEqual(known_absent, set([
-            (e1.type, e1.state_key),
-            (e3.type, e3.state_key),
-            (e5.type, e5.state_key),
-        ]))
-        self.assertDictEqual(state_dict_ids, {
-            (e1.type, e1.state_key): e1.event_id,
-            (e3.type, e3.state_key): e3.event_id,
-            (e5.type, e5.state_key): e5.event_id,
-        })
+        self.assertEqual(
+            known_absent,
+            set(
+                [
+                    (e1.type, e1.state_key),
+                ]
+            ),
+        )
+        self.assertDictEqual(
+            state_dict_ids,
+            {
+                (e1.type, e1.state_key): e1.event_id,
+            },
+        )
 
         ############################################
         # test that things work with a partial cache
@@ -275,45 +336,100 @@ class StateStoreTestCase(tests.unittest.TestCase):
         # test _get_some_state_from_cache correctly filters out members with types=[]
         room_id = self.room.to_string()
         (state_dict, is_all) = yield self.store._get_some_state_from_cache(
+            self.store._state_group_cache,
             group, [], filtered_types=[EventTypes.Member]
         )
 
         self.assertEqual(is_all, False)
-        self.assertDictEqual({
-            (e1.type, e1.state_key): e1.event_id,
-        }, state_dict)
+        self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
+
+        room_id = self.room.to_string()
+        (state_dict, is_all) = yield self.store._get_some_state_from_cache(
+            self.store._state_group_members_cache,
+            group, [], filtered_types=[EventTypes.Member]
+        )
+
+        self.assertEqual(is_all, True)
+        self.assertDictEqual({}, state_dict)
 
         # test _get_some_state_from_cache correctly filters in members wildcard types
         (state_dict, is_all) = yield self.store._get_some_state_from_cache(
+            self.store._state_group_cache,
             group, [(EventTypes.Member, None)], filtered_types=[EventTypes.Member]
         )
 
         self.assertEqual(is_all, False)
-        self.assertDictEqual({
-            (e1.type, e1.state_key): e1.event_id,
-            (e3.type, e3.state_key): e3.event_id,
-            # e4 is overwritten by e5
-            (e5.type, e5.state_key): e5.event_id,
-        }, state_dict)
+        self.assertDictEqual(
+            {
+                (e1.type, e1.state_key): e1.event_id,
+            },
+            state_dict,
+        )
+
+        (state_dict, is_all) = yield self.store._get_some_state_from_cache(
+            self.store._state_group_members_cache,
+            group, [(EventTypes.Member, None)], filtered_types=[EventTypes.Member]
+        )
+
+        self.assertEqual(is_all, True)
+        self.assertDictEqual(
+            {
+                (e3.type, e3.state_key): e3.event_id,
+                (e5.type, e5.state_key): e5.event_id,
+            },
+            state_dict,
+        )
 
         # test _get_some_state_from_cache correctly filters in members with specific types
         (state_dict, is_all) = yield self.store._get_some_state_from_cache(
-            group, [(EventTypes.Member, e5.state_key)], filtered_types=[EventTypes.Member]
+            self.store._state_group_cache,
+            group,
+            [(EventTypes.Member, e5.state_key)],
+            filtered_types=[EventTypes.Member],
         )
 
         self.assertEqual(is_all, False)
-        self.assertDictEqual({
-            (e1.type, e1.state_key): e1.event_id,
-            (e5.type, e5.state_key): e5.event_id,
-        }, state_dict)
+        self.assertDictEqual(
+            {
+                (e1.type, e1.state_key): e1.event_id,
+            },
+            state_dict,
+        )
+
+        (state_dict, is_all) = yield self.store._get_some_state_from_cache(
+            self.store._state_group_members_cache,
+            group,
+            [(EventTypes.Member, e5.state_key)],
+            filtered_types=[EventTypes.Member],
+        )
+
+        self.assertEqual(is_all, True)
+        self.assertDictEqual(
+            {
+                (e5.type, e5.state_key): e5.event_id,
+            },
+            state_dict,
+        )
 
         # test _get_some_state_from_cache correctly filters in members with specific types
         # and no filtered_types
         (state_dict, is_all) = yield self.store._get_some_state_from_cache(
+            self.store._state_group_cache,
+            group, [(EventTypes.Member, e5.state_key)], filtered_types=None
+        )
+
+        self.assertEqual(is_all, False)
+        self.assertDictEqual({}, state_dict)
+
+        (state_dict, is_all) = yield self.store._get_some_state_from_cache(
+            self.store._state_group_members_cache,
             group, [(EventTypes.Member, e5.state_key)], filtered_types=None
         )
 
         self.assertEqual(is_all, True)
-        self.assertDictEqual({
-            (e5.type, e5.state_key): e5.event_id,
-        }, state_dict)
+        self.assertDictEqual(
+            {
+                (e5.type, e5.state_key): e5.event_id,
+            },
+            state_dict,
+        )
diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index 23fad12bca..b46e0ea7e2 100644
--- a/tests/storage/test_user_directory.py
+++ b/tests/storage/test_user_directory.py
@@ -29,7 +29,7 @@ BOBBY = "@bobby:a"
 class UserDirectoryStoreTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def setUp(self):
-        self.hs = yield setup_test_homeserver()
+        self.hs = yield setup_test_homeserver(self.addCleanup)
         self.store = UserDirectoryStore(None, self.hs)
 
         # alice and bob are both in !room_id. bobby is not but shares
@@ -39,20 +39,12 @@ class UserDirectoryStoreTestCase(unittest.TestCase):
             {
                 ALICE: ProfileInfo(None, "alice"),
                 BOB: ProfileInfo(None, "bob"),
-                BOBBY: ProfileInfo(None, "bobby")
+                BOBBY: ProfileInfo(None, "bobby"),
             },
         )
-        yield self.store.add_users_to_public_room(
-            "!room:id",
-            [ALICE, BOB],
-        )
+        yield self.store.add_users_to_public_room("!room:id", [ALICE, BOB])
         yield self.store.add_users_who_share_room(
-            "!room:id",
-            False,
-            (
-                (ALICE, BOB),
-                (BOB, ALICE),
-            ),
+            "!room:id", False, ((ALICE, BOB), (BOB, ALICE))
         )
 
     @defer.inlineCallbacks
@@ -62,11 +54,9 @@ class UserDirectoryStoreTestCase(unittest.TestCase):
         r = yield self.store.search_user_dir(ALICE, "bob", 10)
         self.assertFalse(r["limited"])
         self.assertEqual(1, len(r["results"]))
-        self.assertDictEqual(r["results"][0], {
-            "user_id": BOB,
-            "display_name": "bob",
-            "avatar_url": None,
-        })
+        self.assertDictEqual(
+            r["results"][0], {"user_id": BOB, "display_name": "bob", "avatar_url": None}
+        )
 
     @defer.inlineCallbacks
     def test_search_user_dir_all_users(self):
@@ -75,15 +65,13 @@ class UserDirectoryStoreTestCase(unittest.TestCase):
             r = yield self.store.search_user_dir(ALICE, "bob", 10)
             self.assertFalse(r["limited"])
             self.assertEqual(2, len(r["results"]))
-            self.assertDictEqual(r["results"][0], {
-                "user_id": BOB,
-                "display_name": "bob",
-                "avatar_url": None,
-            })
-            self.assertDictEqual(r["results"][1], {
-                "user_id": BOBBY,
-                "display_name": "bobby",
-                "avatar_url": None,
-            })
+            self.assertDictEqual(
+                r["results"][0],
+                {"user_id": BOB, "display_name": "bob", "avatar_url": None},
+            )
+            self.assertDictEqual(
+                r["results"][1],
+                {"user_id": BOBBY, "display_name": "bobby", "avatar_url": None},
+            )
         finally:
             self.hs.config.user_directory_search_all_users = False