summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/module_api/__init__.py6
-rw-r--r--synapse/storage/databases/main/client_ips.py124
2 files changed, 90 insertions, 40 deletions
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index ab7ef8f950..d37252b6b3 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -46,6 +46,7 @@ from synapse.http.site import SynapseRequest
 from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.rest.client.login import LoginResponse
+from synapse.storage import DataStore
 from synapse.storage.database import DatabasePool, LoggingTransaction
 from synapse.storage.databases.main.roommember import ProfileInfo
 from synapse.storage.state import StateFilter
@@ -61,6 +62,7 @@ from synapse.util import Clock
 from synapse.util.caches.descriptors import cached
 
 if TYPE_CHECKING:
+    from synapse.app.generic_worker import GenericWorkerSlavedStore
     from synapse.server import HomeServer
 
 """
@@ -111,7 +113,9 @@ class ModuleApi:
     def __init__(self, hs: "HomeServer", auth_handler):
         self._hs = hs
 
-        self._store = hs.get_datastore()
+        # TODO: Fix this type hint once the types for the data stores have been ironed
+        #       out.
+        self._store: Union[DataStore, "GenericWorkerSlavedStore"] = hs.get_datastore()
         self._auth = hs.get_auth()
         self._auth_handler = auth_handler
         self._server_name = hs.hostname
diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index b81d9218ce..1dc7f0ebe3 100644
--- a/synapse/storage/databases/main/client_ips.py
+++ b/synapse/storage/databases/main/client_ips.py
@@ -478,6 +478,58 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
 
         return {(d["user_id"], d["device_id"]): d for d in res}
 
+    async def get_user_ip_and_agents(
+        self, user: UserID, since_ts: int = 0
+    ) -> List[LastConnectionInfo]:
+        """Fetch the IPs and user agents for a user since the given timestamp.
+
+        The result might be slightly out of date as client IPs are inserted in batches.
+
+        Args:
+            user: The user for which to fetch IP addresses and user agents.
+            since_ts: The timestamp after which to fetch IP addresses and user agents,
+                in milliseconds.
+
+        Returns:
+            A list of dictionaries, each containing:
+             * `access_token`: The access token used.
+             * `ip`: The IP address used.
+             * `user_agent`: The last user agent seen for this access token and IP
+               address combination.
+             * `last_seen`: The timestamp at which this access token and IP address
+               combination was last seen, in milliseconds.
+
+            Only the latest user agent for each access token and IP address combination
+            is available.
+        """
+        user_id = user.to_string()
+
+        def get_recent(txn: LoggingTransaction) -> List[Tuple[str, str, str, int]]:
+            txn.execute(
+                """
+                SELECT access_token, ip, user_agent, last_seen FROM user_ips
+                WHERE last_seen >= ? AND user_id = ?
+                ORDER BY last_seen
+                DESC
+                """,
+                (since_ts, user_id),
+            )
+            return cast(List[Tuple[str, str, str, int]], txn.fetchall())
+
+        rows = await self.db_pool.runInteraction(
+            desc="get_user_ip_and_agents", func=get_recent
+        )
+
+        return [
+            {
+                "access_token": access_token,
+                "ip": ip,
+                "user_agent": user_agent,
+                "last_seen": last_seen,
+            }
+            for access_token, ip, user_agent, last_seen in rows
+        ]
+
 
 class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore):
     def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
@@ -622,49 +674,43 @@ class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore):
     async def get_user_ip_and_agents(
         self, user: UserID, since_ts: int = 0
     ) -> List[LastConnectionInfo]:
+        """Fetch the IPs and user agents for a user since the given timestamp.
+
+        Args:
+            user: The user for which to fetch IP addresses and user agents.
+            since_ts: The timestamp after which to fetch IP addresses and user agents,
+                in milliseconds.
+
+        Returns:
+            A list of dictionaries, each containing:
+             * `access_token`: The access token used.
+             * `ip`: The IP address used.
+             * `user_agent`: The last user agent seen for this access token and IP
+               address combination.
+             * `last_seen`: The timestamp at which this access token and IP address
+               combination was last seen, in milliseconds.
+
+            Only the latest user agent for each access token and IP address combination
+            is available.
         """
-        Fetch IP/User Agent connection since a given timestamp.
-        """
-        user_id = user.to_string()
-        results: Dict[Tuple[str, str], Tuple[str, int]] = {}
+        results: Dict[Tuple[str, str], LastConnectionInfo] = {
+            (connection["access_token"], connection["ip"]): connection
+            for connection in await super().get_user_ip_and_agents(user, since_ts)
+        }
 
+        # Overlay data that is pending insertion on top of the results from the
+        # database.
+        user_id = user.to_string()
         for key in self._batch_row_update:
-            (
-                uid,
-                access_token,
-                ip,
-            ) = key
+            uid, access_token, ip = key
             if uid == user_id:
                 user_agent, _, last_seen = self._batch_row_update[key]
                 if last_seen >= since_ts:
-                    results[(access_token, ip)] = (user_agent, last_seen)
-
-        def get_recent(txn: LoggingTransaction) -> List[Tuple[str, str, str, int]]:
-            txn.execute(
-                """
-                SELECT access_token, ip, user_agent, last_seen FROM user_ips
-                WHERE last_seen >= ? AND user_id = ?
-                ORDER BY last_seen
-                DESC
-                """,
-                (since_ts, user_id),
-            )
-            return cast(List[Tuple[str, str, str, int]], txn.fetchall())
-
-        rows = await self.db_pool.runInteraction(
-            desc="get_user_ip_and_agents", func=get_recent
-        )
+                    results[(access_token, ip)] = {
+                        "access_token": access_token,
+                        "ip": ip,
+                        "user_agent": user_agent,
+                        "last_seen": last_seen,
+                    }
 
-        results.update(
-            ((access_token, ip), (user_agent, last_seen))
-            for access_token, ip, user_agent, last_seen in rows
-        )
-        return [
-            {
-                "access_token": access_token,
-                "ip": ip,
-                "user_agent": user_agent,
-                "last_seen": last_seen,
-            }
-            for (access_token, ip), (user_agent, last_seen) in results.items()
-        ]
+        return list(results.values())