summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/api/auth.py100
-rw-r--r--synapse/config/logger.py14
-rw-r--r--synapse/crypto/keyring.py27
-rw-r--r--synapse/metrics/__init__.py11
-rw-r--r--synapse/rest/client/v2_alpha/account.py1
-rw-r--r--synapse/storage/schema/delta/23/drop_state_index.sql16
-rw-r--r--tests/api/test_auth.py145
-rw-r--r--tests/rest/client/v1/test_presence.py4
-rw-r--r--tests/rest/client/v1/test_rooms.py14
-rw-r--r--tests/rest/client/v1/test_typing.py2
-rw-r--r--tests/rest/client/v1/utils.py3
-rw-r--r--tests/rest/client/v2_alpha/__init__.py2
-rw-r--r--tests/test_state.py37
13 files changed, 335 insertions, 41 deletions
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 0c0d678562..df788230fa 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -23,6 +23,7 @@ from synapse.util.logutils import log_function
 from synapse.types import UserID, EventID
 
 import logging
+import pymacaroons
 
 logger = logging.getLogger(__name__)
 
@@ -40,6 +41,12 @@ class Auth(object):
         self.store = hs.get_datastore()
         self.state = hs.get_state_handler()
         self.TOKEN_NOT_FOUND_HTTP_STATUS = 401
+        self._KNOWN_CAVEAT_PREFIXES = set([
+            "gen = ",
+            "type = ",
+            "time < ",
+            "user_id = ",
+        ])
 
     def check(self, event, auth_events):
         """ Checks if this event is correctly authed.
@@ -65,6 +72,14 @@ class Auth(object):
                 # FIXME
                 return True
 
+            creation_event = auth_events.get((EventTypes.Create, ""), None)
+
+            if not creation_event:
+                raise SynapseError(
+                    403,
+                    "Room %r does not exist" % (event.room_id,)
+                )
+
             # FIXME: Temp hack
             if event.type == EventTypes.Aliases:
                 return True
@@ -359,7 +374,7 @@ class Auth(object):
             except KeyError:
                 pass  # normal users won't have the user_id query parameter set.
 
-            user_info = yield self.get_user_by_access_token(access_token)
+            user_info = yield self._get_user_by_access_token(access_token)
             user = user_info["user"]
             token_id = user_info["token_id"]
 
@@ -386,7 +401,7 @@ class Auth(object):
             )
 
     @defer.inlineCallbacks
-    def get_user_by_access_token(self, token):
+    def _get_user_by_access_token(self, token):
         """ Get a registered user's ID.
 
         Args:
@@ -396,6 +411,86 @@ class Auth(object):
         Raises:
             AuthError if no user by that token exists or the token is invalid.
         """
+        try:
+            ret = yield self._get_user_from_macaroon(token)
+        except AuthError:
+            # TODO(daniel): Remove this fallback when all existing access tokens
+            # have been re-issued as macaroons.
+            ret = yield self._look_up_user_by_access_token(token)
+        defer.returnValue(ret)
+
+    @defer.inlineCallbacks
+    def _get_user_from_macaroon(self, macaroon_str):
+        try:
+            macaroon = pymacaroons.Macaroon.deserialize(macaroon_str)
+            self._validate_macaroon(macaroon)
+
+            user_prefix = "user_id = "
+            for caveat in macaroon.caveats:
+                if caveat.caveat_id.startswith(user_prefix):
+                    user = UserID.from_string(caveat.caveat_id[len(user_prefix):])
+                    # This codepath exists so that we can actually return a
+                    # token ID, because we use token IDs in place of device
+                    # identifiers throughout the codebase.
+                    # TODO(daniel): Remove this fallback when device IDs are
+                    # properly implemented.
+                    ret = yield self._look_up_user_by_access_token(macaroon_str)
+                    if ret["user"] != user:
+                        logger.error(
+                            "Macaroon user (%s) != DB user (%s)",
+                            user,
+                            ret["user"]
+                        )
+                        raise AuthError(
+                            self.TOKEN_NOT_FOUND_HTTP_STATUS,
+                            "User mismatch in macaroon",
+                            errcode=Codes.UNKNOWN_TOKEN
+                        )
+                    defer.returnValue(ret)
+            raise AuthError(
+                self.TOKEN_NOT_FOUND_HTTP_STATUS, "No user caveat in macaroon",
+                errcode=Codes.UNKNOWN_TOKEN
+            )
+        except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError):
+            raise AuthError(
+                self.TOKEN_NOT_FOUND_HTTP_STATUS, "Invalid macaroon passed.",
+                errcode=Codes.UNKNOWN_TOKEN
+            )
+
+    def _validate_macaroon(self, macaroon):
+        v = pymacaroons.Verifier()
+        v.satisfy_exact("gen = 1")
+        v.satisfy_exact("type = access")
+        v.satisfy_general(lambda c: c.startswith("user_id = "))
+        v.satisfy_general(self._verify_expiry)
+        v.verify(macaroon, self.hs.config.macaroon_secret_key)
+
+        v = pymacaroons.Verifier()
+        v.satisfy_general(self._verify_recognizes_caveats)
+        v.verify(macaroon, self.hs.config.macaroon_secret_key)
+
+    def _verify_expiry(self, caveat):
+        prefix = "time < "
+        if not caveat.startswith(prefix):
+            return False
+        # TODO(daniel): Enable expiry check when clients actually know how to
+        # refresh tokens. (And remember to enable the tests)
+        return True
+        expiry = int(caveat[len(prefix):])
+        now = self.hs.get_clock().time_msec()
+        return now < expiry
+
+    def _verify_recognizes_caveats(self, caveat):
+        first_space = caveat.find(" ")
+        if first_space < 0:
+            return False
+        second_space = caveat.find(" ", first_space + 1)
+        if second_space < 0:
+            return False
+        return caveat[:second_space + 1] in self._KNOWN_CAVEAT_PREFIXES
+
+    @defer.inlineCallbacks
+    def _look_up_user_by_access_token(self, token):
         ret = yield self.store.get_user_by_access_token(token)
         if not ret:
             raise AuthError(
@@ -406,7 +501,6 @@ class Auth(object):
             "user": UserID.from_string(ret.get("name")),
             "token_id": ret.get("token_id", None),
         }
-
         defer.returnValue(user_info)
 
     @defer.inlineCallbacks
diff --git a/synapse/config/logger.py b/synapse/config/logger.py
index fa542623b7..daca698d0c 100644
--- a/synapse/config/logger.py
+++ b/synapse/config/logger.py
@@ -21,6 +21,7 @@ import logging.config
 import yaml
 from string import Template
 import os
+import signal
 
 
 DEFAULT_LOG_CONFIG = Template("""
@@ -142,6 +143,19 @@ class LoggingConfig(Config):
                 handler = logging.handlers.RotatingFileHandler(
                     self.log_file, maxBytes=(1000 * 1000 * 100), backupCount=3
                 )
+
+                def sighup(signum, stack):
+                    logger.info("Closing log file due to SIGHUP")
+                    handler.doRollover()
+                    logger.info("Opened new log file due to SIGHUP")
+
+                # TODO(paul): obviously this is a terrible mechanism for
+                #   stealing SIGHUP, because it means no other part of synapse
+                #   can use it instead. If we want to catch SIGHUP anywhere
+                #   else as well, I'd suggest we find a nicer way to broadcast
+                #   it around.
+                if getattr(signal, "SIGHUP"):
+                    signal.signal(signal.SIGHUP, sighup)
             else:
                 handler = logging.StreamHandler()
             handler.setFormatter(formatter)
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index e98a625fea..1b1b31c5c0 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -162,7 +162,9 @@ class Keyring(object):
         def remove_deferreds(res, server_name, group_id):
             server_to_gids[server_name].discard(group_id)
             if not server_to_gids[server_name]:
-                server_to_deferred.pop(server_name).callback(None)
+                d = server_to_deferred.pop(server_name, None)
+                if d:
+                    d.callback(None)
             return res
 
         for g_id, deferred in deferreds.items():
@@ -200,8 +202,15 @@ class Keyring(object):
             else:
                 break
 
-        for server_name, deferred in server_to_deferred:
-            self.key_downloads[server_name] = ObservableDeferred(deferred)
+        for server_name, deferred in server_to_deferred.items():
+            d = ObservableDeferred(deferred)
+            self.key_downloads[server_name] = d
+
+            def rm(r, server_name):
+                self.key_downloads.pop(server_name, None)
+                return r
+
+            d.addBoth(rm, server_name)
 
     def get_server_verify_keys(self, group_id_to_group, group_id_to_deferred):
         """Takes a dict of KeyGroups and tries to find at least one key for
@@ -220,9 +229,8 @@ class Keyring(object):
             merged_results = {}
 
             missing_keys = {
-                group.server_name: key_id
+                group.server_name: set(group.key_ids)
                 for group in group_id_to_group.values()
-                for key_id in group.key_ids
             }
 
             for fn in key_fetch_fns:
@@ -279,16 +287,15 @@ class Keyring(object):
     def get_keys_from_store(self, server_name_and_key_ids):
         res = yield defer.gatherResults(
             [
-                self.store.get_server_verify_keys(server_name, key_ids)
+                self.store.get_server_verify_keys(
+                    server_name, key_ids
+                ).addCallback(lambda ks, server: (server, ks), server_name)
                 for server_name, key_ids in server_name_and_key_ids
             ],
             consumeErrors=True,
         ).addErrback(unwrapFirstError)
 
-        defer.returnValue(dict(zip(
-            [server_name for server_name, _ in server_name_and_key_ids],
-            res
-        )))
+        defer.returnValue(dict(res))
 
     @defer.inlineCallbacks
     def get_keys_from_perspectives(self, server_name_and_key_ids):
diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py
index d7bcad8a8a..943d637459 100644
--- a/synapse/metrics/__init__.py
+++ b/synapse/metrics/__init__.py
@@ -17,7 +17,7 @@
 from __future__ import absolute_import
 
 import logging
-from resource import getrusage, getpagesize, RUSAGE_SELF
+from resource import getrusage, RUSAGE_SELF
 import functools
 import os
 import stat
@@ -100,7 +100,6 @@ def render_all():
 # process resource usage
 
 rusage = None
-PAGE_SIZE = getpagesize()
 
 
 def update_resource_metrics():
@@ -113,8 +112,8 @@ resource_metrics = get_metrics_for("process.resource")
 resource_metrics.register_callback("utime", lambda: rusage.ru_utime * 1000)
 resource_metrics.register_callback("stime", lambda: rusage.ru_stime * 1000)
 
-# pages
-resource_metrics.register_callback("maxrss", lambda: rusage.ru_maxrss * PAGE_SIZE)
+# kilobytes
+resource_metrics.register_callback("maxrss", lambda: rusage.ru_maxrss * 1024)
 
 TYPES = {
     stat.S_IFSOCK: "SOCK",
@@ -131,6 +130,10 @@ def _process_fds():
     counts = {(k,): 0 for k in TYPES.values()}
     counts[("other",)] = 0
 
+    # Not every OS will have a /proc/self/fd directory
+    if not os.path.exists("/proc/self/fd"):
+        return counts
+
     for fd in os.listdir("/proc/self/fd"):
         try:
             s = os.stat("/proc/self/fd/%s" % (fd))
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index b5edffdb60..4692ba413c 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -96,6 +96,7 @@ class ThreepidRestServlet(RestServlet):
         self.hs = hs
         self.identity_handler = hs.get_handlers().identity_handler
         self.auth = hs.get_auth()
+        self.auth_handler = hs.get_handlers().auth_handler
 
     @defer.inlineCallbacks
     def on_GET(self, request):
diff --git a/synapse/storage/schema/delta/23/drop_state_index.sql b/synapse/storage/schema/delta/23/drop_state_index.sql
new file mode 100644
index 0000000000..07d0ea5cb2
--- /dev/null
+++ b/synapse/storage/schema/delta/23/drop_state_index.sql
@@ -0,0 +1,16 @@
+/* Copyright 2015 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+DROP INDEX IF EXISTS state_groups_state_tuple;
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index 22fc804331..c96273480d 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -19,17 +19,21 @@ from mock import Mock
 
 from synapse.api.auth import Auth
 from synapse.api.errors import AuthError
+from synapse.types import UserID
+from tests.utils import setup_test_homeserver
+
+import pymacaroons
 
 
 class AuthTestCase(unittest.TestCase):
 
+    @defer.inlineCallbacks
     def setUp(self):
         self.state_handler = Mock()
         self.store = Mock()
 
-        self.hs = Mock()
+        self.hs = yield setup_test_homeserver(handlers=None)
         self.hs.get_datastore = Mock(return_value=self.store)
-        self.hs.get_state_handler = Mock(return_value=self.state_handler)
         self.auth = Auth(self.hs)
 
         self.test_user = "@foo:bar"
@@ -133,3 +137,140 @@ class AuthTestCase(unittest.TestCase):
         request.requestHeaders.getRawHeaders = Mock(return_value=[""])
         d = self.auth.get_user_by_req(request)
         self.failureResultOf(d, AuthError)
+
+    @defer.inlineCallbacks
+    def test_get_user_from_macaroon(self):
+        # TODO(danielwh): Remove this mock when we remove the
+        # get_user_by_access_token fallback.
+        self.store.get_user_by_access_token = Mock(
+            return_value={"name": "@baldrick:matrix.org"}
+        )
+
+        user_id = "@baldrick:matrix.org"
+        macaroon = pymacaroons.Macaroon(
+            location=self.hs.config.server_name,
+            identifier="key",
+            key=self.hs.config.macaroon_secret_key)
+        macaroon.add_first_party_caveat("gen = 1")
+        macaroon.add_first_party_caveat("type = access")
+        macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
+        user_info = yield self.auth._get_user_from_macaroon(macaroon.serialize())
+        user = user_info["user"]
+        self.assertEqual(UserID.from_string(user_id), user)
+
+    @defer.inlineCallbacks
+    def test_get_user_from_macaroon_user_db_mismatch(self):
+        self.store.get_user_by_access_token = Mock(
+            return_value={"name": "@percy:matrix.org"}
+        )
+
+        user = "@baldrick:matrix.org"
+        macaroon = pymacaroons.Macaroon(
+            location=self.hs.config.server_name,
+            identifier="key",
+            key=self.hs.config.macaroon_secret_key)
+        macaroon.add_first_party_caveat("gen = 1")
+        macaroon.add_first_party_caveat("type = access")
+        macaroon.add_first_party_caveat("user_id = %s" % (user,))
+        with self.assertRaises(AuthError) as cm:
+            yield self.auth._get_user_from_macaroon(macaroon.serialize())
+        self.assertEqual(401, cm.exception.code)
+        self.assertIn("User mismatch", cm.exception.msg)
+
+    @defer.inlineCallbacks
+    def test_get_user_from_macaroon_missing_caveat(self):
+        # TODO(danielwh): Remove this mock when we remove the
+        # get_user_by_access_token fallback.
+        self.store.get_user_by_access_token = Mock(
+            return_value={"name": "@baldrick:matrix.org"}
+        )
+
+        macaroon = pymacaroons.Macaroon(
+            location=self.hs.config.server_name,
+            identifier="key",
+            key=self.hs.config.macaroon_secret_key)
+        macaroon.add_first_party_caveat("gen = 1")
+        macaroon.add_first_party_caveat("type = access")
+
+        with self.assertRaises(AuthError) as cm:
+            yield self.auth._get_user_from_macaroon(macaroon.serialize())
+        self.assertEqual(401, cm.exception.code)
+        self.assertIn("No user caveat", cm.exception.msg)
+
+    @defer.inlineCallbacks
+    def test_get_user_from_macaroon_wrong_key(self):
+        # TODO(danielwh): Remove this mock when we remove the
+        # get_user_by_access_token fallback.
+        self.store.get_user_by_access_token = Mock(
+            return_value={"name": "@baldrick:matrix.org"}
+        )
+
+        user = "@baldrick:matrix.org"
+        macaroon = pymacaroons.Macaroon(
+            location=self.hs.config.server_name,
+            identifier="key",
+            key=self.hs.config.macaroon_secret_key + "wrong")
+        macaroon.add_first_party_caveat("gen = 1")
+        macaroon.add_first_party_caveat("type = access")
+        macaroon.add_first_party_caveat("user_id = %s" % (user,))
+
+        with self.assertRaises(AuthError) as cm:
+            yield self.auth._get_user_from_macaroon(macaroon.serialize())
+        self.assertEqual(401, cm.exception.code)
+        self.assertIn("Invalid macaroon", cm.exception.msg)
+
+    @defer.inlineCallbacks
+    def test_get_user_from_macaroon_unknown_caveat(self):
+        # TODO(danielwh): Remove this mock when we remove the
+        # get_user_by_access_token fallback.
+        self.store.get_user_by_access_token = Mock(
+            return_value={"name": "@baldrick:matrix.org"}
+        )
+
+        user = "@baldrick:matrix.org"
+        macaroon = pymacaroons.Macaroon(
+            location=self.hs.config.server_name,
+            identifier="key",
+            key=self.hs.config.macaroon_secret_key)
+        macaroon.add_first_party_caveat("gen = 1")
+        macaroon.add_first_party_caveat("type = access")
+        macaroon.add_first_party_caveat("user_id = %s" % (user,))
+        macaroon.add_first_party_caveat("cunning > fox")
+
+        with self.assertRaises(AuthError) as cm:
+            yield self.auth._get_user_from_macaroon(macaroon.serialize())
+        self.assertEqual(401, cm.exception.code)
+        self.assertIn("Invalid macaroon", cm.exception.msg)
+
+    @defer.inlineCallbacks
+    def test_get_user_from_macaroon_expired(self):
+        # TODO(danielwh): Remove this mock when we remove the
+        # get_user_by_access_token fallback.
+        self.store.get_user_by_access_token = Mock(
+            return_value={"name": "@baldrick:matrix.org"}
+        )
+
+        self.store.get_user_by_access_token = Mock(
+            return_value={"name": "@baldrick:matrix.org"}
+        )
+
+        user = "@baldrick:matrix.org"
+        macaroon = pymacaroons.Macaroon(
+            location=self.hs.config.server_name,
+            identifier="key",
+            key=self.hs.config.macaroon_secret_key)
+        macaroon.add_first_party_caveat("gen = 1")
+        macaroon.add_first_party_caveat("type = access")
+        macaroon.add_first_party_caveat("user_id = %s" % (user,))
+        macaroon.add_first_party_caveat("time < 1") # ms
+
+        self.hs.clock.now = 5000 # seconds
+
+        yield self.auth._get_user_from_macaroon(macaroon.serialize())
+        # TODO(daniel): Turn on the check that we validate expiration, when we
+        # validate expiration (and remove the above line, which will start
+        # throwing).
+        # with self.assertRaises(AuthError) as cm:
+        #     yield self.auth._get_user_from_macaroon(macaroon.serialize())
+        # self.assertEqual(401, cm.exception.code)
+        # self.assertIn("Invalid macaroon", cm.exception.msg)
diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py
index 91547bdd06..2ee3da0b34 100644
--- a/tests/rest/client/v1/test_presence.py
+++ b/tests/rest/client/v1/test_presence.py
@@ -76,7 +76,7 @@ class PresenceStateTestCase(unittest.TestCase):
                 "token_id": 1,
             }
 
-        hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
+        hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
 
         room_member_handler = hs.handlers.room_member_handler = Mock(
             spec=[
@@ -169,7 +169,7 @@ class PresenceListTestCase(unittest.TestCase):
             ]
         )
 
-        hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
+        hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
 
         presence.register_servlets(hs, self.mock_resource)
 
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index 34ab47d02e..9fb2bfb315 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -59,7 +59,7 @@ class RoomPermissionsTestCase(RestTestCase):
                 "user": UserID.from_string(self.auth_user_id),
                 "token_id": 1,
             }
-        hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
+        hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
 
         def _insert_client_ip(*args, **kwargs):
             return defer.succeed(None)
@@ -444,7 +444,7 @@ class RoomsMemberListTestCase(RestTestCase):
                 "user": UserID.from_string(self.auth_user_id),
                 "token_id": 1,
             }
-        hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
+        hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
 
         def _insert_client_ip(*args, **kwargs):
             return defer.succeed(None)
@@ -522,7 +522,7 @@ class RoomsCreateTestCase(RestTestCase):
                 "user": UserID.from_string(self.auth_user_id),
                 "token_id": 1,
             }
-        hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
+        hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
 
         def _insert_client_ip(*args, **kwargs):
             return defer.succeed(None)
@@ -614,7 +614,7 @@ class RoomTopicTestCase(RestTestCase):
                 "token_id": 1,
             }
 
-        hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
+        hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
 
         def _insert_client_ip(*args, **kwargs):
             return defer.succeed(None)
@@ -718,7 +718,7 @@ class RoomMemberStateTestCase(RestTestCase):
                 "user": UserID.from_string(self.auth_user_id),
                 "token_id": 1,
             }
-        hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
+        hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
 
         def _insert_client_ip(*args, **kwargs):
             return defer.succeed(None)
@@ -843,7 +843,7 @@ class RoomMessagesTestCase(RestTestCase):
                 "user": UserID.from_string(self.auth_user_id),
                 "token_id": 1,
             }
-        hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
+        hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
 
         def _insert_client_ip(*args, **kwargs):
             return defer.succeed(None)
@@ -938,7 +938,7 @@ class RoomInitialSyncTestCase(RestTestCase):
                 "user": UserID.from_string(self.auth_user_id),
                 "token_id": 1,
             }
-        hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
+        hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
 
         def _insert_client_ip(*args, **kwargs):
             return defer.succeed(None)
diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py
index 1c4519406d..6395ce79db 100644
--- a/tests/rest/client/v1/test_typing.py
+++ b/tests/rest/client/v1/test_typing.py
@@ -67,7 +67,7 @@ class RoomTypingTestCase(RestTestCase):
                 "token_id": 1,
             }
 
-        hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
+        hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
 
         def _insert_client_ip(*args, **kwargs):
             return defer.succeed(None)
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index c472d53043..85096a0326 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -37,9 +37,6 @@ class RestTestCase(unittest.TestCase):
         self.mock_resource = None
         self.auth_user_id = None
 
-    def mock_get_user_by_access_token(self, token=None):
-        return self.auth_user_id
-
     @defer.inlineCallbacks
     def create_room_as(self, room_creator, is_public=True, tok=None):
         temp_id = self.auth_user_id
diff --git a/tests/rest/client/v2_alpha/__init__.py b/tests/rest/client/v2_alpha/__init__.py
index ef972a53aa..f45570a1c0 100644
--- a/tests/rest/client/v2_alpha/__init__.py
+++ b/tests/rest/client/v2_alpha/__init__.py
@@ -48,7 +48,7 @@ class V2AlphaRestTestCase(unittest.TestCase):
                 "user": UserID.from_string(self.USER_ID),
                 "token_id": 1,
             }
-        hs.get_auth().get_user_by_access_token = _get_user_by_access_token
+        hs.get_auth()._get_user_by_access_token = _get_user_by_access_token
 
         for r in self.TO_REGISTER:
             r.register_servlets(hs, self.mock_resource)
diff --git a/tests/test_state.py b/tests/test_state.py
index 5845358754..55f37c521f 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -204,8 +204,8 @@ class StateTestCase(unittest.TestCase):
             nodes={
                 "START": DictObj(
                     type=EventTypes.Create,
-                    state_key="creator",
-                    content={"membership": "@user_id:example.com"},
+                    state_key="",
+                    content={"creator": "@user_id:example.com"},
                     depth=1,
                 ),
                 "A": DictObj(
@@ -259,8 +259,8 @@ class StateTestCase(unittest.TestCase):
             nodes={
                 "START": DictObj(
                     type=EventTypes.Create,
-                    state_key="creator",
-                    content={"membership": "@user_id:example.com"},
+                    state_key="",
+                    content={"creator": "@user_id:example.com"},
                     depth=1,
                 ),
                 "A": DictObj(
@@ -432,13 +432,19 @@ class StateTestCase(unittest.TestCase):
     def test_resolve_message_conflict(self):
         event = create_event(type="test_message", name="event")
 
+        creation = create_event(
+            type=EventTypes.Create, state_key=""
+        )
+
         old_state_1 = [
+            creation,
             create_event(type="test1", state_key="1"),
             create_event(type="test1", state_key="2"),
             create_event(type="test2", state_key=""),
         ]
 
         old_state_2 = [
+            creation,
             create_event(type="test1", state_key="1"),
             create_event(type="test3", state_key="2"),
             create_event(type="test4", state_key=""),
@@ -446,7 +452,7 @@ class StateTestCase(unittest.TestCase):
 
         context = yield self._get_context(event, old_state_1, old_state_2)
 
-        self.assertEqual(len(context.current_state), 5)
+        self.assertEqual(len(context.current_state), 6)
 
         self.assertIsNone(context.state_group)
 
@@ -454,13 +460,19 @@ class StateTestCase(unittest.TestCase):
     def test_resolve_state_conflict(self):
         event = create_event(type="test4", state_key="", name="event")
 
+        creation = create_event(
+            type=EventTypes.Create, state_key=""
+        )
+
         old_state_1 = [
+            creation,
             create_event(type="test1", state_key="1"),
             create_event(type="test1", state_key="2"),
             create_event(type="test2", state_key=""),
         ]
 
         old_state_2 = [
+            creation,
             create_event(type="test1", state_key="1"),
             create_event(type="test3", state_key="2"),
             create_event(type="test4", state_key=""),
@@ -468,7 +480,7 @@ class StateTestCase(unittest.TestCase):
 
         context = yield self._get_context(event, old_state_1, old_state_2)
 
-        self.assertEqual(len(context.current_state), 5)
+        self.assertEqual(len(context.current_state), 6)
 
         self.assertIsNone(context.state_group)
 
@@ -484,36 +496,45 @@ class StateTestCase(unittest.TestCase):
             }
         )
 
+        creation = create_event(
+            type=EventTypes.Create, state_key="",
+            content={"creator": "@foo:bar"}
+        )
+
         old_state_1 = [
+            creation,
             member_event,
             create_event(type="test1", state_key="1", depth=1),
         ]
 
         old_state_2 = [
+            creation,
             member_event,
             create_event(type="test1", state_key="1", depth=2),
         ]
 
         context = yield self._get_context(event, old_state_1, old_state_2)
 
-        self.assertEqual(old_state_2[1], context.current_state[("test1", "1")])
+        self.assertEqual(old_state_2[2], context.current_state[("test1", "1")])
 
         # Reverse the depth to make sure we are actually using the depths
         # during state resolution.
 
         old_state_1 = [
+            creation,
             member_event,
             create_event(type="test1", state_key="1", depth=2),
         ]
 
         old_state_2 = [
+            creation,
             member_event,
             create_event(type="test1", state_key="1", depth=1),
         ]
 
         context = yield self._get_context(event, old_state_1, old_state_2)
 
-        self.assertEqual(old_state_1[1], context.current_state[("test1", "1")])
+        self.assertEqual(old_state_1[2], context.current_state[("test1", "1")])
 
     def _get_context(self, event, old_state_1, old_state_2):
         group_name_1 = "group_name_1"