summary refs log tree commit diff
path: root/synapse/storage/databases/main/client_ips.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/client_ips.py')
-rw-r--r--synapse/storage/databases/main/client_ips.py95
1 files changed, 53 insertions, 42 deletions
diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py

index c5468c7b0d..8a65eb6e16 100644 --- a/synapse/storage/databases/main/client_ips.py +++ b/synapse/storage/databases/main/client_ips.py
@@ -407,6 +407,34 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore): "_prune_old_user_ips", _prune_old_user_ips_txn ) + async def get_last_client_ip_by_device( + self, user_id: str, device_id: Optional[str] + ) -> Dict[Tuple[str, str], dict]: + """For each device_id listed, give the user_ip it was last seen on. + + The result might be slightly out of date as client IPs are inserted in batches. + + Args: + user_id: The user to fetch devices for. + device_id: If None fetches all devices for the user + + Returns: + A dictionary mapping a tuple of (user_id, device_id) to dicts, with + keys giving the column names from the devices table. + """ + + keyvalues = {"user_id": user_id} + if device_id is not None: + keyvalues["device_id"] = device_id + + res = await self.db_pool.simple_select_list( + table="devices", + keyvalues=keyvalues, + retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"), + ) + + return {(d["user_id"], d["device_id"]): d for d in res} + class ClientIpStore(ClientIpWorkerStore): def __init__(self, database: DatabasePool, db_conn, hs): @@ -470,43 +498,35 @@ class ClientIpStore(ClientIpWorkerStore): for entry in to_update.items(): (user_id, access_token, ip), (user_agent, device_id, last_seen) = entry - try: - self.db_pool.simple_upsert_txn( + self.db_pool.simple_upsert_txn( + txn, + table="user_ips", + keyvalues={"user_id": user_id, "access_token": access_token, "ip": ip}, + values={ + "user_agent": user_agent, + "device_id": device_id, + "last_seen": last_seen, + }, + lock=False, + ) + + # Technically an access token might not be associated with + # a device so we need to check. + if device_id: + # this is always an update rather than an upsert: the row should + # already exist, and if it doesn't, that may be because it has been + # deleted, and we don't want to re-create it. + self.db_pool.simple_update_txn( txn, - table="user_ips", - keyvalues={ - "user_id": user_id, - "access_token": access_token, - "ip": ip, - }, - values={ + table="devices", + keyvalues={"user_id": user_id, "device_id": device_id}, + updatevalues={ "user_agent": user_agent, - "device_id": device_id, "last_seen": last_seen, + "ip": ip, }, - lock=False, ) - # Technically an access token might not be associated with - # a device so we need to check. - if device_id: - # this is always an update rather than an upsert: the row should - # already exist, and if it doesn't, that may be because it has been - # deleted, and we don't want to re-create it. - self.db_pool.simple_update_txn( - txn, - table="devices", - keyvalues={"user_id": user_id, "device_id": device_id}, - updatevalues={ - "user_agent": user_agent, - "last_seen": last_seen, - "ip": ip, - }, - ) - except Exception as e: - # Failed to upsert, log and continue - logger.error("Failed to insert client IP %r: %r", entry, e) - async def get_last_client_ip_by_device( self, user_id: str, device_id: Optional[str] ) -> Dict[Tuple[str, str], dict]: @@ -520,18 +540,9 @@ class ClientIpStore(ClientIpWorkerStore): A dictionary mapping a tuple of (user_id, device_id) to dicts, with keys giving the column names from the devices table. """ + ret = await super().get_last_client_ip_by_device(user_id, device_id) - keyvalues = {"user_id": user_id} - if device_id is not None: - keyvalues["device_id"] = device_id - - res = await self.db_pool.simple_select_list( - table="devices", - keyvalues=keyvalues, - retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"), - ) - - ret = {(d["user_id"], d["device_id"]): d for d in res} + # Update what is retrieved from the database with data which is pending insertion. for key in self._batch_row_update: uid, access_token, ip = key if uid == user_id: