diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index ab7ef8f950..d37252b6b3 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -46,6 +46,7 @@ from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.rest.client.login import LoginResponse
+from synapse.storage import DataStore
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.databases.main.roommember import ProfileInfo
from synapse.storage.state import StateFilter
@@ -61,6 +62,7 @@ from synapse.util import Clock
from synapse.util.caches.descriptors import cached
if TYPE_CHECKING:
+ from synapse.app.generic_worker import GenericWorkerSlavedStore
from synapse.server import HomeServer
"""
@@ -111,7 +113,9 @@ class ModuleApi:
def __init__(self, hs: "HomeServer", auth_handler):
self._hs = hs
- self._store = hs.get_datastore()
+ # TODO: Fix this type hint once the types for the data stores have been ironed
+ # out.
+ self._store: Union[DataStore, "GenericWorkerSlavedStore"] = hs.get_datastore()
self._auth = hs.get_auth()
self._auth_handler = auth_handler
self._server_name = hs.hostname
diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index b81d9218ce..1dc7f0ebe3 100644
--- a/synapse/storage/databases/main/client_ips.py
+++ b/synapse/storage/databases/main/client_ips.py
@@ -478,6 +478,58 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
return {(d["user_id"], d["device_id"]): d for d in res}
+ async def get_user_ip_and_agents(
+ self, user: UserID, since_ts: int = 0
+ ) -> List[LastConnectionInfo]:
+ """Fetch the IPs and user agents for a user since the given timestamp.
+
+ The result might be slightly out of date as client IPs are inserted in batches.
+
+ Args:
+ user: The user for which to fetch IP addresses and user agents.
+ since_ts: The timestamp after which to fetch IP addresses and user agents,
+ in milliseconds.
+
+ Returns:
+ A list of dictionaries, each containing:
+ * `access_token`: The access token used.
+ * `ip`: The IP address used.
+ * `user_agent`: The last user agent seen for this access token and IP
+ address combination.
+ * `last_seen`: The timestamp at which this access token and IP address
+ combination was last seen, in milliseconds.
+
+ Only the latest user agent for each access token and IP address combination
+ is available.
+ """
+ user_id = user.to_string()
+
+ def get_recent(txn: LoggingTransaction) -> List[Tuple[str, str, str, int]]:
+ txn.execute(
+ """
+ SELECT access_token, ip, user_agent, last_seen FROM user_ips
+ WHERE last_seen >= ? AND user_id = ?
+ ORDER BY last_seen
+ DESC
+ """,
+ (since_ts, user_id),
+ )
+ return cast(List[Tuple[str, str, str, int]], txn.fetchall())
+
+ rows = await self.db_pool.runInteraction(
+ desc="get_user_ip_and_agents", func=get_recent
+ )
+
+ return [
+ {
+ "access_token": access_token,
+ "ip": ip,
+ "user_agent": user_agent,
+ "last_seen": last_seen,
+ }
+ for access_token, ip, user_agent, last_seen in rows
+ ]
+
class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore):
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
@@ -622,49 +674,43 @@ class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore):
async def get_user_ip_and_agents(
self, user: UserID, since_ts: int = 0
) -> List[LastConnectionInfo]:
+ """Fetch the IPs and user agents for a user since the given timestamp.
+
+ Args:
+ user: The user for which to fetch IP addresses and user agents.
+ since_ts: The timestamp after which to fetch IP addresses and user agents,
+ in milliseconds.
+
+ Returns:
+ A list of dictionaries, each containing:
+ * `access_token`: The access token used.
+ * `ip`: The IP address used.
+ * `user_agent`: The last user agent seen for this access token and IP
+ address combination.
+ * `last_seen`: The timestamp at which this access token and IP address
+ combination was last seen, in milliseconds.
+
+ Only the latest user agent for each access token and IP address combination
+ is available.
"""
- Fetch IP/User Agent connection since a given timestamp.
- """
- user_id = user.to_string()
- results: Dict[Tuple[str, str], Tuple[str, int]] = {}
+ results: Dict[Tuple[str, str], LastConnectionInfo] = {
+ (connection["access_token"], connection["ip"]): connection
+ for connection in await super().get_user_ip_and_agents(user, since_ts)
+ }
+ # Overlay data that is pending insertion on top of the results from the
+ # database.
+ user_id = user.to_string()
for key in self._batch_row_update:
- (
- uid,
- access_token,
- ip,
- ) = key
+ uid, access_token, ip = key
if uid == user_id:
user_agent, _, last_seen = self._batch_row_update[key]
if last_seen >= since_ts:
- results[(access_token, ip)] = (user_agent, last_seen)
-
- def get_recent(txn: LoggingTransaction) -> List[Tuple[str, str, str, int]]:
- txn.execute(
- """
- SELECT access_token, ip, user_agent, last_seen FROM user_ips
- WHERE last_seen >= ? AND user_id = ?
- ORDER BY last_seen
- DESC
- """,
- (since_ts, user_id),
- )
- return cast(List[Tuple[str, str, str, int]], txn.fetchall())
-
- rows = await self.db_pool.runInteraction(
- desc="get_user_ip_and_agents", func=get_recent
- )
+ results[(access_token, ip)] = {
+ "access_token": access_token,
+ "ip": ip,
+ "user_agent": user_agent,
+ "last_seen": last_seen,
+ }
- results.update(
- ((access_token, ip), (user_agent, last_seen))
- for access_token, ip, user_agent, last_seen in rows
- )
- return [
- {
- "access_token": access_token,
- "ip": ip,
- "user_agent": user_agent,
- "last_seen": last_seen,
- }
- for (access_token, ip), (user_agent, last_seen) in results.items()
- ]
+ return list(results.values())
|