summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/replication/slave/storage/_base.py37
1 files changed, 30 insertions, 7 deletions
diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py
index 8708c8a196..a103e7be80 100644
--- a/tests/replication/slave/storage/_base.py
+++ b/tests/replication/slave/storage/_base.py
@@ -11,23 +11,44 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import tempfile
 
 from mock import Mock, NonCallableMock
 
 from twisted.internet import defer, reactor
+from twisted.internet.defer import Deferred
 
 from synapse.replication.tcp.client import (
     ReplicationClientFactory,
     ReplicationClientHandler,
 )
 from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
+from synapse.util.logcontext import PreserveLoggingContext, make_deferred_yieldable
 
 from tests import unittest
 from tests.utils import setup_test_homeserver
 
 
+class TestReplicationClientHandler(ReplicationClientHandler):
+    """Overrides on_rdata so that we can wait for it to happen"""
+    def __init__(self, store):
+        super(TestReplicationClientHandler, self).__init__(store)
+        self._rdata_awaiters = []
+
+    def await_replication(self):
+        d = Deferred()
+        self._rdata_awaiters.append(d)
+        return make_deferred_yieldable(d)
+
+    def on_rdata(self, stream_name, token, rows):
+        awaiters = self._rdata_awaiters
+        self._rdata_awaiters = []
+        super(TestReplicationClientHandler, self).on_rdata(stream_name, token, rows)
+        with PreserveLoggingContext():
+            for a in awaiters:
+                a.callback(None)
+
+
 class BaseSlavedStoreTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def setUp(self):
@@ -52,7 +73,7 @@ class BaseSlavedStoreTestCase(unittest.TestCase):
         self.addCleanup(listener.stopListening)
         self.streamer = server_factory.streamer
 
-        self.replication_handler = ReplicationClientHandler(self.slaved_store)
+        self.replication_handler = TestReplicationClientHandler(self.slaved_store)
         client_factory = ReplicationClientFactory(
             self.hs, "client_name", self.replication_handler
         )
@@ -60,12 +81,14 @@ class BaseSlavedStoreTestCase(unittest.TestCase):
         self.addCleanup(client_factory.stopTrying)
         self.addCleanup(client_connector.disconnect)
 
-    @defer.inlineCallbacks
     def replicate(self):
-        yield self.streamer.on_notifier_poke()
-        d = self.replication_handler.await_sync("replication_test")
-        self.streamer.send_sync_to_all_connections("replication_test")
-        yield d
+        """Tell the master side of replication that something has happened, and then
+        wait for the replication to occur.
+        """
+        # xxx: should we be more specific in what we wait for?
+        d = self.replication_handler.await_replication()
+        self.streamer.on_notifier_poke()
+        return d
 
     @defer.inlineCallbacks
     def check(self, method, args, expected_result=None):