diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index bbff3c8d5b..c0d9d1240f 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -27,6 +27,7 @@ There are also schemas that get applied to every database, regardless of the
data stores associated with them (e.g. the schema version tables), which are
stored in `synapse.storage.schema`.
"""
+from typing import TYPE_CHECKING
from synapse.storage.databases import Databases
from synapse.storage.databases.main import DataStore
@@ -34,14 +35,18 @@ from synapse.storage.persist_events import EventsPersistenceStorage
from synapse.storage.purge_events import PurgeEventsStorage
from synapse.storage.state import StateGroupStorage
-__all__ = ["DataStores", "DataStore"]
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
+
+__all__ = ["Databases", "DataStore"]
class Storage:
"""The high level interfaces for talking to various storage layers.
"""
- def __init__(self, hs, stores: Databases):
+ def __init__(self, hs: "HomeServer", stores: Databases):
# We include the main data store here mainly so that we don't have to
# rewrite all the existing code to split it into high vs low level
# interfaces.
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 2b196ded1b..a25c4093bc 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -17,14 +17,18 @@
import logging
import random
from abc import ABCMeta
-from typing import Any, Optional
+from typing import TYPE_CHECKING, Any, Iterable, Optional, Union
from synapse.storage.database import LoggingTransaction # noqa: F401
from synapse.storage.database import make_in_list_sql_clause # noqa: F401
from synapse.storage.database import DatabasePool
-from synapse.types import Collection, get_domain_from_id
+from synapse.storage.types import Connection
+from synapse.types import Collection, StreamToken, get_domain_from_id
from synapse.util import json_decoder
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
logger = logging.getLogger(__name__)
@@ -36,24 +40,31 @@ class SQLBaseStore(metaclass=ABCMeta):
per data store (and not one per physical database).
"""
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
self.hs = hs
self._clock = hs.get_clock()
self.database_engine = database.engine
self.db_pool = database
self.rand = random.SystemRandom()
- def process_replication_rows(self, stream_name, instance_name, token, rows):
+ def process_replication_rows(
+ self,
+ stream_name: str,
+ instance_name: str,
+ token: StreamToken,
+ rows: Iterable[Any],
+ ) -> None:
pass
- def _invalidate_state_caches(self, room_id, members_changed):
+ def _invalidate_state_caches(
+ self, room_id: str, members_changed: Iterable[str]
+ ) -> None:
"""Invalidates caches that are based on the current state, but does
not stream invalidations down replication.
Args:
- room_id (str): Room where state changed
- members_changed (iterable[str]): The user_ids of members that have
- changed
+ room_id: Room where state changed
+ members_changed: The user_ids of members that have changed
"""
for host in {get_domain_from_id(u) for u in members_changed}:
self._attempt_to_invalidate_cache("is_host_joined", (room_id, host))
@@ -64,7 +75,7 @@ class SQLBaseStore(metaclass=ABCMeta):
def _attempt_to_invalidate_cache(
self, cache_name: str, key: Optional[Collection[Any]]
- ):
+ ) -> None:
"""Attempts to invalidate the cache of the given name, ignoring if the
cache doesn't exist. Mainly used for invalidating caches on workers,
where they may not have the cache.
@@ -88,12 +99,15 @@ class SQLBaseStore(metaclass=ABCMeta):
cache.invalidate(tuple(key))
-def db_to_json(db_content):
+def db_to_json(db_content: Union[memoryview, bytes, bytearray, str]) -> Any:
"""
Take some data from a database row and return a JSON-decoded object.
Args:
- db_content (memoryview|buffer|bytes|bytearray|unicode)
+ db_content: The JSON-encoded contents from the database.
+
+ Returns:
+ The object decoded from JSON.
"""
# psycopg2 on Python 3 returns memoryview objects, which we need to
# cast to bytes to decode
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 810721ebe9..29b8ca676a 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -12,29 +12,34 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
-from typing import Optional
+from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Iterable, Optional
from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage.types import Connection
+from synapse.types import JsonDict
from synapse.util import json_encoder
from . import engines
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+ from synapse.storage.database import DatabasePool, LoggingTransaction
+
logger = logging.getLogger(__name__)
class BackgroundUpdatePerformance:
"""Tracks the how long a background update is taking to update its items"""
- def __init__(self, name):
+ def __init__(self, name: str):
self.name = name
self.total_item_count = 0
- self.total_duration_ms = 0
- self.avg_item_count = 0
- self.avg_duration_ms = 0
+ self.total_duration_ms = 0.0
+ self.avg_item_count = 0.0
+ self.avg_duration_ms = 0.0
- def update(self, item_count, duration_ms):
+ def update(self, item_count: int, duration_ms: float) -> None:
"""Update the stats after doing an update"""
self.total_item_count += item_count
self.total_duration_ms += duration_ms
@@ -44,7 +49,7 @@ class BackgroundUpdatePerformance:
self.avg_item_count += 0.1 * (item_count - self.avg_item_count)
self.avg_duration_ms += 0.1 * (duration_ms - self.avg_duration_ms)
- def average_items_per_ms(self):
+ def average_items_per_ms(self) -> Optional[float]:
"""An estimate of how long it takes to do a single update.
Returns:
A duration in ms as a float
@@ -58,7 +63,7 @@ class BackgroundUpdatePerformance:
# changes in how long the update process takes.
return float(self.avg_item_count) / float(self.avg_duration_ms)
- def total_items_per_ms(self):
+ def total_items_per_ms(self) -> Optional[float]:
"""An estimate of how long it takes to do a single update.
Returns:
A duration in ms as a float
@@ -83,21 +88,25 @@ class BackgroundUpdater:
BACKGROUND_UPDATE_INTERVAL_MS = 1000
BACKGROUND_UPDATE_DURATION_MS = 100
- def __init__(self, hs, database):
+ def __init__(self, hs: "HomeServer", database: "DatabasePool"):
self._clock = hs.get_clock()
self.db_pool = database
# if a background update is currently running, its name.
self._current_background_update = None # type: Optional[str]
- self._background_update_performance = {}
- self._background_update_handlers = {}
+ self._background_update_performance = (
+ {}
+ ) # type: Dict[str, BackgroundUpdatePerformance]
+ self._background_update_handlers = (
+ {}
+ ) # type: Dict[str, Callable[[JsonDict, int], Awaitable[int]]]
self._all_done = False
- def start_doing_background_updates(self):
+ def start_doing_background_updates(self) -> None:
run_as_background_process("background_updates", self.run_background_updates)
- async def run_background_updates(self, sleep=True):
+ async def run_background_updates(self, sleep: bool = True) -> None:
logger.info("Starting background schema updates")
while True:
if sleep:
@@ -148,7 +157,7 @@ class BackgroundUpdater:
return False
- async def has_completed_background_update(self, update_name) -> bool:
+ async def has_completed_background_update(self, update_name: str) -> bool:
"""Check if the given background update has finished running.
"""
if self._all_done:
@@ -173,8 +182,7 @@ class BackgroundUpdater:
Returns once some amount of work is done.
Args:
- desired_duration_ms(float): How long we want to spend
- updating.
+ desired_duration_ms: How long we want to spend updating.
Returns:
True if we have finished running all the background updates, otherwise False
"""
@@ -220,6 +228,7 @@ class BackgroundUpdater:
return False
async def _do_background_update(self, desired_duration_ms: float) -> int:
+ assert self._current_background_update is not None
update_name = self._current_background_update
logger.info("Starting update batch on background update '%s'", update_name)
@@ -273,7 +282,11 @@ class BackgroundUpdater:
return len(self._background_update_performance)
- def register_background_update_handler(self, update_name, update_handler):
+ def register_background_update_handler(
+ self,
+ update_name: str,
+ update_handler: Callable[[JsonDict, int], Awaitable[int]],
+ ):
"""Register a handler for doing a background update.
The handler should take two arguments:
@@ -287,12 +300,12 @@ class BackgroundUpdater:
The handler is responsible for updating the progress of the update.
Args:
- update_name(str): The name of the update that this code handles.
- update_handler(function): The function that does the update.
+ update_name: The name of the update that this code handles.
+ update_handler: The function that does the update.
"""
self._background_update_handlers[update_name] = update_handler
- def register_noop_background_update(self, update_name):
+ def register_noop_background_update(self, update_name: str) -> None:
"""Register a noop handler for a background update.
This is useful when we previously did a background update, but no
@@ -302,10 +315,10 @@ class BackgroundUpdater:
also be called to clear the update.
Args:
- update_name (str): Name of update
+ update_name: Name of update
"""
- async def noop_update(progress, batch_size):
+ async def noop_update(progress: JsonDict, batch_size: int) -> int:
await self._end_background_update(update_name)
return 1
@@ -313,14 +326,14 @@ class BackgroundUpdater:
def register_background_index_update(
self,
- update_name,
- index_name,
- table,
- columns,
- where_clause=None,
- unique=False,
- psql_only=False,
- ):
+ update_name: str,
+ index_name: str,
+ table: str,
+ columns: Iterable[str],
+ where_clause: Optional[str] = None,
+ unique: bool = False,
+ psql_only: bool = False,
+ ) -> None:
"""Helper for store classes to do a background index addition
To use:
@@ -332,19 +345,19 @@ class BackgroundUpdater:
2. In the Store constructor, call this method
Args:
- update_name (str): update_name to register for
- index_name (str): name of index to add
- table (str): table to add index to
- columns (list[str]): columns/expressions to include in index
- unique (bool): true to make a UNIQUE index
+ update_name: update_name to register for
+ index_name: name of index to add
+ table: table to add index to
+ columns: columns/expressions to include in index
+ unique: true to make a UNIQUE index
psql_only: true to only create this index on psql databases (useful
for virtual sqlite tables)
"""
- def create_index_psql(conn):
+ def create_index_psql(conn: Connection) -> None:
conn.rollback()
# postgres insists on autocommit for the index
- conn.set_session(autocommit=True)
+ conn.set_session(autocommit=True) # type: ignore
try:
c = conn.cursor()
@@ -371,9 +384,9 @@ class BackgroundUpdater:
logger.debug("[SQL] %s", sql)
c.execute(sql)
finally:
- conn.set_session(autocommit=False)
+ conn.set_session(autocommit=False) # type: ignore
- def create_index_sqlite(conn):
+ def create_index_sqlite(conn: Connection) -> None:
# Sqlite doesn't support concurrent creation of indexes.
#
# We don't use partial indices on SQLite as it wasn't introduced
@@ -399,7 +412,7 @@ class BackgroundUpdater:
c.execute(sql)
if isinstance(self.db_pool.engine, engines.PostgresEngine):
- runner = create_index_psql
+ runner = create_index_psql # type: Optional[Callable[[Connection], None]]
elif psql_only:
runner = None
else:
@@ -433,7 +446,9 @@ class BackgroundUpdater:
"background_updates", keyvalues={"update_name": update_name}
)
- async def _background_update_progress(self, update_name: str, progress: dict):
+ async def _background_update_progress(
+ self, update_name: str, progress: dict
+ ) -> None:
"""Update the progress of a background update
Args:
@@ -441,20 +456,22 @@ class BackgroundUpdater:
progress: The progress of the update.
"""
- return await self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"background_update_progress",
self._background_update_progress_txn,
update_name,
progress,
)
- def _background_update_progress_txn(self, txn, update_name, progress):
+ def _background_update_progress_txn(
+ self, txn: "LoggingTransaction", update_name: str, progress: JsonDict
+ ) -> None:
"""Update the progress of a background update
Args:
- txn(cursor): The transaction.
- update_name(str): The name of the background update task
- progress(dict): The progress of the update.
+ txn: The transaction.
+ update_name: The name of the background update task
+ progress: The progress of the update.
"""
progress_json = json_encoder.encode(progress)
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 43660ec4fb..701748f93b 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -149,9 +149,6 @@ class DataStore(
self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
- self._pushers_id_gen = StreamIdGenerator(
- db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
- )
self._group_updates_id_gen = StreamIdGenerator(
db_conn, "local_group_updates", "stream_id"
)
@@ -342,12 +339,13 @@ class DataStore(
filters = []
args = [self.hs.config.server_name]
+ # `name` is in database already in lower case
if name:
- filters.append("(name LIKE ? OR displayname LIKE ?)")
- args.extend(["@%" + name + "%:%", "%" + name + "%"])
+ filters.append("(name LIKE ? OR LOWER(displayname) LIKE ?)")
+ args.extend(["@%" + name.lower() + "%:%", "%" + name.lower() + "%"])
elif user_id:
filters.append("name LIKE ?")
- args.extend(["%" + user_id + "%"])
+ args.extend(["%" + user_id.lower() + "%"])
if not guests:
filters.append("is_guest = 0")
diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index 2408432738..c5468c7b0d 100644
--- a/synapse/storage/databases/main/client_ips.py
+++ b/synapse/storage/databases/main/client_ips.py
@@ -14,11 +14,12 @@
# limitations under the License.
import logging
-from typing import Dict, Optional, Tuple
+from typing import Dict, List, Optional, Tuple, Union
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool, make_tuple_comparison_clause
+from synapse.types import UserID
from synapse.util.caches.lrucache import LruCache
logger = logging.getLogger(__name__)
@@ -546,7 +547,9 @@ class ClientIpStore(ClientIpWorkerStore):
}
return ret
- async def get_user_ip_and_agents(self, user):
+ async def get_user_ip_and_agents(
+ self, user: UserID
+ ) -> List[Dict[str, Union[str, int]]]:
user_id = user.to_string()
results = {}
diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py
index f8f4bb9b3f..04ac2d0ced 100644
--- a/synapse/storage/databases/main/keys.py
+++ b/synapse/storage/databases/main/keys.py
@@ -22,6 +22,7 @@ from signedjson.key import decode_verify_key_bytes
from synapse.storage._base import SQLBaseStore
from synapse.storage.keys import FetchKeyResult
+from synapse.storage.types import Cursor
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.iterutils import batch_iter
@@ -44,7 +45,7 @@ class KeyStore(SQLBaseStore):
)
async def get_server_verify_keys(
self, server_name_and_key_ids: Iterable[Tuple[str, str]]
- ) -> Dict[Tuple[str, str], Optional[FetchKeyResult]]:
+ ) -> Dict[Tuple[str, str], FetchKeyResult]:
"""
Args:
server_name_and_key_ids:
@@ -56,7 +57,7 @@ class KeyStore(SQLBaseStore):
"""
keys = {}
- def _get_keys(txn, batch):
+ def _get_keys(txn: Cursor, batch: Tuple[Tuple[str, str]]) -> None:
"""Processes a batch of keys to fetch, and adds the result to `keys`."""
# batch_iter always returns tuples so it's safe to do len(batch)
@@ -77,13 +78,12 @@ class KeyStore(SQLBaseStore):
# `ts_valid_until_ms`.
ts_valid_until_ms = 0
- res = FetchKeyResult(
+ keys[(server_name, key_id)] = FetchKeyResult(
verify_key=decode_verify_key_bytes(key_id, bytes(key_bytes)),
valid_until_ts=ts_valid_until_ms,
)
- keys[(server_name, key_id)] = res
- def _txn(txn):
+ def _txn(txn: Cursor) -> Dict[Tuple[str, str], FetchKeyResult]:
for batch in batch_iter(server_name_and_key_ids, 50):
_get_keys(txn, batch)
return keys
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index 7997242d90..77ba9d819e 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -15,18 +15,32 @@
# limitations under the License.
import logging
-from typing import Iterable, Iterator, List, Tuple
+from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, Tuple
from canonicaljson import encode_canonical_json
+from synapse.push import PusherConfig, ThrottleParams
from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage.database import DatabasePool
+from synapse.storage.types import Connection
+from synapse.storage.util.id_generators import StreamIdGenerator
+from synapse.types import JsonDict
from synapse.util.caches.descriptors import cached, cachedList
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
logger = logging.getLogger(__name__)
class PusherWorkerStore(SQLBaseStore):
- def _decode_pushers_rows(self, rows: Iterable[dict]) -> Iterator[dict]:
+ def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+ super().__init__(database, db_conn, hs)
+ self._pushers_id_gen = StreamIdGenerator(
+ db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
+ )
+
+ def _decode_pushers_rows(self, rows: Iterable[dict]) -> Iterator[PusherConfig]:
"""JSON-decode the data in the rows returned from the `pushers` table
Drops any rows whose data cannot be decoded
@@ -44,21 +58,23 @@ class PusherWorkerStore(SQLBaseStore):
)
continue
- yield r
+ yield PusherConfig(**r)
- async def user_has_pusher(self, user_id):
+ async def user_has_pusher(self, user_id: str) -> bool:
ret = await self.db_pool.simple_select_one_onecol(
"pushers", {"user_name": user_id}, "id", allow_none=True
)
return ret is not None
- def get_pushers_by_app_id_and_pushkey(self, app_id, pushkey):
- return self.get_pushers_by({"app_id": app_id, "pushkey": pushkey})
+ async def get_pushers_by_app_id_and_pushkey(
+ self, app_id: str, pushkey: str
+ ) -> Iterator[PusherConfig]:
+ return await self.get_pushers_by({"app_id": app_id, "pushkey": pushkey})
- def get_pushers_by_user_id(self, user_id):
- return self.get_pushers_by({"user_name": user_id})
+ async def get_pushers_by_user_id(self, user_id: str) -> Iterator[PusherConfig]:
+ return await self.get_pushers_by({"user_name": user_id})
- async def get_pushers_by(self, keyvalues):
+ async def get_pushers_by(self, keyvalues: Dict[str, Any]) -> Iterator[PusherConfig]:
ret = await self.db_pool.simple_select_list(
"pushers",
keyvalues,
@@ -83,7 +99,7 @@ class PusherWorkerStore(SQLBaseStore):
)
return self._decode_pushers_rows(ret)
- async def get_all_pushers(self):
+ async def get_all_pushers(self) -> Iterator[PusherConfig]:
def get_pushers(txn):
txn.execute("SELECT * FROM pushers")
rows = self.db_pool.cursor_to_dict(txn)
@@ -159,14 +175,16 @@ class PusherWorkerStore(SQLBaseStore):
)
@cached(num_args=1, max_entries=15000)
- async def get_if_user_has_pusher(self, user_id):
+ async def get_if_user_has_pusher(self, user_id: str):
# This only exists for the cachedList decorator
raise NotImplementedError()
@cachedList(
cached_method_name="get_if_user_has_pusher", list_name="user_ids", num_args=1,
)
- async def get_if_users_have_pushers(self, user_ids):
+ async def get_if_users_have_pushers(
+ self, user_ids: Iterable[str]
+ ) -> Dict[str, bool]:
rows = await self.db_pool.simple_select_many_batch(
table="pushers",
column="user_name",
@@ -224,7 +242,7 @@ class PusherWorkerStore(SQLBaseStore):
return bool(updated)
async def update_pusher_failing_since(
- self, app_id, pushkey, user_id, failing_since
+ self, app_id: str, pushkey: str, user_id: str, failing_since: Optional[int]
) -> None:
await self.db_pool.simple_update(
table="pushers",
@@ -233,7 +251,9 @@ class PusherWorkerStore(SQLBaseStore):
desc="update_pusher_failing_since",
)
- async def get_throttle_params_by_room(self, pusher_id):
+ async def get_throttle_params_by_room(
+ self, pusher_id: str
+ ) -> Dict[str, ThrottleParams]:
res = await self.db_pool.simple_select_list(
"pusher_throttle",
{"pusher": pusher_id},
@@ -243,43 +263,44 @@ class PusherWorkerStore(SQLBaseStore):
params_by_room = {}
for row in res:
- params_by_room[row["room_id"]] = {
- "last_sent_ts": row["last_sent_ts"],
- "throttle_ms": row["throttle_ms"],
- }
+ params_by_room[row["room_id"]] = ThrottleParams(
+ row["last_sent_ts"], row["throttle_ms"],
+ )
return params_by_room
- async def set_throttle_params(self, pusher_id, room_id, params) -> None:
+ async def set_throttle_params(
+ self, pusher_id: str, room_id: str, params: ThrottleParams
+ ) -> None:
# no need to lock because `pusher_throttle` has a primary key on
# (pusher, room_id) so simple_upsert will retry
await self.db_pool.simple_upsert(
"pusher_throttle",
{"pusher": pusher_id, "room_id": room_id},
- params,
+ {"last_sent_ts": params.last_sent_ts, "throttle_ms": params.throttle_ms},
desc="set_throttle_params",
lock=False,
)
class PusherStore(PusherWorkerStore):
- def get_pushers_stream_token(self):
+ def get_pushers_stream_token(self) -> int:
return self._pushers_id_gen.get_current_token()
async def add_pusher(
self,
- user_id,
- access_token,
- kind,
- app_id,
- app_display_name,
- device_display_name,
- pushkey,
- pushkey_ts,
- lang,
- data,
- last_stream_ordering,
- profile_tag="",
+ user_id: str,
+ access_token: Optional[int],
+ kind: str,
+ app_id: str,
+ app_display_name: str,
+ device_display_name: str,
+ pushkey: str,
+ pushkey_ts: int,
+ lang: Optional[str],
+ data: Optional[JsonDict],
+ last_stream_ordering: int,
+ profile_tag: str = "",
) -> None:
async with self._pushers_id_gen.get_next() as stream_id:
# no need to lock because `pushers` has a unique key on
@@ -311,16 +332,16 @@ class PusherStore(PusherWorkerStore):
# invalidate, since we the user might not have had a pusher before
await self.db_pool.runInteraction(
"add_pusher",
- self._invalidate_cache_and_stream,
+ self._invalidate_cache_and_stream, # type: ignore
self.get_if_user_has_pusher,
(user_id,),
)
async def delete_pusher_by_app_id_pushkey_user_id(
- self, app_id, pushkey, user_id
+ self, app_id: str, pushkey: str, user_id: str
) -> None:
def delete_pusher_txn(txn, stream_id):
- self._invalidate_cache_and_stream(
+ self._invalidate_cache_and_stream( # type: ignore
txn, self.get_if_user_has_pusher, (user_id,)
)
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index ff96c34c2e..8d05288ed4 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -943,6 +943,42 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
desc="del_user_pending_deactivation",
)
+ async def get_access_token_last_validated(self, token_id: int) -> int:
+ """Retrieves the time (in milliseconds) of the last validation of an access token.
+
+ Args:
+ token_id: The ID of the access token to update.
+ Raises:
+ StoreError if the access token was not found.
+
+ Returns:
+ The last validation time.
+ """
+ result = await self.db_pool.simple_select_one_onecol(
+ "access_tokens", {"id": token_id}, "last_validated"
+ )
+
+ # If this token has not been validated (since starting to track this),
+ # return 0 instead of None.
+ return result or 0
+
+ async def update_access_token_last_validated(self, token_id: int) -> None:
+ """Updates the last time an access token was validated.
+
+ Args:
+ token_id: The ID of the access token to update.
+ Raises:
+ StoreError if there was a problem updating this.
+ """
+ now = self._clock.time_msec()
+
+ await self.db_pool.simple_update_one(
+ "access_tokens",
+ {"id": token_id},
+ {"last_validated": now},
+ desc="update_access_token_last_validated",
+ )
+
class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
@@ -1150,6 +1186,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
The token ID
"""
next_id = self._access_tokens_id_gen.get_next()
+ now = self._clock.time_msec()
await self.db_pool.simple_insert(
"access_tokens",
@@ -1160,6 +1197,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
"device_id": device_id,
"valid_until_ms": valid_until_ms,
"puppets_user_id": puppets_user_id,
+ "last_validated": now,
},
desc="add_access_token_to_user",
)
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 6b89db15c9..4650d0689b 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -379,14 +379,14 @@ class RoomWorkerStore(SQLBaseStore):
# Filter room names by a string
where_statement = ""
if search_term:
- where_statement = "WHERE state.name LIKE ?"
+ where_statement = "WHERE LOWER(state.name) LIKE ?"
# Our postgres db driver converts ? -> %s in SQL strings as that's the
# placeholder for postgres.
# HOWEVER, if you put a % into your SQL then everything goes wibbly.
# To get around this, we're going to surround search_term with %'s
# before giving it to the database in python instead
- search_term = "%" + search_term + "%"
+ search_term = "%" + search_term.lower() + "%"
# Set ordering
if RoomSortOrder(order_by) == RoomSortOrder.SIZE:
diff --git a/synapse/storage/databases/main/schema/delta/58/26access_token_last_validated.sql b/synapse/storage/databases/main/schema/delta/58/26access_token_last_validated.sql
new file mode 100644
index 0000000000..1a101cd5eb
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/26access_token_last_validated.sql
@@ -0,0 +1,18 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- The last time this access token was "validated" (i.e. logged in or succeeded
+-- at user-interactive authentication).
+ALTER TABLE access_tokens ADD COLUMN last_validated BIGINT;
diff --git a/synapse/storage/databases/main/schema/delta/58/27local_invites.sql b/synapse/storage/databases/main/schema/delta/58/27local_invites.sql
new file mode 100644
index 0000000000..44b2a0572f
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/27local_invites.sql
@@ -0,0 +1,18 @@
+/*
+ * Copyright 2020 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- This is unused since Synapse v1.17.0.
+DROP TABLE local_invites;
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index d87ceec6da..ef11f1c3b3 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -17,7 +17,7 @@ import logging
import re
from typing import Any, Dict, Iterable, Optional, Set, Tuple
-from synapse.api.constants import EventTypes, JoinRules
+from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules
from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.state import StateFilter
from synapse.storage.databases.main.state_deltas import StateDeltasStore
@@ -360,7 +360,10 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
if hist_vis_id:
hist_vis_ev = await self.get_event(hist_vis_id, allow_none=True)
if hist_vis_ev:
- if hist_vis_ev.content.get("history_visibility") == "world_readable":
+ if (
+ hist_vis_ev.content.get("history_visibility")
+ == HistoryVisibility.WORLD_READABLE
+ ):
return True
return False
@@ -393,9 +396,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
sql = """
INSERT INTO user_directory_search(user_id, vector)
VALUES (?,
- setweight(to_tsvector('english', ?), 'A')
- || setweight(to_tsvector('english', ?), 'D')
- || setweight(to_tsvector('english', COALESCE(?, '')), 'B')
+ setweight(to_tsvector('simple', ?), 'A')
+ || setweight(to_tsvector('simple', ?), 'D')
+ || setweight(to_tsvector('simple', COALESCE(?, '')), 'B')
) ON CONFLICT (user_id) DO UPDATE SET vector=EXCLUDED.vector
"""
txn.execute(
@@ -415,9 +418,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
sql = """
INSERT INTO user_directory_search(user_id, vector)
VALUES (?,
- setweight(to_tsvector('english', ?), 'A')
- || setweight(to_tsvector('english', ?), 'D')
- || setweight(to_tsvector('english', COALESCE(?, '')), 'B')
+ setweight(to_tsvector('simple', ?), 'A')
+ || setweight(to_tsvector('simple', ?), 'D')
+ || setweight(to_tsvector('simple', COALESCE(?, '')), 'B')
)
"""
txn.execute(
@@ -432,9 +435,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
elif new_entry is False:
sql = """
UPDATE user_directory_search
- SET vector = setweight(to_tsvector('english', ?), 'A')
- || setweight(to_tsvector('english', ?), 'D')
- || setweight(to_tsvector('english', COALESCE(?, '')), 'B')
+ SET vector = setweight(to_tsvector('simple', ?), 'A')
+ || setweight(to_tsvector('simple', ?), 'D')
+ || setweight(to_tsvector('simple', COALESCE(?, '')), 'B')
WHERE user_id = ?
"""
txn.execute(
@@ -761,7 +764,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
INNER JOIN user_directory AS d USING (user_id)
WHERE
%s
- AND vector @@ to_tsquery('english', ?)
+ AND vector @@ to_tsquery('simple', ?)
ORDER BY
(CASE WHEN d.user_id IS NOT NULL THEN 4.0 ELSE 1.0 END)
* (CASE WHEN display_name IS NOT NULL THEN 1.2 ELSE 1.0 END)
@@ -770,13 +773,13 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
3 * ts_rank_cd(
'{0.1, 0.1, 0.9, 1.0}',
vector,
- to_tsquery('english', ?),
+ to_tsquery('simple', ?),
8
)
+ ts_rank_cd(
'{0.1, 0.1, 0.9, 1.0}',
vector,
- to_tsquery('english', ?),
+ to_tsquery('simple', ?),
8
)
)
diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py
index afd10f7bae..c03871f393 100644
--- a/synapse/storage/keys.py
+++ b/synapse/storage/keys.py
@@ -17,11 +17,12 @@
import logging
import attr
+from signedjson.types import VerifyKey
logger = logging.getLogger(__name__)
@attr.s(slots=True, frozen=True)
class FetchKeyResult:
- verify_key = attr.ib() # VerifyKey: the key itself
- valid_until_ts = attr.ib() # int: how long we can use this key for
+ verify_key = attr.ib(type=VerifyKey) # the key itself
+ valid_until_ts = attr.ib(type=int) # how long we can use this key for
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index 70e636b0ba..61fc49c69c 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -31,7 +31,14 @@ from synapse.logging.context import PreserveLoggingContext, make_deferred_yielda
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.databases import Databases
from synapse.storage.databases.main.events import DeltaState
-from synapse.types import Collection, PersistedEventPosition, RoomStreamToken, StateMap
+from synapse.storage.databases.main.events_worker import EventRedactBehaviour
+from synapse.types import (
+ Collection,
+ PersistedEventPosition,
+ RoomStreamToken,
+ StateMap,
+ get_domain_from_id,
+)
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.metrics import Measure
@@ -68,6 +75,21 @@ stale_forward_extremities_counter = Histogram(
buckets=(0, 1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"),
)
+state_resolutions_during_persistence = Counter(
+ "synapse_storage_events_state_resolutions_during_persistence",
+ "Number of times we had to do state res to calculate new current state",
+)
+
+potential_times_prune_extremities = Counter(
+ "synapse_storage_events_potential_times_prune_extremities",
+ "Number of times we might be able to prune extremities",
+)
+
+times_pruned_extremities = Counter(
+ "synapse_storage_events_times_pruned_extremities",
+ "Number of times we were actually be able to prune extremities",
+)
+
class _EventPeristenceQueue:
"""Queues up events so that they can be persisted in bulk with only one
@@ -454,7 +476,15 @@ class EventsPersistenceStorage:
latest_event_ids,
new_latest_event_ids,
)
- current_state, delta_ids = res
+ current_state, delta_ids, new_latest_event_ids = res
+
+ # there should always be at least one forward extremity.
+ # (except during the initial persistence of the send_join
+ # results, in which case there will be no existing
+ # extremities, so we'll `continue` above and skip this bit.)
+ assert new_latest_event_ids, "No forward extremities left!"
+
+ new_forward_extremeties[room_id] = new_latest_event_ids
# If either are not None then there has been a change,
# and we need to work out the delta (or use that
@@ -573,29 +603,35 @@ class EventsPersistenceStorage:
self,
room_id: str,
events_context: List[Tuple[EventBase, EventContext]],
- old_latest_event_ids: Iterable[str],
- new_latest_event_ids: Iterable[str],
- ) -> Tuple[Optional[StateMap[str]], Optional[StateMap[str]]]:
+ old_latest_event_ids: Set[str],
+ new_latest_event_ids: Set[str],
+ ) -> Tuple[Optional[StateMap[str]], Optional[StateMap[str]], Set[str]]:
"""Calculate the current state dict after adding some new events to
a room
Args:
- room_id (str):
+ room_id:
room to which the events are being added. Used for logging etc
- events_context (list[(EventBase, EventContext)]):
+ events_context:
events and contexts which are being added to the room
- old_latest_event_ids (iterable[str]):
+ old_latest_event_ids:
the old forward extremities for the room.
- new_latest_event_ids (iterable[str]):
+ new_latest_event_ids :
the new forward extremities for the room.
Returns:
- Returns a tuple of two state maps, the first being the full new current
- state and the second being the delta to the existing current state.
- If both are None then there has been no change.
+ Returns a tuple of two state maps and a set of new forward
+ extremities.
+
+ The first state map is the full new current state and the second
+ is the delta to the existing current state. If both are None then
+ there has been no change.
+
+ The function may prune some old entries from the set of new
+ forward extremities if it's safe to do so.
If there has been a change then we only return the delta if its
already been calculated. Conversely if we do know the delta then
@@ -672,7 +708,7 @@ class EventsPersistenceStorage:
# If they old and new groups are the same then we don't need to do
# anything.
if old_state_groups == new_state_groups:
- return None, None
+ return None, None, new_latest_event_ids
if len(new_state_groups) == 1 and len(old_state_groups) == 1:
# If we're going from one state group to another, lets check if
@@ -689,7 +725,7 @@ class EventsPersistenceStorage:
# the current state in memory then lets also return that,
# but it doesn't matter if we don't.
new_state = state_groups_map.get(new_state_group)
- return new_state, delta_ids
+ return new_state, delta_ids, new_latest_event_ids
# Now that we have calculated new_state_groups we need to get
# their state IDs so we can resolve to a single state set.
@@ -701,7 +737,7 @@ class EventsPersistenceStorage:
if len(new_state_groups) == 1:
# If there is only one state group, then we know what the current
# state is.
- return state_groups_map[new_state_groups.pop()], None
+ return state_groups_map[new_state_groups.pop()], None, new_latest_event_ids
# Ok, we need to defer to the state handler to resolve our state sets.
@@ -734,7 +770,139 @@ class EventsPersistenceStorage:
state_res_store=StateResolutionStore(self.main_store),
)
- return res.state, None
+ state_resolutions_during_persistence.inc()
+
+ # If the returned state matches the state group of one of the new
+ # forward extremities then we check if we are able to prune some state
+ # extremities.
+ if res.state_group and res.state_group in new_state_groups:
+ new_latest_event_ids = await self._prune_extremities(
+ room_id,
+ new_latest_event_ids,
+ res.state_group,
+ event_id_to_state_group,
+ events_context,
+ )
+
+ return res.state, None, new_latest_event_ids
+
+ async def _prune_extremities(
+ self,
+ room_id: str,
+ new_latest_event_ids: Set[str],
+ resolved_state_group: int,
+ event_id_to_state_group: Dict[str, int],
+ events_context: List[Tuple[EventBase, EventContext]],
+ ) -> Set[str]:
+ """See if we can prune any of the extremities after calculating the
+ resolved state.
+ """
+ potential_times_prune_extremities.inc()
+
+ # We keep all the extremities that have the same state group, and
+ # see if we can drop the others.
+ new_new_extrems = {
+ e
+ for e in new_latest_event_ids
+ if event_id_to_state_group[e] == resolved_state_group
+ }
+
+ dropped_extrems = set(new_latest_event_ids) - new_new_extrems
+
+ logger.debug("Might drop extremities: %s", dropped_extrems)
+
+ # We only drop events from the extremities list if:
+ # 1. we're not currently persisting them;
+ # 2. they're not our own events (or are dummy events); and
+ # 3. they're either:
+ # 1. over N hours old and more than N events ago (we use depth to
+ # calculate); or
+ # 2. we are persisting an event from the same domain and more than
+ # M events ago.
+ #
+ # The idea is that we don't want to drop events that are "legitimate"
+ # extremities (that we would want to include as prev events), only
+ # "stuck" extremities that are e.g. due to a gap in the graph.
+ #
+ # Note that we either drop all of them or none of them. If we only drop
+ # some of the events we don't know if state res would come to the same
+ # conclusion.
+
+ for ev, _ in events_context:
+ if ev.event_id in dropped_extrems:
+ logger.debug(
+ "Not dropping extremities: %s is being persisted", ev.event_id
+ )
+ return new_latest_event_ids
+
+ dropped_events = await self.main_store.get_events(
+ dropped_extrems,
+ allow_rejected=True,
+ redact_behaviour=EventRedactBehaviour.AS_IS,
+ )
+
+ new_senders = {get_domain_from_id(e.sender) for e, _ in events_context}
+
+ one_day_ago = self._clock.time_msec() - 24 * 60 * 60 * 1000
+ current_depth = max(e.depth for e, _ in events_context)
+ for event in dropped_events.values():
+ # If the event is a local dummy event then we should check it
+ # doesn't reference any local events, as we want to reference those
+ # if we send any new events.
+ #
+ # Note we do this recursively to handle the case where a dummy event
+ # references a dummy event that only references remote events.
+ #
+ # Ideally we'd figure out a way of still being able to drop old
+ # dummy events that reference local events, but this is good enough
+ # as a first cut.
+ events_to_check = [event]
+ while events_to_check:
+ new_events = set()
+ for event_to_check in events_to_check:
+ if self.is_mine_id(event_to_check.sender):
+ if event_to_check.type != EventTypes.Dummy:
+ logger.debug("Not dropping own event")
+ return new_latest_event_ids
+ new_events.update(event_to_check.prev_event_ids())
+
+ prev_events = await self.main_store.get_events(
+ new_events,
+ allow_rejected=True,
+ redact_behaviour=EventRedactBehaviour.AS_IS,
+ )
+ events_to_check = prev_events.values()
+
+ if (
+ event.origin_server_ts < one_day_ago
+ and event.depth < current_depth - 100
+ ):
+ continue
+
+ # We can be less conservative about dropping extremities from the
+ # same domain, though we do want to wait a little bit (otherwise
+ # we'll immediately remove all extremities from a given server).
+ if (
+ get_domain_from_id(event.sender) in new_senders
+ and event.depth < current_depth - 20
+ ):
+ continue
+
+ logger.debug(
+ "Not dropping as too new and not in new_senders: %s", new_senders,
+ )
+
+ return new_latest_event_ids
+
+ times_pruned_extremities.inc()
+
+ logger.info(
+ "Pruning forward extremities in room %s: from %s -> %s",
+ room_id,
+ new_latest_event_ids,
+ new_new_extrems,
+ )
+ return new_new_extrems
async def _calculate_state_delta(
self, room_id: str, current_state: StateMap[str]
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 459754feab..f91a2eae7a 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -18,9 +18,10 @@ import logging
import os
import re
from collections import Counter
-from typing import Optional, TextIO
+from typing import Generator, Iterable, List, Optional, TextIO, Tuple
import attr
+from typing_extensions import Counter as CounterType
from synapse.config.homeserver import HomeServerConfig
from synapse.storage.database import LoggingDatabaseConnection
@@ -70,7 +71,7 @@ def prepare_database(
db_conn: LoggingDatabaseConnection,
database_engine: BaseDatabaseEngine,
config: Optional[HomeServerConfig],
- databases: Collection[str] = ["main", "state"],
+ databases: Collection[str] = ("main", "state"),
):
"""Prepares a physical database for usage. Will either create all necessary tables
or upgrade from an older schema version.
@@ -155,7 +156,9 @@ def prepare_database(
raise
-def _setup_new_database(cur, database_engine, databases):
+def _setup_new_database(
+ cur: Cursor, database_engine: BaseDatabaseEngine, databases: Collection[str]
+) -> None:
"""Sets up the physical database by finding a base set of "full schemas" and
then applying any necessary deltas, including schemas from the given data
stores.
@@ -188,10 +191,9 @@ def _setup_new_database(cur, database_engine, databases):
folder as well those in the data stores specified.
Args:
- cur (Cursor): a database cursor
- database_engine (DatabaseEngine)
- databases (list[str]): The names of the databases to instantiate
- on the given physical database.
+ cur: a database cursor
+ database_engine
+ databases: The names of the databases to instantiate on the given physical database.
"""
# We're about to set up a brand new database so we check that its
@@ -199,12 +201,11 @@ def _setup_new_database(cur, database_engine, databases):
database_engine.check_new_database(cur)
current_dir = os.path.join(dir_path, "schema", "full_schemas")
- directory_entries = os.listdir(current_dir)
# First we find the highest full schema version we have
valid_versions = []
- for filename in directory_entries:
+ for filename in os.listdir(current_dir):
try:
ver = int(filename)
except ValueError:
@@ -237,7 +238,7 @@ def _setup_new_database(cur, database_engine, databases):
for database in databases
)
- directory_entries = []
+ directory_entries = [] # type: List[_DirectoryListing]
for directory in directories:
directory_entries.extend(
_DirectoryListing(file_name, os.path.join(directory, file_name))
@@ -275,15 +276,15 @@ def _setup_new_database(cur, database_engine, databases):
def _upgrade_existing_database(
- cur,
- current_version,
- applied_delta_files,
- upgraded,
- database_engine,
- config,
- databases,
- is_empty=False,
-):
+ cur: Cursor,
+ current_version: int,
+ applied_delta_files: List[str],
+ upgraded: bool,
+ database_engine: BaseDatabaseEngine,
+ config: Optional[HomeServerConfig],
+ databases: Collection[str],
+ is_empty: bool = False,
+) -> None:
"""Upgrades an existing physical database.
Delta files can either be SQL stored in *.sql files, or python modules
@@ -323,21 +324,20 @@ def _upgrade_existing_database(
for a version before applying those in the next version.
Args:
- cur (Cursor)
- current_version (int): The current version of the schema.
- applied_delta_files (list): A list of deltas that have already been
- applied.
- upgraded (bool): Whether the current version was generated by having
+ cur
+ current_version: The current version of the schema.
+ applied_delta_files: A list of deltas that have already been applied.
+ upgraded: Whether the current version was generated by having
applied deltas or from full schema file. If `True` the function
will never apply delta files for the given `current_version`, since
the current_version wasn't generated by applying those delta files.
- database_engine (DatabaseEngine)
- config (synapse.config.homeserver.HomeServerConfig|None):
+ database_engine
+ config:
None if we are initialising a blank database, otherwise the application
config
- databases (list[str]): The names of the databases to instantiate
+ databases: The names of the databases to instantiate
on the given physical database.
- is_empty (bool): Is this a blank database? I.e. do we need to run the
+ is_empty: Is this a blank database? I.e. do we need to run the
upgrade portions of the delta scripts.
"""
if is_empty:
@@ -358,6 +358,7 @@ def _upgrade_existing_database(
if not is_empty and "main" in databases:
from synapse.storage.databases.main import check_database_before_upgrade
+ assert config is not None
check_database_before_upgrade(cur, database_engine, config)
start_ver = current_version
@@ -388,10 +389,10 @@ def _upgrade_existing_database(
)
# Used to check if we have any duplicate file names
- file_name_counter = Counter()
+ file_name_counter = Counter() # type: CounterType[str]
# Now find which directories have anything of interest.
- directory_entries = []
+ directory_entries = [] # type: List[_DirectoryListing]
for directory in directories:
logger.debug("Looking for schema deltas in %s", directory)
try:
@@ -445,11 +446,11 @@ def _upgrade_existing_database(
module_name = "synapse.storage.v%d_%s" % (v, root_name)
with open(absolute_path) as python_file:
- module = imp.load_source(module_name, absolute_path, python_file)
+ module = imp.load_source(module_name, absolute_path, python_file) # type: ignore
logger.info("Running script %s", relative_path)
- module.run_create(cur, database_engine)
+ module.run_create(cur, database_engine) # type: ignore
if not is_empty:
- module.run_upgrade(cur, database_engine, config=config)
+ module.run_upgrade(cur, database_engine, config=config) # type: ignore
elif ext == ".pyc" or file_name == "__pycache__":
# Sometimes .pyc files turn up anyway even though we've
# disabled their generation; e.g. from distribution package
@@ -497,14 +498,15 @@ def _upgrade_existing_database(
logger.info("Schema now up to date")
-def _apply_module_schemas(txn, database_engine, config):
+def _apply_module_schemas(
+ txn: Cursor, database_engine: BaseDatabaseEngine, config: HomeServerConfig
+) -> None:
"""Apply the module schemas for the dynamic modules, if any
Args:
cur: database cursor
- database_engine: synapse database engine class
- config (synapse.config.homeserver.HomeServerConfig):
- application config
+ database_engine:
+ config: application config
"""
for (mod, _config) in config.password_providers:
if not hasattr(mod, "get_db_schema_files"):
@@ -515,15 +517,19 @@ def _apply_module_schemas(txn, database_engine, config):
)
-def _apply_module_schema_files(cur, database_engine, modname, names_and_streams):
+def _apply_module_schema_files(
+ cur: Cursor,
+ database_engine: BaseDatabaseEngine,
+ modname: str,
+ names_and_streams: Iterable[Tuple[str, TextIO]],
+) -> None:
"""Apply the module schemas for a single module
Args:
cur: database cursor
database_engine: synapse database engine class
- modname (str): fully qualified name of the module
- names_and_streams (Iterable[(str, file)]): the names and streams of
- schemas to be applied
+ modname: fully qualified name of the module
+ names_and_streams: the names and streams of schemas to be applied
"""
cur.execute(
"SELECT file FROM applied_module_schemas WHERE module_name = ?", (modname,),
@@ -549,7 +555,7 @@ def _apply_module_schema_files(cur, database_engine, modname, names_and_streams)
)
-def get_statements(f):
+def get_statements(f: Iterable[str]) -> Generator[str, None, None]:
statement_buffer = ""
in_comment = False # If we're in a /* ... */ style comment
@@ -594,17 +600,19 @@ def get_statements(f):
statement_buffer = statements[-1].strip()
-def executescript(txn, schema_path):
+def executescript(txn: Cursor, schema_path: str) -> None:
with open(schema_path, "r") as f:
execute_statements_from_stream(txn, f)
-def execute_statements_from_stream(cur: Cursor, f: TextIO):
+def execute_statements_from_stream(cur: Cursor, f: TextIO) -> None:
for statement in get_statements(f):
cur.execute(statement)
-def _get_or_create_schema_state(txn, database_engine):
+def _get_or_create_schema_state(
+ txn: Cursor, database_engine: BaseDatabaseEngine
+) -> Optional[Tuple[int, List[str], bool]]:
# Bluntly try creating the schema_version tables.
schema_path = os.path.join(dir_path, "schema", "schema_version.sql")
executescript(txn, schema_path)
@@ -612,7 +620,6 @@ def _get_or_create_schema_state(txn, database_engine):
txn.execute("SELECT version, upgraded FROM schema_version")
row = txn.fetchone()
current_version = int(row[0]) if row else None
- upgraded = bool(row[1]) if row else None
if current_version:
txn.execute(
@@ -620,6 +627,7 @@ def _get_or_create_schema_state(txn, database_engine):
(current_version,),
)
applied_deltas = [d for d, in txn]
+ upgraded = bool(row[1])
return current_version, applied_deltas, upgraded
return None
@@ -634,5 +642,5 @@ class _DirectoryListing:
`file_name` attr is kept first.
"""
- file_name = attr.ib()
- absolute_path = attr.ib()
+ file_name = attr.ib(type=str)
+ absolute_path = attr.ib(type=str)
diff --git a/synapse/storage/purge_events.py b/synapse/storage/purge_events.py
index bfa0a9fd06..6c359c1aae 100644
--- a/synapse/storage/purge_events.py
+++ b/synapse/storage/purge_events.py
@@ -15,7 +15,12 @@
import itertools
import logging
-from typing import Set
+from typing import TYPE_CHECKING, Set
+
+from synapse.storage.databases import Databases
+
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__)
@@ -24,10 +29,10 @@ class PurgeEventsStorage:
"""High level interface for purging rooms and event history.
"""
- def __init__(self, hs, stores):
+ def __init__(self, hs: "HomeServer", stores: Databases):
self.stores = stores
- async def purge_room(self, room_id: str):
+ async def purge_room(self, room_id: str) -> None:
"""Deletes all record of a room
"""
diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py
index cec96ad6a7..2564f34b47 100644
--- a/synapse/storage/relations.py
+++ b/synapse/storage/relations.py
@@ -14,10 +14,12 @@
# limitations under the License.
import logging
+from typing import Any, Dict, List, Optional, Tuple
import attr
from synapse.api.errors import SynapseError
+from synapse.types import JsonDict
logger = logging.getLogger(__name__)
@@ -27,18 +29,18 @@ class PaginationChunk:
"""Returned by relation pagination APIs.
Attributes:
- chunk (list): The rows returned by pagination
- next_batch (Any|None): Token to fetch next set of results with, if
+ chunk: The rows returned by pagination
+ next_batch: Token to fetch next set of results with, if
None then there are no more results.
- prev_batch (Any|None): Token to fetch previous set of results with, if
+ prev_batch: Token to fetch previous set of results with, if
None then there are no previous results.
"""
- chunk = attr.ib()
- next_batch = attr.ib(default=None)
- prev_batch = attr.ib(default=None)
+ chunk = attr.ib(type=List[JsonDict])
+ next_batch = attr.ib(type=Optional[Any], default=None)
+ prev_batch = attr.ib(type=Optional[Any], default=None)
- def to_dict(self):
+ def to_dict(self) -> Dict[str, Any]:
d = {"chunk": self.chunk}
if self.next_batch:
@@ -59,25 +61,25 @@ class RelationPaginationToken:
boundaries of the chunk as pagination tokens.
Attributes:
- topological (int): The topological ordering of the boundary event
- stream (int): The stream ordering of the boundary event.
+ topological: The topological ordering of the boundary event
+ stream: The stream ordering of the boundary event.
"""
- topological = attr.ib()
- stream = attr.ib()
+ topological = attr.ib(type=int)
+ stream = attr.ib(type=int)
@staticmethod
- def from_string(string):
+ def from_string(string: str) -> "RelationPaginationToken":
try:
t, s = string.split("-")
return RelationPaginationToken(int(t), int(s))
except ValueError:
raise SynapseError(400, "Invalid token")
- def to_string(self):
+ def to_string(self) -> str:
return "%d-%d" % (self.topological, self.stream)
- def as_tuple(self):
+ def as_tuple(self) -> Tuple[Any, ...]:
return attr.astuple(self)
@@ -89,23 +91,23 @@ class AggregationPaginationToken:
aggregation groups, we can just use them as our pagination token.
Attributes:
- count (int): The count of relations in the boundar group.
- stream (int): The MAX stream ordering in the boundary group.
+ count: The count of relations in the boundary group.
+ stream: The MAX stream ordering in the boundary group.
"""
- count = attr.ib()
- stream = attr.ib()
+ count = attr.ib(type=int)
+ stream = attr.ib(type=int)
@staticmethod
- def from_string(string):
+ def from_string(string: str) -> "AggregationPaginationToken":
try:
c, s = string.split("-")
return AggregationPaginationToken(int(c), int(s))
except ValueError:
raise SynapseError(400, "Invalid token")
- def to_string(self):
+ def to_string(self) -> str:
return "%d-%d" % (self.count, self.stream)
- def as_tuple(self):
+ def as_tuple(self) -> Tuple[Any, ...]:
return attr.astuple(self)
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 08a69f2f96..31ccbf23dc 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -12,9 +12,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
-from typing import Awaitable, Dict, Iterable, List, Optional, Set, Tuple, TypeVar
+from typing import (
+ TYPE_CHECKING,
+ Awaitable,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Set,
+ Tuple,
+ TypeVar,
+)
import attr
@@ -22,6 +31,10 @@ from synapse.api.constants import EventTypes
from synapse.events import EventBase
from synapse.types import MutableStateMap, StateMap
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+ from synapse.storage.databases import Databases
+
logger = logging.getLogger(__name__)
# Used for generic functions below
@@ -330,10 +343,12 @@ class StateGroupStorage:
"""High level interface to fetching state for event.
"""
- def __init__(self, hs, stores):
+ def __init__(self, hs: "HomeServer", stores: "Databases"):
self.stores = stores
- async def get_state_group_delta(self, state_group: int):
+ async def get_state_group_delta(
+ self, state_group: int
+ ) -> Tuple[Optional[int], Optional[StateMap[str]]]:
"""Given a state group try to return a previous group and a delta between
the old and the new.
@@ -341,8 +356,8 @@ class StateGroupStorage:
state_group: The state group used to retrieve state deltas.
Returns:
- Tuple[Optional[int], Optional[StateMap[str]]]:
- (prev_group, delta_ids)
+ A tuple of the previous group and a state map of the event IDs which
+ make up the delta between the old and new state groups.
"""
return await self.stores.state.get_state_group_delta(state_group)
@@ -436,7 +451,7 @@ class StateGroupStorage:
async def get_state_for_events(
self, event_ids: List[str], state_filter: StateFilter = StateFilter.all()
- ):
+ ) -> Dict[str, StateMap[EventBase]]:
"""Given a list of event_ids and type tuples, return a list of state
dicts for each event.
@@ -472,7 +487,7 @@ class StateGroupStorage:
async def get_state_ids_for_events(
self, event_ids: List[str], state_filter: StateFilter = StateFilter.all()
- ):
+ ) -> Dict[str, StateMap[str]]:
"""
Get the state dicts corresponding to a list of events, containing the event_ids
of the state events (as opposed to the events themselves)
@@ -500,7 +515,7 @@ class StateGroupStorage:
async def get_state_for_event(
self, event_id: str, state_filter: StateFilter = StateFilter.all()
- ):
+ ) -> StateMap[EventBase]:
"""
Get the state dict corresponding to a particular event
@@ -516,7 +531,7 @@ class StateGroupStorage:
async def get_state_ids_for_event(
self, event_id: str, state_filter: StateFilter = StateFilter.all()
- ):
+ ) -> StateMap[str]:
"""
Get the state dict corresponding to a particular event
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 02d71302ea..133c0e7a28 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -153,12 +153,12 @@ class StreamIdGenerator:
return _AsyncCtxManagerWrapper(manager())
- def get_current_token(self):
+ def get_current_token(self) -> int:
"""Returns the maximum stream id such that all stream ids less than or
equal to it have been successfully persisted.
Returns:
- int
+ The maximum stream id.
"""
with self._lock:
if self._unfinished_ids:
|