summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
authorErik Johnston <erikj@jki.re>2018-07-23 15:14:39 +0100
committerGitHub <noreply@github.com>2018-07-23 15:14:39 +0100
commit9c294ea864e630d9accc1009be8817e94d1be4d8 (patch)
treee53ae887887d0e1aecea4a867714ae57144c1929 /tests
parentMerge pull request #3577 from matrix-org/erikj/cleanup_context (diff)
parentUpdate docstrings to make sense (diff)
downloadsynapse-9c294ea864e630d9accc1009be8817e94d1be4d8.tar.xz
Merge pull request #3579 from matrix-org/erikj/stateless_contexts_4
Add concept of StatelessContext, take 4.
Diffstat (limited to 'tests')
-rw-r--r--tests/replication/slave/storage/test_events.py8
-rw-r--r--tests/test_state.py47
2 files changed, 40 insertions, 15 deletions
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index cea01d93eb..f5b47f5ec0 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -222,9 +222,11 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
             state_ids = {
                 key: e.event_id for key, e in state.items()
             }
-            context = EventContext()
-            context.current_state_ids = state_ids
-            context.prev_state_ids = state_ids
+            context = EventContext.with_state(
+                state_group=None,
+                current_state_ids=state_ids,
+                prev_state_ids=state_ids
+            )
         else:
             state_handler = self.hs.get_state_handler()
             context = yield state_handler.compute_event_context(event)
diff --git a/tests/test_state.py b/tests/test_state.py
index c0f2d1152d..429a18cbf7 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -204,7 +204,8 @@ class StateTestCase(unittest.TestCase):
             self.store.register_event_context(event, context)
             context_store[event.event_id] = context
 
-        self.assertEqual(2, len(context_store["D"].prev_state_ids))
+        prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store)
+        self.assertEqual(2, len(prev_state_ids))
 
     @defer.inlineCallbacks
     def test_branch_basic_conflict(self):
@@ -255,9 +256,11 @@ class StateTestCase(unittest.TestCase):
             self.store.register_event_context(event, context)
             context_store[event.event_id] = context
 
+        prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store)
+
         self.assertSetEqual(
             {"START", "A", "C"},
-            {e_id for e_id in context_store["D"].prev_state_ids.values()}
+            {e_id for e_id in prev_state_ids.values()}
         )
 
     @defer.inlineCallbacks
@@ -318,9 +321,11 @@ class StateTestCase(unittest.TestCase):
             self.store.register_event_context(event, context)
             context_store[event.event_id] = context
 
+        prev_state_ids = yield context_store["E"].get_prev_state_ids(self.store)
+
         self.assertSetEqual(
             {"START", "A", "B", "C"},
-            {e for e in context_store["E"].prev_state_ids.values()}
+            {e for e in prev_state_ids.values()}
         )
 
     @defer.inlineCallbacks
@@ -398,9 +403,11 @@ class StateTestCase(unittest.TestCase):
             self.store.register_event_context(event, context)
             context_store[event.event_id] = context
 
+        prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store)
+
         self.assertSetEqual(
             {"A1", "A2", "A3", "A5", "B"},
-            {e for e in context_store["D"].prev_state_ids.values()}
+            {e for e in prev_state_ids.values()}
         )
 
     def _add_depths(self, nodes, edges):
@@ -429,8 +436,10 @@ class StateTestCase(unittest.TestCase):
             event, old_state=old_state
         )
 
+        current_state_ids = yield context.get_current_state_ids(self.store)
+
         self.assertEqual(
-            set(e.event_id for e in old_state), set(context.current_state_ids.values())
+            set(e.event_id for e in old_state), set(current_state_ids.values())
         )
 
         self.assertIsNotNone(context.state_group)
@@ -449,8 +458,10 @@ class StateTestCase(unittest.TestCase):
             event, old_state=old_state
         )
 
+        prev_state_ids = yield context.get_prev_state_ids(self.store)
+
         self.assertEqual(
-            set(e.event_id for e in old_state), set(context.prev_state_ids.values())
+            set(e.event_id for e in old_state), set(prev_state_ids.values())
         )
 
     @defer.inlineCallbacks
@@ -475,9 +486,11 @@ class StateTestCase(unittest.TestCase):
 
         context = yield self.state.compute_event_context(event)
 
+        current_state_ids = yield context.get_current_state_ids(self.store)
+
         self.assertEqual(
             set([e.event_id for e in old_state]),
-            set(context.current_state_ids.values())
+            set(current_state_ids.values())
         )
 
         self.assertEqual(group_name, context.state_group)
@@ -504,9 +517,11 @@ class StateTestCase(unittest.TestCase):
 
         context = yield self.state.compute_event_context(event)
 
+        prev_state_ids = yield context.get_prev_state_ids(self.store)
+
         self.assertEqual(
             set([e.event_id for e in old_state]),
-            set(context.prev_state_ids.values())
+            set(prev_state_ids.values())
         )
 
         self.assertIsNotNone(context.state_group)
@@ -545,7 +560,9 @@ class StateTestCase(unittest.TestCase):
             event, prev_event_id1, old_state_1, prev_event_id2, old_state_2,
         )
 
-        self.assertEqual(len(context.current_state_ids), 6)
+        current_state_ids = yield context.get_current_state_ids(self.store)
+
+        self.assertEqual(len(current_state_ids), 6)
 
         self.assertIsNotNone(context.state_group)
 
@@ -585,7 +602,9 @@ class StateTestCase(unittest.TestCase):
             event, prev_event_id1, old_state_1, prev_event_id2, old_state_2,
         )
 
-        self.assertEqual(len(context.current_state_ids), 6)
+        current_state_ids = yield context.get_current_state_ids(self.store)
+
+        self.assertEqual(len(current_state_ids), 6)
 
         self.assertIsNotNone(context.state_group)
 
@@ -642,8 +661,10 @@ class StateTestCase(unittest.TestCase):
             event, prev_event_id1, old_state_1, prev_event_id2, old_state_2,
         )
 
+        current_state_ids = yield context.get_current_state_ids(self.store)
+
         self.assertEqual(
-            old_state_2[3].event_id, context.current_state_ids[("test1", "1")]
+            old_state_2[3].event_id, current_state_ids[("test1", "1")]
         )
 
         # Reverse the depth to make sure we are actually using the depths
@@ -670,8 +691,10 @@ class StateTestCase(unittest.TestCase):
             event, prev_event_id1, old_state_1, prev_event_id2, old_state_2,
         )
 
+        current_state_ids = yield context.get_current_state_ids(self.store)
+
         self.assertEqual(
-            old_state_1[3].event_id, context.current_state_ids[("test1", "1")]
+            old_state_1[3].event_id, current_state_ids[("test1", "1")]
         )
 
     def _get_context(self, event, prev_event_id_1, old_state_1, prev_event_id_2,