summary refs log tree commit diff
path: root/tests/replication/slave
diff options
context:
space:
mode:
Diffstat (limited to 'tests/replication/slave')
-rw-r--r--tests/replication/slave/storage/_base.py95
-rw-r--r--tests/replication/slave/storage/test_account_data.py20
-rw-r--r--tests/replication/slave/storage/test_events.py110
-rw-r--r--tests/replication/slave/storage/test_receipts.py15
4 files changed, 116 insertions, 124 deletions
diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py
index 65df116efc..089cecfbee 100644
--- a/tests/replication/slave/storage/_base.py
+++ b/tests/replication/slave/storage/_base.py
@@ -1,4 +1,5 @@
 # Copyright 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -11,89 +12,91 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-import tempfile
 
 from mock import Mock, NonCallableMock
 
-from twisted.internet import defer, reactor
-from twisted.internet.defer import Deferred
+import attr
 
 from synapse.replication.tcp.client import (
     ReplicationClientFactory,
     ReplicationClientHandler,
 )
 from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
-from synapse.util.logcontext import PreserveLoggingContext, make_deferred_yieldable
 
 from tests import unittest
-from tests.utils import setup_test_homeserver
 
 
-class TestReplicationClientHandler(ReplicationClientHandler):
-    """Overrides on_rdata so that we can wait for it to happen"""
+class BaseSlavedStoreTestCase(unittest.HomeserverTestCase):
+    def make_homeserver(self, reactor, clock):
 
-    def __init__(self, store):
-        super(TestReplicationClientHandler, self).__init__(store)
-        self._rdata_awaiters = []
-
-    def await_replication(self):
-        d = Deferred()
-        self._rdata_awaiters.append(d)
-        return make_deferred_yieldable(d)
-
-    def on_rdata(self, stream_name, token, rows):
-        awaiters = self._rdata_awaiters
-        self._rdata_awaiters = []
-        super(TestReplicationClientHandler, self).on_rdata(stream_name, token, rows)
-        with PreserveLoggingContext():
-            for a in awaiters:
-                a.callback(None)
-
-
-class BaseSlavedStoreTestCase(unittest.TestCase):
-    @defer.inlineCallbacks
-    def setUp(self):
-        self.hs = yield setup_test_homeserver(
-            self.addCleanup,
+        hs = self.setup_test_homeserver(
             "blue",
-            http_client=None,
             federation_client=Mock(),
             ratelimiter=NonCallableMock(spec_set=["send_message"]),
         )
-        self.hs.get_ratelimiter().send_message.return_value = (True, 0)
+
+        hs.get_ratelimiter().send_message.return_value = (True, 0)
+
+        return hs
+
+    def prepare(self, reactor, clock, hs):
 
         self.master_store = self.hs.get_datastore()
         self.slaved_store = self.STORE_TYPE(self.hs.get_db_conn(), self.hs)
         self.event_id = 0
 
         server_factory = ReplicationStreamProtocolFactory(self.hs)
-        # XXX: mktemp is unsafe and should never be used. but we're just a test.
-        path = tempfile.mktemp(prefix="base_slaved_store_test_case_socket")
-        listener = reactor.listenUNIX(path, server_factory)
-        self.addCleanup(listener.stopListening)
         self.streamer = server_factory.streamer
 
-        self.replication_handler = TestReplicationClientHandler(self.slaved_store)
+        self.replication_handler = ReplicationClientHandler(self.slaved_store)
         client_factory = ReplicationClientFactory(
             self.hs, "client_name", self.replication_handler
         )
-        client_connector = reactor.connectUNIX(path, client_factory)
-        self.addCleanup(client_factory.stopTrying)
-        self.addCleanup(client_connector.disconnect)
+
+        server = server_factory.buildProtocol(None)
+        client = client_factory.buildProtocol(None)
+
+        @attr.s
+        class FakeTransport(object):
+
+            other = attr.ib()
+            disconnecting = False
+            buffer = attr.ib(default=b'')
+
+            def registerProducer(self, producer, streaming):
+
+                self.producer = producer
+
+                def _produce():
+                    self.producer.resumeProducing()
+                    reactor.callLater(0.1, _produce)
+
+                reactor.callLater(0.0, _produce)
+
+            def write(self, byt):
+                self.buffer = self.buffer + byt
+
+                if getattr(self.other, "transport") is not None:
+                    self.other.dataReceived(self.buffer)
+                    self.buffer = b""
+
+            def writeSequence(self, seq):
+                for x in seq:
+                    self.write(x)
+
+        client.makeConnection(FakeTransport(server))
+        server.makeConnection(FakeTransport(client))
 
     def replicate(self):
         """Tell the master side of replication that something has happened, and then
         wait for the replication to occur.
         """
-        # xxx: should we be more specific in what we wait for?
-        d = self.replication_handler.await_replication()
         self.streamer.on_notifier_poke()
-        return d
+        self.pump(0.1)
 
-    @defer.inlineCallbacks
     def check(self, method, args, expected_result=None):
-        master_result = yield getattr(self.master_store, method)(*args)
-        slaved_result = yield getattr(self.slaved_store, method)(*args)
+        master_result = self.get_success(getattr(self.master_store, method)(*args))
+        slaved_result = self.get_success(getattr(self.slaved_store, method)(*args))
         if expected_result is not None:
             self.assertEqual(master_result, expected_result)
             self.assertEqual(slaved_result, expected_result)
diff --git a/tests/replication/slave/storage/test_account_data.py b/tests/replication/slave/storage/test_account_data.py
index 87cc2b2fba..43e3248703 100644
--- a/tests/replication/slave/storage/test_account_data.py
+++ b/tests/replication/slave/storage/test_account_data.py
@@ -12,9 +12,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
-from twisted.internet import defer
-
 from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
 
 from ._base import BaseSlavedStoreTestCase
@@ -27,16 +24,19 @@ class SlavedAccountDataStoreTestCase(BaseSlavedStoreTestCase):
 
     STORE_TYPE = SlavedAccountDataStore
 
-    @defer.inlineCallbacks
     def test_user_account_data(self):
-        yield self.master_store.add_account_data_for_user(USER_ID, TYPE, {"a": 1})
-        yield self.replicate()
-        yield self.check(
+        self.get_success(
+            self.master_store.add_account_data_for_user(USER_ID, TYPE, {"a": 1})
+        )
+        self.replicate()
+        self.check(
             "get_global_account_data_by_type_for_user", [TYPE, USER_ID], {"a": 1}
         )
 
-        yield self.master_store.add_account_data_for_user(USER_ID, TYPE, {"a": 2})
-        yield self.replicate()
-        yield self.check(
+        self.get_success(
+            self.master_store.add_account_data_for_user(USER_ID, TYPE, {"a": 2})
+        )
+        self.replicate()
+        self.check(
             "get_global_account_data_by_type_for_user", [TYPE, USER_ID], {"a": 2}
         )
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index 2ba80ccdcf..db44d33c68 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -12,8 +12,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from twisted.internet import defer
-
 from synapse.events import FrozenEvent, _EventInternalMetadata
 from synapse.events.snapshot import EventContext
 from synapse.replication.slave.storage.events import SlavedEventStore
@@ -55,70 +53,66 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
     def tearDown(self):
         [unpatch() for unpatch in self.unpatches]
 
-    @defer.inlineCallbacks
     def test_get_latest_event_ids_in_room(self):
-        create = yield self.persist(type="m.room.create", key="", creator=USER_ID)
-        yield self.replicate()
-        yield self.check("get_latest_event_ids_in_room", (ROOM_ID,), [create.event_id])
+        create = self.persist(type="m.room.create", key="", creator=USER_ID)
+        self.replicate()
+        self.check("get_latest_event_ids_in_room", (ROOM_ID,), [create.event_id])
 
-        join = yield self.persist(
+        join = self.persist(
             type="m.room.member",
             key=USER_ID,
             membership="join",
             prev_events=[(create.event_id, {})],
         )
-        yield self.replicate()
-        yield self.check("get_latest_event_ids_in_room", (ROOM_ID,), [join.event_id])
+        self.replicate()
+        self.check("get_latest_event_ids_in_room", (ROOM_ID,), [join.event_id])
 
-    @defer.inlineCallbacks
     def test_redactions(self):
-        yield self.persist(type="m.room.create", key="", creator=USER_ID)
-        yield self.persist(type="m.room.member", key=USER_ID, membership="join")
+        self.persist(type="m.room.create", key="", creator=USER_ID)
+        self.persist(type="m.room.member", key=USER_ID, membership="join")
 
-        msg = yield self.persist(type="m.room.message", msgtype="m.text", body="Hello")
-        yield self.replicate()
-        yield self.check("get_event", [msg.event_id], msg)
+        msg = self.persist(type="m.room.message", msgtype="m.text", body="Hello")
+        self.replicate()
+        self.check("get_event", [msg.event_id], msg)
 
-        redaction = yield self.persist(type="m.room.redaction", redacts=msg.event_id)
-        yield self.replicate()
+        redaction = self.persist(type="m.room.redaction", redacts=msg.event_id)
+        self.replicate()
 
         msg_dict = msg.get_dict()
         msg_dict["content"] = {}
         msg_dict["unsigned"]["redacted_by"] = redaction.event_id
         msg_dict["unsigned"]["redacted_because"] = redaction
         redacted = FrozenEvent(msg_dict, msg.internal_metadata.get_dict())
-        yield self.check("get_event", [msg.event_id], redacted)
+        self.check("get_event", [msg.event_id], redacted)
 
-    @defer.inlineCallbacks
     def test_backfilled_redactions(self):
-        yield self.persist(type="m.room.create", key="", creator=USER_ID)
-        yield self.persist(type="m.room.member", key=USER_ID, membership="join")
+        self.persist(type="m.room.create", key="", creator=USER_ID)
+        self.persist(type="m.room.member", key=USER_ID, membership="join")
 
-        msg = yield self.persist(type="m.room.message", msgtype="m.text", body="Hello")
-        yield self.replicate()
-        yield self.check("get_event", [msg.event_id], msg)
+        msg = self.persist(type="m.room.message", msgtype="m.text", body="Hello")
+        self.replicate()
+        self.check("get_event", [msg.event_id], msg)
 
-        redaction = yield self.persist(
+        redaction = self.persist(
             type="m.room.redaction", redacts=msg.event_id, backfill=True
         )
-        yield self.replicate()
+        self.replicate()
 
         msg_dict = msg.get_dict()
         msg_dict["content"] = {}
         msg_dict["unsigned"]["redacted_by"] = redaction.event_id
         msg_dict["unsigned"]["redacted_because"] = redaction
         redacted = FrozenEvent(msg_dict, msg.internal_metadata.get_dict())
-        yield self.check("get_event", [msg.event_id], redacted)
+        self.check("get_event", [msg.event_id], redacted)
 
-    @defer.inlineCallbacks
     def test_invites(self):
-        yield self.persist(type="m.room.create", key="", creator=USER_ID)
-        yield self.check("get_invited_rooms_for_user", [USER_ID_2], [])
-        event = yield self.persist(
-            type="m.room.member", key=USER_ID_2, membership="invite"
-        )
-        yield self.replicate()
-        yield self.check(
+        self.persist(type="m.room.create", key="", creator=USER_ID)
+        self.check("get_invited_rooms_for_user", [USER_ID_2], [])
+        event = self.persist(type="m.room.member", key=USER_ID_2, membership="invite")
+
+        self.replicate()
+
+        self.check(
             "get_invited_rooms_for_user",
             [USER_ID_2],
             [
@@ -132,37 +126,34 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
             ],
         )
 
-    @defer.inlineCallbacks
     def test_push_actions_for_user(self):
-        yield self.persist(type="m.room.create", key="", creator=USER_ID)
-        yield self.persist(type="m.room.join", key=USER_ID, membership="join")
-        yield self.persist(
+        self.persist(type="m.room.create", key="", creator=USER_ID)
+        self.persist(type="m.room.join", key=USER_ID, membership="join")
+        self.persist(
             type="m.room.join", sender=USER_ID, key=USER_ID_2, membership="join"
         )
-        event1 = yield self.persist(
-            type="m.room.message", msgtype="m.text", body="hello"
-        )
-        yield self.replicate()
-        yield self.check(
+        event1 = self.persist(type="m.room.message", msgtype="m.text", body="hello")
+        self.replicate()
+        self.check(
             "get_unread_event_push_actions_by_room_for_user",
             [ROOM_ID, USER_ID_2, event1.event_id],
             {"highlight_count": 0, "notify_count": 0},
         )
 
-        yield self.persist(
+        self.persist(
             type="m.room.message",
             msgtype="m.text",
             body="world",
             push_actions=[(USER_ID_2, ["notify"])],
         )
-        yield self.replicate()
-        yield self.check(
+        self.replicate()
+        self.check(
             "get_unread_event_push_actions_by_room_for_user",
             [ROOM_ID, USER_ID_2, event1.event_id],
             {"highlight_count": 0, "notify_count": 1},
         )
 
-        yield self.persist(
+        self.persist(
             type="m.room.message",
             msgtype="m.text",
             body="world",
@@ -170,8 +161,8 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
                 (USER_ID_2, ["notify", {"set_tweak": "highlight", "value": True}])
             ],
         )
-        yield self.replicate()
-        yield self.check(
+        self.replicate()
+        self.check(
             "get_unread_event_push_actions_by_room_for_user",
             [ROOM_ID, USER_ID_2, event1.event_id],
             {"highlight_count": 1, "notify_count": 2},
@@ -179,7 +170,6 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
 
     event_id = 0
 
-    @defer.inlineCallbacks
     def persist(
         self,
         sender=USER_ID,
@@ -206,8 +196,8 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
             depth = self.event_id
 
         if not prev_events:
-            latest_event_ids = yield self.master_store.get_latest_event_ids_in_room(
-                room_id
+            latest_event_ids = self.get_success(
+                self.master_store.get_latest_event_ids_in_room(room_id)
             )
             prev_events = [(ev_id, {}) for ev_id in latest_event_ids]
 
@@ -240,19 +230,23 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
             )
         else:
             state_handler = self.hs.get_state_handler()
-            context = yield state_handler.compute_event_context(event)
+            context = self.get_success(state_handler.compute_event_context(event))
 
-        yield self.master_store.add_push_actions_to_staging(
+        self.master_store.add_push_actions_to_staging(
             event.event_id, {user_id: actions for user_id, actions in push_actions}
         )
 
         ordering = None
         if backfill:
-            yield self.master_store.persist_events([(event, context)], backfilled=True)
+            self.get_success(
+                self.master_store.persist_events([(event, context)], backfilled=True)
+            )
         else:
-            ordering, _ = yield self.master_store.persist_event(event, context)
+            ordering, _ = self.get_success(
+                self.master_store.persist_event(event, context)
+            )
 
         if ordering:
             event.internal_metadata.stream_ordering = ordering
 
-        defer.returnValue(event)
+        return event
diff --git a/tests/replication/slave/storage/test_receipts.py b/tests/replication/slave/storage/test_receipts.py
index ae1adeded1..f47d94f690 100644
--- a/tests/replication/slave/storage/test_receipts.py
+++ b/tests/replication/slave/storage/test_receipts.py
@@ -12,8 +12,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from twisted.internet import defer
-
 from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
 
 from ._base import BaseSlavedStoreTestCase
@@ -27,13 +25,10 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase):
 
     STORE_TYPE = SlavedReceiptsStore
 
-    @defer.inlineCallbacks
     def test_receipt(self):
-        yield self.check("get_receipts_for_user", [USER_ID, "m.read"], {})
-        yield self.master_store.insert_receipt(
-            ROOM_ID, "m.read", USER_ID, [EVENT_ID], {}
-        )
-        yield self.replicate()
-        yield self.check(
-            "get_receipts_for_user", [USER_ID, "m.read"], {ROOM_ID: EVENT_ID}
+        self.check("get_receipts_for_user", [USER_ID, "m.read"], {})
+        self.get_success(
+            self.master_store.insert_receipt(ROOM_ID, "m.read", USER_ID, [EVENT_ID], {})
         )
+        self.replicate()
+        self.check("get_receipts_for_user", [USER_ID, "m.read"], {ROOM_ID: EVENT_ID})