diff options
Diffstat (limited to '')
-rw-r--r-- | synapse/api/auth.py | 3 | ||||
-rw-r--r-- | synapse/handlers/device.py | 4 | ||||
-rw-r--r-- | synapse/storage/__init__.py | 10 | ||||
-rw-r--r-- | synapse/storage/client_ips.py | 127 | ||||
-rw-r--r-- | tests/storage/test_client_ips.py | 5 |
5 files changed, 101 insertions, 48 deletions
diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 0c297cb022..10f4972369 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -23,7 +23,6 @@ from synapse import event_auth from synapse.api.constants import EventTypes, Membership, JoinRules from synapse.api.errors import AuthError, Codes from synapse.types import UserID -from synapse.util import logcontext from synapse.util.metrics import Measure logger = logging.getLogger(__name__) @@ -200,7 +199,7 @@ class Auth(object): default=[""] )[0] if user and access_token and ip_addr: - logcontext.preserve_fn(self.store.insert_client_ip)( + self.store.insert_client_ip( user=user, access_token=access_token, ip=ip_addr, diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 982cda3edf..ed60d494ff 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -106,7 +106,7 @@ class DeviceHandler(BaseHandler): device_map = yield self.store.get_devices_by_user(user_id) ips = yield self.store.get_last_client_ip_by_device( - devices=((user_id, device_id) for device_id in device_map.keys()) + user_id, device_id=None ) devices = device_map.values() @@ -133,7 +133,7 @@ class DeviceHandler(BaseHandler): except errors.StoreError: raise errors.NotFoundError ips = yield self.store.get_last_client_ip_by_device( - devices=((user_id, device_id),) + user_id, device_id, ) _update_device_from_client_ips(device, ips) defer.returnValue(device) diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index f119c5a758..b92472df33 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -304,16 +304,6 @@ class DataStore(RoomMemberStore, RoomStore, ret = yield self.runInteraction("count_users", _count_users) defer.returnValue(ret) - def get_user_ip_and_agents(self, user): - return self._simple_select_list( - table="user_ips", - keyvalues={"user_id": user.to_string()}, - retcols=[ - "access_token", "ip", "user_agent", "last_seen" - ], - desc="get_user_ip_and_agents", - ) - def get_users(self): """Function to reterive a list of users in users table. diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py index 014ab635b7..88a5eb232f 100644 --- a/synapse/storage/client_ips.py +++ b/synapse/storage/client_ips.py @@ -15,7 +15,7 @@ import logging -from twisted.internet import defer +from twisted.internet import defer, reactor from ._base import Cache from . import background_updates @@ -50,7 +50,14 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): columns=["user_id", "device_id", "last_seen"], ) - @defer.inlineCallbacks + # (user_id, access_token, ip) -> (user_agent, device_id, last_seen) + self._batch_row_update = {} + + self._client_ip_looper = self._clock.looping_call( + self._update_client_ips_batch, 5 * 1000 + ) + reactor.addSystemEventTrigger("before", "shutdown", self._update_client_ips_batch) + def insert_client_ip(self, user, access_token, ip, user_agent, device_id): now = int(self._clock.time_msec()) key = (user.to_string(), access_token, ip) @@ -62,34 +69,48 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): # Rate-limited inserts if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY: - defer.returnValue(None) + return self.client_ip_last_seen.prefill(key, now) - # It's safe not to lock here: a) no unique constraint, - # b) LAST_SEEN_GRANULARITY makes concurrent updates incredibly unlikely - yield self._simple_upsert( - "user_ips", - keyvalues={ - "user_id": user.to_string(), - "access_token": access_token, - "ip": ip, - "user_agent": user_agent, - "device_id": device_id, - }, - values={ - "last_seen": now, - }, - desc="insert_client_ip", - lock=False, + self._batch_row_update[key] = (user_agent, device_id, now) + + def _update_client_ips_batch(self): + to_update = self._batch_row_update + self._batch_row_update = {} + return self.runInteraction( + "_update_client_ips_batch", self._update_client_ips_batch_txn, to_update ) + def _update_client_ips_batch_txn(self, txn, to_update): + self.database_engine.lock_table(txn, "user_ips") + + for entry in to_update.iteritems(): + (user_id, access_token, ip), (user_agent, device_id, last_seen) = entry + + self._simple_upsert_txn( + txn, + table="user_ips", + keyvalues={ + "user_id": user_id, + "access_token": access_token, + "ip": ip, + "user_agent": user_agent, + "device_id": device_id, + }, + values={ + "last_seen": last_seen, + }, + lock=False, + ) + @defer.inlineCallbacks - def get_last_client_ip_by_device(self, devices): + def get_last_client_ip_by_device(self, user_id, device_id): """For each device_id listed, give the user_ip it was last seen on Args: - devices (iterable[(str, str)]): list of (user_id, device_id) pairs + user_id (str) + device_id (str): If None fetches all devices for the user Returns: defer.Deferred: resolves to a dict, where the keys @@ -100,6 +121,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): res = yield self.runInteraction( "get_last_client_ip_by_device", self._get_last_client_ip_by_device_txn, + user_id, device_id, retcols=( "user_id", "access_token", @@ -108,23 +130,34 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): "device_id", "last_seen", ), - devices=devices ) ret = {(d["user_id"], d["device_id"]): d for d in res} + for key in self._batch_row_update: + uid, access_token, ip = key + if uid == user_id: + user_agent, did, last_seen = self._batch_row_update[key] + if not device_id or did == device_id: + ret[(user_id, device_id)] = { + "user_id": user_id, + "access_token": access_token, + "ip": ip, + "user_agent": user_agent, + "device_id": did, + "last_seen": last_seen, + } defer.returnValue(ret) @classmethod - def _get_last_client_ip_by_device_txn(cls, txn, devices, retcols): + def _get_last_client_ip_by_device_txn(cls, txn, user_id, device_id, retcols): where_clauses = [] bindings = [] - for (user_id, device_id) in devices: - if device_id is None: - where_clauses.append("(user_id = ? AND device_id IS NULL)") - bindings.extend((user_id, )) - else: - where_clauses.append("(user_id = ? AND device_id = ?)") - bindings.extend((user_id, device_id)) + if device_id is None: + where_clauses.append("user_id = ?") + bindings.extend((user_id, )) + else: + where_clauses.append("(user_id = ? AND device_id = ?)") + bindings.extend((user_id, device_id)) if not where_clauses: return [] @@ -152,3 +185,37 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): txn.execute(sql, bindings) return cls.cursor_to_dict(txn) + + @defer.inlineCallbacks + def get_user_ip_and_agents(self, user): + user_id = user.to_string() + results = {} + + for key in self._batch_row_update: + uid, access_token, ip = key + if uid == user_id: + user_agent, _, last_seen = self._batch_row_update[key] + results[(access_token, ip)] = (user_agent, last_seen) + + rows = yield self._simple_select_list( + table="user_ips", + keyvalues={"user_id": user_id}, + retcols=[ + "access_token", "ip", "user_agent", "last_seen" + ], + desc="get_user_ip_and_agents", + ) + + results.update( + ((row["access_token"], row["ip"]), (row["user_agent"], row["last_seen"])) + for row in rows + ) + defer.returnValue(list( + { + "access_token": access_token, + "ip": ip, + "user_agent": user_agent, + "last_seen": last_seen, + } + for (access_token, ip), (user_agent, last_seen) in results.iteritems() + )) diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index 1f0c0e7c37..03df697575 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -43,10 +43,7 @@ class ClientIpStoreTestCase(tests.unittest.TestCase): "access_token", "ip", "user_agent", "device_id", ) - # deliberately use an iterable here to make sure that the lookup - # method doesn't iterate it twice - device_list = iter(((user_id, "device_id"),)) - result = yield self.store.get_last_client_ip_by_device(device_list) + result = yield self.store.get_last_client_ip_by_device(user_id, "device_id") r = result[(user_id, "device_id")] self.assertDictContainsSubset( |