diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 982cda3edf..ed60d494ff 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -106,7 +106,7 @@ class DeviceHandler(BaseHandler):
device_map = yield self.store.get_devices_by_user(user_id)
ips = yield self.store.get_last_client_ip_by_device(
- devices=((user_id, device_id) for device_id in device_map.keys())
+ user_id, device_id=None
)
devices = device_map.values()
@@ -133,7 +133,7 @@ class DeviceHandler(BaseHandler):
except errors.StoreError:
raise errors.NotFoundError
ips = yield self.store.get_last_client_ip_by_device(
- devices=((user_id, device_id),)
+ user_id, device_id,
)
_update_device_from_client_ips(device, ips)
defer.returnValue(device)
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index f119c5a758..b92472df33 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -304,16 +304,6 @@ class DataStore(RoomMemberStore, RoomStore,
ret = yield self.runInteraction("count_users", _count_users)
defer.returnValue(ret)
- def get_user_ip_and_agents(self, user):
- return self._simple_select_list(
- table="user_ips",
- keyvalues={"user_id": user.to_string()},
- retcols=[
- "access_token", "ip", "user_agent", "last_seen"
- ],
- desc="get_user_ip_and_agents",
- )
-
def get_users(self):
"""Function to reterive a list of users in users table.
diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py
index 68955e740b..88a5eb232f 100644
--- a/synapse/storage/client_ips.py
+++ b/synapse/storage/client_ips.py
@@ -50,6 +50,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
columns=["user_id", "device_id", "last_seen"],
)
+ # (user_id, access_token, ip) -> (user_agent, device_id, last_seen)
self._batch_row_update = {}
self._client_ip_looper = self._clock.looping_call(
@@ -104,11 +105,12 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
)
@defer.inlineCallbacks
- def get_last_client_ip_by_device(self, devices):
+ def get_last_client_ip_by_device(self, user_id, device_id):
"""For each device_id listed, give the user_ip it was last seen on
Args:
- devices (iterable[(str, str)]): list of (user_id, device_id) pairs
+ user_id (str)
+ device_id (str): If None fetches all devices for the user
Returns:
defer.Deferred: resolves to a dict, where the keys
@@ -119,6 +121,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
res = yield self.runInteraction(
"get_last_client_ip_by_device",
self._get_last_client_ip_by_device_txn,
+ user_id, device_id,
retcols=(
"user_id",
"access_token",
@@ -127,23 +130,34 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
"device_id",
"last_seen",
),
- devices=devices
)
ret = {(d["user_id"], d["device_id"]): d for d in res}
+ for key in self._batch_row_update:
+ uid, access_token, ip = key
+ if uid == user_id:
+ user_agent, did, last_seen = self._batch_row_update[key]
+ if not device_id or did == device_id:
+ ret[(user_id, device_id)] = {
+ "user_id": user_id,
+ "access_token": access_token,
+ "ip": ip,
+ "user_agent": user_agent,
+ "device_id": did,
+ "last_seen": last_seen,
+ }
defer.returnValue(ret)
@classmethod
- def _get_last_client_ip_by_device_txn(cls, txn, devices, retcols):
+ def _get_last_client_ip_by_device_txn(cls, txn, user_id, device_id, retcols):
where_clauses = []
bindings = []
- for (user_id, device_id) in devices:
- if device_id is None:
- where_clauses.append("(user_id = ? AND device_id IS NULL)")
- bindings.extend((user_id, ))
- else:
- where_clauses.append("(user_id = ? AND device_id = ?)")
- bindings.extend((user_id, device_id))
+ if device_id is None:
+ where_clauses.append("user_id = ?")
+ bindings.extend((user_id, ))
+ else:
+ where_clauses.append("(user_id = ? AND device_id = ?)")
+ bindings.extend((user_id, device_id))
if not where_clauses:
return []
@@ -171,3 +185,37 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
txn.execute(sql, bindings)
return cls.cursor_to_dict(txn)
+
+ @defer.inlineCallbacks
+ def get_user_ip_and_agents(self, user):
+ user_id = user.to_string()
+ results = {}
+
+ for key in self._batch_row_update:
+ uid, access_token, ip = key
+ if uid == user_id:
+ user_agent, _, last_seen = self._batch_row_update[key]
+ results[(access_token, ip)] = (user_agent, last_seen)
+
+ rows = yield self._simple_select_list(
+ table="user_ips",
+ keyvalues={"user_id": user_id},
+ retcols=[
+ "access_token", "ip", "user_agent", "last_seen"
+ ],
+ desc="get_user_ip_and_agents",
+ )
+
+ results.update(
+ ((row["access_token"], row["ip"]), (row["user_agent"], row["last_seen"]))
+ for row in rows
+ )
+ defer.returnValue(list(
+ {
+ "access_token": access_token,
+ "ip": ip,
+ "user_agent": user_agent,
+ "last_seen": last_seen,
+ }
+ for (access_token, ip), (user_agent, last_seen) in results.iteritems()
+ ))
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index 1f0c0e7c37..03df697575 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -43,10 +43,7 @@ class ClientIpStoreTestCase(tests.unittest.TestCase):
"access_token", "ip", "user_agent", "device_id",
)
- # deliberately use an iterable here to make sure that the lookup
- # method doesn't iterate it twice
- device_list = iter(((user_id, "device_id"),))
- result = yield self.store.get_last_client_ip_by_device(device_list)
+ result = yield self.store.get_last_client_ip_by_device(user_id, "device_id")
r = result[(user_id, "device_id")]
self.assertDictContainsSubset(
|