diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py
index 014ab635b7..3c95e90eca 100644
--- a/synapse/storage/client_ips.py
+++ b/synapse/storage/client_ips.py
@@ -15,12 +15,13 @@
import logging
-from twisted.internet import defer
+from twisted.internet import defer, reactor
from ._base import Cache
from . import background_updates
-import os
+from synapse.util.caches import CACHE_SIZE_FACTOR
+
logger = logging.getLogger(__name__)
@@ -30,9 +31,6 @@ logger = logging.getLogger(__name__)
LAST_SEEN_GRANULARITY = 120 * 1000
-CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1))
-
-
class ClientIpStore(background_updates.BackgroundUpdateStore):
def __init__(self, hs):
self.client_ip_last_seen = Cache(
@@ -50,10 +48,19 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
columns=["user_id", "device_id", "last_seen"],
)
- @defer.inlineCallbacks
- def insert_client_ip(self, user, access_token, ip, user_agent, device_id):
- now = int(self._clock.time_msec())
- key = (user.to_string(), access_token, ip)
+ # (user_id, access_token, ip) -> (user_agent, device_id, last_seen)
+ self._batch_row_update = {}
+
+ self._client_ip_looper = self._clock.looping_call(
+ self._update_client_ips_batch, 5 * 1000
+ )
+ reactor.addSystemEventTrigger("before", "shutdown", self._update_client_ips_batch)
+
+ def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id,
+ now=None):
+ if not now:
+ now = int(self._clock.time_msec())
+ key = (user_id, access_token, ip)
try:
last_seen = self.client_ip_last_seen.get(key)
@@ -62,34 +69,48 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
# Rate-limited inserts
if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
- defer.returnValue(None)
+ return
self.client_ip_last_seen.prefill(key, now)
- # It's safe not to lock here: a) no unique constraint,
- # b) LAST_SEEN_GRANULARITY makes concurrent updates incredibly unlikely
- yield self._simple_upsert(
- "user_ips",
- keyvalues={
- "user_id": user.to_string(),
- "access_token": access_token,
- "ip": ip,
- "user_agent": user_agent,
- "device_id": device_id,
- },
- values={
- "last_seen": now,
- },
- desc="insert_client_ip",
- lock=False,
+ self._batch_row_update[key] = (user_agent, device_id, now)
+
+ def _update_client_ips_batch(self):
+ to_update = self._batch_row_update
+ self._batch_row_update = {}
+ return self.runInteraction(
+ "_update_client_ips_batch", self._update_client_ips_batch_txn, to_update
)
+ def _update_client_ips_batch_txn(self, txn, to_update):
+ self.database_engine.lock_table(txn, "user_ips")
+
+ for entry in to_update.iteritems():
+ (user_id, access_token, ip), (user_agent, device_id, last_seen) = entry
+
+ self._simple_upsert_txn(
+ txn,
+ table="user_ips",
+ keyvalues={
+ "user_id": user_id,
+ "access_token": access_token,
+ "ip": ip,
+ "user_agent": user_agent,
+ "device_id": device_id,
+ },
+ values={
+ "last_seen": last_seen,
+ },
+ lock=False,
+ )
+
@defer.inlineCallbacks
- def get_last_client_ip_by_device(self, devices):
+ def get_last_client_ip_by_device(self, user_id, device_id):
"""For each device_id listed, give the user_ip it was last seen on
Args:
- devices (iterable[(str, str)]): list of (user_id, device_id) pairs
+ user_id (str)
+ device_id (str): If None fetches all devices for the user
Returns:
defer.Deferred: resolves to a dict, where the keys
@@ -100,6 +121,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
res = yield self.runInteraction(
"get_last_client_ip_by_device",
self._get_last_client_ip_by_device_txn,
+ user_id, device_id,
retcols=(
"user_id",
"access_token",
@@ -108,23 +130,34 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
"device_id",
"last_seen",
),
- devices=devices
)
ret = {(d["user_id"], d["device_id"]): d for d in res}
+ for key in self._batch_row_update:
+ uid, access_token, ip = key
+ if uid == user_id:
+ user_agent, did, last_seen = self._batch_row_update[key]
+ if not device_id or did == device_id:
+ ret[(user_id, device_id)] = {
+ "user_id": user_id,
+ "access_token": access_token,
+ "ip": ip,
+ "user_agent": user_agent,
+ "device_id": did,
+ "last_seen": last_seen,
+ }
defer.returnValue(ret)
@classmethod
- def _get_last_client_ip_by_device_txn(cls, txn, devices, retcols):
+ def _get_last_client_ip_by_device_txn(cls, txn, user_id, device_id, retcols):
where_clauses = []
bindings = []
- for (user_id, device_id) in devices:
- if device_id is None:
- where_clauses.append("(user_id = ? AND device_id IS NULL)")
- bindings.extend((user_id, ))
- else:
- where_clauses.append("(user_id = ? AND device_id = ?)")
- bindings.extend((user_id, device_id))
+ if device_id is None:
+ where_clauses.append("user_id = ?")
+ bindings.extend((user_id, ))
+ else:
+ where_clauses.append("(user_id = ? AND device_id = ?)")
+ bindings.extend((user_id, device_id))
if not where_clauses:
return []
@@ -152,3 +185,37 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
txn.execute(sql, bindings)
return cls.cursor_to_dict(txn)
+
+ @defer.inlineCallbacks
+ def get_user_ip_and_agents(self, user):
+ user_id = user.to_string()
+ results = {}
+
+ for key in self._batch_row_update:
+ uid, access_token, ip = key
+ if uid == user_id:
+ user_agent, _, last_seen = self._batch_row_update[key]
+ results[(access_token, ip)] = (user_agent, last_seen)
+
+ rows = yield self._simple_select_list(
+ table="user_ips",
+ keyvalues={"user_id": user_id},
+ retcols=[
+ "access_token", "ip", "user_agent", "last_seen"
+ ],
+ desc="get_user_ip_and_agents",
+ )
+
+ results.update(
+ ((row["access_token"], row["ip"]), (row["user_agent"], row["last_seen"]))
+ for row in rows
+ )
+ defer.returnValue(list(
+ {
+ "access_token": access_token,
+ "ip": ip,
+ "user_agent": user_agent,
+ "last_seen": last_seen,
+ }
+ for (access_token, ip), (user_agent, last_seen) in results.iteritems()
+ ))
|