diff options
Diffstat (limited to 'synapse/storage')
-rw-r--r-- | synapse/storage/databases/main/client_ips.py | 153 | ||||
-rw-r--r-- | synapse/storage/state.py | 172 |
2 files changed, 280 insertions, 45 deletions
diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py index c77acc7c84..b81d9218ce 100644 --- a/synapse/storage/databases/main/client_ips.py +++ b/synapse/storage/databases/main/client_ips.py @@ -13,14 +13,26 @@ # limitations under the License. import logging -from typing import Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Union, cast + +from typing_extensions import TypedDict from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore -from synapse.storage.database import DatabasePool, make_tuple_comparison_clause -from synapse.types import UserID +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, + make_tuple_comparison_clause, +) +from synapse.storage.databases.main.monthly_active_users import MonthlyActiveUsersStore +from synapse.storage.types import Connection +from synapse.types import JsonDict, UserID from synapse.util.caches.lrucache import LruCache +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) # Number of msec of granularity to store the user IP 'last seen' time. Smaller @@ -29,8 +41,31 @@ logger = logging.getLogger(__name__) LAST_SEEN_GRANULARITY = 120 * 1000 +class DeviceLastConnectionInfo(TypedDict): + """Metadata for the last connection seen for a user and device combination""" + + # These types must match the columns in the `devices` table + user_id: str + device_id: str + + ip: Optional[str] + user_agent: Optional[str] + last_seen: Optional[int] + + +class LastConnectionInfo(TypedDict): + """Metadata for the last connection seen for an access token and IP combination""" + + # These types must match the columns in the `user_ips` table + access_token: str + ip: str + + user_agent: str + last_seen: int + + class ClientIpBackgroundUpdateStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_index_update( @@ -81,8 +116,10 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): "devices_last_seen", self._devices_last_seen_update ) - async def _remove_user_ip_nonunique(self, progress, batch_size): - def f(conn): + async def _remove_user_ip_nonunique( + self, progress: JsonDict, batch_size: int + ) -> int: + def f(conn: LoggingDatabaseConnection) -> None: txn = conn.cursor() txn.execute("DROP INDEX IF EXISTS user_ips_user_ip") txn.close() @@ -93,14 +130,14 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): ) return 1 - async def _analyze_user_ip(self, progress, batch_size): + async def _analyze_user_ip(self, progress: JsonDict, batch_size: int) -> int: # Background update to analyze user_ips table before we run the # deduplication background update. The table may not have been analyzed # for ages due to the table locks. # # This will lock out the naive upserts to user_ips while it happens, but # the analyze should be quick (28GB table takes ~10s) - def user_ips_analyze(txn): + def user_ips_analyze(txn: LoggingTransaction) -> None: txn.execute("ANALYZE user_ips") await self.db_pool.runInteraction("user_ips_analyze", user_ips_analyze) @@ -109,16 +146,16 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): return 1 - async def _remove_user_ip_dupes(self, progress, batch_size): + async def _remove_user_ip_dupes(self, progress: JsonDict, batch_size: int) -> int: # This works function works by scanning the user_ips table in batches # based on `last_seen`. For each row in a batch it searches the rest of # the table to see if there are any duplicates, if there are then they # are removed and replaced with a suitable row. # Fetch the start of the batch - begin_last_seen = progress.get("last_seen", 0) + begin_last_seen: int = progress.get("last_seen", 0) - def get_last_seen(txn): + def get_last_seen(txn: LoggingTransaction) -> Optional[int]: txn.execute( """ SELECT last_seen FROM user_ips @@ -129,7 +166,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): """, (begin_last_seen, batch_size), ) - row = txn.fetchone() + row = cast(Optional[Tuple[int]], txn.fetchone()) if row: return row[0] else: @@ -149,7 +186,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): end_last_seen, ) - def remove(txn): + def remove(txn: LoggingTransaction) -> None: # This works by looking at all entries in the given time span, and # then for each (user_id, access_token, ip) tuple in that range # checking for any duplicates in the rest of the table (via a join). @@ -161,10 +198,12 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): # Define the search space, which requires handling the last batch in # a different way + args: Tuple[int, ...] if last: clause = "? <= last_seen" args = (begin_last_seen,) else: + assert end_last_seen is not None clause = "? <= last_seen AND last_seen < ?" args = (begin_last_seen, end_last_seen) @@ -189,7 +228,9 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): ), args, ) - res = txn.fetchall() + res = cast( + List[Tuple[str, str, str, Optional[str], str, int, int]], txn.fetchall() + ) # We've got some duplicates for i in res: @@ -278,13 +319,15 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): return batch_size - async def _devices_last_seen_update(self, progress, batch_size): + async def _devices_last_seen_update( + self, progress: JsonDict, batch_size: int + ) -> int: """Background update to insert last seen info into devices table""" - last_user_id = progress.get("last_user_id", "") - last_device_id = progress.get("last_device_id", "") + last_user_id: str = progress.get("last_user_id", "") + last_device_id: str = progress.get("last_device_id", "") - def _devices_last_seen_update_txn(txn): + def _devices_last_seen_update_txn(txn: LoggingTransaction) -> int: # This consists of two queries: # # 1. The sub-query searches for the next N devices and joins @@ -296,6 +339,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): # we'll just end up updating the same device row multiple # times, which is fine. + where_args: List[Union[str, int]] where_clause, where_args = make_tuple_comparison_clause( [("user_id", last_user_id), ("device_id", last_device_id)], ) @@ -319,7 +363,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): } txn.execute(sql, where_args + [batch_size]) - rows = txn.fetchall() + rows = cast(List[Tuple[int, str, str, str, str]], txn.fetchall()) if not rows: return 0 @@ -350,7 +394,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): class ClientIpWorkerStore(ClientIpBackgroundUpdateStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): super().__init__(database, db_conn, hs) self.user_ips_max_age = hs.config.server.user_ips_max_age @@ -359,7 +403,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore): self._clock.looping_call(self._prune_old_user_ips, 5 * 1000) @wrap_as_background_process("prune_old_user_ips") - async def _prune_old_user_ips(self): + async def _prune_old_user_ips(self) -> None: """Removes entries in user IPs older than the configured period.""" if self.user_ips_max_age is None: @@ -394,9 +438,9 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore): ) """ - timestamp = self.clock.time_msec() - self.user_ips_max_age + timestamp = self._clock.time_msec() - self.user_ips_max_age - def _prune_old_user_ips_txn(txn): + def _prune_old_user_ips_txn(txn: LoggingTransaction) -> None: txn.execute(sql, (timestamp,)) await self.db_pool.runInteraction( @@ -405,7 +449,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore): async def get_last_client_ip_by_device( self, user_id: str, device_id: Optional[str] - ) -> Dict[Tuple[str, str], dict]: + ) -> Dict[Tuple[str, str], DeviceLastConnectionInfo]: """For each device_id listed, give the user_ip it was last seen on. The result might be slightly out of date as client IPs are inserted in batches. @@ -423,26 +467,32 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore): if device_id is not None: keyvalues["device_id"] = device_id - res = await self.db_pool.simple_select_list( - table="devices", - keyvalues=keyvalues, - retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"), + res = cast( + List[DeviceLastConnectionInfo], + await self.db_pool.simple_select_list( + table="devices", + keyvalues=keyvalues, + retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"), + ), ) return {(d["user_id"], d["device_id"]): d for d in res} -class ClientIpStore(ClientIpWorkerStore): - def __init__(self, database: DatabasePool, db_conn, hs): +class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore): + def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): - self.client_ip_last_seen = LruCache( + # (user_id, access_token, ip,) -> last_seen + self.client_ip_last_seen = LruCache[Tuple[str, str, str], int]( cache_name="client_ip_last_seen", max_size=50000 ) super().__init__(database, db_conn, hs) # (user_id, access_token, ip,) -> (user_agent, device_id, last_seen) - self._batch_row_update = {} + self._batch_row_update: Dict[ + Tuple[str, str, str], Tuple[str, Optional[str], int] + ] = {} self._client_ip_looper = self._clock.looping_call( self._update_client_ips_batch, 5 * 1000 @@ -452,8 +502,14 @@ class ClientIpStore(ClientIpWorkerStore): ) async def insert_client_ip( - self, user_id, access_token, ip, user_agent, device_id, now=None - ): + self, + user_id: str, + access_token: str, + ip: str, + user_agent: str, + device_id: Optional[str], + now: Optional[int] = None, + ) -> None: if not now: now = int(self._clock.time_msec()) key = (user_id, access_token, ip) @@ -485,7 +541,11 @@ class ClientIpStore(ClientIpWorkerStore): "_update_client_ips_batch", self._update_client_ips_batch_txn, to_update ) - def _update_client_ips_batch_txn(self, txn, to_update): + def _update_client_ips_batch_txn( + self, + txn: LoggingTransaction, + to_update: Mapping[Tuple[str, str, str], Tuple[str, Optional[str], int]], + ) -> None: if "user_ips" in self.db_pool._unsafe_to_upsert_tables or ( not self.database_engine.can_native_upsert ): @@ -525,7 +585,7 @@ class ClientIpStore(ClientIpWorkerStore): async def get_last_client_ip_by_device( self, user_id: str, device_id: Optional[str] - ) -> Dict[Tuple[str, str], dict]: + ) -> Dict[Tuple[str, str], DeviceLastConnectionInfo]: """For each device_id listed, give the user_ip it was last seen on Args: @@ -538,15 +598,20 @@ class ClientIpStore(ClientIpWorkerStore): """ ret = await super().get_last_client_ip_by_device(user_id, device_id) - # Update what is retrieved from the database with data which is pending insertion. + # Update what is retrieved from the database with data which is pending + # insertion, as if it has already been stored in the database. for key in self._batch_row_update: - uid, access_token, ip = key + uid, _access_token, ip = key if uid == user_id: user_agent, did, last_seen = self._batch_row_update[key] + + if did is None: + # These updates don't make it to the `devices` table + continue + if not device_id or did == device_id: - ret[(user_id, device_id)] = { + ret[(user_id, did)] = { "user_id": user_id, - "access_token": access_token, "ip": ip, "user_agent": user_agent, "device_id": did, @@ -556,12 +621,12 @@ class ClientIpStore(ClientIpWorkerStore): async def get_user_ip_and_agents( self, user: UserID, since_ts: int = 0 - ) -> List[Dict[str, Union[str, int]]]: + ) -> List[LastConnectionInfo]: """ Fetch IP/User Agent connection since a given timestamp. """ user_id = user.to_string() - results = {} + results: Dict[Tuple[str, str], Tuple[str, int]] = {} for key in self._batch_row_update: ( @@ -574,7 +639,7 @@ class ClientIpStore(ClientIpWorkerStore): if last_seen >= since_ts: results[(access_token, ip)] = (user_agent, last_seen) - def get_recent(txn): + def get_recent(txn: LoggingTransaction) -> List[Tuple[str, str, str, int]]: txn.execute( """ SELECT access_token, ip, user_agent, last_seen FROM user_ips @@ -584,7 +649,7 @@ class ClientIpStore(ClientIpWorkerStore): """, (since_ts, user_id), ) - return txn.fetchall() + return cast(List[Tuple[str, str, str, int]], txn.fetchall()) rows = await self.db_pool.runInteraction( desc="get_user_ip_and_agents", func=get_recent diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 5e86befde4..b5ba1560d1 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -15,9 +15,11 @@ import logging from typing import ( TYPE_CHECKING, Awaitable, + Collection, Dict, Iterable, List, + Mapping, Optional, Set, Tuple, @@ -29,7 +31,7 @@ from frozendict import frozendict from synapse.api.constants import EventTypes from synapse.events import EventBase -from synapse.types import MutableStateMap, StateMap +from synapse.types import MutableStateMap, StateKey, StateMap if TYPE_CHECKING: from typing import FrozenSet # noqa: used within quoted type hint; flake8 sad @@ -134,6 +136,23 @@ class StateFilter: include_others=True, ) + @staticmethod + def freeze(types: Mapping[str, Optional[Collection[str]]], include_others: bool): + """ + Returns a (frozen) StateFilter with the same contents as the parameters + specified here, which can be made of mutable types. + """ + types_with_frozen_values: Dict[str, Optional[FrozenSet[str]]] = {} + for state_types, state_keys in types.items(): + if state_keys is not None: + types_with_frozen_values[state_types] = frozenset(state_keys) + else: + types_with_frozen_values[state_types] = None + + return StateFilter( + frozendict(types_with_frozen_values), include_others=include_others + ) + def return_expanded(self) -> "StateFilter": """Creates a new StateFilter where type wild cards have been removed (except for memberships). The returned filter is a superset of the @@ -356,6 +375,157 @@ class StateFilter: return member_filter, non_member_filter + def _decompose_into_four_parts( + self, + ) -> Tuple[Tuple[bool, Set[str]], Tuple[Set[str], Set[StateKey]]]: + """ + Decomposes this state filter into 4 constituent parts, which can be + thought of as this: + all? - minus_wildcards + plus_wildcards + plus_state_keys + + where + * all represents ALL state + * minus_wildcards represents entire state types to remove + * plus_wildcards represents entire state types to add + * plus_state_keys represents individual state keys to add + + See `recompose_from_four_parts` for the other direction of this + correspondence. + """ + is_all = self.include_others + excluded_types: Set[str] = {t for t in self.types if is_all} + wildcard_types: Set[str] = {t for t, s in self.types.items() if s is None} + concrete_keys: Set[StateKey] = set(self.concrete_types()) + + return (is_all, excluded_types), (wildcard_types, concrete_keys) + + @staticmethod + def _recompose_from_four_parts( + all_part: bool, + minus_wildcards: Set[str], + plus_wildcards: Set[str], + plus_state_keys: Set[StateKey], + ) -> "StateFilter": + """ + Recomposes a state filter from 4 parts. + + See `decompose_into_four_parts` (the other direction of this + correspondence) for descriptions on each of the parts. + """ + + # {state type -> set of state keys OR None for wildcard} + # (The same structure as that of a StateFilter.) + new_types: Dict[str, Optional[Set[str]]] = {} + + # if we start with all, insert the excluded statetypes as empty sets + # to prevent them from being included + if all_part: + new_types.update({state_type: set() for state_type in minus_wildcards}) + + # insert the plus wildcards + new_types.update({state_type: None for state_type in plus_wildcards}) + + # insert the specific state keys + for state_type, state_key in plus_state_keys: + if state_type in new_types: + entry = new_types[state_type] + if entry is not None: + entry.add(state_key) + elif not all_part: + # don't insert if the entire type is already included by + # include_others as this would actually shrink the state allowed + # by this filter. + new_types[state_type] = {state_key} + + return StateFilter.freeze(new_types, include_others=all_part) + + def approx_difference(self, other: "StateFilter") -> "StateFilter": + """ + Returns a state filter which represents `self - other`. + + This is useful for determining what state remains to be pulled out of the + database if we want the state included by `self` but already have the state + included by `other`. + + The returned state filter + - MUST include all state events that are included by this filter (`self`) + unless they are included by `other`; + - MUST NOT include state events not included by this filter (`self`); and + - MAY be an over-approximation: the returned state filter + MAY additionally include some state events from `other`. + + This implementation attempts to return the narrowest such state filter. + In the case that `self` contains wildcards for state types where + `other` contains specific state keys, an approximation must be made: + the returned state filter keeps the wildcard, as state filters are not + able to express 'all state keys except some given examples'. + e.g. + StateFilter(m.room.member -> None (wildcard)) + minus + StateFilter(m.room.member -> {'@wombat:example.org'}) + is approximated as + StateFilter(m.room.member -> None (wildcard)) + """ + + # We first transform self and other into an alternative representation: + # - whether or not they include all events to begin with ('all') + # - if so, which event types are excluded? ('excludes') + # - which entire event types to include ('wildcards') + # - which concrete state keys to include ('concrete state keys') + (self_all, self_excludes), ( + self_wildcards, + self_concrete_keys, + ) = self._decompose_into_four_parts() + (other_all, other_excludes), ( + other_wildcards, + other_concrete_keys, + ) = other._decompose_into_four_parts() + + # Start with an estimate of the difference based on self + new_all = self_all + # Wildcards from the other can be added to the exclusion filter + new_excludes = self_excludes | other_wildcards + # We remove wildcards that appeared as wildcards in the other + new_wildcards = self_wildcards - other_wildcards + # We filter out the concrete state keys that appear in the other + # as wildcards or concrete state keys. + new_concrete_keys = { + (state_type, state_key) + for (state_type, state_key) in self_concrete_keys + if state_type not in other_wildcards + } - other_concrete_keys + + if other_all: + if self_all: + # If self starts with all, then we add as wildcards any + # types which appear in the other's exclusion filter (but + # aren't in the self exclusion filter). This is as the other + # filter will return everything BUT the types in its exclusion, so + # we need to add those excluded types that also match the self + # filter as wildcard types in the new filter. + new_wildcards |= other_excludes.difference(self_excludes) + + # If other is an `include_others` then the difference isn't. + new_all = False + # (We have no need for excludes when we don't start with all, as there + # is nothing to exclude.) + new_excludes = set() + + # We also filter out all state types that aren't in the exclusion + # list of the other. + new_wildcards &= other_excludes + new_concrete_keys = { + (state_type, state_key) + for (state_type, state_key) in new_concrete_keys + if state_type in other_excludes + } + + # Transform our newly-constructed state filter from the alternative + # representation back into the normal StateFilter representation. + return StateFilter._recompose_from_four_parts( + new_all, new_excludes, new_wildcards, new_concrete_keys + ) + class StateGroupStorage: """High level interface to fetching state for event.""" |