diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py
index 8708c8a196..a103e7be80 100644
--- a/tests/replication/slave/storage/_base.py
+++ b/tests/replication/slave/storage/_base.py
@@ -11,23 +11,44 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import tempfile
from mock import Mock, NonCallableMock
from twisted.internet import defer, reactor
+from twisted.internet.defer import Deferred
from synapse.replication.tcp.client import (
ReplicationClientFactory,
ReplicationClientHandler,
)
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
+from synapse.util.logcontext import PreserveLoggingContext, make_deferred_yieldable
from tests import unittest
from tests.utils import setup_test_homeserver
+class TestReplicationClientHandler(ReplicationClientHandler):
+ """Overrides on_rdata so that we can wait for it to happen"""
+ def __init__(self, store):
+ super(TestReplicationClientHandler, self).__init__(store)
+ self._rdata_awaiters = []
+
+ def await_replication(self):
+ d = Deferred()
+ self._rdata_awaiters.append(d)
+ return make_deferred_yieldable(d)
+
+ def on_rdata(self, stream_name, token, rows):
+ awaiters = self._rdata_awaiters
+ self._rdata_awaiters = []
+ super(TestReplicationClientHandler, self).on_rdata(stream_name, token, rows)
+ with PreserveLoggingContext():
+ for a in awaiters:
+ a.callback(None)
+
+
class BaseSlavedStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
@@ -52,7 +73,7 @@ class BaseSlavedStoreTestCase(unittest.TestCase):
self.addCleanup(listener.stopListening)
self.streamer = server_factory.streamer
- self.replication_handler = ReplicationClientHandler(self.slaved_store)
+ self.replication_handler = TestReplicationClientHandler(self.slaved_store)
client_factory = ReplicationClientFactory(
self.hs, "client_name", self.replication_handler
)
@@ -60,12 +81,14 @@ class BaseSlavedStoreTestCase(unittest.TestCase):
self.addCleanup(client_factory.stopTrying)
self.addCleanup(client_connector.disconnect)
- @defer.inlineCallbacks
def replicate(self):
- yield self.streamer.on_notifier_poke()
- d = self.replication_handler.await_sync("replication_test")
- self.streamer.send_sync_to_all_connections("replication_test")
- yield d
+ """Tell the master side of replication that something has happened, and then
+ wait for the replication to occur.
+ """
+ # xxx: should we be more specific in what we wait for?
+ d = self.replication_handler.await_replication()
+ self.streamer.on_notifier_poke()
+ return d
@defer.inlineCallbacks
def check(self, method, args, expected_result=None):
|