diff options
Diffstat (limited to 'synapse')
25 files changed, 253 insertions, 177 deletions
diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 6432d32d83..6f9239d21c 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -17,6 +17,8 @@ """Contains constants from the specification.""" +import enum + from typing_extensions import Final # the max size of a (canonical-json-encoded) event @@ -290,3 +292,8 @@ class ApprovalNoticeMedium: NONE = "org.matrix.msc3866.none" EMAIL = "org.matrix.msc3866.email" + + +class Direction(enum.Enum): + BACKWARDS = "b" + FORWARDS = "f" diff --git a/synapse/events/utils.py b/synapse/events/utils.py index ae57a4df5e..52e4b467e8 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -605,10 +605,11 @@ class EventClientSerializer: _PowerLevel = Union[str, int] +PowerLevelsContent = Mapping[str, Union[_PowerLevel, Mapping[str, _PowerLevel]]] def copy_and_fixup_power_levels_contents( - old_power_levels: Mapping[str, Union[_PowerLevel, Mapping[str, _PowerLevel]]] + old_power_levels: PowerLevelsContent, ) -> Dict[str, Union[int, Dict[str, int]]]: """Copy the content of a power_levels event, unfreezing frozendicts along the way. diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py index 834006356a..d500b21809 100644 --- a/synapse/handlers/account_data.py +++ b/synapse/handlers/account_data.py @@ -14,7 +14,7 @@ # limitations under the License. import logging import random -from typing import TYPE_CHECKING, Awaitable, Callable, Collection, List, Optional, Tuple +from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional, Tuple from synapse.api.constants import AccountDataTypes from synapse.replication.http.account_data import ( @@ -26,7 +26,7 @@ from synapse.replication.http.account_data import ( ReplicationRemoveUserAccountDataRestServlet, ) from synapse.streams import EventSource -from synapse.types import JsonDict, StreamKeyType, UserID +from synapse.types import JsonDict, StrCollection, StreamKeyType, UserID if TYPE_CHECKING: from synapse.server import HomeServer @@ -322,7 +322,7 @@ class AccountDataEventSource(EventSource[int, JsonDict]): user: UserID, from_key: int, limit: int, - room_ids: Collection[str], + room_ids: StrCollection, is_guest: bool, explicit_room_id: Optional[str] = None, ) -> Tuple[List[JsonDict], int]: diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index 5bf8e86387..c81ea34758 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -16,7 +16,7 @@ import abc import logging from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set -from synapse.api.constants import Membership +from synapse.api.constants import Direction, Membership from synapse.events import EventBase from synapse.types import JsonDict, RoomStreamToken, StateMap, UserID from synapse.visibility import filter_events_for_client @@ -197,7 +197,7 @@ class AdminHandler: # efficient method perhaps but it does guarantee we get everything. while True: events, _ = await self.store.paginate_room_events( - room_id, from_key, to_key, limit=100, direction="f" + room_id, from_key, to_key, limit=100, direction=Direction.FORWARDS ) if not events: break diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 58180ae2fa..5c06073901 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -18,7 +18,6 @@ from http import HTTPStatus from typing import ( TYPE_CHECKING, Any, - Collection, Dict, Iterable, List, @@ -45,6 +44,7 @@ from synapse.metrics.background_process_metrics import ( ) from synapse.types import ( JsonDict, + StrCollection, StreamKeyType, StreamToken, UserID, @@ -146,7 +146,7 @@ class DeviceWorkerHandler: @cancellable async def get_device_changes_in_shared_rooms( - self, user_id: str, room_ids: Collection[str], from_token: StreamToken + self, user_id: str, room_ids: StrCollection, from_token: StreamToken ) -> Set[str]: """Get the set of users whose devices have changed who share a room with the given user. @@ -551,7 +551,7 @@ class DeviceHandler(DeviceWorkerHandler): @trace @measure_func("notify_device_update") async def notify_device_update( - self, user_id: str, device_ids: Collection[str] + self, user_id: str, device_ids: StrCollection ) -> None: """Notify that a user's device(s) has changed. Pokes the notifier, and remote servers if the user is local. diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py index f91dbbecb7..a23a8ce2a1 100644 --- a/synapse/handlers/event_auth.py +++ b/synapse/handlers/event_auth.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Collection, List, Mapping, Optional, Union +from typing import TYPE_CHECKING, List, Mapping, Optional, Union from synapse import event_auth from synapse.api.constants import ( @@ -29,7 +29,7 @@ from synapse.event_auth import ( ) from synapse.events import EventBase from synapse.events.builder import EventBuilder -from synapse.types import StateMap, get_domain_from_id +from synapse.types import StateMap, StrCollection, get_domain_from_id if TYPE_CHECKING: from synapse.server import HomeServer @@ -290,7 +290,7 @@ class EventAuthHandler: async def get_rooms_that_allow_join( self, state_ids: StateMap[str] - ) -> Collection[str]: + ) -> StrCollection: """ Generate a list of rooms in which membership allows access to a room. @@ -331,7 +331,7 @@ class EventAuthHandler: return result - async def is_user_in_rooms(self, room_ids: Collection[str], user_id: str) -> bool: + async def is_user_in_rooms(self, room_ids: StrCollection, user_id: str) -> bool: """ Check whether a user is a member of any of the provided rooms. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 233f8c113d..dc1cbf5c3d 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -20,17 +20,7 @@ import itertools import logging from enum import Enum from http import HTTPStatus -from typing import ( - TYPE_CHECKING, - Collection, - Dict, - Iterable, - List, - Optional, - Set, - Tuple, - Union, -) +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple, Union import attr from prometheus_client import Histogram @@ -70,7 +60,7 @@ from synapse.replication.http.federation import ( ) from synapse.storage.databases.main.events import PartialStateConflictError from synapse.storage.databases.main.events_worker import EventRedactBehaviour -from synapse.types import JsonDict, get_domain_from_id +from synapse.types import JsonDict, StrCollection, get_domain_from_id from synapse.types.state import StateFilter from synapse.util.async_helpers import Linearizer from synapse.util.retryutils import NotRetryingDestination @@ -179,7 +169,7 @@ class FederationHandler: # A dictionary mapping room IDs to (initial destination, other destinations) # tuples. self._partial_state_syncs_maybe_needing_restart: Dict[ - str, Tuple[Optional[str], Collection[str]] + str, Tuple[Optional[str], StrCollection] ] = {} # A lock guarding the partial state flag for rooms. # When the lock is held for a given room, no other concurrent code may @@ -437,7 +427,7 @@ class FederationHandler: ) ) - async def try_backfill(domains: Collection[str]) -> bool: + async def try_backfill(domains: StrCollection) -> bool: # TODO: Should we try multiple of these at a time? # Number of contacted remote homeservers that have denied our backfill @@ -1730,7 +1720,7 @@ class FederationHandler: def _start_partial_state_room_sync( self, initial_destination: Optional[str], - other_destinations: Collection[str], + other_destinations: StrCollection, room_id: str, ) -> None: """Starts the background process to resync the state of a partial state room, @@ -1812,7 +1802,7 @@ class FederationHandler: async def _sync_partial_state_room( self, initial_destination: Optional[str], - other_destinations: Collection[str], + other_destinations: StrCollection, room_id: str, ) -> None: """Background process to resync the state of a partial-state room @@ -1949,9 +1939,9 @@ class FederationHandler: def _prioritise_destinations_for_partial_state_resync( initial_destination: Optional[str], - other_destinations: Collection[str], + other_destinations: StrCollection, room_id: str, -) -> Collection[str]: +) -> StrCollection: """Work out the order in which we should ask servers to resync events. If an `initial_destination` is given, it takes top priority. Otherwise diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 904a721483..e037acbca2 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -80,6 +80,7 @@ from synapse.types import ( PersistedEventPosition, RoomStreamToken, StateMap, + StrCollection, UserID, get_domain_from_id, ) @@ -615,7 +616,7 @@ class FederationEventHandler: @trace async def backfill( - self, dest: str, room_id: str, limit: int, extremities: Collection[str] + self, dest: str, room_id: str, limit: int, extremities: StrCollection ) -> None: """Trigger a backfill request to `dest` for the given `room_id` @@ -1565,7 +1566,7 @@ class FederationEventHandler: @trace @tag_args async def _get_events_and_persist( - self, destination: str, room_id: str, event_ids: Collection[str] + self, destination: str, room_id: str, event_ids: StrCollection ) -> None: """Fetch the given events from a server, and persist them as outliers. diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 8c2260ad7d..191529bd8e 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -15,7 +15,13 @@ import logging from typing import TYPE_CHECKING, List, Optional, Tuple, cast -from synapse.api.constants import AccountDataTypes, EduTypes, EventTypes, Membership +from synapse.api.constants import ( + AccountDataTypes, + Direction, + EduTypes, + EventTypes, + Membership, +) from synapse.api.errors import SynapseError from synapse.events import EventBase from synapse.events.utils import SerializeEventConfig @@ -57,7 +63,13 @@ class InitialSyncHandler: self.validator = EventValidator() self.snapshot_cache: ResponseCache[ Tuple[ - str, Optional[StreamToken], Optional[StreamToken], str, int, bool, bool + str, + Optional[StreamToken], + Optional[StreamToken], + Direction, + int, + bool, + bool, ] ] = ResponseCache(hs.get_clock(), "initial_sync_cache") self._event_serializer = hs.get_event_client_serializer() diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 8c8ff18a1a..ceefa16b49 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -13,13 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Set +from typing import TYPE_CHECKING, Dict, List, Optional, Set import attr from twisted.python.failure import Failure -from synapse.api.constants import EventTypes, Membership +from synapse.api.constants import Direction, EventTypes, Membership from synapse.api.errors import SynapseError from synapse.api.filtering import Filter from synapse.events.utils import SerializeEventConfig @@ -28,7 +28,7 @@ from synapse.logging.opentracing import trace from synapse.metrics.background_process_metrics import run_as_background_process from synapse.rest.admin._base import assert_user_is_admin from synapse.streams.config import PaginationConfig -from synapse.types import JsonDict, Requester, StreamKeyType +from synapse.types import JsonDict, Requester, StrCollection, StreamKeyType from synapse.types.state import StateFilter from synapse.util.async_helpers import ReadWriteLock from synapse.util.stringutils import random_string @@ -391,7 +391,7 @@ class PaginationHandler: """ return self._delete_by_id.get(delete_id) - def get_delete_ids_by_room(self, room_id: str) -> Optional[Collection[str]]: + def get_delete_ids_by_room(self, room_id: str) -> Optional[StrCollection]: """Get all active delete ids by room Args: @@ -448,7 +448,7 @@ class PaginationHandler: if pagin_config.from_token: from_token = pagin_config.from_token - elif pagin_config.direction == "f": + elif pagin_config.direction == Direction.FORWARDS: from_token = ( await self.hs.get_event_sources().get_start_token_for_pagination( room_id @@ -476,7 +476,7 @@ class PaginationHandler: room_id, requester, allow_departed_users=True ) - if pagin_config.direction == "b": + if pagin_config.direction == Direction.BACKWARDS: # if we're going backwards, we might need to backfill. This # requires that we have a topo token. if room_token.topological: diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 43e4e7b1b4..87af31aa27 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -64,7 +64,13 @@ from synapse.replication.tcp.commands import ClearUserSyncsCommand from synapse.replication.tcp.streams import PresenceFederationStream, PresenceStream from synapse.storage.databases.main import DataStore from synapse.streams import EventSource -from synapse.types import JsonDict, StreamKeyType, UserID, get_domain_from_id +from synapse.types import ( + JsonDict, + StrCollection, + StreamKeyType, + UserID, + get_domain_from_id, +) from synapse.util.async_helpers import Linearizer from synapse.util.metrics import Measure from synapse.util.wheel_timer import WheelTimer @@ -320,7 +326,7 @@ class BasePresenceHandler(abc.ABC): for destination, host_states in hosts_to_states.items(): self._federation.send_presence_to_destinations(host_states, [destination]) - async def send_full_presence_to_users(self, user_ids: Collection[str]) -> None: + async def send_full_presence_to_users(self, user_ids: StrCollection) -> None: """ Adds to the list of users who should receive a full snapshot of presence upon their next sync. Note that this only works for local users. @@ -1601,7 +1607,7 @@ class PresenceEventSource(EventSource[int, UserPresenceState]): # Having a default limit doesn't match the EventSource API, but some # callers do not provide it. It is unused in this class. limit: int = 0, - room_ids: Optional[Collection[str]] = None, + room_ids: Optional[StrCollection] = None, is_guest: bool = False, explicit_room_id: Optional[str] = None, include_offline: bool = True, @@ -1688,7 +1694,7 @@ class PresenceEventSource(EventSource[int, UserPresenceState]): # The set of users that we're interested in and that have had a presence update. # We'll actually pull the presence updates for these users at the end. - interested_and_updated_users: Collection[str] + interested_and_updated_users: StrCollection if from_key is not None: # First get all users that have had a presence update @@ -2120,7 +2126,7 @@ class PresenceFederationQueue: # stream_id, destinations, user_ids)`. We don't store the full states # for efficiency, and remote workers will already have the full states # cached. - self._queue: List[Tuple[int, int, Collection[str], Set[str]]] = [] + self._queue: List[Tuple[int, int, StrCollection, Set[str]]] = [] self._next_id = 1 @@ -2142,7 +2148,7 @@ class PresenceFederationQueue: self._queue = self._queue[index:] def send_presence_to_destinations( - self, states: Collection[UserPresenceState], destinations: Collection[str] + self, states: Collection[UserPresenceState], destinations: StrCollection ) -> None: """Send the presence states to the given destinations. diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index e96f9999a8..0fb15391e0 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Collection, Dict, FrozenSet, Iterable, List, O import attr -from synapse.api.constants import EventTypes, RelationTypes +from synapse.api.constants import Direction, EventTypes, RelationTypes from synapse.api.errors import SynapseError from synapse.events import EventBase, relation_from_event from synapse.logging.context import make_deferred_yieldable, run_in_background @@ -413,7 +413,11 @@ class RelationsHandler: # Attempt to find another event to use as the latest event. potential_events, _ = await self._main_store.get_relations_for_event( - event_id, event, room_id, RelationTypes.THREAD, direction="f" + event_id, + event, + room_id, + RelationTypes.THREAD, + direction=Direction.FORWARDS, ) # Filter out ignored users. diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 572c7b4db3..60a6d9cf3c 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -20,16 +20,7 @@ import random import string from collections import OrderedDict from http import HTTPStatus -from typing import ( - TYPE_CHECKING, - Any, - Awaitable, - Collection, - Dict, - List, - Optional, - Tuple, -) +from typing import TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Tuple import attr from typing_extensions import TypedDict @@ -72,6 +63,7 @@ from synapse.types import ( RoomID, RoomStreamToken, StateMap, + StrCollection, StreamKeyType, StreamToken, UserID, @@ -1644,7 +1636,7 @@ class RoomEventSource(EventSource[RoomStreamToken, EventBase]): user: UserID, from_key: RoomStreamToken, limit: int, - room_ids: Collection[str], + room_ids: StrCollection, is_guest: bool, explicit_room_id: Optional[str] = None, ) -> Tuple[List[EventBase], RoomStreamToken]: diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py index c6b869c6f4..4472019fbc 100644 --- a/synapse/handlers/room_summary.py +++ b/synapse/handlers/room_summary.py @@ -36,7 +36,7 @@ from synapse.api.errors import ( ) from synapse.api.ratelimiting import Ratelimiter from synapse.events import EventBase -from synapse.types import JsonDict, Requester +from synapse.types import JsonDict, Requester, StrCollection from synapse.util.caches.response_cache import ResponseCache if TYPE_CHECKING: @@ -870,7 +870,7 @@ class _RoomQueueEntry: # The room ID of this entry. room_id: str # The server to query if the room is not known locally. - via: Sequence[str] + via: StrCollection # The minimum number of hops necessary to get to this room (compared to the # originally requested room). depth: int = 0 diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index 40f4635c4e..9bbf83047d 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -14,7 +14,7 @@ import itertools import logging -from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple import attr from unpaddedbase64 import decode_base64, encode_base64 @@ -23,7 +23,7 @@ from synapse.api.constants import EventTypes, Membership from synapse.api.errors import NotFoundError, SynapseError from synapse.api.filtering import Filter from synapse.events import EventBase -from synapse.types import JsonDict, StreamKeyType, UserID +from synapse.types import JsonDict, StrCollection, StreamKeyType, UserID from synapse.types.state import StateFilter from synapse.visibility import filter_events_for_client @@ -418,7 +418,7 @@ class SearchHandler: async def _search_by_rank( self, user: UserID, - room_ids: Collection[str], + room_ids: StrCollection, search_term: str, keys: Iterable[str], search_filter: Filter, @@ -491,7 +491,7 @@ class SearchHandler: async def _search_by_recent( self, user: UserID, - room_ids: Collection[str], + room_ids: StrCollection, search_term: str, keys: Iterable[str], search_filter: Filter, diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index 44e70fc4b8..4a27c0f051 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -20,7 +20,6 @@ from typing import ( Any, Awaitable, Callable, - Collection, Dict, Iterable, List, @@ -47,6 +46,7 @@ from synapse.http.server import respond_with_html, respond_with_redirect from synapse.http.site import SynapseRequest from synapse.types import ( JsonDict, + StrCollection, UserID, contains_invalid_mxid_characters, create_requester, @@ -141,7 +141,8 @@ class UserAttributes: confirm_localpart: bool = False display_name: Optional[str] = None picture: Optional[str] = None - emails: Collection[str] = attr.Factory(list) + # mypy thinks these are incompatible for some reason. + emails: StrCollection = attr.Factory(list) # type: ignore[assignment] @attr.s(slots=True, auto_attribs=True) @@ -159,7 +160,7 @@ class UsernameMappingSession: # attributes returned by the ID mapper display_name: Optional[str] - emails: Collection[str] + emails: StrCollection # An optional dictionary of extra attributes to be provided to the client in the # login response. @@ -174,7 +175,7 @@ class UsernameMappingSession: # choices made by the user chosen_localpart: Optional[str] = None use_display_name: bool = True - emails_to_use: Collection[str] = () + emails_to_use: StrCollection = () terms_accepted_version: Optional[str] = None diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 5ebd3ea855..5235e29460 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -17,7 +17,6 @@ from typing import ( TYPE_CHECKING, AbstractSet, Any, - Collection, Dict, FrozenSet, List, @@ -62,6 +61,7 @@ from synapse.types import ( Requester, RoomStreamToken, StateMap, + StrCollection, StreamKeyType, StreamToken, UserID, @@ -1179,7 +1179,7 @@ class SyncHandler: async def _find_missing_partial_state_memberships( self, room_id: str, - members_to_fetch: Collection[str], + members_to_fetch: StrCollection, events_with_membership_auth: Mapping[str, EventBase], found_state_ids: StateMap[str], ) -> StateMap[str]: diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 6153a48257..d22dd19d38 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -1158,7 +1158,7 @@ class ModuleApi: # Send to remote destinations. destination = UserID.from_string(user).domain presence_handler.get_federation_queue().send_presence_to_destinations( - presence_events, destination + presence_events, [destination] ) def looping_background_call( diff --git a/synapse/notifier.py b/synapse/notifier.py index 2b0e52f23c..a8832a3f8e 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -46,6 +46,7 @@ from synapse.types import ( JsonDict, PersistedEventPosition, RoomStreamToken, + StrCollection, StreamKeyType, StreamToken, UserID, @@ -716,7 +717,7 @@ class Notifier: async def _get_room_ids( self, user: UserID, explicit_room_id: Optional[str] - ) -> Tuple[Collection[str], bool]: + ) -> Tuple[StrCollection, bool]: joined_room_ids = await self.store.get_rooms_for_user(user.to_string()) if explicit_room_id: if explicit_room_id in joined_room_ids: diff --git a/synapse/rest/client/push_rule.py b/synapse/rest/client/push_rule.py index 8191b4e32c..ad5c10c99d 100644 --- a/synapse/rest/client/push_rule.py +++ b/synapse/rest/client/push_rule.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, List, Sequence, Tuple, Union +from typing import TYPE_CHECKING, List, Tuple, Union from synapse.api.errors import ( NotFoundError, @@ -169,7 +169,7 @@ class PushRuleRestServlet(RestServlet): raise UnrecognizedRequestError() -def _rule_spec_from_path(path: Sequence[str]) -> RuleSpec: +def _rule_spec_from_path(path: List[str]) -> RuleSpec: """Turn a sequence of path components into a rule spec Args: diff --git a/synapse/rest/client/transactions.py b/synapse/rest/client/transactions.py index 61375651bc..3f40f1874a 100644 --- a/synapse/rest/client/transactions.py +++ b/synapse/rest/client/transactions.py @@ -19,6 +19,7 @@ from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Tuple from typing_extensions import ParamSpec +from twisted.internet.defer import Deferred from twisted.python.failure import Failure from twisted.web.server import Request @@ -90,7 +91,7 @@ class HttpTransactionCache: fn: Callable[P, Awaitable[Tuple[int, JsonDict]]], *args: P.args, **kwargs: P.kwargs, - ) -> Awaitable[Tuple[int, JsonDict]]: + ) -> "Deferred[Tuple[int, JsonDict]]": """Fetches the response for this transaction, or executes the given function to produce a response for this transaction. diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 84f844b79e..0018d6f7ab 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -30,7 +30,7 @@ from typing import ( import attr -from synapse.api.constants import MAIN_TIMELINE, RelationTypes +from synapse.api.constants import MAIN_TIMELINE, Direction, RelationTypes from synapse.api.errors import SynapseError from synapse.events import EventBase from synapse.storage._base import SQLBaseStore @@ -40,9 +40,13 @@ from synapse.storage.database import ( LoggingTransaction, make_in_list_sql_clause, ) -from synapse.storage.databases.main.stream import generate_pagination_where_clause +from synapse.storage.databases.main.stream import ( + generate_next_token, + generate_pagination_bounds, + generate_pagination_where_clause, +) from synapse.storage.engines import PostgresEngine -from synapse.types import JsonDict, RoomStreamToken, StreamKeyType, StreamToken +from synapse.types import JsonDict, StreamKeyType, StreamToken from synapse.util.caches.descriptors import cached, cachedList if TYPE_CHECKING: @@ -164,7 +168,7 @@ class RelationsWorkerStore(SQLBaseStore): relation_type: Optional[str] = None, event_type: Optional[str] = None, limit: int = 5, - direction: str = "b", + direction: Direction = Direction.BACKWARDS, from_token: Optional[StreamToken] = None, to_token: Optional[StreamToken] = None, ) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]: @@ -177,8 +181,8 @@ class RelationsWorkerStore(SQLBaseStore): relation_type: Only fetch events with this relation type, if given. event_type: Only fetch events with this event type, if given. limit: Only fetch the most recent `limit` events. - direction: Whether to fetch the most recent first (`"b"`) or the - oldest first (`"f"`). + direction: Whether to fetch the most recent first (backwards) or the + oldest first (forwards). from_token: Fetch rows from the given token, or from the start if None. to_token: Fetch rows up to the given token, or up to the end if None. @@ -207,24 +211,23 @@ class RelationsWorkerStore(SQLBaseStore): where_clause.append("type = ?") where_args.append(event_type) + order, from_bound, to_bound = generate_pagination_bounds( + direction, + from_token.room_key if from_token else None, + to_token.room_key if to_token else None, + ) + pagination_clause = generate_pagination_where_clause( direction=direction, column_names=("topological_ordering", "stream_ordering"), - from_token=from_token.room_key.as_historical_tuple() - if from_token - else None, - to_token=to_token.room_key.as_historical_tuple() if to_token else None, + from_token=from_bound, + to_token=to_bound, engine=self.database_engine, ) if pagination_clause: where_clause.append(pagination_clause) - if direction == "b": - order = "DESC" - else: - order = "ASC" - sql = """ SELECT event_id, relation_type, sender, topological_ordering, stream_ordering FROM event_relations @@ -266,16 +269,9 @@ class RelationsWorkerStore(SQLBaseStore): topo_orderings = topo_orderings[:limit] stream_orderings = stream_orderings[:limit] - topo = topo_orderings[-1] - token = stream_orderings[-1] - if direction == "b": - # Tokens are positions between events. - # This token points *after* the last event in the chunk. - # We need it to point to the event before it in the chunk - # when we are going backwards so we subtract one from the - # stream part. - token -= 1 - next_key = RoomStreamToken(topo, token) + next_key = generate_next_token( + direction, topo_orderings[-1], stream_orderings[-1] + ) if from_token: next_token = from_token.copy_and_replace( diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index d28fc65df9..818c46182e 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -55,6 +55,7 @@ from typing_extensions import Literal from twisted.internet import defer +from synapse.api.constants import Direction from synapse.api.filtering import Filter from synapse.events import EventBase from synapse.logging.context import make_deferred_yieldable, run_in_background @@ -86,7 +87,6 @@ MAX_STREAM_SIZE = 1000 _STREAM_TOKEN = "stream" _TOPOLOGICAL_TOKEN = "topological" - # Used as return values for pagination APIs @attr.s(slots=True, frozen=True, auto_attribs=True) class _EventDictReturn: @@ -104,7 +104,7 @@ class _EventsAround: def generate_pagination_where_clause( - direction: str, + direction: Direction, column_names: Tuple[str, str], from_token: Optional[Tuple[Optional[int], int]], to_token: Optional[Tuple[Optional[int], int]], @@ -130,27 +130,26 @@ def generate_pagination_where_clause( token, but include those that match the to token. Args: - direction: Whether we're paginating backwards("b") or forwards ("f"). + direction: Whether we're paginating backwards or forwards. column_names: The column names to bound. Must *not* be user defined as these get inserted directly into the SQL statement without escapes. from_token: The start point for the pagination. This is an exclusive - minimum bound if direction is "f", and an inclusive maximum bound if - direction is "b". + minimum bound if direction is forwards, and an inclusive maximum bound if + direction is backwards. to_token: The endpoint point for the pagination. This is an inclusive - maximum bound if direction is "f", and an exclusive minimum bound if - direction is "b". + maximum bound if direction is forwards, and an exclusive minimum bound if + direction is backwards. engine: The database engine to generate the clauses for Returns: The sql expression """ - assert direction in ("b", "f") where_clause = [] if from_token: where_clause.append( _make_generic_sql_bound( - bound=">=" if direction == "b" else "<", + bound=">=" if direction == Direction.BACKWARDS else "<", column_names=column_names, values=from_token, engine=engine, @@ -160,7 +159,7 @@ def generate_pagination_where_clause( if to_token: where_clause.append( _make_generic_sql_bound( - bound="<" if direction == "b" else ">=", + bound="<" if direction == Direction.BACKWARDS else ">=", column_names=column_names, values=to_token, engine=engine, @@ -170,6 +169,104 @@ def generate_pagination_where_clause( return " AND ".join(where_clause) +def generate_pagination_bounds( + direction: Direction, + from_token: Optional[RoomStreamToken], + to_token: Optional[RoomStreamToken], +) -> Tuple[ + str, Optional[Tuple[Optional[int], int]], Optional[Tuple[Optional[int], int]] +]: + """ + Generate a start and end point for this page of events. + + Args: + direction: Whether pagination is going forwards or backwards. + from_token: The token to start pagination at, or None to start at the first value. + to_token: The token to end pagination at, or None to not limit the end point. + + Returns: + A three tuple of: + + ASC or DESC for sorting of the query. + + The starting position as a tuple of ints representing + (topological position, stream position) or None if no from_token was + provided. The topological position may be None for live tokens. + + The end position in the same format as the starting position, or None + if no to_token was provided. + """ + + # Tokens really represent positions between elements, but we use + # the convention of pointing to the event before the gap. Hence + # we have a bit of asymmetry when it comes to equalities. + if direction == Direction.BACKWARDS: + order = "DESC" + else: + order = "ASC" + + # The bounds for the stream tokens are complicated by the fact + # that we need to handle the instance_map part of the tokens. We do this + # by fetching all events between the min stream token and the maximum + # stream token (as returned by `RoomStreamToken.get_max_stream_pos`) and + # then filtering the results. + from_bound: Optional[Tuple[Optional[int], int]] = None + if from_token: + if from_token.topological is not None: + from_bound = from_token.as_historical_tuple() + elif direction == Direction.BACKWARDS: + from_bound = ( + None, + from_token.get_max_stream_pos(), + ) + else: + from_bound = ( + None, + from_token.stream, + ) + + to_bound: Optional[Tuple[Optional[int], int]] = None + if to_token: + if to_token.topological is not None: + to_bound = to_token.as_historical_tuple() + elif direction == Direction.BACKWARDS: + to_bound = ( + None, + to_token.stream, + ) + else: + to_bound = ( + None, + to_token.get_max_stream_pos(), + ) + + return order, from_bound, to_bound + + +def generate_next_token( + direction: Direction, last_topo_ordering: int, last_stream_ordering: int +) -> RoomStreamToken: + """ + Generate the next room stream token based on the currently returned data. + + Args: + direction: Whether pagination is going forwards or backwards. + last_topo_ordering: The last topological ordering being returned. + last_stream_ordering: The last stream ordering being returned. + + Returns: + A new RoomStreamToken to return to the client. + """ + if direction == Direction.BACKWARDS: + # Tokens are positions between events. + # This token points *after* the last event in the chunk. + # We need it to point to the event before it in the chunk + # when we are going backwards so we subtract one from the + # stream part. + last_stream_ordering -= 1 + return RoomStreamToken(last_topo_ordering, last_stream_ordering) + + def _make_generic_sql_bound( bound: str, column_names: Tuple[str, str], @@ -1103,7 +1200,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): txn, room_id, before_token, - direction="b", + direction=Direction.BACKWARDS, limit=before_limit, event_filter=event_filter, ) @@ -1113,7 +1210,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): txn, room_id, after_token, - direction="f", + direction=Direction.FORWARDS, limit=after_limit, event_filter=event_filter, ) @@ -1276,7 +1373,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): room_id: str, from_token: RoomStreamToken, to_token: Optional[RoomStreamToken] = None, - direction: str = "b", + direction: Direction = Direction.BACKWARDS, limit: int = -1, event_filter: Optional[Filter] = None, ) -> Tuple[List[_EventDictReturn], RoomStreamToken]: @@ -1287,8 +1384,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): room_id from_token: The token used to stream from to_token: A token which if given limits the results to only those before - direction: Either 'b' or 'f' to indicate whether we are paginating - forwards or backwards from `from_key`. + direction: Indicates whether we are paginating forwards or backwards + from `from_key`. limit: The maximum number of events to return. event_filter: If provided filters the events to those that match the filter. @@ -1300,47 +1397,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): `to_token`), or `limit` is zero. """ - # Tokens really represent positions between elements, but we use - # the convention of pointing to the event before the gap. Hence - # we have a bit of asymmetry when it comes to equalities. args = [False, room_id] - if direction == "b": - order = "DESC" - else: - order = "ASC" - - # The bounds for the stream tokens are complicated by the fact - # that we need to handle the instance_map part of the tokens. We do this - # by fetching all events between the min stream token and the maximum - # stream token (as returned by `RoomStreamToken.get_max_stream_pos`) and - # then filtering the results. - if from_token.topological is not None: - from_bound: Tuple[Optional[int], int] = from_token.as_historical_tuple() - elif direction == "b": - from_bound = ( - None, - from_token.get_max_stream_pos(), - ) - else: - from_bound = ( - None, - from_token.stream, - ) - to_bound: Optional[Tuple[Optional[int], int]] = None - if to_token: - if to_token.topological is not None: - to_bound = to_token.as_historical_tuple() - elif direction == "b": - to_bound = ( - None, - to_token.stream, - ) - else: - to_bound = ( - None, - to_token.get_max_stream_pos(), - ) + order, from_bound, to_bound = generate_pagination_bounds( + direction, from_token, to_token + ) bounds = generate_pagination_where_clause( direction=direction, @@ -1427,8 +1488,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): _EventDictReturn(event_id, topological_ordering, stream_ordering) for event_id, instance_name, topological_ordering, stream_ordering in txn if _filter_results( - lower_token=to_token if direction == "b" else from_token, - upper_token=from_token if direction == "b" else to_token, + lower_token=to_token + if direction == Direction.BACKWARDS + else from_token, + upper_token=from_token + if direction == Direction.BACKWARDS + else to_token, instance_name=instance_name, topological_ordering=topological_ordering, stream_ordering=stream_ordering, @@ -1436,16 +1501,10 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): ][:limit] if rows: - topo = rows[-1].topological_ordering - token = rows[-1].stream_ordering - if direction == "b": - # Tokens are positions between events. - # This token points *after* the last event in the chunk. - # We need it to point to the event before it in the chunk - # when we are going backwards so we subtract one from the - # stream part. - token -= 1 - next_token = RoomStreamToken(topo, token) + assert rows[-1].topological_ordering is not None + next_token = generate_next_token( + direction, rows[-1].topological_ordering, rows[-1].stream_ordering + ) else: # TODO (erikj): We should work out what to do here instead. next_token = to_token if to_token else from_token @@ -1458,7 +1517,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): room_id: str, from_key: RoomStreamToken, to_key: Optional[RoomStreamToken] = None, - direction: str = "b", + direction: Direction = Direction.BACKWARDS, limit: int = -1, event_filter: Optional[Filter] = None, ) -> Tuple[List[EventBase], RoomStreamToken]: @@ -1468,8 +1527,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): room_id from_key: The token used to stream from to_key: A token which if given limits the results to only those before - direction: Either 'b' or 'f' to indicate whether we are paginating - forwards or backwards from `from_key`. + direction: Indicates whether we are paginating forwards or backwards + from `from_key`. limit: The maximum number of events to return. event_filter: If provided filters the events to those that match the filter. diff --git a/synapse/streams/__init__.py b/synapse/streams/__init__.py index 2dcd43d0a2..c6c8a0315c 100644 --- a/synapse/streams/__init__.py +++ b/synapse/streams/__init__.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Collection, Generic, List, Optional, Tuple, TypeVar +from typing import Generic, List, Optional, Tuple, TypeVar -from synapse.types import UserID +from synapse.types import StrCollection, UserID # The key, this is either a stream token or int. K = TypeVar("K") @@ -28,7 +28,7 @@ class EventSource(Generic[K, R]): user: UserID, from_key: K, limit: int, - room_ids: Collection[str], + room_ids: StrCollection, is_guest: bool, explicit_room_id: Optional[str] = None, ) -> Tuple[List[R], K]: diff --git a/synapse/streams/config.py b/synapse/streams/config.py index 6df2de919c..5cb7875181 100644 --- a/synapse/streams/config.py +++ b/synapse/streams/config.py @@ -16,6 +16,7 @@ from typing import Optional import attr +from synapse.api.constants import Direction from synapse.api.errors import SynapseError from synapse.http.servlet import parse_integer, parse_string from synapse.http.site import SynapseRequest @@ -34,7 +35,7 @@ class PaginationConfig: from_token: Optional[StreamToken] to_token: Optional[StreamToken] - direction: str + direction: Direction limit: int @classmethod @@ -45,9 +46,13 @@ class PaginationConfig: default_limit: int, default_dir: str = "f", ) -> "PaginationConfig": - direction = parse_string( - request, "dir", default=default_dir, allowed_values=["f", "b"] + direction_str = parse_string( + request, + "dir", + default=default_dir, + allowed_values=[Direction.FORWARDS.value, Direction.BACKWARDS.value], ) + direction = Direction(direction_str) from_tok_str = parse_string(request, "from") to_tok_str = parse_string(request, "to") |