diff options
51 files changed, 507 insertions, 332 deletions
diff --git a/changelog.d/14866.bugfix b/changelog.d/14866.bugfix new file mode 100644 index 0000000000..540f918cbd --- /dev/null +++ b/changelog.d/14866.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 1.53.0 where `next_batch` tokens from `/sync` could not be used with the `/relations` endpoint. diff --git a/changelog.d/14879.misc b/changelog.d/14879.misc new file mode 100644 index 0000000000..d44571b731 --- /dev/null +++ b/changelog.d/14879.misc @@ -0,0 +1 @@ +Add missing type hints. diff --git a/changelog.d/14880.bugfix b/changelog.d/14880.bugfix new file mode 100644 index 0000000000..e56c567082 --- /dev/null +++ b/changelog.d/14880.bugfix @@ -0,0 +1 @@ +Fix a bug when using the `send_local_online_presence_to` module API. diff --git a/changelog.d/14886.misc b/changelog.d/14886.misc new file mode 100644 index 0000000000..9f5384e60e --- /dev/null +++ b/changelog.d/14886.misc @@ -0,0 +1 @@ +Add missing type hints. \ No newline at end of file diff --git a/changelog.d/14887.misc b/changelog.d/14887.misc new file mode 100644 index 0000000000..9f5384e60e --- /dev/null +++ b/changelog.d/14887.misc @@ -0,0 +1 @@ +Add missing type hints. \ No newline at end of file diff --git a/changelog.d/14904.misc b/changelog.d/14904.misc new file mode 100644 index 0000000000..d44571b731 --- /dev/null +++ b/changelog.d/14904.misc @@ -0,0 +1 @@ +Add missing type hints. diff --git a/changelog.d/14916.misc b/changelog.d/14916.misc new file mode 100644 index 0000000000..59914d4b8a --- /dev/null +++ b/changelog.d/14916.misc @@ -0,0 +1 @@ +Document how to handle Dependabot pull requests. diff --git a/changelog.d/14920.misc b/changelog.d/14920.misc new file mode 100644 index 0000000000..7988897d39 --- /dev/null +++ b/changelog.d/14920.misc @@ -0,0 +1 @@ +Fix typo in release script. diff --git a/changelog.d/14922.misc b/changelog.d/14922.misc new file mode 100644 index 0000000000..2cc3614dfd --- /dev/null +++ b/changelog.d/14922.misc @@ -0,0 +1 @@ +Use `StrCollection` to avoid potential bugs with `Collection[str]`. diff --git a/changelog.d/14927.misc b/changelog.d/14927.misc new file mode 100644 index 0000000000..9f5384e60e --- /dev/null +++ b/changelog.d/14927.misc @@ -0,0 +1 @@ +Add missing type hints. \ No newline at end of file diff --git a/docs/development/dependencies.md b/docs/development/dependencies.md index b734cc5826..c4449c51f7 100644 --- a/docs/development/dependencies.md +++ b/docs/development/dependencies.md @@ -258,6 +258,20 @@ because [`build`](https://github.com/pypa/build) is a standardish tool which doesn't require poetry. (It's what we use in CI too). However, you could try `poetry build` too. +## ...handle a Dependabot pull request? + +Synapse uses Dependabot to keep the `poetry.lock` file up-to-date. When it +creates a pull request a GitHub Action will run to automatically create a changelog +file. Ensure that: + +* the lockfile changes look reasonable; +* the upstream changelog file (linked in the description) doesn't include any + breaking changes; +* continuous integration passes (due to permissions, the GitHub Actions run on + the changelog commit will fail, look at the initial commit of the pull request); + +In particular, any updates to the type hints (usually packages which start with `types-`) +should be safe to merge if linting passes. # Troubleshooting diff --git a/mypy.ini b/mypy.ini index 63366dad99..978d92940b 100644 --- a/mypy.ini +++ b/mypy.ini @@ -35,19 +35,12 @@ exclude = (?x) |tests/api/test_auth.py |tests/app/test_openid_listener.py |tests/appservice/test_scheduler.py - |tests/events/test_presence_router.py - |tests/events/test_utils.py |tests/federation/test_federation_catch_up.py |tests/federation/test_federation_sender.py - |tests/federation/transport/test_knocking.py - |tests/handlers/test_typing.py |tests/http/federation/test_matrix_federation_agent.py |tests/http/federation/test_srv_resolver.py |tests/http/test_proxyagent.py - |tests/logging/__init__.py - |tests/logging/test_terse_json.py |tests/module_api/test_api.py - |tests/rest/client/test_transactions.py |tests/rest/media/v1/test_media_storage.py |tests/server.py |tests/test_state.py @@ -87,12 +80,18 @@ disallow_untyped_defs = True [mypy-tests.crypto.*] disallow_untyped_defs = True +[mypy-tests.events.*] +disallow_untyped_defs = True + [mypy-tests.federation.transport.test_client] disallow_untyped_defs = True [mypy-tests.handlers.*] disallow_untyped_defs = True +[mypy-tests.logging.*] +disallow_untyped_defs = True + [mypy-tests.metrics.*] disallow_untyped_defs = True diff --git a/scripts-dev/release.py b/scripts-dev/release.py index 6974fd7895..008a5bd965 100755 --- a/scripts-dev/release.py +++ b/scripts-dev/release.py @@ -438,7 +438,7 @@ def _upload(gh_token: Optional[str]) -> None: repo = get_repo_and_check_clean_checkout() tag = repo.tag(f"refs/tags/{tag_name}") if repo.head.commit != tag.commit: - click.echo("Tag {tag_name} (tag.commit) is not currently checked out!") + click.echo(f"Tag {tag_name} ({tag.commit}) is not currently checked out!") click.get_current_context().abort() # Query all the assets corresponding to this release. diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 6432d32d83..6f9239d21c 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -17,6 +17,8 @@ """Contains constants from the specification.""" +import enum + from typing_extensions import Final # the max size of a (canonical-json-encoded) event @@ -290,3 +292,8 @@ class ApprovalNoticeMedium: NONE = "org.matrix.msc3866.none" EMAIL = "org.matrix.msc3866.email" + + +class Direction(enum.Enum): + BACKWARDS = "b" + FORWARDS = "f" diff --git a/synapse/events/utils.py b/synapse/events/utils.py index ae57a4df5e..52e4b467e8 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -605,10 +605,11 @@ class EventClientSerializer: _PowerLevel = Union[str, int] +PowerLevelsContent = Mapping[str, Union[_PowerLevel, Mapping[str, _PowerLevel]]] def copy_and_fixup_power_levels_contents( - old_power_levels: Mapping[str, Union[_PowerLevel, Mapping[str, _PowerLevel]]] + old_power_levels: PowerLevelsContent, ) -> Dict[str, Union[int, Dict[str, int]]]: """Copy the content of a power_levels event, unfreezing frozendicts along the way. diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py index 834006356a..d500b21809 100644 --- a/synapse/handlers/account_data.py +++ b/synapse/handlers/account_data.py @@ -14,7 +14,7 @@ # limitations under the License. import logging import random -from typing import TYPE_CHECKING, Awaitable, Callable, Collection, List, Optional, Tuple +from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional, Tuple from synapse.api.constants import AccountDataTypes from synapse.replication.http.account_data import ( @@ -26,7 +26,7 @@ from synapse.replication.http.account_data import ( ReplicationRemoveUserAccountDataRestServlet, ) from synapse.streams import EventSource -from synapse.types import JsonDict, StreamKeyType, UserID +from synapse.types import JsonDict, StrCollection, StreamKeyType, UserID if TYPE_CHECKING: from synapse.server import HomeServer @@ -322,7 +322,7 @@ class AccountDataEventSource(EventSource[int, JsonDict]): user: UserID, from_key: int, limit: int, - room_ids: Collection[str], + room_ids: StrCollection, is_guest: bool, explicit_room_id: Optional[str] = None, ) -> Tuple[List[JsonDict], int]: diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index 5bf8e86387..c81ea34758 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -16,7 +16,7 @@ import abc import logging from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set -from synapse.api.constants import Membership +from synapse.api.constants import Direction, Membership from synapse.events import EventBase from synapse.types import JsonDict, RoomStreamToken, StateMap, UserID from synapse.visibility import filter_events_for_client @@ -197,7 +197,7 @@ class AdminHandler: # efficient method perhaps but it does guarantee we get everything. while True: events, _ = await self.store.paginate_room_events( - room_id, from_key, to_key, limit=100, direction="f" + room_id, from_key, to_key, limit=100, direction=Direction.FORWARDS ) if not events: break diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 58180ae2fa..5c06073901 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -18,7 +18,6 @@ from http import HTTPStatus from typing import ( TYPE_CHECKING, Any, - Collection, Dict, Iterable, List, @@ -45,6 +44,7 @@ from synapse.metrics.background_process_metrics import ( ) from synapse.types import ( JsonDict, + StrCollection, StreamKeyType, StreamToken, UserID, @@ -146,7 +146,7 @@ class DeviceWorkerHandler: @cancellable async def get_device_changes_in_shared_rooms( - self, user_id: str, room_ids: Collection[str], from_token: StreamToken + self, user_id: str, room_ids: StrCollection, from_token: StreamToken ) -> Set[str]: """Get the set of users whose devices have changed who share a room with the given user. @@ -551,7 +551,7 @@ class DeviceHandler(DeviceWorkerHandler): @trace @measure_func("notify_device_update") async def notify_device_update( - self, user_id: str, device_ids: Collection[str] + self, user_id: str, device_ids: StrCollection ) -> None: """Notify that a user's device(s) has changed. Pokes the notifier, and remote servers if the user is local. diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py index f91dbbecb7..a23a8ce2a1 100644 --- a/synapse/handlers/event_auth.py +++ b/synapse/handlers/event_auth.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Collection, List, Mapping, Optional, Union +from typing import TYPE_CHECKING, List, Mapping, Optional, Union from synapse import event_auth from synapse.api.constants import ( @@ -29,7 +29,7 @@ from synapse.event_auth import ( ) from synapse.events import EventBase from synapse.events.builder import EventBuilder -from synapse.types import StateMap, get_domain_from_id +from synapse.types import StateMap, StrCollection, get_domain_from_id if TYPE_CHECKING: from synapse.server import HomeServer @@ -290,7 +290,7 @@ class EventAuthHandler: async def get_rooms_that_allow_join( self, state_ids: StateMap[str] - ) -> Collection[str]: + ) -> StrCollection: """ Generate a list of rooms in which membership allows access to a room. @@ -331,7 +331,7 @@ class EventAuthHandler: return result - async def is_user_in_rooms(self, room_ids: Collection[str], user_id: str) -> bool: + async def is_user_in_rooms(self, room_ids: StrCollection, user_id: str) -> bool: """ Check whether a user is a member of any of the provided rooms. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 233f8c113d..dc1cbf5c3d 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -20,17 +20,7 @@ import itertools import logging from enum import Enum from http import HTTPStatus -from typing import ( - TYPE_CHECKING, - Collection, - Dict, - Iterable, - List, - Optional, - Set, - Tuple, - Union, -) +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple, Union import attr from prometheus_client import Histogram @@ -70,7 +60,7 @@ from synapse.replication.http.federation import ( ) from synapse.storage.databases.main.events import PartialStateConflictError from synapse.storage.databases.main.events_worker import EventRedactBehaviour -from synapse.types import JsonDict, get_domain_from_id +from synapse.types import JsonDict, StrCollection, get_domain_from_id from synapse.types.state import StateFilter from synapse.util.async_helpers import Linearizer from synapse.util.retryutils import NotRetryingDestination @@ -179,7 +169,7 @@ class FederationHandler: # A dictionary mapping room IDs to (initial destination, other destinations) # tuples. self._partial_state_syncs_maybe_needing_restart: Dict[ - str, Tuple[Optional[str], Collection[str]] + str, Tuple[Optional[str], StrCollection] ] = {} # A lock guarding the partial state flag for rooms. # When the lock is held for a given room, no other concurrent code may @@ -437,7 +427,7 @@ class FederationHandler: ) ) - async def try_backfill(domains: Collection[str]) -> bool: + async def try_backfill(domains: StrCollection) -> bool: # TODO: Should we try multiple of these at a time? # Number of contacted remote homeservers that have denied our backfill @@ -1730,7 +1720,7 @@ class FederationHandler: def _start_partial_state_room_sync( self, initial_destination: Optional[str], - other_destinations: Collection[str], + other_destinations: StrCollection, room_id: str, ) -> None: """Starts the background process to resync the state of a partial state room, @@ -1812,7 +1802,7 @@ class FederationHandler: async def _sync_partial_state_room( self, initial_destination: Optional[str], - other_destinations: Collection[str], + other_destinations: StrCollection, room_id: str, ) -> None: """Background process to resync the state of a partial-state room @@ -1949,9 +1939,9 @@ class FederationHandler: def _prioritise_destinations_for_partial_state_resync( initial_destination: Optional[str], - other_destinations: Collection[str], + other_destinations: StrCollection, room_id: str, -) -> Collection[str]: +) -> StrCollection: """Work out the order in which we should ask servers to resync events. If an `initial_destination` is given, it takes top priority. Otherwise diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 904a721483..e037acbca2 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -80,6 +80,7 @@ from synapse.types import ( PersistedEventPosition, RoomStreamToken, StateMap, + StrCollection, UserID, get_domain_from_id, ) @@ -615,7 +616,7 @@ class FederationEventHandler: @trace async def backfill( - self, dest: str, room_id: str, limit: int, extremities: Collection[str] + self, dest: str, room_id: str, limit: int, extremities: StrCollection ) -> None: """Trigger a backfill request to `dest` for the given `room_id` @@ -1565,7 +1566,7 @@ class FederationEventHandler: @trace @tag_args async def _get_events_and_persist( - self, destination: str, room_id: str, event_ids: Collection[str] + self, destination: str, room_id: str, event_ids: StrCollection ) -> None: """Fetch the given events from a server, and persist them as outliers. diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 8c2260ad7d..191529bd8e 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -15,7 +15,13 @@ import logging from typing import TYPE_CHECKING, List, Optional, Tuple, cast -from synapse.api.constants import AccountDataTypes, EduTypes, EventTypes, Membership +from synapse.api.constants import ( + AccountDataTypes, + Direction, + EduTypes, + EventTypes, + Membership, +) from synapse.api.errors import SynapseError from synapse.events import EventBase from synapse.events.utils import SerializeEventConfig @@ -57,7 +63,13 @@ class InitialSyncHandler: self.validator = EventValidator() self.snapshot_cache: ResponseCache[ Tuple[ - str, Optional[StreamToken], Optional[StreamToken], str, int, bool, bool + str, + Optional[StreamToken], + Optional[StreamToken], + Direction, + int, + bool, + bool, ] ] = ResponseCache(hs.get_clock(), "initial_sync_cache") self._event_serializer = hs.get_event_client_serializer() diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 8c8ff18a1a..ceefa16b49 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -13,13 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Set +from typing import TYPE_CHECKING, Dict, List, Optional, Set import attr from twisted.python.failure import Failure -from synapse.api.constants import EventTypes, Membership +from synapse.api.constants import Direction, EventTypes, Membership from synapse.api.errors import SynapseError from synapse.api.filtering import Filter from synapse.events.utils import SerializeEventConfig @@ -28,7 +28,7 @@ from synapse.logging.opentracing import trace from synapse.metrics.background_process_metrics import run_as_background_process from synapse.rest.admin._base import assert_user_is_admin from synapse.streams.config import PaginationConfig -from synapse.types import JsonDict, Requester, StreamKeyType +from synapse.types import JsonDict, Requester, StrCollection, StreamKeyType from synapse.types.state import StateFilter from synapse.util.async_helpers import ReadWriteLock from synapse.util.stringutils import random_string @@ -391,7 +391,7 @@ class PaginationHandler: """ return self._delete_by_id.get(delete_id) - def get_delete_ids_by_room(self, room_id: str) -> Optional[Collection[str]]: + def get_delete_ids_by_room(self, room_id: str) -> Optional[StrCollection]: """Get all active delete ids by room Args: @@ -448,7 +448,7 @@ class PaginationHandler: if pagin_config.from_token: from_token = pagin_config.from_token - elif pagin_config.direction == "f": + elif pagin_config.direction == Direction.FORWARDS: from_token = ( await self.hs.get_event_sources().get_start_token_for_pagination( room_id @@ -476,7 +476,7 @@ class PaginationHandler: room_id, requester, allow_departed_users=True ) - if pagin_config.direction == "b": + if pagin_config.direction == Direction.BACKWARDS: # if we're going backwards, we might need to backfill. This # requires that we have a topo token. if room_token.topological: diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 43e4e7b1b4..87af31aa27 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -64,7 +64,13 @@ from synapse.replication.tcp.commands import ClearUserSyncsCommand from synapse.replication.tcp.streams import PresenceFederationStream, PresenceStream from synapse.storage.databases.main import DataStore from synapse.streams import EventSource -from synapse.types import JsonDict, StreamKeyType, UserID, get_domain_from_id +from synapse.types import ( + JsonDict, + StrCollection, + StreamKeyType, + UserID, + get_domain_from_id, +) from synapse.util.async_helpers import Linearizer from synapse.util.metrics import Measure from synapse.util.wheel_timer import WheelTimer @@ -320,7 +326,7 @@ class BasePresenceHandler(abc.ABC): for destination, host_states in hosts_to_states.items(): self._federation.send_presence_to_destinations(host_states, [destination]) - async def send_full_presence_to_users(self, user_ids: Collection[str]) -> None: + async def send_full_presence_to_users(self, user_ids: StrCollection) -> None: """ Adds to the list of users who should receive a full snapshot of presence upon their next sync. Note that this only works for local users. @@ -1601,7 +1607,7 @@ class PresenceEventSource(EventSource[int, UserPresenceState]): # Having a default limit doesn't match the EventSource API, but some # callers do not provide it. It is unused in this class. limit: int = 0, - room_ids: Optional[Collection[str]] = None, + room_ids: Optional[StrCollection] = None, is_guest: bool = False, explicit_room_id: Optional[str] = None, include_offline: bool = True, @@ -1688,7 +1694,7 @@ class PresenceEventSource(EventSource[int, UserPresenceState]): # The set of users that we're interested in and that have had a presence update. # We'll actually pull the presence updates for these users at the end. - interested_and_updated_users: Collection[str] + interested_and_updated_users: StrCollection if from_key is not None: # First get all users that have had a presence update @@ -2120,7 +2126,7 @@ class PresenceFederationQueue: # stream_id, destinations, user_ids)`. We don't store the full states # for efficiency, and remote workers will already have the full states # cached. - self._queue: List[Tuple[int, int, Collection[str], Set[str]]] = [] + self._queue: List[Tuple[int, int, StrCollection, Set[str]]] = [] self._next_id = 1 @@ -2142,7 +2148,7 @@ class PresenceFederationQueue: self._queue = self._queue[index:] def send_presence_to_destinations( - self, states: Collection[UserPresenceState], destinations: Collection[str] + self, states: Collection[UserPresenceState], destinations: StrCollection ) -> None: """Send the presence states to the given destinations. diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index e96f9999a8..0fb15391e0 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Collection, Dict, FrozenSet, Iterable, List, O import attr -from synapse.api.constants import EventTypes, RelationTypes +from synapse.api.constants import Direction, EventTypes, RelationTypes from synapse.api.errors import SynapseError from synapse.events import EventBase, relation_from_event from synapse.logging.context import make_deferred_yieldable, run_in_background @@ -413,7 +413,11 @@ class RelationsHandler: # Attempt to find another event to use as the latest event. potential_events, _ = await self._main_store.get_relations_for_event( - event_id, event, room_id, RelationTypes.THREAD, direction="f" + event_id, + event, + room_id, + RelationTypes.THREAD, + direction=Direction.FORWARDS, ) # Filter out ignored users. diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 572c7b4db3..60a6d9cf3c 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -20,16 +20,7 @@ import random import string from collections import OrderedDict from http import HTTPStatus -from typing import ( - TYPE_CHECKING, - Any, - Awaitable, - Collection, - Dict, - List, - Optional, - Tuple, -) +from typing import TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Tuple import attr from typing_extensions import TypedDict @@ -72,6 +63,7 @@ from synapse.types import ( RoomID, RoomStreamToken, StateMap, + StrCollection, StreamKeyType, StreamToken, UserID, @@ -1644,7 +1636,7 @@ class RoomEventSource(EventSource[RoomStreamToken, EventBase]): user: UserID, from_key: RoomStreamToken, limit: int, - room_ids: Collection[str], + room_ids: StrCollection, is_guest: bool, explicit_room_id: Optional[str] = None, ) -> Tuple[List[EventBase], RoomStreamToken]: diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py index c6b869c6f4..4472019fbc 100644 --- a/synapse/handlers/room_summary.py +++ b/synapse/handlers/room_summary.py @@ -36,7 +36,7 @@ from synapse.api.errors import ( ) from synapse.api.ratelimiting import Ratelimiter from synapse.events import EventBase -from synapse.types import JsonDict, Requester +from synapse.types import JsonDict, Requester, StrCollection from synapse.util.caches.response_cache import ResponseCache if TYPE_CHECKING: @@ -870,7 +870,7 @@ class _RoomQueueEntry: # The room ID of this entry. room_id: str # The server to query if the room is not known locally. - via: Sequence[str] + via: StrCollection # The minimum number of hops necessary to get to this room (compared to the # originally requested room). depth: int = 0 diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index 40f4635c4e..9bbf83047d 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -14,7 +14,7 @@ import itertools import logging -from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple import attr from unpaddedbase64 import decode_base64, encode_base64 @@ -23,7 +23,7 @@ from synapse.api.constants import EventTypes, Membership from synapse.api.errors import NotFoundError, SynapseError from synapse.api.filtering import Filter from synapse.events import EventBase -from synapse.types import JsonDict, StreamKeyType, UserID +from synapse.types import JsonDict, StrCollection, StreamKeyType, UserID from synapse.types.state import StateFilter from synapse.visibility import filter_events_for_client @@ -418,7 +418,7 @@ class SearchHandler: async def _search_by_rank( self, user: UserID, - room_ids: Collection[str], + room_ids: StrCollection, search_term: str, keys: Iterable[str], search_filter: Filter, @@ -491,7 +491,7 @@ class SearchHandler: async def _search_by_recent( self, user: UserID, - room_ids: Collection[str], + room_ids: StrCollection, search_term: str, keys: Iterable[str], search_filter: Filter, diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index 44e70fc4b8..4a27c0f051 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -20,7 +20,6 @@ from typing import ( Any, Awaitable, Callable, - Collection, Dict, Iterable, List, @@ -47,6 +46,7 @@ from synapse.http.server import respond_with_html, respond_with_redirect from synapse.http.site import SynapseRequest from synapse.types import ( JsonDict, + StrCollection, UserID, contains_invalid_mxid_characters, create_requester, @@ -141,7 +141,8 @@ class UserAttributes: confirm_localpart: bool = False display_name: Optional[str] = None picture: Optional[str] = None - emails: Collection[str] = attr.Factory(list) + # mypy thinks these are incompatible for some reason. + emails: StrCollection = attr.Factory(list) # type: ignore[assignment] @attr.s(slots=True, auto_attribs=True) @@ -159,7 +160,7 @@ class UsernameMappingSession: # attributes returned by the ID mapper display_name: Optional[str] - emails: Collection[str] + emails: StrCollection # An optional dictionary of extra attributes to be provided to the client in the # login response. @@ -174,7 +175,7 @@ class UsernameMappingSession: # choices made by the user chosen_localpart: Optional[str] = None use_display_name: bool = True - emails_to_use: Collection[str] = () + emails_to_use: StrCollection = () terms_accepted_version: Optional[str] = None diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 5ebd3ea855..5235e29460 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -17,7 +17,6 @@ from typing import ( TYPE_CHECKING, AbstractSet, Any, - Collection, Dict, FrozenSet, List, @@ -62,6 +61,7 @@ from synapse.types import ( Requester, RoomStreamToken, StateMap, + StrCollection, StreamKeyType, StreamToken, UserID, @@ -1179,7 +1179,7 @@ class SyncHandler: async def _find_missing_partial_state_memberships( self, room_id: str, - members_to_fetch: Collection[str], + members_to_fetch: StrCollection, events_with_membership_auth: Mapping[str, EventBase], found_state_ids: StateMap[str], ) -> StateMap[str]: diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 6153a48257..d22dd19d38 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -1158,7 +1158,7 @@ class ModuleApi: # Send to remote destinations. destination = UserID.from_string(user).domain presence_handler.get_federation_queue().send_presence_to_destinations( - presence_events, destination + presence_events, [destination] ) def looping_background_call( diff --git a/synapse/notifier.py b/synapse/notifier.py index 2b0e52f23c..a8832a3f8e 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -46,6 +46,7 @@ from synapse.types import ( JsonDict, PersistedEventPosition, RoomStreamToken, + StrCollection, StreamKeyType, StreamToken, UserID, @@ -716,7 +717,7 @@ class Notifier: async def _get_room_ids( self, user: UserID, explicit_room_id: Optional[str] - ) -> Tuple[Collection[str], bool]: + ) -> Tuple[StrCollection, bool]: joined_room_ids = await self.store.get_rooms_for_user(user.to_string()) if explicit_room_id: if explicit_room_id in joined_room_ids: diff --git a/synapse/rest/client/push_rule.py b/synapse/rest/client/push_rule.py index 8191b4e32c..ad5c10c99d 100644 --- a/synapse/rest/client/push_rule.py +++ b/synapse/rest/client/push_rule.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, List, Sequence, Tuple, Union +from typing import TYPE_CHECKING, List, Tuple, Union from synapse.api.errors import ( NotFoundError, @@ -169,7 +169,7 @@ class PushRuleRestServlet(RestServlet): raise UnrecognizedRequestError() -def _rule_spec_from_path(path: Sequence[str]) -> RuleSpec: +def _rule_spec_from_path(path: List[str]) -> RuleSpec: """Turn a sequence of path components into a rule spec Args: diff --git a/synapse/rest/client/transactions.py b/synapse/rest/client/transactions.py index 61375651bc..3f40f1874a 100644 --- a/synapse/rest/client/transactions.py +++ b/synapse/rest/client/transactions.py @@ -19,6 +19,7 @@ from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Tuple from typing_extensions import ParamSpec +from twisted.internet.defer import Deferred from twisted.python.failure import Failure from twisted.web.server import Request @@ -90,7 +91,7 @@ class HttpTransactionCache: fn: Callable[P, Awaitable[Tuple[int, JsonDict]]], *args: P.args, **kwargs: P.kwargs, - ) -> Awaitable[Tuple[int, JsonDict]]: + ) -> "Deferred[Tuple[int, JsonDict]]": """Fetches the response for this transaction, or executes the given function to produce a response for this transaction. diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 84f844b79e..0018d6f7ab 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -30,7 +30,7 @@ from typing import ( import attr -from synapse.api.constants import MAIN_TIMELINE, RelationTypes +from synapse.api.constants import MAIN_TIMELINE, Direction, RelationTypes from synapse.api.errors import SynapseError from synapse.events import EventBase from synapse.storage._base import SQLBaseStore @@ -40,9 +40,13 @@ from synapse.storage.database import ( LoggingTransaction, make_in_list_sql_clause, ) -from synapse.storage.databases.main.stream import generate_pagination_where_clause +from synapse.storage.databases.main.stream import ( + generate_next_token, + generate_pagination_bounds, + generate_pagination_where_clause, +) from synapse.storage.engines import PostgresEngine -from synapse.types import JsonDict, RoomStreamToken, StreamKeyType, StreamToken +from synapse.types import JsonDict, StreamKeyType, StreamToken from synapse.util.caches.descriptors import cached, cachedList if TYPE_CHECKING: @@ -164,7 +168,7 @@ class RelationsWorkerStore(SQLBaseStore): relation_type: Optional[str] = None, event_type: Optional[str] = None, limit: int = 5, - direction: str = "b", + direction: Direction = Direction.BACKWARDS, from_token: Optional[StreamToken] = None, to_token: Optional[StreamToken] = None, ) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]: @@ -177,8 +181,8 @@ class RelationsWorkerStore(SQLBaseStore): relation_type: Only fetch events with this relation type, if given. event_type: Only fetch events with this event type, if given. limit: Only fetch the most recent `limit` events. - direction: Whether to fetch the most recent first (`"b"`) or the - oldest first (`"f"`). + direction: Whether to fetch the most recent first (backwards) or the + oldest first (forwards). from_token: Fetch rows from the given token, or from the start if None. to_token: Fetch rows up to the given token, or up to the end if None. @@ -207,24 +211,23 @@ class RelationsWorkerStore(SQLBaseStore): where_clause.append("type = ?") where_args.append(event_type) + order, from_bound, to_bound = generate_pagination_bounds( + direction, + from_token.room_key if from_token else None, + to_token.room_key if to_token else None, + ) + pagination_clause = generate_pagination_where_clause( direction=direction, column_names=("topological_ordering", "stream_ordering"), - from_token=from_token.room_key.as_historical_tuple() - if from_token - else None, - to_token=to_token.room_key.as_historical_tuple() if to_token else None, + from_token=from_bound, + to_token=to_bound, engine=self.database_engine, ) if pagination_clause: where_clause.append(pagination_clause) - if direction == "b": - order = "DESC" - else: - order = "ASC" - sql = """ SELECT event_id, relation_type, sender, topological_ordering, stream_ordering FROM event_relations @@ -266,16 +269,9 @@ class RelationsWorkerStore(SQLBaseStore): topo_orderings = topo_orderings[:limit] stream_orderings = stream_orderings[:limit] - topo = topo_orderings[-1] - token = stream_orderings[-1] - if direction == "b": - # Tokens are positions between events. - # This token points *after* the last event in the chunk. - # We need it to point to the event before it in the chunk - # when we are going backwards so we subtract one from the - # stream part. - token -= 1 - next_key = RoomStreamToken(topo, token) + next_key = generate_next_token( + direction, topo_orderings[-1], stream_orderings[-1] + ) if from_token: next_token = from_token.copy_and_replace( diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index d28fc65df9..818c46182e 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -55,6 +55,7 @@ from typing_extensions import Literal from twisted.internet import defer +from synapse.api.constants import Direction from synapse.api.filtering import Filter from synapse.events import EventBase from synapse.logging.context import make_deferred_yieldable, run_in_background @@ -86,7 +87,6 @@ MAX_STREAM_SIZE = 1000 _STREAM_TOKEN = "stream" _TOPOLOGICAL_TOKEN = "topological" - # Used as return values for pagination APIs @attr.s(slots=True, frozen=True, auto_attribs=True) class _EventDictReturn: @@ -104,7 +104,7 @@ class _EventsAround: def generate_pagination_where_clause( - direction: str, + direction: Direction, column_names: Tuple[str, str], from_token: Optional[Tuple[Optional[int], int]], to_token: Optional[Tuple[Optional[int], int]], @@ -130,27 +130,26 @@ def generate_pagination_where_clause( token, but include those that match the to token. Args: - direction: Whether we're paginating backwards("b") or forwards ("f"). + direction: Whether we're paginating backwards or forwards. column_names: The column names to bound. Must *not* be user defined as these get inserted directly into the SQL statement without escapes. from_token: The start point for the pagination. This is an exclusive - minimum bound if direction is "f", and an inclusive maximum bound if - direction is "b". + minimum bound if direction is forwards, and an inclusive maximum bound if + direction is backwards. to_token: The endpoint point for the pagination. This is an inclusive - maximum bound if direction is "f", and an exclusive minimum bound if - direction is "b". + maximum bound if direction is forwards, and an exclusive minimum bound if + direction is backwards. engine: The database engine to generate the clauses for Returns: The sql expression """ - assert direction in ("b", "f") where_clause = [] if from_token: where_clause.append( _make_generic_sql_bound( - bound=">=" if direction == "b" else "<", + bound=">=" if direction == Direction.BACKWARDS else "<", column_names=column_names, values=from_token, engine=engine, @@ -160,7 +159,7 @@ def generate_pagination_where_clause( if to_token: where_clause.append( _make_generic_sql_bound( - bound="<" if direction == "b" else ">=", + bound="<" if direction == Direction.BACKWARDS else ">=", column_names=column_names, values=to_token, engine=engine, @@ -170,6 +169,104 @@ def generate_pagination_where_clause( return " AND ".join(where_clause) +def generate_pagination_bounds( + direction: Direction, + from_token: Optional[RoomStreamToken], + to_token: Optional[RoomStreamToken], +) -> Tuple[ + str, Optional[Tuple[Optional[int], int]], Optional[Tuple[Optional[int], int]] +]: + """ + Generate a start and end point for this page of events. + + Args: + direction: Whether pagination is going forwards or backwards. + from_token: The token to start pagination at, or None to start at the first value. + to_token: The token to end pagination at, or None to not limit the end point. + + Returns: + A three tuple of: + + ASC or DESC for sorting of the query. + + The starting position as a tuple of ints representing + (topological position, stream position) or None if no from_token was + provided. The topological position may be None for live tokens. + + The end position in the same format as the starting position, or None + if no to_token was provided. + """ + + # Tokens really represent positions between elements, but we use + # the convention of pointing to the event before the gap. Hence + # we have a bit of asymmetry when it comes to equalities. + if direction == Direction.BACKWARDS: + order = "DESC" + else: + order = "ASC" + + # The bounds for the stream tokens are complicated by the fact + # that we need to handle the instance_map part of the tokens. We do this + # by fetching all events between the min stream token and the maximum + # stream token (as returned by `RoomStreamToken.get_max_stream_pos`) and + # then filtering the results. + from_bound: Optional[Tuple[Optional[int], int]] = None + if from_token: + if from_token.topological is not None: + from_bound = from_token.as_historical_tuple() + elif direction == Direction.BACKWARDS: + from_bound = ( + None, + from_token.get_max_stream_pos(), + ) + else: + from_bound = ( + None, + from_token.stream, + ) + + to_bound: Optional[Tuple[Optional[int], int]] = None + if to_token: + if to_token.topological is not None: + to_bound = to_token.as_historical_tuple() + elif direction == Direction.BACKWARDS: + to_bound = ( + None, + to_token.stream, + ) + else: + to_bound = ( + None, + to_token.get_max_stream_pos(), + ) + + return order, from_bound, to_bound + + +def generate_next_token( + direction: Direction, last_topo_ordering: int, last_stream_ordering: int +) -> RoomStreamToken: + """ + Generate the next room stream token based on the currently returned data. + + Args: + direction: Whether pagination is going forwards or backwards. + last_topo_ordering: The last topological ordering being returned. + last_stream_ordering: The last stream ordering being returned. + + Returns: + A new RoomStreamToken to return to the client. + """ + if direction == Direction.BACKWARDS: + # Tokens are positions between events. + # This token points *after* the last event in the chunk. + # We need it to point to the event before it in the chunk + # when we are going backwards so we subtract one from the + # stream part. + last_stream_ordering -= 1 + return RoomStreamToken(last_topo_ordering, last_stream_ordering) + + def _make_generic_sql_bound( bound: str, column_names: Tuple[str, str], @@ -1103,7 +1200,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): txn, room_id, before_token, - direction="b", + direction=Direction.BACKWARDS, limit=before_limit, event_filter=event_filter, ) @@ -1113,7 +1210,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): txn, room_id, after_token, - direction="f", + direction=Direction.FORWARDS, limit=after_limit, event_filter=event_filter, ) @@ -1276,7 +1373,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): room_id: str, from_token: RoomStreamToken, to_token: Optional[RoomStreamToken] = None, - direction: str = "b", + direction: Direction = Direction.BACKWARDS, limit: int = -1, event_filter: Optional[Filter] = None, ) -> Tuple[List[_EventDictReturn], RoomStreamToken]: @@ -1287,8 +1384,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): room_id from_token: The token used to stream from to_token: A token which if given limits the results to only those before - direction: Either 'b' or 'f' to indicate whether we are paginating - forwards or backwards from `from_key`. + direction: Indicates whether we are paginating forwards or backwards + from `from_key`. limit: The maximum number of events to return. event_filter: If provided filters the events to those that match the filter. @@ -1300,47 +1397,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): `to_token`), or `limit` is zero. """ - # Tokens really represent positions between elements, but we use - # the convention of pointing to the event before the gap. Hence - # we have a bit of asymmetry when it comes to equalities. args = [False, room_id] - if direction == "b": - order = "DESC" - else: - order = "ASC" - - # The bounds for the stream tokens are complicated by the fact - # that we need to handle the instance_map part of the tokens. We do this - # by fetching all events between the min stream token and the maximum - # stream token (as returned by `RoomStreamToken.get_max_stream_pos`) and - # then filtering the results. - if from_token.topological is not None: - from_bound: Tuple[Optional[int], int] = from_token.as_historical_tuple() - elif direction == "b": - from_bound = ( - None, - from_token.get_max_stream_pos(), - ) - else: - from_bound = ( - None, - from_token.stream, - ) - to_bound: Optional[Tuple[Optional[int], int]] = None - if to_token: - if to_token.topological is not None: - to_bound = to_token.as_historical_tuple() - elif direction == "b": - to_bound = ( - None, - to_token.stream, - ) - else: - to_bound = ( - None, - to_token.get_max_stream_pos(), - ) + order, from_bound, to_bound = generate_pagination_bounds( + direction, from_token, to_token + ) bounds = generate_pagination_where_clause( direction=direction, @@ -1427,8 +1488,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): _EventDictReturn(event_id, topological_ordering, stream_ordering) for event_id, instance_name, topological_ordering, stream_ordering in txn if _filter_results( - lower_token=to_token if direction == "b" else from_token, - upper_token=from_token if direction == "b" else to_token, + lower_token=to_token + if direction == Direction.BACKWARDS + else from_token, + upper_token=from_token + if direction == Direction.BACKWARDS + else to_token, instance_name=instance_name, topological_ordering=topological_ordering, stream_ordering=stream_ordering, @@ -1436,16 +1501,10 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): ][:limit] if rows: - topo = rows[-1].topological_ordering - token = rows[-1].stream_ordering - if direction == "b": - # Tokens are positions between events. - # This token points *after* the last event in the chunk. - # We need it to point to the event before it in the chunk - # when we are going backwards so we subtract one from the - # stream part. - token -= 1 - next_token = RoomStreamToken(topo, token) + assert rows[-1].topological_ordering is not None + next_token = generate_next_token( + direction, rows[-1].topological_ordering, rows[-1].stream_ordering + ) else: # TODO (erikj): We should work out what to do here instead. next_token = to_token if to_token else from_token @@ -1458,7 +1517,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): room_id: str, from_key: RoomStreamToken, to_key: Optional[RoomStreamToken] = None, - direction: str = "b", + direction: Direction = Direction.BACKWARDS, limit: int = -1, event_filter: Optional[Filter] = None, ) -> Tuple[List[EventBase], RoomStreamToken]: @@ -1468,8 +1527,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): room_id from_key: The token used to stream from to_key: A token which if given limits the results to only those before - direction: Either 'b' or 'f' to indicate whether we are paginating - forwards or backwards from `from_key`. + direction: Indicates whether we are paginating forwards or backwards + from `from_key`. limit: The maximum number of events to return. event_filter: If provided filters the events to those that match the filter. diff --git a/synapse/streams/__init__.py b/synapse/streams/__init__.py index 2dcd43d0a2..c6c8a0315c 100644 --- a/synapse/streams/__init__.py +++ b/synapse/streams/__init__.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Collection, Generic, List, Optional, Tuple, TypeVar +from typing import Generic, List, Optional, Tuple, TypeVar -from synapse.types import UserID +from synapse.types import StrCollection, UserID # The key, this is either a stream token or int. K = TypeVar("K") @@ -28,7 +28,7 @@ class EventSource(Generic[K, R]): user: UserID, from_key: K, limit: int, - room_ids: Collection[str], + room_ids: StrCollection, is_guest: bool, explicit_room_id: Optional[str] = None, ) -> Tuple[List[R], K]: diff --git a/synapse/streams/config.py b/synapse/streams/config.py index 6df2de919c..5cb7875181 100644 --- a/synapse/streams/config.py +++ b/synapse/streams/config.py @@ -16,6 +16,7 @@ from typing import Optional import attr +from synapse.api.constants import Direction from synapse.api.errors import SynapseError from synapse.http.servlet import parse_integer, parse_string from synapse.http.site import SynapseRequest @@ -34,7 +35,7 @@ class PaginationConfig: from_token: Optional[StreamToken] to_token: Optional[StreamToken] - direction: str + direction: Direction limit: int @classmethod @@ -45,9 +46,13 @@ class PaginationConfig: default_limit: int, default_dir: str = "f", ) -> "PaginationConfig": - direction = parse_string( - request, "dir", default=default_dir, allowed_values=["f", "b"] + direction_str = parse_string( + request, + "dir", + default=default_dir, + allowed_values=[Direction.FORWARDS.value, Direction.BACKWARDS.value], ) + direction = Direction(direction_str) from_tok_str = parse_string(request, "from") to_tok_str = parse_string(request, "to") diff --git a/tests/events/test_presence_router.py b/tests/events/test_presence_router.py index b703e4472e..a9893def74 100644 --- a/tests/events/test_presence_router.py +++ b/tests/events/test_presence_router.py @@ -16,6 +16,8 @@ from unittest.mock import Mock import attr +from twisted.test.proto_helpers import MemoryReactor + from synapse.api.constants import EduTypes from synapse.events.presence_router import PresenceRouter, load_legacy_presence_router from synapse.federation.units import Transaction @@ -23,11 +25,13 @@ from synapse.handlers.presence import UserPresenceState from synapse.module_api import ModuleApi from synapse.rest import admin from synapse.rest.client import login, presence, room +from synapse.server import HomeServer from synapse.types import JsonDict, StreamToken, create_requester +from synapse.util import Clock from tests.handlers.test_sync import generate_sync_config from tests.test_utils import simple_async_mock -from tests.unittest import FederatingHomeserverTestCase, TestCase, override_config +from tests.unittest import FederatingHomeserverTestCase, override_config @attr.s @@ -49,9 +53,7 @@ class LegacyPresenceRouterTestModule: } return users_to_state - async def get_interested_users( - self, user_id: str - ) -> Union[Set[str], PresenceRouter.ALL_USERS]: + async def get_interested_users(self, user_id: str) -> Union[Set[str], str]: if user_id in self._config.users_who_should_receive_all_presence: return PresenceRouter.ALL_USERS @@ -71,9 +73,14 @@ class LegacyPresenceRouterTestModule: # Initialise a typed config object config = PresenceRouterTestConfig() - config.users_who_should_receive_all_presence = config_dict.get( + users_who_should_receive_all_presence = config_dict.get( "users_who_should_receive_all_presence" ) + assert isinstance(users_who_should_receive_all_presence, list) + + config.users_who_should_receive_all_presence = ( + users_who_should_receive_all_presence + ) return config @@ -96,9 +103,7 @@ class PresenceRouterTestModule: } return users_to_state - async def get_interested_users( - self, user_id: str - ) -> Union[Set[str], PresenceRouter.ALL_USERS]: + async def get_interested_users(self, user_id: str) -> Union[Set[str], str]: if user_id in self._config.users_who_should_receive_all_presence: return PresenceRouter.ALL_USERS @@ -118,9 +123,14 @@ class PresenceRouterTestModule: # Initialise a typed config object config = PresenceRouterTestConfig() - config.users_who_should_receive_all_presence = config_dict.get( + users_who_should_receive_all_presence = config_dict.get( "users_who_should_receive_all_presence" ) + assert isinstance(users_who_should_receive_all_presence, list) + + config.users_who_should_receive_all_presence = ( + users_who_should_receive_all_presence + ) return config @@ -140,7 +150,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase): presence.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: # Mock out the calls over federation. fed_transport_client = Mock(spec=["send_transaction"]) fed_transport_client.send_transaction = simple_async_mock({}) @@ -153,7 +163,9 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase): return hs - def prepare(self, reactor, clock, homeserver): + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer + ) -> None: self.sync_handler = self.hs.get_sync_handler() self.module_api = homeserver.get_module_api() @@ -176,7 +188,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase): }, } ) - def test_receiving_all_presence_legacy(self): + def test_receiving_all_presence_legacy(self) -> None: self.receiving_all_presence_test_body() @override_config( @@ -193,10 +205,10 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase): ], } ) - def test_receiving_all_presence(self): + def test_receiving_all_presence(self) -> None: self.receiving_all_presence_test_body() - def receiving_all_presence_test_body(self): + def receiving_all_presence_test_body(self) -> None: """Test that a user that does not share a room with another other can receive presence for them, due to presence routing. """ @@ -302,7 +314,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase): }, } ) - def test_send_local_online_presence_to_with_module_legacy(self): + def test_send_local_online_presence_to_with_module_legacy(self) -> None: self.send_local_online_presence_to_with_module_test_body() @override_config( @@ -321,10 +333,10 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase): ], } ) - def test_send_local_online_presence_to_with_module(self): + def test_send_local_online_presence_to_with_module(self) -> None: self.send_local_online_presence_to_with_module_test_body() - def send_local_online_presence_to_with_module_test_body(self): + def send_local_online_presence_to_with_module_test_body(self) -> None: """Tests that send_local_presence_to_users sends local online presence to a set of specified local and remote users, with a custom PresenceRouter module enabled. """ @@ -447,18 +459,18 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase): continue # EDUs can contain multiple presence updates - for presence_update in edu["content"]["push"]: + for presence_edu in edu["content"]["push"]: # Check for presence updates that contain the user IDs we're after - found_users.add(presence_update["user_id"]) + found_users.add(presence_edu["user_id"]) # Ensure that no offline states are being sent out - self.assertNotEqual(presence_update["presence"], "offline") + self.assertNotEqual(presence_edu["presence"], "offline") self.assertEqual(found_users, expected_users) def send_presence_update( - testcase: TestCase, + testcase: FederatingHomeserverTestCase, user_id: str, access_token: str, presence_state: str, @@ -479,7 +491,7 @@ def send_presence_update( def sync_presence( - testcase: TestCase, + testcase: FederatingHomeserverTestCase, user_id: str, since_token: Optional[StreamToken] = None, ) -> Tuple[List[UserPresenceState], StreamToken]: @@ -500,7 +512,7 @@ def sync_presence( requester = create_requester(user_id) sync_config = generate_sync_config(requester.user.to_string()) sync_result = testcase.get_success( - testcase.sync_handler.wait_for_sync_for_user( + testcase.hs.get_sync_handler().wait_for_sync_for_user( requester, sync_config, since_token ) ) diff --git a/tests/events/test_snapshot.py b/tests/events/test_snapshot.py index 8ddce83b83..6687c28e8f 100644 --- a/tests/events/test_snapshot.py +++ b/tests/events/test_snapshot.py @@ -12,9 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from twisted.test.proto_helpers import MemoryReactor + +from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.rest import admin from synapse.rest.client import login, room +from synapse.server import HomeServer +from synapse.util import Clock from tests import unittest from tests.test_utils.event_injection import create_event @@ -27,7 +32,7 @@ class TestEventContext(unittest.HomeserverTestCase): room.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main self._storage_controllers = hs.get_storage_controllers() @@ -35,7 +40,7 @@ class TestEventContext(unittest.HomeserverTestCase): self.user_tok = self.login("u1", "pass") self.room_id = self.helper.create_room_as(tok=self.user_tok) - def test_serialize_deserialize_msg(self): + def test_serialize_deserialize_msg(self) -> None: """Test that an EventContext for a message event is the same after serialize/deserialize. """ @@ -51,7 +56,7 @@ class TestEventContext(unittest.HomeserverTestCase): self._check_serialize_deserialize(event, context) - def test_serialize_deserialize_state_no_prev(self): + def test_serialize_deserialize_state_no_prev(self) -> None: """Test that an EventContext for a state event (with not previous entry) is the same after serialize/deserialize. """ @@ -67,7 +72,7 @@ class TestEventContext(unittest.HomeserverTestCase): self._check_serialize_deserialize(event, context) - def test_serialize_deserialize_state_prev(self): + def test_serialize_deserialize_state_prev(self) -> None: """Test that an EventContext for a state event (which replaces a previous entry) is the same after serialize/deserialize. """ @@ -84,7 +89,9 @@ class TestEventContext(unittest.HomeserverTestCase): self._check_serialize_deserialize(event, context) - def _check_serialize_deserialize(self, event, context): + def _check_serialize_deserialize( + self, event: EventBase, context: EventContext + ) -> None: serialized = self.get_success(context.serialize(event, self.store)) d_context = EventContext.deserialize(self._storage_controllers, serialized) diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py index a79256846f..ff7b349d75 100644 --- a/tests/events/test_utils.py +++ b/tests/events/test_utils.py @@ -13,21 +13,24 @@ # limitations under the License. import unittest as stdlib_unittest +from typing import Any, List, Mapping, Optional from synapse.api.constants import EventContentFields from synapse.api.room_versions import RoomVersions -from synapse.events import make_event_from_dict +from synapse.events import EventBase, make_event_from_dict from synapse.events.utils import ( + PowerLevelsContent, SerializeEventConfig, copy_and_fixup_power_levels_contents, maybe_upsert_event_field, prune_event, serialize_event, ) +from synapse.types import JsonDict from synapse.util.frozenutils import freeze -def MockEvent(**kwargs): +def MockEvent(**kwargs: Any) -> EventBase: if "event_id" not in kwargs: kwargs["event_id"] = "fake_event_id" if "type" not in kwargs: @@ -60,7 +63,7 @@ class TestMaybeUpsertEventField(stdlib_unittest.TestCase): class PruneEventTestCase(stdlib_unittest.TestCase): - def run_test(self, evdict, matchdict, **kwargs): + def run_test(self, evdict: JsonDict, matchdict: JsonDict, **kwargs: Any) -> None: """ Asserts that a new event constructed with `evdict` will look like `matchdict` when it is redacted. @@ -74,7 +77,7 @@ class PruneEventTestCase(stdlib_unittest.TestCase): prune_event(make_event_from_dict(evdict, **kwargs)).get_dict(), matchdict ) - def test_minimal(self): + def test_minimal(self) -> None: self.run_test( {"type": "A", "event_id": "$test:domain"}, { @@ -86,7 +89,7 @@ class PruneEventTestCase(stdlib_unittest.TestCase): }, ) - def test_basic_keys(self): + def test_basic_keys(self) -> None: """Ensure that the keys that should be untouched are kept.""" # Note that some of the values below don't really make sense, but the # pruning of events doesn't worry about the values of any fields (with @@ -138,7 +141,7 @@ class PruneEventTestCase(stdlib_unittest.TestCase): room_version=RoomVersions.MSC2176, ) - def test_unsigned(self): + def test_unsigned(self) -> None: """Ensure that unsigned properties get stripped (except age_ts and replaces_state).""" self.run_test( { @@ -159,7 +162,7 @@ class PruneEventTestCase(stdlib_unittest.TestCase): }, ) - def test_content(self): + def test_content(self) -> None: """The content dictionary should be stripped in most cases.""" self.run_test( {"type": "C", "event_id": "$test:domain", "content": {"things": "here"}}, @@ -194,7 +197,7 @@ class PruneEventTestCase(stdlib_unittest.TestCase): }, ) - def test_create(self): + def test_create(self) -> None: """Create events are partially redacted until MSC2176.""" self.run_test( { @@ -223,7 +226,7 @@ class PruneEventTestCase(stdlib_unittest.TestCase): room_version=RoomVersions.MSC2176, ) - def test_power_levels(self): + def test_power_levels(self) -> None: """Power level events keep a variety of content keys.""" self.run_test( { @@ -273,7 +276,7 @@ class PruneEventTestCase(stdlib_unittest.TestCase): room_version=RoomVersions.MSC2176, ) - def test_alias_event(self): + def test_alias_event(self) -> None: """Alias events have special behavior up through room version 6.""" self.run_test( { @@ -302,7 +305,7 @@ class PruneEventTestCase(stdlib_unittest.TestCase): room_version=RoomVersions.V6, ) - def test_redacts(self): + def test_redacts(self) -> None: """Redaction events have no special behaviour until MSC2174/MSC2176.""" self.run_test( @@ -328,7 +331,7 @@ class PruneEventTestCase(stdlib_unittest.TestCase): room_version=RoomVersions.MSC2176, ) - def test_join_rules(self): + def test_join_rules(self) -> None: """Join rules events have changed behavior starting with MSC3083.""" self.run_test( { @@ -371,7 +374,7 @@ class PruneEventTestCase(stdlib_unittest.TestCase): room_version=RoomVersions.V8, ) - def test_member(self): + def test_member(self) -> None: """Member events have changed behavior starting with MSC3375.""" self.run_test( { @@ -417,12 +420,12 @@ class PruneEventTestCase(stdlib_unittest.TestCase): class SerializeEventTestCase(stdlib_unittest.TestCase): - def serialize(self, ev, fields): + def serialize(self, ev: EventBase, fields: Optional[List[str]]) -> JsonDict: return serialize_event( ev, 1479807801915, config=SerializeEventConfig(only_event_fields=fields) ) - def test_event_fields_works_with_keys(self): + def test_event_fields_works_with_keys(self) -> None: self.assertEqual( self.serialize( MockEvent(sender="@alice:localhost", room_id="!foo:bar"), ["room_id"] @@ -430,7 +433,7 @@ class SerializeEventTestCase(stdlib_unittest.TestCase): {"room_id": "!foo:bar"}, ) - def test_event_fields_works_with_nested_keys(self): + def test_event_fields_works_with_nested_keys(self) -> None: self.assertEqual( self.serialize( MockEvent( @@ -443,7 +446,7 @@ class SerializeEventTestCase(stdlib_unittest.TestCase): {"content": {"body": "A message"}}, ) - def test_event_fields_works_with_dot_keys(self): + def test_event_fields_works_with_dot_keys(self) -> None: self.assertEqual( self.serialize( MockEvent( @@ -456,7 +459,7 @@ class SerializeEventTestCase(stdlib_unittest.TestCase): {"content": {"key.with.dots": {}}}, ) - def test_event_fields_works_with_nested_dot_keys(self): + def test_event_fields_works_with_nested_dot_keys(self) -> None: self.assertEqual( self.serialize( MockEvent( @@ -472,7 +475,7 @@ class SerializeEventTestCase(stdlib_unittest.TestCase): {"content": {"nested.dot.key": {"leaf.key": 42}}}, ) - def test_event_fields_nops_with_unknown_keys(self): + def test_event_fields_nops_with_unknown_keys(self) -> None: self.assertEqual( self.serialize( MockEvent( @@ -485,7 +488,7 @@ class SerializeEventTestCase(stdlib_unittest.TestCase): {"content": {"foo": "bar"}}, ) - def test_event_fields_nops_with_non_dict_keys(self): + def test_event_fields_nops_with_non_dict_keys(self) -> None: self.assertEqual( self.serialize( MockEvent( @@ -498,7 +501,7 @@ class SerializeEventTestCase(stdlib_unittest.TestCase): {}, ) - def test_event_fields_nops_with_array_keys(self): + def test_event_fields_nops_with_array_keys(self) -> None: self.assertEqual( self.serialize( MockEvent( @@ -511,7 +514,7 @@ class SerializeEventTestCase(stdlib_unittest.TestCase): {}, ) - def test_event_fields_all_fields_if_empty(self): + def test_event_fields_all_fields_if_empty(self) -> None: self.assertEqual( self.serialize( MockEvent( @@ -531,16 +534,16 @@ class SerializeEventTestCase(stdlib_unittest.TestCase): }, ) - def test_event_fields_fail_if_fields_not_str(self): + def test_event_fields_fail_if_fields_not_str(self) -> None: with self.assertRaises(TypeError): self.serialize( - MockEvent(room_id="!foo:bar", content={"foo": "bar"}), ["room_id", 4] + MockEvent(room_id="!foo:bar", content={"foo": "bar"}), ["room_id", 4] # type: ignore[list-item] ) class CopyPowerLevelsContentTestCase(stdlib_unittest.TestCase): def setUp(self) -> None: - self.test_content = { + self.test_content: PowerLevelsContent = { "ban": 50, "events": {"m.room.name": 100, "m.room.power_levels": 100}, "events_default": 0, @@ -553,10 +556,11 @@ class CopyPowerLevelsContentTestCase(stdlib_unittest.TestCase): "users_default": 0, } - def _test(self, input): + def _test(self, input: PowerLevelsContent) -> None: a = copy_and_fixup_power_levels_contents(input) self.assertEqual(a["ban"], 50) + assert isinstance(a["events"], Mapping) self.assertEqual(a["events"]["m.room.name"], 100) # make sure that changing the copy changes the copy and not the orig @@ -564,18 +568,19 @@ class CopyPowerLevelsContentTestCase(stdlib_unittest.TestCase): a["events"]["m.room.power_levels"] = 20 self.assertEqual(input["ban"], 50) + assert isinstance(input["events"], Mapping) self.assertEqual(input["events"]["m.room.power_levels"], 100) - def test_unfrozen(self): + def test_unfrozen(self) -> None: self._test(self.test_content) - def test_frozen(self): + def test_frozen(self) -> None: input = freeze(self.test_content) self._test(input) - def test_stringy_integers(self): + def test_stringy_integers(self) -> None: """String representations of decimal integers are converted to integers.""" - input = { + input: PowerLevelsContent = { "a": "100", "b": { "foo": 99, @@ -603,9 +608,9 @@ class CopyPowerLevelsContentTestCase(stdlib_unittest.TestCase): def test_invalid_types_raise_type_error(self) -> None: with self.assertRaises(TypeError): - copy_and_fixup_power_levels_contents({"a": ["hello", "grandma"]}) # type: ignore[arg-type] - copy_and_fixup_power_levels_contents({"a": None}) # type: ignore[arg-type] + copy_and_fixup_power_levels_contents({"a": ["hello", "grandma"]}) # type: ignore[dict-item] + copy_and_fixup_power_levels_contents({"a": None}) # type: ignore[dict-item] def test_invalid_nesting_raises_type_error(self) -> None: with self.assertRaises(TypeError): - copy_and_fixup_power_levels_contents({"a": {"b": {"c": 1}}}) + copy_and_fixup_power_levels_contents({"a": {"b": {"c": 1}}}) # type: ignore[dict-item] diff --git a/tests/federation/transport/test_knocking.py b/tests/federation/transport/test_knocking.py index d21c11b716..ff589c0b6c 100644 --- a/tests/federation/transport/test_knocking.py +++ b/tests/federation/transport/test_knocking.py @@ -23,10 +23,10 @@ from synapse.server import HomeServer from synapse.types import RoomAlias from tests.test_utils import event_injection -from tests.unittest import FederatingHomeserverTestCase, TestCase +from tests.unittest import FederatingHomeserverTestCase, HomeserverTestCase -class KnockingStrippedStateEventHelperMixin(TestCase): +class KnockingStrippedStateEventHelperMixin(HomeserverTestCase): def send_example_state_events_to_room( self, hs: "HomeServer", @@ -49,7 +49,7 @@ class KnockingStrippedStateEventHelperMixin(TestCase): # To set a canonical alias, we'll need to point an alias at the room first. canonical_alias = "#fancy_alias:test" self.get_success( - self.store.create_room_alias_association( + self.hs.get_datastores().main.create_room_alias_association( RoomAlias.from_string(canonical_alias), room_id, ["test"] ) ) diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index efbb5a8dbb..1fe9563c98 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -14,21 +14,22 @@ import json -from typing import Dict +from typing import Dict, List, Set from unittest.mock import ANY, Mock, call -from twisted.internet import defer from twisted.test.proto_helpers import MemoryReactor from twisted.web.resource import Resource from synapse.api.constants import EduTypes from synapse.api.errors import AuthError from synapse.federation.transport.server import TransportLayerServer +from synapse.handlers.typing import TypingWriterHandler from synapse.server import HomeServer from synapse.types import JsonDict, Requester, UserID, create_requester from synapse.util import Clock from tests import unittest +from tests.server import ThreadedMemoryReactorClock from tests.test_utils import make_awaitable from tests.unittest import override_config @@ -62,7 +63,11 @@ def _make_edu_transaction_json(edu_type: str, content: JsonDict) -> bytes: class TypingNotificationsTestCase(unittest.HomeserverTestCase): - def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + def make_homeserver( + self, + reactor: ThreadedMemoryReactorClock, + clock: Clock, + ) -> HomeServer: # we mock out the keyring so as to skip the authentication check on the # federation API call. mock_keyring = Mock(spec=["verify_json_for_server"]) @@ -75,8 +80,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): # the tests assume that we are starting at unix time 1000 reactor.pump((1000,)) + self.mock_hs_notifier = Mock() hs = self.setup_test_homeserver( - notifier=Mock(), + notifier=self.mock_hs_notifier, federation_http_client=mock_federation_client, keyring=mock_keyring, replication_streams={}, @@ -90,32 +96,38 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): return d def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - mock_notifier = hs.get_notifier() - self.on_new_event = mock_notifier.on_new_event + self.on_new_event = self.mock_hs_notifier.on_new_event - self.handler = hs.get_typing_handler() + # hs.get_typing_handler will return a TypingWriterHandler when calling it + # from the main process, and a FollowerTypingHandler on workers. + # We rely on methods only available on the former, so assert we have the + # correct type here. We have to assign self.handler after the assert, + # otherwise mypy will treat it as a FollowerTypingHandler + handler = hs.get_typing_handler() + assert isinstance(handler, TypingWriterHandler) + self.handler = handler self.event_source = hs.get_event_sources().sources.typing self.datastore = hs.get_datastores().main + self.datastore.get_destination_retry_timings = Mock( return_value=make_awaitable(None) ) - self.datastore.get_device_updates_by_remote = Mock( + self.datastore.get_device_updates_by_remote = Mock( # type: ignore[assignment] return_value=make_awaitable((0, [])) ) - self.datastore.get_destination_last_successful_stream_ordering = Mock( + self.datastore.get_destination_last_successful_stream_ordering = Mock( # type: ignore[assignment] return_value=make_awaitable(None) ) - def get_received_txn_response(*args): - return defer.succeed(None) - - self.datastore.get_received_txn_response = get_received_txn_response + self.datastore.get_received_txn_response = Mock( # type: ignore[assignment] + return_value=make_awaitable(None) + ) - self.room_members = [] + self.room_members: List[UserID] = [] async def check_user_in_room(room_id: str, requester: Requester) -> None: if requester.user.to_string() not in [ @@ -124,47 +136,54 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): raise AuthError(401, "User is not in the room") return None - hs.get_auth().check_user_in_room = check_user_in_room + hs.get_auth().check_user_in_room = Mock( # type: ignore[assignment] + side_effect=check_user_in_room + ) async def check_host_in_room(room_id: str, server_name: str) -> bool: return room_id == ROOM_ID - hs.get_event_auth_handler().is_host_in_room = check_host_in_room + hs.get_event_auth_handler().is_host_in_room = Mock( # type: ignore[assignment] + side_effect=check_host_in_room + ) - async def get_current_hosts_in_room(room_id: str): + async def get_current_hosts_in_room(room_id: str) -> Set[str]: return {member.domain for member in self.room_members} - hs.get_storage_controllers().state.get_current_hosts_in_room = ( - get_current_hosts_in_room + hs.get_storage_controllers().state.get_current_hosts_in_room = Mock( # type: ignore[assignment] + side_effect=get_current_hosts_in_room ) - hs.get_storage_controllers().state.get_current_hosts_in_room_or_partial_state_approximation = ( - get_current_hosts_in_room + hs.get_storage_controllers().state.get_current_hosts_in_room_or_partial_state_approximation = Mock( # type: ignore[assignment] + side_effect=get_current_hosts_in_room ) - async def get_users_in_room(room_id: str): + async def get_users_in_room(room_id: str) -> Set[str]: return {str(u) for u in self.room_members} - self.datastore.get_users_in_room = get_users_in_room + self.datastore.get_users_in_room = Mock(side_effect=get_users_in_room) - self.datastore.get_user_directory_stream_pos = Mock( + self.datastore.get_user_directory_stream_pos = Mock( # type: ignore[assignment] side_effect=( - # we deliberately return a non-None stream pos to avoid doing an initial_spam + # we deliberately return a non-None stream pos to avoid + # doing an initial_sync lambda: make_awaitable(1) ) ) - self.datastore.get_partial_current_state_deltas = Mock(return_value=(0, None)) + self.datastore.get_partial_current_state_deltas = Mock(return_value=(0, None)) # type: ignore[assignment] - self.datastore.get_to_device_stream_token = lambda: 0 - self.datastore.get_new_device_msgs_for_remote = ( - lambda *args, **kargs: make_awaitable(([], 0)) + self.datastore.get_to_device_stream_token = Mock( # type: ignore[assignment] + side_effect=lambda: 0 + ) + self.datastore.get_new_device_msgs_for_remote = Mock( # type: ignore[assignment] + side_effect=lambda *args, **kargs: make_awaitable(([], 0)) ) - self.datastore.delete_device_msgs_for_remote = ( - lambda *args, **kargs: make_awaitable(None) + self.datastore.delete_device_msgs_for_remote = Mock( # type: ignore[assignment] + side_effect=lambda *args, **kargs: make_awaitable(None) ) - self.datastore.set_received_txn_response = ( - lambda *args, **kwargs: make_awaitable(None) + self.datastore.set_received_txn_response = Mock( # type: ignore[assignment] + side_effect=lambda *args, **kwargs: make_awaitable(None) ) def test_started_typing_local(self) -> None: @@ -186,7 +205,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.assertEqual(self.event_source.get_current_key(), 1) events = self.get_success( self.event_source.get_new_events( - user=U_APPLE, from_key=0, limit=None, room_ids=[ROOM_ID], is_guest=False + user=U_APPLE, from_key=0, limit=0, room_ids=[ROOM_ID], is_guest=False ) ) self.assertEqual( @@ -257,7 +276,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.assertEqual(self.event_source.get_current_key(), 1) events = self.get_success( self.event_source.get_new_events( - user=U_APPLE, from_key=0, limit=None, room_ids=[ROOM_ID], is_guest=False + user=U_APPLE, from_key=0, limit=0, room_ids=[ROOM_ID], is_guest=False ) ) self.assertEqual( @@ -298,7 +317,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.event_source.get_new_events( user=U_APPLE, from_key=0, - limit=None, + limit=0, room_ids=[OTHER_ROOM_ID], is_guest=False, ) @@ -351,7 +370,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.assertEqual(self.event_source.get_current_key(), 1) events = self.get_success( self.event_source.get_new_events( - user=U_APPLE, from_key=0, limit=None, room_ids=[ROOM_ID], is_guest=False + user=U_APPLE, from_key=0, limit=0, room_ids=[ROOM_ID], is_guest=False ) ) self.assertEqual( @@ -387,7 +406,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.event_source.get_new_events( user=U_APPLE, from_key=0, - limit=None, + limit=0, room_ids=[ROOM_ID], is_guest=False, ) @@ -412,7 +431,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.event_source.get_new_events( user=U_APPLE, from_key=1, - limit=None, + limit=0, room_ids=[ROOM_ID], is_guest=False, ) @@ -447,7 +466,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.event_source.get_new_events( user=U_APPLE, from_key=0, - limit=None, + limit=0, room_ids=[ROOM_ID], is_guest=False, ) diff --git a/tests/logging/__init__.py b/tests/logging/__init__.py index 1acf5666a8..1c5de95a80 100644 --- a/tests/logging/__init__.py +++ b/tests/logging/__init__.py @@ -13,9 +13,11 @@ # limitations under the License. import logging +from tests.unittest import TestCase -class LoggerCleanupMixin: - def get_logger(self, handler): + +class LoggerCleanupMixin(TestCase): + def get_logger(self, handler: logging.Handler) -> logging.Logger: """ Attach a handler to a logger and add clean-ups to remove revert this. """ diff --git a/tests/logging/test_opentracing.py b/tests/logging/test_opentracing.py index 0917e478a5..e28ba84cc2 100644 --- a/tests/logging/test_opentracing.py +++ b/tests/logging/test_opentracing.py @@ -153,7 +153,7 @@ class LogContextScopeManagerTestCase(TestCase): scopes = [] - async def task(i: int): + async def task(i: int) -> None: scope = start_active_span( f"task{i}", tracer=self._tracer, @@ -165,7 +165,7 @@ class LogContextScopeManagerTestCase(TestCase): self.assertEqual(self._tracer.active_span, scope.span) scope.close() - async def root(): + async def root() -> None: with start_active_span("root span", tracer=self._tracer) as root_scope: self.assertEqual(self._tracer.active_span, root_scope.span) scopes.append(root_scope) diff --git a/tests/logging/test_remote_handler.py b/tests/logging/test_remote_handler.py index b0d046fe00..c08954d887 100644 --- a/tests/logging/test_remote_handler.py +++ b/tests/logging/test_remote_handler.py @@ -11,7 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from twisted.test.proto_helpers import AccumulatingProtocol +from typing import Tuple + +from twisted.internet.protocol import Protocol +from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock from synapse.logging import RemoteHandler @@ -20,7 +23,9 @@ from tests.server import FakeTransport, get_clock from tests.unittest import TestCase -def connect_logging_client(reactor, client_id): +def connect_logging_client( + reactor: MemoryReactorClock, client_id: int +) -> Tuple[Protocol, AccumulatingProtocol]: # This is essentially tests.server.connect_client, but disabling autoflush on # the client transport. This is necessary to avoid an infinite loop due to # sending of data via the logging transport causing additional logs to be @@ -35,10 +40,10 @@ def connect_logging_client(reactor, client_id): class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase): - def setUp(self): + def setUp(self) -> None: self.reactor, _ = get_clock() - def test_log_output(self): + def test_log_output(self) -> None: """ The remote handler delivers logs over TCP. """ @@ -51,6 +56,7 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase): client, server = connect_logging_client(self.reactor, 0) # Trigger data being sent + assert isinstance(client.transport, FakeTransport) client.transport.flush() # One log message, with a single trailing newline @@ -61,7 +67,7 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase): # Ensure the data passed through properly. self.assertEqual(logs[0], "Hello there, wally!") - def test_log_backpressure_debug(self): + def test_log_backpressure_debug(self) -> None: """ When backpressure is hit, DEBUG logs will be shed. """ @@ -83,6 +89,7 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase): # Allow the reconnection client, server = connect_logging_client(self.reactor, 0) + assert isinstance(client.transport, FakeTransport) client.transport.flush() # Only the 7 infos made it through, the debugs were elided @@ -90,7 +97,7 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase): self.assertEqual(len(logs), 7) self.assertNotIn(b"debug", server.data) - def test_log_backpressure_info(self): + def test_log_backpressure_info(self) -> None: """ When backpressure is hit, DEBUG and INFO logs will be shed. """ @@ -116,6 +123,7 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase): # Allow the reconnection client, server = connect_logging_client(self.reactor, 0) + assert isinstance(client.transport, FakeTransport) client.transport.flush() # The 10 warnings made it through, the debugs and infos were elided @@ -124,7 +132,7 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase): self.assertNotIn(b"debug", server.data) self.assertNotIn(b"info", server.data) - def test_log_backpressure_cut_middle(self): + def test_log_backpressure_cut_middle(self) -> None: """ When backpressure is hit, and no more DEBUG and INFOs cannot be culled, it will cut the middle messages out. @@ -140,6 +148,7 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase): # Allow the reconnection client, server = connect_logging_client(self.reactor, 0) + assert isinstance(client.transport, FakeTransport) client.transport.flush() # The first five and last five warnings made it through, the debugs and @@ -151,7 +160,7 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase): logs, ) - def test_cancel_connection(self): + def test_cancel_connection(self) -> None: """ Gracefully handle the connection being cancelled. """ diff --git a/tests/logging/test_terse_json.py b/tests/logging/test_terse_json.py index 0b0d8737c1..fa27f1279a 100644 --- a/tests/logging/test_terse_json.py +++ b/tests/logging/test_terse_json.py @@ -14,24 +14,28 @@ import json import logging from io import BytesIO, StringIO +from typing import cast from unittest.mock import Mock, patch +from twisted.web.http import HTTPChannel from twisted.web.server import Request from synapse.http.site import SynapseRequest from synapse.logging._terse_json import JsonFormatter, TerseJsonFormatter from synapse.logging.context import LoggingContext, LoggingContextFilter +from synapse.types import JsonDict from tests.logging import LoggerCleanupMixin -from tests.server import FakeChannel +from tests.server import FakeChannel, get_clock from tests.unittest import TestCase class TerseJsonTestCase(LoggerCleanupMixin, TestCase): - def setUp(self): + def setUp(self) -> None: self.output = StringIO() + self.reactor, _ = get_clock() - def get_log_line(self): + def get_log_line(self) -> JsonDict: # One log message, with a single trailing newline. data = self.output.getvalue() logs = data.splitlines() @@ -39,7 +43,7 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase): self.assertEqual(data.count("\n"), 1) return json.loads(logs[0]) - def test_terse_json_output(self): + def test_terse_json_output(self) -> None: """ The Terse JSON formatter converts log messages to JSON. """ @@ -61,7 +65,7 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase): self.assertCountEqual(log.keys(), expected_log_keys) self.assertEqual(log["log"], "Hello there, wally!") - def test_extra_data(self): + def test_extra_data(self) -> None: """ Additional information can be included in the structured logging. """ @@ -93,7 +97,7 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase): self.assertEqual(log["int"], 3) self.assertIs(log["bool"], True) - def test_json_output(self): + def test_json_output(self) -> None: """ The Terse JSON formatter converts log messages to JSON. """ @@ -114,7 +118,7 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase): self.assertCountEqual(log.keys(), expected_log_keys) self.assertEqual(log["log"], "Hello there, wally!") - def test_with_context(self): + def test_with_context(self) -> None: """ The logging context should be added to the JSON response. """ @@ -139,7 +143,7 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase): self.assertEqual(log["log"], "Hello there, wally!") self.assertEqual(log["request"], "name") - def test_with_request_context(self): + def test_with_request_context(self) -> None: """ Information from the logging context request should be added to the JSON response. """ @@ -154,11 +158,13 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase): site.server_version_string = "Server v1" site.reactor = Mock() site.experimental_cors_msc3886 = False - request = SynapseRequest(FakeChannel(site, None), site) + request = SynapseRequest( + cast(HTTPChannel, FakeChannel(site, self.reactor)), site + ) # Call requestReceived to finish instantiating the object. request.content = BytesIO() - # Partially skip some of the internal processing of SynapseRequest. - request._started_processing = Mock() + # Partially skip some internal processing of SynapseRequest. + request._started_processing = Mock() # type: ignore[assignment] request.request_metrics = Mock(spec=["name"]) with patch.object(Request, "render"): request.requestReceived(b"POST", b"/_matrix/client/versions", b"1.1") @@ -200,7 +206,7 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase): self.assertEqual(log["protocol"], "1.1") self.assertEqual(log["user_agent"], "") - def test_with_exception(self): + def test_with_exception(self) -> None: """ The logging exception type & value should be added to the JSON response. """ diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py index c9afa0f3dd..b9047194dd 100644 --- a/tests/rest/client/test_sync.py +++ b/tests/rest/client/test_sync.py @@ -294,9 +294,7 @@ class SyncTypingTests(unittest.HomeserverTestCase): self.make_request("GET", sync_url % (access_token, next_batch)) -class SyncKnockTestCase( - unittest.HomeserverTestCase, KnockingStrippedStateEventHelperMixin -): +class SyncKnockTestCase(KnockingStrippedStateEventHelperMixin): servlets = [ synapse.rest.admin.register_servlets, login.register_servlets, diff --git a/tests/rest/client/test_transactions.py b/tests/rest/client/test_transactions.py index 21a1ca2a68..3086e1b565 100644 --- a/tests/rest/client/test_transactions.py +++ b/tests/rest/client/test_transactions.py @@ -13,18 +13,22 @@ # limitations under the License. from http import HTTPStatus +from typing import Any, Generator, Tuple, cast from unittest.mock import Mock, call -from twisted.internet import defer, reactor +from twisted.internet import defer, reactor as _reactor from synapse.logging.context import SENTINEL_CONTEXT, LoggingContext, current_context from synapse.rest.client.transactions import CLEANUP_PERIOD_MS, HttpTransactionCache +from synapse.types import ISynapseReactor, JsonDict from synapse.util import Clock from tests import unittest from tests.test_utils import make_awaitable from tests.utils import MockClock +reactor = cast(ISynapseReactor, _reactor) + class HttpTransactionCacheTestCase(unittest.TestCase): def setUp(self) -> None: @@ -34,11 +38,13 @@ class HttpTransactionCacheTestCase(unittest.TestCase): self.hs.get_auth = Mock() self.cache = HttpTransactionCache(self.hs) - self.mock_http_response = (HTTPStatus.OK, "GOOD JOB!") + self.mock_http_response = (HTTPStatus.OK, {"result": "GOOD JOB!"}) self.mock_key = "foo" @defer.inlineCallbacks - def test_executes_given_function(self): + def test_executes_given_function( + self, + ) -> Generator["defer.Deferred[Any]", object, None]: cb = Mock(return_value=make_awaitable(self.mock_http_response)) res = yield self.cache.fetch_or_execute( self.mock_key, cb, "some_arg", keyword="arg" @@ -47,7 +53,9 @@ class HttpTransactionCacheTestCase(unittest.TestCase): self.assertEqual(res, self.mock_http_response) @defer.inlineCallbacks - def test_deduplicates_based_on_key(self): + def test_deduplicates_based_on_key( + self, + ) -> Generator["defer.Deferred[Any]", object, None]: cb = Mock(return_value=make_awaitable(self.mock_http_response)) for i in range(3): # invoke multiple times res = yield self.cache.fetch_or_execute( @@ -58,18 +66,20 @@ class HttpTransactionCacheTestCase(unittest.TestCase): cb.assert_called_once_with("some_arg", keyword="arg", changing_args=0) @defer.inlineCallbacks - def test_logcontexts_with_async_result(self): + def test_logcontexts_with_async_result( + self, + ) -> Generator["defer.Deferred[Any]", object, None]: @defer.inlineCallbacks - def cb(): + def cb() -> Generator["defer.Deferred[object]", object, Tuple[int, JsonDict]]: yield Clock(reactor).sleep(0) - return "yay" + return 1, {} @defer.inlineCallbacks - def test(): + def test() -> Generator["defer.Deferred[Any]", object, None]: with LoggingContext("c") as c1: res = yield self.cache.fetch_or_execute(self.mock_key, cb) self.assertIs(current_context(), c1) - self.assertEqual(res, "yay") + self.assertEqual(res, (1, {})) # run the test twice in parallel d = defer.gatherResults([test(), test()]) @@ -78,13 +88,15 @@ class HttpTransactionCacheTestCase(unittest.TestCase): self.assertIs(current_context(), SENTINEL_CONTEXT) @defer.inlineCallbacks - def test_does_not_cache_exceptions(self): + def test_does_not_cache_exceptions( + self, + ) -> Generator["defer.Deferred[Any]", object, None]: """Checks that, if the callback throws an exception, it is called again for the next request. """ called = [False] - def cb(): + def cb() -> "defer.Deferred[Tuple[int, JsonDict]]": if called[0]: # return a valid result the second time return defer.succeed(self.mock_http_response) @@ -104,13 +116,15 @@ class HttpTransactionCacheTestCase(unittest.TestCase): self.assertIs(current_context(), test_context) @defer.inlineCallbacks - def test_does_not_cache_failures(self): + def test_does_not_cache_failures( + self, + ) -> Generator["defer.Deferred[Any]", object, None]: """Checks that, if the callback returns a failure, it is called again for the next request. """ called = [False] - def cb(): + def cb() -> "defer.Deferred[Tuple[int, JsonDict]]": if called[0]: # return a valid result the second time return defer.succeed(self.mock_http_response) @@ -130,7 +144,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase): self.assertIs(current_context(), test_context) @defer.inlineCallbacks - def test_cleans_up(self): + def test_cleans_up(self) -> Generator["defer.Deferred[Any]", object, None]: cb = Mock(return_value=make_awaitable(self.mock_http_response)) yield self.cache.fetch_or_execute(self.mock_key, cb, "an arg") # should NOT have cleaned up yet diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py index 3ba896ecf3..f1ca523d23 100644 --- a/tests/storage/test_user_directory.py +++ b/tests/storage/test_user_directory.py @@ -28,6 +28,7 @@ from synapse.storage.background_updates import _BackgroundUpdateHandler from synapse.storage.roommember import ProfileInfo from synapse.util import Clock +from tests.server import ThreadedMemoryReactorClock from tests.test_utils.event_injection import inject_member_event from tests.unittest import HomeserverTestCase, override_config @@ -138,7 +139,9 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase): register.register_servlets, ] - def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + def make_homeserver( + self, reactor: ThreadedMemoryReactorClock, clock: Clock + ) -> HomeServer: self.appservice = ApplicationService( token="i_am_an_app_service", id="1234", diff --git a/tests/unittest.py b/tests/unittest.py index a120c2976c..fa92dd94eb 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -75,6 +75,7 @@ from synapse.util.httpresourcetree import create_resource_tree from tests.server import ( CustomHeaderType, FakeChannel, + ThreadedMemoryReactorClock, get_clock, make_request, setup_test_homeserver, @@ -360,7 +361,7 @@ class HomeserverTestCase(TestCase): store.db_pool.updates.do_next_background_update(False), by=0.1 ) - def make_homeserver(self, reactor: MemoryReactor, clock: Clock): + def make_homeserver(self, reactor: ThreadedMemoryReactorClock, clock: Clock): """ Make and return a homeserver. |