diff --git a/tests/test_visibility.py b/tests/test_visibility.py
index 6a180ddc32..f7381b2885 100644
--- a/tests/test_visibility.py
+++ b/tests/test_visibility.py
@@ -14,6 +14,8 @@
# limitations under the License.
import logging
+from mock import Mock
+
from twisted.internet import defer
from twisted.internet.defer import succeed
@@ -36,6 +38,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
self.event_creation_handler = self.hs.get_event_creation_handler()
self.event_builder_factory = self.hs.get_event_builder_factory()
self.store = self.hs.get_datastore()
+ self.storage = self.hs.get_storage()
yield create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM")
@@ -62,7 +65,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
events_to_filter.append(evt)
filtered = yield filter_events_for_server(
- self.store, "test_server", events_to_filter
+ self.storage, "test_server", events_to_filter
)
# the result should be 5 redacted events, and 5 unredacted events.
@@ -74,7 +77,6 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
self.assertEqual(events_to_filter[i].event_id, filtered[i].event_id)
self.assertEqual(filtered[i].content["a"], "b")
- @tests.unittest.DEBUG
@defer.inlineCallbacks
def test_erased_user(self):
# 4 message events, from erased and unerased users, with a membership
@@ -101,7 +103,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
# ... and the filtering happens.
filtered = yield filter_events_for_server(
- self.store, "test_server", events_to_filter
+ self.storage, "test_server", events_to_filter
)
for i in range(0, len(events_to_filter)):
@@ -138,8 +140,8 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
event, context = yield self.event_creation_handler.create_new_client_event(
builder
)
- yield self.hs.get_datastore().persist_event(event, context)
- defer.returnValue(event)
+ yield self.storage.persistence.persist_event(event, context)
+ return event
@defer.inlineCallbacks
def inject_room_member(self, user_id, membership="join", extra_content={}):
@@ -160,8 +162,8 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
builder
)
- yield self.hs.get_datastore().persist_event(event, context)
- defer.returnValue(event)
+ yield self.storage.persistence.persist_event(event, context)
+ return event
@defer.inlineCallbacks
def inject_message(self, user_id, content=None):
@@ -181,8 +183,8 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
builder
)
- yield self.hs.get_datastore().persist_event(event, context)
- defer.returnValue(event)
+ yield self.storage.persistence.persist_event(event, context)
+ return event
@defer.inlineCallbacks
def test_large_room(self):
@@ -258,6 +260,11 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
logger.info("Starting filtering")
start = time.time()
+
+ storage = Mock()
+ storage.main = test_store
+ storage.state = test_store
+
filtered = yield filter_events_for_server(
test_store, "test_server", events_to_filter
)
@@ -265,7 +272,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
pr.disable()
with open("filter_events_for_server.profile", "w+") as f:
- ps = pstats.Stats(pr, stream=f).sort_stats('cumulative')
+ ps = pstats.Stats(pr, stream=f).sort_stats("cumulative")
ps.print_stats()
# the result should be 5 redacted events, and 5 unredacted events.
|