From 38e1fac8861f12b707609da06008695a05aaf21c Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 9 Jul 2020 09:52:58 -0400 Subject: Fix some spelling mistakes / typos. (#7811) --- synapse/notifier.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse/notifier.py') diff --git a/synapse/notifier.py b/synapse/notifier.py index 87c120a59c..bd41f77852 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -83,7 +83,7 @@ class _NotifierUserStream(object): self.current_token = current_token # The last token for which we should wake up any streams that have a - # token that comes before it. This gets updated everytime we get poked. + # token that comes before it. This gets updated every time we get poked. # We start it at the current token since if we get any streams # that have a token from before we have no idea whether they should be # woken up or not, so lets just wake them up. -- cgit 1.5.1 From e19de43eb5903c3b6ccca82334971ebc57fc38de Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 4 Aug 2020 07:21:47 -0400 Subject: Convert streams to async. (#8014) --- changelog.d/8014.misc | 1 + synapse/handlers/initial_sync.py | 4 ++-- synapse/handlers/pagination.py | 2 +- synapse/handlers/room.py | 10 +++++----- synapse/handlers/search.py | 2 +- synapse/handlers/sync.py | 2 +- synapse/notifier.py | 4 ++-- synapse/storage/data_stores/main/stream.py | 8 ++++---- synapse/streams/events.py | 22 +++++++++------------- .../test_resource_limits_server_notices.py | 2 +- 10 files changed, 27 insertions(+), 30 deletions(-) create mode 100644 changelog.d/8014.misc (limited to 'synapse/notifier.py') diff --git a/changelog.d/8014.misc b/changelog.d/8014.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/8014.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index f88bad5f25..ae6bd1d352 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -109,7 +109,7 @@ class InitialSyncHandler(BaseHandler): rooms_ret = [] - now_token = await self.hs.get_event_sources().get_current_token() + now_token = self.hs.get_event_sources().get_current_token() presence_stream = self.hs.get_event_sources().sources["presence"] pagination_config = PaginationConfig(from_token=now_token) @@ -360,7 +360,7 @@ class InitialSyncHandler(BaseHandler): current_state.values(), time_now ) - now_token = await self.hs.get_event_sources().get_current_token() + now_token = self.hs.get_event_sources().get_current_token() limit = pagin_config.limit if pagin_config else None if limit is None: diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index da06582d4b..487420bb5d 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -309,7 +309,7 @@ class PaginationHandler(object): room_token = pagin_config.from_token.room_key else: pagin_config.from_token = ( - await self.hs.get_event_sources().get_current_token_for_pagination() + self.hs.get_event_sources().get_current_token_for_pagination() ) room_token = pagin_config.from_token.room_key diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 0c5b99234d..a8545255b1 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -22,7 +22,7 @@ import logging import math import string from collections import OrderedDict -from typing import Optional, Tuple +from typing import Awaitable, Optional, Tuple from synapse.api.constants import ( EventTypes, @@ -1041,7 +1041,7 @@ class RoomEventSource(object): ): # We just ignore the key for now. - to_key = await self.get_current_key() + to_key = self.get_current_key() from_token = RoomStreamToken.parse(from_key) if from_token.topological: @@ -1081,10 +1081,10 @@ class RoomEventSource(object): return (events, end_key) - def get_current_key(self): - return self.store.get_room_events_max_id() + def get_current_key(self) -> str: + return "s%d" % (self.store.get_room_max_stream_ordering(),) - def get_current_key_for_room(self, room_id): + def get_current_key_for_room(self, room_id: str) -> Awaitable[str]: return self.store.get_room_events_max_id(room_id) diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index 9b312a1558..d58f9788c5 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -340,7 +340,7 @@ class SearchHandler(BaseHandler): # If client has asked for "context" for each event (i.e. some surrounding # events and state), fetch that if event_context is not None: - now_token = await self.hs.get_event_sources().get_current_token() + now_token = self.hs.get_event_sources().get_current_token() contexts = {} for event in allowed_events: diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index eaa4eeadf7..5a19bac929 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -961,7 +961,7 @@ class SyncHandler(object): # this is due to some of the underlying streams not supporting the ability # to query up to a given point. # Always use the `now_token` in `SyncResultBuilder` - now_token = await self.event_sources.get_current_token() + now_token = self.event_sources.get_current_token() logger.debug( "Calculating sync response for %r between %s and %s", diff --git a/synapse/notifier.py b/synapse/notifier.py index bd41f77852..22ab4a9da5 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -320,7 +320,7 @@ class Notifier(object): """ user_stream = self.user_to_user_stream.get(user_id) if user_stream is None: - current_token = await self.event_sources.get_current_token() + current_token = self.event_sources.get_current_token() if room_ids is None: room_ids = await self.store.get_rooms_for_user(user_id) user_stream = _NotifierUserStream( @@ -397,7 +397,7 @@ class Notifier(object): """ from_token = pagination_config.from_token if not from_token: - from_token = await self.event_sources.get_current_token() + from_token = self.event_sources.get_current_token() limit = pagination_config.limit diff --git a/synapse/storage/data_stores/main/stream.py b/synapse/storage/data_stores/main/stream.py index 10d39b3699..f1334a6efc 100644 --- a/synapse/storage/data_stores/main/stream.py +++ b/synapse/storage/data_stores/main/stream.py @@ -39,6 +39,7 @@ what sort order was used: import abc import logging from collections import namedtuple +from typing import Optional from twisted.internet import defer @@ -557,19 +558,18 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return self.db.runInteraction("get_room_event_before_stream_ordering", _f) - @defer.inlineCallbacks - def get_room_events_max_id(self, room_id=None): + async def get_room_events_max_id(self, room_id: Optional[str] = None) -> str: """Returns the current token for rooms stream. By default, it returns the current global stream token. Specifying a `room_id` causes it to return the current room specific topological token. """ - token = yield self.get_room_max_stream_ordering() + token = self.get_room_max_stream_ordering() if room_id is None: return "s%d" % (token,) else: - topo = yield self.db.runInteraction( + topo = await self.db.runInteraction( "_get_max_topological_txn", self._get_max_topological_txn, room_id ) return "t%d-%d" % (topo, token) diff --git a/synapse/streams/events.py b/synapse/streams/events.py index 5d3eddcfdc..393e34b9fb 100644 --- a/synapse/streams/events.py +++ b/synapse/streams/events.py @@ -15,8 +15,6 @@ from typing import Any, Dict -from twisted.internet import defer - from synapse.handlers.account_data import AccountDataEventSource from synapse.handlers.presence import PresenceEventSource from synapse.handlers.receipts import ReceiptEventSource @@ -40,19 +38,18 @@ class EventSources(object): } # type: Dict[str, Any] self.store = hs.get_datastore() - @defer.inlineCallbacks - def get_current_token(self): + def get_current_token(self) -> StreamToken: push_rules_key, _ = self.store.get_push_rules_stream_token() to_device_key = self.store.get_to_device_stream_token() device_list_key = self.store.get_device_stream_token() groups_key = self.store.get_group_stream_token() token = StreamToken( - room_key=(yield self.sources["room"].get_current_key()), - presence_key=(yield self.sources["presence"].get_current_key()), - typing_key=(yield self.sources["typing"].get_current_key()), - receipt_key=(yield self.sources["receipt"].get_current_key()), - account_data_key=(yield self.sources["account_data"].get_current_key()), + room_key=self.sources["room"].get_current_key(), + presence_key=self.sources["presence"].get_current_key(), + typing_key=self.sources["typing"].get_current_key(), + receipt_key=self.sources["receipt"].get_current_key(), + account_data_key=self.sources["account_data"].get_current_key(), push_rules_key=push_rules_key, to_device_key=to_device_key, device_list_key=device_list_key, @@ -60,8 +57,7 @@ class EventSources(object): ) return token - @defer.inlineCallbacks - def get_current_token_for_pagination(self): + def get_current_token_for_pagination(self) -> StreamToken: """Get the current token for a given room to be used to paginate events. @@ -69,10 +65,10 @@ class EventSources(object): than `room`, since they are not used during pagination. Returns: - Deferred[StreamToken] + The current token for pagination. """ token = StreamToken( - room_key=(yield self.sources["room"].get_current_key()), + room_key=self.sources["room"].get_current_key(), presence_key=0, typing_key=0, receipt_key=0, diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index 99908edba3..7f70353b0d 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -275,7 +275,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): self.server_notices_manager.get_or_create_notice_room_for_user(self.user_id) ) - token = self.get_success(self.event_source.get_current_token()) + token = self.event_source.get_current_token() events, _ = self.get_success( self.store.get_recent_events_for_room( room_id, limit=100, end_token=token.room_key -- cgit 1.5.1 From a1e9bb9eae7f90e93e0ba02e84e8216e3c2f447a Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 11 Aug 2020 19:40:02 +0100 Subject: Add typing info to Notifier (#8058) --- changelog.d/8058.misc | 1 + synapse/handlers/events.py | 4 -- synapse/notifier.py | 131 ++++++++++++++++++++++++++++----------------- synapse/server.pyi | 6 +++ tox.ini | 1 + 5 files changed, 91 insertions(+), 52 deletions(-) create mode 100644 changelog.d/8058.misc (limited to 'synapse/notifier.py') diff --git a/changelog.d/8058.misc b/changelog.d/8058.misc new file mode 100644 index 0000000000..41a27e5d72 --- /dev/null +++ b/changelog.d/8058.misc @@ -0,0 +1 @@ +Add type hints to `Notifier`. diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index 71a89f09c7..1924636c4d 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -57,13 +57,10 @@ class EventStreamHandler(BaseHandler): timeout=0, as_client_event=True, affect_presence=True, - only_keys=None, room_id=None, is_guest=False, ): """Fetches the events stream for a given user. - - If `only_keys` is not None, events from keys will be sent down. """ if room_id: @@ -93,7 +90,6 @@ class EventStreamHandler(BaseHandler): auth_user, pagin_config, timeout, - only_keys=only_keys, is_guest=is_guest, explicit_room_id=room_id, ) diff --git a/synapse/notifier.py b/synapse/notifier.py index 22ab4a9da5..694efe7116 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -15,7 +15,17 @@ import logging from collections import namedtuple -from typing import Callable, Iterable, List, TypeVar +from typing import ( + Awaitable, + Callable, + Dict, + Iterable, + List, + Optional, + Set, + Tuple, + TypeVar, +) from prometheus_client import Counter @@ -24,12 +34,14 @@ from twisted.internet import defer import synapse.server from synapse.api.constants import EventTypes, Membership from synapse.api.errors import AuthError +from synapse.events import EventBase from synapse.handlers.presence import format_user_presence_state from synapse.logging.context import PreserveLoggingContext from synapse.logging.utils import log_function from synapse.metrics import LaterGauge from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.types import StreamToken +from synapse.streams.config import PaginationConfig +from synapse.types import Collection, StreamToken, UserID from synapse.util.async_helpers import ObservableDeferred, timeout_deferred from synapse.util.metrics import Measure from synapse.visibility import filter_events_for_client @@ -77,7 +89,13 @@ class _NotifierUserStream(object): so that it can remove itself from the indexes in the Notifier class. """ - def __init__(self, user_id, rooms, current_token, time_now_ms): + def __init__( + self, + user_id: str, + rooms: Collection[str], + current_token: StreamToken, + time_now_ms: int, + ): self.user_id = user_id self.rooms = set(rooms) self.current_token = current_token @@ -93,13 +111,13 @@ class _NotifierUserStream(object): with PreserveLoggingContext(): self.notify_deferred = ObservableDeferred(defer.Deferred()) - def notify(self, stream_key, stream_id, time_now_ms): + def notify(self, stream_key: str, stream_id: int, time_now_ms: int): """Notify any listeners for this user of a new event from an event source. Args: - stream_key(str): The stream the event came from. - stream_id(str): The new id for the stream the event came from. - time_now_ms(int): The current time in milliseconds. + stream_key: The stream the event came from. + stream_id: The new id for the stream the event came from. + time_now_ms: The current time in milliseconds. """ self.current_token = self.current_token.copy_and_advance(stream_key, stream_id) self.last_notified_token = self.current_token @@ -112,7 +130,7 @@ class _NotifierUserStream(object): self.notify_deferred = ObservableDeferred(defer.Deferred()) noify_deferred.callback(self.current_token) - def remove(self, notifier): + def remove(self, notifier: "Notifier"): """ Remove this listener from all the indexes in the Notifier it knows about. """ @@ -123,10 +141,10 @@ class _NotifierUserStream(object): notifier.user_to_user_stream.pop(self.user_id) - def count_listeners(self): + def count_listeners(self) -> int: return len(self.notify_deferred.observers()) - def new_listener(self, token): + def new_listener(self, token: StreamToken) -> _NotificationListener: """Returns a deferred that is resolved when there is a new token greater than the given token. @@ -159,14 +177,16 @@ class Notifier(object): UNUSED_STREAM_EXPIRY_MS = 10 * 60 * 1000 def __init__(self, hs: "synapse.server.HomeServer"): - self.user_to_user_stream = {} - self.room_to_user_streams = {} + self.user_to_user_stream = {} # type: Dict[str, _NotifierUserStream] + self.room_to_user_streams = {} # type: Dict[str, Set[_NotifierUserStream]] self.hs = hs self.storage = hs.get_storage() self.event_sources = hs.get_event_sources() self.store = hs.get_datastore() - self.pending_new_room_events = [] + self.pending_new_room_events = ( + [] + ) # type: List[Tuple[int, EventBase, Collection[str]]] # Called when there are new things to stream over replication self.replication_callbacks = [] # type: List[Callable[[], None]] @@ -178,10 +198,9 @@ class Notifier(object): self.clock = hs.get_clock() self.appservice_handler = hs.get_application_service_handler() + self.federation_sender = None if hs.should_send_federation(): self.federation_sender = hs.get_federation_sender() - else: - self.federation_sender = None self.state_handler = hs.get_state_handler() @@ -193,12 +212,12 @@ class Notifier(object): # when rendering the metrics page, which is likely once per minute at # most when scraping it. def count_listeners(): - all_user_streams = set() + all_user_streams = set() # type: Set[_NotifierUserStream] - for x in list(self.room_to_user_streams.values()): - all_user_streams |= x - for x in list(self.user_to_user_stream.values()): - all_user_streams.add(x) + for streams in list(self.room_to_user_streams.values()): + all_user_streams |= streams + for stream in list(self.user_to_user_stream.values()): + all_user_streams.add(stream) return sum(stream.count_listeners() for stream in all_user_streams) @@ -223,7 +242,11 @@ class Notifier(object): self.replication_callbacks.append(cb) def on_new_room_event( - self, event, room_stream_id, max_room_stream_id, extra_users=[] + self, + event: EventBase, + room_stream_id: int, + max_room_stream_id: int, + extra_users: Collection[str] = [], ): """ Used by handlers to inform the notifier something has happened in the room, room event wise. @@ -241,11 +264,11 @@ class Notifier(object): self.notify_replication() - def _notify_pending_new_room_events(self, max_room_stream_id): + def _notify_pending_new_room_events(self, max_room_stream_id: int): """Notify for the room events that were queued waiting for a previous event to be persisted. Args: - max_room_stream_id(int): The highest stream_id below which all + max_room_stream_id: The highest stream_id below which all events have been persisted. """ pending = self.pending_new_room_events @@ -258,7 +281,9 @@ class Notifier(object): else: self._on_new_room_event(event, room_stream_id, extra_users) - def _on_new_room_event(self, event, room_stream_id, extra_users=[]): + def _on_new_room_event( + self, event: EventBase, room_stream_id: int, extra_users: Collection[str] = [] + ): """Notify any user streams that are interested in this room event""" # poke any interested application service. run_as_background_process( @@ -275,13 +300,19 @@ class Notifier(object): "room_key", room_stream_id, users=extra_users, rooms=[event.room_id] ) - async def _notify_app_services(self, room_stream_id): + async def _notify_app_services(self, room_stream_id: int): try: await self.appservice_handler.notify_interested_services(room_stream_id) except Exception: logger.exception("Error notifying application services of event") - def on_new_event(self, stream_key, new_token, users=[], rooms=[]): + def on_new_event( + self, + stream_key: str, + new_token: int, + users: Collection[str] = [], + rooms: Collection[str] = [], + ): """ Used to inform listeners that something has happened event wise. Will wake up all listeners for the given users and rooms. @@ -307,14 +338,19 @@ class Notifier(object): self.notify_replication() - def on_new_replication_data(self): + def on_new_replication_data(self) -> None: """Used to inform replication listeners that something has happend without waking up any of the normal user event streams""" self.notify_replication() async def wait_for_events( - self, user_id, timeout, callback, room_ids=None, from_token=StreamToken.START - ): + self, + user_id: str, + timeout: int, + callback: Callable[[StreamToken, StreamToken], Awaitable[T]], + room_ids=None, + from_token=StreamToken.START, + ) -> T: """Wait until the callback returns a non empty response or the timeout fires. """ @@ -377,19 +413,16 @@ class Notifier(object): async def get_events_for( self, - user, - pagination_config, - timeout, - only_keys=None, - is_guest=False, - explicit_room_id=None, - ): + user: UserID, + pagination_config: PaginationConfig, + timeout: int, + is_guest: bool = False, + explicit_room_id: str = None, + ) -> EventStreamResult: """ For the given user and rooms, return any new events for them. If there are no new events wait for up to `timeout` milliseconds for any new events to happen before returning. - If `only_keys` is not None, events from keys will be sent down. - If explicit_room_id is not set, the user's joined rooms will be polled for events. If explicit_room_id is set, that room will be polled for events only if @@ -404,11 +437,13 @@ class Notifier(object): room_ids, is_joined = await self._get_room_ids(user, explicit_room_id) is_peeking = not is_joined - async def check_for_updates(before_token, after_token): + async def check_for_updates( + before_token: StreamToken, after_token: StreamToken + ) -> EventStreamResult: if not after_token.is_after(before_token): return EventStreamResult([], (from_token, from_token)) - events = [] + events = [] # type: List[EventBase] end_token = from_token for name, source in self.event_sources.sources.items(): @@ -417,8 +452,6 @@ class Notifier(object): after_id = getattr(after_token, keyname) if before_id == after_id: continue - if only_keys and name not in only_keys: - continue new_events, new_key = await source.get_new_events( user=user, @@ -476,7 +509,9 @@ class Notifier(object): return result - async def _get_room_ids(self, user, explicit_room_id): + async def _get_room_ids( + self, user: UserID, explicit_room_id: Optional[str] + ) -> Tuple[Collection[str], bool]: joined_room_ids = await self.store.get_rooms_for_user(user.to_string()) if explicit_room_id: if explicit_room_id in joined_room_ids: @@ -486,7 +521,7 @@ class Notifier(object): raise AuthError(403, "Non-joined access not allowed") return joined_room_ids, True - async def _is_world_readable(self, room_id): + async def _is_world_readable(self, room_id: str) -> bool: state = await self.state_handler.get_current_state( room_id, EventTypes.RoomHistoryVisibility, "" ) @@ -496,7 +531,7 @@ class Notifier(object): return False @log_function - def remove_expired_streams(self): + def remove_expired_streams(self) -> None: time_now_ms = self.clock.time_msec() expired_streams = [] expire_before_ts = time_now_ms - self.UNUSED_STREAM_EXPIRY_MS @@ -510,21 +545,21 @@ class Notifier(object): expired_stream.remove(self) @log_function - def _register_with_keys(self, user_stream): + def _register_with_keys(self, user_stream: _NotifierUserStream): self.user_to_user_stream[user_stream.user_id] = user_stream for room in user_stream.rooms: s = self.room_to_user_streams.setdefault(room, set()) s.add(user_stream) - def _user_joined_room(self, user_id, room_id): + def _user_joined_room(self, user_id: str, room_id: str): new_user_stream = self.user_to_user_stream.get(user_id) if new_user_stream is not None: room_streams = self.room_to_user_streams.setdefault(room_id, set()) room_streams.add(new_user_stream) new_user_stream.rooms.add(room_id) - def notify_replication(self): + def notify_replication(self) -> None: """Notify the any replication listeners that there's a new event""" for cb in self.replication_callbacks: cb() diff --git a/synapse/server.pyi b/synapse/server.pyi index 1aba408c21..6a9bbb4713 100644 --- a/synapse/server.pyi +++ b/synapse/server.pyi @@ -31,8 +31,10 @@ import synapse.server_notices.server_notices_sender import synapse.state import synapse.storage from synapse.events.builder import EventBuilderFactory +from synapse.handlers.appservice import ApplicationServicesHandler from synapse.handlers.typing import FollowerTypingHandler from synapse.replication.tcp.streams import Stream +from synapse.streams.events import EventSources class HomeServer(object): @property @@ -153,3 +155,7 @@ class HomeServer(object): pass def get_typing_handler(self) -> FollowerTypingHandler: pass + def get_event_sources(self) -> EventSources: + pass + def get_application_service_handler(self): + return ApplicationServicesHandler(self) diff --git a/tox.ini b/tox.ini index 9a052c1e33..54f28990ce 100644 --- a/tox.ini +++ b/tox.ini @@ -198,6 +198,7 @@ commands = mypy \ synapse/logging/ \ synapse/metrics \ synapse/module_api \ + synapse/notifier.py \ synapse/push/pusherpool.py \ synapse/push/push_rule_evaluator.py \ synapse/replication \ -- cgit 1.5.1 From 9d1e4942ab728ebfe09ff9a63c66708ceaaf7591 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 12 Aug 2020 14:03:08 +0100 Subject: Fix typing for notifier (#8064) --- changelog.d/8064.misc | 1 + synapse/federation/sender/transaction_manager.py | 7 +++++-- synapse/notifier.py | 12 ++++++++---- synapse/types.py | 23 ++++++++++++++++------- synapse/util/metrics.py | 9 ++++++--- tox.ini | 2 ++ 6 files changed, 38 insertions(+), 16 deletions(-) create mode 100644 changelog.d/8064.misc (limited to 'synapse/notifier.py') diff --git a/changelog.d/8064.misc b/changelog.d/8064.misc new file mode 100644 index 0000000000..41a27e5d72 --- /dev/null +++ b/changelog.d/8064.misc @@ -0,0 +1 @@ +Add type hints to `Notifier`. diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py index 8280f8b900..c7f6cb3d73 100644 --- a/synapse/federation/sender/transaction_manager.py +++ b/synapse/federation/sender/transaction_manager.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, List +from typing import TYPE_CHECKING, List, Tuple from canonicaljson import json @@ -54,7 +54,10 @@ class TransactionManager(object): @measure_func("_send_new_transaction") async def send_new_transaction( - self, destination: str, pending_pdus: List[EventBase], pending_edus: List[Edu] + self, + destination: str, + pending_pdus: List[Tuple[EventBase, int]], + pending_edus: List[Edu], ): # Make a transaction-sending opentracing span. This span follows on from diff --git a/synapse/notifier.py b/synapse/notifier.py index 694efe7116..dfb096e589 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -25,6 +25,7 @@ from typing import ( Set, Tuple, TypeVar, + Union, ) from prometheus_client import Counter @@ -186,7 +187,7 @@ class Notifier(object): self.store = hs.get_datastore() self.pending_new_room_events = ( [] - ) # type: List[Tuple[int, EventBase, Collection[str]]] + ) # type: List[Tuple[int, EventBase, Collection[Union[str, UserID]]]] # Called when there are new things to stream over replication self.replication_callbacks = [] # type: List[Callable[[], None]] @@ -246,7 +247,7 @@ class Notifier(object): event: EventBase, room_stream_id: int, max_room_stream_id: int, - extra_users: Collection[str] = [], + extra_users: Collection[Union[str, UserID]] = [], ): """ Used by handlers to inform the notifier something has happened in the room, room event wise. @@ -282,7 +283,10 @@ class Notifier(object): self._on_new_room_event(event, room_stream_id, extra_users) def _on_new_room_event( - self, event: EventBase, room_stream_id: int, extra_users: Collection[str] = [] + self, + event: EventBase, + room_stream_id: int, + extra_users: Collection[Union[str, UserID]] = [], ): """Notify any user streams that are interested in this room event""" # poke any interested application service. @@ -310,7 +314,7 @@ class Notifier(object): self, stream_key: str, new_token: int, - users: Collection[str] = [], + users: Collection[Union[str, UserID]] = [], rooms: Collection[str] = [], ): """ Used to inform listeners that something has happened event wise. diff --git a/synapse/types.py b/synapse/types.py index 238b938064..9e580f4295 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -13,11 +13,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import abc import re import string import sys from collections import namedtuple -from typing import Any, Dict, Tuple, TypeVar +from typing import Any, Dict, Tuple, Type, TypeVar import attr from signedjson.key import decode_verify_key_bytes @@ -33,7 +34,7 @@ else: T_co = TypeVar("T_co", covariant=True) - class Collection(Iterable[T_co], Container[T_co], Sized): + class Collection(Iterable[T_co], Container[T_co], Sized): # type: ignore __slots__ = () @@ -141,6 +142,9 @@ def get_localpart_from_id(string): return string[1:idx] +DS = TypeVar("DS", bound="DomainSpecificString") + + class DomainSpecificString(namedtuple("DomainSpecificString", ("localpart", "domain"))): """Common base class among ID/name strings that have a local part and a domain name, prefixed with a sigil. @@ -151,6 +155,10 @@ class DomainSpecificString(namedtuple("DomainSpecificString", ("localpart", "dom 'domain' : The domain part of the name """ + __metaclass__ = abc.ABCMeta + + SIGIL = abc.abstractproperty() # type: str # type: ignore + # Deny iteration because it will bite you if you try to create a singleton # set by: # users = set(user) @@ -166,7 +174,7 @@ class DomainSpecificString(namedtuple("DomainSpecificString", ("localpart", "dom return self @classmethod - def from_string(cls, s: str): + def from_string(cls: Type[DS], s: str) -> DS: """Parse the string given by 's' into a structure object.""" if len(s) < 1 or s[0:1] != cls.SIGIL: raise SynapseError( @@ -190,12 +198,12 @@ class DomainSpecificString(namedtuple("DomainSpecificString", ("localpart", "dom # names on one HS return cls(localpart=parts[0], domain=domain) - def to_string(self): + def to_string(self) -> str: """Return a string encoding the fields of the structure object.""" return "%s%s:%s" % (self.SIGIL, self.localpart, self.domain) @classmethod - def is_valid(cls, s): + def is_valid(cls: Type[DS], s: str) -> bool: try: cls.from_string(s) return True @@ -235,8 +243,9 @@ class GroupID(DomainSpecificString): SIGIL = "+" @classmethod - def from_string(cls, s): - group_id = super(GroupID, cls).from_string(s) + def from_string(cls: Type[DS], s: str) -> DS: + group_id = super().from_string(s) # type: DS # type: ignore + if not group_id.localpart: raise SynapseError(400, "Group ID cannot be empty", Codes.INVALID_PARAM) diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py index a805f51df1..13775b43f9 100644 --- a/synapse/util/metrics.py +++ b/synapse/util/metrics.py @@ -15,6 +15,7 @@ import logging from functools import wraps +from typing import Any, Callable, Optional, TypeVar, cast from prometheus_client import Counter @@ -57,8 +58,10 @@ in_flight = InFlightGauge( sub_metrics=["real_time_max", "real_time_sum"], ) +T = TypeVar("T", bound=Callable[..., Any]) -def measure_func(name=None): + +def measure_func(name: Optional[str] = None) -> Callable[[T], T]: """ Used to decorate an async function with a `Measure` context manager. @@ -76,7 +79,7 @@ def measure_func(name=None): """ - def wrapper(func): + def wrapper(func: T) -> T: block_name = func.__name__ if name is None else name @wraps(func) @@ -85,7 +88,7 @@ def measure_func(name=None): r = await func(self, *args, **kwargs) return r - return measured_func + return cast(T, measured_func) return wrapper diff --git a/tox.ini b/tox.ini index 217590edef..45e129580f 100644 --- a/tox.ini +++ b/tox.ini @@ -212,7 +212,9 @@ commands = mypy \ synapse/storage/state.py \ synapse/storage/util \ synapse/streams \ + synapse/types.py \ synapse/util/caches/stream_change_cache.py \ + synapse/util/metrics.py \ tests/replication \ tests/test_utils \ tests/rest/client/v2_alpha/test_auth.py \ -- cgit 1.5.1