diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py
index 747d2df622..fc468ea185 100644
--- a/synapse/storage/client_ips.py
+++ b/synapse/storage/client_ips.py
@@ -15,11 +15,14 @@
import logging
-from twisted.internet import defer
+from twisted.internet import defer, reactor
from ._base import Cache
from . import background_updates
+from synapse.util.caches import CACHE_SIZE_FACTOR
+
+
logger = logging.getLogger(__name__)
# Number of msec of granularity to store the user IP 'last seen' time. Smaller
@@ -33,7 +36,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
self.client_ip_last_seen = Cache(
name="client_ip_last_seen",
keylen=4,
- max_entries=5000,
+ max_entries=50000 * CACHE_SIZE_FACTOR,
)
super(ClientIpStore, self).__init__(hs)
@@ -45,7 +48,14 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
columns=["user_id", "device_id", "last_seen"],
)
- @defer.inlineCallbacks
+ # (user_id, access_token, ip) -> (user_agent, device_id, last_seen)
+ self._batch_row_update = {}
+
+ self._client_ip_looper = self._clock.looping_call(
+ self._update_client_ips_batch, 5 * 1000
+ )
+ reactor.addSystemEventTrigger("before", "shutdown", self._update_client_ips_batch)
+
def insert_client_ip(self, user, access_token, ip, user_agent, device_id):
now = int(self._clock.time_msec())
key = (user.to_string(), access_token, ip)
@@ -57,34 +67,48 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
# Rate-limited inserts
if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
- defer.returnValue(None)
+ return
self.client_ip_last_seen.prefill(key, now)
- # It's safe not to lock here: a) no unique constraint,
- # b) LAST_SEEN_GRANULARITY makes concurrent updates incredibly unlikely
- yield self._simple_upsert(
- "user_ips",
- keyvalues={
- "user_id": user.to_string(),
- "access_token": access_token,
- "ip": ip,
- "user_agent": user_agent,
- "device_id": device_id,
- },
- values={
- "last_seen": now,
- },
- desc="insert_client_ip",
- lock=False,
+ self._batch_row_update[key] = (user_agent, device_id, now)
+
+ def _update_client_ips_batch(self):
+ to_update = self._batch_row_update
+ self._batch_row_update = {}
+ return self.runInteraction(
+ "_update_client_ips_batch", self._update_client_ips_batch_txn, to_update
)
+ def _update_client_ips_batch_txn(self, txn, to_update):
+ self.database_engine.lock_table(txn, "user_ips")
+
+ for entry in to_update.iteritems():
+ (user_id, access_token, ip), (user_agent, device_id, last_seen) = entry
+
+ self._simple_upsert_txn(
+ txn,
+ table="user_ips",
+ keyvalues={
+ "user_id": user_id,
+ "access_token": access_token,
+ "ip": ip,
+ "user_agent": user_agent,
+ "device_id": device_id,
+ },
+ values={
+ "last_seen": last_seen,
+ },
+ lock=False,
+ )
+
@defer.inlineCallbacks
- def get_last_client_ip_by_device(self, devices):
+ def get_last_client_ip_by_device(self, user_id, device_id):
"""For each device_id listed, give the user_ip it was last seen on
Args:
- devices (iterable[(str, str)]): list of (user_id, device_id) pairs
+ user_id (str)
+ device_id (str): If None fetches all devices for the user
Returns:
defer.Deferred: resolves to a dict, where the keys
@@ -95,6 +119,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
res = yield self.runInteraction(
"get_last_client_ip_by_device",
self._get_last_client_ip_by_device_txn,
+ user_id, device_id,
retcols=(
"user_id",
"access_token",
@@ -103,23 +128,34 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
"device_id",
"last_seen",
),
- devices=devices
)
ret = {(d["user_id"], d["device_id"]): d for d in res}
+ for key in self._batch_row_update:
+ uid, access_token, ip = key
+ if uid == user_id:
+ user_agent, did, last_seen = self._batch_row_update[key]
+ if not device_id or did == device_id:
+ ret[(user_id, device_id)] = {
+ "user_id": user_id,
+ "access_token": access_token,
+ "ip": ip,
+ "user_agent": user_agent,
+ "device_id": did,
+ "last_seen": last_seen,
+ }
defer.returnValue(ret)
@classmethod
- def _get_last_client_ip_by_device_txn(cls, txn, devices, retcols):
+ def _get_last_client_ip_by_device_txn(cls, txn, user_id, device_id, retcols):
where_clauses = []
bindings = []
- for (user_id, device_id) in devices:
- if device_id is None:
- where_clauses.append("(user_id = ? AND device_id IS NULL)")
- bindings.extend((user_id, ))
- else:
- where_clauses.append("(user_id = ? AND device_id = ?)")
- bindings.extend((user_id, device_id))
+ if device_id is None:
+ where_clauses.append("user_id = ?")
+ bindings.extend((user_id, ))
+ else:
+ where_clauses.append("(user_id = ? AND device_id = ?)")
+ bindings.extend((user_id, device_id))
if not where_clauses:
return []
@@ -147,3 +183,37 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
txn.execute(sql, bindings)
return cls.cursor_to_dict(txn)
+
+ @defer.inlineCallbacks
+ def get_user_ip_and_agents(self, user):
+ user_id = user.to_string()
+ results = {}
+
+ for key in self._batch_row_update:
+ uid, access_token, ip = key
+ if uid == user_id:
+ user_agent, _, last_seen = self._batch_row_update[key]
+ results[(access_token, ip)] = (user_agent, last_seen)
+
+ rows = yield self._simple_select_list(
+ table="user_ips",
+ keyvalues={"user_id": user_id},
+ retcols=[
+ "access_token", "ip", "user_agent", "last_seen"
+ ],
+ desc="get_user_ip_and_agents",
+ )
+
+ results.update(
+ ((row["access_token"], row["ip"]), (row["user_agent"], row["last_seen"]))
+ for row in rows
+ )
+ defer.returnValue(list(
+ {
+ "access_token": access_token,
+ "ip": ip,
+ "user_agent": user_agent,
+ "last_seen": last_seen,
+ }
+ for (access_token, ip), (user_agent, last_seen) in results.iteritems()
+ ))
|