summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/__init__.py4
-rw-r--r--tests/config/test_tls.py25
-rw-r--r--tests/handlers/test_e2e_room_keys.py47
-rw-r--r--tests/handlers/test_federation.py81
-rw-r--r--tests/handlers/test_presence.py39
-rw-r--r--tests/handlers/test_roomlist.py39
-rw-r--r--tests/handlers/test_typing.py2
-rw-r--r--tests/patch_inline_callbacks.py94
-rw-r--r--tests/rest/admin/test_admin.py2
-rw-r--r--tests/rest/client/v1/test_rooms.py9
-rw-r--r--tests/rest/client/v2_alpha/test_account.py68
-rw-r--r--tests/rest/client/v2_alpha/test_filter.py2
-rw-r--r--tests/storage/test_end_to_end_keys.py12
-rw-r--r--tests/storage/test_event_federation.py2
-rw-r--r--tests/storage/test_monthly_active_users.py58
15 files changed, 291 insertions, 193 deletions
diff --git a/tests/__init__.py b/tests/__init__.py
index f7fc502f01..ed805db1c2 100644
--- a/tests/__init__.py
+++ b/tests/__init__.py
@@ -16,9 +16,9 @@
 
 from twisted.trial import util
 
-import tests.patch_inline_callbacks
+from synapse.util.patch_inline_callbacks import do_patch
 
 # attempt to do the patch before we load any synapse code
-tests.patch_inline_callbacks.do_patch()
+do_patch()
 
 util.DEFAULT_TIMEOUT_DURATION = 20
diff --git a/tests/config/test_tls.py b/tests/config/test_tls.py
index b02780772a..1be6ff563b 100644
--- a/tests/config/test_tls.py
+++ b/tests/config/test_tls.py
@@ -21,17 +21,24 @@ import yaml
 
 from OpenSSL import SSL
 
+from synapse.config._base import Config, RootConfig
 from synapse.config.tls import ConfigError, TlsConfig
 from synapse.crypto.context_factory import ClientTLSOptionsFactory
 
 from tests.unittest import TestCase
 
 
-class TestConfig(TlsConfig):
+class FakeServer(Config):
+    section = "server"
+
     def has_tls_listener(self):
         return False
 
 
+class TestConfig(RootConfig):
+    config_classes = [FakeServer, TlsConfig]
+
+
 class TLSConfigTests(TestCase):
     def test_warn_self_signed(self):
         """
@@ -202,13 +209,13 @@ s4niecZKPBizL6aucT59CsunNmmb5Glq8rlAcU+1ZTZZzGYqVYhF6axB9Qg=
         conf = TestConfig()
         conf.read_config(
             yaml.safe_load(
-                TestConfig().generate_config_section(
+                TestConfig().generate_config(
                     "/config_dir_path",
                     "my_super_secure_server",
                     "/data_dir_path",
-                    "/tls_cert_path",
-                    "tls_private_key",
-                    None,  # This is the acme_domain
+                    tls_certificate_path="/tls_cert_path",
+                    tls_private_key_path="tls_private_key",
+                    acme_domain=None,  # This is the acme_domain
                 )
             ),
             "/config_dir_path",
@@ -223,13 +230,13 @@ s4niecZKPBizL6aucT59CsunNmmb5Glq8rlAcU+1ZTZZzGYqVYhF6axB9Qg=
         conf = TestConfig()
         conf.read_config(
             yaml.safe_load(
-                TestConfig().generate_config_section(
+                TestConfig().generate_config(
                     "/config_dir_path",
                     "my_super_secure_server",
                     "/data_dir_path",
-                    "/tls_cert_path",
-                    "tls_private_key",
-                    "my_supe_secure_server",  # This is the acme_domain
+                    tls_certificate_path="/tls_cert_path",
+                    tls_private_key_path="tls_private_key",
+                    acme_domain="my_supe_secure_server",  # This is the acme_domain
                 )
             ),
             "/config_dir_path",
diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py
index c4503c1611..0bb96674a2 100644
--- a/tests/handlers/test_e2e_room_keys.py
+++ b/tests/handlers/test_e2e_room_keys.py
@@ -187,9 +187,8 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         self.assertEqual(res, 404)
 
     @defer.inlineCallbacks
-    def test_update_bad_version(self):
-        """Check that we get a 400 if the version in the body is missing or
-        doesn't match
+    def test_update_omitted_version(self):
+        """Check that the update succeeds if the version is missing from the body
         """
         version = yield self.handler.create_version(
             self.local_user,
@@ -197,19 +196,35 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         )
         self.assertEqual(version, "1")
 
-        res = None
-        try:
-            yield self.handler.update_version(
-                self.local_user,
-                version,
-                {
-                    "algorithm": "m.megolm_backup.v1",
-                    "auth_data": "revised_first_version_auth_data",
-                },
-            )
-        except errors.SynapseError as e:
-            res = e.code
-        self.assertEqual(res, 400)
+        yield self.handler.update_version(
+            self.local_user,
+            version,
+            {
+                "algorithm": "m.megolm_backup.v1",
+                "auth_data": "revised_first_version_auth_data",
+            },
+        )
+
+        # check we can retrieve it as the current version
+        res = yield self.handler.get_version_info(self.local_user)
+        self.assertDictEqual(
+            res,
+            {
+                "algorithm": "m.megolm_backup.v1",
+                "auth_data": "revised_first_version_auth_data",
+                "version": version,
+            },
+        )
+
+    @defer.inlineCallbacks
+    def test_update_bad_version(self):
+        """Check that we get a 400 if the version in the body doesn't match
+        """
+        version = yield self.handler.create_version(
+            self.local_user,
+            {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
+        )
+        self.assertEqual(version, "1")
 
         res = None
         try:
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
new file mode 100644
index 0000000000..d56220f403
--- /dev/null
+++ b/tests/handlers/test_federation.py
@@ -0,0 +1,81 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from synapse.api.constants import EventTypes
+from synapse.api.errors import AuthError, Codes
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+
+from tests import unittest
+
+
+class FederationTestCase(unittest.HomeserverTestCase):
+    servlets = [
+        admin.register_servlets,
+        login.register_servlets,
+        room.register_servlets,
+    ]
+
+    def make_homeserver(self, reactor, clock):
+        hs = self.setup_test_homeserver(http_client=None)
+        self.handler = hs.get_handlers().federation_handler
+        self.store = hs.get_datastore()
+        return hs
+
+    def test_exchange_revoked_invite(self):
+        user_id = self.register_user("kermit", "test")
+        tok = self.login("kermit", "test")
+
+        room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
+
+        # Send a 3PID invite event with an empty body so it's considered as a revoked one.
+        invite_token = "sometoken"
+        self.helper.send_state(
+            room_id=room_id,
+            event_type=EventTypes.ThirdPartyInvite,
+            state_key=invite_token,
+            body={},
+            tok=tok,
+        )
+
+        d = self.handler.on_exchange_third_party_invite_request(
+            room_id=room_id,
+            event_dict={
+                "type": EventTypes.Member,
+                "room_id": room_id,
+                "sender": user_id,
+                "state_key": "@someone:example.org",
+                "content": {
+                    "membership": "invite",
+                    "third_party_invite": {
+                        "display_name": "alice",
+                        "signed": {
+                            "mxid": "@alice:localhost",
+                            "token": invite_token,
+                            "signatures": {
+                                "magic.forest": {
+                                    "ed25519:3": "fQpGIW1Snz+pwLZu6sTy2aHy/DYWWTspTJRPyNp0PKkymfIsNffysMl6ObMMFdIJhk6g6pwlIqZ54rxo8SLmAg"
+                                }
+                            },
+                        },
+                    },
+                },
+            },
+        )
+
+        failure = self.get_failure(d, AuthError).value
+
+        self.assertEqual(failure.code, 403, failure)
+        self.assertEqual(failure.errcode, Codes.FORBIDDEN, failure)
+        self.assertEqual(failure.msg, "You are not invited to this room.")
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index f70c6e7d65..d4293b4312 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -22,6 +22,7 @@ from synapse.api.constants import EventTypes, Membership, PresenceState
 from synapse.events import room_version_to_event_format
 from synapse.events.builder import EventBuilder
 from synapse.handlers.presence import (
+    EXTERNAL_PROCESS_EXPIRY,
     FEDERATION_PING_INTERVAL,
     FEDERATION_TIMEOUT,
     IDLE_TIMER,
@@ -413,6 +414,44 @@ class PresenceTimeoutTestCase(unittest.TestCase):
         self.assertEquals(state, new_state)
 
 
+class PresenceHandlerTestCase(unittest.HomeserverTestCase):
+    def prepare(self, reactor, clock, hs):
+        self.presence_handler = hs.get_presence_handler()
+        self.clock = hs.get_clock()
+
+    def test_external_process_timeout(self):
+        """Test that if an external process doesn't update the records for a while
+        we time out their syncing users presence.
+        """
+        process_id = 1
+        user_id = "@test:server"
+
+        # Notify handler that a user is now syncing.
+        self.get_success(
+            self.presence_handler.update_external_syncs_row(
+                process_id, user_id, True, self.clock.time_msec()
+            )
+        )
+
+        # Check that if we wait a while without telling the handler the user has
+        # stopped syncing that their presence state doesn't get timed out.
+        self.reactor.advance(EXTERNAL_PROCESS_EXPIRY / 2)
+
+        state = self.get_success(
+            self.presence_handler.get_state(UserID.from_string(user_id))
+        )
+        self.assertEqual(state.state, PresenceState.ONLINE)
+
+        # Check that if the external process timeout fires, then the syncing
+        # user gets timed out
+        self.reactor.advance(EXTERNAL_PROCESS_EXPIRY)
+
+        state = self.get_success(
+            self.presence_handler.get_state(UserID.from_string(user_id))
+        )
+        self.assertEqual(state.state, PresenceState.OFFLINE)
+
+
 class PresenceJoinTestCase(unittest.HomeserverTestCase):
     """Tests remote servers get told about presence of users in the room when
     they join and when new local users join.
diff --git a/tests/handlers/test_roomlist.py b/tests/handlers/test_roomlist.py
deleted file mode 100644
index 61eebb6985..0000000000
--- a/tests/handlers/test_roomlist.py
+++ /dev/null
@@ -1,39 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2018 New Vector Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from synapse.handlers.room_list import RoomListNextBatch
-
-import tests.unittest
-import tests.utils
-
-
-class RoomListTestCase(tests.unittest.TestCase):
-    """ Tests RoomList's RoomListNextBatch. """
-
-    def setUp(self):
-        pass
-
-    def test_check_read_batch_tokens(self):
-        batch_token = RoomListNextBatch(
-            stream_ordering="abcdef",
-            public_room_stream_id="123",
-            current_limit=20,
-            direction_is_forward=True,
-        ).to_token()
-        next_batch = RoomListNextBatch.from_token(batch_token)
-        self.assertEquals(next_batch.stream_ordering, "abcdef")
-        self.assertEquals(next_batch.public_room_stream_id, "123")
-        self.assertEquals(next_batch.current_limit, 20)
-        self.assertEquals(next_batch.direction_is_forward, True)
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 1f2ef5d01f..67f1013051 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -139,7 +139,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
             defer.succeed(1)
         )
 
-        self.datastore.get_current_state_deltas.return_value = None
+        self.datastore.get_current_state_deltas.return_value = (0, None)
 
         self.datastore.get_to_device_stream_token = lambda: 0
         self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: ([], 0)
diff --git a/tests/patch_inline_callbacks.py b/tests/patch_inline_callbacks.py
deleted file mode 100644
index 220884311c..0000000000
--- a/tests/patch_inline_callbacks.py
+++ /dev/null
@@ -1,94 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2018 New Vector Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from __future__ import print_function
-
-import functools
-import sys
-
-from twisted.internet import defer
-from twisted.internet.defer import Deferred
-from twisted.python.failure import Failure
-
-
-def do_patch():
-    """
-    Patch defer.inlineCallbacks so that it checks the state of the logcontext on exit
-    """
-
-    from synapse.logging.context import LoggingContext
-
-    orig_inline_callbacks = defer.inlineCallbacks
-
-    def new_inline_callbacks(f):
-
-        orig = orig_inline_callbacks(f)
-
-        @functools.wraps(f)
-        def wrapped(*args, **kwargs):
-            start_context = LoggingContext.current_context()
-
-            try:
-                res = orig(*args, **kwargs)
-            except Exception:
-                if LoggingContext.current_context() != start_context:
-                    err = "%s changed context from %s to %s on exception" % (
-                        f,
-                        start_context,
-                        LoggingContext.current_context(),
-                    )
-                    print(err, file=sys.stderr)
-                    raise Exception(err)
-                raise
-
-            if not isinstance(res, Deferred) or res.called:
-                if LoggingContext.current_context() != start_context:
-                    err = "%s changed context from %s to %s" % (
-                        f,
-                        start_context,
-                        LoggingContext.current_context(),
-                    )
-                    # print the error to stderr because otherwise all we
-                    # see in travis-ci is the 500 error
-                    print(err, file=sys.stderr)
-                    raise Exception(err)
-                return res
-
-            if LoggingContext.current_context() != LoggingContext.sentinel:
-                err = (
-                    "%s returned incomplete deferred in non-sentinel context "
-                    "%s (start was %s)"
-                ) % (f, LoggingContext.current_context(), start_context)
-                print(err, file=sys.stderr)
-                raise Exception(err)
-
-            def check_ctx(r):
-                if LoggingContext.current_context() != start_context:
-                    err = "%s completion of %s changed context from %s to %s" % (
-                        "Failure" if isinstance(r, Failure) else "Success",
-                        f,
-                        start_context,
-                        LoggingContext.current_context(),
-                    )
-                    print(err, file=sys.stderr)
-                    raise Exception(err)
-                return r
-
-            res.addBoth(check_ctx)
-            return res
-
-        return wrapped
-
-    defer.inlineCallbacks = new_inline_callbacks
diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index 5877bb2133..d3a4f717f7 100644
--- a/tests/rest/admin/test_admin.py
+++ b/tests/rest/admin/test_admin.py
@@ -62,7 +62,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
         self.device_handler.check_device_registered = Mock(return_value="FAKE")
 
         self.datastore = Mock(return_value=Mock())
-        self.datastore.get_current_state_deltas = Mock(return_value=[])
+        self.datastore.get_current_state_deltas = Mock(return_value=(0, []))
 
         self.secrets = Mock()
 
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index fe741637f5..2f2ca74611 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -484,6 +484,15 @@ class RoomsCreateTestCase(RoomBase):
         self.render(request)
         self.assertEquals(400, channel.code)
 
+    def test_post_room_invitees_invalid_mxid(self):
+        # POST with invalid invitee, see https://github.com/matrix-org/synapse/issues/4088
+        # Note the trailing space in the MXID here!
+        request, channel = self.make_request(
+            "POST", "/createRoom", b'{"invite":["@alice:example.com "]}'
+        )
+        self.render(request)
+        self.assertEquals(400, channel.code)
+
 
 class RoomTopicTestCase(RoomBase):
     """ Tests /rooms/$room_id/topic REST events. """
diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py
index 920de41de4..0f51895b81 100644
--- a/tests/rest/client/v2_alpha/test_account.py
+++ b/tests/rest/client/v2_alpha/test_account.py
@@ -23,8 +23,8 @@ from email.parser import Parser
 import pkg_resources
 
 import synapse.rest.admin
-from synapse.api.constants import LoginType
-from synapse.rest.client.v1 import login
+from synapse.api.constants import LoginType, Membership
+from synapse.rest.client.v1 import login, room
 from synapse.rest.client.v2_alpha import account, register
 
 from tests import unittest
@@ -244,16 +244,66 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
         synapse.rest.admin.register_servlets_for_client_rest_resource,
         login.register_servlets,
         account.register_servlets,
+        room.register_servlets,
     ]
 
     def make_homeserver(self, reactor, clock):
-        hs = self.setup_test_homeserver()
-        return hs
+        self.hs = self.setup_test_homeserver()
+        return self.hs
 
     def test_deactivate_account(self):
         user_id = self.register_user("kermit", "test")
         tok = self.login("kermit", "test")
 
+        self.deactivate(user_id, tok)
+
+        store = self.hs.get_datastore()
+
+        # Check that the user has been marked as deactivated.
+        self.assertTrue(self.get_success(store.get_user_deactivated_status(user_id)))
+
+        # Check that this access token has been invalidated.
+        request, channel = self.make_request("GET", "account/whoami")
+        self.render(request)
+        self.assertEqual(request.code, 401)
+
+    @unittest.INFO
+    def test_pending_invites(self):
+        """Tests that deactivating a user rejects every pending invite for them."""
+        store = self.hs.get_datastore()
+
+        inviter_id = self.register_user("inviter", "test")
+        inviter_tok = self.login("inviter", "test")
+
+        invitee_id = self.register_user("invitee", "test")
+        invitee_tok = self.login("invitee", "test")
+
+        # Make @inviter:test invite @invitee:test in a new room.
+        room_id = self.helper.create_room_as(inviter_id, tok=inviter_tok)
+        self.helper.invite(
+            room=room_id, src=inviter_id, targ=invitee_id, tok=inviter_tok
+        )
+
+        # Make sure the invite is here.
+        pending_invites = self.get_success(store.get_invited_rooms_for_user(invitee_id))
+        self.assertEqual(len(pending_invites), 1, pending_invites)
+        self.assertEqual(pending_invites[0].room_id, room_id, pending_invites)
+
+        # Deactivate @invitee:test.
+        self.deactivate(invitee_id, invitee_tok)
+
+        # Check that the invite isn't there anymore.
+        pending_invites = self.get_success(store.get_invited_rooms_for_user(invitee_id))
+        self.assertEqual(len(pending_invites), 0, pending_invites)
+
+        # Check that the membership of @invitee:test in the room is now "leave".
+        memberships = self.get_success(
+            store.get_rooms_for_user_where_membership_is(invitee_id, [Membership.LEAVE])
+        )
+        self.assertEqual(len(memberships), 1, memberships)
+        self.assertEqual(memberships[0].room_id, room_id, memberships)
+
+    def deactivate(self, user_id, tok):
         request_data = json.dumps(
             {
                 "auth": {
@@ -269,13 +319,3 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
         )
         self.render(request)
         self.assertEqual(request.code, 200)
-
-        store = self.hs.get_datastore()
-
-        # Check that the user has been marked as deactivated.
-        self.assertTrue(self.get_success(store.get_user_deactivated_status(user_id)))
-
-        # Check that this access token has been invalidated.
-        request, channel = self.make_request("GET", "account/whoami")
-        self.render(request)
-        self.assertEqual(request.code, 401)
diff --git a/tests/rest/client/v2_alpha/test_filter.py b/tests/rest/client/v2_alpha/test_filter.py
index f42a8efbf4..e0e9e94fbf 100644
--- a/tests/rest/client/v2_alpha/test_filter.py
+++ b/tests/rest/client/v2_alpha/test_filter.py
@@ -92,7 +92,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
         )
         self.render(request)
 
-        self.assertEqual(channel.result["code"], b"400")
+        self.assertEqual(channel.result["code"], b"404")
         self.assertEquals(channel.json_body["errcode"], Codes.NOT_FOUND)
 
     # Currently invalid params do not have an appropriate errcode
diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py
index c8ece15284..398d546280 100644
--- a/tests/storage/test_end_to_end_keys.py
+++ b/tests/storage/test_end_to_end_keys.py
@@ -38,7 +38,7 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
         self.assertIn("user", res)
         self.assertIn("device", res["user"])
         dev = res["user"]["device"]
-        self.assertDictContainsSubset({"keys": json, "device_display_name": None}, dev)
+        self.assertDictContainsSubset(json, dev)
 
     @defer.inlineCallbacks
     def test_reupload_key(self):
@@ -68,7 +68,7 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
         self.assertIn("device", res["user"])
         dev = res["user"]["device"]
         self.assertDictContainsSubset(
-            {"keys": json, "device_display_name": "display_name"}, dev
+            {"key": "value", "unsigned": {"device_display_name": "display_name"}}, dev
         )
 
     @defer.inlineCallbacks
@@ -80,10 +80,10 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
         yield self.store.store_device("user2", "device1", None)
         yield self.store.store_device("user2", "device2", None)
 
-        yield self.store.set_e2e_device_keys("user1", "device1", now, "json11")
-        yield self.store.set_e2e_device_keys("user1", "device2", now, "json12")
-        yield self.store.set_e2e_device_keys("user2", "device1", now, "json21")
-        yield self.store.set_e2e_device_keys("user2", "device2", now, "json22")
+        yield self.store.set_e2e_device_keys("user1", "device1", now, {"key": "json11"})
+        yield self.store.set_e2e_device_keys("user1", "device2", now, {"key": "json12"})
+        yield self.store.set_e2e_device_keys("user2", "device1", now, {"key": "json21"})
+        yield self.store.set_e2e_device_keys("user2", "device2", now, {"key": "json22"})
 
         res = yield self.store.get_e2e_device_keys(
             (("user1", "device1"), ("user2", "device2"))
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index b58386994e..2fe50377f8 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -57,7 +57,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.TestCase):
                     "(event_id, algorithm, hash) "
                     "VALUES (?, 'sha256', ?)"
                 ),
-                (event_id, b"ffff"),
+                (event_id, bytearray(b"ffff")),
             )
 
         for i in range(0, 11):
diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py
index 1494650d10..90a63dc477 100644
--- a/tests/storage/test_monthly_active_users.py
+++ b/tests/storage/test_monthly_active_users.py
@@ -50,6 +50,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
             {"medium": "email", "address": user2_email},
             {"medium": "email", "address": user3_email},
         ]
+        self.hs.config.mau_limits_reserved_threepids = threepids
         # -1 because user3 is a support user and does not count
         user_num = len(threepids) - 1
 
@@ -84,6 +85,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
         self.hs.config.max_mau_value = 0
 
         self.reactor.advance(FORTY_DAYS)
+        self.hs.config.max_mau_value = 5
 
         self.store.reap_monthly_active_users()
         self.pump()
@@ -147,9 +149,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
         self.store.reap_monthly_active_users()
         self.pump()
         count = self.store.get_monthly_active_count()
-        self.assertEquals(
-            self.get_success(count), initial_users - self.hs.config.max_mau_value
-        )
+        self.assertEquals(self.get_success(count), self.hs.config.max_mau_value)
 
         self.reactor.advance(FORTY_DAYS)
         self.store.reap_monthly_active_users()
@@ -158,6 +158,44 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
         count = self.store.get_monthly_active_count()
         self.assertEquals(self.get_success(count), 0)
 
+    def test_reap_monthly_active_users_reserved_users(self):
+        """ Tests that reaping correctly handles reaping where reserved users are
+        present"""
+
+        self.hs.config.max_mau_value = 5
+        initial_users = 5
+        reserved_user_number = initial_users - 1
+        threepids = []
+        for i in range(initial_users):
+            user = "@user%d:server" % i
+            email = "user%d@example.com" % i
+            self.get_success(self.store.upsert_monthly_active_user(user))
+            threepids.append({"medium": "email", "address": email})
+            # Need to ensure that the most recent entries in the
+            # monthly_active_users table are reserved
+            now = int(self.hs.get_clock().time_msec())
+            if i != 0:
+                self.get_success(
+                    self.store.register_user(user_id=user, password_hash=None)
+                )
+                self.get_success(
+                    self.store.user_add_threepid(user, "email", email, now, now)
+                )
+
+        self.hs.config.mau_limits_reserved_threepids = threepids
+        self.store.runInteraction(
+            "initialise", self.store._initialise_reserved_users, threepids
+        )
+        count = self.store.get_monthly_active_count()
+        self.assertTrue(self.get_success(count), initial_users)
+
+        users = self.store.get_registered_reserved_users()
+        self.assertEquals(len(self.get_success(users)), reserved_user_number)
+
+        self.get_success(self.store.reap_monthly_active_users())
+        count = self.store.get_monthly_active_count()
+        self.assertEquals(self.get_success(count), self.hs.config.max_mau_value)
+
     def test_populate_monthly_users_is_guest(self):
         # Test that guest users are not added to mau list
         user_id = "@user_id:host"
@@ -192,12 +230,13 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
 
     def test_get_reserved_real_user_account(self):
         # Test no reserved users, or reserved threepids
-        count = self.store.get_registered_reserved_users_count()
-        self.assertEquals(self.get_success(count), 0)
+        users = self.get_success(self.store.get_registered_reserved_users())
+        self.assertEquals(len(users), 0)
         # Test reserved users but no registered users
 
         user1 = "@user1:example.com"
         user2 = "@user2:example.com"
+
         user1_email = "user1@example.com"
         user2_email = "user2@example.com"
         threepids = [
@@ -210,8 +249,8 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
         )
 
         self.pump()
-        count = self.store.get_registered_reserved_users_count()
-        self.assertEquals(self.get_success(count), 0)
+        users = self.get_success(self.store.get_registered_reserved_users())
+        self.assertEquals(len(users), 0)
 
         # Test reserved registed users
         self.store.register_user(user_id=user1, password_hash=None)
@@ -221,8 +260,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
         now = int(self.hs.get_clock().time_msec())
         self.store.user_add_threepid(user1, "email", user1_email, now, now)
         self.store.user_add_threepid(user2, "email", user2_email, now, now)
-        count = self.store.get_registered_reserved_users_count()
-        self.assertEquals(self.get_success(count), len(threepids))
+
+        users = self.get_success(self.store.get_registered_reserved_users())
+        self.assertEquals(len(users), len(threepids))
 
     def test_support_user_not_add_to_mau_limits(self):
         support_user_id = "@support:test"