summary refs log tree commit diff
path: root/synapse/storage/databases/main/relations.py
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2021-10-28 14:35:12 -0400
committerGitHub <noreply@github.com>2021-10-28 14:35:12 -0400
commit56e281bf6c4f58929d56e3901856f6d0fa4b1816 (patch)
tree5077a686b297ee14412e717529963d855cb5e48d /synapse/storage/databases/main/relations.py
parentAdd knock information in admin exported data (#11171) (diff)
downloadsynapse-56e281bf6c4f58929d56e3901856f6d0fa4b1816.tar.xz
Additional type hints for relations database class. (#11205)
Diffstat (limited to 'synapse/storage/databases/main/relations.py')
-rw-r--r--synapse/storage/databases/main/relations.py38
1 files changed, 23 insertions, 15 deletions
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 40760fbd1b..53576ad52f 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -13,13 +13,14 @@
 # limitations under the License.
 
 import logging
-from typing import Optional, Tuple
+from typing import List, Optional, Tuple, Union
 
 import attr
 
 from synapse.api.constants import RelationTypes
 from synapse.events import EventBase
 from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import LoggingTransaction
 from synapse.storage.databases.main.stream import generate_pagination_where_clause
 from synapse.storage.relations import (
     AggregationPaginationToken,
@@ -63,7 +64,7 @@ class RelationsWorkerStore(SQLBaseStore):
         """
 
         where_clause = ["relates_to_id = ?"]
-        where_args = [event_id]
+        where_args: List[Union[str, int]] = [event_id]
 
         if relation_type is not None:
             where_clause.append("relation_type = ?")
@@ -80,8 +81,8 @@ class RelationsWorkerStore(SQLBaseStore):
         pagination_clause = generate_pagination_where_clause(
             direction=direction,
             column_names=("topological_ordering", "stream_ordering"),
-            from_token=attr.astuple(from_token) if from_token else None,
-            to_token=attr.astuple(to_token) if to_token else None,
+            from_token=attr.astuple(from_token) if from_token else None,  # type: ignore[arg-type]
+            to_token=attr.astuple(to_token) if to_token else None,  # type: ignore[arg-type]
             engine=self.database_engine,
         )
 
@@ -106,7 +107,9 @@ class RelationsWorkerStore(SQLBaseStore):
             order,
         )
 
-        def _get_recent_references_for_event_txn(txn):
+        def _get_recent_references_for_event_txn(
+            txn: LoggingTransaction,
+        ) -> PaginationChunk:
             txn.execute(sql, where_args + [limit + 1])
 
             last_topo_id = None
@@ -160,7 +163,7 @@ class RelationsWorkerStore(SQLBaseStore):
         """
 
         where_clause = ["relates_to_id = ?", "relation_type = ?"]
-        where_args = [event_id, RelationTypes.ANNOTATION]
+        where_args: List[Union[str, int]] = [event_id, RelationTypes.ANNOTATION]
 
         if event_type:
             where_clause.append("type = ?")
@@ -169,8 +172,8 @@ class RelationsWorkerStore(SQLBaseStore):
         having_clause = generate_pagination_where_clause(
             direction=direction,
             column_names=("COUNT(*)", "MAX(stream_ordering)"),
-            from_token=attr.astuple(from_token) if from_token else None,
-            to_token=attr.astuple(to_token) if to_token else None,
+            from_token=attr.astuple(from_token) if from_token else None,  # type: ignore[arg-type]
+            to_token=attr.astuple(to_token) if to_token else None,  # type: ignore[arg-type]
             engine=self.database_engine,
         )
 
@@ -199,7 +202,9 @@ class RelationsWorkerStore(SQLBaseStore):
             having_clause=having_clause,
         )
 
-        def _get_aggregation_groups_for_event_txn(txn):
+        def _get_aggregation_groups_for_event_txn(
+            txn: LoggingTransaction,
+        ) -> PaginationChunk:
             txn.execute(sql, where_args + [limit + 1])
 
             next_batch = None
@@ -254,11 +259,12 @@ class RelationsWorkerStore(SQLBaseStore):
             LIMIT 1
         """
 
-        def _get_applicable_edit_txn(txn):
+        def _get_applicable_edit_txn(txn: LoggingTransaction) -> Optional[str]:
             txn.execute(sql, (event_id, RelationTypes.REPLACE))
             row = txn.fetchone()
             if row:
                 return row[0]
+            return None
 
         edit_id = await self.db_pool.runInteraction(
             "get_applicable_edit", _get_applicable_edit_txn
@@ -267,7 +273,7 @@ class RelationsWorkerStore(SQLBaseStore):
         if not edit_id:
             return None
 
-        return await self.get_event(edit_id, allow_none=True)
+        return await self.get_event(edit_id, allow_none=True)  # type: ignore[attr-defined]
 
     @cached()
     async def get_thread_summary(
@@ -283,7 +289,9 @@ class RelationsWorkerStore(SQLBaseStore):
             The number of items in the thread and the most recent response, if any.
         """
 
-        def _get_thread_summary_txn(txn) -> Tuple[int, Optional[str]]:
+        def _get_thread_summary_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[int, Optional[str]]:
             # Fetch the count of threaded events and the latest event ID.
             # TODO Should this only allow m.room.message events.
             sql = """
@@ -312,7 +320,7 @@ class RelationsWorkerStore(SQLBaseStore):
                     AND relation_type = ?
             """
             txn.execute(sql, (event_id, RelationTypes.THREAD))
-            count = txn.fetchone()[0]
+            count = txn.fetchone()[0]  # type: ignore[index]
 
             return count, latest_event_id
 
@@ -322,7 +330,7 @@ class RelationsWorkerStore(SQLBaseStore):
 
         latest_event = None
         if latest_event_id:
-            latest_event = await self.get_event(latest_event_id, allow_none=True)
+            latest_event = await self.get_event(latest_event_id, allow_none=True)  # type: ignore[attr-defined]
 
         return count, latest_event
 
@@ -354,7 +362,7 @@ class RelationsWorkerStore(SQLBaseStore):
             LIMIT 1;
         """
 
-        def _get_if_user_has_annotated_event(txn):
+        def _get_if_user_has_annotated_event(txn: LoggingTransaction) -> bool:
             txn.execute(
                 sql,
                 (