diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index ac3418d69d..5a1aa7d830 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -14,15 +14,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+from typing import Any, Dict, Optional
from twisted.python.failure import Failure
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import SynapseError
+from synapse.api.filtering import Filter
from synapse.logging.context import run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.state import StateFilter
-from synapse.types import RoomStreamToken
+from synapse.streams.config import PaginationConfig
+from synapse.types import Requester, RoomStreamToken
from synapse.util.async_helpers import ReadWriteLock
from synapse.util.stringutils import random_string
from synapse.visibility import filter_events_for_client
@@ -247,15 +250,16 @@ class PaginationHandler(object):
)
return purge_id
- async def _purge_history(self, purge_id, room_id, token, delete_local_events):
+ async def _purge_history(
+ self, purge_id: str, room_id: str, token: str, delete_local_events: bool
+ ) -> None:
"""Carry out a history purge on a room.
Args:
- purge_id (str): The id for this purge
- room_id (str): The room to purge from
- token (str): topological token to delete events before
- delete_local_events (bool): True to delete local events as well as
- remote ones
+ purge_id: The id for this purge
+ room_id: The room to purge from
+ token: topological token to delete events before
+ delete_local_events: True to delete local events as well as remote ones
"""
self._purges_in_progress_by_room.add(room_id)
try:
@@ -291,9 +295,9 @@ class PaginationHandler(object):
"""
return self._purges_by_id.get(purge_id)
- async def purge_room(self, room_id):
+ async def purge_room(self, room_id: str) -> None:
"""Purge the given room from the database"""
- with (await self.pagination_lock.write(room_id)):
+ with await self.pagination_lock.write(room_id):
# check we know about the room
await self.store.get_room_version_id(room_id)
@@ -307,23 +311,22 @@ class PaginationHandler(object):
async def get_messages(
self,
- requester,
- room_id=None,
- pagin_config=None,
- as_client_event=True,
- event_filter=None,
- ):
+ requester: Requester,
+ room_id: Optional[str] = None,
+ pagin_config: Optional[PaginationConfig] = None,
+ as_client_event: bool = True,
+ event_filter: Optional[Filter] = None,
+ ) -> Dict[str, Any]:
"""Get messages in a room.
Args:
- requester (Requester): The user requesting messages.
- room_id (str): The room they want messages from.
- pagin_config (synapse.api.streams.PaginationConfig): The pagination
- config rules to apply, if any.
- as_client_event (bool): True to get events in client-server format.
- event_filter (Filter): Filter to apply to results or None
+ requester: The user requesting messages.
+ room_id: The room they want messages from.
+ pagin_config: The pagination config rules to apply, if any.
+ as_client_event: True to get events in client-server format.
+ event_filter: Filter to apply to results or None
Returns:
- dict: Pagination API results
+ Pagination API results
"""
user_id = requester.user.to_string()
@@ -343,7 +346,7 @@ class PaginationHandler(object):
source_config = pagin_config.get_source_config("room")
- with (await self.pagination_lock.read(room_id)):
+ with await self.pagination_lock.read(room_id):
(
membership,
member_event_id,
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index f562770922..dfefbd996d 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -20,6 +20,7 @@ from contextlib import contextmanager
from typing import Dict, Sequence, Set, Union
import attr
+from typing_extensions import ContextManager
from twisted.internet import defer
from twisted.internet.defer import CancelledError
@@ -338,11 +339,11 @@ class Linearizer(object):
class ReadWriteLock(object):
- """A deferred style read write lock.
+ """An async read write lock.
Example:
- with (yield read_write_lock.read("test_key")):
+ with await read_write_lock.read("test_key"):
# do some work
"""
@@ -365,8 +366,7 @@ class ReadWriteLock(object):
# Latest writer queued
self.key_to_current_writer = {} # type: Dict[str, defer.Deferred]
- @defer.inlineCallbacks
- def read(self, key):
+ async def read(self, key: str) -> ContextManager:
new_defer = defer.Deferred()
curr_readers = self.key_to_current_readers.setdefault(key, set())
@@ -376,7 +376,8 @@ class ReadWriteLock(object):
# We wait for the latest writer to finish writing. We can safely ignore
# any existing readers... as they're readers.
- yield make_deferred_yieldable(curr_writer)
+ if curr_writer:
+ await make_deferred_yieldable(curr_writer)
@contextmanager
def _ctx_manager():
@@ -388,8 +389,7 @@ class ReadWriteLock(object):
return _ctx_manager()
- @defer.inlineCallbacks
- def write(self, key):
+ async def write(self, key: str) -> ContextManager:
new_defer = defer.Deferred()
curr_readers = self.key_to_current_readers.get(key, set())
@@ -405,7 +405,7 @@ class ReadWriteLock(object):
curr_readers.clear()
self.key_to_current_writer[key] = new_defer
- yield make_deferred_yieldable(defer.gatherResults(to_wait_on))
+ await make_deferred_yieldable(defer.gatherResults(to_wait_on))
@contextmanager
def _ctx_manager():
|