summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/14929.misc1
-rw-r--r--synapse/event_auth.py23
-rw-r--r--synapse/events/__init__.py6
-rw-r--r--synapse/storage/databases/main/events.py7
-rw-r--r--synapse/storage/databases/main/events_bg_updates.py6
5 files changed, 19 insertions, 24 deletions
diff --git a/changelog.d/14929.misc b/changelog.d/14929.misc
new file mode 100644
index 0000000000..2cc3614dfd
--- /dev/null
+++ b/changelog.d/14929.misc
@@ -0,0 +1 @@
+Use `StrCollection` to avoid potential bugs with `Collection[str]`.
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index e0be9f88cc..4d6d1b8ebd 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -16,18 +16,7 @@
 import collections.abc
 import logging
 import typing
-from typing import (
-    Any,
-    Collection,
-    Dict,
-    Iterable,
-    List,
-    Mapping,
-    Optional,
-    Set,
-    Tuple,
-    Union,
-)
+from typing import Any, Dict, Iterable, List, Mapping, Optional, Set, Tuple, Union
 
 from canonicaljson import encode_canonical_json
 from signedjson.key import decode_verify_key_bytes
@@ -56,7 +45,13 @@ from synapse.api.room_versions import (
     RoomVersions,
 )
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
-from synapse.types import MutableStateMap, StateMap, UserID, get_domain_from_id
+from synapse.types import (
+    MutableStateMap,
+    StateMap,
+    StrCollection,
+    UserID,
+    get_domain_from_id,
+)
 
 if typing.TYPE_CHECKING:
     # conditional imports to avoid import cycle
@@ -69,7 +64,7 @@ logger = logging.getLogger(__name__)
 class _EventSourceStore(Protocol):
     async def get_events(
         self,
-        event_ids: Collection[str],
+        event_ids: StrCollection,
         redact_behaviour: EventRedactBehaviour,
         get_prev_content: bool = False,
         allow_rejected: bool = False,
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index 8aca9a3ab9..91118a8d84 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -39,7 +39,7 @@ from unpaddedbase64 import encode_base64
 
 from synapse.api.constants import RelationTypes
 from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVersions
-from synapse.types import JsonDict, RoomStreamToken
+from synapse.types import JsonDict, RoomStreamToken, StrCollection
 from synapse.util.caches import intern_dict
 from synapse.util.frozenutils import freeze
 from synapse.util.stringutils import strtobool
@@ -413,7 +413,7 @@ class EventBase(metaclass=abc.ABCMeta):
         """
         return [e for e, _ in self._dict["prev_events"]]
 
-    def auth_event_ids(self) -> Sequence[str]:
+    def auth_event_ids(self) -> StrCollection:
         """Returns the list of auth event IDs. The order matches the order
         specified in the event, though there is no meaning to it.
 
@@ -558,7 +558,7 @@ class FrozenEventV2(EventBase):
         """
         return self._dict["prev_events"]
 
-    def auth_event_ids(self) -> Sequence[str]:
+    def auth_event_ids(self) -> StrCollection:
         """Returns the list of auth event IDs. The order matches the order
         specified in the event, though there is no meaning to it.
 
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index ffe766fd56..7996cbb557 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -25,7 +25,6 @@ from typing import (
     Iterable,
     List,
     Optional,
-    Sequence,
     Set,
     Tuple,
 )
@@ -51,7 +50,7 @@ from synapse.storage.databases.main.search import SearchEntry
 from synapse.storage.engines import PostgresEngine
 from synapse.storage.util.id_generators import AbstractStreamIdGenerator
 from synapse.storage.util.sequence import SequenceGenerator
-from synapse.types import JsonDict, StateMap, get_domain_from_id
+from synapse.types import JsonDict, StateMap, StrCollection, get_domain_from_id
 from synapse.util import json_encoder
 from synapse.util.iterutils import batch_iter, sorted_topologically
 from synapse.util.stringutils import non_null_str_or_none
@@ -552,7 +551,7 @@ class PersistEventsStore:
         event_chain_id_gen: SequenceGenerator,
         event_to_room_id: Dict[str, str],
         event_to_types: Dict[str, Tuple[str, str]],
-        event_to_auth_chain: Dict[str, Sequence[str]],
+        event_to_auth_chain: Dict[str, StrCollection],
     ) -> None:
         """Calculate the chain cover index for the given events.
 
@@ -846,7 +845,7 @@ class PersistEventsStore:
         event_chain_id_gen: SequenceGenerator,
         event_to_room_id: Dict[str, str],
         event_to_types: Dict[str, Tuple[str, str]],
-        event_to_auth_chain: Dict[str, Sequence[str]],
+        event_to_auth_chain: Dict[str, StrCollection],
         events_to_calc_chain_id_for: Set[str],
         chain_map: Dict[str, Tuple[int, int]],
     ) -> Dict[str, Tuple[int, int]]:
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index b9d3c36d60..584536111d 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -13,7 +13,7 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, Tuple, cast
+from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, cast
 
 import attr
 
@@ -29,7 +29,7 @@ from synapse.storage.database import (
 )
 from synapse.storage.databases.main.events import PersistEventsStore
 from synapse.storage.types import Cursor
-from synapse.types import JsonDict
+from synapse.types import JsonDict, StrCollection
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -1061,7 +1061,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
             self.event_chain_id_gen,  # type: ignore[attr-defined]
             event_to_room_id,
             event_to_types,
-            cast(Dict[str, Sequence[str]], event_to_auth_chain),
+            cast(Dict[str, StrCollection], event_to_auth_chain),
         )
 
         return _CalculateChainCover(