diff --git a/synapse/storage/databases/main/directory.py b/synapse/storage/databases/main/directory.py
index 6daf8b8ffb..a3442814d7 100644
--- a/synapse/storage/databases/main/directory.py
+++ b/synapse/storage/databases/main/directory.py
@@ -13,17 +13,18 @@
# limitations under the License.
from collections import namedtuple
-from typing import Iterable, List, Optional
+from typing import Iterable, List, Optional, Tuple
from synapse.api.errors import SynapseError
-from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import LoggingTransaction
+from synapse.storage.databases.main import CacheInvalidationWorkerStore
from synapse.types import RoomAlias
from synapse.util.caches.descriptors import cached
RoomAliasMapping = namedtuple("RoomAliasMapping", ("room_id", "room_alias", "servers"))
-class DirectoryWorkerStore(SQLBaseStore):
+class DirectoryWorkerStore(CacheInvalidationWorkerStore):
async def get_association_from_room_alias(
self, room_alias: RoomAlias
) -> Optional[RoomAliasMapping]:
@@ -91,7 +92,7 @@ class DirectoryWorkerStore(SQLBaseStore):
creator: Optional user_id of creator.
"""
- def alias_txn(txn):
+ def alias_txn(txn: LoggingTransaction) -> None:
self.db_pool.simple_insert_txn(
txn,
"room_aliases",
@@ -126,14 +127,16 @@ class DirectoryWorkerStore(SQLBaseStore):
class DirectoryStore(DirectoryWorkerStore):
- async def delete_room_alias(self, room_alias: RoomAlias) -> str:
+ async def delete_room_alias(self, room_alias: RoomAlias) -> Optional[str]:
room_id = await self.db_pool.runInteraction(
"delete_room_alias", self._delete_room_alias_txn, room_alias
)
return room_id
- def _delete_room_alias_txn(self, txn, room_alias: RoomAlias) -> str:
+ def _delete_room_alias_txn(
+ self, txn: LoggingTransaction, room_alias: RoomAlias
+ ) -> Optional[str]:
txn.execute(
"SELECT room_id FROM room_aliases WHERE room_alias = ?",
(room_alias.to_string(),),
@@ -173,9 +176,9 @@ class DirectoryStore(DirectoryWorkerStore):
If None, the creator will be left unchanged.
"""
- def _update_aliases_for_room_txn(txn):
+ def _update_aliases_for_room_txn(txn: LoggingTransaction) -> None:
update_creator_sql = ""
- sql_params = (new_room_id, old_room_id)
+ sql_params: Tuple[str, ...] = (new_room_id, old_room_id)
if creator:
update_creator_sql = ", creator = ?"
sql_params = (new_room_id, creator, old_room_id)
|