summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2015-09-25 11:38:28 +0100
committerErik Johnston <erik@matrix.org>2015-09-25 11:38:28 +0100
commita14665bde7730872627e956860be0d727a53e26d (patch)
treeeeb44967874a737bd2492b6591966a389de1aeda /tests
parentBundle in some room state in the unsigned bit of the invite when sending to i... (diff)
parentMerge pull request #289 from matrix-org/markjh/fix_sql (diff)
downloadsynapse-a14665bde7730872627e956860be0d727a53e26d.tar.xz
Merge branch 'develop' of github.com:matrix-org/synapse into erikj/invite_state
Diffstat (limited to 'tests')
-rw-r--r--tests/api/test_auth.py145
-rw-r--r--tests/rest/client/v1/test_presence.py4
-rw-r--r--tests/rest/client/v1/test_rooms.py28
-rw-r--r--tests/rest/client/v1/test_typing.py2
-rw-r--r--tests/rest/client/v1/utils.py3
-rw-r--r--tests/rest/client/v2_alpha/__init__.py2
-rw-r--r--tests/storage/event_injector.py81
-rw-r--r--tests/storage/test_events.py116
-rw-r--r--tests/storage/test_room.py2
-rw-r--r--tests/storage/test_stream.py68
-rw-r--r--tests/test_state.py37
11 files changed, 403 insertions, 85 deletions
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index 22fc804331..c96273480d 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -19,17 +19,21 @@ from mock import Mock
 
 from synapse.api.auth import Auth
 from synapse.api.errors import AuthError
+from synapse.types import UserID
+from tests.utils import setup_test_homeserver
+
+import pymacaroons
 
 
 class AuthTestCase(unittest.TestCase):
 
+    @defer.inlineCallbacks
     def setUp(self):
         self.state_handler = Mock()
         self.store = Mock()
 
-        self.hs = Mock()
+        self.hs = yield setup_test_homeserver(handlers=None)
         self.hs.get_datastore = Mock(return_value=self.store)
-        self.hs.get_state_handler = Mock(return_value=self.state_handler)
         self.auth = Auth(self.hs)
 
         self.test_user = "@foo:bar"
@@ -133,3 +137,140 @@ class AuthTestCase(unittest.TestCase):
         request.requestHeaders.getRawHeaders = Mock(return_value=[""])
         d = self.auth.get_user_by_req(request)
         self.failureResultOf(d, AuthError)
+
+    @defer.inlineCallbacks
+    def test_get_user_from_macaroon(self):
+        # TODO(danielwh): Remove this mock when we remove the
+        # get_user_by_access_token fallback.
+        self.store.get_user_by_access_token = Mock(
+            return_value={"name": "@baldrick:matrix.org"}
+        )
+
+        user_id = "@baldrick:matrix.org"
+        macaroon = pymacaroons.Macaroon(
+            location=self.hs.config.server_name,
+            identifier="key",
+            key=self.hs.config.macaroon_secret_key)
+        macaroon.add_first_party_caveat("gen = 1")
+        macaroon.add_first_party_caveat("type = access")
+        macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
+        user_info = yield self.auth._get_user_from_macaroon(macaroon.serialize())
+        user = user_info["user"]
+        self.assertEqual(UserID.from_string(user_id), user)
+
+    @defer.inlineCallbacks
+    def test_get_user_from_macaroon_user_db_mismatch(self):
+        self.store.get_user_by_access_token = Mock(
+            return_value={"name": "@percy:matrix.org"}
+        )
+
+        user = "@baldrick:matrix.org"
+        macaroon = pymacaroons.Macaroon(
+            location=self.hs.config.server_name,
+            identifier="key",
+            key=self.hs.config.macaroon_secret_key)
+        macaroon.add_first_party_caveat("gen = 1")
+        macaroon.add_first_party_caveat("type = access")
+        macaroon.add_first_party_caveat("user_id = %s" % (user,))
+        with self.assertRaises(AuthError) as cm:
+            yield self.auth._get_user_from_macaroon(macaroon.serialize())
+        self.assertEqual(401, cm.exception.code)
+        self.assertIn("User mismatch", cm.exception.msg)
+
+    @defer.inlineCallbacks
+    def test_get_user_from_macaroon_missing_caveat(self):
+        # TODO(danielwh): Remove this mock when we remove the
+        # get_user_by_access_token fallback.
+        self.store.get_user_by_access_token = Mock(
+            return_value={"name": "@baldrick:matrix.org"}
+        )
+
+        macaroon = pymacaroons.Macaroon(
+            location=self.hs.config.server_name,
+            identifier="key",
+            key=self.hs.config.macaroon_secret_key)
+        macaroon.add_first_party_caveat("gen = 1")
+        macaroon.add_first_party_caveat("type = access")
+
+        with self.assertRaises(AuthError) as cm:
+            yield self.auth._get_user_from_macaroon(macaroon.serialize())
+        self.assertEqual(401, cm.exception.code)
+        self.assertIn("No user caveat", cm.exception.msg)
+
+    @defer.inlineCallbacks
+    def test_get_user_from_macaroon_wrong_key(self):
+        # TODO(danielwh): Remove this mock when we remove the
+        # get_user_by_access_token fallback.
+        self.store.get_user_by_access_token = Mock(
+            return_value={"name": "@baldrick:matrix.org"}
+        )
+
+        user = "@baldrick:matrix.org"
+        macaroon = pymacaroons.Macaroon(
+            location=self.hs.config.server_name,
+            identifier="key",
+            key=self.hs.config.macaroon_secret_key + "wrong")
+        macaroon.add_first_party_caveat("gen = 1")
+        macaroon.add_first_party_caveat("type = access")
+        macaroon.add_first_party_caveat("user_id = %s" % (user,))
+
+        with self.assertRaises(AuthError) as cm:
+            yield self.auth._get_user_from_macaroon(macaroon.serialize())
+        self.assertEqual(401, cm.exception.code)
+        self.assertIn("Invalid macaroon", cm.exception.msg)
+
+    @defer.inlineCallbacks
+    def test_get_user_from_macaroon_unknown_caveat(self):
+        # TODO(danielwh): Remove this mock when we remove the
+        # get_user_by_access_token fallback.
+        self.store.get_user_by_access_token = Mock(
+            return_value={"name": "@baldrick:matrix.org"}
+        )
+
+        user = "@baldrick:matrix.org"
+        macaroon = pymacaroons.Macaroon(
+            location=self.hs.config.server_name,
+            identifier="key",
+            key=self.hs.config.macaroon_secret_key)
+        macaroon.add_first_party_caveat("gen = 1")
+        macaroon.add_first_party_caveat("type = access")
+        macaroon.add_first_party_caveat("user_id = %s" % (user,))
+        macaroon.add_first_party_caveat("cunning > fox")
+
+        with self.assertRaises(AuthError) as cm:
+            yield self.auth._get_user_from_macaroon(macaroon.serialize())
+        self.assertEqual(401, cm.exception.code)
+        self.assertIn("Invalid macaroon", cm.exception.msg)
+
+    @defer.inlineCallbacks
+    def test_get_user_from_macaroon_expired(self):
+        # TODO(danielwh): Remove this mock when we remove the
+        # get_user_by_access_token fallback.
+        self.store.get_user_by_access_token = Mock(
+            return_value={"name": "@baldrick:matrix.org"}
+        )
+
+        self.store.get_user_by_access_token = Mock(
+            return_value={"name": "@baldrick:matrix.org"}
+        )
+
+        user = "@baldrick:matrix.org"
+        macaroon = pymacaroons.Macaroon(
+            location=self.hs.config.server_name,
+            identifier="key",
+            key=self.hs.config.macaroon_secret_key)
+        macaroon.add_first_party_caveat("gen = 1")
+        macaroon.add_first_party_caveat("type = access")
+        macaroon.add_first_party_caveat("user_id = %s" % (user,))
+        macaroon.add_first_party_caveat("time < 1") # ms
+
+        self.hs.clock.now = 5000 # seconds
+
+        yield self.auth._get_user_from_macaroon(macaroon.serialize())
+        # TODO(daniel): Turn on the check that we validate expiration, when we
+        # validate expiration (and remove the above line, which will start
+        # throwing).
+        # with self.assertRaises(AuthError) as cm:
+        #     yield self.auth._get_user_from_macaroon(macaroon.serialize())
+        # self.assertEqual(401, cm.exception.code)
+        # self.assertIn("Invalid macaroon", cm.exception.msg)
diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py
index 91547bdd06..2ee3da0b34 100644
--- a/tests/rest/client/v1/test_presence.py
+++ b/tests/rest/client/v1/test_presence.py
@@ -76,7 +76,7 @@ class PresenceStateTestCase(unittest.TestCase):
                 "token_id": 1,
             }
 
-        hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
+        hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
 
         room_member_handler = hs.handlers.room_member_handler = Mock(
             spec=[
@@ -169,7 +169,7 @@ class PresenceListTestCase(unittest.TestCase):
             ]
         )
 
-        hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
+        hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
 
         presence.register_servlets(hs, self.mock_resource)
 
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index 34ab47d02e..a2123be81b 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -59,7 +59,7 @@ class RoomPermissionsTestCase(RestTestCase):
                 "user": UserID.from_string(self.auth_user_id),
                 "token_id": 1,
             }
-        hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
+        hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
 
         def _insert_client_ip(*args, **kwargs):
             return defer.succeed(None)
@@ -239,7 +239,7 @@ class RoomPermissionsTestCase(RestTestCase):
                            "PUT", topic_path, topic_content)
         self.assertEquals(403, code, msg=str(response))
         (code, response) = yield self.mock_resource.trigger_get(topic_path)
-        self.assertEquals(403, code, msg=str(response))
+        self.assertEquals(200, code, msg=str(response))
 
         # get topic in PUBLIC room, not joined, expect 403
         (code, response) = yield self.mock_resource.trigger_get(
@@ -301,11 +301,11 @@ class RoomPermissionsTestCase(RestTestCase):
             room=room, expect_code=200)
 
         # get membership of self, get membership of other, private room + left
-        # expect all 403s
+        # expect all 200s
         yield self.leave(room=room, user=self.user_id)
         yield self._test_get_membership(
             members=[self.user_id, self.rmcreator_id],
-            room=room, expect_code=403)
+            room=room, expect_code=200)
 
     @defer.inlineCallbacks
     def test_membership_public_room_perms(self):
@@ -326,11 +326,11 @@ class RoomPermissionsTestCase(RestTestCase):
             room=room, expect_code=200)
 
         # get membership of self, get membership of other, public room + left
-        # expect 403.
+        # expect 200.
         yield self.leave(room=room, user=self.user_id)
         yield self._test_get_membership(
             members=[self.user_id, self.rmcreator_id],
-            room=room, expect_code=403)
+            room=room, expect_code=200)
 
     @defer.inlineCallbacks
     def test_invited_permissions(self):
@@ -444,7 +444,7 @@ class RoomsMemberListTestCase(RestTestCase):
                 "user": UserID.from_string(self.auth_user_id),
                 "token_id": 1,
             }
-        hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
+        hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
 
         def _insert_client_ip(*args, **kwargs):
             return defer.succeed(None)
@@ -492,9 +492,9 @@ class RoomsMemberListTestCase(RestTestCase):
         self.assertEquals(200, code, msg=str(response))
 
         yield self.leave(room=room_id, user=self.user_id)
-        # can no longer see list, you've left.
+        # can see old list once left
         (code, response) = yield self.mock_resource.trigger_get(room_path)
-        self.assertEquals(403, code, msg=str(response))
+        self.assertEquals(200, code, msg=str(response))
 
 
 class RoomsCreateTestCase(RestTestCase):
@@ -522,7 +522,7 @@ class RoomsCreateTestCase(RestTestCase):
                 "user": UserID.from_string(self.auth_user_id),
                 "token_id": 1,
             }
-        hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
+        hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
 
         def _insert_client_ip(*args, **kwargs):
             return defer.succeed(None)
@@ -614,7 +614,7 @@ class RoomTopicTestCase(RestTestCase):
                 "token_id": 1,
             }
 
-        hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
+        hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
 
         def _insert_client_ip(*args, **kwargs):
             return defer.succeed(None)
@@ -718,7 +718,7 @@ class RoomMemberStateTestCase(RestTestCase):
                 "user": UserID.from_string(self.auth_user_id),
                 "token_id": 1,
             }
-        hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
+        hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
 
         def _insert_client_ip(*args, **kwargs):
             return defer.succeed(None)
@@ -843,7 +843,7 @@ class RoomMessagesTestCase(RestTestCase):
                 "user": UserID.from_string(self.auth_user_id),
                 "token_id": 1,
             }
-        hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
+        hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
 
         def _insert_client_ip(*args, **kwargs):
             return defer.succeed(None)
@@ -938,7 +938,7 @@ class RoomInitialSyncTestCase(RestTestCase):
                 "user": UserID.from_string(self.auth_user_id),
                 "token_id": 1,
             }
-        hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
+        hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
 
         def _insert_client_ip(*args, **kwargs):
             return defer.succeed(None)
diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py
index 1c4519406d..6395ce79db 100644
--- a/tests/rest/client/v1/test_typing.py
+++ b/tests/rest/client/v1/test_typing.py
@@ -67,7 +67,7 @@ class RoomTypingTestCase(RestTestCase):
                 "token_id": 1,
             }
 
-        hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
+        hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
 
         def _insert_client_ip(*args, **kwargs):
             return defer.succeed(None)
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index c472d53043..85096a0326 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -37,9 +37,6 @@ class RestTestCase(unittest.TestCase):
         self.mock_resource = None
         self.auth_user_id = None
 
-    def mock_get_user_by_access_token(self, token=None):
-        return self.auth_user_id
-
     @defer.inlineCallbacks
     def create_room_as(self, room_creator, is_public=True, tok=None):
         temp_id = self.auth_user_id
diff --git a/tests/rest/client/v2_alpha/__init__.py b/tests/rest/client/v2_alpha/__init__.py
index ef972a53aa..f45570a1c0 100644
--- a/tests/rest/client/v2_alpha/__init__.py
+++ b/tests/rest/client/v2_alpha/__init__.py
@@ -48,7 +48,7 @@ class V2AlphaRestTestCase(unittest.TestCase):
                 "user": UserID.from_string(self.USER_ID),
                 "token_id": 1,
             }
-        hs.get_auth().get_user_by_access_token = _get_user_by_access_token
+        hs.get_auth()._get_user_by_access_token = _get_user_by_access_token
 
         for r in self.TO_REGISTER:
             r.register_servlets(hs, self.mock_resource)
diff --git a/tests/storage/event_injector.py b/tests/storage/event_injector.py
new file mode 100644
index 0000000000..42bd8928bd
--- /dev/null
+++ b/tests/storage/event_injector.py
@@ -0,0 +1,81 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from tests import unittest
+from twisted.internet import defer
+
+from synapse.api.constants import EventTypes, Membership
+from synapse.types import UserID, RoomID
+
+from tests.utils import setup_test_homeserver
+
+from mock import Mock
+
+
+class EventInjector:
+    def __init__(self, hs):
+        self.hs = hs
+        self.store = hs.get_datastore()
+        self.message_handler = hs.get_handlers().message_handler
+        self.event_builder_factory = hs.get_event_builder_factory()
+
+    @defer.inlineCallbacks
+    def create_room(self, room):
+        builder = self.event_builder_factory.new({
+            "type": EventTypes.Create,
+            "room_id": room.to_string(),
+            "content": {},
+        })
+
+        event, context = yield self.message_handler._create_new_client_event(
+            builder
+        )
+
+        yield self.store.persist_event(event, context)
+
+    @defer.inlineCallbacks
+    def inject_room_member(self, room, user, membership):
+        builder = self.event_builder_factory.new({
+            "type": EventTypes.Member,
+            "sender": user.to_string(),
+            "state_key": user.to_string(),
+            "room_id": room.to_string(),
+            "content": {"membership": membership},
+        })
+
+        event, context = yield self.message_handler._create_new_client_event(
+            builder
+        )
+
+        yield self.store.persist_event(event, context)
+
+        defer.returnValue(event)
+
+    @defer.inlineCallbacks
+    def inject_message(self, room, user, body):
+        builder = self.event_builder_factory.new({
+            "type": EventTypes.Message,
+            "sender": user.to_string(),
+            "state_key": user.to_string(),
+            "room_id": room.to_string(),
+            "content": {"body": body, "msgtype": u"message"},
+        })
+
+        event, context = yield self.message_handler._create_new_client_event(
+            builder
+        )
+
+        yield self.store.persist_event(event, context)
diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py
new file mode 100644
index 0000000000..313013009e
--- /dev/null
+++ b/tests/storage/test_events.py
@@ -0,0 +1,116 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import uuid
+from mock.mock import Mock
+from synapse.types import RoomID, UserID
+
+from tests import unittest
+from twisted.internet import defer
+from tests.storage.event_injector import EventInjector
+
+from tests.utils import setup_test_homeserver
+
+
+class EventsStoreTestCase(unittest.TestCase):
+
+    @defer.inlineCallbacks
+    def setUp(self):
+        self.hs = yield setup_test_homeserver(
+            resource_for_federation=Mock(),
+            http_client=None,
+        )
+        self.store = self.hs.get_datastore()
+        self.db_pool = self.hs.get_db_pool()
+        self.message_handler = self.hs.get_handlers().message_handler
+        self.event_injector = EventInjector(self.hs)
+
+    @defer.inlineCallbacks
+    def test_count_daily_messages(self):
+        self.db_pool.runQuery("DELETE FROM stats_reporting")
+
+        self.hs.clock.now = 100
+
+        # Never reported before, and nothing which could be reported
+        count = yield self.store.count_daily_messages()
+        self.assertIsNone(count)
+        count = yield self.db_pool.runQuery("SELECT COUNT(*) FROM stats_reporting")
+        self.assertEqual([(0,)], count)
+
+        # Create something to report
+        room = RoomID.from_string("!abc123:test")
+        user = UserID.from_string("@raccoonlover:test")
+        yield self.event_injector.create_room(room)
+
+        self.base_event = yield self._get_last_stream_token()
+
+        yield self.event_injector.inject_message(room, user, "Raccoons are really cute")
+
+        # Never reported before, something could be reported, but isn't because
+        # it isn't old enough.
+        count = yield self.store.count_daily_messages()
+        self.assertIsNone(count)
+        self._assert_stats_reporting(1, self.hs.clock.now)
+
+        # Already reported yesterday, two new events from today.
+        yield self.event_injector.inject_message(room, user, "Yeah they are!")
+        yield self.event_injector.inject_message(room, user, "Incredibly!")
+        self.hs.clock.now += 60 * 60 * 24
+        count = yield self.store.count_daily_messages()
+        self.assertEqual(2, count)  # 2 since yesterday
+        self._assert_stats_reporting(3, self.hs.clock.now)  # 3 ever
+
+        # Last reported too recently.
+        yield self.event_injector.inject_message(room, user, "Who could disagree?")
+        self.hs.clock.now += 60 * 60 * 22
+        count = yield self.store.count_daily_messages()
+        self.assertIsNone(count)
+        self._assert_stats_reporting(4, self.hs.clock.now)
+
+        # Last reported too long ago
+        yield self.event_injector.inject_message(room, user, "No one.")
+        self.hs.clock.now += 60 * 60 * 26
+        count = yield self.store.count_daily_messages()
+        self.assertIsNone(count)
+        self._assert_stats_reporting(5, self.hs.clock.now)
+
+        # And now let's actually report something
+        yield self.event_injector.inject_message(room, user, "Indeed.")
+        yield self.event_injector.inject_message(room, user, "Indeed.")
+        yield self.event_injector.inject_message(room, user, "Indeed.")
+        # A little over 24 hours is fine :)
+        self.hs.clock.now += (60 * 60 * 24) + 50
+        count = yield self.store.count_daily_messages()
+        self.assertEqual(3, count)
+        self._assert_stats_reporting(8, self.hs.clock.now)
+
+    @defer.inlineCallbacks
+    def _get_last_stream_token(self):
+        rows = yield self.db_pool.runQuery(
+            "SELECT stream_ordering"
+            " FROM events"
+            " ORDER BY stream_ordering DESC"
+            " LIMIT 1"
+        )
+        if not rows:
+            defer.returnValue(0)
+        else:
+            defer.returnValue(rows[0][0])
+
+    @defer.inlineCallbacks
+    def _assert_stats_reporting(self, messages, time):
+        rows = yield self.db_pool.runQuery(
+            "SELECT reported_stream_token, reported_time FROM stats_reporting"
+        )
+        self.assertEqual([(self.base_event + messages, time,)], rows)
diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py
index ab7625a3ca..caffce64e3 100644
--- a/tests/storage/test_room.py
+++ b/tests/storage/test_room.py
@@ -85,7 +85,7 @@ class RoomEventsStoreTestCase(unittest.TestCase):
         # Room events need the full datastore, for persist_event() and
         # get_room_state()
         self.store = hs.get_datastore()
-        self.event_factory = hs.get_event_factory();
+        self.event_factory = hs.get_event_factory()
 
         self.room = RoomID.from_string("!abcde:test")
 
diff --git a/tests/storage/test_stream.py b/tests/storage/test_stream.py
index 0c9b89d765..a658a789aa 100644
--- a/tests/storage/test_stream.py
+++ b/tests/storage/test_stream.py
@@ -19,6 +19,7 @@ from twisted.internet import defer
 
 from synapse.api.constants import EventTypes, Membership
 from synapse.types import UserID, RoomID
+from tests.storage.event_injector import EventInjector
 
 from tests.utils import setup_test_homeserver
 
@@ -36,6 +37,7 @@ class StreamStoreTestCase(unittest.TestCase):
 
         self.store = hs.get_datastore()
         self.event_builder_factory = hs.get_event_builder_factory()
+        self.event_injector = EventInjector(hs)
         self.handlers = hs.get_handlers()
         self.message_handler = self.handlers.message_handler
 
@@ -45,60 +47,20 @@ class StreamStoreTestCase(unittest.TestCase):
         self.room1 = RoomID.from_string("!abc123:test")
         self.room2 = RoomID.from_string("!xyx987:test")
 
-        self.depth = 1
-
-    @defer.inlineCallbacks
-    def inject_room_member(self, room, user, membership):
-        self.depth += 1
-
-        builder = self.event_builder_factory.new({
-            "type": EventTypes.Member,
-            "sender": user.to_string(),
-            "state_key": user.to_string(),
-            "room_id": room.to_string(),
-            "content": {"membership": membership},
-        })
-
-        event, context = yield self.message_handler._create_new_client_event(
-            builder
-        )
-
-        yield self.store.persist_event(event, context)
-
-        defer.returnValue(event)
-
-    @defer.inlineCallbacks
-    def inject_message(self, room, user, body):
-        self.depth += 1
-
-        builder = self.event_builder_factory.new({
-            "type": EventTypes.Message,
-            "sender": user.to_string(),
-            "state_key": user.to_string(),
-            "room_id": room.to_string(),
-            "content": {"body": body, "msgtype": u"message"},
-        })
-
-        event, context = yield self.message_handler._create_new_client_event(
-            builder
-        )
-
-        yield self.store.persist_event(event, context)
-
     @defer.inlineCallbacks
     def test_event_stream_get_other(self):
         # Both bob and alice joins the room
-        yield self.inject_room_member(
+        yield self.event_injector.inject_room_member(
             self.room1, self.u_alice, Membership.JOIN
         )
-        yield self.inject_room_member(
+        yield self.event_injector.inject_room_member(
             self.room1, self.u_bob, Membership.JOIN
         )
 
         # Initial stream key:
         start = yield self.store.get_room_events_max_id()
 
-        yield self.inject_message(self.room1, self.u_alice, u"test")
+        yield self.event_injector.inject_message(self.room1, self.u_alice, u"test")
 
         end = yield self.store.get_room_events_max_id()
 
@@ -125,17 +87,17 @@ class StreamStoreTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def test_event_stream_get_own(self):
         # Both bob and alice joins the room
-        yield self.inject_room_member(
+        yield self.event_injector.inject_room_member(
             self.room1, self.u_alice, Membership.JOIN
         )
-        yield self.inject_room_member(
+        yield self.event_injector.inject_room_member(
             self.room1, self.u_bob, Membership.JOIN
         )
 
         # Initial stream key:
         start = yield self.store.get_room_events_max_id()
 
-        yield self.inject_message(self.room1, self.u_alice, u"test")
+        yield self.event_injector.inject_message(self.room1, self.u_alice, u"test")
 
         end = yield self.store.get_room_events_max_id()
 
@@ -162,22 +124,22 @@ class StreamStoreTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def test_event_stream_join_leave(self):
         # Both bob and alice joins the room
-        yield self.inject_room_member(
+        yield self.event_injector.inject_room_member(
             self.room1, self.u_alice, Membership.JOIN
         )
-        yield self.inject_room_member(
+        yield self.event_injector.inject_room_member(
             self.room1, self.u_bob, Membership.JOIN
         )
 
         # Then bob leaves again.
-        yield self.inject_room_member(
+        yield self.event_injector.inject_room_member(
             self.room1, self.u_bob, Membership.LEAVE
         )
 
         # Initial stream key:
         start = yield self.store.get_room_events_max_id()
 
-        yield self.inject_message(self.room1, self.u_alice, u"test")
+        yield self.event_injector.inject_message(self.room1, self.u_alice, u"test")
 
         end = yield self.store.get_room_events_max_id()
 
@@ -193,17 +155,17 @@ class StreamStoreTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_event_stream_prev_content(self):
-        yield self.inject_room_member(
+        yield self.event_injector.inject_room_member(
             self.room1, self.u_bob, Membership.JOIN
         )
 
-        event1 = yield self.inject_room_member(
+        event1 = yield self.event_injector.inject_room_member(
             self.room1, self.u_alice, Membership.JOIN
         )
 
         start = yield self.store.get_room_events_max_id()
 
-        event2 = yield self.inject_room_member(
+        event2 = yield self.event_injector.inject_room_member(
             self.room1, self.u_alice, Membership.JOIN,
         )
 
diff --git a/tests/test_state.py b/tests/test_state.py
index 5845358754..55f37c521f 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -204,8 +204,8 @@ class StateTestCase(unittest.TestCase):
             nodes={
                 "START": DictObj(
                     type=EventTypes.Create,
-                    state_key="creator",
-                    content={"membership": "@user_id:example.com"},
+                    state_key="",
+                    content={"creator": "@user_id:example.com"},
                     depth=1,
                 ),
                 "A": DictObj(
@@ -259,8 +259,8 @@ class StateTestCase(unittest.TestCase):
             nodes={
                 "START": DictObj(
                     type=EventTypes.Create,
-                    state_key="creator",
-                    content={"membership": "@user_id:example.com"},
+                    state_key="",
+                    content={"creator": "@user_id:example.com"},
                     depth=1,
                 ),
                 "A": DictObj(
@@ -432,13 +432,19 @@ class StateTestCase(unittest.TestCase):
     def test_resolve_message_conflict(self):
         event = create_event(type="test_message", name="event")
 
+        creation = create_event(
+            type=EventTypes.Create, state_key=""
+        )
+
         old_state_1 = [
+            creation,
             create_event(type="test1", state_key="1"),
             create_event(type="test1", state_key="2"),
             create_event(type="test2", state_key=""),
         ]
 
         old_state_2 = [
+            creation,
             create_event(type="test1", state_key="1"),
             create_event(type="test3", state_key="2"),
             create_event(type="test4", state_key=""),
@@ -446,7 +452,7 @@ class StateTestCase(unittest.TestCase):
 
         context = yield self._get_context(event, old_state_1, old_state_2)
 
-        self.assertEqual(len(context.current_state), 5)
+        self.assertEqual(len(context.current_state), 6)
 
         self.assertIsNone(context.state_group)
 
@@ -454,13 +460,19 @@ class StateTestCase(unittest.TestCase):
     def test_resolve_state_conflict(self):
         event = create_event(type="test4", state_key="", name="event")
 
+        creation = create_event(
+            type=EventTypes.Create, state_key=""
+        )
+
         old_state_1 = [
+            creation,
             create_event(type="test1", state_key="1"),
             create_event(type="test1", state_key="2"),
             create_event(type="test2", state_key=""),
         ]
 
         old_state_2 = [
+            creation,
             create_event(type="test1", state_key="1"),
             create_event(type="test3", state_key="2"),
             create_event(type="test4", state_key=""),
@@ -468,7 +480,7 @@ class StateTestCase(unittest.TestCase):
 
         context = yield self._get_context(event, old_state_1, old_state_2)
 
-        self.assertEqual(len(context.current_state), 5)
+        self.assertEqual(len(context.current_state), 6)
 
         self.assertIsNone(context.state_group)
 
@@ -484,36 +496,45 @@ class StateTestCase(unittest.TestCase):
             }
         )
 
+        creation = create_event(
+            type=EventTypes.Create, state_key="",
+            content={"creator": "@foo:bar"}
+        )
+
         old_state_1 = [
+            creation,
             member_event,
             create_event(type="test1", state_key="1", depth=1),
         ]
 
         old_state_2 = [
+            creation,
             member_event,
             create_event(type="test1", state_key="1", depth=2),
         ]
 
         context = yield self._get_context(event, old_state_1, old_state_2)
 
-        self.assertEqual(old_state_2[1], context.current_state[("test1", "1")])
+        self.assertEqual(old_state_2[2], context.current_state[("test1", "1")])
 
         # Reverse the depth to make sure we are actually using the depths
         # during state resolution.
 
         old_state_1 = [
+            creation,
             member_event,
             create_event(type="test1", state_key="1", depth=2),
         ]
 
         old_state_2 = [
+            creation,
             member_event,
             create_event(type="test1", state_key="1", depth=1),
         ]
 
         context = yield self._get_context(event, old_state_1, old_state_2)
 
-        self.assertEqual(old_state_1[1], context.current_state[("test1", "1")])
+        self.assertEqual(old_state_1[2], context.current_state[("test1", "1")])
 
     def _get_context(self, event, old_state_1, old_state_2):
         group_name_1 = "group_name_1"