summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/handlers/federation.py3
-rw-r--r--synapse/handlers/message.py3
-rw-r--r--synapse/server.py9
-rw-r--r--tests/replication/slave/storage/_base.py1
-rw-r--r--tests/replication/slave/storage/test_events.py10
-rw-r--r--tests/storage/test_redaction.py11
-rw-r--r--tests/storage/test_room.py3
-rw-r--r--tests/storage/test_roommember.py3
-rw-r--r--tests/storage/test_state.py3
-rw-r--r--tests/test_visibility.py7
-rw-r--r--tests/utils.py10
11 files changed, 44 insertions, 19 deletions
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 4b4c6c15f9..7e9c9aee13 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -109,6 +109,7 @@ class FederationHandler(BaseHandler):
         self.hs = hs
 
         self.store = hs.get_datastore()
+        self.storage = hs.get_storage()
         self.federation_client = hs.get_federation_client()
         self.state_handler = hs.get_state_handler()
         self.server_name = hs.hostname
@@ -2648,7 +2649,7 @@ class FederationHandler(BaseHandler):
                 backfilled=backfilled,
             )
         else:
-            max_stream_id = yield self.store.persist_events(
+            max_stream_id = yield self.storage.persistence.persist_events(
                 event_and_contexts, backfilled=backfilled
             )
 
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 0f8cce8ffe..7908a2d52c 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -234,6 +234,7 @@ class EventCreationHandler(object):
         self.hs = hs
         self.auth = hs.get_auth()
         self.store = hs.get_datastore()
+        self.storage = hs.get_storage()
         self.state = hs.get_state_handler()
         self.clock = hs.get_clock()
         self.validator = EventValidator()
@@ -868,7 +869,7 @@ class EventCreationHandler(object):
             if prev_state_ids:
                 raise AuthError(403, "Changing the room create event is forbidden")
 
-        (event_stream_id, max_stream_id) = yield self.store.persist_event(
+        event_stream_id, max_stream_id = yield self.storage.persistence.persist_event(
             event, context=context
         )
 
diff --git a/synapse/server.py b/synapse/server.py
index 1fcc7375d3..54a7f4aa5f 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -95,6 +95,7 @@ from synapse.server_notices.worker_server_notices_sender import (
     WorkerServerNoticesSender,
 )
 from synapse.state import StateHandler, StateResolutionHandler
+from synapse.storage import DataStores, Storage
 from synapse.streams.events import EventSources
 from synapse.util import Clock
 from synapse.util.distributor import Distributor
@@ -196,6 +197,7 @@ class HomeServer(object):
         "account_validity_handler",
         "saml_handler",
         "event_client_serializer",
+        "storage",
     ]
 
     REQUIRED_ON_MASTER_STARTUP = ["user_directory_handler", "stats_handler"]
@@ -225,6 +227,7 @@ class HomeServer(object):
         self.registration_ratelimiter = Ratelimiter()
 
         self.datastore = None
+        self.datastores = None
 
         # Other kwargs are explicit dependencies
         for depname in kwargs:
@@ -234,6 +237,7 @@ class HomeServer(object):
         logger.info("Setting up.")
         with self.get_db_conn() as conn:
             self.datastore = self.DATASTORE_CLASS(conn, self)
+            self.datastores = DataStores(self.datastore, conn, self)
             conn.commit()
         logger.info("Finished setting up.")
 
@@ -266,7 +270,7 @@ class HomeServer(object):
         return self.clock
 
     def get_datastore(self):
-        return self.datastore
+        return self.datastores.main
 
     def get_config(self):
         return self.config
@@ -537,6 +541,9 @@ class HomeServer(object):
     def build_event_client_serializer(self):
         return EventClientSerializer(self)
 
+    def build_storage(self) -> Storage:
+        return Storage(self, self.datastores)
+
     def remove_pusher(self, app_id, push_key, user_id):
         return self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
 
diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py
index 104349cdbd..4f924ce451 100644
--- a/tests/replication/slave/storage/_base.py
+++ b/tests/replication/slave/storage/_base.py
@@ -41,6 +41,7 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase):
     def prepare(self, reactor, clock, hs):
 
         self.master_store = self.hs.get_datastore()
+        self.storage = hs.get_storage()
         self.slaved_store = self.STORE_TYPE(self.hs.get_db_conn(), self.hs)
         self.event_id = 0
 
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index a368117b43..b68e9fe082 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -234,7 +234,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
             type="m.room.member", sender=USER_ID_2, key=USER_ID_2, membership="join"
         )
         msg, msgctx = self.build_event()
-        self.get_success(self.master_store.persist_events([(j2, j2ctx), (msg, msgctx)]))
+        self.get_success(
+            self.storage.persistence.persist_events([(j2, j2ctx), (msg, msgctx)])
+        )
         self.replicate()
 
         event_source = RoomEventSource(self.hs)
@@ -290,10 +292,12 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
 
         if backfill:
             self.get_success(
-                self.master_store.persist_events([(event, context)], backfilled=True)
+                self.storage.persistence.persist_events(
+                    [(event, context)], backfilled=True
+                )
             )
         else:
-            self.get_success(self.master_store.persist_event(event, context))
+            self.get_success(self.storage.persistence.persist_event(event, context))
 
         return event
 
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index 427d3c49c5..4561c3e383 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -39,6 +39,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
 
     def prepare(self, reactor, clock, hs):
         self.store = hs.get_datastore()
+        self.storage = hs.get_storage()
         self.event_builder_factory = hs.get_event_builder_factory()
         self.event_creation_handler = hs.get_event_creation_handler()
 
@@ -73,7 +74,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
             self.event_creation_handler.create_new_client_event(builder)
         )
 
-        self.get_success(self.store.persist_event(event, context))
+        self.get_success(self.storage.persistence.persist_event(event, context))
 
         return event
 
@@ -95,7 +96,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
             self.event_creation_handler.create_new_client_event(builder)
         )
 
-        self.get_success(self.store.persist_event(event, context))
+        self.get_success(self.storage.persistence.persist_event(event, context))
 
         return event
 
@@ -116,7 +117,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
             self.event_creation_handler.create_new_client_event(builder)
         )
 
-        self.get_success(self.store.persist_event(event, context))
+        self.get_success(self.storage.persistence.persist_event(event, context))
 
         return event
 
@@ -263,7 +264,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
             )
         )
 
-        self.get_success(self.store.persist_event(event_1, context_1))
+        self.get_success(self.storage.persistence.persist_event(event_1, context_1))
 
         event_2, context_2 = self.get_success(
             self.event_creation_handler.create_new_client_event(
@@ -282,7 +283,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
                 )
             )
         )
-        self.get_success(self.store.persist_event(event_2, context_2))
+        self.get_success(self.storage.persistence.persist_event(event_2, context_2))
 
         # fetch one of the redactions
         fetched = self.get_success(self.store.get_event(redaction_event_id1))
diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py
index 1bee45706f..3ddaa151fe 100644
--- a/tests/storage/test_room.py
+++ b/tests/storage/test_room.py
@@ -62,6 +62,7 @@ class RoomEventsStoreTestCase(unittest.TestCase):
         # Room events need the full datastore, for persist_event() and
         # get_room_state()
         self.store = hs.get_datastore()
+        self.storage = hs.get_storage()
         self.event_factory = hs.get_event_factory()
 
         self.room = RoomID.from_string("!abcde:test")
@@ -72,7 +73,7 @@ class RoomEventsStoreTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def inject_room_event(self, **kwargs):
-        yield self.store.persist_event(
+        yield self.storage.persistence.persist_event(
             self.event_factory.create_event(room_id=self.room.to_string(), **kwargs)
         )
 
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index 447a3c6ffb..9ddd17f73d 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -44,6 +44,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
         # We can't test the RoomMemberStore on its own without the other event
         # storage logic
         self.store = hs.get_datastore()
+        self.storage = hs.get_storage()
         self.event_builder_factory = hs.get_event_builder_factory()
         self.event_creation_handler = hs.get_event_creation_handler()
 
@@ -70,7 +71,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
             self.event_creation_handler.create_new_client_event(builder)
         )
 
-        self.get_success(self.store.persist_event(event, context))
+        self.get_success(self.storage.persistence.persist_event(event, context))
 
         return event
 
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index 5c2cf3c2db..d573a3e07b 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -34,6 +34,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
         hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
 
         self.store = hs.get_datastore()
+        self.storage = hs.get_storage()
         self.event_builder_factory = hs.get_event_builder_factory()
         self.event_creation_handler = hs.get_event_creation_handler()
 
@@ -63,7 +64,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
             builder
         )
 
-        yield self.store.persist_event(event, context)
+        yield self.storage.persistence.persist_event(event, context)
 
         return event
 
diff --git a/tests/test_visibility.py b/tests/test_visibility.py
index 18f1a0035d..6ae1ea9b04 100644
--- a/tests/test_visibility.py
+++ b/tests/test_visibility.py
@@ -36,6 +36,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
         self.event_creation_handler = self.hs.get_event_creation_handler()
         self.event_builder_factory = self.hs.get_event_builder_factory()
         self.store = self.hs.get_datastore()
+        self.storage = self.hs.get_storage()
 
         yield create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM")
 
@@ -137,7 +138,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
         event, context = yield self.event_creation_handler.create_new_client_event(
             builder
         )
-        yield self.hs.get_datastore().persist_event(event, context)
+        yield self.storage.persistence.persist_event(event, context)
         return event
 
     @defer.inlineCallbacks
@@ -159,7 +160,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
             builder
         )
 
-        yield self.hs.get_datastore().persist_event(event, context)
+        yield self.storage.persistence.persist_event(event, context)
         return event
 
     @defer.inlineCallbacks
@@ -180,7 +181,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
             builder
         )
 
-        yield self.hs.get_datastore().persist_event(event, context)
+        yield self.storage.persistence.persist_event(event, context)
         return event
 
     @defer.inlineCallbacks
diff --git a/tests/utils.py b/tests/utils.py
index 0a64f75d04..2808eea933 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -326,10 +326,16 @@ def setup_test_homeserver(
         if homeserverToUse.__name__ == "TestHomeServer":
             hs.setup_master()
     else:
+        # If we have been given an explicit datastore we probably want to mock
+        # out the DataStores somehow too. This all feels a bit wrong, but then
+        # mocking the stores feels wrong too.
+        datastores = Mock(datastore=datastore)
+
         hs = homeserverToUse(
             name,
             db_pool=None,
             datastore=datastore,
+            datastores=datastores,
             config=config,
             version_string="Synapse/tests",
             database_engine=db_engine,
@@ -647,7 +653,7 @@ def create_room(hs, room_id, creator_id):
         creator_id (str)
     """
 
-    store = hs.get_datastore()
+    persistence_store = hs.get_storage().persistence
     event_builder_factory = hs.get_event_builder_factory()
     event_creation_handler = hs.get_event_creation_handler()
 
@@ -664,4 +670,4 @@ def create_room(hs, room_id, creator_id):
 
     event, context = yield event_creation_handler.create_new_client_event(builder)
 
-    yield store.persist_event(event, context)
+    yield persistence_store.persist_event(event, context)