From ae6ad4cf4190dbd2dadd72fc8131e4c139f8eb4a Mon Sep 17 00:00:00 2001
From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com>
Date: Thu, 27 Sep 2018 11:22:25 +0100
Subject: docstrings and unittests for storage.state (#3958)

I spent ages trying to figure out how I was going mad...
---
 tests/storage/test_state.py | 39 +++++++++++++++++++++++++++++++++++++++
 1 file changed, 39 insertions(+)

(limited to 'tests')

diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index b910965932..b9c5b39d59 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -74,6 +74,45 @@ class StateStoreTestCase(tests.unittest.TestCase):
             self.assertEqual(s1[t].event_id, s2[t].event_id)
         self.assertEqual(len(s1), len(s2))
 
+    @defer.inlineCallbacks
+    def test_get_state_groups_ids(self):
+        e1 = yield self.inject_state_event(
+            self.room, self.u_alice, EventTypes.Create, '', {}
+        )
+        e2 = yield self.inject_state_event(
+            self.room, self.u_alice, EventTypes.Name, '', {"name": "test room"}
+        )
+
+        state_group_map = yield self.store.get_state_groups_ids(self.room, [e2.event_id])
+        self.assertEqual(len(state_group_map), 1)
+        state_map = list(state_group_map.values())[0]
+        self.assertDictEqual(
+            state_map,
+            {
+                (EventTypes.Create, ''): e1.event_id,
+                (EventTypes.Name, ''): e2.event_id,
+            },
+        )
+
+    @defer.inlineCallbacks
+    def test_get_state_groups(self):
+        e1 = yield self.inject_state_event(
+            self.room, self.u_alice, EventTypes.Create, '', {}
+        )
+        e2 = yield self.inject_state_event(
+            self.room, self.u_alice, EventTypes.Name, '', {"name": "test room"}
+        )
+
+        state_group_map = yield self.store.get_state_groups(
+            self.room, [e2.event_id])
+        self.assertEqual(len(state_group_map), 1)
+        state_list = list(state_group_map.values())[0]
+
+        self.assertEqual(
+            {ev.event_id for ev in state_list},
+            {e1.event_id, e2.event_id},
+        )
+
     @defer.inlineCallbacks
     def test_get_state_for_event(self):
 
-- 
cgit 1.5.1