summary refs log tree commit diff
path: root/tests/test_state.py
diff options
context:
space:
mode:
authorAmber Brown <hawkowl@atleastfornow.net>2018-09-03 21:08:35 +1000
committerGitHub <noreply@github.com>2018-09-03 21:08:35 +1000
commit4fc4b881c58fd638db5f4dac0863721111b67af0 (patch)
treecc1604f5e3b4e0a263e0e11a55b62ef4006a64a1 /tests/test_state.py
parentThe project `matrix-synapse-auto-deploy` does not seem to be maintained anymore. (diff)
parentMerge pull request #3777 from matrix-org/neilj/fix_register_user_registration (diff)
downloadsynapse-4fc4b881c58fd638db5f4dac0863721111b67af0.tar.xz
Merge branch 'develop' into develop
Diffstat (limited to 'tests/test_state.py')
-rw-r--r--tests/test_state.py293
1 files changed, 137 insertions, 156 deletions
diff --git a/tests/test_state.py b/tests/test_state.py
index a5c5e55951..452a123c3a 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -13,24 +13,31 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from tests import unittest
+from mock import Mock
+
 from twisted.internet import defer
 
-from synapse.events import FrozenEvent
 from synapse.api.auth import Auth
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import EventTypes, Membership, RoomVersions
+from synapse.events import FrozenEvent
 from synapse.state import StateHandler, StateResolutionHandler
 
-from .utils import MockClock
-
-from mock import Mock
+from tests import unittest
 
+from .utils import MockClock
 
 _next_event_id = 1000
 
 
-def create_event(name=None, type=None, state_key=None, depth=2, event_id=None,
-                 prev_events=[], **kwargs):
+def create_event(
+    name=None,
+    type=None,
+    state_key=None,
+    depth=2,
+    event_id=None,
+    prev_events=[],
+    **kwargs
+):
     global _next_event_id
 
     if not event_id:
@@ -39,9 +46,9 @@ def create_event(name=None, type=None, state_key=None, depth=2, event_id=None,
 
     if not name:
         if state_key is not None:
-            name = "<%s-%s, %s>" % (type, state_key, event_id,)
+            name = "<%s-%s, %s>" % (type, state_key, event_id)
         else:
-            name = "<%s, %s>" % (type, event_id,)
+            name = "<%s, %s>" % (type, event_id)
 
     d = {
         "event_id": event_id,
@@ -80,8 +87,9 @@ class StateGroupStore(object):
 
         return defer.succeed(groups)
 
-    def store_state_group(self, event_id, room_id, prev_group, delta_ids,
-                          current_state_ids):
+    def store_state_group(
+        self, event_id, room_id, prev_group, delta_ids, current_state_ids
+    ):
         state_group = self._next_group
         self._next_group += 1
 
@@ -91,7 +99,8 @@ class StateGroupStore(object):
 
     def get_events(self, event_ids, **kwargs):
         return {
-            e_id: self._event_id_to_event[e_id] for e_id in event_ids
+            e_id: self._event_id_to_event[e_id]
+            for e_id in event_ids
             if e_id in self._event_id_to_event
         }
 
@@ -108,6 +117,9 @@ class StateGroupStore(object):
     def register_event_id_state_group(self, event_id, state_group):
         self._event_to_state_group[event_id] = state_group
 
+    def get_room_version(self, room_id):
+        return RoomVersions.V1
+
 
 class DictObj(dict):
     def __init__(self, **kwargs):
@@ -129,9 +141,7 @@ class Graph(object):
                 prev_events = []
 
             events[event_id] = create_event(
-                event_id=event_id,
-                prev_events=prev_events,
-                **fields
+                event_id=event_id, prev_events=prev_events, **fields
             )
 
         self._leaves = clobbered
@@ -147,10 +157,15 @@ class Graph(object):
 class StateTestCase(unittest.TestCase):
     def setUp(self):
         self.store = StateGroupStore()
-        hs = Mock(spec_set=[
-            "get_datastore", "get_auth", "get_state_handler", "get_clock",
-            "get_state_resolution_handler",
-        ])
+        hs = Mock(
+            spec_set=[
+                "get_datastore",
+                "get_auth",
+                "get_state_handler",
+                "get_clock",
+                "get_state_resolution_handler",
+            ]
+        )
         hs.get_datastore.return_value = self.store
         hs.get_state_handler.return_value = None
         hs.get_clock.return_value = MockClock()
@@ -165,34 +180,14 @@ class StateTestCase(unittest.TestCase):
         graph = Graph(
             nodes={
                 "START": DictObj(
-                    type=EventTypes.Create,
-                    state_key="",
-                    depth=1,
-                ),
-                "A": DictObj(
-                    type=EventTypes.Message,
-                    depth=2,
-                ),
-                "B": DictObj(
-                    type=EventTypes.Message,
-                    depth=3,
-                ),
-                "C": DictObj(
-                    type=EventTypes.Name,
-                    state_key="",
-                    depth=3,
-                ),
-                "D": DictObj(
-                    type=EventTypes.Message,
-                    depth=4,
+                    type=EventTypes.Create, state_key="", content={}, depth=1,
                 ),
+                "A": DictObj(type=EventTypes.Message, depth=2),
+                "B": DictObj(type=EventTypes.Message, depth=3),
+                "C": DictObj(type=EventTypes.Name, state_key="", depth=3),
+                "D": DictObj(type=EventTypes.Message, depth=4),
             },
-            edges={
-                "A": ["START"],
-                "B": ["A"],
-                "C": ["A"],
-                "D": ["B", "C"]
-            }
+            edges={"A": ["START"], "B": ["A"], "C": ["A"], "D": ["B", "C"]},
         )
 
         self.store.register_events(graph.walk())
@@ -204,7 +199,8 @@ class StateTestCase(unittest.TestCase):
             self.store.register_event_context(event, context)
             context_store[event.event_id] = context
 
-        self.assertEqual(2, len(context_store["D"].prev_state_ids))
+        prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store)
+        self.assertEqual(2, len(prev_state_ids))
 
     @defer.inlineCallbacks
     def test_branch_basic_conflict(self):
@@ -223,27 +219,11 @@ class StateTestCase(unittest.TestCase):
                     membership=Membership.JOIN,
                     depth=2,
                 ),
-                "B": DictObj(
-                    type=EventTypes.Name,
-                    state_key="",
-                    depth=3,
-                ),
-                "C": DictObj(
-                    type=EventTypes.Name,
-                    state_key="",
-                    depth=4,
-                ),
-                "D": DictObj(
-                    type=EventTypes.Message,
-                    depth=5,
-                ),
+                "B": DictObj(type=EventTypes.Name, state_key="", depth=3),
+                "C": DictObj(type=EventTypes.Name, state_key="", depth=4),
+                "D": DictObj(type=EventTypes.Message, depth=5),
             },
-            edges={
-                "A": ["START"],
-                "B": ["A"],
-                "C": ["A"],
-                "D": ["B", "C"]
-            }
+            edges={"A": ["START"], "B": ["A"], "C": ["A"], "D": ["B", "C"]},
         )
 
         self.store.register_events(graph.walk())
@@ -255,9 +235,10 @@ class StateTestCase(unittest.TestCase):
             self.store.register_event_context(event, context)
             context_store[event.event_id] = context
 
+        prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store)
+
         self.assertSetEqual(
-            {"START", "A", "C"},
-            {e_id for e_id in context_store["D"].prev_state_ids.values()}
+            {"START", "A", "C"}, {e_id for e_id in prev_state_ids.values()}
         )
 
     @defer.inlineCallbacks
@@ -277,11 +258,7 @@ class StateTestCase(unittest.TestCase):
                     membership=Membership.JOIN,
                     depth=2,
                 ),
-                "B": DictObj(
-                    type=EventTypes.Name,
-                    state_key="",
-                    depth=3,
-                ),
+                "B": DictObj(type=EventTypes.Name, state_key="", depth=3),
                 "C": DictObj(
                     type=EventTypes.Member,
                     state_key="@user_id_2:example.com",
@@ -295,18 +272,9 @@ class StateTestCase(unittest.TestCase):
                     depth=4,
                     sender="@user_id_2:example.com",
                 ),
-                "E": DictObj(
-                    type=EventTypes.Message,
-                    depth=5,
-                ),
+                "E": DictObj(type=EventTypes.Message, depth=5),
             },
-            edges={
-                "A": ["START"],
-                "B": ["A"],
-                "C": ["B"],
-                "D": ["B"],
-                "E": ["C", "D"]
-            }
+            edges={"A": ["START"], "B": ["A"], "C": ["B"], "D": ["B"], "E": ["C", "D"]},
         )
 
         self.store.register_events(graph.walk())
@@ -318,9 +286,10 @@ class StateTestCase(unittest.TestCase):
             self.store.register_event_context(event, context)
             context_store[event.event_id] = context
 
+        prev_state_ids = yield context_store["E"].get_prev_state_ids(self.store)
+
         self.assertSetEqual(
-            {"START", "A", "B", "C"},
-            {e for e in context_store["E"].prev_state_ids.values()}
+            {"START", "A", "B", "C"}, {e for e in prev_state_ids.values()}
         )
 
     @defer.inlineCallbacks
@@ -352,30 +321,17 @@ class StateTestCase(unittest.TestCase):
                 state_key="",
                 content={
                     "events": {"m.room.name": 50},
-                    "users": {userid1: 100,
-                              userid2: 60},
+                    "users": {userid1: 100, userid2: 60},
                 },
             ),
-            "A5": DictObj(
-                type=EventTypes.Name,
-                state_key="",
-            ),
+            "A5": DictObj(type=EventTypes.Name, state_key=""),
             "B": DictObj(
                 type=EventTypes.PowerLevels,
                 state_key="",
-                content={
-                    "events": {"m.room.name": 50},
-                    "users": {userid2: 30},
-                },
-            ),
-            "C": DictObj(
-                type=EventTypes.Name,
-                state_key="",
-                sender=userid2,
-            ),
-            "D": DictObj(
-                type=EventTypes.Message,
+                content={"events": {"m.room.name": 50}, "users": {userid2: 30}},
             ),
+            "C": DictObj(type=EventTypes.Name, state_key="", sender=userid2),
+            "D": DictObj(type=EventTypes.Message),
         }
         edges = {
             "A2": ["A1"],
@@ -384,7 +340,7 @@ class StateTestCase(unittest.TestCase):
             "A5": ["A4"],
             "B": ["A5"],
             "C": ["A5"],
-            "D": ["B", "C"]
+            "D": ["B", "C"],
         }
         self._add_depths(nodes, edges)
         graph = Graph(nodes, edges)
@@ -398,9 +354,10 @@ class StateTestCase(unittest.TestCase):
             self.store.register_event_context(event, context)
             context_store[event.event_id] = context
 
+        prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store)
+
         self.assertSetEqual(
-            {"A1", "A2", "A3", "A5", "B"},
-            {e for e in context_store["D"].prev_state_ids.values()}
+            {"A1", "A2", "A3", "A5", "B"}, {e for e in prev_state_ids.values()}
         )
 
     def _add_depths(self, nodes, edges):
@@ -425,12 +382,12 @@ class StateTestCase(unittest.TestCase):
             create_event(type="test2", state_key=""),
         ]
 
-        context = yield self.state.compute_event_context(
-            event, old_state=old_state
-        )
+        context = yield self.state.compute_event_context(event, old_state=old_state)
+
+        current_state_ids = yield context.get_current_state_ids(self.store)
 
         self.assertEqual(
-            set(e.event_id for e in old_state), set(context.current_state_ids.values())
+            set(e.event_id for e in old_state), set(current_state_ids.values())
         )
 
         self.assertIsNotNone(context.state_group)
@@ -445,20 +402,19 @@ class StateTestCase(unittest.TestCase):
             create_event(type="test2", state_key=""),
         ]
 
-        context = yield self.state.compute_event_context(
-            event, old_state=old_state
-        )
+        context = yield self.state.compute_event_context(event, old_state=old_state)
+
+        prev_state_ids = yield context.get_prev_state_ids(self.store)
 
         self.assertEqual(
-            set(e.event_id for e in old_state), set(context.prev_state_ids.values())
+            set(e.event_id for e in old_state), set(prev_state_ids.values())
         )
 
     @defer.inlineCallbacks
     def test_trivial_annotate_message(self):
         prev_event_id = "prev_event_id"
         event = create_event(
-            type="test_message", name="event2",
-            prev_events=[(prev_event_id, {})],
+            type="test_message", name="event2", prev_events=[(prev_event_id, {})]
         )
 
         old_state = [
@@ -468,16 +424,20 @@ class StateTestCase(unittest.TestCase):
         ]
 
         group_name = self.store.store_state_group(
-            prev_event_id, event.room_id, None, None,
+            prev_event_id,
+            event.room_id,
+            None,
+            None,
             {(e.type, e.state_key): e.event_id for e in old_state},
         )
         self.store.register_event_id_state_group(prev_event_id, group_name)
 
         context = yield self.state.compute_event_context(event)
 
+        current_state_ids = yield context.get_current_state_ids(self.store)
+
         self.assertEqual(
-            set([e.event_id for e in old_state]),
-            set(context.current_state_ids.values())
+            set([e.event_id for e in old_state]), set(current_state_ids.values())
         )
 
         self.assertEqual(group_name, context.state_group)
@@ -486,8 +446,7 @@ class StateTestCase(unittest.TestCase):
     def test_trivial_annotate_state(self):
         prev_event_id = "prev_event_id"
         event = create_event(
-            type="state", state_key="", name="event2",
-            prev_events=[(prev_event_id, {})],
+            type="state", state_key="", name="event2", prev_events=[(prev_event_id, {})]
         )
 
         old_state = [
@@ -497,16 +456,20 @@ class StateTestCase(unittest.TestCase):
         ]
 
         group_name = self.store.store_state_group(
-            prev_event_id, event.room_id, None, None,
+            prev_event_id,
+            event.room_id,
+            None,
+            None,
             {(e.type, e.state_key): e.event_id for e in old_state},
         )
         self.store.register_event_id_state_group(prev_event_id, group_name)
 
         context = yield self.state.compute_event_context(event)
 
+        prev_state_ids = yield context.get_prev_state_ids(self.store)
+
         self.assertEqual(
-            set([e.event_id for e in old_state]),
-            set(context.prev_state_ids.values())
+            set([e.event_id for e in old_state]), set(prev_state_ids.values())
         )
 
         self.assertIsNotNone(context.state_group)
@@ -516,13 +479,12 @@ class StateTestCase(unittest.TestCase):
         prev_event_id1 = "event_id1"
         prev_event_id2 = "event_id2"
         event = create_event(
-            type="test_message", name="event3",
+            type="test_message",
+            name="event3",
             prev_events=[(prev_event_id1, {}), (prev_event_id2, {})],
         )
 
-        creation = create_event(
-            type=EventTypes.Create, state_key=""
-        )
+        creation = create_event(type=EventTypes.Create, state_key="")
 
         old_state_1 = [
             creation,
@@ -542,10 +504,12 @@ class StateTestCase(unittest.TestCase):
         self.store.register_events(old_state_2)
 
         context = yield self._get_context(
-            event, prev_event_id1, old_state_1, prev_event_id2, old_state_2,
+            event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
         )
 
-        self.assertEqual(len(context.current_state_ids), 6)
+        current_state_ids = yield context.get_current_state_ids(self.store)
+
+        self.assertEqual(len(current_state_ids), 6)
 
         self.assertIsNotNone(context.state_group)
 
@@ -554,13 +518,13 @@ class StateTestCase(unittest.TestCase):
         prev_event_id1 = "event_id1"
         prev_event_id2 = "event_id2"
         event = create_event(
-            type="test4", state_key="", name="event",
+            type="test4",
+            state_key="",
+            name="event",
             prev_events=[(prev_event_id1, {}), (prev_event_id2, {})],
         )
 
-        creation = create_event(
-            type=EventTypes.Create, state_key=""
-        )
+        creation = create_event(type=EventTypes.Create, state_key="")
 
         old_state_1 = [
             creation,
@@ -582,10 +546,12 @@ class StateTestCase(unittest.TestCase):
         self.store.get_events = store.get_events
 
         context = yield self._get_context(
-            event, prev_event_id1, old_state_1, prev_event_id2, old_state_2,
+            event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
         )
 
-        self.assertEqual(len(context.current_state_ids), 6)
+        current_state_ids = yield context.get_current_state_ids(self.store)
+
+        self.assertEqual(len(current_state_ids), 6)
 
         self.assertIsNotNone(context.state_group)
 
@@ -594,31 +560,37 @@ class StateTestCase(unittest.TestCase):
         prev_event_id1 = "event_id1"
         prev_event_id2 = "event_id2"
         event = create_event(
-            type="test4", name="event",
+            type="test4",
+            name="event",
             prev_events=[(prev_event_id1, {}), (prev_event_id2, {})],
         )
 
         member_event = create_event(
             type=EventTypes.Member,
             state_key="@user_id:example.com",
-            content={
-                "membership": Membership.JOIN,
-            }
+            content={"membership": Membership.JOIN},
+        )
+
+        power_levels = create_event(
+            type=EventTypes.PowerLevels,
+            state_key="",
+            content={"users": {"@foo:bar": "100", "@user_id:example.com": "100"}},
         )
 
         creation = create_event(
-            type=EventTypes.Create, state_key="",
-            content={"creator": "@foo:bar"}
+            type=EventTypes.Create, state_key="", content={"creator": "@foo:bar"}
         )
 
         old_state_1 = [
             creation,
+            power_levels,
             member_event,
             create_event(type="test1", state_key="1", depth=1),
         ]
 
         old_state_2 = [
             creation,
+            power_levels,
             member_event,
             create_event(type="test1", state_key="1", depth=2),
         ]
@@ -629,24 +601,26 @@ class StateTestCase(unittest.TestCase):
         self.store.get_events = store.get_events
 
         context = yield self._get_context(
-            event, prev_event_id1, old_state_1, prev_event_id2, old_state_2,
+            event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
         )
 
-        self.assertEqual(
-            old_state_2[2].event_id, context.current_state_ids[("test1", "1")]
-        )
+        current_state_ids = yield context.get_current_state_ids(self.store)
+
+        self.assertEqual(old_state_2[3].event_id, current_state_ids[("test1", "1")])
 
         # Reverse the depth to make sure we are actually using the depths
         # during state resolution.
 
         old_state_1 = [
             creation,
+            power_levels,
             member_event,
             create_event(type="test1", state_key="1", depth=2),
         ]
 
         old_state_2 = [
             creation,
+            power_levels,
             member_event,
             create_event(type="test1", state_key="1", depth=1),
         ]
@@ -655,23 +629,30 @@ class StateTestCase(unittest.TestCase):
         store.register_events(old_state_2)
 
         context = yield self._get_context(
-            event, prev_event_id1, old_state_1, prev_event_id2, old_state_2,
+            event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
         )
 
-        self.assertEqual(
-            old_state_1[2].event_id, context.current_state_ids[("test1", "1")]
-        )
+        current_state_ids = yield context.get_current_state_ids(self.store)
+
+        self.assertEqual(old_state_1[3].event_id, current_state_ids[("test1", "1")])
 
-    def _get_context(self, event, prev_event_id_1, old_state_1, prev_event_id_2,
-                     old_state_2):
+    def _get_context(
+        self, event, prev_event_id_1, old_state_1, prev_event_id_2, old_state_2
+    ):
         sg1 = self.store.store_state_group(
-            prev_event_id_1, event.room_id, None, None,
+            prev_event_id_1,
+            event.room_id,
+            None,
+            None,
             {(e.type, e.state_key): e.event_id for e in old_state_1},
         )
         self.store.register_event_id_state_group(prev_event_id_1, sg1)
 
         sg2 = self.store.store_state_group(
-            prev_event_id_2, event.room_id, None, None,
+            prev_event_id_2,
+            event.room_id,
+            None,
+            None,
             {(e.type, e.state_key): e.event_id for e in old_state_2},
         )
         self.store.register_event_id_state_group(prev_event_id_2, sg2)