summary refs log tree commit diff
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2020-08-28 16:47:11 -0400
committerGitHub <noreply@github.com>2020-08-28 16:47:11 -0400
commitd2ac767de2ce895a965fb2fcdcb883636f19a5c5 (patch)
tree376538b8b1b257e2e7b8e05ecf656bf0079eeb73
parentFix incorrect return signature (diff)
downloadsynapse-d2ac767de2ce895a965fb2fcdcb883636f19a5c5.tar.xz
Convert ReadWriteLock to async/await. (#8202)
-rw-r--r--changelog.d/8202.misc1
-rw-r--r--synapse/handlers/pagination.py49
-rw-r--r--synapse/util/async_helpers.py16
-rw-r--r--tests/util/test_rwlock.py6
4 files changed, 39 insertions, 33 deletions
diff --git a/changelog.d/8202.misc b/changelog.d/8202.misc
new file mode 100644
index 0000000000..dfe4c03171
--- /dev/null
+++ b/changelog.d/8202.misc
@@ -0,0 +1 @@
+Convert various parts of the codebase to async/await.
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index ac3418d69d..5a1aa7d830 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -14,15 +14,18 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
+from typing import Any, Dict, Optional
 
 from twisted.python.failure import Failure
 
 from synapse.api.constants import EventTypes, Membership
 from synapse.api.errors import SynapseError
+from synapse.api.filtering import Filter
 from synapse.logging.context import run_in_background
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage.state import StateFilter
-from synapse.types import RoomStreamToken
+from synapse.streams.config import PaginationConfig
+from synapse.types import Requester, RoomStreamToken
 from synapse.util.async_helpers import ReadWriteLock
 from synapse.util.stringutils import random_string
 from synapse.visibility import filter_events_for_client
@@ -247,15 +250,16 @@ class PaginationHandler(object):
         )
         return purge_id
 
-    async def _purge_history(self, purge_id, room_id, token, delete_local_events):
+    async def _purge_history(
+        self, purge_id: str, room_id: str, token: str, delete_local_events: bool
+    ) -> None:
         """Carry out a history purge on a room.
 
         Args:
-            purge_id (str): The id for this purge
-            room_id (str): The room to purge from
-            token (str): topological token to delete events before
-            delete_local_events (bool): True to delete local events as well as
-                remote ones
+            purge_id: The id for this purge
+            room_id: The room to purge from
+            token: topological token to delete events before
+            delete_local_events: True to delete local events as well as remote ones
         """
         self._purges_in_progress_by_room.add(room_id)
         try:
@@ -291,9 +295,9 @@ class PaginationHandler(object):
         """
         return self._purges_by_id.get(purge_id)
 
-    async def purge_room(self, room_id):
+    async def purge_room(self, room_id: str) -> None:
         """Purge the given room from the database"""
-        with (await self.pagination_lock.write(room_id)):
+        with await self.pagination_lock.write(room_id):
             # check we know about the room
             await self.store.get_room_version_id(room_id)
 
@@ -307,23 +311,22 @@ class PaginationHandler(object):
 
     async def get_messages(
         self,
-        requester,
-        room_id=None,
-        pagin_config=None,
-        as_client_event=True,
-        event_filter=None,
-    ):
+        requester: Requester,
+        room_id: Optional[str] = None,
+        pagin_config: Optional[PaginationConfig] = None,
+        as_client_event: bool = True,
+        event_filter: Optional[Filter] = None,
+    ) -> Dict[str, Any]:
         """Get messages in a room.
 
         Args:
-            requester (Requester): The user requesting messages.
-            room_id (str): The room they want messages from.
-            pagin_config (synapse.api.streams.PaginationConfig): The pagination
-                config rules to apply, if any.
-            as_client_event (bool): True to get events in client-server format.
-            event_filter (Filter): Filter to apply to results or None
+            requester: The user requesting messages.
+            room_id: The room they want messages from.
+            pagin_config: The pagination config rules to apply, if any.
+            as_client_event: True to get events in client-server format.
+            event_filter: Filter to apply to results or None
         Returns:
-            dict: Pagination API results
+            Pagination API results
         """
         user_id = requester.user.to_string()
 
@@ -343,7 +346,7 @@ class PaginationHandler(object):
 
         source_config = pagin_config.get_source_config("room")
 
-        with (await self.pagination_lock.read(room_id)):
+        with await self.pagination_lock.read(room_id):
             (
                 membership,
                 member_event_id,
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index f562770922..dfefbd996d 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -20,6 +20,7 @@ from contextlib import contextmanager
 from typing import Dict, Sequence, Set, Union
 
 import attr
+from typing_extensions import ContextManager
 
 from twisted.internet import defer
 from twisted.internet.defer import CancelledError
@@ -338,11 +339,11 @@ class Linearizer(object):
 
 
 class ReadWriteLock(object):
-    """A deferred style read write lock.
+    """An async read write lock.
 
     Example:
 
-        with (yield read_write_lock.read("test_key")):
+        with await read_write_lock.read("test_key"):
             # do some work
     """
 
@@ -365,8 +366,7 @@ class ReadWriteLock(object):
         # Latest writer queued
         self.key_to_current_writer = {}  # type: Dict[str, defer.Deferred]
 
-    @defer.inlineCallbacks
-    def read(self, key):
+    async def read(self, key: str) -> ContextManager:
         new_defer = defer.Deferred()
 
         curr_readers = self.key_to_current_readers.setdefault(key, set())
@@ -376,7 +376,8 @@ class ReadWriteLock(object):
 
         # We wait for the latest writer to finish writing. We can safely ignore
         # any existing readers... as they're readers.
-        yield make_deferred_yieldable(curr_writer)
+        if curr_writer:
+            await make_deferred_yieldable(curr_writer)
 
         @contextmanager
         def _ctx_manager():
@@ -388,8 +389,7 @@ class ReadWriteLock(object):
 
         return _ctx_manager()
 
-    @defer.inlineCallbacks
-    def write(self, key):
+    async def write(self, key: str) -> ContextManager:
         new_defer = defer.Deferred()
 
         curr_readers = self.key_to_current_readers.get(key, set())
@@ -405,7 +405,7 @@ class ReadWriteLock(object):
         curr_readers.clear()
         self.key_to_current_writer[key] = new_defer
 
-        yield make_deferred_yieldable(defer.gatherResults(to_wait_on))
+        await make_deferred_yieldable(defer.gatherResults(to_wait_on))
 
         @contextmanager
         def _ctx_manager():
diff --git a/tests/util/test_rwlock.py b/tests/util/test_rwlock.py
index bd32e2cee7..d3dea3b52a 100644
--- a/tests/util/test_rwlock.py
+++ b/tests/util/test_rwlock.py
@@ -13,6 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from twisted.internet import defer
 
 from synapse.util.async_helpers import ReadWriteLock
 
@@ -43,6 +44,7 @@ class ReadWriteLockTestCase(unittest.TestCase):
             rwlock.read(key),  # 5
             rwlock.write(key),  # 6
         ]
+        ds = [defer.ensureDeferred(d) for d in ds]
 
         self._assert_called_before_not_after(ds, 2)
 
@@ -73,12 +75,12 @@ class ReadWriteLockTestCase(unittest.TestCase):
         with ds[6].result:
             pass
 
-        d = rwlock.write(key)
+        d = defer.ensureDeferred(rwlock.write(key))
         self.assertTrue(d.called)
         with d.result:
             pass
 
-        d = rwlock.read(key)
+        d = defer.ensureDeferred(rwlock.read(key))
         self.assertTrue(d.called)
         with d.result:
             pass