summary refs log tree commit diff
path: root/tests/test_visibility.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--tests/test_visibility.py46
1 files changed, 33 insertions, 13 deletions
diff --git a/tests/test_visibility.py b/tests/test_visibility.py
index 7a9b01ef9d..f338af6c36 100644
--- a/tests/test_visibility.py
+++ b/tests/test_visibility.py
@@ -34,7 +34,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
         super(FilterEventsForServerTestCase, self).setUp()
         self.event_creation_handler = self.hs.get_event_creation_handler()
         self.event_builder_factory = self.hs.get_event_builder_factory()
-        self.storage = self.hs.get_storage()
+        self._storage_controllers = self.hs.get_storage_controllers()
 
         self.get_success(create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM"))
 
@@ -60,7 +60,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
             events_to_filter.append(evt)
 
         filtered = self.get_success(
-            filter_events_for_server(self.storage, "test_server", events_to_filter)
+            filter_events_for_server(
+                self._storage_controllers, "test_server", events_to_filter
+            )
         )
 
         # the result should be 5 redacted events, and 5 unredacted events.
@@ -80,7 +82,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
         outlier = self._inject_outlier()
         self.assertEqual(
             self.get_success(
-                filter_events_for_server(self.storage, "remote_hs", [outlier])
+                filter_events_for_server(
+                    self._storage_controllers, "remote_hs", [outlier]
+                )
             ),
             [outlier],
         )
@@ -89,7 +93,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
         evt = self._inject_message("@unerased:local_hs")
 
         filtered = self.get_success(
-            filter_events_for_server(self.storage, "remote_hs", [outlier, evt])
+            filter_events_for_server(
+                self._storage_controllers, "remote_hs", [outlier, evt]
+            )
         )
         self.assertEqual(len(filtered), 2, f"expected 2 results, got: {filtered}")
         self.assertEqual(filtered[0], outlier)
@@ -99,7 +105,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
         # ... but other servers should only be able to see the outlier (the other should
         # be redacted)
         filtered = self.get_success(
-            filter_events_for_server(self.storage, "other_server", [outlier, evt])
+            filter_events_for_server(
+                self._storage_controllers, "other_server", [outlier, evt]
+            )
         )
         self.assertEqual(filtered[0], outlier)
         self.assertEqual(filtered[1].event_id, evt.event_id)
@@ -132,7 +140,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
 
         # ... and the filtering happens.
         filtered = self.get_success(
-            filter_events_for_server(self.storage, "test_server", events_to_filter)
+            filter_events_for_server(
+                self._storage_controllers, "test_server", events_to_filter
+            )
         )
 
         for i in range(0, len(events_to_filter)):
@@ -168,7 +178,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
         event, context = self.get_success(
             self.event_creation_handler.create_new_client_event(builder)
         )
-        self.get_success(self.storage.persistence.persist_event(event, context))
+        self.get_success(
+            self._storage_controllers.persistence.persist_event(event, context)
+        )
         return event
 
     def _inject_room_member(
@@ -194,7 +206,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
             self.event_creation_handler.create_new_client_event(builder)
         )
 
-        self.get_success(self.storage.persistence.persist_event(event, context))
+        self.get_success(
+            self._storage_controllers.persistence.persist_event(event, context)
+        )
         return event
 
     def _inject_message(
@@ -216,7 +230,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
             self.event_creation_handler.create_new_client_event(builder)
         )
 
-        self.get_success(self.storage.persistence.persist_event(event, context))
+        self.get_success(
+            self._storage_controllers.persistence.persist_event(event, context)
+        )
         return event
 
     def _inject_outlier(self) -> EventBase:
@@ -234,8 +250,8 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
         event = self.get_success(builder.build(prev_event_ids=[], auth_event_ids=[]))
         event.internal_metadata.outlier = True
         self.get_success(
-            self.storage.persistence.persist_event(
-                event, EventContext.for_outlier(self.storage)
+            self._storage_controllers.persistence.persist_event(
+                event, EventContext.for_outlier(self._storage_controllers)
             )
         )
         return event
@@ -293,7 +309,9 @@ class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase):
         self.assertEqual(
             self.get_success(
                 filter_events_for_client(
-                    self.hs.get_storage(), "@user:test", [invite_event, reject_event]
+                    self.hs.get_storage_controllers(),
+                    "@user:test",
+                    [invite_event, reject_event],
                 )
             ),
             [invite_event, reject_event],
@@ -303,7 +321,9 @@ class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase):
         self.assertEqual(
             self.get_success(
                 filter_events_for_client(
-                    self.hs.get_storage(), "@other:test", [invite_event, reject_event]
+                    self.hs.get_storage_controllers(),
+                    "@other:test",
+                    [invite_event, reject_event],
                 )
             ),
             [],