summary refs log tree commit diff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--changelog.d/16301.misc1
-rw-r--r--synapse/app/_base.py12
-rw-r--r--synapse/config/_base.py3
-rw-r--r--synapse/events/__init__.py5
-rw-r--r--synapse/events/builder.py8
-rw-r--r--synapse/events/validator.py6
-rw-r--r--synapse/http/client.py5
-rw-r--r--synapse/http/servlet.py33
-rw-r--r--synapse/metrics/__init__.py8
-rw-r--r--synapse/notifier.py6
-rw-r--r--synapse/rest/client/_base.py4
-rw-r--r--synapse/state/__init__.py13
-rw-r--r--synapse/state/v1.py5
-rw-r--r--synapse/state/v2.py7
-rw-r--r--synapse/storage/databases/main/event_federation.py4
-rw-r--r--synapse/visibility.py6
16 files changed, 59 insertions, 67 deletions
diff --git a/changelog.d/16301.misc b/changelog.d/16301.misc
new file mode 100644
index 0000000000..93ceaeafc9
--- /dev/null
+++ b/changelog.d/16301.misc
@@ -0,0 +1 @@
+Improve type hints.
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index a94b57a671..9ac7e4313e 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -27,9 +27,7 @@ from typing import (
     Any,
     Awaitable,
     Callable,
-    Collection,
     Dict,
-    Iterable,
     List,
     NoReturn,
     Optional,
@@ -76,7 +74,7 @@ from synapse.module_api.callbacks.spamchecker_callbacks import load_legacy_spam_
 from synapse.module_api.callbacks.third_party_event_rules_callbacks import (
     load_legacy_third_party_event_rules,
 )
-from synapse.types import ISynapseReactor
+from synapse.types import ISynapseReactor, StrCollection
 from synapse.util import SYNAPSE_VERSION
 from synapse.util.caches.lrucache import setup_expire_lru_cache_entries
 from synapse.util.daemonize import daemonize_process
@@ -278,7 +276,7 @@ def register_start(
     reactor.callWhenRunning(lambda: defer.ensureDeferred(wrapper()))
 
 
-def listen_metrics(bind_addresses: Iterable[str], port: int) -> None:
+def listen_metrics(bind_addresses: StrCollection, port: int) -> None:
     """
     Start Prometheus metrics server.
     """
@@ -315,7 +313,7 @@ def _set_prometheus_client_use_created_metrics(new_value: bool) -> None:
 
 
 def listen_manhole(
-    bind_addresses: Collection[str],
+    bind_addresses: StrCollection,
     port: int,
     manhole_settings: ManholeConfig,
     manhole_globals: dict,
@@ -339,7 +337,7 @@ def listen_manhole(
 
 
 def listen_tcp(
-    bind_addresses: Collection[str],
+    bind_addresses: StrCollection,
     port: int,
     factory: ServerFactory,
     reactor: IReactorTCP = reactor,
@@ -448,7 +446,7 @@ def listen_http(
 
 
 def listen_ssl(
-    bind_addresses: Collection[str],
+    bind_addresses: StrCollection,
     port: int,
     factory: ServerFactory,
     context_factory: IOpenSSLContextFactory,
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index 58856839e1..c5816105f4 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -26,7 +26,6 @@ from textwrap import dedent
 from typing import (
     Any,
     ClassVar,
-    Collection,
     Dict,
     Iterable,
     Iterator,
@@ -384,7 +383,7 @@ class RootConfig:
 
     config_classes: List[Type[Config]] = []
 
-    def __init__(self, config_files: Collection[str] = ()):
+    def __init__(self, config_files: StrSequence = ()):
         # Capture absolute paths here, so we can reload config after we daemonize.
         self.config_files = [os.path.abspath(path) for path in config_files]
 
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index 35257a3b1b..3c1777b7ec 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -25,7 +25,6 @@ from typing import (
     Iterable,
     List,
     Optional,
-    Sequence,
     Tuple,
     Type,
     TypeVar,
@@ -408,7 +407,7 @@ class EventBase(metaclass=abc.ABCMeta):
     def keys(self) -> Iterable[str]:
         return self._dict.keys()
 
-    def prev_event_ids(self) -> Sequence[str]:
+    def prev_event_ids(self) -> List[str]:
         """Returns the list of prev event IDs. The order matches the order
         specified in the event, though there is no meaning to it.
 
@@ -553,7 +552,7 @@ class FrozenEventV2(EventBase):
         self._event_id = "$" + encode_base64(compute_event_reference_hash(self)[1])
         return self._event_id
 
-    def prev_event_ids(self) -> Sequence[str]:
+    def prev_event_ids(self) -> List[str]:
         """Returns the list of prev event IDs. The order matches the order
         specified in the event, though there is no meaning to it.
 
diff --git a/synapse/events/builder.py b/synapse/events/builder.py
index 14ea0e6640..1165c017ba 100644
--- a/synapse/events/builder.py
+++ b/synapse/events/builder.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import TYPE_CHECKING, Any, Collection, Dict, List, Optional, Tuple, Union
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
 
 import attr
 from signedjson.types import SigningKey
@@ -28,7 +28,7 @@ from synapse.event_auth import auth_types_for_event
 from synapse.events import EventBase, _EventInternalMetadata, make_event_from_dict
 from synapse.state import StateHandler
 from synapse.storage.databases.main import DataStore
-from synapse.types import EventID, JsonDict
+from synapse.types import EventID, JsonDict, StrCollection
 from synapse.types.state import StateFilter
 from synapse.util import Clock
 from synapse.util.stringutils import random_string
@@ -103,7 +103,7 @@ class EventBuilder:
 
     async def build(
         self,
-        prev_event_ids: Collection[str],
+        prev_event_ids: StrCollection,
         auth_event_ids: Optional[List[str]],
         depth: Optional[int] = None,
     ) -> EventBase:
@@ -136,7 +136,7 @@ class EventBuilder:
 
         format_version = self.room_version.event_format
         # The types of auth/prev events changes between event versions.
-        prev_events: Union[Collection[str], List[Tuple[str, Dict[str, str]]]]
+        prev_events: Union[StrCollection, List[Tuple[str, Dict[str, str]]]]
         auth_events: Union[List[str], List[Tuple[str, Dict[str, str]]]]
         if format_version == EventFormatVersions.ROOM_V1_V2:
             auth_events = await self._store.add_event_hashes(auth_event_ids)
diff --git a/synapse/events/validator.py b/synapse/events/validator.py
index 34625dd7a1..5da50cb0d2 100644
--- a/synapse/events/validator.py
+++ b/synapse/events/validator.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import collections.abc
-from typing import Iterable, List, Type, Union, cast
+from typing import List, Type, Union, cast
 
 import jsonschema
 from pydantic import Field, StrictBool, StrictStr
@@ -36,7 +36,7 @@ from synapse.events.utils import (
 from synapse.federation.federation_server import server_matches_acl_event
 from synapse.http.servlet import validate_json_object
 from synapse.rest.models import RequestBodyModel
-from synapse.types import EventID, JsonDict, RoomID, UserID
+from synapse.types import EventID, JsonDict, RoomID, StrCollection, UserID
 
 
 class EventValidator:
@@ -225,7 +225,7 @@ class EventValidator:
 
             self._ensure_state_event(event)
 
-    def _ensure_strings(self, d: JsonDict, keys: Iterable[str]) -> None:
+    def _ensure_strings(self, d: JsonDict, keys: StrCollection) -> None:
         for s in keys:
             if s not in d:
                 raise SynapseError(400, "'%s' not in content" % (s,))
diff --git a/synapse/http/client.py b/synapse/http/client.py
index ca2cdbc6e2..c750e03b36 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -78,7 +78,7 @@ from synapse.http.replicationagent import ReplicationAgent
 from synapse.http.types import QueryParams
 from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.logging.opentracing import set_tag, start_active_span, tags
-from synapse.types import ISynapseReactor
+from synapse.types import ISynapseReactor, StrSequence
 from synapse.util import json_decoder
 from synapse.util.async_helpers import timeout_deferred
 
@@ -108,10 +108,9 @@ RawHeaders = Union[Mapping[str, "RawHeaderValue"], Mapping[bytes, "RawHeaderValu
 # the value actually has to be a List, but List is invariant so we can't specify that
 # the entries can either be Lists or bytes.
 RawHeaderValue = Union[
-    List[str],
+    StrSequence,
     List[bytes],
     List[Union[str, bytes]],
-    Tuple[str, ...],
     Tuple[bytes, ...],
     Tuple[Union[str, bytes], ...],
 ]
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index fc62793628..5d79d31579 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -18,7 +18,6 @@ import logging
 from http import HTTPStatus
 from typing import (
     TYPE_CHECKING,
-    Iterable,
     List,
     Mapping,
     Optional,
@@ -38,7 +37,7 @@ from twisted.web.server import Request
 from synapse.api.errors import Codes, SynapseError
 from synapse.http import redact_uri
 from synapse.http.server import HttpServer
-from synapse.types import JsonDict, RoomAlias, RoomID
+from synapse.types import JsonDict, RoomAlias, RoomID, StrCollection
 from synapse.util import json_decoder
 
 if TYPE_CHECKING:
@@ -340,7 +339,7 @@ def parse_string(
     name: str,
     default: str,
     *,
-    allowed_values: Optional[Iterable[str]] = None,
+    allowed_values: Optional[StrCollection] = None,
     encoding: str = "ascii",
 ) -> str:
     ...
@@ -352,7 +351,7 @@ def parse_string(
     name: str,
     *,
     required: Literal[True],
-    allowed_values: Optional[Iterable[str]] = None,
+    allowed_values: Optional[StrCollection] = None,
     encoding: str = "ascii",
 ) -> str:
     ...
@@ -365,7 +364,7 @@ def parse_string(
     *,
     default: Optional[str] = None,
     required: bool = False,
-    allowed_values: Optional[Iterable[str]] = None,
+    allowed_values: Optional[StrCollection] = None,
     encoding: str = "ascii",
 ) -> Optional[str]:
     ...
@@ -376,7 +375,7 @@ def parse_string(
     name: str,
     default: Optional[str] = None,
     required: bool = False,
-    allowed_values: Optional[Iterable[str]] = None,
+    allowed_values: Optional[StrCollection] = None,
     encoding: str = "ascii",
 ) -> Optional[str]:
     """
@@ -485,7 +484,7 @@ def parse_enum(
 
 def _parse_string_value(
     value: bytes,
-    allowed_values: Optional[Iterable[str]],
+    allowed_values: Optional[StrCollection],
     name: str,
     encoding: str,
 ) -> str:
@@ -511,7 +510,7 @@ def parse_strings_from_args(
     args: Mapping[bytes, Sequence[bytes]],
     name: str,
     *,
-    allowed_values: Optional[Iterable[str]] = None,
+    allowed_values: Optional[StrCollection] = None,
     encoding: str = "ascii",
 ) -> Optional[List[str]]:
     ...
@@ -523,7 +522,7 @@ def parse_strings_from_args(
     name: str,
     default: List[str],
     *,
-    allowed_values: Optional[Iterable[str]] = None,
+    allowed_values: Optional[StrCollection] = None,
     encoding: str = "ascii",
 ) -> List[str]:
     ...
@@ -535,7 +534,7 @@ def parse_strings_from_args(
     name: str,
     *,
     required: Literal[True],
-    allowed_values: Optional[Iterable[str]] = None,
+    allowed_values: Optional[StrCollection] = None,
     encoding: str = "ascii",
 ) -> List[str]:
     ...
@@ -548,7 +547,7 @@ def parse_strings_from_args(
     default: Optional[List[str]] = None,
     *,
     required: bool = False,
-    allowed_values: Optional[Iterable[str]] = None,
+    allowed_values: Optional[StrCollection] = None,
     encoding: str = "ascii",
 ) -> Optional[List[str]]:
     ...
@@ -559,7 +558,7 @@ def parse_strings_from_args(
     name: str,
     default: Optional[List[str]] = None,
     required: bool = False,
-    allowed_values: Optional[Iterable[str]] = None,
+    allowed_values: Optional[StrCollection] = None,
     encoding: str = "ascii",
 ) -> Optional[List[str]]:
     """
@@ -610,7 +609,7 @@ def parse_string_from_args(
     name: str,
     default: Optional[str] = None,
     *,
-    allowed_values: Optional[Iterable[str]] = None,
+    allowed_values: Optional[StrCollection] = None,
     encoding: str = "ascii",
 ) -> Optional[str]:
     ...
@@ -623,7 +622,7 @@ def parse_string_from_args(
     default: Optional[str] = None,
     *,
     required: Literal[True],
-    allowed_values: Optional[Iterable[str]] = None,
+    allowed_values: Optional[StrCollection] = None,
     encoding: str = "ascii",
 ) -> str:
     ...
@@ -635,7 +634,7 @@ def parse_string_from_args(
     name: str,
     default: Optional[str] = None,
     required: bool = False,
-    allowed_values: Optional[Iterable[str]] = None,
+    allowed_values: Optional[StrCollection] = None,
     encoding: str = "ascii",
 ) -> Optional[str]:
     ...
@@ -646,7 +645,7 @@ def parse_string_from_args(
     name: str,
     default: Optional[str] = None,
     required: bool = False,
-    allowed_values: Optional[Iterable[str]] = None,
+    allowed_values: Optional[StrCollection] = None,
     encoding: str = "ascii",
 ) -> Optional[str]:
     """
@@ -821,7 +820,7 @@ def parse_and_validate_json_object_from_request(
     return validate_json_object(content, model_type)
 
 
-def assert_params_in_dict(body: JsonDict, required: Iterable[str]) -> None:
+def assert_params_in_dict(body: JsonDict, required: StrCollection) -> None:
     absent = []
     for k in required:
         if k not in body:
diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py
index 39fc629937..3cf2fbc3e2 100644
--- a/synapse/metrics/__init__.py
+++ b/synapse/metrics/__init__.py
@@ -25,7 +25,6 @@ from typing import (
     Iterable,
     Mapping,
     Optional,
-    Sequence,
     Set,
     Tuple,
     Type,
@@ -49,6 +48,7 @@ import synapse.metrics._reactor_metrics  # noqa: F401
 from synapse.metrics._gc import MIN_TIME_BETWEEN_GCS, install_gc_manager
 from synapse.metrics._twisted_exposition import MetricsResource, generate_latest
 from synapse.metrics._types import Collector
+from synapse.types import StrSequence
 from synapse.util import SYNAPSE_VERSION
 
 logger = logging.getLogger(__name__)
@@ -81,7 +81,7 @@ class LaterGauge(Collector):
 
     name: str
     desc: str
-    labels: Optional[Sequence[str]] = attr.ib(hash=False)
+    labels: Optional[StrSequence] = attr.ib(hash=False)
     # callback: should either return a value (if there are no labels for this metric),
     # or dict mapping from a label tuple to a value
     caller: Callable[
@@ -143,8 +143,8 @@ class InFlightGauge(Generic[MetricsEntry], Collector):
         self,
         name: str,
         desc: str,
-        labels: Sequence[str],
-        sub_metrics: Sequence[str],
+        labels: StrSequence,
+        sub_metrics: StrSequence,
     ):
         self.name = name
         self.desc = desc
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 68115bca70..fc39e5c963 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -104,7 +104,7 @@ class _NotifierUserStream:
     def __init__(
         self,
         user_id: str,
-        rooms: Collection[str],
+        rooms: StrCollection,
         current_token: StreamToken,
         time_now_ms: int,
     ):
@@ -457,7 +457,7 @@ class Notifier:
         stream_key: str,
         new_token: Union[int, RoomStreamToken],
         users: Optional[Collection[Union[str, UserID]]] = None,
-        rooms: Optional[Collection[str]] = None,
+        rooms: Optional[StrCollection] = None,
     ) -> None:
         """Used to inform listeners that something has happened event wise.
 
@@ -529,7 +529,7 @@ class Notifier:
         user_id: str,
         timeout: int,
         callback: Callable[[StreamToken, StreamToken], Awaitable[T]],
-        room_ids: Optional[Collection[str]] = None,
+        room_ids: Optional[StrCollection] = None,
         from_token: StreamToken = StreamToken.START,
     ) -> T:
         """Wait until the callback returns a non empty response or the
diff --git a/synapse/rest/client/_base.py b/synapse/rest/client/_base.py
index 5c1c19e1f3..73c568ef75 100644
--- a/synapse/rest/client/_base.py
+++ b/synapse/rest/client/_base.py
@@ -20,14 +20,14 @@ from typing import Any, Awaitable, Callable, Iterable, Pattern, Tuple, TypeVar,
 
 from synapse.api.errors import InteractiveAuthIncompleteError
 from synapse.api.urls import CLIENT_API_PREFIX
-from synapse.types import JsonDict
+from synapse.types import JsonDict, StrCollection
 
 logger = logging.getLogger(__name__)
 
 
 def client_patterns(
     path_regex: str,
-    releases: Iterable[str] = ("r0", "v3"),
+    releases: StrCollection = ("r0", "v3"),
     unstable: bool = True,
     v1: bool = False,
 ) -> Iterable[Pattern]:
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 1b91cf5eaa..e977ed1044 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -20,7 +20,6 @@ from typing import (
     Any,
     Awaitable,
     Callable,
-    Collection,
     DefaultDict,
     Dict,
     FrozenSet,
@@ -49,7 +48,7 @@ from synapse.logging.opentracing import tag_args, trace
 from synapse.replication.http.state import ReplicationUpdateCurrentStateRestServlet
 from synapse.state import v1, v2
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
-from synapse.types import StateMap
+from synapse.types import StateMap, StrCollection
 from synapse.types.state import StateFilter
 from synapse.util.async_helpers import Linearizer
 from synapse.util.caches.expiringcache import ExpiringCache
@@ -197,7 +196,7 @@ class StateHandler:
     async def compute_state_after_events(
         self,
         room_id: str,
-        event_ids: Collection[str],
+        event_ids: StrCollection,
         state_filter: Optional[StateFilter] = None,
         await_full_state: bool = True,
     ) -> StateMap[str]:
@@ -231,7 +230,7 @@ class StateHandler:
         return await ret.get_state(self._state_storage_controller, state_filter)
 
     async def get_current_user_ids_in_room(
-        self, room_id: str, latest_event_ids: Collection[str]
+        self, room_id: str, latest_event_ids: StrCollection
     ) -> Set[str]:
         """
         Get the users IDs who are currently in a room.
@@ -256,7 +255,7 @@ class StateHandler:
         return await self.store.get_joined_user_ids_from_state(room_id, state)
 
     async def get_hosts_in_room_at_events(
-        self, room_id: str, event_ids: Collection[str]
+        self, room_id: str, event_ids: StrCollection
     ) -> FrozenSet[str]:
         """Get the hosts that were in a room at the given event ids
 
@@ -470,7 +469,7 @@ class StateHandler:
     @trace
     @measure_func()
     async def resolve_state_groups_for_events(
-        self, room_id: str, event_ids: Collection[str], await_full_state: bool = True
+        self, room_id: str, event_ids: StrCollection, await_full_state: bool = True
     ) -> _StateCacheEntry:
         """Given a list of event_ids this method fetches the state at each
         event, resolves conflicts between them and returns them.
@@ -882,7 +881,7 @@ class StateResolutionStore:
     store: "DataStore"
 
     def get_events(
-        self, event_ids: Collection[str], allow_rejected: bool = False
+        self, event_ids: StrCollection, allow_rejected: bool = False
     ) -> Awaitable[Dict[str, EventBase]]:
         """Get events from the database
 
diff --git a/synapse/state/v1.py b/synapse/state/v1.py
index 500e384695..c76a2f082e 100644
--- a/synapse/state/v1.py
+++ b/synapse/state/v1.py
@@ -17,7 +17,6 @@ import logging
 from typing import (
     Awaitable,
     Callable,
-    Collection,
     Dict,
     Iterable,
     List,
@@ -32,7 +31,7 @@ from synapse.api.constants import EventTypes
 from synapse.api.errors import AuthError
 from synapse.api.room_versions import RoomVersion
 from synapse.events import EventBase
-from synapse.types import MutableStateMap, StateMap
+from synapse.types import MutableStateMap, StateMap, StrCollection
 
 logger = logging.getLogger(__name__)
 
@@ -45,7 +44,7 @@ async def resolve_events_with_store(
     room_version: RoomVersion,
     state_sets: Sequence[StateMap[str]],
     event_map: Optional[Dict[str, EventBase]],
-    state_map_factory: Callable[[Collection[str]], Awaitable[Dict[str, EventBase]]],
+    state_map_factory: Callable[[StrCollection], Awaitable[Dict[str, EventBase]]],
 ) -> StateMap[str]:
     """
     Args:
diff --git a/synapse/state/v2.py b/synapse/state/v2.py
index 44c49274a9..1752f95db8 100644
--- a/synapse/state/v2.py
+++ b/synapse/state/v2.py
@@ -19,7 +19,6 @@ from typing import (
     Any,
     Awaitable,
     Callable,
-    Collection,
     Dict,
     Generator,
     Iterable,
@@ -39,7 +38,7 @@ from synapse.api.constants import EventTypes
 from synapse.api.errors import AuthError
 from synapse.api.room_versions import RoomVersion
 from synapse.events import EventBase
-from synapse.types import MutableStateMap, StateMap
+from synapse.types import MutableStateMap, StateMap, StrCollection
 
 logger = logging.getLogger(__name__)
 
@@ -56,7 +55,7 @@ class StateResolutionStore(Protocol):
     # This is usually synapse.state.StateResolutionStore, but it's replaced with a
     # TestStateResolutionStore in tests.
     def get_events(
-        self, event_ids: Collection[str], allow_rejected: bool = False
+        self, event_ids: StrCollection, allow_rejected: bool = False
     ) -> Awaitable[Dict[str, EventBase]]:
         ...
 
@@ -366,7 +365,7 @@ async def _get_auth_chain_difference(
         union = unpersisted_set_ids[0].union(*unpersisted_set_ids[1:])
         intersection = unpersisted_set_ids[0].intersection(*unpersisted_set_ids[1:])
 
-        auth_difference_unpersisted_part: Collection[str] = union - intersection
+        auth_difference_unpersisted_part: StrCollection = union - intersection
     else:
         auth_difference_unpersisted_part = ()
         state_sets_ids = [set(state_set.values()) for state_set in state_sets]
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index fab7008a8f..09de8f55e2 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -47,7 +47,7 @@ from synapse.storage.database import (
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.storage.databases.main.signatures import SignatureWorkerStore
 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
-from synapse.types import JsonDict, StrCollection
+from synapse.types import JsonDict, StrCollection, StrSequence
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached
 from synapse.util.caches.lrucache import LruCache
@@ -1179,7 +1179,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         )
 
     @cached(max_entries=5000, iterable=True)
-    async def get_latest_event_ids_in_room(self, room_id: str) -> Sequence[str]:
+    async def get_latest_event_ids_in_room(self, room_id: str) -> StrSequence:
         return await self.db_pool.simple_select_onecol(
             table="event_forward_extremities",
             keyvalues={"room_id": room_id},
diff --git a/synapse/visibility.py b/synapse/visibility.py
index eac10f6438..f15fdd8314 100644
--- a/synapse/visibility.py
+++ b/synapse/visibility.py
@@ -36,7 +36,7 @@ from synapse.events.utils import prune_event
 from synapse.logging.opentracing import trace
 from synapse.storage.controllers import StorageControllers
 from synapse.storage.databases.main import DataStore
-from synapse.types import RetentionPolicy, StateMap, get_domain_from_id
+from synapse.types import RetentionPolicy, StateMap, StrCollection, get_domain_from_id
 from synapse.types.state import StateFilter
 from synapse.util import Clock
 
@@ -150,12 +150,12 @@ async def filter_events_for_client(
 
 async def filter_event_for_clients_with_state(
     store: DataStore,
-    user_ids: Collection[str],
+    user_ids: StrCollection,
     event: EventBase,
     context: EventContext,
     is_peeking: bool = False,
     filter_send_to_client: bool = True,
-) -> Collection[str]:
+) -> StrCollection:
     """
     Checks to see if an event is visible to the users in the list at the time of
     the event.