summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/4956.bugfix1
-rw-r--r--synapse/replication/slave/storage/events.py31
-rw-r--r--synapse/storage/_base.py5
-rw-r--r--synapse/storage/events.py24
-rw-r--r--tests/replication/slave/storage/_base.py28
-rw-r--r--tests/replication/slave/storage/test_events.py161
-rw-r--r--tests/server.py56
7 files changed, 237 insertions, 69 deletions
diff --git a/changelog.d/4956.bugfix b/changelog.d/4956.bugfix
new file mode 100644
index 0000000000..e50e67383d
--- /dev/null
+++ b/changelog.d/4956.bugfix
@@ -0,0 +1 @@
+Fix sync bug which made accepting invites unreliable in worker-mode synapses.
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index c57385d92f..b457c5563f 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -16,7 +16,10 @@
 import logging
 
 from synapse.api.constants import EventTypes
-from synapse.replication.tcp.streams.events import EventsStreamEventRow
+from synapse.replication.tcp.streams.events import (
+    EventsStreamCurrentStateRow,
+    EventsStreamEventRow,
+)
 from synapse.storage.event_federation import EventFederationWorkerStore
 from synapse.storage.event_push_actions import EventPushActionsWorkerStore
 from synapse.storage.events_worker import EventsWorkerStore
@@ -80,14 +83,7 @@ class SlavedEventStore(EventFederationWorkerStore,
         if stream_name == "events":
             self._stream_id_gen.advance(token)
             for row in rows:
-                if row.type != EventsStreamEventRow.TypeId:
-                    continue
-                data = row.data
-                self.invalidate_caches_for_event(
-                    token, data.event_id, data.room_id, data.type, data.state_key,
-                    data.redacts,
-                    backfilled=False,
-                )
+                self._process_event_stream_row(token, row)
         elif stream_name == "backfill":
             self._backfill_id_gen.advance(-token)
             for row in rows:
@@ -100,6 +96,23 @@ class SlavedEventStore(EventFederationWorkerStore,
             stream_name, token, rows
         )
 
+    def _process_event_stream_row(self, token, row):
+        data = row.data
+
+        if row.type == EventsStreamEventRow.TypeId:
+            self.invalidate_caches_for_event(
+                token, data.event_id, data.room_id, data.type, data.state_key,
+                data.redacts,
+                backfilled=False,
+            )
+        elif row.type == EventsStreamCurrentStateRow.TypeId:
+            if data.type == EventTypes.Member:
+                self.get_rooms_for_user_with_stream_ordering.invalidate(
+                    (data.state_key, ),
+                )
+        else:
+            raise Exception("Unknown events stream row type %s" % (row.type, ))
+
     def invalidate_caches_for_event(self, stream_ordering, event_id, room_id,
                                     etype, state_key, redacts, backfilled):
         self._invalidate_get_event_cache(event_id)
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 7e3903859b..653b32fbf5 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -1355,11 +1355,6 @@ class SQLBaseStore(object):
             members_changed (iterable[str]): The user_ids of members that have
                 changed
         """
-        for member in members_changed:
-            self._attempt_to_invalidate_cache(
-                "get_rooms_for_user_with_stream_ordering", (member,),
-            )
-
         for host in set(get_domain_from_id(u) for u in members_changed):
             self._attempt_to_invalidate_cache(
                 "is_host_joined", (room_id, host,),
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index d0668e39c4..dfda39bbe0 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -79,7 +79,7 @@ def encode_json(json_object):
     """
     out = frozendict_json_encoder.encode(json_object)
     if isinstance(out, bytes):
-        out = out.decode('utf8')
+        out = out.decode("utf8")
     return out
 
 
@@ -813,9 +813,10 @@ class EventsStore(
         """
         all_events_and_contexts = events_and_contexts
 
+        min_stream_order = events_and_contexts[0][0].internal_metadata.stream_ordering
         max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering
 
-        self._update_current_state_txn(txn, state_delta_for_room, max_stream_order)
+        self._update_current_state_txn(txn, state_delta_for_room, min_stream_order)
 
         self._update_forward_extremities_txn(
             txn,
@@ -890,7 +891,7 @@ class EventsStore(
             backfilled=backfilled,
         )
 
-    def _update_current_state_txn(self, txn, state_delta_by_room, max_stream_order):
+    def _update_current_state_txn(self, txn, state_delta_by_room, stream_id):
         for room_id, current_state_tuple in iteritems(state_delta_by_room):
             to_delete, to_insert = current_state_tuple
 
@@ -899,6 +900,12 @@ class EventsStore(
             # that we can use it to calculate the `prev_event_id`. (This
             # allows us to not have to pull out the existing state
             # unnecessarily).
+            #
+            # The stream_id for the update is chosen to be the minimum of the stream_ids
+            # for the batch of the events that we are persisting; that means we do not
+            # end up in a situation where workers see events before the
+            # current_state_delta updates.
+            #
             sql = """
                 INSERT INTO current_state_delta_stream
                 (stream_id, room_id, type, state_key, event_id, prev_event_id)
@@ -911,7 +918,7 @@ class EventsStore(
                 sql,
                 (
                     (
-                        max_stream_order,
+                        stream_id,
                         room_id,
                         etype,
                         state_key,
@@ -929,7 +936,7 @@ class EventsStore(
                 sql,
                 (
                     (
-                        max_stream_order,
+                        stream_id,
                         room_id,
                         etype,
                         state_key,
@@ -970,7 +977,7 @@ class EventsStore(
             txn.call_after(
                 self._curr_state_delta_stream_cache.entity_has_changed,
                 room_id,
-                max_stream_order,
+                stream_id,
             )
 
             # Invalidate the various caches
@@ -986,6 +993,11 @@ class EventsStore(
                 if ev_type == EventTypes.Member
             )
 
+            for member in members_changed:
+                txn.call_after(
+                    self.get_rooms_for_user_with_stream_ordering.invalidate, (member,)
+                )
+
             self._invalidate_state_caches_and_stream(txn, room_id, members_changed)
 
     def _update_forward_extremities_txn(
diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py
index 524af4f8d1..1f72a2a04f 100644
--- a/tests/replication/slave/storage/_base.py
+++ b/tests/replication/slave/storage/_base.py
@@ -56,7 +56,9 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase):
         client = client_factory.buildProtocol(None)
 
         client.makeConnection(FakeTransport(server, reactor))
-        server.makeConnection(FakeTransport(client, reactor))
+
+        self.server_to_client_transport = FakeTransport(client, reactor)
+        server.makeConnection(self.server_to_client_transport)
 
     def replicate(self):
         """Tell the master side of replication that something has happened, and then
@@ -69,6 +71,24 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase):
         master_result = self.get_success(getattr(self.master_store, method)(*args))
         slaved_result = self.get_success(getattr(self.slaved_store, method)(*args))
         if expected_result is not None:
-            self.assertEqual(master_result, expected_result)
-            self.assertEqual(slaved_result, expected_result)
-        self.assertEqual(master_result, slaved_result)
+            self.assertEqual(
+                master_result,
+                expected_result,
+                "Expected master result to be %r but was %r" % (
+                    expected_result, master_result
+                ),
+            )
+            self.assertEqual(
+                slaved_result,
+                expected_result,
+                "Expected slave result to be %r but was %r" % (
+                    expected_result, slaved_result
+                ),
+            )
+        self.assertEqual(
+            master_result,
+            slaved_result,
+            "Slave result %r does not match master result %r" % (
+                slaved_result, master_result
+            ),
+        )
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index 1688a741d1..65ecff3bd6 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -11,11 +11,13 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import logging
 
 from canonicaljson import encode_canonical_json
 
 from synapse.events import FrozenEvent, _EventInternalMetadata
 from synapse.events.snapshot import EventContext
+from synapse.handlers.room import RoomEventSource
 from synapse.replication.slave.storage.events import SlavedEventStore
 from synapse.storage.roommember import RoomsForUser
 
@@ -26,6 +28,8 @@ USER_ID_2 = "@bright:blue"
 OUTLIER = {"outlier": True}
 ROOM_ID = "!room:blue"
 
+logger = logging.getLogger(__name__)
+
 
 def dict_equals(self, other):
     me = encode_canonical_json(self.get_pdu_json())
@@ -172,18 +176,142 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
             {"highlight_count": 1, "notify_count": 2},
         )
 
+    def test_get_rooms_for_user_with_stream_ordering(self):
+        """Check that the cache on get_rooms_for_user_with_stream_ordering is invalidated
+        by rows in the events stream
+        """
+        self.persist(type="m.room.create", key="", creator=USER_ID)
+        self.persist(type="m.room.member", key=USER_ID, membership="join")
+        self.replicate()
+        self.check("get_rooms_for_user_with_stream_ordering", (USER_ID_2,), set())
+
+        j2 = self.persist(
+            type="m.room.member", sender=USER_ID_2, key=USER_ID_2, membership="join"
+        )
+        self.replicate()
+        self.check(
+            "get_rooms_for_user_with_stream_ordering",
+            (USER_ID_2,),
+            {(ROOM_ID, j2.internal_metadata.stream_ordering)},
+        )
+
+    def test_get_rooms_for_user_with_stream_ordering_with_multi_event_persist(self):
+        """Check that current_state invalidation happens correctly with multiple events
+        in the persistence batch.
+
+        This test attempts to reproduce a race condition between the event persistence
+        loop and a worker-based Sync handler.
+
+        The problem occurred when the master persisted several events in one batch. It
+        only updates the current_state at the end of each batch, so the obvious thing
+        to do is then to issue a current_state_delta stream update corresponding to the
+        last stream_id in the batch.
+
+        However, that raises the possibility that a worker will see the replication
+        notification for a join event before the current_state caches are invalidated.
+
+        The test involves:
+         * creating a join and a message event for a user, and persisting them in the
+           same batch
+
+         * controlling the replication stream so that updates are sent gradually
+
+         * between each bunch of replication updates, check that we see a consistent
+           snapshot of the state.
+        """
+        self.persist(type="m.room.create", key="", creator=USER_ID)
+        self.persist(type="m.room.member", key=USER_ID, membership="join")
+        self.replicate()
+        self.check("get_rooms_for_user_with_stream_ordering", (USER_ID_2,), set())
+
+        # limit the replication rate
+        repl_transport = self.server_to_client_transport
+        repl_transport.autoflush = False
+
+        # build the join and message events and persist them in the same batch.
+        logger.info("----- build test events ------")
+        j2, j2ctx = self.build_event(
+            type="m.room.member", sender=USER_ID_2, key=USER_ID_2, membership="join"
+        )
+        msg, msgctx = self.build_event()
+        self.get_success(self.master_store.persist_events([
+            (j2, j2ctx),
+            (msg, msgctx),
+        ]))
+        self.replicate()
+
+        event_source = RoomEventSource(self.hs)
+        event_source.store = self.slaved_store
+        current_token = self.get_success(event_source.get_current_key())
+
+        # gradually stream out the replication
+        while repl_transport.buffer:
+            logger.info("------ flush ------")
+            repl_transport.flush(30)
+            self.pump(0)
+
+            prev_token = current_token
+            current_token = self.get_success(event_source.get_current_key())
+
+            # attempt to replicate the behaviour of the sync handler.
+            #
+            # First, we get a list of the rooms we are joined to
+            joined_rooms = self.get_success(
+                self.slaved_store.get_rooms_for_user_with_stream_ordering(
+                    USER_ID_2,
+                ),
+            )
+
+            # Then, we get a list of the events since the last sync
+            membership_changes = self.get_success(
+                self.slaved_store.get_membership_changes_for_user(
+                    USER_ID_2, prev_token, current_token,
+                )
+            )
+
+            logger.info(
+                "%s->%s: joined_rooms=%r membership_changes=%r",
+                prev_token,
+                current_token,
+                joined_rooms,
+                membership_changes,
+            )
+
+            # the membership change is only any use to us if the room is in the
+            # joined_rooms list.
+            if membership_changes:
+                self.assertEqual(
+                    joined_rooms, {(ROOM_ID, j2.internal_metadata.stream_ordering)}
+                )
+
     event_id = 0
 
-    def persist(
+    def persist(self, backfill=False, **kwargs):
+        """
+        Returns:
+            synapse.events.FrozenEvent: The event that was persisted.
+        """
+        event, context = self.build_event(**kwargs)
+
+        if backfill:
+            self.get_success(
+                self.master_store.persist_events([(event, context)], backfilled=True)
+            )
+        else:
+            self.get_success(
+                self.master_store.persist_event(event, context)
+            )
+
+        return event
+
+    def build_event(
         self,
         sender=USER_ID,
         room_id=ROOM_ID,
-        type={},
+        type="m.room.message",
         key=None,
         internal={},
         state=None,
-        reset_state=False,
-        backfill=False,
         depth=None,
         prev_events=[],
         auth_events=[],
@@ -192,10 +320,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
         push_actions=[],
         **content
     ):
-        """
-        Returns:
-            synapse.events.FrozenEvent: The event that was persisted.
-        """
+
         if depth is None:
             depth = self.event_id
 
@@ -234,23 +359,11 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
             )
         else:
             state_handler = self.hs.get_state_handler()
-            context = self.get_success(state_handler.compute_event_context(event))
+            context = self.get_success(state_handler.compute_event_context(
+                event
+            ))
 
         self.master_store.add_push_actions_to_staging(
             event.event_id, {user_id: actions for user_id, actions in push_actions}
         )
-
-        ordering = None
-        if backfill:
-            self.get_success(
-                self.master_store.persist_events([(event, context)], backfilled=True)
-            )
-        else:
-            ordering, _ = self.get_success(
-                self.master_store.persist_event(event, context)
-            )
-
-        if ordering:
-            event.internal_metadata.stream_ordering = ordering
-
-        return event
+        return event, context
diff --git a/tests/server.py b/tests/server.py
index ea26dea623..8f89f4a83d 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -365,6 +365,7 @@ class FakeTransport(object):
     disconnected = False
     buffer = attr.ib(default=b'')
     producer = attr.ib(default=None)
+    autoflush = attr.ib(default=True)
 
     def getPeer(self):
         return None
@@ -415,31 +416,44 @@ class FakeTransport(object):
     def write(self, byt):
         self.buffer = self.buffer + byt
 
-        def _write():
-            if not self.buffer:
-                # nothing to do. Don't write empty buffers: it upsets the
-                # TLSMemoryBIOProtocol
-                return
-
-            if self.disconnected:
-                return
-            logger.info("%s->%s: %s", self._protocol, self.other, self.buffer)
-
-            if getattr(self.other, "transport") is not None:
-                try:
-                    self.other.dataReceived(self.buffer)
-                    self.buffer = b""
-                except Exception as e:
-                    logger.warning("Exception writing to protocol: %s", e)
-                return
-
-            self._reactor.callLater(0.0, _write)
-
         # always actually do the write asynchronously. Some protocols (notably the
         # TLSMemoryBIOProtocol) get very confused if a read comes back while they are
         # still doing a write. Doing a callLater here breaks the cycle.
-        self._reactor.callLater(0.0, _write)
+        if self.autoflush:
+            self._reactor.callLater(0.0, self.flush)
 
     def writeSequence(self, seq):
         for x in seq:
             self.write(x)
+
+    def flush(self, maxbytes=None):
+        if not self.buffer:
+            # nothing to do. Don't write empty buffers: it upsets the
+            # TLSMemoryBIOProtocol
+            return
+
+        if self.disconnected:
+            return
+
+        if getattr(self.other, "transport") is None:
+            # the other has no transport yet; reschedule
+            if self.autoflush:
+                self._reactor.callLater(0.0, self.flush)
+            return
+
+        if maxbytes is not None:
+            to_write = self.buffer[:maxbytes]
+        else:
+            to_write = self.buffer
+
+        logger.info("%s->%s: %s", self._protocol, self.other, to_write)
+
+        try:
+            self.other.dataReceived(to_write)
+        except Exception as e:
+            logger.warning("Exception writing to protocol: %s", e)
+            return
+
+        self.buffer = self.buffer[len(to_write):]
+        if self.buffer and self.autoflush:
+            self._reactor.callLater(0.0, self.flush)