summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2018-08-06 13:33:54 +0100
committerErik Johnston <erik@matrix.org>2018-08-06 13:33:54 +0100
commit49a316395831a0676160833c39daa0bb63ceea09 (patch)
treef1c4db4121d7890a7ba390c01b0e81e329f09b22 /tests
parentMerge branch 'release-v0.33.1' of github.com:matrix-org/synapse into matrix-o... (diff)
parentReturn M_NOT_FOUND when a profile could not be found. (#3596) (diff)
downloadsynapse-49a316395831a0676160833c39daa0bb63ceea09.tar.xz
Merge branch 'develop' of github.com:matrix-org/synapse into matrix-org-hotfixes
Diffstat (limited to 'tests')
-rw-r--r--tests/api/test_auth.py35
-rw-r--r--tests/handlers/test_auth.py80
-rw-r--r--tests/handlers/test_register.py51
-rw-r--r--tests/handlers/test_typing.py1
-rw-r--r--tests/replication/slave/storage/_base.py37
-rw-r--r--tests/storage/test__init__.py65
-rw-r--r--tests/storage/test_state.py319
-rw-r--r--tests/util/caches/test_descriptors.py101
-rw-r--r--tests/utils.py9
9 files changed, 659 insertions, 39 deletions
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index 5f158ec4b9..a82d737e71 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -46,7 +46,7 @@ class AuthTestCase(unittest.TestCase):
         self.auth = Auth(self.hs)
 
         self.test_user = "@foo:bar"
-        self.test_token = "_test_token_"
+        self.test_token = b"_test_token_"
 
         # this is overridden for the appservice tests
         self.store.get_app_service_by_token = Mock(return_value=None)
@@ -61,7 +61,7 @@ class AuthTestCase(unittest.TestCase):
         self.store.get_user_by_access_token = Mock(return_value=user_info)
 
         request = Mock(args={})
-        request.args["access_token"] = [self.test_token]
+        request.args[b"access_token"] = [self.test_token]
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
         requester = yield self.auth.get_user_by_req(request)
         self.assertEquals(requester.user.to_string(), self.test_user)
@@ -70,7 +70,7 @@ class AuthTestCase(unittest.TestCase):
         self.store.get_user_by_access_token = Mock(return_value=None)
 
         request = Mock(args={})
-        request.args["access_token"] = [self.test_token]
+        request.args[b"access_token"] = [self.test_token]
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
         d = self.auth.get_user_by_req(request)
         self.failureResultOf(d, AuthError)
@@ -98,7 +98,7 @@ class AuthTestCase(unittest.TestCase):
 
         request = Mock(args={})
         request.getClientIP.return_value = "127.0.0.1"
-        request.args["access_token"] = [self.test_token]
+        request.args[b"access_token"] = [self.test_token]
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
         requester = yield self.auth.get_user_by_req(request)
         self.assertEquals(requester.user.to_string(), self.test_user)
@@ -115,7 +115,7 @@ class AuthTestCase(unittest.TestCase):
 
         request = Mock(args={})
         request.getClientIP.return_value = "192.168.10.10"
-        request.args["access_token"] = [self.test_token]
+        request.args[b"access_token"] = [self.test_token]
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
         requester = yield self.auth.get_user_by_req(request)
         self.assertEquals(requester.user.to_string(), self.test_user)
@@ -131,7 +131,7 @@ class AuthTestCase(unittest.TestCase):
 
         request = Mock(args={})
         request.getClientIP.return_value = "131.111.8.42"
-        request.args["access_token"] = [self.test_token]
+        request.args[b"access_token"] = [self.test_token]
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
         d = self.auth.get_user_by_req(request)
         self.failureResultOf(d, AuthError)
@@ -141,7 +141,7 @@ class AuthTestCase(unittest.TestCase):
         self.store.get_user_by_access_token = Mock(return_value=None)
 
         request = Mock(args={})
-        request.args["access_token"] = [self.test_token]
+        request.args[b"access_token"] = [self.test_token]
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
         d = self.auth.get_user_by_req(request)
         self.failureResultOf(d, AuthError)
@@ -158,7 +158,7 @@ class AuthTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_get_user_by_req_appservice_valid_token_valid_user_id(self):
-        masquerading_user_id = "@doppelganger:matrix.org"
+        masquerading_user_id = b"@doppelganger:matrix.org"
         app_service = Mock(
             token="foobar", url="a_url", sender=self.test_user,
             ip_range_whitelist=None,
@@ -169,14 +169,17 @@ class AuthTestCase(unittest.TestCase):
 
         request = Mock(args={})
         request.getClientIP.return_value = "127.0.0.1"
-        request.args["access_token"] = [self.test_token]
-        request.args["user_id"] = [masquerading_user_id]
+        request.args[b"access_token"] = [self.test_token]
+        request.args[b"user_id"] = [masquerading_user_id]
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
         requester = yield self.auth.get_user_by_req(request)
-        self.assertEquals(requester.user.to_string(), masquerading_user_id)
+        self.assertEquals(
+            requester.user.to_string(),
+            masquerading_user_id.decode('utf8')
+        )
 
     def test_get_user_by_req_appservice_valid_token_bad_user_id(self):
-        masquerading_user_id = "@doppelganger:matrix.org"
+        masquerading_user_id = b"@doppelganger:matrix.org"
         app_service = Mock(
             token="foobar", url="a_url", sender=self.test_user,
             ip_range_whitelist=None,
@@ -187,8 +190,8 @@ class AuthTestCase(unittest.TestCase):
 
         request = Mock(args={})
         request.getClientIP.return_value = "127.0.0.1"
-        request.args["access_token"] = [self.test_token]
-        request.args["user_id"] = [masquerading_user_id]
+        request.args[b"access_token"] = [self.test_token]
+        request.args[b"user_id"] = [masquerading_user_id]
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
         d = self.auth.get_user_by_req(request)
         self.failureResultOf(d, AuthError)
@@ -418,7 +421,7 @@ class AuthTestCase(unittest.TestCase):
 
         # check the token works
         request = Mock(args={})
-        request.args["access_token"] = [token]
+        request.args[b"access_token"] = [token.encode('ascii')]
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
         requester = yield self.auth.get_user_by_req(request, allow_guest=True)
         self.assertEqual(UserID.from_string(USER_ID), requester.user)
@@ -431,7 +434,7 @@ class AuthTestCase(unittest.TestCase):
 
         # the token should *not* work now
         request = Mock(args={})
-        request.args["access_token"] = [guest_tok]
+        request.args[b"access_token"] = [guest_tok.encode('ascii')]
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
 
         with self.assertRaises(AuthError) as cm:
diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py
index 2e5e8e4dec..55eab9e9cf 100644
--- a/tests/handlers/test_auth.py
+++ b/tests/handlers/test_auth.py
@@ -12,6 +12,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from mock import Mock
 
 import pymacaroons
 
@@ -19,6 +20,7 @@ from twisted.internet import defer
 
 import synapse
 import synapse.api.errors
+from synapse.api.errors import AuthError
 from synapse.handlers.auth import AuthHandler
 
 from tests import unittest
@@ -37,6 +39,10 @@ class AuthTestCase(unittest.TestCase):
         self.hs.handlers = AuthHandlers(self.hs)
         self.auth_handler = self.hs.handlers.auth_handler
         self.macaroon_generator = self.hs.get_macaroon_generator()
+        # MAU tests
+        self.hs.config.max_mau_value = 50
+        self.small_number_of_users = 1
+        self.large_number_of_users = 100
 
     def test_token_is_a_macaroon(self):
         token = self.macaroon_generator.generate_access_token("some_user")
@@ -71,38 +77,37 @@ class AuthTestCase(unittest.TestCase):
         v.satisfy_general(verify_nonce)
         v.verify(macaroon, self.hs.config.macaroon_secret_key)
 
+    @defer.inlineCallbacks
     def test_short_term_login_token_gives_user_id(self):
         self.hs.clock.now = 1000
 
         token = self.macaroon_generator.generate_short_term_login_token(
             "a_user", 5000
         )
-
-        self.assertEqual(
-            "a_user",
-            self.auth_handler.validate_short_term_login_token_and_get_user_id(
-                token
-            )
+        user_id = yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
+            token
         )
+        self.assertEqual("a_user", user_id)
 
         # when we advance the clock, the token should be rejected
         self.hs.clock.now = 6000
         with self.assertRaises(synapse.api.errors.AuthError):
-            self.auth_handler.validate_short_term_login_token_and_get_user_id(
+            yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
                 token
             )
 
+    @defer.inlineCallbacks
     def test_short_term_login_token_cannot_replace_user_id(self):
         token = self.macaroon_generator.generate_short_term_login_token(
             "a_user", 5000
         )
         macaroon = pymacaroons.Macaroon.deserialize(token)
 
+        user_id = yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
+            macaroon.serialize()
+        )
         self.assertEqual(
-            "a_user",
-            self.auth_handler.validate_short_term_login_token_and_get_user_id(
-                macaroon.serialize()
-            )
+            "a_user", user_id
         )
 
         # add another "user_id" caveat, which might allow us to override the
@@ -110,6 +115,57 @@ class AuthTestCase(unittest.TestCase):
         macaroon.add_first_party_caveat("user_id = b_user")
 
         with self.assertRaises(synapse.api.errors.AuthError):
-            self.auth_handler.validate_short_term_login_token_and_get_user_id(
+            yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
                 macaroon.serialize()
             )
+
+    @defer.inlineCallbacks
+    def test_mau_limits_disabled(self):
+        self.hs.config.limit_usage_by_mau = False
+        # Ensure does not throw exception
+        yield self.auth_handler.get_access_token_for_user_id('user_a')
+
+        yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
+            self._get_macaroon().serialize()
+        )
+
+    @defer.inlineCallbacks
+    def test_mau_limits_exceeded(self):
+        self.hs.config.limit_usage_by_mau = True
+        self.hs.get_datastore().count_monthly_users = Mock(
+            return_value=defer.succeed(self.large_number_of_users)
+        )
+
+        with self.assertRaises(AuthError):
+            yield self.auth_handler.get_access_token_for_user_id('user_a')
+
+        self.hs.get_datastore().count_monthly_users = Mock(
+            return_value=defer.succeed(self.large_number_of_users)
+        )
+        with self.assertRaises(AuthError):
+            yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
+                self._get_macaroon().serialize()
+            )
+
+    @defer.inlineCallbacks
+    def test_mau_limits_not_exceeded(self):
+        self.hs.config.limit_usage_by_mau = True
+
+        self.hs.get_datastore().count_monthly_users = Mock(
+            return_value=defer.succeed(self.small_number_of_users)
+        )
+        # Ensure does not raise exception
+        yield self.auth_handler.get_access_token_for_user_id('user_a')
+
+        self.hs.get_datastore().count_monthly_users = Mock(
+            return_value=defer.succeed(self.small_number_of_users)
+        )
+        yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
+            self._get_macaroon().serialize()
+        )
+
+    def _get_macaroon(self):
+        token = self.macaroon_generator.generate_short_term_login_token(
+            "user_a", 5000
+        )
+        return pymacaroons.Macaroon.deserialize(token)
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index 025fa1be81..0937d71cf6 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -17,6 +17,7 @@ from mock import Mock
 
 from twisted.internet import defer
 
+from synapse.api.errors import RegistrationError
 from synapse.handlers.register import RegistrationHandler
 from synapse.types import UserID, create_requester
 
@@ -77,3 +78,53 @@ class RegistrationTestCase(unittest.TestCase):
             requester, local_part, display_name)
         self.assertEquals(result_user_id, user_id)
         self.assertEquals(result_token, 'secret')
+
+    @defer.inlineCallbacks
+    def test_cannot_register_when_mau_limits_exceeded(self):
+        local_part = "someone"
+        display_name = "someone"
+        requester = create_requester("@as:test")
+        store = self.hs.get_datastore()
+        self.hs.config.limit_usage_by_mau = False
+        self.hs.config.max_mau_value = 50
+        lots_of_users = 100
+        small_number_users = 1
+
+        store.count_monthly_users = Mock(return_value=defer.succeed(lots_of_users))
+
+        # Ensure does not throw exception
+        yield self.handler.get_or_create_user(requester, 'a', display_name)
+
+        self.hs.config.limit_usage_by_mau = True
+
+        with self.assertRaises(RegistrationError):
+            yield self.handler.get_or_create_user(requester, 'b', display_name)
+
+        store.count_monthly_users = Mock(return_value=defer.succeed(small_number_users))
+
+        self._macaroon_mock_generator("another_secret")
+
+        # Ensure does not throw exception
+        yield self.handler.get_or_create_user("@neil:matrix.org", 'c', "Neil")
+
+        self._macaroon_mock_generator("another another secret")
+        store.count_monthly_users = Mock(return_value=defer.succeed(lots_of_users))
+
+        with self.assertRaises(RegistrationError):
+            yield self.handler.register(localpart=local_part)
+
+        self._macaroon_mock_generator("another another secret")
+        store.count_monthly_users = Mock(return_value=defer.succeed(lots_of_users))
+
+        with self.assertRaises(RegistrationError):
+            yield self.handler.register_saml2(local_part)
+
+    def _macaroon_mock_generator(self, secret):
+        """
+        Reset macaroon generator in the case where the test creates multiple users
+        """
+        macaroon_generator = Mock(
+            generate_access_token=Mock(return_value=secret))
+        self.hs.get_macaroon_generator = Mock(return_value=macaroon_generator)
+        self.hs.handlers = RegistrationHandlers(self.hs)
+        self.handler = self.hs.get_handlers().registration_handler
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index b08856f763..2c263af1a3 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -44,7 +44,6 @@ def _expect_edu(destination, edu_type, content, origin="test"):
                 "content": content,
             }
         ],
-        "pdu_failures": [],
     }
 
 
diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py
index 8708c8a196..a103e7be80 100644
--- a/tests/replication/slave/storage/_base.py
+++ b/tests/replication/slave/storage/_base.py
@@ -11,23 +11,44 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import tempfile
 
 from mock import Mock, NonCallableMock
 
 from twisted.internet import defer, reactor
+from twisted.internet.defer import Deferred
 
 from synapse.replication.tcp.client import (
     ReplicationClientFactory,
     ReplicationClientHandler,
 )
 from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
+from synapse.util.logcontext import PreserveLoggingContext, make_deferred_yieldable
 
 from tests import unittest
 from tests.utils import setup_test_homeserver
 
 
+class TestReplicationClientHandler(ReplicationClientHandler):
+    """Overrides on_rdata so that we can wait for it to happen"""
+    def __init__(self, store):
+        super(TestReplicationClientHandler, self).__init__(store)
+        self._rdata_awaiters = []
+
+    def await_replication(self):
+        d = Deferred()
+        self._rdata_awaiters.append(d)
+        return make_deferred_yieldable(d)
+
+    def on_rdata(self, stream_name, token, rows):
+        awaiters = self._rdata_awaiters
+        self._rdata_awaiters = []
+        super(TestReplicationClientHandler, self).on_rdata(stream_name, token, rows)
+        with PreserveLoggingContext():
+            for a in awaiters:
+                a.callback(None)
+
+
 class BaseSlavedStoreTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def setUp(self):
@@ -52,7 +73,7 @@ class BaseSlavedStoreTestCase(unittest.TestCase):
         self.addCleanup(listener.stopListening)
         self.streamer = server_factory.streamer
 
-        self.replication_handler = ReplicationClientHandler(self.slaved_store)
+        self.replication_handler = TestReplicationClientHandler(self.slaved_store)
         client_factory = ReplicationClientFactory(
             self.hs, "client_name", self.replication_handler
         )
@@ -60,12 +81,14 @@ class BaseSlavedStoreTestCase(unittest.TestCase):
         self.addCleanup(client_factory.stopTrying)
         self.addCleanup(client_connector.disconnect)
 
-    @defer.inlineCallbacks
     def replicate(self):
-        yield self.streamer.on_notifier_poke()
-        d = self.replication_handler.await_sync("replication_test")
-        self.streamer.send_sync_to_all_connections("replication_test")
-        yield d
+        """Tell the master side of replication that something has happened, and then
+        wait for the replication to occur.
+        """
+        # xxx: should we be more specific in what we wait for?
+        d = self.replication_handler.await_replication()
+        self.streamer.on_notifier_poke()
+        return d
 
     @defer.inlineCallbacks
     def check(self, method, args, expected_result=None):
diff --git a/tests/storage/test__init__.py b/tests/storage/test__init__.py
new file mode 100644
index 0000000000..f19cb1265c
--- /dev/null
+++ b/tests/storage/test__init__.py
@@ -0,0 +1,65 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+import tests.utils
+
+
+class InitTestCase(tests.unittest.TestCase):
+    def __init__(self, *args, **kwargs):
+        super(InitTestCase, self).__init__(*args, **kwargs)
+        self.store = None  # type: synapse.storage.DataStore
+
+    @defer.inlineCallbacks
+    def setUp(self):
+        hs = yield tests.utils.setup_test_homeserver()
+
+        hs.config.max_mau_value = 50
+        hs.config.limit_usage_by_mau = True
+        self.store = hs.get_datastore()
+        self.clock = hs.get_clock()
+
+    @defer.inlineCallbacks
+    def test_count_monthly_users(self):
+        count = yield self.store.count_monthly_users()
+        self.assertEqual(0, count)
+
+        yield self._insert_user_ips("@user:server1")
+        yield self._insert_user_ips("@user:server2")
+
+        count = yield self.store.count_monthly_users()
+        self.assertEqual(2, count)
+
+    @defer.inlineCallbacks
+    def _insert_user_ips(self, user):
+        """
+        Helper function to populate user_ips without using batch insertion infra
+        args:
+            user (str):  specify username i.e. @user:server.com
+        """
+        yield self.store._simple_upsert(
+            table="user_ips",
+            keyvalues={
+                "user_id": user,
+                "access_token": "access_token",
+                "ip": "ip",
+                "user_agent": "user_agent",
+                "device_id": "device_id",
+            },
+            values={
+                "last_seen": self.clock.time_msec(),
+            }
+        )
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
new file mode 100644
index 0000000000..7a76d67b8c
--- /dev/null
+++ b/tests/storage/test_state.py
@@ -0,0 +1,319 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from twisted.internet import defer
+
+from synapse.api.constants import EventTypes, Membership
+from synapse.types import RoomID, UserID
+
+import tests.unittest
+import tests.utils
+
+logger = logging.getLogger(__name__)
+
+
+class StateStoreTestCase(tests.unittest.TestCase):
+    def __init__(self, *args, **kwargs):
+        super(StateStoreTestCase, self).__init__(*args, **kwargs)
+        self.store = None  # type: synapse.storage.DataStore
+
+    @defer.inlineCallbacks
+    def setUp(self):
+        hs = yield tests.utils.setup_test_homeserver()
+
+        self.store = hs.get_datastore()
+        self.event_builder_factory = hs.get_event_builder_factory()
+        self.event_creation_handler = hs.get_event_creation_handler()
+
+        self.u_alice = UserID.from_string("@alice:test")
+        self.u_bob = UserID.from_string("@bob:test")
+
+        self.room = RoomID.from_string("!abc123:test")
+
+        yield self.store.store_room(
+            self.room.to_string(),
+            room_creator_user_id="@creator:text",
+            is_public=True
+        )
+
+    @defer.inlineCallbacks
+    def inject_state_event(self, room, sender, typ, state_key, content):
+        builder = self.event_builder_factory.new({
+            "type": typ,
+            "sender": sender.to_string(),
+            "state_key": state_key,
+            "room_id": room.to_string(),
+            "content": content,
+        })
+
+        event, context = yield self.event_creation_handler.create_new_client_event(
+            builder
+        )
+
+        yield self.store.persist_event(event, context)
+
+        defer.returnValue(event)
+
+    def assertStateMapEqual(self, s1, s2):
+        for t in s1:
+            # just compare event IDs for simplicity
+            self.assertEqual(s1[t].event_id, s2[t].event_id)
+        self.assertEqual(len(s1), len(s2))
+
+    @defer.inlineCallbacks
+    def test_get_state_for_event(self):
+
+        # this defaults to a linear DAG as each new injection defaults to whatever
+        # forward extremities are currently in the DB for this room.
+        e1 = yield self.inject_state_event(
+            self.room, self.u_alice, EventTypes.Create, '', {},
+        )
+        e2 = yield self.inject_state_event(
+            self.room, self.u_alice, EventTypes.Name, '', {
+                "name": "test room"
+            },
+        )
+        e3 = yield self.inject_state_event(
+            self.room, self.u_alice, EventTypes.Member, self.u_alice.to_string(), {
+                "membership": Membership.JOIN
+            },
+        )
+        e4 = yield self.inject_state_event(
+            self.room, self.u_bob, EventTypes.Member, self.u_bob.to_string(), {
+                "membership": Membership.JOIN
+            },
+        )
+        e5 = yield self.inject_state_event(
+            self.room, self.u_bob, EventTypes.Member, self.u_bob.to_string(), {
+                "membership": Membership.LEAVE
+            },
+        )
+
+        # check we get the full state as of the final event
+        state = yield self.store.get_state_for_event(
+            e5.event_id, None, filtered_types=None
+        )
+
+        self.assertIsNotNone(e4)
+
+        self.assertStateMapEqual({
+            (e1.type, e1.state_key): e1,
+            (e2.type, e2.state_key): e2,
+            (e3.type, e3.state_key): e3,
+            # e4 is overwritten by e5
+            (e5.type, e5.state_key): e5,
+        }, state)
+
+        # check we can filter to the m.room.name event (with a '' state key)
+        state = yield self.store.get_state_for_event(
+            e5.event_id, [(EventTypes.Name, '')], filtered_types=None
+        )
+
+        self.assertStateMapEqual({
+            (e2.type, e2.state_key): e2,
+        }, state)
+
+        # check we can filter to the m.room.name event (with a wildcard None state key)
+        state = yield self.store.get_state_for_event(
+            e5.event_id, [(EventTypes.Name, None)], filtered_types=None
+        )
+
+        self.assertStateMapEqual({
+            (e2.type, e2.state_key): e2,
+        }, state)
+
+        # check we can grab the m.room.member events (with a wildcard None state key)
+        state = yield self.store.get_state_for_event(
+            e5.event_id, [(EventTypes.Member, None)], filtered_types=None
+        )
+
+        self.assertStateMapEqual({
+            (e3.type, e3.state_key): e3,
+            (e5.type, e5.state_key): e5,
+        }, state)
+
+        # check we can use filter_types to grab a specific room member
+        # without filtering out the other event types
+        state = yield self.store.get_state_for_event(
+            e5.event_id, [(EventTypes.Member, self.u_alice.to_string())],
+            filtered_types=[EventTypes.Member],
+        )
+
+        self.assertStateMapEqual({
+            (e1.type, e1.state_key): e1,
+            (e2.type, e2.state_key): e2,
+            (e3.type, e3.state_key): e3,
+        }, state)
+
+        # check that types=[], filtered_types=[EventTypes.Member]
+        # doesn't return all members
+        state = yield self.store.get_state_for_event(
+            e5.event_id, [], filtered_types=[EventTypes.Member],
+        )
+
+        self.assertStateMapEqual({
+            (e1.type, e1.state_key): e1,
+            (e2.type, e2.state_key): e2,
+        }, state)
+
+        #######################################################
+        # _get_some_state_from_cache tests against a full cache
+        #######################################################
+
+        room_id = self.room.to_string()
+        group_ids = yield self.store.get_state_groups_ids(room_id, [e5.event_id])
+        group = group_ids.keys()[0]
+
+        # test _get_some_state_from_cache correctly filters out members with types=[]
+        (state_dict, is_all) = yield self.store._get_some_state_from_cache(
+            group, [], filtered_types=[EventTypes.Member]
+        )
+
+        self.assertEqual(is_all, True)
+        self.assertDictEqual({
+            (e1.type, e1.state_key): e1.event_id,
+            (e2.type, e2.state_key): e2.event_id,
+        }, state_dict)
+
+        # test _get_some_state_from_cache correctly filters in members with wildcard types
+        (state_dict, is_all) = yield self.store._get_some_state_from_cache(
+            group, [(EventTypes.Member, None)], filtered_types=[EventTypes.Member]
+        )
+
+        self.assertEqual(is_all, True)
+        self.assertDictEqual({
+            (e1.type, e1.state_key): e1.event_id,
+            (e2.type, e2.state_key): e2.event_id,
+            (e3.type, e3.state_key): e3.event_id,
+            # e4 is overwritten by e5
+            (e5.type, e5.state_key): e5.event_id,
+        }, state_dict)
+
+        # test _get_some_state_from_cache correctly filters in members with specific types
+        (state_dict, is_all) = yield self.store._get_some_state_from_cache(
+            group, [(EventTypes.Member, e5.state_key)], filtered_types=[EventTypes.Member]
+        )
+
+        self.assertEqual(is_all, True)
+        self.assertDictEqual({
+            (e1.type, e1.state_key): e1.event_id,
+            (e2.type, e2.state_key): e2.event_id,
+            (e5.type, e5.state_key): e5.event_id,
+        }, state_dict)
+
+        # test _get_some_state_from_cache correctly filters in members with specific types
+        # and no filtered_types
+        (state_dict, is_all) = yield self.store._get_some_state_from_cache(
+            group, [(EventTypes.Member, e5.state_key)], filtered_types=None
+        )
+
+        self.assertEqual(is_all, True)
+        self.assertDictEqual({
+            (e5.type, e5.state_key): e5.event_id,
+        }, state_dict)
+
+        #######################################################
+        # deliberately remove e2 (room name) from the _state_group_cache
+
+        (is_all, known_absent, state_dict_ids) = self.store._state_group_cache.get(group)
+
+        self.assertEqual(is_all, True)
+        self.assertEqual(known_absent, set())
+        self.assertDictEqual(state_dict_ids, {
+            (e1.type, e1.state_key): e1.event_id,
+            (e2.type, e2.state_key): e2.event_id,
+            (e3.type, e3.state_key): e3.event_id,
+            # e4 is overwritten by e5
+            (e5.type, e5.state_key): e5.event_id,
+        })
+
+        state_dict_ids.pop((e2.type, e2.state_key))
+        self.store._state_group_cache.invalidate(group)
+        self.store._state_group_cache.update(
+            sequence=self.store._state_group_cache.sequence,
+            key=group,
+            value=state_dict_ids,
+            # list fetched keys so it knows it's partial
+            fetched_keys=(
+                (e1.type, e1.state_key),
+                (e3.type, e3.state_key),
+                (e5.type, e5.state_key),
+            )
+        )
+
+        (is_all, known_absent, state_dict_ids) = self.store._state_group_cache.get(group)
+
+        self.assertEqual(is_all, False)
+        self.assertEqual(known_absent, set([
+            (e1.type, e1.state_key),
+            (e3.type, e3.state_key),
+            (e5.type, e5.state_key),
+        ]))
+        self.assertDictEqual(state_dict_ids, {
+            (e1.type, e1.state_key): e1.event_id,
+            (e3.type, e3.state_key): e3.event_id,
+            (e5.type, e5.state_key): e5.event_id,
+        })
+
+        ############################################
+        # test that things work with a partial cache
+
+        # test _get_some_state_from_cache correctly filters out members with types=[]
+        room_id = self.room.to_string()
+        (state_dict, is_all) = yield self.store._get_some_state_from_cache(
+            group, [], filtered_types=[EventTypes.Member]
+        )
+
+        self.assertEqual(is_all, False)
+        self.assertDictEqual({
+            (e1.type, e1.state_key): e1.event_id,
+        }, state_dict)
+
+        # test _get_some_state_from_cache correctly filters in members wildcard types
+        (state_dict, is_all) = yield self.store._get_some_state_from_cache(
+            group, [(EventTypes.Member, None)], filtered_types=[EventTypes.Member]
+        )
+
+        self.assertEqual(is_all, False)
+        self.assertDictEqual({
+            (e1.type, e1.state_key): e1.event_id,
+            (e3.type, e3.state_key): e3.event_id,
+            # e4 is overwritten by e5
+            (e5.type, e5.state_key): e5.event_id,
+        }, state_dict)
+
+        # test _get_some_state_from_cache correctly filters in members with specific types
+        (state_dict, is_all) = yield self.store._get_some_state_from_cache(
+            group, [(EventTypes.Member, e5.state_key)], filtered_types=[EventTypes.Member]
+        )
+
+        self.assertEqual(is_all, False)
+        self.assertDictEqual({
+            (e1.type, e1.state_key): e1.event_id,
+            (e5.type, e5.state_key): e5.event_id,
+        }, state_dict)
+
+        # test _get_some_state_from_cache correctly filters in members with specific types
+        # and no filtered_types
+        (state_dict, is_all) = yield self.store._get_some_state_from_cache(
+            group, [(EventTypes.Member, e5.state_key)], filtered_types=None
+        )
+
+        self.assertEqual(is_all, True)
+        self.assertDictEqual({
+            (e5.type, e5.state_key): e5.event_id,
+        }, state_dict)
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 8176a7dabd..ca8a7c907f 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -273,3 +273,104 @@ class DescriptorTestCase(unittest.TestCase):
         r = yield obj.fn(2, 3)
         self.assertEqual(r, 'chips')
         obj.mock.assert_not_called()
+
+
+class CachedListDescriptorTestCase(unittest.TestCase):
+    @defer.inlineCallbacks
+    def test_cache(self):
+        class Cls(object):
+            def __init__(self):
+                self.mock = mock.Mock()
+
+            @descriptors.cached()
+            def fn(self, arg1, arg2):
+                pass
+
+            @descriptors.cachedList("fn", "args1", inlineCallbacks=True)
+            def list_fn(self, args1, arg2):
+                assert (
+                    logcontext.LoggingContext.current_context().request == "c1"
+                )
+                # we want this to behave like an asynchronous function
+                yield run_on_reactor()
+                assert (
+                    logcontext.LoggingContext.current_context().request == "c1"
+                )
+                defer.returnValue(self.mock(args1, arg2))
+
+        with logcontext.LoggingContext() as c1:
+            c1.request = "c1"
+            obj = Cls()
+            obj.mock.return_value = {10: 'fish', 20: 'chips'}
+            d1 = obj.list_fn([10, 20], 2)
+            self.assertEqual(
+                logcontext.LoggingContext.current_context(),
+                logcontext.LoggingContext.sentinel,
+            )
+            r = yield d1
+            self.assertEqual(
+                logcontext.LoggingContext.current_context(),
+                c1
+            )
+            obj.mock.assert_called_once_with([10, 20], 2)
+            self.assertEqual(r, {10: 'fish', 20: 'chips'})
+            obj.mock.reset_mock()
+
+            # a call with different params should call the mock again
+            obj.mock.return_value = {30: 'peas'}
+            r = yield obj.list_fn([20, 30], 2)
+            obj.mock.assert_called_once_with([30], 2)
+            self.assertEqual(r, {20: 'chips', 30: 'peas'})
+            obj.mock.reset_mock()
+
+            # all the values should now be cached
+            r = yield obj.fn(10, 2)
+            self.assertEqual(r, 'fish')
+            r = yield obj.fn(20, 2)
+            self.assertEqual(r, 'chips')
+            r = yield obj.fn(30, 2)
+            self.assertEqual(r, 'peas')
+            r = yield obj.list_fn([10, 20, 30], 2)
+            obj.mock.assert_not_called()
+            self.assertEqual(r, {10: 'fish', 20: 'chips', 30: 'peas'})
+
+    @defer.inlineCallbacks
+    def test_invalidate(self):
+        """Make sure that invalidation callbacks are called."""
+        class Cls(object):
+            def __init__(self):
+                self.mock = mock.Mock()
+
+            @descriptors.cached()
+            def fn(self, arg1, arg2):
+                pass
+
+            @descriptors.cachedList("fn", "args1", inlineCallbacks=True)
+            def list_fn(self, args1, arg2):
+                # we want this to behave like an asynchronous function
+                yield run_on_reactor()
+                defer.returnValue(self.mock(args1, arg2))
+
+        obj = Cls()
+        invalidate0 = mock.Mock()
+        invalidate1 = mock.Mock()
+
+        # cache miss
+        obj.mock.return_value = {10: 'fish', 20: 'chips'}
+        r1 = yield obj.list_fn([10, 20], 2, on_invalidate=invalidate0)
+        obj.mock.assert_called_once_with([10, 20], 2)
+        self.assertEqual(r1, {10: 'fish', 20: 'chips'})
+        obj.mock.reset_mock()
+
+        # cache hit
+        r2 = yield obj.list_fn([10, 20], 2, on_invalidate=invalidate1)
+        obj.mock.assert_not_called()
+        self.assertEqual(r2, {10: 'fish', 20: 'chips'})
+
+        invalidate0.assert_not_called()
+        invalidate1.assert_not_called()
+
+        # now if we invalidate the keys, both invalidations should get called
+        obj.fn.invalidate((10, 2))
+        invalidate0.assert_called_once()
+        invalidate1.assert_called_once()
diff --git a/tests/utils.py b/tests/utils.py
index c3dbff8507..9bff3ff3b9 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -193,7 +193,7 @@ class MockHttpResource(HttpServer):
         self.prefix = prefix
 
     def trigger_get(self, path):
-        return self.trigger("GET", path, None)
+        return self.trigger(b"GET", path, None)
 
     @patch('twisted.web.http.Request')
     @defer.inlineCallbacks
@@ -227,7 +227,7 @@ class MockHttpResource(HttpServer):
 
         headers = {}
         if federation_auth:
-            headers[b"Authorization"] = ["X-Matrix origin=test,key=,sig="]
+            headers[b"Authorization"] = [b"X-Matrix origin=test,key=,sig="]
         mock_request.requestHeaders.getRawHeaders = mock_getRawHeaders(headers)
 
         # return the right path if the event requires it
@@ -241,6 +241,9 @@ class MockHttpResource(HttpServer):
         except Exception:
             pass
 
+        if isinstance(path, bytes):
+            path = path.decode('utf8')
+
         for (method, pattern, func) in self.callbacks:
             if http_method != method:
                 continue
@@ -249,7 +252,7 @@ class MockHttpResource(HttpServer):
             if matcher:
                 try:
                     args = [
-                        urlparse.unquote(u).decode("UTF-8")
+                        urlparse.unquote(u)
                         for u in matcher.groups()
                     ]