summary refs log tree commit diff
path: root/tests/test_state.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_state.py')
-rw-r--r--tests/test_state.py59
1 files changed, 32 insertions, 27 deletions
diff --git a/tests/test_state.py b/tests/test_state.py
index 90800421fb..e4baa69137 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -11,7 +11,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import List, Optional
+from typing import Collection, Dict, List, Optional
 from unittest.mock import Mock
 
 from twisted.internet import defer
@@ -70,7 +70,7 @@ def create_event(
     return event
 
 
-class StateGroupStore:
+class _DummyStore:
     def __init__(self):
         self._event_to_state_group = {}
         self._group_to_state = {}
@@ -105,6 +105,11 @@ class StateGroupStore:
             if e_id in self._event_id_to_event
         }
 
+    async def get_partial_state_events(
+        self, event_ids: Collection[str]
+    ) -> Dict[str, bool]:
+        return {e: False for e in event_ids}
+
     async def get_state_group_delta(self, name):
         return None, None
 
@@ -157,8 +162,8 @@ class Graph:
 
 class StateTestCase(unittest.TestCase):
     def setUp(self):
-        self.store = StateGroupStore()
-        storage = Mock(main=self.store, state=self.store)
+        self.dummy_store = _DummyStore()
+        storage = Mock(main=self.dummy_store, state=self.dummy_store)
         hs = Mock(
             spec_set=[
                 "config",
@@ -173,7 +178,7 @@ class StateTestCase(unittest.TestCase):
             ]
         )
         hs.config = default_config("tesths", True)
-        hs.get_datastores.return_value = Mock(main=self.store)
+        hs.get_datastores.return_value = Mock(main=self.dummy_store)
         hs.get_state_handler.return_value = None
         hs.get_clock.return_value = MockClock()
         hs.get_auth.return_value = Auth(hs)
@@ -198,7 +203,7 @@ class StateTestCase(unittest.TestCase):
             edges={"A": ["START"], "B": ["A"], "C": ["A"], "D": ["B", "C"]},
         )
 
-        self.store.register_events(graph.walk())
+        self.dummy_store.register_events(graph.walk())
 
         context_store: dict[str, EventContext] = {}
 
@@ -206,7 +211,7 @@ class StateTestCase(unittest.TestCase):
             context = yield defer.ensureDeferred(
                 self.state.compute_event_context(event)
             )
-            self.store.register_event_context(event, context)
+            self.dummy_store.register_event_context(event, context)
             context_store[event.event_id] = context
 
         ctx_c = context_store["C"]
@@ -242,7 +247,7 @@ class StateTestCase(unittest.TestCase):
             edges={"A": ["START"], "B": ["A"], "C": ["A"], "D": ["B", "C"]},
         )
 
-        self.store.register_events(graph.walk())
+        self.dummy_store.register_events(graph.walk())
 
         context_store = {}
 
@@ -250,7 +255,7 @@ class StateTestCase(unittest.TestCase):
             context = yield defer.ensureDeferred(
                 self.state.compute_event_context(event)
             )
-            self.store.register_event_context(event, context)
+            self.dummy_store.register_event_context(event, context)
             context_store[event.event_id] = context
 
         # C ends up winning the resolution between B and C
@@ -300,7 +305,7 @@ class StateTestCase(unittest.TestCase):
             edges={"A": ["START"], "B": ["A"], "C": ["B"], "D": ["B"], "E": ["C", "D"]},
         )
 
-        self.store.register_events(graph.walk())
+        self.dummy_store.register_events(graph.walk())
 
         context_store = {}
 
@@ -308,7 +313,7 @@ class StateTestCase(unittest.TestCase):
             context = yield defer.ensureDeferred(
                 self.state.compute_event_context(event)
             )
-            self.store.register_event_context(event, context)
+            self.dummy_store.register_event_context(event, context)
             context_store[event.event_id] = context
 
         # C ends up winning the resolution between C and D because bans win over other
@@ -375,7 +380,7 @@ class StateTestCase(unittest.TestCase):
         self._add_depths(nodes, edges)
         graph = Graph(nodes, edges)
 
-        self.store.register_events(graph.walk())
+        self.dummy_store.register_events(graph.walk())
 
         context_store = {}
 
@@ -383,7 +388,7 @@ class StateTestCase(unittest.TestCase):
             context = yield defer.ensureDeferred(
                 self.state.compute_event_context(event)
             )
-            self.store.register_event_context(event, context)
+            self.dummy_store.register_event_context(event, context)
             context_store[event.event_id] = context
 
         # B ends up winning the resolution between B and C because power levels
@@ -476,7 +481,7 @@ class StateTestCase(unittest.TestCase):
         ]
 
         group_name = yield defer.ensureDeferred(
-            self.store.store_state_group(
+            self.dummy_store.store_state_group(
                 prev_event_id,
                 event.room_id,
                 None,
@@ -484,7 +489,7 @@ class StateTestCase(unittest.TestCase):
                 {(e.type, e.state_key): e.event_id for e in old_state},
             )
         )
-        self.store.register_event_id_state_group(prev_event_id, group_name)
+        self.dummy_store.register_event_id_state_group(prev_event_id, group_name)
 
         context = yield defer.ensureDeferred(self.state.compute_event_context(event))
 
@@ -510,7 +515,7 @@ class StateTestCase(unittest.TestCase):
         ]
 
         group_name = yield defer.ensureDeferred(
-            self.store.store_state_group(
+            self.dummy_store.store_state_group(
                 prev_event_id,
                 event.room_id,
                 None,
@@ -518,7 +523,7 @@ class StateTestCase(unittest.TestCase):
                 {(e.type, e.state_key): e.event_id for e in old_state},
             )
         )
-        self.store.register_event_id_state_group(prev_event_id, group_name)
+        self.dummy_store.register_event_id_state_group(prev_event_id, group_name)
 
         context = yield defer.ensureDeferred(self.state.compute_event_context(event))
 
@@ -554,8 +559,8 @@ class StateTestCase(unittest.TestCase):
             create_event(type="test4", state_key=""),
         ]
 
-        self.store.register_events(old_state_1)
-        self.store.register_events(old_state_2)
+        self.dummy_store.register_events(old_state_1)
+        self.dummy_store.register_events(old_state_2)
 
         context = yield self._get_context(
             event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
@@ -594,10 +599,10 @@ class StateTestCase(unittest.TestCase):
             create_event(type="test4", state_key=""),
         ]
 
-        store = StateGroupStore()
+        store = _DummyStore()
         store.register_events(old_state_1)
         store.register_events(old_state_2)
-        self.store.get_events = store.get_events
+        self.dummy_store.get_events = store.get_events
 
         context = yield self._get_context(
             event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
@@ -649,10 +654,10 @@ class StateTestCase(unittest.TestCase):
             create_event(type="test1", state_key="1", depth=2),
         ]
 
-        store = StateGroupStore()
+        store = _DummyStore()
         store.register_events(old_state_1)
         store.register_events(old_state_2)
-        self.store.get_events = store.get_events
+        self.dummy_store.get_events = store.get_events
 
         context = yield self._get_context(
             event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
@@ -695,7 +700,7 @@ class StateTestCase(unittest.TestCase):
         self, event, prev_event_id_1, old_state_1, prev_event_id_2, old_state_2
     ):
         sg1 = yield defer.ensureDeferred(
-            self.store.store_state_group(
+            self.dummy_store.store_state_group(
                 prev_event_id_1,
                 event.room_id,
                 None,
@@ -703,10 +708,10 @@ class StateTestCase(unittest.TestCase):
                 {(e.type, e.state_key): e.event_id for e in old_state_1},
             )
         )
-        self.store.register_event_id_state_group(prev_event_id_1, sg1)
+        self.dummy_store.register_event_id_state_group(prev_event_id_1, sg1)
 
         sg2 = yield defer.ensureDeferred(
-            self.store.store_state_group(
+            self.dummy_store.store_state_group(
                 prev_event_id_2,
                 event.room_id,
                 None,
@@ -714,7 +719,7 @@ class StateTestCase(unittest.TestCase):
                 {(e.type, e.state_key): e.event_id for e in old_state_2},
             )
         )
-        self.store.register_event_id_state_group(prev_event_id_2, sg2)
+        self.dummy_store.register_event_id_state_group(prev_event_id_2, sg2)
 
         result = yield defer.ensureDeferred(self.state.compute_event_context(event))
         return result