summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/test_state.py37
-rw-r--r--tests/utils.py13
2 files changed, 42 insertions, 8 deletions
diff --git a/tests/test_state.py b/tests/test_state.py
index 5845358754..55f37c521f 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -204,8 +204,8 @@ class StateTestCase(unittest.TestCase):
             nodes={
                 "START": DictObj(
                     type=EventTypes.Create,
-                    state_key="creator",
-                    content={"membership": "@user_id:example.com"},
+                    state_key="",
+                    content={"creator": "@user_id:example.com"},
                     depth=1,
                 ),
                 "A": DictObj(
@@ -259,8 +259,8 @@ class StateTestCase(unittest.TestCase):
             nodes={
                 "START": DictObj(
                     type=EventTypes.Create,
-                    state_key="creator",
-                    content={"membership": "@user_id:example.com"},
+                    state_key="",
+                    content={"creator": "@user_id:example.com"},
                     depth=1,
                 ),
                 "A": DictObj(
@@ -432,13 +432,19 @@ class StateTestCase(unittest.TestCase):
     def test_resolve_message_conflict(self):
         event = create_event(type="test_message", name="event")
 
+        creation = create_event(
+            type=EventTypes.Create, state_key=""
+        )
+
         old_state_1 = [
+            creation,
             create_event(type="test1", state_key="1"),
             create_event(type="test1", state_key="2"),
             create_event(type="test2", state_key=""),
         ]
 
         old_state_2 = [
+            creation,
             create_event(type="test1", state_key="1"),
             create_event(type="test3", state_key="2"),
             create_event(type="test4", state_key=""),
@@ -446,7 +452,7 @@ class StateTestCase(unittest.TestCase):
 
         context = yield self._get_context(event, old_state_1, old_state_2)
 
-        self.assertEqual(len(context.current_state), 5)
+        self.assertEqual(len(context.current_state), 6)
 
         self.assertIsNone(context.state_group)
 
@@ -454,13 +460,19 @@ class StateTestCase(unittest.TestCase):
     def test_resolve_state_conflict(self):
         event = create_event(type="test4", state_key="", name="event")
 
+        creation = create_event(
+            type=EventTypes.Create, state_key=""
+        )
+
         old_state_1 = [
+            creation,
             create_event(type="test1", state_key="1"),
             create_event(type="test1", state_key="2"),
             create_event(type="test2", state_key=""),
         ]
 
         old_state_2 = [
+            creation,
             create_event(type="test1", state_key="1"),
             create_event(type="test3", state_key="2"),
             create_event(type="test4", state_key=""),
@@ -468,7 +480,7 @@ class StateTestCase(unittest.TestCase):
 
         context = yield self._get_context(event, old_state_1, old_state_2)
 
-        self.assertEqual(len(context.current_state), 5)
+        self.assertEqual(len(context.current_state), 6)
 
         self.assertIsNone(context.state_group)
 
@@ -484,36 +496,45 @@ class StateTestCase(unittest.TestCase):
             }
         )
 
+        creation = create_event(
+            type=EventTypes.Create, state_key="",
+            content={"creator": "@foo:bar"}
+        )
+
         old_state_1 = [
+            creation,
             member_event,
             create_event(type="test1", state_key="1", depth=1),
         ]
 
         old_state_2 = [
+            creation,
             member_event,
             create_event(type="test1", state_key="1", depth=2),
         ]
 
         context = yield self._get_context(event, old_state_1, old_state_2)
 
-        self.assertEqual(old_state_2[1], context.current_state[("test1", "1")])
+        self.assertEqual(old_state_2[2], context.current_state[("test1", "1")])
 
         # Reverse the depth to make sure we are actually using the depths
         # during state resolution.
 
         old_state_1 = [
+            creation,
             member_event,
             create_event(type="test1", state_key="1", depth=2),
         ]
 
         old_state_2 = [
+            creation,
             member_event,
             create_event(type="test1", state_key="1", depth=1),
         ]
 
         context = yield self._get_context(event, old_state_1, old_state_2)
 
-        self.assertEqual(old_state_1[1], context.current_state[("test1", "1")])
+        self.assertEqual(old_state_1[2], context.current_state[("test1", "1")])
 
     def _get_context(self, event, old_state_1, old_state_2):
         group_name_1 = "group_name_1"
diff --git a/tests/utils.py b/tests/utils.py
index 3766a994f2..dd19a16fc7 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -27,6 +27,7 @@ from twisted.enterprise.adbapi import ConnectionPool
 
 from collections import namedtuple
 from mock import patch, Mock
+import hashlib
 import urllib
 import urlparse
 
@@ -67,6 +68,18 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
             **kargs
         )
 
+    # bcrypt is far too slow to be doing in unit tests
+    def swap_out_hash_for_testing(old_build_handlers):
+        def build_handlers():
+            handlers = old_build_handlers()
+            auth_handler = handlers.auth_handler
+            auth_handler.hash = lambda p: hashlib.md5(p).hexdigest()
+            auth_handler.validate_hash = lambda p, h: hashlib.md5(p).hexdigest() == h
+            return handlers
+        return build_handlers
+
+    hs.build_handlers = swap_out_hash_for_testing(hs.build_handlers)
+
     defer.returnValue(hs)