diff --git a/changelog.d/17616.misc b/changelog.d/17616.misc
new file mode 100644
index 0000000000..8250832dcd
--- /dev/null
+++ b/changelog.d/17616.misc
@@ -0,0 +1 @@
+Overload DatabasePool.simple_select_one_txn to return non-None when the allow_none parameter is False.
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index cb4a5857be..8272e39340 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -2159,10 +2159,26 @@ class DatabasePool:
if rowcount > 1:
raise StoreError(500, "More than one row matched (%s)" % (table,))
- # Ideally we could use the overload decorator here to specify that the
- # return type is only optional if allow_none is True, but this does not work
- # when you call a static method from an instance.
- # See https://github.com/python/mypy/issues/7781
+ @overload
+ @staticmethod
+ def simple_select_one_txn(
+ txn: LoggingTransaction,
+ table: str,
+ keyvalues: Dict[str, Any],
+ retcols: Collection[str],
+ allow_none: Literal[False] = False,
+ ) -> Tuple[Any, ...]: ...
+
+ @overload
+ @staticmethod
+ def simple_select_one_txn(
+ txn: LoggingTransaction,
+ table: str,
+ keyvalues: Dict[str, Any],
+ retcols: Collection[str],
+ allow_none: Literal[True] = True,
+ ) -> Optional[Tuple[Any, ...]]: ...
+
@staticmethod
def simple_select_one_txn(
txn: LoggingTransaction,
diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py
index c2c93e12d9..a618a2de69 100644
--- a/synapse/storage/databases/main/e2e_room_keys.py
+++ b/synapse/storage/databases/main/e2e_room_keys.py
@@ -510,19 +510,16 @@ class EndToEndRoomKeyStore(EndToEndRoomKeyBackgroundStore):
# it isn't there.
raise StoreError(404, "No backup with that version exists")
- row = cast(
- Tuple[int, str, str, Optional[int]],
- self.db_pool.simple_select_one_txn(
- txn,
- table="e2e_room_keys_versions",
- keyvalues={
- "user_id": user_id,
- "version": this_version,
- "deleted": 0,
- },
- retcols=("version", "algorithm", "auth_data", "etag"),
- allow_none=False,
- ),
+ row = self.db_pool.simple_select_one_txn(
+ txn,
+ table="e2e_room_keys_versions",
+ keyvalues={
+ "user_id": user_id,
+ "version": this_version,
+ "deleted": 0,
+ },
+ retcols=("version", "algorithm", "auth_data", "etag"),
+ allow_none=False,
)
return {
"auth_data": db_to_json(row[2]),
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index d7cbe33411..8380930c70 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -1510,15 +1510,14 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
# Override type because the return type is only optional if
# allow_none is True, and we don't want mypy throwing errors
# about None not being indexable.
- pending, completed = cast(
- Tuple[int, int],
- self.db_pool.simple_select_one_txn(
- txn,
- "registration_tokens",
- keyvalues={"token": token},
- retcols=["pending", "completed"],
- ),
+ row = self.db_pool.simple_select_one_txn(
+ txn,
+ "registration_tokens",
+ keyvalues={"token": token},
+ retcols=("pending", "completed"),
)
+ pending = int(row[0])
+ completed = int(row[1])
# Decrement pending and increment completed
self.db_pool.simple_update_one_txn(
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index b4258a4436..40b0bff164 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -1837,15 +1837,14 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
dict
"""
- stream_ordering, topological_ordering = cast(
- Tuple[int, int],
- self.db_pool.simple_select_one_txn(
- txn,
- "events",
- keyvalues={"event_id": event_id, "room_id": room_id},
- retcols=["stream_ordering", "topological_ordering"],
- ),
+ row = self.db_pool.simple_select_one_txn(
+ txn,
+ "events",
+ keyvalues={"event_id": event_id, "room_id": room_id},
+ retcols=("stream_ordering", "topological_ordering"),
)
+ stream_ordering = int(row[0])
+ topological_ordering = int(row[1])
# Paginating backwards includes the event at the token, but paginating
# forward doesn't.
|