From a7bdf98d01d2225a479753a85ba81adf02b16a32 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 5 Aug 2020 21:38:57 +0100 Subject: Rename database classes to make some sense (#8033) --- synapse/storage/databases/main/ui_auth.py | 300 ++++++++++++++++++++++++++++++ 1 file changed, 300 insertions(+) create mode 100644 synapse/storage/databases/main/ui_auth.py (limited to 'synapse/storage/databases/main/ui_auth.py') diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py new file mode 100644 index 0000000000..37276f73f8 --- /dev/null +++ b/synapse/storage/databases/main/ui_auth.py @@ -0,0 +1,300 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Dict, Optional, Union + +import attr +from canonicaljson import json + +from synapse.api.errors import StoreError +from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.types import JsonDict +from synapse.util import stringutils as stringutils + + +@attr.s +class UIAuthSessionData: + session_id = attr.ib(type=str) + # The dictionary from the client root level, not the 'auth' key. + clientdict = attr.ib(type=JsonDict) + # The URI and method the session was intiatied with. These are checked at + # each stage of the authentication to ensure that the asked for operation + # has not changed. + uri = attr.ib(type=str) + method = attr.ib(type=str) + # A string description of the operation that the current authentication is + # authorising. + description = attr.ib(type=str) + + +class UIAuthWorkerStore(SQLBaseStore): + """ + Manage user interactive authentication sessions. + """ + + async def create_ui_auth_session( + self, clientdict: JsonDict, uri: str, method: str, description: str, + ) -> UIAuthSessionData: + """ + Creates a new user interactive authentication session. + + The session can be used to track the stages necessary to authenticate a + user across multiple HTTP requests. + + Args: + clientdict: + The dictionary from the client root level, not the 'auth' key. + uri: + The URI this session was initiated with, this is checked at each + stage of the authentication to ensure that the asked for + operation has not changed. + method: + The method this session was initiated with, this is checked at each + stage of the authentication to ensure that the asked for + operation has not changed. + description: + A string description of the operation that the current + authentication is authorising. + Returns: + The newly created session. + Raises: + StoreError if a unique session ID cannot be generated. + """ + # The clientdict gets stored as JSON. + clientdict_json = json.dumps(clientdict) + + # autogen a session ID and try to create it. We may clash, so just + # try a few times till one goes through, giving up eventually. + attempts = 0 + while attempts < 5: + session_id = stringutils.random_string(24) + + try: + await self.db_pool.simple_insert( + table="ui_auth_sessions", + values={ + "session_id": session_id, + "clientdict": clientdict_json, + "uri": uri, + "method": method, + "description": description, + "serverdict": "{}", + "creation_time": self.hs.get_clock().time_msec(), + }, + desc="create_ui_auth_session", + ) + return UIAuthSessionData( + session_id, clientdict, uri, method, description + ) + except self.db_pool.engine.module.IntegrityError: + attempts += 1 + raise StoreError(500, "Couldn't generate a session ID.") + + async def get_ui_auth_session(self, session_id: str) -> UIAuthSessionData: + """Retrieve a UI auth session. + + Args: + session_id: The ID of the session. + Returns: + A dict containing the device information. + Raises: + StoreError if the session is not found. + """ + result = await self.db_pool.simple_select_one( + table="ui_auth_sessions", + keyvalues={"session_id": session_id}, + retcols=("clientdict", "uri", "method", "description"), + desc="get_ui_auth_session", + ) + + result["clientdict"] = db_to_json(result["clientdict"]) + + return UIAuthSessionData(session_id, **result) + + async def mark_ui_auth_stage_complete( + self, session_id: str, stage_type: str, result: Union[str, bool, JsonDict], + ): + """ + Mark a session stage as completed. + + Args: + session_id: The ID of the corresponding session. + stage_type: The completed stage type. + result: The result of the stage verification. + Raises: + StoreError if the session cannot be found. + """ + # Add (or update) the results of the current stage to the database. + # + # Note that we need to allow for the same stage to complete multiple + # times here so that registration is idempotent. + try: + await self.db_pool.simple_upsert( + table="ui_auth_sessions_credentials", + keyvalues={"session_id": session_id, "stage_type": stage_type}, + values={"result": json.dumps(result)}, + desc="mark_ui_auth_stage_complete", + ) + except self.db_pool.engine.module.IntegrityError: + raise StoreError(400, "Unknown session ID: %s" % (session_id,)) + + async def get_completed_ui_auth_stages( + self, session_id: str + ) -> Dict[str, Union[str, bool, JsonDict]]: + """ + Retrieve the completed stages of a UI authentication session. + + Args: + session_id: The ID of the session. + Returns: + The completed stages mapped to the result of the verification of + that auth-type. + """ + results = {} + for row in await self.db_pool.simple_select_list( + table="ui_auth_sessions_credentials", + keyvalues={"session_id": session_id}, + retcols=("stage_type", "result"), + desc="get_completed_ui_auth_stages", + ): + results[row["stage_type"]] = db_to_json(row["result"]) + + return results + + async def set_ui_auth_clientdict( + self, session_id: str, clientdict: JsonDict + ) -> None: + """ + Store an updated clientdict for a given session ID. + + Args: + session_id: The ID of this session as returned from check_auth + clientdict: + The dictionary from the client root level, not the 'auth' key. + """ + # The clientdict gets stored as JSON. + clientdict_json = json.dumps(clientdict) + + await self.db_pool.simple_update_one( + table="ui_auth_sessions", + keyvalues={"session_id": session_id}, + updatevalues={"clientdict": clientdict_json}, + desc="set_ui_auth_client_dict", + ) + + async def set_ui_auth_session_data(self, session_id: str, key: str, value: Any): + """ + Store a key-value pair into the sessions data associated with this + request. This data is stored server-side and cannot be modified by + the client. + + Args: + session_id: The ID of this session as returned from check_auth + key: The key to store the data under + value: The data to store + Raises: + StoreError if the session cannot be found. + """ + await self.db_pool.runInteraction( + "set_ui_auth_session_data", + self._set_ui_auth_session_data_txn, + session_id, + key, + value, + ) + + def _set_ui_auth_session_data_txn(self, txn, session_id: str, key: str, value: Any): + # Get the current value. + result = self.db_pool.simple_select_one_txn( + txn, + table="ui_auth_sessions", + keyvalues={"session_id": session_id}, + retcols=("serverdict",), + ) + + # Update it and add it back to the database. + serverdict = db_to_json(result["serverdict"]) + serverdict[key] = value + + self.db_pool.simple_update_one_txn( + txn, + table="ui_auth_sessions", + keyvalues={"session_id": session_id}, + updatevalues={"serverdict": json.dumps(serverdict)}, + ) + + async def get_ui_auth_session_data( + self, session_id: str, key: str, default: Optional[Any] = None + ) -> Any: + """ + Retrieve data stored with set_session_data + + Args: + session_id: The ID of this session as returned from check_auth + key: The key to store the data under + default: Value to return if the key has not been set + Raises: + StoreError if the session cannot be found. + """ + result = await self.db_pool.simple_select_one( + table="ui_auth_sessions", + keyvalues={"session_id": session_id}, + retcols=("serverdict",), + desc="get_ui_auth_session_data", + ) + + serverdict = db_to_json(result["serverdict"]) + + return serverdict.get(key, default) + + +class UIAuthStore(UIAuthWorkerStore): + def delete_old_ui_auth_sessions(self, expiration_time: int): + """ + Remove sessions which were last used earlier than the expiration time. + + Args: + expiration_time: The latest time that is still considered valid. + This is an epoch time in milliseconds. + + """ + return self.db_pool.runInteraction( + "delete_old_ui_auth_sessions", + self._delete_old_ui_auth_sessions_txn, + expiration_time, + ) + + def _delete_old_ui_auth_sessions_txn(self, txn, expiration_time: int): + # Get the expired sessions. + sql = "SELECT session_id FROM ui_auth_sessions WHERE creation_time <= ?" + txn.execute(sql, [expiration_time]) + session_ids = [r[0] for r in txn.fetchall()] + + # Delete the corresponding completed credentials. + self.db_pool.simple_delete_many_txn( + txn, + table="ui_auth_sessions_credentials", + column="session_id", + iterable=session_ids, + keyvalues={}, + ) + + # Finally, delete the sessions. + self.db_pool.simple_delete_many_txn( + txn, + table="ui_auth_sessions", + column="session_id", + iterable=session_ids, + keyvalues={}, + ) -- cgit 1.5.1 From 5eac0b7e76c8316604480faf2d6158a0e1d68466 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 20 Aug 2020 09:00:59 -0400 Subject: Add more types to synapse.storage.database. (#8127) --- changelog.d/8127.misc | 1 + synapse/storage/database.py | 577 ++++++++++++++++++------------ synapse/storage/databases/main/ui_auth.py | 11 +- 3 files changed, 367 insertions(+), 222 deletions(-) create mode 100644 changelog.d/8127.misc (limited to 'synapse/storage/databases/main/ui_auth.py') diff --git a/changelog.d/8127.misc b/changelog.d/8127.misc new file mode 100644 index 0000000000..cb557122aa --- /dev/null +++ b/changelog.d/8127.misc @@ -0,0 +1 @@ +Add type hints to `synapse.storage.database`. diff --git a/synapse/storage/database.py b/synapse/storage/database.py index b9aef96b08..bc327e344e 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -28,6 +28,7 @@ from typing import ( Optional, Tuple, TypeVar, + Union, ) from prometheus_client import Histogram @@ -125,7 +126,7 @@ class LoggingTransaction: method. Args: - txn: The database transcation object to wrap. + txn: The database transaction object to wrap. name: The name of this transactions for logging. database_engine after_callbacks: A list that callbacks will be appended to @@ -160,7 +161,7 @@ class LoggingTransaction: self.after_callbacks = after_callbacks self.exception_callbacks = exception_callbacks - def call_after(self, callback: "Callable[..., None]", *args, **kwargs): + def call_after(self, callback: "Callable[..., None]", *args: Any, **kwargs: Any): """Call the given callback on the main twisted thread after the transaction has finished. Used to invalidate the caches on the correct thread. @@ -171,7 +172,9 @@ class LoggingTransaction: assert self.after_callbacks is not None self.after_callbacks.append((callback, args, kwargs)) - def call_on_exception(self, callback: "Callable[..., None]", *args, **kwargs): + def call_on_exception( + self, callback: "Callable[..., None]", *args: Any, **kwargs: Any + ): # if self.exception_callbacks is None, that means that whatever constructed the # LoggingTransaction isn't expecting there to be any callbacks; assert that # is not the case. @@ -195,7 +198,7 @@ class LoggingTransaction: def description(self) -> Any: return self.txn.description - def execute_batch(self, sql, args): + def execute_batch(self, sql: str, args: Iterable[Iterable[Any]]) -> None: if isinstance(self.database_engine, PostgresEngine): from psycopg2.extras import execute_batch # type: ignore @@ -204,17 +207,17 @@ class LoggingTransaction: for val in args: self.execute(sql, val) - def execute(self, sql: str, *args: Any): + def execute(self, sql: str, *args: Any) -> None: self._do_execute(self.txn.execute, sql, *args) - def executemany(self, sql: str, *args: Any): + def executemany(self, sql: str, *args: Any) -> None: self._do_execute(self.txn.executemany, sql, *args) def _make_sql_one_line(self, sql: str) -> str: "Strip newlines out of SQL so that the loggers in the DB are on one line" return " ".join(line.strip() for line in sql.splitlines() if line.strip()) - def _do_execute(self, func, sql, *args): + def _do_execute(self, func, sql: str, *args: Any) -> None: sql = self._make_sql_one_line(sql) # TODO(paul): Maybe use 'info' and 'debug' for values? @@ -240,7 +243,7 @@ class LoggingTransaction: sql_logger.debug("[SQL time] {%s} %f sec", self.name, secs) sql_query_timer.labels(sql.split()[0]).observe(secs) - def close(self): + def close(self) -> None: self.txn.close() @@ -249,13 +252,13 @@ class PerformanceCounters(object): self.current_counters = {} self.previous_counters = {} - def update(self, key, duration_secs): + def update(self, key: str, duration_secs: float) -> None: count, cum_time = self.current_counters.get(key, (0, 0)) count += 1 cum_time += duration_secs self.current_counters[key] = (count, cum_time) - def interval(self, interval_duration_secs, limit=3): + def interval(self, interval_duration_secs: float, limit: int = 3) -> str: counters = [] for name, (count, cum_time) in self.current_counters.items(): prev_count, prev_time = self.previous_counters.get(name, (0, 0)) @@ -279,6 +282,9 @@ class PerformanceCounters(object): return top_n_counters +R = TypeVar("R") + + class DatabasePool(object): """Wraps a single physical database and connection pool. @@ -327,12 +333,12 @@ class DatabasePool(object): self._check_safe_to_upsert, ) - def is_running(self): + def is_running(self) -> bool: """Is the database pool currently running """ return self._db_pool.running - async def _check_safe_to_upsert(self): + async def _check_safe_to_upsert(self) -> None: """ Is it safe to use native UPSERT? @@ -363,7 +369,7 @@ class DatabasePool(object): self._check_safe_to_upsert, ) - def start_profiling(self): + def start_profiling(self) -> None: self._previous_loop_ts = monotonic_time() def loop(): @@ -387,8 +393,15 @@ class DatabasePool(object): self._clock.looping_call(loop, 10000) def new_transaction( - self, conn, desc, after_callbacks, exception_callbacks, func, *args, **kwargs - ): + self, + conn: Connection, + desc: str, + after_callbacks: List[_CallbackListEntry], + exception_callbacks: List[_CallbackListEntry], + func: "Callable[..., R]", + *args: Any, + **kwargs: Any + ) -> R: start = monotonic_time() txn_id = self._TXN_ID @@ -537,7 +550,9 @@ class DatabasePool(object): return result - async def runWithConnection(self, func: Callable, *args: Any, **kwargs: Any) -> Any: + async def runWithConnection( + self, func: "Callable[..., R]", *args: Any, **kwargs: Any + ) -> R: """Wraps the .runWithConnection() method on the underlying db_pool. Arguments: @@ -576,11 +591,11 @@ class DatabasePool(object): ) @staticmethod - def cursor_to_dict(cursor): + def cursor_to_dict(cursor: Cursor) -> List[Dict[str, Any]]: """Converts a SQL cursor into an list of dicts. Args: - cursor : The DBAPI cursor which has executed a query. + cursor: The DBAPI cursor which has executed a query. Returns: A list of dicts where the key is the column header. """ @@ -588,7 +603,7 @@ class DatabasePool(object): results = [dict(zip(col_headers, row)) for row in cursor] return results - def execute(self, desc, decoder, query, *args): + def execute(self, desc: str, decoder: Callable, query: str, *args: Any): """Runs a single query for a result set. Args: @@ -597,7 +612,7 @@ class DatabasePool(object): query - The query string to execute *args - Query args. Returns: - The result of decoder(results) + Deferred which results to the result of decoder(results) """ def interaction(txn): @@ -612,20 +627,25 @@ class DatabasePool(object): # "Simple" SQL API methods that operate on a single table with no JOINs, # no complex WHERE clauses, just a dict of values for columns. - async def simple_insert(self, table, values, or_ignore=False, desc="simple_insert"): + async def simple_insert( + self, + table: str, + values: Dict[str, Any], + or_ignore: bool = False, + desc: str = "simple_insert", + ) -> bool: """Executes an INSERT query on the named table. Args: - table : string giving the table name - values : dict of new column names and values for them - or_ignore : bool stating whether an exception should be raised + table: string giving the table name + values: dict of new column names and values for them + or_ignore: bool stating whether an exception should be raised when a conflicting row already exists. If True, False will be returned by the function instead - desc : string giving a description of the transaction + desc: string giving a description of the transaction Returns: - bool: Whether the row was inserted or not. Only useful when - `or_ignore` is True + Whether the row was inserted or not. Only useful when `or_ignore` is True """ try: await self.runInteraction(desc, self.simple_insert_txn, table, values) @@ -638,7 +658,9 @@ class DatabasePool(object): return True @staticmethod - def simple_insert_txn(txn, table, values): + def simple_insert_txn( + txn: LoggingTransaction, table: str, values: Dict[str, Any] + ) -> None: keys, vals = zip(*values.items()) sql = "INSERT INTO %s (%s) VALUES(%s)" % ( @@ -649,11 +671,15 @@ class DatabasePool(object): txn.execute(sql, vals) - def simple_insert_many(self, table, values, desc): + def simple_insert_many( + self, table: str, values: List[Dict[str, Any]], desc: str + ) -> defer.Deferred: return self.runInteraction(desc, self.simple_insert_many_txn, table, values) @staticmethod - def simple_insert_many_txn(txn, table, values): + def simple_insert_many_txn( + txn: LoggingTransaction, table: str, values: List[Dict[str, Any]] + ) -> None: if not values: return @@ -683,13 +709,13 @@ class DatabasePool(object): async def simple_upsert( self, - table, - keyvalues, - values, - insertion_values={}, - desc="simple_upsert", - lock=True, - ): + table: str, + keyvalues: Dict[str, Any], + values: Dict[str, Any], + insertion_values: Dict[str, Any] = {}, + desc: str = "simple_upsert", + lock: bool = True, + ) -> Optional[bool]: """ `lock` should generally be set to True (the default), but can be set @@ -703,16 +729,14 @@ class DatabasePool(object): this table. Args: - table (str): The table to upsert into - keyvalues (dict): The unique key columns and their new values - values (dict): The nonunique columns and their new values - insertion_values (dict): additional key/values to use only when - inserting - lock (bool): True to lock the table when doing the upsert. + table: The table to upsert into + keyvalues: The unique key columns and their new values + values: The nonunique columns and their new values + insertion_values: additional key/values to use only when inserting + lock: True to lock the table when doing the upsert. Returns: - None or bool: Native upserts always return None. Emulated - upserts return True if a new entry was created, False if an existing - one was updated. + Native upserts always return None. Emulated upserts return True if a + new entry was created, False if an existing one was updated. """ attempts = 0 while True: @@ -739,29 +763,34 @@ class DatabasePool(object): ) def simple_upsert_txn( - self, txn, table, keyvalues, values, insertion_values={}, lock=True - ): + self, + txn: LoggingTransaction, + table: str, + keyvalues: Dict[str, Any], + values: Dict[str, Any], + insertion_values: Dict[str, Any] = {}, + lock: bool = True, + ) -> Optional[bool]: """ Pick the UPSERT method which works best on the platform. Either the native one (Pg9.5+, recent SQLites), or fall back to an emulated method. Args: txn: The transaction to use. - table (str): The table to upsert into - keyvalues (dict): The unique key tables and their new values - values (dict): The nonunique columns and their new values - insertion_values (dict): additional key/values to use only when - inserting - lock (bool): True to lock the table when doing the upsert. + table: The table to upsert into + keyvalues: The unique key tables and their new values + values: The nonunique columns and their new values + insertion_values: additional key/values to use only when inserting + lock: True to lock the table when doing the upsert. Returns: - None or bool: Native upserts always return None. Emulated - upserts return True if a new entry was created, False if an existing - one was updated. + Native upserts always return None. Emulated upserts return True if a + new entry was created, False if an existing one was updated. """ if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables: - return self.simple_upsert_txn_native_upsert( + self.simple_upsert_txn_native_upsert( txn, table, keyvalues, values, insertion_values=insertion_values ) + return None else: return self.simple_upsert_txn_emulated( txn, @@ -773,18 +802,23 @@ class DatabasePool(object): ) def simple_upsert_txn_emulated( - self, txn, table, keyvalues, values, insertion_values={}, lock=True - ): + self, + txn: LoggingTransaction, + table: str, + keyvalues: Dict[str, Any], + values: Dict[str, Any], + insertion_values: Dict[str, Any] = {}, + lock: bool = True, + ) -> bool: """ Args: - table (str): The table to upsert into - keyvalues (dict): The unique key tables and their new values - values (dict): The nonunique columns and their new values - insertion_values (dict): additional key/values to use only when - inserting - lock (bool): True to lock the table when doing the upsert. + table: The table to upsert into + keyvalues: The unique key tables and their new values + values: The nonunique columns and their new values + insertion_values: additional key/values to use only when inserting + lock: True to lock the table when doing the upsert. Returns: - bool: Return True if a new entry was created, False if an existing + Returns True if a new entry was created, False if an existing one was updated. """ # We need to lock the table :(, unless we're *really* careful @@ -842,19 +876,21 @@ class DatabasePool(object): return True def simple_upsert_txn_native_upsert( - self, txn, table, keyvalues, values, insertion_values={} - ): + self, + txn: LoggingTransaction, + table: str, + keyvalues: Dict[str, Any], + values: Dict[str, Any], + insertion_values: Dict[str, Any] = {}, + ) -> None: """ Use the native UPSERT functionality in recent PostgreSQL versions. Args: - table (str): The table to upsert into - keyvalues (dict): The unique key tables and their new values - values (dict): The nonunique columns and their new values - insertion_values (dict): additional key/values to use only when - inserting - Returns: - None + table: The table to upsert into + keyvalues: The unique key tables and their new values + values: The nonunique columns and their new values + insertion_values: additional key/values to use only when inserting """ allvalues = {} # type: Dict[str, Any] allvalues.update(keyvalues) @@ -985,18 +1021,22 @@ class DatabasePool(object): return txn.execute_batch(sql, args) def simple_select_one( - self, table, keyvalues, retcols, allow_none=False, desc="simple_select_one" - ): + self, + table: str, + keyvalues: Dict[str, Any], + retcols: Iterable[str], + allow_none: bool = False, + desc: str = "simple_select_one", + ) -> defer.Deferred: """Executes a SELECT query on the named table, which is expected to return a single row, returning multiple columns from it. Args: - table : string giving the table name - keyvalues : dict of column names and values to select the row with - retcols : list of strings giving the names of the columns to return - - allow_none : If true, return None instead of failing if the SELECT - statement returns no rows + table: string giving the table name + keyvalues: dict of column names and values to select the row with + retcols: list of strings giving the names of the columns to return + allow_none: If true, return None instead of failing if the SELECT + statement returns no rows """ return self.runInteraction( desc, self.simple_select_one_txn, table, keyvalues, retcols, allow_none @@ -1004,19 +1044,22 @@ class DatabasePool(object): def simple_select_one_onecol( self, - table, - keyvalues, - retcol, - allow_none=False, - desc="simple_select_one_onecol", - ): + table: str, + keyvalues: Dict[str, Any], + retcol: Iterable[str], + allow_none: bool = False, + desc: str = "simple_select_one_onecol", + ) -> defer.Deferred: """Executes a SELECT query on the named table, which is expected to return a single row, returning a single column from it. Args: - table : string giving the table name - keyvalues : dict of column names and values to select the row with - retcol : string giving the name of the column to return + table: string giving the table name + keyvalues: dict of column names and values to select the row with + retcol: string giving the name of the column to return + allow_none: If true, return None instead of failing if the SELECT + statement returns no rows + desc: description of the transaction, for logging and metrics """ return self.runInteraction( desc, @@ -1029,8 +1072,13 @@ class DatabasePool(object): @classmethod def simple_select_one_onecol_txn( - cls, txn, table, keyvalues, retcol, allow_none=False - ): + cls, + txn: LoggingTransaction, + table: str, + keyvalues: Dict[str, Any], + retcol: Iterable[str], + allow_none: bool = False, + ) -> Optional[Any]: ret = cls.simple_select_onecol_txn( txn, table=table, keyvalues=keyvalues, retcol=retcol ) @@ -1044,7 +1092,12 @@ class DatabasePool(object): raise StoreError(404, "No row found") @staticmethod - def simple_select_onecol_txn(txn, table, keyvalues, retcol): + def simple_select_onecol_txn( + txn: LoggingTransaction, + table: str, + keyvalues: Dict[str, Any], + retcol: Iterable[str], + ) -> List[Any]: sql = ("SELECT %(retcol)s FROM %(table)s") % {"retcol": retcol, "table": table} if keyvalues: @@ -1056,15 +1109,19 @@ class DatabasePool(object): return [r[0] for r in txn] def simple_select_onecol( - self, table, keyvalues, retcol, desc="simple_select_onecol" - ): + self, + table: str, + keyvalues: Optional[Dict[str, Any]], + retcol: str, + desc: str = "simple_select_onecol", + ) -> defer.Deferred: """Executes a SELECT query on the named table, which returns a list comprising of the values of the named column from the selected rows. Args: - table (str): table name - keyvalues (dict|None): column names and values to select the rows with - retcol (str): column whos value we wish to retrieve. + table: table name + keyvalues: column names and values to select the rows with + retcol: column whos value we wish to retrieve. Returns: Deferred: Results in a list @@ -1073,16 +1130,22 @@ class DatabasePool(object): desc, self.simple_select_onecol_txn, table, keyvalues, retcol ) - def simple_select_list(self, table, keyvalues, retcols, desc="simple_select_list"): + def simple_select_list( + self, + table: str, + keyvalues: Optional[Dict[str, Any]], + retcols: Iterable[str], + desc: str = "simple_select_list", + ) -> defer.Deferred: """Executes a SELECT query on the named table, which may return zero or more rows, returning the result as a list of dicts. Args: - table (str): the table name - keyvalues (dict[str, Any] | None): + table: the table name + keyvalues: column names and values to select the rows with, or None to not apply a WHERE clause. - retcols (iterable[str]): the names of the columns to return + retcols: the names of the columns to return Returns: defer.Deferred: resolves to list[dict[str, Any]] """ @@ -1091,17 +1154,23 @@ class DatabasePool(object): ) @classmethod - def simple_select_list_txn(cls, txn, table, keyvalues, retcols): + def simple_select_list_txn( + cls, + txn: LoggingTransaction, + table: str, + keyvalues: Optional[Dict[str, Any]], + retcols: Iterable[str], + ) -> List[Dict[str, Any]]: """Executes a SELECT query on the named table, which may return zero or more rows, returning the result as a list of dicts. Args: - txn : Transaction object - table (str): the table name - keyvalues (dict[str, T] | None): + txn: Transaction object + table: the table name + keyvalues: column names and values to select the rows with, or None to not apply a WHERE clause. - retcols (iterable[str]): the names of the columns to return + retcols: the names of the columns to return """ if keyvalues: sql = "SELECT %s FROM %s WHERE %s" % ( @@ -1118,25 +1187,25 @@ class DatabasePool(object): async def simple_select_many_batch( self, - table, - column, - iterable, - retcols, - keyvalues={}, - desc="simple_select_many_batch", - batch_size=100, - ): + table: str, + column: str, + iterable: Iterable[Any], + retcols: Iterable[str], + keyvalues: Dict[str, Any] = {}, + desc: str = "simple_select_many_batch", + batch_size: int = 100, + ) -> List[Any]: """Executes a SELECT query on the named table, which may return zero or more rows, returning the result as a list of dicts. Filters rows by if value of `column` is in `iterable`. Args: - table : string giving the table name - column : column name to test for inclusion against `iterable` - iterable : list - keyvalues : dict of column names and values to select the rows with - retcols : list of strings giving the names of the columns to return + table: string giving the table name + column: column name to test for inclusion against `iterable` + iterable: list + keyvalues: dict of column names and values to select the rows with + retcols: list of strings giving the names of the columns to return """ results = [] # type: List[Dict[str, Any]] @@ -1165,19 +1234,27 @@ class DatabasePool(object): return results @classmethod - def simple_select_many_txn(cls, txn, table, column, iterable, keyvalues, retcols): + def simple_select_many_txn( + cls, + txn: LoggingTransaction, + table: str, + column: str, + iterable: Iterable[Any], + keyvalues: Dict[str, Any], + retcols: Iterable[str], + ) -> List[Dict[str, Any]]: """Executes a SELECT query on the named table, which may return zero or more rows, returning the result as a list of dicts. Filters rows by if value of `column` is in `iterable`. Args: - txn : Transaction object - table : string giving the table name - column : column name to test for inclusion against `iterable` - iterable : list - keyvalues : dict of column names and values to select the rows with - retcols : list of strings giving the names of the columns to return + txn: Transaction object + table: string giving the table name + column: column name to test for inclusion against `iterable` + iterable: list + keyvalues: dict of column names and values to select the rows with + retcols: list of strings giving the names of the columns to return """ if not iterable: return [] @@ -1198,13 +1275,24 @@ class DatabasePool(object): txn.execute(sql, values) return cls.cursor_to_dict(txn) - def simple_update(self, table, keyvalues, updatevalues, desc): + def simple_update( + self, + table: str, + keyvalues: Dict[str, Any], + updatevalues: Dict[str, Any], + desc: str, + ) -> defer.Deferred: return self.runInteraction( desc, self.simple_update_txn, table, keyvalues, updatevalues ) @staticmethod - def simple_update_txn(txn, table, keyvalues, updatevalues): + def simple_update_txn( + txn: LoggingTransaction, + table: str, + keyvalues: Dict[str, Any], + updatevalues: Dict[str, Any], + ) -> int: if keyvalues: where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys()) else: @@ -1221,31 +1309,32 @@ class DatabasePool(object): return txn.rowcount def simple_update_one( - self, table, keyvalues, updatevalues, desc="simple_update_one" - ): + self, + table: str, + keyvalues: Dict[str, Any], + updatevalues: Dict[str, Any], + desc: str = "simple_update_one", + ) -> defer.Deferred: """Executes an UPDATE query on the named table, setting new values for columns in a row matching the key values. Args: - table : string giving the table name - keyvalues : dict of column names and values to select the row with - updatevalues : dict giving column names and values to update - retcols : optional list of column names to return - - If present, retcols gives a list of column names on which to perform - a SELECT statement *before* performing the UPDATE statement. The values - of these will be returned in a dict. - - These are performed within the same transaction, allowing an atomic - get-and-set. This can be used to implement compare-and-set by putting - the update column in the 'keyvalues' dict as well. + table: string giving the table name + keyvalues: dict of column names and values to select the row with + updatevalues: dict giving column names and values to update """ return self.runInteraction( desc, self.simple_update_one_txn, table, keyvalues, updatevalues ) @classmethod - def simple_update_one_txn(cls, txn, table, keyvalues, updatevalues): + def simple_update_one_txn( + cls, + txn: LoggingTransaction, + table: str, + keyvalues: Dict[str, Any], + updatevalues: Dict[str, Any], + ) -> None: rowcount = cls.simple_update_txn(txn, table, keyvalues, updatevalues) if rowcount == 0: @@ -1253,8 +1342,18 @@ class DatabasePool(object): if rowcount > 1: raise StoreError(500, "More than one row matched (%s)" % (table,)) + # Ideally we could use the overload decorator here to specify that the + # return type is only optional if allow_none is True, but this does not work + # when you call a static method from an instance. + # See https://github.com/python/mypy/issues/7781 @staticmethod - def simple_select_one_txn(txn, table, keyvalues, retcols, allow_none=False): + def simple_select_one_txn( + txn: LoggingTransaction, + table: str, + keyvalues: Dict[str, Any], + retcols: Iterable[str], + allow_none: bool = False, + ) -> Optional[Dict[str, Any]]: select_sql = "SELECT %s FROM %s WHERE %s" % ( ", ".join(retcols), table, @@ -1273,24 +1372,28 @@ class DatabasePool(object): return dict(zip(retcols, row)) - def simple_delete_one(self, table, keyvalues, desc="simple_delete_one"): + def simple_delete_one( + self, table: str, keyvalues: Dict[str, Any], desc: str = "simple_delete_one" + ) -> defer.Deferred: """Executes a DELETE query on the named table, expecting to delete a single row. Args: - table : string giving the table name - keyvalues : dict of column names and values to select the row with + table: string giving the table name + keyvalues: dict of column names and values to select the row with """ return self.runInteraction(desc, self.simple_delete_one_txn, table, keyvalues) @staticmethod - def simple_delete_one_txn(txn, table, keyvalues): + def simple_delete_one_txn( + txn: LoggingTransaction, table: str, keyvalues: Dict[str, Any] + ) -> None: """Executes a DELETE query on the named table, expecting to delete a single row. Args: - table : string giving the table name - keyvalues : dict of column names and values to select the row with + table: string giving the table name + keyvalues: dict of column names and values to select the row with """ sql = "DELETE FROM %s WHERE %s" % ( table, @@ -1303,11 +1406,13 @@ class DatabasePool(object): if txn.rowcount > 1: raise StoreError(500, "More than one row matched (%s)" % (table,)) - def simple_delete(self, table, keyvalues, desc): + def simple_delete(self, table: str, keyvalues: Dict[str, Any], desc: str): return self.runInteraction(desc, self.simple_delete_txn, table, keyvalues) @staticmethod - def simple_delete_txn(txn, table, keyvalues): + def simple_delete_txn( + txn: LoggingTransaction, table: str, keyvalues: Dict[str, Any] + ) -> int: sql = "DELETE FROM %s WHERE %s" % ( table, " AND ".join("%s = ?" % (k,) for k in keyvalues), @@ -1316,26 +1421,39 @@ class DatabasePool(object): txn.execute(sql, list(keyvalues.values())) return txn.rowcount - def simple_delete_many(self, table, column, iterable, keyvalues, desc): + def simple_delete_many( + self, + table: str, + column: str, + iterable: Iterable[Any], + keyvalues: Dict[str, Any], + desc: str, + ) -> defer.Deferred: return self.runInteraction( desc, self.simple_delete_many_txn, table, column, iterable, keyvalues ) @staticmethod - def simple_delete_many_txn(txn, table, column, iterable, keyvalues): + def simple_delete_many_txn( + txn: LoggingTransaction, + table: str, + column: str, + iterable: Iterable[Any], + keyvalues: Dict[str, Any], + ) -> int: """Executes a DELETE query on the named table. Filters rows by if value of `column` is in `iterable`. Args: - txn : Transaction object - table : string giving the table name - column : column name to test for inclusion against `iterable` - iterable : list - keyvalues : dict of column names and values to select the rows with + txn: Transaction object + table: string giving the table name + column: column name to test for inclusion against `iterable` + iterable: list + keyvalues: dict of column names and values to select the rows with Returns: - int: Number rows deleted + Number rows deleted """ if not iterable: return 0 @@ -1356,8 +1474,14 @@ class DatabasePool(object): return txn.rowcount def get_cache_dict( - self, db_conn, table, entity_column, stream_column, max_value, limit=100000 - ): + self, + db_conn: Connection, + table: str, + entity_column: str, + stream_column: str, + max_value: int, + limit: int = 100000, + ) -> Tuple[Dict[Any, int], int]: # Fetch a mapping of room_id -> max stream position for "recent" rooms. # It doesn't really matter how many we get, the StreamChangeCache will # do the right thing to ensure it respects the max size of cache. @@ -1390,34 +1514,34 @@ class DatabasePool(object): def simple_select_list_paginate( self, - table, - orderby, - start, - limit, - retcols, - filters=None, - keyvalues=None, - order_direction="ASC", - desc="simple_select_list_paginate", - ): + table: str, + orderby: str, + start: int, + limit: int, + retcols: Iterable[str], + filters: Optional[Dict[str, Any]] = None, + keyvalues: Optional[Dict[str, Any]] = None, + order_direction: str = "ASC", + desc: str = "simple_select_list_paginate", + ) -> defer.Deferred: """ Executes a SELECT query on the named table with start and limit, of row numbers, which may return zero or number of rows from start to limit, returning the result as a list of dicts. Args: - table (str): the table name - filters (dict[str, T] | None): + table: the table name + orderby: Column to order the results by. + start: Index to begin the query at. + limit: Number of results to return. + retcols: the names of the columns to return + filters: column names and values to filter the rows with, or None to not apply a WHERE ? LIKE ? clause. - keyvalues (dict[str, T] | None): + keyvalues: column names and values to select the rows with, or None to not apply a WHERE clause. - orderby (str): Column to order the results by. - start (int): Index to begin the query at. - limit (int): Number of results to return. - retcols (iterable[str]): the names of the columns to return - order_direction (str): Whether the results should be ordered "ASC" or "DESC". + order_direction: Whether the results should be ordered "ASC" or "DESC". Returns: defer.Deferred: resolves to list[dict[str, Any]] """ @@ -1437,16 +1561,16 @@ class DatabasePool(object): @classmethod def simple_select_list_paginate_txn( cls, - txn, - table, - orderby, - start, - limit, - retcols, - filters=None, - keyvalues=None, - order_direction="ASC", - ): + txn: LoggingTransaction, + table: str, + orderby: str, + start: int, + limit: int, + retcols: Iterable[str], + filters: Optional[Dict[str, Any]] = None, + keyvalues: Optional[Dict[str, Any]] = None, + order_direction: str = "ASC", + ) -> List[Dict[str, Any]]: """ Executes a SELECT query on the named table with start and limit, of row numbers, which may return zero or number of rows from start to limit, @@ -1457,21 +1581,22 @@ class DatabasePool(object): using 'AND'. Args: - txn : Transaction object - table (str): the table name - orderby (str): Column to order the results by. - start (int): Index to begin the query at. - limit (int): Number of results to return. - retcols (iterable[str]): the names of the columns to return - filters (dict[str, T] | None): + txn: Transaction object + table: the table name + orderby: Column to order the results by. + start: Index to begin the query at. + limit: Number of results to return. + retcols: the names of the columns to return + filters: column names and values to filter the rows with, or None to not apply a WHERE ? LIKE ? clause. - keyvalues (dict[str, T] | None): + keyvalues: column names and values to select the rows with, or None to not apply a WHERE clause. - order_direction (str): Whether the results should be ordered "ASC" or "DESC". + order_direction: Whether the results should be ordered "ASC" or "DESC". + Returns: - defer.Deferred: resolves to list[dict[str, Any]] + The result as a list of dictionaries. """ if order_direction not in ["ASC", "DESC"]: raise ValueError("order_direction must be one of 'ASC' or 'DESC'.") @@ -1497,16 +1622,23 @@ class DatabasePool(object): return cls.cursor_to_dict(txn) - def simple_search_list(self, table, term, col, retcols, desc="simple_search_list"): + def simple_search_list( + self, + table: str, + term: Optional[str], + col: str, + retcols: Iterable[str], + desc="simple_search_list", + ): """Executes a SELECT query on the named table, which may return zero or more rows, returning the result as a list of dicts. Args: - table (str): the table name - term (str | None): - term for searching the table matched to a column. - col (str): column to query term should be matched to - retcols (iterable[str]): the names of the columns to return + table: the table name + term: term for searching the table matched to a column. + col: column to query term should be matched to + retcols: the names of the columns to return + Returns: defer.Deferred: resolves to list[dict[str, Any]] or None """ @@ -1516,19 +1648,26 @@ class DatabasePool(object): ) @classmethod - def simple_search_list_txn(cls, txn, table, term, col, retcols): + def simple_search_list_txn( + cls, + txn: LoggingTransaction, + table: str, + term: Optional[str], + col: str, + retcols: Iterable[str], + ) -> Union[List[Dict[str, Any]], int]: """Executes a SELECT query on the named table, which may return zero or more rows, returning the result as a list of dicts. Args: - txn : Transaction object - table (str): the table name - term (str | None): - term for searching the table matched to a column. - col (str): column to query term should be matched to - retcols (iterable[str]): the names of the columns to return + txn: Transaction object + table: the table name + term: term for searching the table matched to a column. + col: column to query term should be matched to + retcols: the names of the columns to return + Returns: - defer.Deferred: resolves to list[dict[str, Any]] or None + 0 if no term is given, otherwise a list of dictionaries. """ if term: sql = "SELECT %s FROM %s WHERE %s LIKE ?" % (", ".join(retcols), table, col) @@ -1541,7 +1680,7 @@ class DatabasePool(object): def make_in_list_sql_clause( - database_engine, column: str, iterable: Iterable + database_engine: BaseDatabaseEngine, column: str, iterable: Iterable ) -> Tuple[str, list]: """Returns an SQL clause that checks the given column is in the iterable. diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py index 37276f73f8..d80d7da895 100644 --- a/synapse/storage/databases/main/ui_auth.py +++ b/synapse/storage/databases/main/ui_auth.py @@ -19,6 +19,7 @@ from canonicaljson import json from synapse.api.errors import StoreError from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.storage.database import LoggingTransaction from synapse.types import JsonDict from synapse.util import stringutils as stringutils @@ -214,14 +215,16 @@ class UIAuthWorkerStore(SQLBaseStore): value, ) - def _set_ui_auth_session_data_txn(self, txn, session_id: str, key: str, value: Any): + def _set_ui_auth_session_data_txn( + self, txn: LoggingTransaction, session_id: str, key: str, value: Any + ): # Get the current value. result = self.db_pool.simple_select_one_txn( txn, table="ui_auth_sessions", keyvalues={"session_id": session_id}, retcols=("serverdict",), - ) + ) # type: Dict[str, Any] # type: ignore # Update it and add it back to the database. serverdict = db_to_json(result["serverdict"]) @@ -275,7 +278,9 @@ class UIAuthStore(UIAuthWorkerStore): expiration_time, ) - def _delete_old_ui_auth_sessions_txn(self, txn, expiration_time: int): + def _delete_old_ui_auth_sessions_txn( + self, txn: LoggingTransaction, expiration_time: int + ): # Get the expired sessions. sql = "SELECT session_id FROM ui_auth_sessions WHERE creation_time <= ?" txn.execute(sql, [expiration_time]) -- cgit 1.5.1 From dbc630a628e4fc6eb5eff09ce5edba062c0e9955 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 20 Aug 2020 10:32:33 -0400 Subject: Use the JSON encoder without whitespace in more places. (#8124) --- changelog.d/8124.misc | 1 + synapse/handlers/devicemessage.py | 5 ++--- synapse/logging/opentracing.py | 5 ++--- synapse/rest/well_known.py | 4 ++-- synapse/storage/background_updates.py | 5 ++--- synapse/storage/databases/main/appservice.py | 5 ++--- synapse/storage/databases/main/room.py | 5 ++--- synapse/storage/databases/main/tags.py | 7 +++---- synapse/storage/databases/main/ui_auth.py | 11 +++++------ 9 files changed, 21 insertions(+), 27 deletions(-) create mode 100644 changelog.d/8124.misc (limited to 'synapse/storage/databases/main/ui_auth.py') diff --git a/changelog.d/8124.misc b/changelog.d/8124.misc new file mode 100644 index 0000000000..9fac710205 --- /dev/null +++ b/changelog.d/8124.misc @@ -0,0 +1 @@ +Reduce the amount of whitespace in JSON stored and sent in responses. diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py index 610b08d00b..dcb4c82244 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py @@ -16,8 +16,6 @@ import logging from typing import Any, Dict -from canonicaljson import json - from synapse.api.errors import SynapseError from synapse.logging.context import run_in_background from synapse.logging.opentracing import ( @@ -27,6 +25,7 @@ from synapse.logging.opentracing import ( start_active_span, ) from synapse.types import UserID, get_domain_from_id +from synapse.util import json_encoder from synapse.util.stringutils import random_string logger = logging.getLogger(__name__) @@ -174,7 +173,7 @@ class DeviceMessageHandler(object): "sender": sender_user_id, "type": message_type, "message_id": message_id, - "org.matrix.opentracing_context": json.dumps(context), + "org.matrix.opentracing_context": json_encoder.encode(context), } log_kv({"local_messages": local_messages}) diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index abe532d350..d39ac62168 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -172,12 +172,11 @@ from functools import wraps from typing import TYPE_CHECKING, Dict, Optional, Type import attr -from canonicaljson import json from twisted.internet import defer from synapse.config import ConfigError -from synapse.util import json_decoder +from synapse.util import json_decoder, json_encoder if TYPE_CHECKING: from synapse.http.site import SynapseRequest @@ -693,7 +692,7 @@ def active_span_context_as_string(): opentracing.tracer.inject( opentracing.tracer.active_span, opentracing.Format.TEXT_MAP, carrier ) - return json.dumps(carrier) + return json_encoder.encode(carrier) @only_if_tracing diff --git a/synapse/rest/well_known.py b/synapse/rest/well_known.py index 20177b44e7..e15e13b756 100644 --- a/synapse/rest/well_known.py +++ b/synapse/rest/well_known.py @@ -13,12 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json import logging from twisted.web.resource import Resource from synapse.http.server import set_cors_headers +from synapse.util import json_encoder logger = logging.getLogger(__name__) @@ -67,4 +67,4 @@ class WellKnownResource(Resource): logger.debug("returning: %s", r) request.setHeader(b"Content-Type", b"application/json") - return json.dumps(r).encode("utf-8") + return json_encoder.encode(r).encode("utf-8") diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index 90a1f9e8b1..56818f4df8 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -16,9 +16,8 @@ import logging from typing import Optional -from canonicaljson import json - from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.util import json_encoder from . import engines @@ -457,7 +456,7 @@ class BackgroundUpdater(object): progress(dict): The progress of the update. """ - progress_json = json.dumps(progress) + progress_json = json_encoder.encode(progress) self.db_pool.simple_update_one_txn( txn, diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index 02568a2391..77723f7d4d 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -16,13 +16,12 @@ import logging import re -from canonicaljson import json - from synapse.appservice import AppServiceTransaction from synapse.config.appservice import load_appservices from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import DatabasePool from synapse.storage.databases.main.events_worker import EventsWorkerStore +from synapse.util import json_encoder logger = logging.getLogger(__name__) @@ -204,7 +203,7 @@ class ApplicationServiceTransactionWorkerStore( new_txn_id = max(highest_txn_id, last_txn_id) + 1 # Insert new txn into txn table - event_ids = json.dumps([e.event_id for e in events]) + event_ids = json_encoder.encode([e.event_id for e in events]) txn.execute( "INSERT INTO application_services_txns(as_id, txn_id, event_ids) " "VALUES(?,?,?)", diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index aef08c7e12..7d3ac47261 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -21,8 +21,6 @@ from abc import abstractmethod from enum import Enum from typing import Any, Dict, List, Optional, Tuple -from canonicaljson import json - from synapse.api.constants import EventTypes from synapse.api.errors import StoreError from synapse.api.room_versions import RoomVersion, RoomVersions @@ -30,6 +28,7 @@ from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.databases.main.search import SearchStore from synapse.types import ThirdPartyInstanceID +from synapse.util import json_encoder from synapse.util.caches.descriptors import cached logger = logging.getLogger(__name__) @@ -1310,7 +1309,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): "event_id": event_id, "user_id": user_id, "reason": reason, - "content": json.dumps(content), + "content": json_encoder.encode(content), }, desc="add_event_report", ) diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py index e4e0a0c433..ade7abc927 100644 --- a/synapse/storage/databases/main/tags.py +++ b/synapse/storage/databases/main/tags.py @@ -17,11 +17,10 @@ import logging from typing import Dict, List, Tuple -from canonicaljson import json - from synapse.storage._base import db_to_json from synapse.storage.databases.main.account_data import AccountDataWorkerStore from synapse.types import JsonDict +from synapse.util import json_encoder from synapse.util.caches.descriptors import cached logger = logging.getLogger(__name__) @@ -98,7 +97,7 @@ class TagsWorkerStore(AccountDataWorkerStore): txn.execute(sql, (user_id, room_id)) tags = [] for tag, content in txn: - tags.append(json.dumps(tag) + ":" + content) + tags.append(json_encoder.encode(tag) + ":" + content) tag_json = "{" + ",".join(tags) + "}" results.append((stream_id, (user_id, room_id, tag_json))) @@ -200,7 +199,7 @@ class TagsStore(TagsWorkerStore): Returns: The next account data ID. """ - content_json = json.dumps(content) + content_json = json_encoder.encode(content) def add_tag_txn(txn, next_id): self.db_pool.simple_upsert_txn( diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py index d80d7da895..6281a41a3d 100644 --- a/synapse/storage/databases/main/ui_auth.py +++ b/synapse/storage/databases/main/ui_auth.py @@ -15,13 +15,12 @@ from typing import Any, Dict, Optional, Union import attr -from canonicaljson import json from synapse.api.errors import StoreError from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import LoggingTransaction from synapse.types import JsonDict -from synapse.util import stringutils as stringutils +from synapse.util import json_encoder, stringutils @attr.s @@ -73,7 +72,7 @@ class UIAuthWorkerStore(SQLBaseStore): StoreError if a unique session ID cannot be generated. """ # The clientdict gets stored as JSON. - clientdict_json = json.dumps(clientdict) + clientdict_json = json_encoder.encode(clientdict) # autogen a session ID and try to create it. We may clash, so just # try a few times till one goes through, giving up eventually. @@ -144,7 +143,7 @@ class UIAuthWorkerStore(SQLBaseStore): await self.db_pool.simple_upsert( table="ui_auth_sessions_credentials", keyvalues={"session_id": session_id, "stage_type": stage_type}, - values={"result": json.dumps(result)}, + values={"result": json_encoder.encode(result)}, desc="mark_ui_auth_stage_complete", ) except self.db_pool.engine.module.IntegrityError: @@ -185,7 +184,7 @@ class UIAuthWorkerStore(SQLBaseStore): The dictionary from the client root level, not the 'auth' key. """ # The clientdict gets stored as JSON. - clientdict_json = json.dumps(clientdict) + clientdict_json = json_encoder.encode(clientdict) await self.db_pool.simple_update_one( table="ui_auth_sessions", @@ -234,7 +233,7 @@ class UIAuthWorkerStore(SQLBaseStore): txn, table="ui_auth_sessions", keyvalues={"session_id": session_id}, - updatevalues={"serverdict": json.dumps(serverdict)}, + updatevalues={"serverdict": json_encoder.encode(serverdict)}, ) async def get_ui_auth_session_data( -- cgit 1.5.1 From 3f91638da6ea0aeaf789ddc8ca1e624a11b7ebb2 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 20 Aug 2020 15:42:58 -0400 Subject: Allow denying or shadow banning registrations via the spam checker (#8034) --- changelog.d/8034.feature | 1 + synapse/events/spamcheck.py | 35 ++++++++++++++- synapse/handlers/auth.py | 8 ++++ synapse/handlers/cas_handler.py | 11 ++++- synapse/handlers/oidc_handler.py | 21 +++++++-- synapse/handlers/register.py | 26 ++++++++++- synapse/handlers/saml_handler.py | 18 +++++++- synapse/rest/client/v2_alpha/register.py | 5 +++ synapse/spam_checker_api/__init__.py | 11 +++++ .../main/schema/delta/58/07persist_ui_auth_ips.sql | 25 +++++++++++ synapse/storage/databases/main/ui_auth.py | 39 +++++++++++++++- tests/handlers/test_oidc.py | 18 ++++++-- tests/handlers/test_register.py | 52 +++++++++++++++++++++- tests/handlers/test_user_directory.py | 6 +-- 14 files changed, 258 insertions(+), 18 deletions(-) create mode 100644 changelog.d/8034.feature create mode 100644 synapse/storage/databases/main/schema/delta/58/07persist_ui_auth_ips.sql (limited to 'synapse/storage/databases/main/ui_auth.py') diff --git a/changelog.d/8034.feature b/changelog.d/8034.feature new file mode 100644 index 0000000000..813e6d0903 --- /dev/null +++ b/changelog.d/8034.feature @@ -0,0 +1 @@ +Add support for shadow-banning users (ignoring any message send requests). diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py index 1ffc9525d1..a7cddac974 100644 --- a/synapse/events/spamcheck.py +++ b/synapse/events/spamcheck.py @@ -15,9 +15,10 @@ # limitations under the License. import inspect -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional, Tuple -from synapse.spam_checker_api import SpamCheckerApi +from synapse.spam_checker_api import RegistrationBehaviour, SpamCheckerApi +from synapse.types import Collection MYPY = False if MYPY: @@ -160,3 +161,33 @@ class SpamChecker(object): return True return False + + def check_registration_for_spam( + self, + email_threepid: Optional[dict], + username: Optional[str], + request_info: Collection[Tuple[str, str]], + ) -> RegistrationBehaviour: + """Checks if we should allow the given registration request. + + Args: + email_threepid: The email threepid used for registering, if any + username: The request user name, if any + request_info: List of tuples of user agent and IP that + were used during the registration process. + + Returns: + Enum for how the request should be handled + """ + + for spam_checker in self.spam_checkers: + # For backwards compatibility, only run if the method exists on the + # spam checker + checker = getattr(spam_checker, "check_registration_for_spam", None) + if checker: + behaviour = checker(email_threepid, username, request_info) + assert isinstance(behaviour, RegistrationBehaviour) + if behaviour != RegistrationBehaviour.ALLOW: + return behaviour + + return RegistrationBehaviour.ALLOW diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 68d6870e40..654f58ddae 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -364,6 +364,14 @@ class AuthHandler(BaseHandler): # authentication flow. await self.store.set_ui_auth_clientdict(sid, clientdict) + user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[ + 0 + ].decode("ascii", "surrogateescape") + + await self.store.add_user_agent_ip_to_ui_auth_session( + session.session_id, user_agent, clientip + ) + if not authdict: raise InteractiveAuthIncompleteError( session.session_id, self._auth_dict_for_flows(flows, session.session_id) diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py index 786e608fa2..a4cc4b9a5a 100644 --- a/synapse/handlers/cas_handler.py +++ b/synapse/handlers/cas_handler.py @@ -35,6 +35,7 @@ class CasHandler: """ def __init__(self, hs): + self.hs = hs self._hostname = hs.hostname self._auth_handler = hs.get_auth_handler() self._registration_handler = hs.get_registration_handler() @@ -210,8 +211,16 @@ class CasHandler: else: if not registered_user_id: + # Pull out the user-agent and IP from the request. + user_agent = request.requestHeaders.getRawHeaders( + b"User-Agent", default=[b""] + )[0].decode("ascii", "surrogateescape") + ip_address = self.hs.get_ip_from_request(request) + registered_user_id = await self._registration_handler.register_user( - localpart=localpart, default_display_name=user_display_name + localpart=localpart, + default_display_name=user_display_name, + user_agent_ips=(user_agent, ip_address), ) await self._auth_handler.complete_sso_login( diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index dd3703cbd2..c5bd2fea68 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -93,6 +93,7 @@ class OidcHandler: """ def __init__(self, hs: "HomeServer"): + self.hs = hs self._callback_url = hs.config.oidc_callback_url # type: str self._scopes = hs.config.oidc_scopes # type: List[str] self._client_auth = ClientAuth( @@ -689,9 +690,17 @@ class OidcHandler: self._render_error(request, "invalid_token", str(e)) return + # Pull out the user-agent and IP from the request. + user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[ + 0 + ].decode("ascii", "surrogateescape") + ip_address = self.hs.get_ip_from_request(request) + # Call the mapper to register/login the user try: - user_id = await self._map_userinfo_to_user(userinfo, token) + user_id = await self._map_userinfo_to_user( + userinfo, token, user_agent, ip_address + ) except MappingException as e: logger.exception("Could not map user") self._render_error(request, "mapping_error", str(e)) @@ -828,7 +837,9 @@ class OidcHandler: now = self._clock.time_msec() return now < expiry - async def _map_userinfo_to_user(self, userinfo: UserInfo, token: Token) -> str: + async def _map_userinfo_to_user( + self, userinfo: UserInfo, token: Token, user_agent: str, ip_address: str + ) -> str: """Maps a UserInfo object to a mxid. UserInfo should have a claim that uniquely identifies users. This claim @@ -843,6 +854,8 @@ class OidcHandler: Args: userinfo: an object representing the user token: a dict with the tokens obtained from the provider + user_agent: The user agent of the client making the request. + ip_address: The IP address of the client making the request. Raises: MappingException: if there was an error while mapping some properties @@ -899,7 +912,9 @@ class OidcHandler: # It's the first time this user is logging in and the mapped mxid was # not taken, register the user registered_user_id = await self._registration_handler.register_user( - localpart=localpart, default_display_name=attributes["display_name"], + localpart=localpart, + default_display_name=attributes["display_name"], + user_agent_ips=(user_agent, ip_address), ) await self._datastore.record_user_external_id( diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index ccd96e4626..cde2dbca92 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -26,6 +26,7 @@ from synapse.replication.http.register import ( ReplicationPostRegisterActionsServlet, ReplicationRegisterServlet, ) +from synapse.spam_checker_api import RegistrationBehaviour from synapse.storage.state import StateFilter from synapse.types import RoomAlias, UserID, create_requester @@ -52,6 +53,8 @@ class RegistrationHandler(BaseHandler): self.macaroon_gen = hs.get_macaroon_generator() self._server_notices_mxid = hs.config.server_notices_mxid + self.spam_checker = hs.get_spam_checker() + if hs.config.worker_app: self._register_client = ReplicationRegisterServlet.make_client(hs) self._register_device_client = RegisterDeviceReplicationServlet.make_client( @@ -144,7 +147,7 @@ class RegistrationHandler(BaseHandler): address=None, bind_emails=[], by_admin=False, - shadow_banned=False, + user_agent_ips=None, ): """Registers a new client on the server. @@ -162,7 +165,8 @@ class RegistrationHandler(BaseHandler): bind_emails (List[str]): list of emails to bind to this account. by_admin (bool): True if this registration is being made via the admin api, otherwise False. - shadow_banned (bool): Shadow-ban the created user. + user_agent_ips (List[(str, str)]): Tuples of IP addresses and user-agents used + during the registration process. Returns: str: user_id Raises: @@ -170,6 +174,24 @@ class RegistrationHandler(BaseHandler): """ self.check_registration_ratelimit(address) + result = self.spam_checker.check_registration_for_spam( + threepid, localpart, user_agent_ips or [], + ) + + if result == RegistrationBehaviour.DENY: + logger.info( + "Blocked registration of %r", localpart, + ) + # We return a 429 to make it not obvious that they've been + # denied. + raise SynapseError(429, "Rate limited") + + shadow_banned = result == RegistrationBehaviour.SHADOW_BAN + if shadow_banned: + logger.info( + "Shadow banning registration of %r", localpart, + ) + # do not check_auth_blocking if the call is coming through the Admin API if not by_admin: await self.auth.check_auth_blocking(threepid=threepid) diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py index c1fcb98454..b426199aa6 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml_handler.py @@ -54,6 +54,7 @@ class Saml2SessionData: class SamlHandler: def __init__(self, hs: "synapse.server.HomeServer"): + self.hs = hs self._saml_client = Saml2Client(hs.config.saml2_sp_config) self._auth = hs.get_auth() self._auth_handler = hs.get_auth_handler() @@ -133,8 +134,14 @@ class SamlHandler: # the dict. self.expire_sessions() + # Pull out the user-agent and IP from the request. + user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[ + 0 + ].decode("ascii", "surrogateescape") + ip_address = self.hs.get_ip_from_request(request) + user_id, current_session = await self._map_saml_response_to_user( - resp_bytes, relay_state + resp_bytes, relay_state, user_agent, ip_address ) # Complete the interactive auth session or the login. @@ -147,7 +154,11 @@ class SamlHandler: await self._auth_handler.complete_sso_login(user_id, request, relay_state) async def _map_saml_response_to_user( - self, resp_bytes: str, client_redirect_url: str + self, + resp_bytes: str, + client_redirect_url: str, + user_agent: str, + ip_address: str, ) -> Tuple[str, Optional[Saml2SessionData]]: """ Given a sample response, retrieve the cached session and user for it. @@ -155,6 +166,8 @@ class SamlHandler: Args: resp_bytes: The SAML response. client_redirect_url: The redirect URL passed in by the client. + user_agent: The user agent of the client making the request. + ip_address: The IP address of the client making the request. Returns: Tuple of the user ID and SAML session associated with this response. @@ -291,6 +304,7 @@ class SamlHandler: localpart=localpart, default_display_name=displayname, bind_emails=emails, + user_agent_ips=(user_agent, ip_address), ) await self._datastore.record_user_external_id( diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 7290fd0756..be0e680ac5 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -591,12 +591,17 @@ class RegisterRestServlet(RestServlet): Codes.THREEPID_IN_USE, ) + entries = await self.store.get_user_agents_ips_to_ui_auth_session( + session_id + ) + registered_user_id = await self.registration_handler.register_user( localpart=desired_username, password_hash=password_hash, guest_access_token=guest_access_token, threepid=threepid, address=client_addr, + user_agent_ips=entries, ) # Necessary due to auth checks prior to the threepid being # written to the db diff --git a/synapse/spam_checker_api/__init__.py b/synapse/spam_checker_api/__init__.py index 7f63f1bfa0..9be92e2565 100644 --- a/synapse/spam_checker_api/__init__.py +++ b/synapse/spam_checker_api/__init__.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from enum import Enum from twisted.internet import defer @@ -25,6 +26,16 @@ if MYPY: logger = logging.getLogger(__name__) +class RegistrationBehaviour(Enum): + """ + Enum to define whether a registration request should allowed, denied, or shadow-banned. + """ + + ALLOW = "allow" + SHADOW_BAN = "shadow_ban" + DENY = "deny" + + class SpamCheckerApi(object): """A proxy object that gets passed to spam checkers so they can get access to rooms and other relevant information. diff --git a/synapse/storage/databases/main/schema/delta/58/07persist_ui_auth_ips.sql b/synapse/storage/databases/main/schema/delta/58/07persist_ui_auth_ips.sql new file mode 100644 index 0000000000..4cc96a5341 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/07persist_ui_auth_ips.sql @@ -0,0 +1,25 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- A table of the IP address and user-agent used to complete each step of a +-- user-interactive authentication session. +CREATE TABLE IF NOT EXISTS ui_auth_sessions_ips( + session_id TEXT NOT NULL, + ip TEXT NOT NULL, + user_agent TEXT NOT NULL, + UNIQUE (session_id, ip, user_agent), + FOREIGN KEY (session_id) + REFERENCES ui_auth_sessions (session_id) +); diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py index 6281a41a3d..9eef8e57c5 100644 --- a/synapse/storage/databases/main/ui_auth.py +++ b/synapse/storage/databases/main/ui_auth.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union import attr @@ -260,6 +260,34 @@ class UIAuthWorkerStore(SQLBaseStore): return serverdict.get(key, default) + async def add_user_agent_ip_to_ui_auth_session( + self, session_id: str, user_agent: str, ip: str, + ): + """Add the given user agent / IP to the tracking table + """ + await self.db_pool.simple_upsert( + table="ui_auth_sessions_ips", + keyvalues={"session_id": session_id, "user_agent": user_agent, "ip": ip}, + values={}, + desc="add_user_agent_ip_to_ui_auth_session", + ) + + async def get_user_agents_ips_to_ui_auth_session( + self, session_id: str, + ) -> List[Tuple[str, str]]: + """Get the given user agents / IPs used during the ui auth process + + Returns: + List of user_agent/ip pairs + """ + rows = await self.db_pool.simple_select_list( + table="ui_auth_sessions_ips", + keyvalues={"session_id": session_id}, + retcols=("user_agent", "ip"), + desc="get_user_agents_ips_to_ui_auth_session", + ) + return [(row["user_agent"], row["ip"]) for row in rows] + class UIAuthStore(UIAuthWorkerStore): def delete_old_ui_auth_sessions(self, expiration_time: int): @@ -285,6 +313,15 @@ class UIAuthStore(UIAuthWorkerStore): txn.execute(sql, [expiration_time]) session_ids = [r[0] for r in txn.fetchall()] + # Delete the corresponding IP/user agents. + self.db_pool.simple_delete_many_txn( + txn, + table="ui_auth_sessions_ips", + column="session_id", + iterable=session_ids, + keyvalues={}, + ) + # Delete the corresponding completed credentials. self.db_pool.simple_delete_many_txn( txn, diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index 1bb25ab684..f92f3b8c15 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -374,12 +374,16 @@ class OidcHandlerTestCase(HomeserverTestCase): self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo) self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id) self.handler._auth_handler.complete_sso_login = simple_async_mock() - request = Mock(spec=["args", "getCookie", "addCookie"]) + request = Mock( + spec=["args", "getCookie", "addCookie", "requestHeaders", "getClientIP"] + ) code = "code" state = "state" nonce = "nonce" client_redirect_url = "http://client/redirect" + user_agent = "Browser" + ip_address = "10.0.0.1" session = self.handler._generate_oidc_session_token( state=state, nonce=nonce, @@ -392,6 +396,10 @@ class OidcHandlerTestCase(HomeserverTestCase): request.args[b"code"] = [code.encode("utf-8")] request.args[b"state"] = [state.encode("utf-8")] + request.requestHeaders = Mock(spec=["getRawHeaders"]) + request.requestHeaders.getRawHeaders.return_value = [user_agent.encode("ascii")] + request.getClientIP.return_value = ip_address + yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) self.handler._auth_handler.complete_sso_login.assert_called_once_with( @@ -399,7 +407,9 @@ class OidcHandlerTestCase(HomeserverTestCase): ) self.handler._exchange_code.assert_called_once_with(code) self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce) - self.handler._map_userinfo_to_user.assert_called_once_with(userinfo, token) + self.handler._map_userinfo_to_user.assert_called_once_with( + userinfo, token, user_agent, ip_address + ) self.handler._fetch_userinfo.assert_not_called() self.handler._render_error.assert_not_called() @@ -431,7 +441,9 @@ class OidcHandlerTestCase(HomeserverTestCase): ) self.handler._exchange_code.assert_called_once_with(code) self.handler._parse_id_token.assert_not_called() - self.handler._map_userinfo_to_user.assert_called_once_with(userinfo, token) + self.handler._map_userinfo_to_user.assert_called_once_with( + userinfo, token, user_agent, ip_address + ) self.handler._fetch_userinfo.assert_called_once_with(token) self.handler._render_error.assert_not_called() diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index e364b1bd62..5c92d0e8c9 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -17,18 +17,21 @@ from mock import Mock from twisted.internet import defer +from synapse.api.auth import Auth from synapse.api.constants import UserTypes from synapse.api.errors import Codes, ResourceLimitError, SynapseError from synapse.handlers.register import RegistrationHandler +from synapse.spam_checker_api import RegistrationBehaviour from synapse.types import RoomAlias, UserID, create_requester from tests.test_utils import make_awaitable from tests.unittest import override_config +from tests.utils import mock_getRawHeaders from .. import unittest -class RegistrationHandlers(object): +class RegistrationHandlers: def __init__(self, hs): self.registration_handler = RegistrationHandler(hs) @@ -475,6 +478,53 @@ class RegistrationTestCase(unittest.HomeserverTestCase): self.handler.register_user(localpart=invalid_user_id), SynapseError ) + def test_spam_checker_deny(self): + """A spam checker can deny registration, which results in an error.""" + + class DenyAll: + def check_registration_for_spam( + self, email_threepid, username, request_info + ): + return RegistrationBehaviour.DENY + + # Configure a spam checker that denies all users. + spam_checker = self.hs.get_spam_checker() + spam_checker.spam_checkers = [DenyAll()] + + self.get_failure(self.handler.register_user(localpart="user"), SynapseError) + + def test_spam_checker_shadow_ban(self): + """A spam checker can choose to shadow-ban a user, which allows registration to succeed.""" + + class BanAll: + def check_registration_for_spam( + self, email_threepid, username, request_info + ): + return RegistrationBehaviour.SHADOW_BAN + + # Configure a spam checker that denies all users. + spam_checker = self.hs.get_spam_checker() + spam_checker.spam_checkers = [BanAll()] + + user_id = self.get_success(self.handler.register_user(localpart="user")) + + # Get an access token. + token = self.macaroon_generator.generate_access_token(user_id) + self.get_success( + self.store.add_access_token_to_user( + user_id=user_id, token=token, device_id=None, valid_until_ms=None + ) + ) + + # Ensure the user was marked as shadow-banned. + request = Mock(args={}) + request.args[b"access_token"] = [token.encode("ascii")] + request.requestHeaders.getRawHeaders = mock_getRawHeaders() + auth = Auth(self.hs) + requester = self.get_success(auth.get_user_by_req(request)) + + self.assertTrue(requester.shadow_banned) + async def get_or_create_user( self, requester, localpart, displayname, password_hash=None ): diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index 31ed89a5cd..87be94111f 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -238,7 +238,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): def test_spam_checker(self): """ - A user which fails to the spam checks will not appear in search results. + A user which fails the spam checks will not appear in search results. """ u1 = self.register_user("user1", "pass") u1_token = self.login(u1, "pass") @@ -269,7 +269,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): # Configure a spam checker that does not filter any users. spam_checker = self.hs.get_spam_checker() - class AllowAll(object): + class AllowAll: def check_username_for_spam(self, user_profile): # Allow all users. return False @@ -282,7 +282,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): self.assertEqual(len(s["results"]), 1) # Configure a spam checker that filters all users. - class BlockAll(object): + class BlockAll: def check_username_for_spam(self, user_profile): # All users are spammy. return True -- cgit 1.5.1 From da77520cd1c414c9341da287967feb1bab14cbec Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 1 Sep 2020 08:39:04 -0400 Subject: Convert additional databases to async/await part 2 (#8200) --- changelog.d/8200.misc | 1 + synapse/events/builder.py | 19 +++++---- synapse/handlers/message.py | 13 ++---- synapse/handlers/room_member.py | 12 +----- synapse/storage/databases/main/client_ips.py | 4 +- synapse/storage/databases/main/directory.py | 6 +-- synapse/storage/databases/main/filtering.py | 5 ++- synapse/storage/databases/main/openid.py | 8 +++- synapse/storage/databases/main/profile.py | 6 ++- synapse/storage/databases/main/push_rule.py | 10 ++--- synapse/storage/databases/main/room.py | 49 ++++++++++++---------- synapse/storage/databases/main/signatures.py | 40 ++++++++++++++---- synapse/storage/databases/main/ui_auth.py | 4 +- .../storage/databases/main/user_erasure_store.py | 8 ++-- tests/test_utils/event_injection.py | 7 ++-- 15 files changed, 111 insertions(+), 81 deletions(-) create mode 100644 changelog.d/8200.misc (limited to 'synapse/storage/databases/main/ui_auth.py') diff --git a/changelog.d/8200.misc b/changelog.d/8200.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/8200.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/events/builder.py b/synapse/events/builder.py index 9ed24380dd..7878cd7044 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import Any, Dict, List, Optional, Tuple, Union import attr from nacl.signing import SigningKey @@ -97,14 +97,14 @@ class EventBuilder(object): def is_state(self): return self._state_key is not None - async def build(self, prev_event_ids): + async def build(self, prev_event_ids: List[str]) -> EventBase: """Transform into a fully signed and hashed event Args: - prev_event_ids (list[str]): The event IDs to use as the prev events + prev_event_ids: The event IDs to use as the prev events Returns: - FrozenEvent + The signed and hashed event. """ state_ids = await self._state.get_current_state_ids( @@ -114,8 +114,13 @@ class EventBuilder(object): format_version = self.room_version.event_format if format_version == EventFormatVersions.V1: - auth_events = await self._store.add_event_hashes(auth_ids) - prev_events = await self._store.add_event_hashes(prev_event_ids) + # The types of auth/prev events changes between event versions. + auth_events = await self._store.add_event_hashes( + auth_ids + ) # type: Union[List[str], List[Tuple[str, Dict[str, str]]]] + prev_events = await self._store.add_event_hashes( + prev_event_ids + ) # type: Union[List[str], List[Tuple[str, Dict[str, str]]]] else: auth_events = auth_ids prev_events = prev_event_ids @@ -138,7 +143,7 @@ class EventBuilder(object): "unsigned": self.unsigned, "depth": depth, "prev_state": [], - } + } # type: Dict[str, Any] if self.is_state(): event_dict["state_key"] = self._state_key diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 9d0c38f4df..72bb638167 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -49,14 +49,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.http.send_event import ReplicationSendEventRestServlet from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.storage.state import StateFilter -from synapse.types import ( - Collection, - Requester, - RoomAlias, - StreamToken, - UserID, - create_requester, -) +from synapse.types import Requester, RoomAlias, StreamToken, UserID, create_requester from synapse.util import json_decoder from synapse.util.async_helpers import Linearizer from synapse.util.frozenutils import frozendict_json_encoder @@ -446,7 +439,7 @@ class EventCreationHandler(object): event_dict: dict, token_id: Optional[str] = None, txn_id: Optional[str] = None, - prev_event_ids: Optional[Collection[str]] = None, + prev_event_ids: Optional[List[str]] = None, require_consent: bool = True, ) -> Tuple[EventBase, EventContext]: """ @@ -786,7 +779,7 @@ class EventCreationHandler(object): self, builder: EventBuilder, requester: Optional[Requester] = None, - prev_event_ids: Optional[Collection[str]] = None, + prev_event_ids: Optional[List[str]] = None, ) -> Tuple[EventBase, EventContext]: """Create a new event for a local client diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index cae4d013b8..a7962b0ada 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -38,15 +38,7 @@ from synapse.events.builder import create_local_event_from_event_dict from synapse.events.snapshot import EventContext from synapse.events.validator import EventValidator from synapse.storage.roommember import RoomsForUser -from synapse.types import ( - Collection, - JsonDict, - Requester, - RoomAlias, - RoomID, - StateMap, - UserID, -) +from synapse.types import JsonDict, Requester, RoomAlias, RoomID, StateMap, UserID from synapse.util.async_helpers import Linearizer from synapse.util.distributor import user_joined_room, user_left_room @@ -184,7 +176,7 @@ class RoomMemberHandler(object): target: UserID, room_id: str, membership: str, - prev_event_ids: Collection[str], + prev_event_ids: List[str], txn_id: Optional[str] = None, ratelimit: bool = True, content: Optional[dict] = None, diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py index 216a5925fc..c2fc847fbc 100644 --- a/synapse/storage/databases/main/client_ips.py +++ b/synapse/storage/databases/main/client_ips.py @@ -396,7 +396,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): self._batch_row_update[key] = (user_agent, device_id, now) @wrap_as_background_process("update_client_ips") - def _update_client_ips_batch(self): + async def _update_client_ips_batch(self) -> None: # If the DB pool has already terminated, don't try updating if not self.db_pool.is_running(): @@ -405,7 +405,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): to_update = self._batch_row_update self._batch_row_update = {} - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "_update_client_ips_batch", self._update_client_ips_batch_txn, to_update ) diff --git a/synapse/storage/databases/main/directory.py b/synapse/storage/databases/main/directory.py index 405b5eafa5..e5060d4c46 100644 --- a/synapse/storage/databases/main/directory.py +++ b/synapse/storage/databases/main/directory.py @@ -159,9 +159,9 @@ class DirectoryStore(DirectoryWorkerStore): return room_id - def update_aliases_for_room( + async def update_aliases_for_room( self, old_room_id: str, new_room_id: str, creator: Optional[str] = None, - ): + ) -> None: """Repoint all of the aliases for a given room, to a different room. Args: @@ -189,6 +189,6 @@ class DirectoryStore(DirectoryWorkerStore): txn, self.get_aliases_for_room, (new_room_id,) ) - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "_update_aliases_for_room_txn", _update_aliases_for_room_txn ) diff --git a/synapse/storage/databases/main/filtering.py b/synapse/storage/databases/main/filtering.py index 45a1760170..d2f5b9a502 100644 --- a/synapse/storage/databases/main/filtering.py +++ b/synapse/storage/databases/main/filtering.py @@ -17,6 +17,7 @@ from canonicaljson import encode_canonical_json from synapse.api.errors import Codes, SynapseError from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.types import JsonDict from synapse.util.caches.descriptors import cached @@ -40,7 +41,7 @@ class FilteringStore(SQLBaseStore): return db_to_json(def_json) - def add_user_filter(self, user_localpart, user_filter): + async def add_user_filter(self, user_localpart: str, user_filter: JsonDict) -> str: def_json = encode_canonical_json(user_filter) # Need an atomic transaction to SELECT the maximal ID so far then @@ -71,4 +72,4 @@ class FilteringStore(SQLBaseStore): return filter_id - return self.db_pool.runInteraction("add_user_filter", _do_txn) + return await self.db_pool.runInteraction("add_user_filter", _do_txn) diff --git a/synapse/storage/databases/main/openid.py b/synapse/storage/databases/main/openid.py index 4db8949da7..2aac64901b 100644 --- a/synapse/storage/databases/main/openid.py +++ b/synapse/storage/databases/main/openid.py @@ -1,3 +1,5 @@ +from typing import Optional + from synapse.storage._base import SQLBaseStore @@ -15,7 +17,9 @@ class OpenIdStore(SQLBaseStore): desc="insert_open_id_token", ) - def get_user_id_for_open_id_token(self, token, ts_now_ms): + async def get_user_id_for_open_id_token( + self, token: str, ts_now_ms: int + ) -> Optional[str]: def get_user_id_for_token_txn(txn): sql = ( "SELECT user_id FROM open_id_tokens" @@ -30,6 +34,6 @@ class OpenIdStore(SQLBaseStore): else: return rows[0][0] - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_user_id_for_token", get_user_id_for_token_txn ) diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py index 301875a672..d2e0685e9e 100644 --- a/synapse/storage/databases/main/profile.py +++ b/synapse/storage/databases/main/profile.py @@ -138,7 +138,9 @@ class ProfileStore(ProfileWorkerStore): desc="delete_remote_profile_cache", ) - def get_remote_profile_cache_entries_that_expire(self, last_checked): + async def get_remote_profile_cache_entries_that_expire( + self, last_checked: int + ) -> Dict[str, str]: """Get all users who haven't been checked since `last_checked` """ @@ -153,7 +155,7 @@ class ProfileStore(ProfileWorkerStore): return self.db_pool.cursor_to_dict(txn) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_remote_profile_cache_entries_that_expire", _get_remote_profile_cache_entries_that_expire_txn, ) diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index 2fb5b02d7d..0de802a86b 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -18,8 +18,6 @@ import abc import logging from typing import List, Tuple, Union -from twisted.internet import defer - from synapse.push.baserules import list_with_base_rules from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.storage._base import SQLBaseStore, db_to_json @@ -149,9 +147,11 @@ class PushRulesWorkerStore( ) return {r["rule_id"]: False if r["enabled"] == 0 else True for r in results} - def have_push_rules_changed_for_user(self, user_id, last_id): + async def have_push_rules_changed_for_user( + self, user_id: str, last_id: int + ) -> bool: if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id): - return defer.succeed(False) + return False else: def have_push_rules_changed_txn(txn): @@ -163,7 +163,7 @@ class PushRulesWorkerStore( (count,) = txn.fetchone() return bool(count) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "have_push_rules_changed", have_push_rules_changed_txn ) diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index a92641c339..717df97301 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -89,7 +89,7 @@ class RoomWorkerStore(SQLBaseStore): allow_none=True, ) - def get_room_with_stats(self, room_id: str): + async def get_room_with_stats(self, room_id: str) -> Optional[Dict[str, Any]]: """Retrieve room with statistics. Args: @@ -121,7 +121,7 @@ class RoomWorkerStore(SQLBaseStore): res["public"] = bool(res["public"]) return res - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_room_with_stats", get_room_with_stats_txn, room_id ) @@ -133,13 +133,17 @@ class RoomWorkerStore(SQLBaseStore): desc="get_public_room_ids", ) - def count_public_rooms(self, network_tuple, ignore_non_federatable): + async def count_public_rooms( + self, + network_tuple: Optional[ThirdPartyInstanceID], + ignore_non_federatable: bool, + ) -> int: """Counts the number of public rooms as tracked in the room_stats_current and room_stats_state table. Args: - network_tuple (ThirdPartyInstanceID|None) - ignore_non_federatable (bool): If true filters out non-federatable rooms + network_tuple + ignore_non_federatable: If true filters out non-federatable rooms """ def _count_public_rooms_txn(txn): @@ -183,7 +187,7 @@ class RoomWorkerStore(SQLBaseStore): txn.execute(sql, query_args) return txn.fetchone()[0] - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "count_public_rooms", _count_public_rooms_txn ) @@ -586,15 +590,14 @@ class RoomWorkerStore(SQLBaseStore): return row - def get_media_mxcs_in_room(self, room_id): + async def get_media_mxcs_in_room(self, room_id: str) -> Tuple[List[str], List[str]]: """Retrieves all the local and remote media MXC URIs in a given room Args: - room_id (str) + room_id Returns: - The local and remote media as a lists of tuples where the key is - the hostname and the value is the media ID. + The local and remote media as a lists of the media IDs. """ def _get_media_mxcs_in_room_txn(txn): @@ -610,11 +613,13 @@ class RoomWorkerStore(SQLBaseStore): return local_media_mxcs, remote_media_mxcs - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_media_ids_in_room", _get_media_mxcs_in_room_txn ) - def quarantine_media_ids_in_room(self, room_id, quarantined_by): + async def quarantine_media_ids_in_room( + self, room_id: str, quarantined_by: str + ) -> int: """For a room loops through all events with media and quarantines the associated media """ @@ -627,7 +632,7 @@ class RoomWorkerStore(SQLBaseStore): txn, local_mxcs, remote_mxcs, quarantined_by ) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "quarantine_media_in_room", _quarantine_media_in_room_txn ) @@ -690,9 +695,9 @@ class RoomWorkerStore(SQLBaseStore): return local_media_mxcs, remote_media_mxcs - def quarantine_media_by_id( + async def quarantine_media_by_id( self, server_name: str, media_id: str, quarantined_by: str, - ): + ) -> int: """quarantines a single local or remote media id Args: @@ -711,11 +716,13 @@ class RoomWorkerStore(SQLBaseStore): txn, local_mxcs, remote_mxcs, quarantined_by ) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "quarantine_media_by_user", _quarantine_media_by_id_txn ) - def quarantine_media_ids_by_user(self, user_id: str, quarantined_by: str): + async def quarantine_media_ids_by_user( + self, user_id: str, quarantined_by: str + ) -> int: """quarantines all local media associated with a single user Args: @@ -727,7 +734,7 @@ class RoomWorkerStore(SQLBaseStore): local_media_ids = self._get_media_ids_by_user_txn(txn, user_id) return self._quarantine_media_txn(txn, local_media_ids, [], quarantined_by) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "quarantine_media_by_user", _quarantine_media_by_user_txn ) @@ -1284,8 +1291,8 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): ) self.hs.get_notifier().on_new_replication_data() - def get_room_count(self): - """Retrieve a list of all rooms + async def get_room_count(self) -> int: + """Retrieve the total number of rooms. """ def f(txn): @@ -1294,7 +1301,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): row = txn.fetchone() return row[0] or 0 - return self.db_pool.runInteraction("get_rooms", f) + return await self.db_pool.runInteraction("get_rooms", f) async def add_event_report( self, diff --git a/synapse/storage/databases/main/signatures.py b/synapse/storage/databases/main/signatures.py index be191dd870..c8c67953e4 100644 --- a/synapse/storage/databases/main/signatures.py +++ b/synapse/storage/databases/main/signatures.py @@ -13,9 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Dict, Iterable, List, Tuple + from unpaddedbase64 import encode_base64 from synapse.storage._base import SQLBaseStore +from synapse.storage.types import Cursor from synapse.util.caches.descriptors import cached, cachedList @@ -29,16 +32,37 @@ class SignatureWorkerStore(SQLBaseStore): @cachedList( cached_method_name="get_event_reference_hash", list_name="event_ids", num_args=1 ) - def get_event_reference_hashes(self, event_ids): + async def get_event_reference_hashes( + self, event_ids: Iterable[str] + ) -> Dict[str, Dict[str, bytes]]: + """Get all hashes for given events. + + Args: + event_ids: The event IDs to get hashes for. + + Returns: + A mapping of event ID to a mapping of algorithm to hash. + """ + def f(txn): return { event_id: self._get_event_reference_hashes_txn(txn, event_id) for event_id in event_ids } - return self.db_pool.runInteraction("get_event_reference_hashes", f) + return await self.db_pool.runInteraction("get_event_reference_hashes", f) - async def add_event_hashes(self, event_ids): + async def add_event_hashes( + self, event_ids: Iterable[str] + ) -> List[Tuple[str, Dict[str, str]]]: + """ + + Args: + event_ids: The event IDs + + Returns: + A list of tuples of event ID and a mapping of algorithm to base-64 encoded hash. + """ hashes = await self.get_event_reference_hashes(event_ids) hashes = { e_id: {k: encode_base64(v) for k, v in h.items() if k == "sha256"} @@ -47,13 +71,15 @@ class SignatureWorkerStore(SQLBaseStore): return list(hashes.items()) - def _get_event_reference_hashes_txn(self, txn, event_id): + def _get_event_reference_hashes_txn( + self, txn: Cursor, event_id: str + ) -> Dict[str, bytes]: """Get all the hashes for a given PDU. Args: - txn (cursor): - event_id (str): Id for the Event. + txn: + event_id: Id for the Event. Returns: - A dict[unicode, bytes] of algorithm -> hash. + A mapping of algorithm -> hash. """ query = ( "SELECT algorithm, hash" diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py index 9eef8e57c5..b89668d561 100644 --- a/synapse/storage/databases/main/ui_auth.py +++ b/synapse/storage/databases/main/ui_auth.py @@ -290,7 +290,7 @@ class UIAuthWorkerStore(SQLBaseStore): class UIAuthStore(UIAuthWorkerStore): - def delete_old_ui_auth_sessions(self, expiration_time: int): + async def delete_old_ui_auth_sessions(self, expiration_time: int) -> None: """ Remove sessions which were last used earlier than the expiration time. @@ -299,7 +299,7 @@ class UIAuthStore(UIAuthWorkerStore): This is an epoch time in milliseconds. """ - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "delete_old_ui_auth_sessions", self._delete_old_ui_auth_sessions_txn, expiration_time, diff --git a/synapse/storage/databases/main/user_erasure_store.py b/synapse/storage/databases/main/user_erasure_store.py index e3547e53b3..2f7c95fc74 100644 --- a/synapse/storage/databases/main/user_erasure_store.py +++ b/synapse/storage/databases/main/user_erasure_store.py @@ -66,7 +66,7 @@ class UserErasureWorkerStore(SQLBaseStore): class UserErasureStore(UserErasureWorkerStore): - def mark_user_erased(self, user_id: str) -> None: + async def mark_user_erased(self, user_id: str) -> None: """Indicate that user_id wishes their message history to be erased. Args: @@ -84,9 +84,9 @@ class UserErasureStore(UserErasureWorkerStore): self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,)) - return self.db_pool.runInteraction("mark_user_erased", f) + await self.db_pool.runInteraction("mark_user_erased", f) - def mark_user_not_erased(self, user_id: str) -> None: + async def mark_user_not_erased(self, user_id: str) -> None: """Indicate that user_id is no longer erased. Args: @@ -106,4 +106,4 @@ class UserErasureStore(UserErasureWorkerStore): self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,)) - return self.db_pool.runInteraction("mark_user_not_erased", f) + await self.db_pool.runInteraction("mark_user_not_erased", f) diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py index 8522c6fc09..fb1ca90336 100644 --- a/tests/test_utils/event_injection.py +++ b/tests/test_utils/event_injection.py @@ -13,14 +13,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple +from typing import List, Optional, Tuple import synapse.server from synapse.api.constants import EventTypes from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import EventBase from synapse.events.snapshot import EventContext -from synapse.types import Collection """ Utility functions for poking events into the storage of the server under test. @@ -58,7 +57,7 @@ async def inject_member_event( async def inject_event( hs: synapse.server.HomeServer, room_version: Optional[str] = None, - prev_event_ids: Optional[Collection[str]] = None, + prev_event_ids: Optional[List[str]] = None, **kwargs ) -> EventBase: """Inject a generic event into a room @@ -80,7 +79,7 @@ async def inject_event( async def create_event( hs: synapse.server.HomeServer, room_version: Optional[str] = None, - prev_event_ids: Optional[Collection[str]] = None, + prev_event_ids: Optional[List[str]] = None, **kwargs ) -> Tuple[EventBase, EventContext]: if room_version is None: -- cgit 1.5.1 From aec294ee0d0f2fa4ccef57085d670b8939de3669 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 14 Sep 2020 12:50:06 -0400 Subject: Use slots in attrs classes where possible (#8296) slots use less memory (and attribute access is faster) while slightly limiting the flexibility of the class attributes. This focuses on objects which are instantiated "often" and for short periods of time. --- changelog.d/8296.misc | 1 + synapse/handlers/acme_issuing_service.py | 2 +- synapse/handlers/auth.py | 2 +- synapse/handlers/e2e_keys.py | 2 +- synapse/handlers/federation.py | 2 +- synapse/handlers/saml_handler.py | 2 +- synapse/handlers/sync.py | 34 +++++++--------------- synapse/http/federation/well_known_resolver.py | 2 +- synapse/http/matrixfederationclient.py | 2 +- synapse/logging/context.py | 4 +-- synapse/metrics/__init__.py | 4 +-- synapse/notifier.py | 4 +-- synapse/replication/tcp/streams/_base.py | 4 +-- synapse/rest/media/v1/preview_url_resource.py | 2 +- synapse/state/__init__.py | 2 +- synapse/storage/databases/main/end_to_end_keys.py | 2 +- .../storage/databases/main/event_push_actions.py | 2 +- synapse/storage/databases/main/ui_auth.py | 2 +- synapse/storage/prepare_database.py | 2 +- synapse/storage/relations.py | 2 +- synapse/util/__init__.py | 2 +- synapse/util/caches/__init__.py | 2 +- 22 files changed, 33 insertions(+), 50 deletions(-) create mode 100644 changelog.d/8296.misc (limited to 'synapse/storage/databases/main/ui_auth.py') diff --git a/changelog.d/8296.misc b/changelog.d/8296.misc new file mode 100644 index 0000000000..f593a5b347 --- /dev/null +++ b/changelog.d/8296.misc @@ -0,0 +1 @@ +Use slotted classes where possible. diff --git a/synapse/handlers/acme_issuing_service.py b/synapse/handlers/acme_issuing_service.py index 69650ff221..7294649d71 100644 --- a/synapse/handlers/acme_issuing_service.py +++ b/synapse/handlers/acme_issuing_service.py @@ -76,7 +76,7 @@ def create_issuing_service(reactor, acme_url, account_key_file, well_known_resou ) -@attr.s +@attr.s(slots=True) @implementer(ICertificateStore) class ErsatzStore: """ diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 90189869cc..4e658d9a48 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -1235,7 +1235,7 @@ class AuthHandler(BaseHandler): return urllib.parse.urlunparse(url_parts) -@attr.s +@attr.s(slots=True) class MacaroonGenerator: hs = attr.ib() diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index d629c7c16c..dd40fd1299 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -1201,7 +1201,7 @@ def _one_time_keys_match(old_key_json, new_key): return old_key == new_key_copy -@attr.s +@attr.s(slots=True) class SignatureListItem: """An item in the signature list as used by upload_signatures_for_device_keys. """ diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index a5734bebab..262901363f 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -86,7 +86,7 @@ from synapse.visibility import filter_events_for_server logger = logging.getLogger(__name__) -@attr.s +@attr.s(slots=True) class _NewEventInfo: """Holds information about a received event, ready for passing to _handle_new_events diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py index 8715abd4d1..285c481a96 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml_handler.py @@ -46,7 +46,7 @@ class MappingException(Exception): """Used to catch errors when mapping the SAML2 response to a user.""" -@attr.s +@attr.s(slots=True) class Saml2SessionData: """Data we track about SAML2 sessions""" diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index a615c7c2f0..9b3a4f638b 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -89,14 +89,12 @@ class TimelineBatch: events = attr.ib(type=List[EventBase]) limited = attr.ib(bool) - def __nonzero__(self) -> bool: + def __bool__(self) -> bool: """Make the result appear empty if there are no updates. This is used to tell if room needs to be part of the sync result. """ return bool(self.events) - __bool__ = __nonzero__ # python3 - # We can't freeze this class, because we need to update it after it's instantiated to # update its unread count. This is because we calculate the unread count for a room only @@ -114,7 +112,7 @@ class JoinedSyncResult: summary = attr.ib(type=Optional[JsonDict]) unread_count = attr.ib(type=int) - def __nonzero__(self) -> bool: + def __bool__(self) -> bool: """Make the result appear empty if there are no updates. This is used to tell if room needs to be part of the sync result. """ @@ -127,8 +125,6 @@ class JoinedSyncResult: # else in the result, we don't need to send it. ) - __bool__ = __nonzero__ # python3 - @attr.s(slots=True, frozen=True) class ArchivedSyncResult: @@ -137,26 +133,22 @@ class ArchivedSyncResult: state = attr.ib(type=StateMap[EventBase]) account_data = attr.ib(type=List[JsonDict]) - def __nonzero__(self) -> bool: + def __bool__(self) -> bool: """Make the result appear empty if there are no updates. This is used to tell if room needs to be part of the sync result. """ return bool(self.timeline or self.state or self.account_data) - __bool__ = __nonzero__ # python3 - @attr.s(slots=True, frozen=True) class InvitedSyncResult: room_id = attr.ib(type=str) invite = attr.ib(type=EventBase) - def __nonzero__(self) -> bool: + def __bool__(self) -> bool: """Invited rooms should always be reported to the client""" return True - __bool__ = __nonzero__ # python3 - @attr.s(slots=True, frozen=True) class GroupsSyncResult: @@ -164,11 +156,9 @@ class GroupsSyncResult: invite = attr.ib(type=JsonDict) leave = attr.ib(type=JsonDict) - def __nonzero__(self) -> bool: + def __bool__(self) -> bool: return bool(self.join or self.invite or self.leave) - __bool__ = __nonzero__ # python3 - @attr.s(slots=True, frozen=True) class DeviceLists: @@ -181,13 +171,11 @@ class DeviceLists: changed = attr.ib(type=Collection[str]) left = attr.ib(type=Collection[str]) - def __nonzero__(self) -> bool: + def __bool__(self) -> bool: return bool(self.changed or self.left) - __bool__ = __nonzero__ # python3 - -@attr.s +@attr.s(slots=True) class _RoomChanges: """The set of room entries to include in the sync, plus the set of joined and left room IDs since last sync. @@ -227,7 +215,7 @@ class SyncResult: device_one_time_keys_count = attr.ib(type=JsonDict) groups = attr.ib(type=Optional[GroupsSyncResult]) - def __nonzero__(self) -> bool: + def __bool__(self) -> bool: """Make the result appear empty if there are no updates. This is used to tell if the notifier needs to wait for more events when polling for events. @@ -243,8 +231,6 @@ class SyncResult: or self.groups ) - __bool__ = __nonzero__ # python3 - class SyncHandler: def __init__(self, hs: "HomeServer"): @@ -2038,7 +2024,7 @@ def _calculate_state( return {event_id_to_key[e]: e for e in state_ids} -@attr.s +@attr.s(slots=True) class SyncResultBuilder: """Used to help build up a new SyncResult for a user @@ -2074,7 +2060,7 @@ class SyncResultBuilder: to_device = attr.ib(type=List[JsonDict], default=attr.Factory(list)) -@attr.s +@attr.s(slots=True) class RoomSyncResultBuilder: """Stores information needed to create either a `JoinedSyncResult` or `ArchivedSyncResult`. diff --git a/synapse/http/federation/well_known_resolver.py b/synapse/http/federation/well_known_resolver.py index e6f067ca29..a306faa267 100644 --- a/synapse/http/federation/well_known_resolver.py +++ b/synapse/http/federation/well_known_resolver.py @@ -311,7 +311,7 @@ def _parse_cache_control(headers: Headers) -> Dict[bytes, Optional[bytes]]: return cache_controls -@attr.s() +@attr.s(slots=True) class _FetchWellKnownFailure(Exception): # True if we didn't get a non-5xx HTTP response, i.e. this may or may not be # a temporary failure. diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 5eaf3151ce..3c86cbc546 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -76,7 +76,7 @@ MAXINT = sys.maxsize _next_id = 1 -@attr.s(frozen=True) +@attr.s(slots=True, frozen=True) class MatrixFederationRequest: method = attr.ib() """HTTP method diff --git a/synapse/logging/context.py b/synapse/logging/context.py index 22598e02d2..2e282d9d67 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py @@ -217,11 +217,9 @@ class _Sentinel: def record_event_fetch(self, event_count): pass - def __nonzero__(self): + def __bool__(self): return False - __bool__ = __nonzero__ # python3 - SENTINEL_CONTEXT = _Sentinel() diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py index 2643380d9e..a1f7ca3449 100644 --- a/synapse/metrics/__init__.py +++ b/synapse/metrics/__init__.py @@ -59,7 +59,7 @@ class RegistryProxy: yield metric -@attr.s(hash=True) +@attr.s(slots=True, hash=True) class LaterGauge: name = attr.ib(type=str) @@ -205,7 +205,7 @@ class InFlightGauge: all_gauges[self.name] = self -@attr.s(hash=True) +@attr.s(slots=True, hash=True) class BucketCollector: """ Like a Histogram, but allows buckets to be point-in-time instead of diff --git a/synapse/notifier.py b/synapse/notifier.py index 12cd84b27b..a8fd3ef886 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -164,11 +164,9 @@ class _NotifierUserStream: class EventStreamResult(namedtuple("EventStreamResult", ("events", "tokens"))): - def __nonzero__(self): + def __bool__(self): return bool(self.events) - __bool__ = __nonzero__ # python3 - class Notifier: """ This class is responsible for notifying any listeners when there are diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 682d47f402..1f609f158c 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -383,7 +383,7 @@ class CachesStream(Stream): the cache on the workers """ - @attr.s + @attr.s(slots=True) class CachesStreamRow: """Stream to inform workers they should invalidate their cache. @@ -441,7 +441,7 @@ class DeviceListsStream(Stream): told about a device update. """ - @attr.s + @attr.s(slots=True) class DeviceListsStreamRow: entity = attr.ib(type=str) diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index cd8c246594..987765e877 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -102,7 +102,7 @@ for endpoint, globs in _oembed_globs.items(): _oembed_patterns[re.compile(pattern)] = endpoint -@attr.s +@attr.s(slots=True) class OEmbedResult: # Either HTML content or URL must be provided. html = attr.ib(type=Optional[str]) diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index c7e3015b5d..56d6afb863 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -678,7 +678,7 @@ def resolve_events_with_store( ) -@attr.s +@attr.s(slots=True) class StateResolutionStore: """Interface that allows state resolution algorithms to access the database in well defined way. diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index fba3098ea2..c8df0bcb3f 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -35,7 +35,7 @@ if TYPE_CHECKING: from synapse.handlers.e2e_keys import SignatureListItem -@attr.s +@attr.s(slots=True) class DeviceKeyLookupResult: """The type returned by get_e2e_device_keys_and_signatures""" diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index 5233ed83e2..7805fb814e 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -969,7 +969,7 @@ def _action_has_highlight(actions): return False -@attr.s +@attr.s(slots=True) class _EventPushSummary: """Summary of pending event push actions for a given user in a given room. Used in _rotate_notifs_before_txn to manipulate results from event_push_actions. diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py index b89668d561..3b9211a6d2 100644 --- a/synapse/storage/databases/main/ui_auth.py +++ b/synapse/storage/databases/main/ui_auth.py @@ -23,7 +23,7 @@ from synapse.types import JsonDict from synapse.util import json_encoder, stringutils -@attr.s +@attr.s(slots=True) class UIAuthSessionData: session_id = attr.ib(type=str) # The dictionary from the client root level, not the 'auth' key. diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index a7f2dfb850..4957e77f4c 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -638,7 +638,7 @@ def _get_or_create_schema_state(txn, database_engine): return None -@attr.s() +@attr.s(slots=True) class _DirectoryListing: """Helper class to store schema file name and the absolute path to it. diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py index d30e3f11e7..cec96ad6a7 100644 --- a/synapse/storage/relations.py +++ b/synapse/storage/relations.py @@ -22,7 +22,7 @@ from synapse.api.errors import SynapseError logger = logging.getLogger(__name__) -@attr.s +@attr.s(slots=True) class PaginationChunk: """Returned by relation pagination APIs. diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index 60ecc498ab..d55b93d763 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -45,7 +45,7 @@ def unwrapFirstError(failure): return failure.value.subFailure -@attr.s +@attr.s(slots=True) class Clock: """ A Clock wraps a Twisted reactor and provides utilities on top of it. diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py index 237f588658..8fc05be278 100644 --- a/synapse/util/caches/__init__.py +++ b/synapse/util/caches/__init__.py @@ -42,7 +42,7 @@ response_cache_evicted = Gauge( response_cache_total = Gauge("synapse_util_caches_response_cache:total", "", ["name"]) -@attr.s +@attr.s(slots=True) class CacheMetric: _cache = attr.ib() -- cgit 1.5.1