diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py
index 524af4f8d1..1f72a2a04f 100644
--- a/tests/replication/slave/storage/_base.py
+++ b/tests/replication/slave/storage/_base.py
@@ -56,7 +56,9 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase):
client = client_factory.buildProtocol(None)
client.makeConnection(FakeTransport(server, reactor))
- server.makeConnection(FakeTransport(client, reactor))
+
+ self.server_to_client_transport = FakeTransport(client, reactor)
+ server.makeConnection(self.server_to_client_transport)
def replicate(self):
"""Tell the master side of replication that something has happened, and then
@@ -69,6 +71,24 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase):
master_result = self.get_success(getattr(self.master_store, method)(*args))
slaved_result = self.get_success(getattr(self.slaved_store, method)(*args))
if expected_result is not None:
- self.assertEqual(master_result, expected_result)
- self.assertEqual(slaved_result, expected_result)
- self.assertEqual(master_result, slaved_result)
+ self.assertEqual(
+ master_result,
+ expected_result,
+ "Expected master result to be %r but was %r" % (
+ expected_result, master_result
+ ),
+ )
+ self.assertEqual(
+ slaved_result,
+ expected_result,
+ "Expected slave result to be %r but was %r" % (
+ expected_result, slaved_result
+ ),
+ )
+ self.assertEqual(
+ master_result,
+ slaved_result,
+ "Slave result %r does not match master result %r" % (
+ slaved_result, master_result
+ ),
+ )
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index 1688a741d1..65ecff3bd6 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -11,11 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
from canonicaljson import encode_canonical_json
from synapse.events import FrozenEvent, _EventInternalMetadata
from synapse.events.snapshot import EventContext
+from synapse.handlers.room import RoomEventSource
from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.storage.roommember import RoomsForUser
@@ -26,6 +28,8 @@ USER_ID_2 = "@bright:blue"
OUTLIER = {"outlier": True}
ROOM_ID = "!room:blue"
+logger = logging.getLogger(__name__)
+
def dict_equals(self, other):
me = encode_canonical_json(self.get_pdu_json())
@@ -172,18 +176,142 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
{"highlight_count": 1, "notify_count": 2},
)
+ def test_get_rooms_for_user_with_stream_ordering(self):
+ """Check that the cache on get_rooms_for_user_with_stream_ordering is invalidated
+ by rows in the events stream
+ """
+ self.persist(type="m.room.create", key="", creator=USER_ID)
+ self.persist(type="m.room.member", key=USER_ID, membership="join")
+ self.replicate()
+ self.check("get_rooms_for_user_with_stream_ordering", (USER_ID_2,), set())
+
+ j2 = self.persist(
+ type="m.room.member", sender=USER_ID_2, key=USER_ID_2, membership="join"
+ )
+ self.replicate()
+ self.check(
+ "get_rooms_for_user_with_stream_ordering",
+ (USER_ID_2,),
+ {(ROOM_ID, j2.internal_metadata.stream_ordering)},
+ )
+
+ def test_get_rooms_for_user_with_stream_ordering_with_multi_event_persist(self):
+ """Check that current_state invalidation happens correctly with multiple events
+ in the persistence batch.
+
+ This test attempts to reproduce a race condition between the event persistence
+ loop and a worker-based Sync handler.
+
+ The problem occurred when the master persisted several events in one batch. It
+ only updates the current_state at the end of each batch, so the obvious thing
+ to do is then to issue a current_state_delta stream update corresponding to the
+ last stream_id in the batch.
+
+ However, that raises the possibility that a worker will see the replication
+ notification for a join event before the current_state caches are invalidated.
+
+ The test involves:
+ * creating a join and a message event for a user, and persisting them in the
+ same batch
+
+ * controlling the replication stream so that updates are sent gradually
+
+ * between each bunch of replication updates, check that we see a consistent
+ snapshot of the state.
+ """
+ self.persist(type="m.room.create", key="", creator=USER_ID)
+ self.persist(type="m.room.member", key=USER_ID, membership="join")
+ self.replicate()
+ self.check("get_rooms_for_user_with_stream_ordering", (USER_ID_2,), set())
+
+ # limit the replication rate
+ repl_transport = self.server_to_client_transport
+ repl_transport.autoflush = False
+
+ # build the join and message events and persist them in the same batch.
+ logger.info("----- build test events ------")
+ j2, j2ctx = self.build_event(
+ type="m.room.member", sender=USER_ID_2, key=USER_ID_2, membership="join"
+ )
+ msg, msgctx = self.build_event()
+ self.get_success(self.master_store.persist_events([
+ (j2, j2ctx),
+ (msg, msgctx),
+ ]))
+ self.replicate()
+
+ event_source = RoomEventSource(self.hs)
+ event_source.store = self.slaved_store
+ current_token = self.get_success(event_source.get_current_key())
+
+ # gradually stream out the replication
+ while repl_transport.buffer:
+ logger.info("------ flush ------")
+ repl_transport.flush(30)
+ self.pump(0)
+
+ prev_token = current_token
+ current_token = self.get_success(event_source.get_current_key())
+
+ # attempt to replicate the behaviour of the sync handler.
+ #
+ # First, we get a list of the rooms we are joined to
+ joined_rooms = self.get_success(
+ self.slaved_store.get_rooms_for_user_with_stream_ordering(
+ USER_ID_2,
+ ),
+ )
+
+ # Then, we get a list of the events since the last sync
+ membership_changes = self.get_success(
+ self.slaved_store.get_membership_changes_for_user(
+ USER_ID_2, prev_token, current_token,
+ )
+ )
+
+ logger.info(
+ "%s->%s: joined_rooms=%r membership_changes=%r",
+ prev_token,
+ current_token,
+ joined_rooms,
+ membership_changes,
+ )
+
+ # the membership change is only any use to us if the room is in the
+ # joined_rooms list.
+ if membership_changes:
+ self.assertEqual(
+ joined_rooms, {(ROOM_ID, j2.internal_metadata.stream_ordering)}
+ )
+
event_id = 0
- def persist(
+ def persist(self, backfill=False, **kwargs):
+ """
+ Returns:
+ synapse.events.FrozenEvent: The event that was persisted.
+ """
+ event, context = self.build_event(**kwargs)
+
+ if backfill:
+ self.get_success(
+ self.master_store.persist_events([(event, context)], backfilled=True)
+ )
+ else:
+ self.get_success(
+ self.master_store.persist_event(event, context)
+ )
+
+ return event
+
+ def build_event(
self,
sender=USER_ID,
room_id=ROOM_ID,
- type={},
+ type="m.room.message",
key=None,
internal={},
state=None,
- reset_state=False,
- backfill=False,
depth=None,
prev_events=[],
auth_events=[],
@@ -192,10 +320,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
push_actions=[],
**content
):
- """
- Returns:
- synapse.events.FrozenEvent: The event that was persisted.
- """
+
if depth is None:
depth = self.event_id
@@ -234,23 +359,11 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
)
else:
state_handler = self.hs.get_state_handler()
- context = self.get_success(state_handler.compute_event_context(event))
+ context = self.get_success(state_handler.compute_event_context(
+ event
+ ))
self.master_store.add_push_actions_to_staging(
event.event_id, {user_id: actions for user_id, actions in push_actions}
)
-
- ordering = None
- if backfill:
- self.get_success(
- self.master_store.persist_events([(event, context)], backfilled=True)
- )
- else:
- ordering, _ = self.get_success(
- self.master_store.persist_event(event, context)
- )
-
- if ordering:
- event.internal_metadata.stream_ordering = ordering
-
- return event
+ return event, context
|