summary refs log tree commit diff
path: root/synapse/storage/client_ips.py
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2017-06-27 14:46:12 +0100
committerErik Johnston <erik@matrix.org>2017-06-27 14:46:12 +0100
commita0a561ae852c6a031f8611859d95e2ad563770ca (patch)
treec6e4bd6bd58526f2fcf21d96965c977aacee4d73 /synapse/storage/client_ips.py
parentBatch upsert user ips (diff)
downloadsynapse-a0a561ae852c6a031f8611859d95e2ad563770ca.tar.xz
Fix up client ips to read from pending data
Diffstat (limited to 'synapse/storage/client_ips.py')
-rw-r--r--synapse/storage/client_ips.py70
1 files changed, 59 insertions, 11 deletions
diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py
index 68955e740b..88a5eb232f 100644
--- a/synapse/storage/client_ips.py
+++ b/synapse/storage/client_ips.py
@@ -50,6 +50,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
             columns=["user_id", "device_id", "last_seen"],
         )
 
+        # (user_id, access_token, ip) -> (user_agent, device_id, last_seen)
         self._batch_row_update = {}
 
         self._client_ip_looper = self._clock.looping_call(
@@ -104,11 +105,12 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
             )
 
     @defer.inlineCallbacks
-    def get_last_client_ip_by_device(self, devices):
+    def get_last_client_ip_by_device(self, user_id, device_id):
         """For each device_id listed, give the user_ip it was last seen on
 
         Args:
-            devices (iterable[(str, str)]):  list of (user_id, device_id) pairs
+            user_id (str)
+            device_id (str): If None fetches all devices for the user
 
         Returns:
             defer.Deferred: resolves to a dict, where the keys
@@ -119,6 +121,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
         res = yield self.runInteraction(
             "get_last_client_ip_by_device",
             self._get_last_client_ip_by_device_txn,
+            user_id, device_id,
             retcols=(
                 "user_id",
                 "access_token",
@@ -127,23 +130,34 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
                 "device_id",
                 "last_seen",
             ),
-            devices=devices
         )
 
         ret = {(d["user_id"], d["device_id"]): d for d in res}
+        for key in self._batch_row_update:
+            uid, access_token, ip = key
+            if uid == user_id:
+                user_agent, did, last_seen = self._batch_row_update[key]
+                if not device_id or did == device_id:
+                    ret[(user_id, device_id)] = {
+                        "user_id": user_id,
+                        "access_token": access_token,
+                        "ip": ip,
+                        "user_agent": user_agent,
+                        "device_id": did,
+                        "last_seen": last_seen,
+                    }
         defer.returnValue(ret)
 
     @classmethod
-    def _get_last_client_ip_by_device_txn(cls, txn, devices, retcols):
+    def _get_last_client_ip_by_device_txn(cls, txn, user_id, device_id, retcols):
         where_clauses = []
         bindings = []
-        for (user_id, device_id) in devices:
-            if device_id is None:
-                where_clauses.append("(user_id = ? AND device_id IS NULL)")
-                bindings.extend((user_id, ))
-            else:
-                where_clauses.append("(user_id = ? AND device_id = ?)")
-                bindings.extend((user_id, device_id))
+        if device_id is None:
+            where_clauses.append("user_id = ?")
+            bindings.extend((user_id, ))
+        else:
+            where_clauses.append("(user_id = ? AND device_id = ?)")
+            bindings.extend((user_id, device_id))
 
         if not where_clauses:
             return []
@@ -171,3 +185,37 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
 
         txn.execute(sql, bindings)
         return cls.cursor_to_dict(txn)
+
+    @defer.inlineCallbacks
+    def get_user_ip_and_agents(self, user):
+        user_id = user.to_string()
+        results = {}
+
+        for key in self._batch_row_update:
+            uid, access_token, ip = key
+            if uid == user_id:
+                user_agent, _, last_seen = self._batch_row_update[key]
+                results[(access_token, ip)] = (user_agent, last_seen)
+
+        rows = yield self._simple_select_list(
+            table="user_ips",
+            keyvalues={"user_id": user_id},
+            retcols=[
+                "access_token", "ip", "user_agent", "last_seen"
+            ],
+            desc="get_user_ip_and_agents",
+        )
+
+        results.update(
+            ((row["access_token"], row["ip"]), (row["user_agent"], row["last_seen"]))
+            for row in rows
+        )
+        defer.returnValue(list(
+            {
+                "access_token": access_token,
+                "ip": ip,
+                "user_agent": user_agent,
+                "last_seen": last_seen,
+            }
+            for (access_token, ip), (user_agent, last_seen) in results.iteritems()
+        ))