diff --git a/changelog.d/16301.misc b/changelog.d/16301.misc
new file mode 100644
index 0000000000..93ceaeafc9
--- /dev/null
+++ b/changelog.d/16301.misc
@@ -0,0 +1 @@
+Improve type hints.
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index a94b57a671..9ac7e4313e 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -27,9 +27,7 @@ from typing import (
Any,
Awaitable,
Callable,
- Collection,
Dict,
- Iterable,
List,
NoReturn,
Optional,
@@ -76,7 +74,7 @@ from synapse.module_api.callbacks.spamchecker_callbacks import load_legacy_spam_
from synapse.module_api.callbacks.third_party_event_rules_callbacks import (
load_legacy_third_party_event_rules,
)
-from synapse.types import ISynapseReactor
+from synapse.types import ISynapseReactor, StrCollection
from synapse.util import SYNAPSE_VERSION
from synapse.util.caches.lrucache import setup_expire_lru_cache_entries
from synapse.util.daemonize import daemonize_process
@@ -278,7 +276,7 @@ def register_start(
reactor.callWhenRunning(lambda: defer.ensureDeferred(wrapper()))
-def listen_metrics(bind_addresses: Iterable[str], port: int) -> None:
+def listen_metrics(bind_addresses: StrCollection, port: int) -> None:
"""
Start Prometheus metrics server.
"""
@@ -315,7 +313,7 @@ def _set_prometheus_client_use_created_metrics(new_value: bool) -> None:
def listen_manhole(
- bind_addresses: Collection[str],
+ bind_addresses: StrCollection,
port: int,
manhole_settings: ManholeConfig,
manhole_globals: dict,
@@ -339,7 +337,7 @@ def listen_manhole(
def listen_tcp(
- bind_addresses: Collection[str],
+ bind_addresses: StrCollection,
port: int,
factory: ServerFactory,
reactor: IReactorTCP = reactor,
@@ -448,7 +446,7 @@ def listen_http(
def listen_ssl(
- bind_addresses: Collection[str],
+ bind_addresses: StrCollection,
port: int,
factory: ServerFactory,
context_factory: IOpenSSLContextFactory,
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index 58856839e1..c5816105f4 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -26,7 +26,6 @@ from textwrap import dedent
from typing import (
Any,
ClassVar,
- Collection,
Dict,
Iterable,
Iterator,
@@ -384,7 +383,7 @@ class RootConfig:
config_classes: List[Type[Config]] = []
- def __init__(self, config_files: Collection[str] = ()):
+ def __init__(self, config_files: StrSequence = ()):
# Capture absolute paths here, so we can reload config after we daemonize.
self.config_files = [os.path.abspath(path) for path in config_files]
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index 35257a3b1b..3c1777b7ec 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -25,7 +25,6 @@ from typing import (
Iterable,
List,
Optional,
- Sequence,
Tuple,
Type,
TypeVar,
@@ -408,7 +407,7 @@ class EventBase(metaclass=abc.ABCMeta):
def keys(self) -> Iterable[str]:
return self._dict.keys()
- def prev_event_ids(self) -> Sequence[str]:
+ def prev_event_ids(self) -> List[str]:
"""Returns the list of prev event IDs. The order matches the order
specified in the event, though there is no meaning to it.
@@ -553,7 +552,7 @@ class FrozenEventV2(EventBase):
self._event_id = "$" + encode_base64(compute_event_reference_hash(self)[1])
return self._event_id
- def prev_event_ids(self) -> Sequence[str]:
+ def prev_event_ids(self) -> List[str]:
"""Returns the list of prev event IDs. The order matches the order
specified in the event, though there is no meaning to it.
diff --git a/synapse/events/builder.py b/synapse/events/builder.py
index 14ea0e6640..1165c017ba 100644
--- a/synapse/events/builder.py
+++ b/synapse/events/builder.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Any, Collection, Dict, List, Optional, Tuple, Union
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import attr
from signedjson.types import SigningKey
@@ -28,7 +28,7 @@ from synapse.event_auth import auth_types_for_event
from synapse.events import EventBase, _EventInternalMetadata, make_event_from_dict
from synapse.state import StateHandler
from synapse.storage.databases.main import DataStore
-from synapse.types import EventID, JsonDict
+from synapse.types import EventID, JsonDict, StrCollection
from synapse.types.state import StateFilter
from synapse.util import Clock
from synapse.util.stringutils import random_string
@@ -103,7 +103,7 @@ class EventBuilder:
async def build(
self,
- prev_event_ids: Collection[str],
+ prev_event_ids: StrCollection,
auth_event_ids: Optional[List[str]],
depth: Optional[int] = None,
) -> EventBase:
@@ -136,7 +136,7 @@ class EventBuilder:
format_version = self.room_version.event_format
# The types of auth/prev events changes between event versions.
- prev_events: Union[Collection[str], List[Tuple[str, Dict[str, str]]]]
+ prev_events: Union[StrCollection, List[Tuple[str, Dict[str, str]]]]
auth_events: Union[List[str], List[Tuple[str, Dict[str, str]]]]
if format_version == EventFormatVersions.ROOM_V1_V2:
auth_events = await self._store.add_event_hashes(auth_event_ids)
diff --git a/synapse/events/validator.py b/synapse/events/validator.py
index 34625dd7a1..5da50cb0d2 100644
--- a/synapse/events/validator.py
+++ b/synapse/events/validator.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import collections.abc
-from typing import Iterable, List, Type, Union, cast
+from typing import List, Type, Union, cast
import jsonschema
from pydantic import Field, StrictBool, StrictStr
@@ -36,7 +36,7 @@ from synapse.events.utils import (
from synapse.federation.federation_server import server_matches_acl_event
from synapse.http.servlet import validate_json_object
from synapse.rest.models import RequestBodyModel
-from synapse.types import EventID, JsonDict, RoomID, UserID
+from synapse.types import EventID, JsonDict, RoomID, StrCollection, UserID
class EventValidator:
@@ -225,7 +225,7 @@ class EventValidator:
self._ensure_state_event(event)
- def _ensure_strings(self, d: JsonDict, keys: Iterable[str]) -> None:
+ def _ensure_strings(self, d: JsonDict, keys: StrCollection) -> None:
for s in keys:
if s not in d:
raise SynapseError(400, "'%s' not in content" % (s,))
diff --git a/synapse/http/client.py b/synapse/http/client.py
index ca2cdbc6e2..c750e03b36 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -78,7 +78,7 @@ from synapse.http.replicationagent import ReplicationAgent
from synapse.http.types import QueryParams
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.logging.opentracing import set_tag, start_active_span, tags
-from synapse.types import ISynapseReactor
+from synapse.types import ISynapseReactor, StrSequence
from synapse.util import json_decoder
from synapse.util.async_helpers import timeout_deferred
@@ -108,10 +108,9 @@ RawHeaders = Union[Mapping[str, "RawHeaderValue"], Mapping[bytes, "RawHeaderValu
# the value actually has to be a List, but List is invariant so we can't specify that
# the entries can either be Lists or bytes.
RawHeaderValue = Union[
- List[str],
+ StrSequence,
List[bytes],
List[Union[str, bytes]],
- Tuple[str, ...],
Tuple[bytes, ...],
Tuple[Union[str, bytes], ...],
]
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index fc62793628..5d79d31579 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -18,7 +18,6 @@ import logging
from http import HTTPStatus
from typing import (
TYPE_CHECKING,
- Iterable,
List,
Mapping,
Optional,
@@ -38,7 +37,7 @@ from twisted.web.server import Request
from synapse.api.errors import Codes, SynapseError
from synapse.http import redact_uri
from synapse.http.server import HttpServer
-from synapse.types import JsonDict, RoomAlias, RoomID
+from synapse.types import JsonDict, RoomAlias, RoomID, StrCollection
from synapse.util import json_decoder
if TYPE_CHECKING:
@@ -340,7 +339,7 @@ def parse_string(
name: str,
default: str,
*,
- allowed_values: Optional[Iterable[str]] = None,
+ allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii",
) -> str:
...
@@ -352,7 +351,7 @@ def parse_string(
name: str,
*,
required: Literal[True],
- allowed_values: Optional[Iterable[str]] = None,
+ allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii",
) -> str:
...
@@ -365,7 +364,7 @@ def parse_string(
*,
default: Optional[str] = None,
required: bool = False,
- allowed_values: Optional[Iterable[str]] = None,
+ allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii",
) -> Optional[str]:
...
@@ -376,7 +375,7 @@ def parse_string(
name: str,
default: Optional[str] = None,
required: bool = False,
- allowed_values: Optional[Iterable[str]] = None,
+ allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii",
) -> Optional[str]:
"""
@@ -485,7 +484,7 @@ def parse_enum(
def _parse_string_value(
value: bytes,
- allowed_values: Optional[Iterable[str]],
+ allowed_values: Optional[StrCollection],
name: str,
encoding: str,
) -> str:
@@ -511,7 +510,7 @@ def parse_strings_from_args(
args: Mapping[bytes, Sequence[bytes]],
name: str,
*,
- allowed_values: Optional[Iterable[str]] = None,
+ allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii",
) -> Optional[List[str]]:
...
@@ -523,7 +522,7 @@ def parse_strings_from_args(
name: str,
default: List[str],
*,
- allowed_values: Optional[Iterable[str]] = None,
+ allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii",
) -> List[str]:
...
@@ -535,7 +534,7 @@ def parse_strings_from_args(
name: str,
*,
required: Literal[True],
- allowed_values: Optional[Iterable[str]] = None,
+ allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii",
) -> List[str]:
...
@@ -548,7 +547,7 @@ def parse_strings_from_args(
default: Optional[List[str]] = None,
*,
required: bool = False,
- allowed_values: Optional[Iterable[str]] = None,
+ allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii",
) -> Optional[List[str]]:
...
@@ -559,7 +558,7 @@ def parse_strings_from_args(
name: str,
default: Optional[List[str]] = None,
required: bool = False,
- allowed_values: Optional[Iterable[str]] = None,
+ allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii",
) -> Optional[List[str]]:
"""
@@ -610,7 +609,7 @@ def parse_string_from_args(
name: str,
default: Optional[str] = None,
*,
- allowed_values: Optional[Iterable[str]] = None,
+ allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii",
) -> Optional[str]:
...
@@ -623,7 +622,7 @@ def parse_string_from_args(
default: Optional[str] = None,
*,
required: Literal[True],
- allowed_values: Optional[Iterable[str]] = None,
+ allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii",
) -> str:
...
@@ -635,7 +634,7 @@ def parse_string_from_args(
name: str,
default: Optional[str] = None,
required: bool = False,
- allowed_values: Optional[Iterable[str]] = None,
+ allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii",
) -> Optional[str]:
...
@@ -646,7 +645,7 @@ def parse_string_from_args(
name: str,
default: Optional[str] = None,
required: bool = False,
- allowed_values: Optional[Iterable[str]] = None,
+ allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii",
) -> Optional[str]:
"""
@@ -821,7 +820,7 @@ def parse_and_validate_json_object_from_request(
return validate_json_object(content, model_type)
-def assert_params_in_dict(body: JsonDict, required: Iterable[str]) -> None:
+def assert_params_in_dict(body: JsonDict, required: StrCollection) -> None:
absent = []
for k in required:
if k not in body:
diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py
index 39fc629937..3cf2fbc3e2 100644
--- a/synapse/metrics/__init__.py
+++ b/synapse/metrics/__init__.py
@@ -25,7 +25,6 @@ from typing import (
Iterable,
Mapping,
Optional,
- Sequence,
Set,
Tuple,
Type,
@@ -49,6 +48,7 @@ import synapse.metrics._reactor_metrics # noqa: F401
from synapse.metrics._gc import MIN_TIME_BETWEEN_GCS, install_gc_manager
from synapse.metrics._twisted_exposition import MetricsResource, generate_latest
from synapse.metrics._types import Collector
+from synapse.types import StrSequence
from synapse.util import SYNAPSE_VERSION
logger = logging.getLogger(__name__)
@@ -81,7 +81,7 @@ class LaterGauge(Collector):
name: str
desc: str
- labels: Optional[Sequence[str]] = attr.ib(hash=False)
+ labels: Optional[StrSequence] = attr.ib(hash=False)
# callback: should either return a value (if there are no labels for this metric),
# or dict mapping from a label tuple to a value
caller: Callable[
@@ -143,8 +143,8 @@ class InFlightGauge(Generic[MetricsEntry], Collector):
self,
name: str,
desc: str,
- labels: Sequence[str],
- sub_metrics: Sequence[str],
+ labels: StrSequence,
+ sub_metrics: StrSequence,
):
self.name = name
self.desc = desc
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 68115bca70..fc39e5c963 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -104,7 +104,7 @@ class _NotifierUserStream:
def __init__(
self,
user_id: str,
- rooms: Collection[str],
+ rooms: StrCollection,
current_token: StreamToken,
time_now_ms: int,
):
@@ -457,7 +457,7 @@ class Notifier:
stream_key: str,
new_token: Union[int, RoomStreamToken],
users: Optional[Collection[Union[str, UserID]]] = None,
- rooms: Optional[Collection[str]] = None,
+ rooms: Optional[StrCollection] = None,
) -> None:
"""Used to inform listeners that something has happened event wise.
@@ -529,7 +529,7 @@ class Notifier:
user_id: str,
timeout: int,
callback: Callable[[StreamToken, StreamToken], Awaitable[T]],
- room_ids: Optional[Collection[str]] = None,
+ room_ids: Optional[StrCollection] = None,
from_token: StreamToken = StreamToken.START,
) -> T:
"""Wait until the callback returns a non empty response or the
diff --git a/synapse/rest/client/_base.py b/synapse/rest/client/_base.py
index 5c1c19e1f3..73c568ef75 100644
--- a/synapse/rest/client/_base.py
+++ b/synapse/rest/client/_base.py
@@ -20,14 +20,14 @@ from typing import Any, Awaitable, Callable, Iterable, Pattern, Tuple, TypeVar,
from synapse.api.errors import InteractiveAuthIncompleteError
from synapse.api.urls import CLIENT_API_PREFIX
-from synapse.types import JsonDict
+from synapse.types import JsonDict, StrCollection
logger = logging.getLogger(__name__)
def client_patterns(
path_regex: str,
- releases: Iterable[str] = ("r0", "v3"),
+ releases: StrCollection = ("r0", "v3"),
unstable: bool = True,
v1: bool = False,
) -> Iterable[Pattern]:
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 1b91cf5eaa..e977ed1044 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -20,7 +20,6 @@ from typing import (
Any,
Awaitable,
Callable,
- Collection,
DefaultDict,
Dict,
FrozenSet,
@@ -49,7 +48,7 @@ from synapse.logging.opentracing import tag_args, trace
from synapse.replication.http.state import ReplicationUpdateCurrentStateRestServlet
from synapse.state import v1, v2
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
-from synapse.types import StateMap
+from synapse.types import StateMap, StrCollection
from synapse.types.state import StateFilter
from synapse.util.async_helpers import Linearizer
from synapse.util.caches.expiringcache import ExpiringCache
@@ -197,7 +196,7 @@ class StateHandler:
async def compute_state_after_events(
self,
room_id: str,
- event_ids: Collection[str],
+ event_ids: StrCollection,
state_filter: Optional[StateFilter] = None,
await_full_state: bool = True,
) -> StateMap[str]:
@@ -231,7 +230,7 @@ class StateHandler:
return await ret.get_state(self._state_storage_controller, state_filter)
async def get_current_user_ids_in_room(
- self, room_id: str, latest_event_ids: Collection[str]
+ self, room_id: str, latest_event_ids: StrCollection
) -> Set[str]:
"""
Get the users IDs who are currently in a room.
@@ -256,7 +255,7 @@ class StateHandler:
return await self.store.get_joined_user_ids_from_state(room_id, state)
async def get_hosts_in_room_at_events(
- self, room_id: str, event_ids: Collection[str]
+ self, room_id: str, event_ids: StrCollection
) -> FrozenSet[str]:
"""Get the hosts that were in a room at the given event ids
@@ -470,7 +469,7 @@ class StateHandler:
@trace
@measure_func()
async def resolve_state_groups_for_events(
- self, room_id: str, event_ids: Collection[str], await_full_state: bool = True
+ self, room_id: str, event_ids: StrCollection, await_full_state: bool = True
) -> _StateCacheEntry:
"""Given a list of event_ids this method fetches the state at each
event, resolves conflicts between them and returns them.
@@ -882,7 +881,7 @@ class StateResolutionStore:
store: "DataStore"
def get_events(
- self, event_ids: Collection[str], allow_rejected: bool = False
+ self, event_ids: StrCollection, allow_rejected: bool = False
) -> Awaitable[Dict[str, EventBase]]:
"""Get events from the database
diff --git a/synapse/state/v1.py b/synapse/state/v1.py
index 500e384695..c76a2f082e 100644
--- a/synapse/state/v1.py
+++ b/synapse/state/v1.py
@@ -17,7 +17,6 @@ import logging
from typing import (
Awaitable,
Callable,
- Collection,
Dict,
Iterable,
List,
@@ -32,7 +31,7 @@ from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError
from synapse.api.room_versions import RoomVersion
from synapse.events import EventBase
-from synapse.types import MutableStateMap, StateMap
+from synapse.types import MutableStateMap, StateMap, StrCollection
logger = logging.getLogger(__name__)
@@ -45,7 +44,7 @@ async def resolve_events_with_store(
room_version: RoomVersion,
state_sets: Sequence[StateMap[str]],
event_map: Optional[Dict[str, EventBase]],
- state_map_factory: Callable[[Collection[str]], Awaitable[Dict[str, EventBase]]],
+ state_map_factory: Callable[[StrCollection], Awaitable[Dict[str, EventBase]]],
) -> StateMap[str]:
"""
Args:
diff --git a/synapse/state/v2.py b/synapse/state/v2.py
index 44c49274a9..1752f95db8 100644
--- a/synapse/state/v2.py
+++ b/synapse/state/v2.py
@@ -19,7 +19,6 @@ from typing import (
Any,
Awaitable,
Callable,
- Collection,
Dict,
Generator,
Iterable,
@@ -39,7 +38,7 @@ from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError
from synapse.api.room_versions import RoomVersion
from synapse.events import EventBase
-from synapse.types import MutableStateMap, StateMap
+from synapse.types import MutableStateMap, StateMap, StrCollection
logger = logging.getLogger(__name__)
@@ -56,7 +55,7 @@ class StateResolutionStore(Protocol):
# This is usually synapse.state.StateResolutionStore, but it's replaced with a
# TestStateResolutionStore in tests.
def get_events(
- self, event_ids: Collection[str], allow_rejected: bool = False
+ self, event_ids: StrCollection, allow_rejected: bool = False
) -> Awaitable[Dict[str, EventBase]]:
...
@@ -366,7 +365,7 @@ async def _get_auth_chain_difference(
union = unpersisted_set_ids[0].union(*unpersisted_set_ids[1:])
intersection = unpersisted_set_ids[0].intersection(*unpersisted_set_ids[1:])
- auth_difference_unpersisted_part: Collection[str] = union - intersection
+ auth_difference_unpersisted_part: StrCollection = union - intersection
else:
auth_difference_unpersisted_part = ()
state_sets_ids = [set(state_set.values()) for state_set in state_sets]
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index fab7008a8f..09de8f55e2 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -47,7 +47,7 @@ from synapse.storage.database import (
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.signatures import SignatureWorkerStore
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
-from synapse.types import JsonDict, StrCollection
+from synapse.types import JsonDict, StrCollection, StrSequence
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
from synapse.util.caches.lrucache import LruCache
@@ -1179,7 +1179,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
)
@cached(max_entries=5000, iterable=True)
- async def get_latest_event_ids_in_room(self, room_id: str) -> Sequence[str]:
+ async def get_latest_event_ids_in_room(self, room_id: str) -> StrSequence:
return await self.db_pool.simple_select_onecol(
table="event_forward_extremities",
keyvalues={"room_id": room_id},
diff --git a/synapse/visibility.py b/synapse/visibility.py
index eac10f6438..f15fdd8314 100644
--- a/synapse/visibility.py
+++ b/synapse/visibility.py
@@ -36,7 +36,7 @@ from synapse.events.utils import prune_event
from synapse.logging.opentracing import trace
from synapse.storage.controllers import StorageControllers
from synapse.storage.databases.main import DataStore
-from synapse.types import RetentionPolicy, StateMap, get_domain_from_id
+from synapse.types import RetentionPolicy, StateMap, StrCollection, get_domain_from_id
from synapse.types.state import StateFilter
from synapse.util import Clock
@@ -150,12 +150,12 @@ async def filter_events_for_client(
async def filter_event_for_clients_with_state(
store: DataStore,
- user_ids: Collection[str],
+ user_ids: StrCollection,
event: EventBase,
context: EventContext,
is_peeking: bool = False,
filter_send_to_client: bool = True,
-) -> Collection[str]:
+) -> StrCollection:
"""
Checks to see if an event is visible to the users in the list at the time of
the event.
|