diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index 6432d32d83..6f9239d21c 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -17,6 +17,8 @@
"""Contains constants from the specification."""
+import enum
+
from typing_extensions import Final
# the max size of a (canonical-json-encoded) event
@@ -290,3 +292,8 @@ class ApprovalNoticeMedium:
NONE = "org.matrix.msc3866.none"
EMAIL = "org.matrix.msc3866.email"
+
+
+class Direction(enum.Enum):
+ BACKWARDS = "b"
+ FORWARDS = "f"
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index ae57a4df5e..52e4b467e8 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -605,10 +605,11 @@ class EventClientSerializer:
_PowerLevel = Union[str, int]
+PowerLevelsContent = Mapping[str, Union[_PowerLevel, Mapping[str, _PowerLevel]]]
def copy_and_fixup_power_levels_contents(
- old_power_levels: Mapping[str, Union[_PowerLevel, Mapping[str, _PowerLevel]]]
+ old_power_levels: PowerLevelsContent,
) -> Dict[str, Union[int, Dict[str, int]]]:
"""Copy the content of a power_levels event, unfreezing frozendicts along the way.
diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py
index 834006356a..d500b21809 100644
--- a/synapse/handlers/account_data.py
+++ b/synapse/handlers/account_data.py
@@ -14,7 +14,7 @@
# limitations under the License.
import logging
import random
-from typing import TYPE_CHECKING, Awaitable, Callable, Collection, List, Optional, Tuple
+from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional, Tuple
from synapse.api.constants import AccountDataTypes
from synapse.replication.http.account_data import (
@@ -26,7 +26,7 @@ from synapse.replication.http.account_data import (
ReplicationRemoveUserAccountDataRestServlet,
)
from synapse.streams import EventSource
-from synapse.types import JsonDict, StreamKeyType, UserID
+from synapse.types import JsonDict, StrCollection, StreamKeyType, UserID
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -322,7 +322,7 @@ class AccountDataEventSource(EventSource[int, JsonDict]):
user: UserID,
from_key: int,
limit: int,
- room_ids: Collection[str],
+ room_ids: StrCollection,
is_guest: bool,
explicit_room_id: Optional[str] = None,
) -> Tuple[List[JsonDict], int]:
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index 5bf8e86387..c81ea34758 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -16,7 +16,7 @@ import abc
import logging
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set
-from synapse.api.constants import Membership
+from synapse.api.constants import Direction, Membership
from synapse.events import EventBase
from synapse.types import JsonDict, RoomStreamToken, StateMap, UserID
from synapse.visibility import filter_events_for_client
@@ -197,7 +197,7 @@ class AdminHandler:
# efficient method perhaps but it does guarantee we get everything.
while True:
events, _ = await self.store.paginate_room_events(
- room_id, from_key, to_key, limit=100, direction="f"
+ room_id, from_key, to_key, limit=100, direction=Direction.FORWARDS
)
if not events:
break
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 58180ae2fa..5c06073901 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -18,7 +18,6 @@ from http import HTTPStatus
from typing import (
TYPE_CHECKING,
Any,
- Collection,
Dict,
Iterable,
List,
@@ -45,6 +44,7 @@ from synapse.metrics.background_process_metrics import (
)
from synapse.types import (
JsonDict,
+ StrCollection,
StreamKeyType,
StreamToken,
UserID,
@@ -146,7 +146,7 @@ class DeviceWorkerHandler:
@cancellable
async def get_device_changes_in_shared_rooms(
- self, user_id: str, room_ids: Collection[str], from_token: StreamToken
+ self, user_id: str, room_ids: StrCollection, from_token: StreamToken
) -> Set[str]:
"""Get the set of users whose devices have changed who share a room with
the given user.
@@ -551,7 +551,7 @@ class DeviceHandler(DeviceWorkerHandler):
@trace
@measure_func("notify_device_update")
async def notify_device_update(
- self, user_id: str, device_ids: Collection[str]
+ self, user_id: str, device_ids: StrCollection
) -> None:
"""Notify that a user's device(s) has changed. Pokes the notifier, and
remote servers if the user is local.
diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py
index f91dbbecb7..a23a8ce2a1 100644
--- a/synapse/handlers/event_auth.py
+++ b/synapse/handlers/event_auth.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Collection, List, Mapping, Optional, Union
+from typing import TYPE_CHECKING, List, Mapping, Optional, Union
from synapse import event_auth
from synapse.api.constants import (
@@ -29,7 +29,7 @@ from synapse.event_auth import (
)
from synapse.events import EventBase
from synapse.events.builder import EventBuilder
-from synapse.types import StateMap, get_domain_from_id
+from synapse.types import StateMap, StrCollection, get_domain_from_id
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -290,7 +290,7 @@ class EventAuthHandler:
async def get_rooms_that_allow_join(
self, state_ids: StateMap[str]
- ) -> Collection[str]:
+ ) -> StrCollection:
"""
Generate a list of rooms in which membership allows access to a room.
@@ -331,7 +331,7 @@ class EventAuthHandler:
return result
- async def is_user_in_rooms(self, room_ids: Collection[str], user_id: str) -> bool:
+ async def is_user_in_rooms(self, room_ids: StrCollection, user_id: str) -> bool:
"""
Check whether a user is a member of any of the provided rooms.
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 233f8c113d..dc1cbf5c3d 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -20,17 +20,7 @@ import itertools
import logging
from enum import Enum
from http import HTTPStatus
-from typing import (
- TYPE_CHECKING,
- Collection,
- Dict,
- Iterable,
- List,
- Optional,
- Set,
- Tuple,
- Union,
-)
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple, Union
import attr
from prometheus_client import Histogram
@@ -70,7 +60,7 @@ from synapse.replication.http.federation import (
)
from synapse.storage.databases.main.events import PartialStateConflictError
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
-from synapse.types import JsonDict, get_domain_from_id
+from synapse.types import JsonDict, StrCollection, get_domain_from_id
from synapse.types.state import StateFilter
from synapse.util.async_helpers import Linearizer
from synapse.util.retryutils import NotRetryingDestination
@@ -179,7 +169,7 @@ class FederationHandler:
# A dictionary mapping room IDs to (initial destination, other destinations)
# tuples.
self._partial_state_syncs_maybe_needing_restart: Dict[
- str, Tuple[Optional[str], Collection[str]]
+ str, Tuple[Optional[str], StrCollection]
] = {}
# A lock guarding the partial state flag for rooms.
# When the lock is held for a given room, no other concurrent code may
@@ -437,7 +427,7 @@ class FederationHandler:
)
)
- async def try_backfill(domains: Collection[str]) -> bool:
+ async def try_backfill(domains: StrCollection) -> bool:
# TODO: Should we try multiple of these at a time?
# Number of contacted remote homeservers that have denied our backfill
@@ -1730,7 +1720,7 @@ class FederationHandler:
def _start_partial_state_room_sync(
self,
initial_destination: Optional[str],
- other_destinations: Collection[str],
+ other_destinations: StrCollection,
room_id: str,
) -> None:
"""Starts the background process to resync the state of a partial state room,
@@ -1812,7 +1802,7 @@ class FederationHandler:
async def _sync_partial_state_room(
self,
initial_destination: Optional[str],
- other_destinations: Collection[str],
+ other_destinations: StrCollection,
room_id: str,
) -> None:
"""Background process to resync the state of a partial-state room
@@ -1949,9 +1939,9 @@ class FederationHandler:
def _prioritise_destinations_for_partial_state_resync(
initial_destination: Optional[str],
- other_destinations: Collection[str],
+ other_destinations: StrCollection,
room_id: str,
-) -> Collection[str]:
+) -> StrCollection:
"""Work out the order in which we should ask servers to resync events.
If an `initial_destination` is given, it takes top priority. Otherwise
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index 904a721483..e037acbca2 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -80,6 +80,7 @@ from synapse.types import (
PersistedEventPosition,
RoomStreamToken,
StateMap,
+ StrCollection,
UserID,
get_domain_from_id,
)
@@ -615,7 +616,7 @@ class FederationEventHandler:
@trace
async def backfill(
- self, dest: str, room_id: str, limit: int, extremities: Collection[str]
+ self, dest: str, room_id: str, limit: int, extremities: StrCollection
) -> None:
"""Trigger a backfill request to `dest` for the given `room_id`
@@ -1565,7 +1566,7 @@ class FederationEventHandler:
@trace
@tag_args
async def _get_events_and_persist(
- self, destination: str, room_id: str, event_ids: Collection[str]
+ self, destination: str, room_id: str, event_ids: StrCollection
) -> None:
"""Fetch the given events from a server, and persist them as outliers.
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 8c2260ad7d..191529bd8e 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -15,7 +15,13 @@
import logging
from typing import TYPE_CHECKING, List, Optional, Tuple, cast
-from synapse.api.constants import AccountDataTypes, EduTypes, EventTypes, Membership
+from synapse.api.constants import (
+ AccountDataTypes,
+ Direction,
+ EduTypes,
+ EventTypes,
+ Membership,
+)
from synapse.api.errors import SynapseError
from synapse.events import EventBase
from synapse.events.utils import SerializeEventConfig
@@ -57,7 +63,13 @@ class InitialSyncHandler:
self.validator = EventValidator()
self.snapshot_cache: ResponseCache[
Tuple[
- str, Optional[StreamToken], Optional[StreamToken], str, int, bool, bool
+ str,
+ Optional[StreamToken],
+ Optional[StreamToken],
+ Direction,
+ int,
+ bool,
+ bool,
]
] = ResponseCache(hs.get_clock(), "initial_sync_cache")
self._event_serializer = hs.get_event_client_serializer()
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 8c8ff18a1a..ceefa16b49 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -13,13 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Set
+from typing import TYPE_CHECKING, Dict, List, Optional, Set
import attr
from twisted.python.failure import Failure
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import Direction, EventTypes, Membership
from synapse.api.errors import SynapseError
from synapse.api.filtering import Filter
from synapse.events.utils import SerializeEventConfig
@@ -28,7 +28,7 @@ from synapse.logging.opentracing import trace
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.rest.admin._base import assert_user_is_admin
from synapse.streams.config import PaginationConfig
-from synapse.types import JsonDict, Requester, StreamKeyType
+from synapse.types import JsonDict, Requester, StrCollection, StreamKeyType
from synapse.types.state import StateFilter
from synapse.util.async_helpers import ReadWriteLock
from synapse.util.stringutils import random_string
@@ -391,7 +391,7 @@ class PaginationHandler:
"""
return self._delete_by_id.get(delete_id)
- def get_delete_ids_by_room(self, room_id: str) -> Optional[Collection[str]]:
+ def get_delete_ids_by_room(self, room_id: str) -> Optional[StrCollection]:
"""Get all active delete ids by room
Args:
@@ -448,7 +448,7 @@ class PaginationHandler:
if pagin_config.from_token:
from_token = pagin_config.from_token
- elif pagin_config.direction == "f":
+ elif pagin_config.direction == Direction.FORWARDS:
from_token = (
await self.hs.get_event_sources().get_start_token_for_pagination(
room_id
@@ -476,7 +476,7 @@ class PaginationHandler:
room_id, requester, allow_departed_users=True
)
- if pagin_config.direction == "b":
+ if pagin_config.direction == Direction.BACKWARDS:
# if we're going backwards, we might need to backfill. This
# requires that we have a topo token.
if room_token.topological:
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 43e4e7b1b4..87af31aa27 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -64,7 +64,13 @@ from synapse.replication.tcp.commands import ClearUserSyncsCommand
from synapse.replication.tcp.streams import PresenceFederationStream, PresenceStream
from synapse.storage.databases.main import DataStore
from synapse.streams import EventSource
-from synapse.types import JsonDict, StreamKeyType, UserID, get_domain_from_id
+from synapse.types import (
+ JsonDict,
+ StrCollection,
+ StreamKeyType,
+ UserID,
+ get_domain_from_id,
+)
from synapse.util.async_helpers import Linearizer
from synapse.util.metrics import Measure
from synapse.util.wheel_timer import WheelTimer
@@ -320,7 +326,7 @@ class BasePresenceHandler(abc.ABC):
for destination, host_states in hosts_to_states.items():
self._federation.send_presence_to_destinations(host_states, [destination])
- async def send_full_presence_to_users(self, user_ids: Collection[str]) -> None:
+ async def send_full_presence_to_users(self, user_ids: StrCollection) -> None:
"""
Adds to the list of users who should receive a full snapshot of presence
upon their next sync. Note that this only works for local users.
@@ -1601,7 +1607,7 @@ class PresenceEventSource(EventSource[int, UserPresenceState]):
# Having a default limit doesn't match the EventSource API, but some
# callers do not provide it. It is unused in this class.
limit: int = 0,
- room_ids: Optional[Collection[str]] = None,
+ room_ids: Optional[StrCollection] = None,
is_guest: bool = False,
explicit_room_id: Optional[str] = None,
include_offline: bool = True,
@@ -1688,7 +1694,7 @@ class PresenceEventSource(EventSource[int, UserPresenceState]):
# The set of users that we're interested in and that have had a presence update.
# We'll actually pull the presence updates for these users at the end.
- interested_and_updated_users: Collection[str]
+ interested_and_updated_users: StrCollection
if from_key is not None:
# First get all users that have had a presence update
@@ -2120,7 +2126,7 @@ class PresenceFederationQueue:
# stream_id, destinations, user_ids)`. We don't store the full states
# for efficiency, and remote workers will already have the full states
# cached.
- self._queue: List[Tuple[int, int, Collection[str], Set[str]]] = []
+ self._queue: List[Tuple[int, int, StrCollection, Set[str]]] = []
self._next_id = 1
@@ -2142,7 +2148,7 @@ class PresenceFederationQueue:
self._queue = self._queue[index:]
def send_presence_to_destinations(
- self, states: Collection[UserPresenceState], destinations: Collection[str]
+ self, states: Collection[UserPresenceState], destinations: StrCollection
) -> None:
"""Send the presence states to the given destinations.
diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py
index e96f9999a8..0fb15391e0 100644
--- a/synapse/handlers/relations.py
+++ b/synapse/handlers/relations.py
@@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Collection, Dict, FrozenSet, Iterable, List, O
import attr
-from synapse.api.constants import EventTypes, RelationTypes
+from synapse.api.constants import Direction, EventTypes, RelationTypes
from synapse.api.errors import SynapseError
from synapse.events import EventBase, relation_from_event
from synapse.logging.context import make_deferred_yieldable, run_in_background
@@ -413,7 +413,11 @@ class RelationsHandler:
# Attempt to find another event to use as the latest event.
potential_events, _ = await self._main_store.get_relations_for_event(
- event_id, event, room_id, RelationTypes.THREAD, direction="f"
+ event_id,
+ event,
+ room_id,
+ RelationTypes.THREAD,
+ direction=Direction.FORWARDS,
)
# Filter out ignored users.
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 572c7b4db3..60a6d9cf3c 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -20,16 +20,7 @@ import random
import string
from collections import OrderedDict
from http import HTTPStatus
-from typing import (
- TYPE_CHECKING,
- Any,
- Awaitable,
- Collection,
- Dict,
- List,
- Optional,
- Tuple,
-)
+from typing import TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Tuple
import attr
from typing_extensions import TypedDict
@@ -72,6 +63,7 @@ from synapse.types import (
RoomID,
RoomStreamToken,
StateMap,
+ StrCollection,
StreamKeyType,
StreamToken,
UserID,
@@ -1644,7 +1636,7 @@ class RoomEventSource(EventSource[RoomStreamToken, EventBase]):
user: UserID,
from_key: RoomStreamToken,
limit: int,
- room_ids: Collection[str],
+ room_ids: StrCollection,
is_guest: bool,
explicit_room_id: Optional[str] = None,
) -> Tuple[List[EventBase], RoomStreamToken]:
diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py
index c6b869c6f4..4472019fbc 100644
--- a/synapse/handlers/room_summary.py
+++ b/synapse/handlers/room_summary.py
@@ -36,7 +36,7 @@ from synapse.api.errors import (
)
from synapse.api.ratelimiting import Ratelimiter
from synapse.events import EventBase
-from synapse.types import JsonDict, Requester
+from synapse.types import JsonDict, Requester, StrCollection
from synapse.util.caches.response_cache import ResponseCache
if TYPE_CHECKING:
@@ -870,7 +870,7 @@ class _RoomQueueEntry:
# The room ID of this entry.
room_id: str
# The server to query if the room is not known locally.
- via: Sequence[str]
+ via: StrCollection
# The minimum number of hops necessary to get to this room (compared to the
# originally requested room).
depth: int = 0
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index 40f4635c4e..9bbf83047d 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -14,7 +14,7 @@
import itertools
import logging
-from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
import attr
from unpaddedbase64 import decode_base64, encode_base64
@@ -23,7 +23,7 @@ from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import NotFoundError, SynapseError
from synapse.api.filtering import Filter
from synapse.events import EventBase
-from synapse.types import JsonDict, StreamKeyType, UserID
+from synapse.types import JsonDict, StrCollection, StreamKeyType, UserID
from synapse.types.state import StateFilter
from synapse.visibility import filter_events_for_client
@@ -418,7 +418,7 @@ class SearchHandler:
async def _search_by_rank(
self,
user: UserID,
- room_ids: Collection[str],
+ room_ids: StrCollection,
search_term: str,
keys: Iterable[str],
search_filter: Filter,
@@ -491,7 +491,7 @@ class SearchHandler:
async def _search_by_recent(
self,
user: UserID,
- room_ids: Collection[str],
+ room_ids: StrCollection,
search_term: str,
keys: Iterable[str],
search_filter: Filter,
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 44e70fc4b8..4a27c0f051 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -20,7 +20,6 @@ from typing import (
Any,
Awaitable,
Callable,
- Collection,
Dict,
Iterable,
List,
@@ -47,6 +46,7 @@ from synapse.http.server import respond_with_html, respond_with_redirect
from synapse.http.site import SynapseRequest
from synapse.types import (
JsonDict,
+ StrCollection,
UserID,
contains_invalid_mxid_characters,
create_requester,
@@ -141,7 +141,8 @@ class UserAttributes:
confirm_localpart: bool = False
display_name: Optional[str] = None
picture: Optional[str] = None
- emails: Collection[str] = attr.Factory(list)
+ # mypy thinks these are incompatible for some reason.
+ emails: StrCollection = attr.Factory(list) # type: ignore[assignment]
@attr.s(slots=True, auto_attribs=True)
@@ -159,7 +160,7 @@ class UsernameMappingSession:
# attributes returned by the ID mapper
display_name: Optional[str]
- emails: Collection[str]
+ emails: StrCollection
# An optional dictionary of extra attributes to be provided to the client in the
# login response.
@@ -174,7 +175,7 @@ class UsernameMappingSession:
# choices made by the user
chosen_localpart: Optional[str] = None
use_display_name: bool = True
- emails_to_use: Collection[str] = ()
+ emails_to_use: StrCollection = ()
terms_accepted_version: Optional[str] = None
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 5ebd3ea855..5235e29460 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -17,7 +17,6 @@ from typing import (
TYPE_CHECKING,
AbstractSet,
Any,
- Collection,
Dict,
FrozenSet,
List,
@@ -62,6 +61,7 @@ from synapse.types import (
Requester,
RoomStreamToken,
StateMap,
+ StrCollection,
StreamKeyType,
StreamToken,
UserID,
@@ -1179,7 +1179,7 @@ class SyncHandler:
async def _find_missing_partial_state_memberships(
self,
room_id: str,
- members_to_fetch: Collection[str],
+ members_to_fetch: StrCollection,
events_with_membership_auth: Mapping[str, EventBase],
found_state_ids: StateMap[str],
) -> StateMap[str]:
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 6153a48257..d22dd19d38 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -1158,7 +1158,7 @@ class ModuleApi:
# Send to remote destinations.
destination = UserID.from_string(user).domain
presence_handler.get_federation_queue().send_presence_to_destinations(
- presence_events, destination
+ presence_events, [destination]
)
def looping_background_call(
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 2b0e52f23c..a8832a3f8e 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -46,6 +46,7 @@ from synapse.types import (
JsonDict,
PersistedEventPosition,
RoomStreamToken,
+ StrCollection,
StreamKeyType,
StreamToken,
UserID,
@@ -716,7 +717,7 @@ class Notifier:
async def _get_room_ids(
self, user: UserID, explicit_room_id: Optional[str]
- ) -> Tuple[Collection[str], bool]:
+ ) -> Tuple[StrCollection, bool]:
joined_room_ids = await self.store.get_rooms_for_user(user.to_string())
if explicit_room_id:
if explicit_room_id in joined_room_ids:
diff --git a/synapse/rest/client/push_rule.py b/synapse/rest/client/push_rule.py
index 8191b4e32c..ad5c10c99d 100644
--- a/synapse/rest/client/push_rule.py
+++ b/synapse/rest/client/push_rule.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import TYPE_CHECKING, List, Sequence, Tuple, Union
+from typing import TYPE_CHECKING, List, Tuple, Union
from synapse.api.errors import (
NotFoundError,
@@ -169,7 +169,7 @@ class PushRuleRestServlet(RestServlet):
raise UnrecognizedRequestError()
-def _rule_spec_from_path(path: Sequence[str]) -> RuleSpec:
+def _rule_spec_from_path(path: List[str]) -> RuleSpec:
"""Turn a sequence of path components into a rule spec
Args:
diff --git a/synapse/rest/client/transactions.py b/synapse/rest/client/transactions.py
index 61375651bc..3f40f1874a 100644
--- a/synapse/rest/client/transactions.py
+++ b/synapse/rest/client/transactions.py
@@ -19,6 +19,7 @@ from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Tuple
from typing_extensions import ParamSpec
+from twisted.internet.defer import Deferred
from twisted.python.failure import Failure
from twisted.web.server import Request
@@ -90,7 +91,7 @@ class HttpTransactionCache:
fn: Callable[P, Awaitable[Tuple[int, JsonDict]]],
*args: P.args,
**kwargs: P.kwargs,
- ) -> Awaitable[Tuple[int, JsonDict]]:
+ ) -> "Deferred[Tuple[int, JsonDict]]":
"""Fetches the response for this transaction, or executes the given function
to produce a response for this transaction.
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 84f844b79e..0018d6f7ab 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -30,7 +30,7 @@ from typing import (
import attr
-from synapse.api.constants import MAIN_TIMELINE, RelationTypes
+from synapse.api.constants import MAIN_TIMELINE, Direction, RelationTypes
from synapse.api.errors import SynapseError
from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore
@@ -40,9 +40,13 @@ from synapse.storage.database import (
LoggingTransaction,
make_in_list_sql_clause,
)
-from synapse.storage.databases.main.stream import generate_pagination_where_clause
+from synapse.storage.databases.main.stream import (
+ generate_next_token,
+ generate_pagination_bounds,
+ generate_pagination_where_clause,
+)
from synapse.storage.engines import PostgresEngine
-from synapse.types import JsonDict, RoomStreamToken, StreamKeyType, StreamToken
+from synapse.types import JsonDict, StreamKeyType, StreamToken
from synapse.util.caches.descriptors import cached, cachedList
if TYPE_CHECKING:
@@ -164,7 +168,7 @@ class RelationsWorkerStore(SQLBaseStore):
relation_type: Optional[str] = None,
event_type: Optional[str] = None,
limit: int = 5,
- direction: str = "b",
+ direction: Direction = Direction.BACKWARDS,
from_token: Optional[StreamToken] = None,
to_token: Optional[StreamToken] = None,
) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]:
@@ -177,8 +181,8 @@ class RelationsWorkerStore(SQLBaseStore):
relation_type: Only fetch events with this relation type, if given.
event_type: Only fetch events with this event type, if given.
limit: Only fetch the most recent `limit` events.
- direction: Whether to fetch the most recent first (`"b"`) or the
- oldest first (`"f"`).
+ direction: Whether to fetch the most recent first (backwards) or the
+ oldest first (forwards).
from_token: Fetch rows from the given token, or from the start if None.
to_token: Fetch rows up to the given token, or up to the end if None.
@@ -207,24 +211,23 @@ class RelationsWorkerStore(SQLBaseStore):
where_clause.append("type = ?")
where_args.append(event_type)
+ order, from_bound, to_bound = generate_pagination_bounds(
+ direction,
+ from_token.room_key if from_token else None,
+ to_token.room_key if to_token else None,
+ )
+
pagination_clause = generate_pagination_where_clause(
direction=direction,
column_names=("topological_ordering", "stream_ordering"),
- from_token=from_token.room_key.as_historical_tuple()
- if from_token
- else None,
- to_token=to_token.room_key.as_historical_tuple() if to_token else None,
+ from_token=from_bound,
+ to_token=to_bound,
engine=self.database_engine,
)
if pagination_clause:
where_clause.append(pagination_clause)
- if direction == "b":
- order = "DESC"
- else:
- order = "ASC"
-
sql = """
SELECT event_id, relation_type, sender, topological_ordering, stream_ordering
FROM event_relations
@@ -266,16 +269,9 @@ class RelationsWorkerStore(SQLBaseStore):
topo_orderings = topo_orderings[:limit]
stream_orderings = stream_orderings[:limit]
- topo = topo_orderings[-1]
- token = stream_orderings[-1]
- if direction == "b":
- # Tokens are positions between events.
- # This token points *after* the last event in the chunk.
- # We need it to point to the event before it in the chunk
- # when we are going backwards so we subtract one from the
- # stream part.
- token -= 1
- next_key = RoomStreamToken(topo, token)
+ next_key = generate_next_token(
+ direction, topo_orderings[-1], stream_orderings[-1]
+ )
if from_token:
next_token = from_token.copy_and_replace(
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index d28fc65df9..818c46182e 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -55,6 +55,7 @@ from typing_extensions import Literal
from twisted.internet import defer
+from synapse.api.constants import Direction
from synapse.api.filtering import Filter
from synapse.events import EventBase
from synapse.logging.context import make_deferred_yieldable, run_in_background
@@ -86,7 +87,6 @@ MAX_STREAM_SIZE = 1000
_STREAM_TOKEN = "stream"
_TOPOLOGICAL_TOKEN = "topological"
-
# Used as return values for pagination APIs
@attr.s(slots=True, frozen=True, auto_attribs=True)
class _EventDictReturn:
@@ -104,7 +104,7 @@ class _EventsAround:
def generate_pagination_where_clause(
- direction: str,
+ direction: Direction,
column_names: Tuple[str, str],
from_token: Optional[Tuple[Optional[int], int]],
to_token: Optional[Tuple[Optional[int], int]],
@@ -130,27 +130,26 @@ def generate_pagination_where_clause(
token, but include those that match the to token.
Args:
- direction: Whether we're paginating backwards("b") or forwards ("f").
+ direction: Whether we're paginating backwards or forwards.
column_names: The column names to bound. Must *not* be user defined as
these get inserted directly into the SQL statement without escapes.
from_token: The start point for the pagination. This is an exclusive
- minimum bound if direction is "f", and an inclusive maximum bound if
- direction is "b".
+ minimum bound if direction is forwards, and an inclusive maximum bound if
+ direction is backwards.
to_token: The endpoint point for the pagination. This is an inclusive
- maximum bound if direction is "f", and an exclusive minimum bound if
- direction is "b".
+ maximum bound if direction is forwards, and an exclusive minimum bound if
+ direction is backwards.
engine: The database engine to generate the clauses for
Returns:
The sql expression
"""
- assert direction in ("b", "f")
where_clause = []
if from_token:
where_clause.append(
_make_generic_sql_bound(
- bound=">=" if direction == "b" else "<",
+ bound=">=" if direction == Direction.BACKWARDS else "<",
column_names=column_names,
values=from_token,
engine=engine,
@@ -160,7 +159,7 @@ def generate_pagination_where_clause(
if to_token:
where_clause.append(
_make_generic_sql_bound(
- bound="<" if direction == "b" else ">=",
+ bound="<" if direction == Direction.BACKWARDS else ">=",
column_names=column_names,
values=to_token,
engine=engine,
@@ -170,6 +169,104 @@ def generate_pagination_where_clause(
return " AND ".join(where_clause)
+def generate_pagination_bounds(
+ direction: Direction,
+ from_token: Optional[RoomStreamToken],
+ to_token: Optional[RoomStreamToken],
+) -> Tuple[
+ str, Optional[Tuple[Optional[int], int]], Optional[Tuple[Optional[int], int]]
+]:
+ """
+ Generate a start and end point for this page of events.
+
+ Args:
+ direction: Whether pagination is going forwards or backwards.
+ from_token: The token to start pagination at, or None to start at the first value.
+ to_token: The token to end pagination at, or None to not limit the end point.
+
+ Returns:
+ A three tuple of:
+
+ ASC or DESC for sorting of the query.
+
+ The starting position as a tuple of ints representing
+ (topological position, stream position) or None if no from_token was
+ provided. The topological position may be None for live tokens.
+
+ The end position in the same format as the starting position, or None
+ if no to_token was provided.
+ """
+
+ # Tokens really represent positions between elements, but we use
+ # the convention of pointing to the event before the gap. Hence
+ # we have a bit of asymmetry when it comes to equalities.
+ if direction == Direction.BACKWARDS:
+ order = "DESC"
+ else:
+ order = "ASC"
+
+ # The bounds for the stream tokens are complicated by the fact
+ # that we need to handle the instance_map part of the tokens. We do this
+ # by fetching all events between the min stream token and the maximum
+ # stream token (as returned by `RoomStreamToken.get_max_stream_pos`) and
+ # then filtering the results.
+ from_bound: Optional[Tuple[Optional[int], int]] = None
+ if from_token:
+ if from_token.topological is not None:
+ from_bound = from_token.as_historical_tuple()
+ elif direction == Direction.BACKWARDS:
+ from_bound = (
+ None,
+ from_token.get_max_stream_pos(),
+ )
+ else:
+ from_bound = (
+ None,
+ from_token.stream,
+ )
+
+ to_bound: Optional[Tuple[Optional[int], int]] = None
+ if to_token:
+ if to_token.topological is not None:
+ to_bound = to_token.as_historical_tuple()
+ elif direction == Direction.BACKWARDS:
+ to_bound = (
+ None,
+ to_token.stream,
+ )
+ else:
+ to_bound = (
+ None,
+ to_token.get_max_stream_pos(),
+ )
+
+ return order, from_bound, to_bound
+
+
+def generate_next_token(
+ direction: Direction, last_topo_ordering: int, last_stream_ordering: int
+) -> RoomStreamToken:
+ """
+ Generate the next room stream token based on the currently returned data.
+
+ Args:
+ direction: Whether pagination is going forwards or backwards.
+ last_topo_ordering: The last topological ordering being returned.
+ last_stream_ordering: The last stream ordering being returned.
+
+ Returns:
+ A new RoomStreamToken to return to the client.
+ """
+ if direction == Direction.BACKWARDS:
+ # Tokens are positions between events.
+ # This token points *after* the last event in the chunk.
+ # We need it to point to the event before it in the chunk
+ # when we are going backwards so we subtract one from the
+ # stream part.
+ last_stream_ordering -= 1
+ return RoomStreamToken(last_topo_ordering, last_stream_ordering)
+
+
def _make_generic_sql_bound(
bound: str,
column_names: Tuple[str, str],
@@ -1103,7 +1200,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
txn,
room_id,
before_token,
- direction="b",
+ direction=Direction.BACKWARDS,
limit=before_limit,
event_filter=event_filter,
)
@@ -1113,7 +1210,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
txn,
room_id,
after_token,
- direction="f",
+ direction=Direction.FORWARDS,
limit=after_limit,
event_filter=event_filter,
)
@@ -1276,7 +1373,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
room_id: str,
from_token: RoomStreamToken,
to_token: Optional[RoomStreamToken] = None,
- direction: str = "b",
+ direction: Direction = Direction.BACKWARDS,
limit: int = -1,
event_filter: Optional[Filter] = None,
) -> Tuple[List[_EventDictReturn], RoomStreamToken]:
@@ -1287,8 +1384,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
room_id
from_token: The token used to stream from
to_token: A token which if given limits the results to only those before
- direction: Either 'b' or 'f' to indicate whether we are paginating
- forwards or backwards from `from_key`.
+ direction: Indicates whether we are paginating forwards or backwards
+ from `from_key`.
limit: The maximum number of events to return.
event_filter: If provided filters the events to
those that match the filter.
@@ -1300,47 +1397,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
`to_token`), or `limit` is zero.
"""
- # Tokens really represent positions between elements, but we use
- # the convention of pointing to the event before the gap. Hence
- # we have a bit of asymmetry when it comes to equalities.
args = [False, room_id]
- if direction == "b":
- order = "DESC"
- else:
- order = "ASC"
-
- # The bounds for the stream tokens are complicated by the fact
- # that we need to handle the instance_map part of the tokens. We do this
- # by fetching all events between the min stream token and the maximum
- # stream token (as returned by `RoomStreamToken.get_max_stream_pos`) and
- # then filtering the results.
- if from_token.topological is not None:
- from_bound: Tuple[Optional[int], int] = from_token.as_historical_tuple()
- elif direction == "b":
- from_bound = (
- None,
- from_token.get_max_stream_pos(),
- )
- else:
- from_bound = (
- None,
- from_token.stream,
- )
- to_bound: Optional[Tuple[Optional[int], int]] = None
- if to_token:
- if to_token.topological is not None:
- to_bound = to_token.as_historical_tuple()
- elif direction == "b":
- to_bound = (
- None,
- to_token.stream,
- )
- else:
- to_bound = (
- None,
- to_token.get_max_stream_pos(),
- )
+ order, from_bound, to_bound = generate_pagination_bounds(
+ direction, from_token, to_token
+ )
bounds = generate_pagination_where_clause(
direction=direction,
@@ -1427,8 +1488,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
_EventDictReturn(event_id, topological_ordering, stream_ordering)
for event_id, instance_name, topological_ordering, stream_ordering in txn
if _filter_results(
- lower_token=to_token if direction == "b" else from_token,
- upper_token=from_token if direction == "b" else to_token,
+ lower_token=to_token
+ if direction == Direction.BACKWARDS
+ else from_token,
+ upper_token=from_token
+ if direction == Direction.BACKWARDS
+ else to_token,
instance_name=instance_name,
topological_ordering=topological_ordering,
stream_ordering=stream_ordering,
@@ -1436,16 +1501,10 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
][:limit]
if rows:
- topo = rows[-1].topological_ordering
- token = rows[-1].stream_ordering
- if direction == "b":
- # Tokens are positions between events.
- # This token points *after* the last event in the chunk.
- # We need it to point to the event before it in the chunk
- # when we are going backwards so we subtract one from the
- # stream part.
- token -= 1
- next_token = RoomStreamToken(topo, token)
+ assert rows[-1].topological_ordering is not None
+ next_token = generate_next_token(
+ direction, rows[-1].topological_ordering, rows[-1].stream_ordering
+ )
else:
# TODO (erikj): We should work out what to do here instead.
next_token = to_token if to_token else from_token
@@ -1458,7 +1517,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
room_id: str,
from_key: RoomStreamToken,
to_key: Optional[RoomStreamToken] = None,
- direction: str = "b",
+ direction: Direction = Direction.BACKWARDS,
limit: int = -1,
event_filter: Optional[Filter] = None,
) -> Tuple[List[EventBase], RoomStreamToken]:
@@ -1468,8 +1527,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
room_id
from_key: The token used to stream from
to_key: A token which if given limits the results to only those before
- direction: Either 'b' or 'f' to indicate whether we are paginating
- forwards or backwards from `from_key`.
+ direction: Indicates whether we are paginating forwards or backwards
+ from `from_key`.
limit: The maximum number of events to return.
event_filter: If provided filters the events to those that match the filter.
diff --git a/synapse/streams/__init__.py b/synapse/streams/__init__.py
index 2dcd43d0a2..c6c8a0315c 100644
--- a/synapse/streams/__init__.py
+++ b/synapse/streams/__init__.py
@@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Collection, Generic, List, Optional, Tuple, TypeVar
+from typing import Generic, List, Optional, Tuple, TypeVar
-from synapse.types import UserID
+from synapse.types import StrCollection, UserID
# The key, this is either a stream token or int.
K = TypeVar("K")
@@ -28,7 +28,7 @@ class EventSource(Generic[K, R]):
user: UserID,
from_key: K,
limit: int,
- room_ids: Collection[str],
+ room_ids: StrCollection,
is_guest: bool,
explicit_room_id: Optional[str] = None,
) -> Tuple[List[R], K]:
diff --git a/synapse/streams/config.py b/synapse/streams/config.py
index 6df2de919c..5cb7875181 100644
--- a/synapse/streams/config.py
+++ b/synapse/streams/config.py
@@ -16,6 +16,7 @@ from typing import Optional
import attr
+from synapse.api.constants import Direction
from synapse.api.errors import SynapseError
from synapse.http.servlet import parse_integer, parse_string
from synapse.http.site import SynapseRequest
@@ -34,7 +35,7 @@ class PaginationConfig:
from_token: Optional[StreamToken]
to_token: Optional[StreamToken]
- direction: str
+ direction: Direction
limit: int
@classmethod
@@ -45,9 +46,13 @@ class PaginationConfig:
default_limit: int,
default_dir: str = "f",
) -> "PaginationConfig":
- direction = parse_string(
- request, "dir", default=default_dir, allowed_values=["f", "b"]
+ direction_str = parse_string(
+ request,
+ "dir",
+ default=default_dir,
+ allowed_values=[Direction.FORWARDS.value, Direction.BACKWARDS.value],
)
+ direction = Direction(direction_str)
from_tok_str = parse_string(request, "from")
to_tok_str = parse_string(request, "to")
|