summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/replication/slave/storage/_base.py28
-rw-r--r--tests/replication/slave/storage/test_events.py161
-rw-r--r--tests/server.py56
3 files changed, 196 insertions, 49 deletions
diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py
index 524af4f8d1..1f72a2a04f 100644
--- a/tests/replication/slave/storage/_base.py
+++ b/tests/replication/slave/storage/_base.py
@@ -56,7 +56,9 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase):
         client = client_factory.buildProtocol(None)
 
         client.makeConnection(FakeTransport(server, reactor))
-        server.makeConnection(FakeTransport(client, reactor))
+
+        self.server_to_client_transport = FakeTransport(client, reactor)
+        server.makeConnection(self.server_to_client_transport)
 
     def replicate(self):
         """Tell the master side of replication that something has happened, and then
@@ -69,6 +71,24 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase):
         master_result = self.get_success(getattr(self.master_store, method)(*args))
         slaved_result = self.get_success(getattr(self.slaved_store, method)(*args))
         if expected_result is not None:
-            self.assertEqual(master_result, expected_result)
-            self.assertEqual(slaved_result, expected_result)
-        self.assertEqual(master_result, slaved_result)
+            self.assertEqual(
+                master_result,
+                expected_result,
+                "Expected master result to be %r but was %r" % (
+                    expected_result, master_result
+                ),
+            )
+            self.assertEqual(
+                slaved_result,
+                expected_result,
+                "Expected slave result to be %r but was %r" % (
+                    expected_result, slaved_result
+                ),
+            )
+        self.assertEqual(
+            master_result,
+            slaved_result,
+            "Slave result %r does not match master result %r" % (
+                slaved_result, master_result
+            ),
+        )
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index 1688a741d1..65ecff3bd6 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -11,11 +11,13 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import logging
 
 from canonicaljson import encode_canonical_json
 
 from synapse.events import FrozenEvent, _EventInternalMetadata
 from synapse.events.snapshot import EventContext
+from synapse.handlers.room import RoomEventSource
 from synapse.replication.slave.storage.events import SlavedEventStore
 from synapse.storage.roommember import RoomsForUser
 
@@ -26,6 +28,8 @@ USER_ID_2 = "@bright:blue"
 OUTLIER = {"outlier": True}
 ROOM_ID = "!room:blue"
 
+logger = logging.getLogger(__name__)
+
 
 def dict_equals(self, other):
     me = encode_canonical_json(self.get_pdu_json())
@@ -172,18 +176,142 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
             {"highlight_count": 1, "notify_count": 2},
         )
 
+    def test_get_rooms_for_user_with_stream_ordering(self):
+        """Check that the cache on get_rooms_for_user_with_stream_ordering is invalidated
+        by rows in the events stream
+        """
+        self.persist(type="m.room.create", key="", creator=USER_ID)
+        self.persist(type="m.room.member", key=USER_ID, membership="join")
+        self.replicate()
+        self.check("get_rooms_for_user_with_stream_ordering", (USER_ID_2,), set())
+
+        j2 = self.persist(
+            type="m.room.member", sender=USER_ID_2, key=USER_ID_2, membership="join"
+        )
+        self.replicate()
+        self.check(
+            "get_rooms_for_user_with_stream_ordering",
+            (USER_ID_2,),
+            {(ROOM_ID, j2.internal_metadata.stream_ordering)},
+        )
+
+    def test_get_rooms_for_user_with_stream_ordering_with_multi_event_persist(self):
+        """Check that current_state invalidation happens correctly with multiple events
+        in the persistence batch.
+
+        This test attempts to reproduce a race condition between the event persistence
+        loop and a worker-based Sync handler.
+
+        The problem occurred when the master persisted several events in one batch. It
+        only updates the current_state at the end of each batch, so the obvious thing
+        to do is then to issue a current_state_delta stream update corresponding to the
+        last stream_id in the batch.
+
+        However, that raises the possibility that a worker will see the replication
+        notification for a join event before the current_state caches are invalidated.
+
+        The test involves:
+         * creating a join and a message event for a user, and persisting them in the
+           same batch
+
+         * controlling the replication stream so that updates are sent gradually
+
+         * between each bunch of replication updates, check that we see a consistent
+           snapshot of the state.
+        """
+        self.persist(type="m.room.create", key="", creator=USER_ID)
+        self.persist(type="m.room.member", key=USER_ID, membership="join")
+        self.replicate()
+        self.check("get_rooms_for_user_with_stream_ordering", (USER_ID_2,), set())
+
+        # limit the replication rate
+        repl_transport = self.server_to_client_transport
+        repl_transport.autoflush = False
+
+        # build the join and message events and persist them in the same batch.
+        logger.info("----- build test events ------")
+        j2, j2ctx = self.build_event(
+            type="m.room.member", sender=USER_ID_2, key=USER_ID_2, membership="join"
+        )
+        msg, msgctx = self.build_event()
+        self.get_success(self.master_store.persist_events([
+            (j2, j2ctx),
+            (msg, msgctx),
+        ]))
+        self.replicate()
+
+        event_source = RoomEventSource(self.hs)
+        event_source.store = self.slaved_store
+        current_token = self.get_success(event_source.get_current_key())
+
+        # gradually stream out the replication
+        while repl_transport.buffer:
+            logger.info("------ flush ------")
+            repl_transport.flush(30)
+            self.pump(0)
+
+            prev_token = current_token
+            current_token = self.get_success(event_source.get_current_key())
+
+            # attempt to replicate the behaviour of the sync handler.
+            #
+            # First, we get a list of the rooms we are joined to
+            joined_rooms = self.get_success(
+                self.slaved_store.get_rooms_for_user_with_stream_ordering(
+                    USER_ID_2,
+                ),
+            )
+
+            # Then, we get a list of the events since the last sync
+            membership_changes = self.get_success(
+                self.slaved_store.get_membership_changes_for_user(
+                    USER_ID_2, prev_token, current_token,
+                )
+            )
+
+            logger.info(
+                "%s->%s: joined_rooms=%r membership_changes=%r",
+                prev_token,
+                current_token,
+                joined_rooms,
+                membership_changes,
+            )
+
+            # the membership change is only any use to us if the room is in the
+            # joined_rooms list.
+            if membership_changes:
+                self.assertEqual(
+                    joined_rooms, {(ROOM_ID, j2.internal_metadata.stream_ordering)}
+                )
+
     event_id = 0
 
-    def persist(
+    def persist(self, backfill=False, **kwargs):
+        """
+        Returns:
+            synapse.events.FrozenEvent: The event that was persisted.
+        """
+        event, context = self.build_event(**kwargs)
+
+        if backfill:
+            self.get_success(
+                self.master_store.persist_events([(event, context)], backfilled=True)
+            )
+        else:
+            self.get_success(
+                self.master_store.persist_event(event, context)
+            )
+
+        return event
+
+    def build_event(
         self,
         sender=USER_ID,
         room_id=ROOM_ID,
-        type={},
+        type="m.room.message",
         key=None,
         internal={},
         state=None,
-        reset_state=False,
-        backfill=False,
         depth=None,
         prev_events=[],
         auth_events=[],
@@ -192,10 +320,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
         push_actions=[],
         **content
     ):
-        """
-        Returns:
-            synapse.events.FrozenEvent: The event that was persisted.
-        """
+
         if depth is None:
             depth = self.event_id
 
@@ -234,23 +359,11 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
             )
         else:
             state_handler = self.hs.get_state_handler()
-            context = self.get_success(state_handler.compute_event_context(event))
+            context = self.get_success(state_handler.compute_event_context(
+                event
+            ))
 
         self.master_store.add_push_actions_to_staging(
             event.event_id, {user_id: actions for user_id, actions in push_actions}
         )
-
-        ordering = None
-        if backfill:
-            self.get_success(
-                self.master_store.persist_events([(event, context)], backfilled=True)
-            )
-        else:
-            ordering, _ = self.get_success(
-                self.master_store.persist_event(event, context)
-            )
-
-        if ordering:
-            event.internal_metadata.stream_ordering = ordering
-
-        return event
+        return event, context
diff --git a/tests/server.py b/tests/server.py
index ea26dea623..8f89f4a83d 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -365,6 +365,7 @@ class FakeTransport(object):
     disconnected = False
     buffer = attr.ib(default=b'')
     producer = attr.ib(default=None)
+    autoflush = attr.ib(default=True)
 
     def getPeer(self):
         return None
@@ -415,31 +416,44 @@ class FakeTransport(object):
     def write(self, byt):
         self.buffer = self.buffer + byt
 
-        def _write():
-            if not self.buffer:
-                # nothing to do. Don't write empty buffers: it upsets the
-                # TLSMemoryBIOProtocol
-                return
-
-            if self.disconnected:
-                return
-            logger.info("%s->%s: %s", self._protocol, self.other, self.buffer)
-
-            if getattr(self.other, "transport") is not None:
-                try:
-                    self.other.dataReceived(self.buffer)
-                    self.buffer = b""
-                except Exception as e:
-                    logger.warning("Exception writing to protocol: %s", e)
-                return
-
-            self._reactor.callLater(0.0, _write)
-
         # always actually do the write asynchronously. Some protocols (notably the
         # TLSMemoryBIOProtocol) get very confused if a read comes back while they are
         # still doing a write. Doing a callLater here breaks the cycle.
-        self._reactor.callLater(0.0, _write)
+        if self.autoflush:
+            self._reactor.callLater(0.0, self.flush)
 
     def writeSequence(self, seq):
         for x in seq:
             self.write(x)
+
+    def flush(self, maxbytes=None):
+        if not self.buffer:
+            # nothing to do. Don't write empty buffers: it upsets the
+            # TLSMemoryBIOProtocol
+            return
+
+        if self.disconnected:
+            return
+
+        if getattr(self.other, "transport") is None:
+            # the other has no transport yet; reschedule
+            if self.autoflush:
+                self._reactor.callLater(0.0, self.flush)
+            return
+
+        if maxbytes is not None:
+            to_write = self.buffer[:maxbytes]
+        else:
+            to_write = self.buffer
+
+        logger.info("%s->%s: %s", self._protocol, self.other, to_write)
+
+        try:
+            self.other.dataReceived(to_write)
+        except Exception as e:
+            logger.warning("Exception writing to protocol: %s", e)
+            return
+
+        self.buffer = self.buffer[len(to_write):]
+        if self.buffer and self.autoflush:
+            self._reactor.callLater(0.0, self.flush)