From b8b905c4ea8a0d922d34d469f7d220f53def1b53 Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Tue, 12 Oct 2021 11:24:05 +0100 Subject: Fix inconsistent behavior of `get_last_client_by_ip` (#10970) Make `get_last_client_by_ip` return the same dictionary structure regardless of whether the data has been persisted to the database. This change will allow slightly cleaner type hints to be applied later on. --- synapse/storage/databases/main/client_ips.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) (limited to 'synapse/storage/databases/main/client_ips.py') diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py index c77acc7c84..6c1ef09049 100644 --- a/synapse/storage/databases/main/client_ips.py +++ b/synapse/storage/databases/main/client_ips.py @@ -538,15 +538,20 @@ class ClientIpStore(ClientIpWorkerStore): """ ret = await super().get_last_client_ip_by_device(user_id, device_id) - # Update what is retrieved from the database with data which is pending insertion. + # Update what is retrieved from the database with data which is pending + # insertion, as if it has already been stored in the database. for key in self._batch_row_update: - uid, access_token, ip = key + uid, _access_token, ip = key if uid == user_id: user_agent, did, last_seen = self._batch_row_update[key] + + if did is None: + # These updates don't make it to the `devices` table + continue + if not device_id or did == device_id: - ret[(user_id, device_id)] = { + ret[(user_id, did)] = { "user_id": user_id, - "access_token": access_token, "ip": ip, "user_agent": user_agent, "device_id": did, -- cgit 1.5.1 From 36224e056a0ba91b4541607c5ad5cd5152d0e672 Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Tue, 12 Oct 2021 13:50:34 +0100 Subject: Add type hints to `synapse.storage.databases.main.client_ips` (#10972) --- changelog.d/10972.misc | 1 + mypy.ini | 4 + synapse/handlers/device.py | 15 ++- synapse/module_api/__init__.py | 6 +- synapse/storage/databases/main/client_ips.py | 140 +++++++++++++++++++-------- 5 files changed, 121 insertions(+), 45 deletions(-) create mode 100644 changelog.d/10972.misc (limited to 'synapse/storage/databases/main/client_ips.py') diff --git a/changelog.d/10972.misc b/changelog.d/10972.misc new file mode 100644 index 0000000000..f66a7beaf0 --- /dev/null +++ b/changelog.d/10972.misc @@ -0,0 +1 @@ +Add type hints to `synapse.storage.databases.main.client_ips`. diff --git a/mypy.ini b/mypy.ini index a7019e2bd4..174a6edae6 100644 --- a/mypy.ini +++ b/mypy.ini @@ -53,6 +53,7 @@ files = synapse/storage/_base.py, synapse/storage/background_updates.py, synapse/storage/databases/main/appservice.py, + synapse/storage/databases/main/client_ips.py, synapse/storage/databases/main/events.py, synapse/storage/databases/main/keys.py, synapse/storage/databases/main/pusher.py, @@ -108,6 +109,9 @@ disallow_untyped_defs = True [mypy-synapse.state.*] disallow_untyped_defs = True +[mypy-synapse.storage.databases.main.client_ips] +disallow_untyped_defs = True + [mypy-synapse.storage.util.*] disallow_untyped_defs = True diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 75e6019760..6eafbea25d 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -14,7 +14,18 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple +from typing import ( + TYPE_CHECKING, + Any, + Collection, + Dict, + Iterable, + List, + Mapping, + Optional, + Set, + Tuple, +) from synapse.api import errors from synapse.api.constants import EventTypes @@ -595,7 +606,7 @@ class DeviceHandler(DeviceWorkerHandler): def _update_device_from_client_ips( - device: JsonDict, client_ips: Dict[Tuple[str, str], JsonDict] + device: JsonDict, client_ips: Mapping[Tuple[str, str], Mapping[str, Any]] ) -> None: ip = client_ips.get((device["user_id"], device["device_id"]), {}) device.update({"last_seen_ts": ip.get("last_seen"), "last_seen_ip": ip.get("ip")}) diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 8ae21bc43c..b2a228c231 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -773,9 +773,9 @@ class ModuleApi: # Sanitize some of the data. We don't want to return tokens. return [ UserIpAndAgent( - ip=str(data["ip"]), - user_agent=str(data["user_agent"]), - last_seen=int(data["last_seen"]), + ip=data["ip"], + user_agent=data["user_agent"], + last_seen=data["last_seen"], ) for data in raw_data ] diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py index 6c1ef09049..b81d9218ce 100644 --- a/synapse/storage/databases/main/client_ips.py +++ b/synapse/storage/databases/main/client_ips.py @@ -13,14 +13,26 @@ # limitations under the License. import logging -from typing import Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Union, cast + +from typing_extensions import TypedDict from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore -from synapse.storage.database import DatabasePool, make_tuple_comparison_clause -from synapse.types import UserID +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, + make_tuple_comparison_clause, +) +from synapse.storage.databases.main.monthly_active_users import MonthlyActiveUsersStore +from synapse.storage.types import Connection +from synapse.types import JsonDict, UserID from synapse.util.caches.lrucache import LruCache +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) # Number of msec of granularity to store the user IP 'last seen' time. Smaller @@ -29,8 +41,31 @@ logger = logging.getLogger(__name__) LAST_SEEN_GRANULARITY = 120 * 1000 +class DeviceLastConnectionInfo(TypedDict): + """Metadata for the last connection seen for a user and device combination""" + + # These types must match the columns in the `devices` table + user_id: str + device_id: str + + ip: Optional[str] + user_agent: Optional[str] + last_seen: Optional[int] + + +class LastConnectionInfo(TypedDict): + """Metadata for the last connection seen for an access token and IP combination""" + + # These types must match the columns in the `user_ips` table + access_token: str + ip: str + + user_agent: str + last_seen: int + + class ClientIpBackgroundUpdateStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_index_update( @@ -81,8 +116,10 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): "devices_last_seen", self._devices_last_seen_update ) - async def _remove_user_ip_nonunique(self, progress, batch_size): - def f(conn): + async def _remove_user_ip_nonunique( + self, progress: JsonDict, batch_size: int + ) -> int: + def f(conn: LoggingDatabaseConnection) -> None: txn = conn.cursor() txn.execute("DROP INDEX IF EXISTS user_ips_user_ip") txn.close() @@ -93,14 +130,14 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): ) return 1 - async def _analyze_user_ip(self, progress, batch_size): + async def _analyze_user_ip(self, progress: JsonDict, batch_size: int) -> int: # Background update to analyze user_ips table before we run the # deduplication background update. The table may not have been analyzed # for ages due to the table locks. # # This will lock out the naive upserts to user_ips while it happens, but # the analyze should be quick (28GB table takes ~10s) - def user_ips_analyze(txn): + def user_ips_analyze(txn: LoggingTransaction) -> None: txn.execute("ANALYZE user_ips") await self.db_pool.runInteraction("user_ips_analyze", user_ips_analyze) @@ -109,16 +146,16 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): return 1 - async def _remove_user_ip_dupes(self, progress, batch_size): + async def _remove_user_ip_dupes(self, progress: JsonDict, batch_size: int) -> int: # This works function works by scanning the user_ips table in batches # based on `last_seen`. For each row in a batch it searches the rest of # the table to see if there are any duplicates, if there are then they # are removed and replaced with a suitable row. # Fetch the start of the batch - begin_last_seen = progress.get("last_seen", 0) + begin_last_seen: int = progress.get("last_seen", 0) - def get_last_seen(txn): + def get_last_seen(txn: LoggingTransaction) -> Optional[int]: txn.execute( """ SELECT last_seen FROM user_ips @@ -129,7 +166,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): """, (begin_last_seen, batch_size), ) - row = txn.fetchone() + row = cast(Optional[Tuple[int]], txn.fetchone()) if row: return row[0] else: @@ -149,7 +186,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): end_last_seen, ) - def remove(txn): + def remove(txn: LoggingTransaction) -> None: # This works by looking at all entries in the given time span, and # then for each (user_id, access_token, ip) tuple in that range # checking for any duplicates in the rest of the table (via a join). @@ -161,10 +198,12 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): # Define the search space, which requires handling the last batch in # a different way + args: Tuple[int, ...] if last: clause = "? <= last_seen" args = (begin_last_seen,) else: + assert end_last_seen is not None clause = "? <= last_seen AND last_seen < ?" args = (begin_last_seen, end_last_seen) @@ -189,7 +228,9 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): ), args, ) - res = txn.fetchall() + res = cast( + List[Tuple[str, str, str, Optional[str], str, int, int]], txn.fetchall() + ) # We've got some duplicates for i in res: @@ -278,13 +319,15 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): return batch_size - async def _devices_last_seen_update(self, progress, batch_size): + async def _devices_last_seen_update( + self, progress: JsonDict, batch_size: int + ) -> int: """Background update to insert last seen info into devices table""" - last_user_id = progress.get("last_user_id", "") - last_device_id = progress.get("last_device_id", "") + last_user_id: str = progress.get("last_user_id", "") + last_device_id: str = progress.get("last_device_id", "") - def _devices_last_seen_update_txn(txn): + def _devices_last_seen_update_txn(txn: LoggingTransaction) -> int: # This consists of two queries: # # 1. The sub-query searches for the next N devices and joins @@ -296,6 +339,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): # we'll just end up updating the same device row multiple # times, which is fine. + where_args: List[Union[str, int]] where_clause, where_args = make_tuple_comparison_clause( [("user_id", last_user_id), ("device_id", last_device_id)], ) @@ -319,7 +363,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): } txn.execute(sql, where_args + [batch_size]) - rows = txn.fetchall() + rows = cast(List[Tuple[int, str, str, str, str]], txn.fetchall()) if not rows: return 0 @@ -350,7 +394,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): class ClientIpWorkerStore(ClientIpBackgroundUpdateStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): super().__init__(database, db_conn, hs) self.user_ips_max_age = hs.config.server.user_ips_max_age @@ -359,7 +403,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore): self._clock.looping_call(self._prune_old_user_ips, 5 * 1000) @wrap_as_background_process("prune_old_user_ips") - async def _prune_old_user_ips(self): + async def _prune_old_user_ips(self) -> None: """Removes entries in user IPs older than the configured period.""" if self.user_ips_max_age is None: @@ -394,9 +438,9 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore): ) """ - timestamp = self.clock.time_msec() - self.user_ips_max_age + timestamp = self._clock.time_msec() - self.user_ips_max_age - def _prune_old_user_ips_txn(txn): + def _prune_old_user_ips_txn(txn: LoggingTransaction) -> None: txn.execute(sql, (timestamp,)) await self.db_pool.runInteraction( @@ -405,7 +449,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore): async def get_last_client_ip_by_device( self, user_id: str, device_id: Optional[str] - ) -> Dict[Tuple[str, str], dict]: + ) -> Dict[Tuple[str, str], DeviceLastConnectionInfo]: """For each device_id listed, give the user_ip it was last seen on. The result might be slightly out of date as client IPs are inserted in batches. @@ -423,26 +467,32 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore): if device_id is not None: keyvalues["device_id"] = device_id - res = await self.db_pool.simple_select_list( - table="devices", - keyvalues=keyvalues, - retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"), + res = cast( + List[DeviceLastConnectionInfo], + await self.db_pool.simple_select_list( + table="devices", + keyvalues=keyvalues, + retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"), + ), ) return {(d["user_id"], d["device_id"]): d for d in res} -class ClientIpStore(ClientIpWorkerStore): - def __init__(self, database: DatabasePool, db_conn, hs): +class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore): + def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): - self.client_ip_last_seen = LruCache( + # (user_id, access_token, ip,) -> last_seen + self.client_ip_last_seen = LruCache[Tuple[str, str, str], int]( cache_name="client_ip_last_seen", max_size=50000 ) super().__init__(database, db_conn, hs) # (user_id, access_token, ip,) -> (user_agent, device_id, last_seen) - self._batch_row_update = {} + self._batch_row_update: Dict[ + Tuple[str, str, str], Tuple[str, Optional[str], int] + ] = {} self._client_ip_looper = self._clock.looping_call( self._update_client_ips_batch, 5 * 1000 @@ -452,8 +502,14 @@ class ClientIpStore(ClientIpWorkerStore): ) async def insert_client_ip( - self, user_id, access_token, ip, user_agent, device_id, now=None - ): + self, + user_id: str, + access_token: str, + ip: str, + user_agent: str, + device_id: Optional[str], + now: Optional[int] = None, + ) -> None: if not now: now = int(self._clock.time_msec()) key = (user_id, access_token, ip) @@ -485,7 +541,11 @@ class ClientIpStore(ClientIpWorkerStore): "_update_client_ips_batch", self._update_client_ips_batch_txn, to_update ) - def _update_client_ips_batch_txn(self, txn, to_update): + def _update_client_ips_batch_txn( + self, + txn: LoggingTransaction, + to_update: Mapping[Tuple[str, str, str], Tuple[str, Optional[str], int]], + ) -> None: if "user_ips" in self.db_pool._unsafe_to_upsert_tables or ( not self.database_engine.can_native_upsert ): @@ -525,7 +585,7 @@ class ClientIpStore(ClientIpWorkerStore): async def get_last_client_ip_by_device( self, user_id: str, device_id: Optional[str] - ) -> Dict[Tuple[str, str], dict]: + ) -> Dict[Tuple[str, str], DeviceLastConnectionInfo]: """For each device_id listed, give the user_ip it was last seen on Args: @@ -561,12 +621,12 @@ class ClientIpStore(ClientIpWorkerStore): async def get_user_ip_and_agents( self, user: UserID, since_ts: int = 0 - ) -> List[Dict[str, Union[str, int]]]: + ) -> List[LastConnectionInfo]: """ Fetch IP/User Agent connection since a given timestamp. """ user_id = user.to_string() - results = {} + results: Dict[Tuple[str, str], Tuple[str, int]] = {} for key in self._batch_row_update: ( @@ -579,7 +639,7 @@ class ClientIpStore(ClientIpWorkerStore): if last_seen >= since_ts: results[(access_token, ip)] = (user_agent, last_seen) - def get_recent(txn): + def get_recent(txn: LoggingTransaction) -> List[Tuple[str, str, str, int]]: txn.execute( """ SELECT access_token, ip, user_agent, last_seen FROM user_ips @@ -589,7 +649,7 @@ class ClientIpStore(ClientIpWorkerStore): """, (since_ts, user_id), ) - return txn.fetchall() + return cast(List[Tuple[str, str, str, int]], txn.fetchall()) rows = await self.db_pool.runInteraction( desc="get_user_ip_and_agents", func=get_recent -- cgit 1.5.1 From 85a09f8b8ba7c8023c0d28a526d32111fc704197 Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Mon, 25 Oct 2021 13:01:04 +0100 Subject: Fix module API's `get_user_ip_and_agents` function when run on workers (#11112) --- changelog.d/11112.bugfix | 1 + synapse/module_api/__init__.py | 6 +- synapse/storage/databases/main/client_ips.py | 124 ++++++++++++++++++--------- 3 files changed, 91 insertions(+), 40 deletions(-) create mode 100644 changelog.d/11112.bugfix (limited to 'synapse/storage/databases/main/client_ips.py') diff --git a/changelog.d/11112.bugfix b/changelog.d/11112.bugfix new file mode 100644 index 0000000000..c8e22da8cf --- /dev/null +++ b/changelog.d/11112.bugfix @@ -0,0 +1 @@ +Fix a bug which caused the module API's `get_user_ip_and_agents` function to always fail on workers. `get_user_ip_and_agents` was introduced in 1.44.0 and did not function correctly on worker processes at the time. diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index ab7ef8f950..d37252b6b3 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -46,6 +46,7 @@ from synapse.http.site import SynapseRequest from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.metrics.background_process_metrics import run_as_background_process from synapse.rest.client.login import LoginResponse +from synapse.storage import DataStore from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.databases.main.roommember import ProfileInfo from synapse.storage.state import StateFilter @@ -61,6 +62,7 @@ from synapse.util import Clock from synapse.util.caches.descriptors import cached if TYPE_CHECKING: + from synapse.app.generic_worker import GenericWorkerSlavedStore from synapse.server import HomeServer """ @@ -111,7 +113,9 @@ class ModuleApi: def __init__(self, hs: "HomeServer", auth_handler): self._hs = hs - self._store = hs.get_datastore() + # TODO: Fix this type hint once the types for the data stores have been ironed + # out. + self._store: Union[DataStore, "GenericWorkerSlavedStore"] = hs.get_datastore() self._auth = hs.get_auth() self._auth_handler = auth_handler self._server_name = hs.hostname diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py index b81d9218ce..1dc7f0ebe3 100644 --- a/synapse/storage/databases/main/client_ips.py +++ b/synapse/storage/databases/main/client_ips.py @@ -478,6 +478,58 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore): return {(d["user_id"], d["device_id"]): d for d in res} + async def get_user_ip_and_agents( + self, user: UserID, since_ts: int = 0 + ) -> List[LastConnectionInfo]: + """Fetch the IPs and user agents for a user since the given timestamp. + + The result might be slightly out of date as client IPs are inserted in batches. + + Args: + user: The user for which to fetch IP addresses and user agents. + since_ts: The timestamp after which to fetch IP addresses and user agents, + in milliseconds. + + Returns: + A list of dictionaries, each containing: + * `access_token`: The access token used. + * `ip`: The IP address used. + * `user_agent`: The last user agent seen for this access token and IP + address combination. + * `last_seen`: The timestamp at which this access token and IP address + combination was last seen, in milliseconds. + + Only the latest user agent for each access token and IP address combination + is available. + """ + user_id = user.to_string() + + def get_recent(txn: LoggingTransaction) -> List[Tuple[str, str, str, int]]: + txn.execute( + """ + SELECT access_token, ip, user_agent, last_seen FROM user_ips + WHERE last_seen >= ? AND user_id = ? + ORDER BY last_seen + DESC + """, + (since_ts, user_id), + ) + return cast(List[Tuple[str, str, str, int]], txn.fetchall()) + + rows = await self.db_pool.runInteraction( + desc="get_user_ip_and_agents", func=get_recent + ) + + return [ + { + "access_token": access_token, + "ip": ip, + "user_agent": user_agent, + "last_seen": last_seen, + } + for access_token, ip, user_agent, last_seen in rows + ] + class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore): def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): @@ -622,49 +674,43 @@ class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore): async def get_user_ip_and_agents( self, user: UserID, since_ts: int = 0 ) -> List[LastConnectionInfo]: + """Fetch the IPs and user agents for a user since the given timestamp. + + Args: + user: The user for which to fetch IP addresses and user agents. + since_ts: The timestamp after which to fetch IP addresses and user agents, + in milliseconds. + + Returns: + A list of dictionaries, each containing: + * `access_token`: The access token used. + * `ip`: The IP address used. + * `user_agent`: The last user agent seen for this access token and IP + address combination. + * `last_seen`: The timestamp at which this access token and IP address + combination was last seen, in milliseconds. + + Only the latest user agent for each access token and IP address combination + is available. """ - Fetch IP/User Agent connection since a given timestamp. - """ - user_id = user.to_string() - results: Dict[Tuple[str, str], Tuple[str, int]] = {} + results: Dict[Tuple[str, str], LastConnectionInfo] = { + (connection["access_token"], connection["ip"]): connection + for connection in await super().get_user_ip_and_agents(user, since_ts) + } + # Overlay data that is pending insertion on top of the results from the + # database. + user_id = user.to_string() for key in self._batch_row_update: - ( - uid, - access_token, - ip, - ) = key + uid, access_token, ip = key if uid == user_id: user_agent, _, last_seen = self._batch_row_update[key] if last_seen >= since_ts: - results[(access_token, ip)] = (user_agent, last_seen) - - def get_recent(txn: LoggingTransaction) -> List[Tuple[str, str, str, int]]: - txn.execute( - """ - SELECT access_token, ip, user_agent, last_seen FROM user_ips - WHERE last_seen >= ? AND user_id = ? - ORDER BY last_seen - DESC - """, - (since_ts, user_id), - ) - return cast(List[Tuple[str, str, str, int]], txn.fetchall()) - - rows = await self.db_pool.runInteraction( - desc="get_user_ip_and_agents", func=get_recent - ) + results[(access_token, ip)] = { + "access_token": access_token, + "ip": ip, + "user_agent": user_agent, + "last_seen": last_seen, + } - results.update( - ((access_token, ip), (user_agent, last_seen)) - for access_token, ip, user_agent, last_seen in rows - ) - return [ - { - "access_token": access_token, - "ip": ip, - "user_agent": user_agent, - "last_seen": last_seen, - } - for (access_token, ip), (user_agent, last_seen) in results.items() - ] + return list(results.values()) -- cgit 1.5.1