summary refs log tree commit diff
path: root/synapse/storage/databases
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2021-12-30 13:47:12 -0500
committerGitHub <noreply@github.com>2021-12-30 18:47:12 +0000
commitcbd82d0b2db069400b5d43373838817d8a0209e7 (patch)
tree5380e1b2c6ad6e112754f45d0d85274f7a8641e8 /synapse/storage/databases
parentAdd type hints to `synapse/storage/databases/main/events_bg_updates.py` (#11654) (diff)
downloadsynapse-cbd82d0b2db069400b5d43373838817d8a0209e7.tar.xz
Convert all namedtuples to attrs. (#11665)
To improve type hints throughout the code.
Diffstat (limited to 'synapse/storage/databases')
-rw-r--r--synapse/storage/databases/main/directory.py10
-rw-r--r--synapse/storage/databases/main/events.py13
-rw-r--r--synapse/storage/databases/main/room.py26
-rw-r--r--synapse/storage/databases/main/search.py16
-rw-r--r--synapse/storage/databases/main/state.py14
-rw-r--r--synapse/storage/databases/main/stream.py12
6 files changed, 54 insertions, 37 deletions
diff --git a/synapse/storage/databases/main/directory.py b/synapse/storage/databases/main/directory.py
index a3442814d7..f76c6121e8 100644
--- a/synapse/storage/databases/main/directory.py
+++ b/synapse/storage/databases/main/directory.py
@@ -12,16 +12,22 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from collections import namedtuple
 from typing import Iterable, List, Optional, Tuple
 
+import attr
+
 from synapse.api.errors import SynapseError
 from synapse.storage.database import LoggingTransaction
 from synapse.storage.databases.main import CacheInvalidationWorkerStore
 from synapse.types import RoomAlias
 from synapse.util.caches.descriptors import cached
 
-RoomAliasMapping = namedtuple("RoomAliasMapping", ("room_id", "room_alias", "servers"))
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class RoomAliasMapping:
+    room_id: str
+    room_alias: str
+    servers: List[str]
 
 
 class DirectoryWorkerStore(CacheInvalidationWorkerStore):
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 81e67ece55..dd255aefb9 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -1976,14 +1976,17 @@ class PersistEventsStore:
                 txn, self.store.get_retention_policy_for_room, (event.room_id,)
             )
 
-    def store_event_search_txn(self, txn, event, key, value):
+    def store_event_search_txn(
+        self, txn: LoggingTransaction, event: EventBase, key: str, value: str
+    ) -> None:
         """Add event to the search table
 
         Args:
-            txn (cursor):
-            event (EventBase):
-            key (str):
-            value (str):
+            txn: The database transaction.
+            event: The event being added to the search table.
+            key: A key describing the search value (one of "content.name",
+                "content.topic", or "content.body")
+            value: The value from the event's content.
         """
         self.store.store_search_entries_txn(
             txn,
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 4472335af9..c0e837854a 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -13,11 +13,22 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import collections
 import logging
 from abc import abstractmethod
 from enum import Enum
-from typing import TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Tuple, cast
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Awaitable,
+    Dict,
+    List,
+    Optional,
+    Tuple,
+    Union,
+    cast,
+)
+
+import attr
 
 from synapse.api.constants import EventContentFields, EventTypes, JoinRules
 from synapse.api.errors import StoreError
@@ -43,9 +54,10 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
-RatelimitOverride = collections.namedtuple(
-    "RatelimitOverride", ("messages_per_second", "burst_count")
-)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class RatelimitOverride:
+    messages_per_second: int
+    burst_count: int
 
 
 class RoomSortOrder(Enum):
@@ -207,6 +219,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
                         WHERE appservice_id = ? AND network_id = ?
                     """
                     query_args.append(network_tuple.appservice_id)
+                    assert network_tuple.network_id is not None
                     query_args.append(network_tuple.network_id)
                 else:
                     published_sql = """
@@ -284,7 +297,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
         """
 
         where_clauses = []
-        query_args = []
+        query_args: List[Union[str, int]] = []
 
         if network_tuple:
             if network_tuple.appservice_id:
@@ -293,6 +306,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
                     WHERE appservice_id = ? AND network_id = ?
                 """
                 query_args.append(network_tuple.appservice_id)
+                assert network_tuple.network_id is not None
                 query_args.append(network_tuple.network_id)
             else:
                 published_sql = """
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index f87acfb866..2d085a5764 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -14,9 +14,10 @@
 
 import logging
 import re
-from collections import namedtuple
 from typing import TYPE_CHECKING, Collection, Iterable, List, Optional, Set
 
+import attr
+
 from synapse.api.errors import SynapseError
 from synapse.events import EventBase
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
@@ -33,10 +34,15 @@ if TYPE_CHECKING:
 
 logger = logging.getLogger(__name__)
 
-SearchEntry = namedtuple(
-    "SearchEntry",
-    ["key", "value", "event_id", "room_id", "stream_ordering", "origin_server_ts"],
-)
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class SearchEntry:
+    key: str
+    value: str
+    event_id: str
+    room_id: str
+    stream_ordering: Optional[int]
+    origin_server_ts: int
 
 
 def _clean_value_for_search(value: str) -> str:
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index 4bc044fb16..7e5a6aae18 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -14,7 +14,6 @@
 # limitations under the License.
 import collections.abc
 import logging
-from collections import namedtuple
 from typing import TYPE_CHECKING, Iterable, Optional, Set
 
 from synapse.api.constants import EventTypes, Membership
@@ -43,19 +42,6 @@ logger = logging.getLogger(__name__)
 MAX_STATE_DELTA_HOPS = 100
 
 
-class _GetStateGroupDelta(
-    namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids"))
-):
-    """Return type of get_state_group_delta that implements __len__, which lets
-    us use the itrable flag when caching
-    """
-
-    __slots__ = []
-
-    def __len__(self):
-        return len(self.delta_ids) if self.delta_ids else 0
-
-
 # this inherits from EventsWorkerStore because it calls self.get_events
 class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
     """The parts of StateGroupStore that can be called from workers."""
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 9488fd5094..b0642ca69f 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -36,9 +36,9 @@ what sort order was used:
 """
 import abc
 import logging
-from collections import namedtuple
 from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Set, Tuple
 
+import attr
 from frozendict import frozendict
 
 from twisted.internet import defer
@@ -74,9 +74,11 @@ _TOPOLOGICAL_TOKEN = "topological"
 
 
 # Used as return values for pagination APIs
-_EventDictReturn = namedtuple(
-    "_EventDictReturn", ("event_id", "topological_ordering", "stream_ordering")
-)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class _EventDictReturn:
+    event_id: str
+    topological_ordering: Optional[int]
+    stream_ordering: int
 
 
 def generate_pagination_where_clause(
@@ -825,7 +827,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
         for event, row in zip(events, rows):
             stream = row.stream_ordering
             if topo_order and row.topological_ordering:
-                topo = row.topological_ordering
+                topo: Optional[int] = row.topological_ordering
             else:
                 topo = None
             internal = event.internal_metadata