summary refs log tree commit diff
path: root/tests/test_visibility.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_visibility.py')
-rw-r--r--tests/test_visibility.py18
1 files changed, 13 insertions, 5 deletions
diff --git a/tests/test_visibility.py b/tests/test_visibility.py
index 18f1a0035d..f7381b2885 100644
--- a/tests/test_visibility.py
+++ b/tests/test_visibility.py
@@ -14,6 +14,8 @@
 # limitations under the License.
 import logging
 
+from mock import Mock
+
 from twisted.internet import defer
 from twisted.internet.defer import succeed
 
@@ -36,6 +38,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
         self.event_creation_handler = self.hs.get_event_creation_handler()
         self.event_builder_factory = self.hs.get_event_builder_factory()
         self.store = self.hs.get_datastore()
+        self.storage = self.hs.get_storage()
 
         yield create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM")
 
@@ -62,7 +65,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
             events_to_filter.append(evt)
 
         filtered = yield filter_events_for_server(
-            self.store, "test_server", events_to_filter
+            self.storage, "test_server", events_to_filter
         )
 
         # the result should be 5 redacted events, and 5 unredacted events.
@@ -100,7 +103,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
 
         # ... and the filtering happens.
         filtered = yield filter_events_for_server(
-            self.store, "test_server", events_to_filter
+            self.storage, "test_server", events_to_filter
         )
 
         for i in range(0, len(events_to_filter)):
@@ -137,7 +140,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
         event, context = yield self.event_creation_handler.create_new_client_event(
             builder
         )
-        yield self.hs.get_datastore().persist_event(event, context)
+        yield self.storage.persistence.persist_event(event, context)
         return event
 
     @defer.inlineCallbacks
@@ -159,7 +162,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
             builder
         )
 
-        yield self.hs.get_datastore().persist_event(event, context)
+        yield self.storage.persistence.persist_event(event, context)
         return event
 
     @defer.inlineCallbacks
@@ -180,7 +183,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
             builder
         )
 
-        yield self.hs.get_datastore().persist_event(event, context)
+        yield self.storage.persistence.persist_event(event, context)
         return event
 
     @defer.inlineCallbacks
@@ -257,6 +260,11 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
 
         logger.info("Starting filtering")
         start = time.time()
+
+        storage = Mock()
+        storage.main = test_store
+        storage.state = test_store
+
         filtered = yield filter_events_for_server(
             test_store, "test_server", events_to_filter
         )