summary refs log tree commit diff
path: root/tests/util/test_rwlock.py
blob: 12f821d684cd0f36a2be8e24b34181e5bf581bcc (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# [This file includes modifications made by New Vector Limited]
#
#

from typing import AsyncContextManager, Callable, Sequence, Tuple

from twisted.internet import defer
from twisted.internet.defer import CancelledError, Deferred

from synapse.util.async_helpers import ReadWriteLock

from tests import unittest


class ReadWriteLockTestCase(unittest.TestCase):
    def _start_reader_or_writer(
        self,
        read_or_write: Callable[[str], AsyncContextManager],
        key: str,
        return_value: str,
    ) -> Tuple["Deferred[str]", "Deferred[None]", "Deferred[None]"]:
        """Starts a reader or writer which acquires the lock, blocks, then completes.

        Args:
            read_or_write: A function returning a context manager for a lock.
                Either a bound `ReadWriteLock.read` or `ReadWriteLock.write`.
            key: The key to read or write.
            return_value: A string that the reader or writer will resolve with when
                done.

        Returns:
            A tuple of three `Deferred`s:
             * A cancellable `Deferred` for the entire read or write operation that
               resolves with `return_value` on successful completion.
             * A `Deferred` that resolves once the reader or writer acquires the lock.
             * A `Deferred` that blocks the reader or writer. Must be resolved by the
               caller to allow the reader or writer to release the lock and complete.
        """
        acquired_d: "Deferred[None]" = Deferred()
        unblock_d: "Deferred[None]" = Deferred()

        async def reader_or_writer() -> str:
            async with read_or_write(key):
                acquired_d.callback(None)
                await unblock_d
            return return_value

        d = defer.ensureDeferred(reader_or_writer())
        return d, acquired_d, unblock_d

    def _start_blocking_reader(
        self, rwlock: ReadWriteLock, key: str, return_value: str
    ) -> Tuple["Deferred[str]", "Deferred[None]", "Deferred[None]"]:
        """Starts a reader which acquires the lock, blocks, then releases the lock.

        See the docstring for `_start_reader_or_writer` for details about the arguments
        and return values.
        """
        return self._start_reader_or_writer(rwlock.read, key, return_value)

    def _start_blocking_writer(
        self, rwlock: ReadWriteLock, key: str, return_value: str
    ) -> Tuple["Deferred[str]", "Deferred[None]", "Deferred[None]"]:
        """Starts a writer which acquires the lock, blocks, then releases the lock.

        See the docstring for `_start_reader_or_writer` for details about the arguments
        and return values.
        """
        return self._start_reader_or_writer(rwlock.write, key, return_value)

    def _start_nonblocking_reader(
        self, rwlock: ReadWriteLock, key: str, return_value: str
    ) -> Tuple["Deferred[str]", "Deferred[None]"]:
        """Starts a reader which acquires the lock, then releases it immediately.

        See the docstring for `_start_reader_or_writer` for details about the arguments.

        Returns:
            A tuple of two `Deferred`s:
             * A cancellable `Deferred` for the entire read operation that resolves with
               `return_value` on successful completion.
             * A `Deferred` that resolves once the reader acquires the lock.
        """
        d, acquired_d, unblock_d = self._start_reader_or_writer(
            rwlock.read, key, return_value
        )
        unblock_d.callback(None)
        return d, acquired_d

    def _start_nonblocking_writer(
        self, rwlock: ReadWriteLock, key: str, return_value: str
    ) -> Tuple["Deferred[str]", "Deferred[None]"]:
        """Starts a writer which acquires the lock, then releases it immediately.

        See the docstring for `_start_reader_or_writer` for details about the arguments.

        Returns:
            A tuple of two `Deferred`s:
             * A cancellable `Deferred` for the entire write operation that resolves
               with `return_value` on successful completion.
             * A `Deferred` that resolves once the writer acquires the lock.
        """
        d, acquired_d, unblock_d = self._start_reader_or_writer(
            rwlock.write, key, return_value
        )
        unblock_d.callback(None)
        return d, acquired_d

    def _assert_first_n_resolved(
        self, deferreds: Sequence["defer.Deferred[None]"], n: int
    ) -> None:
        """Assert that exactly the first n `Deferred`s in the given list are resolved.

        Args:
            deferreds: The list of `Deferred`s to be checked.
            n: The number of `Deferred`s at the start of `deferreds` that should be
                resolved.
        """
        for i, d in enumerate(deferreds[:n]):
            self.assertTrue(d.called, msg="deferred %d was unexpectedly unresolved" % i)

        for i, d in enumerate(deferreds[n:]):
            self.assertFalse(
                d.called, msg="deferred %d was unexpectedly resolved" % (i + n)
            )

    def test_rwlock(self) -> None:
        rwlock = ReadWriteLock()
        key = "key"

        ds = [
            self._start_blocking_reader(rwlock, key, "0"),
            self._start_blocking_reader(rwlock, key, "1"),
            self._start_blocking_writer(rwlock, key, "2"),
            self._start_blocking_writer(rwlock, key, "3"),
            self._start_blocking_reader(rwlock, key, "4"),
            self._start_blocking_reader(rwlock, key, "5"),
            self._start_blocking_writer(rwlock, key, "6"),
        ]
        # `Deferred`s that resolve when each reader or writer acquires the lock.
        acquired_ds = [acquired_d for _, acquired_d, _ in ds]
        # `Deferred`s that will trigger the release of locks when resolved.
        release_ds = [release_d for _, _, release_d in ds]

        # The first two readers should acquire their locks.
        self._assert_first_n_resolved(acquired_ds, 2)

        # Release one of the read locks. The next writer should not acquire the lock,
        # because there is another reader holding the lock.
        self._assert_first_n_resolved(acquired_ds, 2)
        release_ds[0].callback(None)
        self._assert_first_n_resolved(acquired_ds, 2)

        # Release the other read lock. The next writer should acquire the lock.
        self._assert_first_n_resolved(acquired_ds, 2)
        release_ds[1].callback(None)
        self._assert_first_n_resolved(acquired_ds, 3)

        # Release the write lock. The next writer should acquire the lock.
        self._assert_first_n_resolved(acquired_ds, 3)
        release_ds[2].callback(None)
        self._assert_first_n_resolved(acquired_ds, 4)

        # Release the write lock. The next two readers should acquire locks.
        self._assert_first_n_resolved(acquired_ds, 4)
        release_ds[3].callback(None)
        self._assert_first_n_resolved(acquired_ds, 6)

        # Release one of the read locks. The next writer should not acquire the lock,
        # because there is another reader holding the lock.
        self._assert_first_n_resolved(acquired_ds, 6)
        release_ds[5].callback(None)
        self._assert_first_n_resolved(acquired_ds, 6)

        # Release the other read lock. The next writer should acquire the lock.
        self._assert_first_n_resolved(acquired_ds, 6)
        release_ds[4].callback(None)
        self._assert_first_n_resolved(acquired_ds, 7)

        # Release the write lock.
        release_ds[6].callback(None)

        # Acquire and release the write and read locks one last time for good measure.
        _, acquired_d = self._start_nonblocking_writer(rwlock, key, "last writer")
        self.assertTrue(acquired_d.called)

        _, acquired_d = self._start_nonblocking_reader(rwlock, key, "last reader")
        self.assertTrue(acquired_d.called)

    def test_lock_handoff_to_nonblocking_writer(self) -> None:
        """Test a writer handing the lock to another writer that completes instantly."""
        rwlock = ReadWriteLock()
        key = "key"

        d1, _, unblock = self._start_blocking_writer(rwlock, key, "write 1 completed")
        d2, _ = self._start_nonblocking_writer(rwlock, key, "write 2 completed")
        self.assertFalse(d1.called)
        self.assertFalse(d2.called)

        # Unblock the first writer. The second writer will complete without blocking.
        unblock.callback(None)
        self.assertTrue(d1.called)
        self.assertTrue(d2.called)

        # The `ReadWriteLock` should operate as normal.
        d3, _ = self._start_nonblocking_writer(rwlock, key, "write 3 completed")
        self.assertTrue(d3.called)

    def test_cancellation_while_holding_read_lock(self) -> None:
        """Test cancellation while holding a read lock.

        A waiting writer should be given the lock when the reader holding the lock is
        cancelled.
        """
        rwlock = ReadWriteLock()
        key = "key"

        # 1. A reader takes the lock and blocks.
        reader_d, _, _ = self._start_blocking_reader(rwlock, key, "read completed")

        # 2. A writer waits for the reader to complete.
        writer_d, _ = self._start_nonblocking_writer(rwlock, key, "write completed")
        self.assertFalse(writer_d.called)

        # 3. The reader is cancelled.
        reader_d.cancel()
        self.failureResultOf(reader_d, CancelledError)

        # 4. The writer should take the lock and complete.
        self.assertTrue(
            writer_d.called, "Writer is stuck waiting for a cancelled reader"
        )
        self.assertEqual("write completed", self.successResultOf(writer_d))

    def test_cancellation_while_holding_write_lock(self) -> None:
        """Test cancellation while holding a write lock.

        A waiting reader should be given the lock when the writer holding the lock is
        cancelled.
        """
        rwlock = ReadWriteLock()
        key = "key"

        # 1. A writer takes the lock and blocks.
        writer_d, _, _ = self._start_blocking_writer(rwlock, key, "write completed")

        # 2. A reader waits for the writer to complete.
        reader_d, _ = self._start_nonblocking_reader(rwlock, key, "read completed")
        self.assertFalse(reader_d.called)

        # 3. The writer is cancelled.
        writer_d.cancel()
        self.failureResultOf(writer_d, CancelledError)

        # 4. The reader should take the lock and complete.
        self.assertTrue(
            reader_d.called, "Reader is stuck waiting for a cancelled writer"
        )
        self.assertEqual("read completed", self.successResultOf(reader_d))

    def test_cancellation_while_waiting_for_read_lock(self) -> None:
        """Test cancellation while waiting for a read lock.

        Tests that cancelling a waiting reader:
         * does not cancel the writer it is waiting on
         * does not cancel the next writer waiting on it
         * does not allow the next writer to acquire the lock before an earlier writer
           has finished
         * does not keep the next writer waiting indefinitely

        These correspond to the asserts with explicit messages.
        """
        rwlock = ReadWriteLock()
        key = "key"

        # 1. A writer takes the lock and blocks.
        writer1_d, _, unblock_writer1 = self._start_blocking_writer(
            rwlock, key, "write 1 completed"
        )

        # 2. A reader waits for the first writer to complete.
        #    This reader will be cancelled later.
        reader_d, _ = self._start_nonblocking_reader(rwlock, key, "read completed")
        self.assertFalse(reader_d.called)

        # 3. A second writer waits for both the first writer and the reader to complete.
        writer2_d, _ = self._start_nonblocking_writer(rwlock, key, "write 2 completed")
        self.assertFalse(writer2_d.called)

        # 4. The waiting reader is cancelled.
        #    Neither of the writers should be cancelled.
        #    The second writer should still be waiting, but only on the first writer.
        reader_d.cancel()
        self.failureResultOf(reader_d, CancelledError)
        self.assertFalse(writer1_d.called, "First writer was unexpectedly cancelled")
        self.assertFalse(
            writer2_d.called,
            "Second writer was unexpectedly cancelled or given the lock before the "
            "first writer finished",
        )

        # 5. Unblock the first writer, which should complete.
        unblock_writer1.callback(None)
        self.assertEqual("write 1 completed", self.successResultOf(writer1_d))

        # 6. The second writer should take the lock and complete.
        self.assertTrue(
            writer2_d.called, "Second writer is stuck waiting for a cancelled reader"
        )
        self.assertEqual("write 2 completed", self.successResultOf(writer2_d))

    def test_cancellation_while_waiting_for_write_lock(self) -> None:
        """Test cancellation while waiting for a write lock.

        Tests that cancelling a waiting writer:
         * does not cancel the reader or writer it is waiting on
         * does not cancel the next writer waiting on it
         * does not allow the next writer to acquire the lock before an earlier reader
           and writer have finished
         * does not keep the next writer waiting indefinitely

        These correspond to the asserts with explicit messages.
        """
        rwlock = ReadWriteLock()
        key = "key"

        # 1. A reader takes the lock and blocks.
        reader_d, _, unblock_reader = self._start_blocking_reader(
            rwlock, key, "read completed"
        )

        # 2. A writer waits for the reader to complete.
        writer1_d, _, unblock_writer1 = self._start_blocking_writer(
            rwlock, key, "write 1 completed"
        )

        # 3. A second writer waits for both the reader and first writer to complete.
        #    This writer will be cancelled later.
        writer2_d, _ = self._start_nonblocking_writer(rwlock, key, "write 2 completed")
        self.assertFalse(writer2_d.called)

        # 4. A third writer waits for the second writer to complete.
        writer3_d, _ = self._start_nonblocking_writer(rwlock, key, "write 3 completed")
        self.assertFalse(writer3_d.called)

        # 5. The second writer is cancelled, but continues waiting for the lock.
        #    The reader, first writer and third writer should not be cancelled.
        #    The first writer should still be waiting on the reader.
        #    The third writer should still be waiting on the second writer.
        writer2_d.cancel()
        self.assertNoResult(writer2_d)
        self.assertFalse(reader_d.called, "Reader was unexpectedly cancelled")
        self.assertFalse(writer1_d.called, "First writer was unexpectedly cancelled")
        self.assertFalse(
            writer3_d.called,
            "Third writer was unexpectedly cancelled or given the lock before the first "
            "writer finished",
        )

        # 6. Unblock the reader, which should complete.
        #    The first writer should be given the lock and block.
        #    The third writer should still be waiting on the second writer.
        unblock_reader.callback(None)
        self.assertEqual("read completed", self.successResultOf(reader_d))
        self.assertNoResult(writer2_d)
        self.assertFalse(
            writer3_d.called,
            "Third writer was unexpectedly given the lock before the first writer "
            "finished",
        )

        # 7. Unblock the first writer, which should complete.
        unblock_writer1.callback(None)
        self.assertEqual("write 1 completed", self.successResultOf(writer1_d))

        # 8. The second writer should take the lock and release it immediately, since it
        #    has been cancelled.
        self.failureResultOf(writer2_d, CancelledError)

        # 9. The third writer should take the lock and complete.
        self.assertTrue(
            writer3_d.called, "Third writer is stuck waiting for a cancelled writer"
        )
        self.assertEqual("write 3 completed", self.successResultOf(writer3_d))