summary refs log tree commit diff
path: root/synapse/storage/client_ips.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/client_ips.py')
-rw-r--r--synapse/storage/client_ips.py253
1 files changed, 181 insertions, 72 deletions
diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py
index 6db8c54077..067820a5da 100644
--- a/synapse/storage/client_ips.py
+++ b/synapse/storage/client_ips.py
@@ -19,7 +19,7 @@ from six import iteritems
 
 from twisted.internet import defer
 
-from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.util.caches import CACHE_SIZE_FACTOR
 
 from . import background_updates
@@ -33,14 +33,9 @@ logger = logging.getLogger(__name__)
 LAST_SEEN_GRANULARITY = 120 * 1000
 
 
-class ClientIpStore(background_updates.BackgroundUpdateStore):
+class ClientIpBackgroundUpdateStore(background_updates.BackgroundUpdateStore):
     def __init__(self, db_conn, hs):
-
-        self.client_ip_last_seen = Cache(
-            name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR
-        )
-
-        super(ClientIpStore, self).__init__(db_conn, hs)
+        super(ClientIpBackgroundUpdateStore, self).__init__(db_conn, hs)
 
         self.register_background_index_update(
             "user_ips_device_index",
@@ -85,14 +80,9 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
             "user_ips_drop_nonunique_index", self._remove_user_ip_nonunique
         )
 
-        # (user_id, access_token, ip,) -> (user_agent, device_id, last_seen)
-        self._batch_row_update = {}
-
-        self._client_ip_looper = self._clock.looping_call(
-            self._update_client_ips_batch, 5 * 1000
-        )
-        self.hs.get_reactor().addSystemEventTrigger(
-            "before", "shutdown", self._update_client_ips_batch
+        # Update the last seen info in devices.
+        self.register_background_update_handler(
+            "devices_last_seen", self._devices_last_seen_update
         )
 
     @defer.inlineCallbacks
@@ -294,6 +284,110 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
         return batch_size
 
     @defer.inlineCallbacks
+    def _devices_last_seen_update(self, progress, batch_size):
+        """Background update to insert last seen info into devices table
+        """
+
+        last_user_id = progress.get("last_user_id", "")
+        last_device_id = progress.get("last_device_id", "")
+
+        def _devices_last_seen_update_txn(txn):
+            # This consists of two queries:
+            #
+            #   1. The sub-query searches for the next N devices and joins
+            #      against user_ips to find the max last_seen associated with
+            #      that device.
+            #   2. The outer query then joins again against user_ips on
+            #      user/device/last_seen. This *should* hopefully only
+            #      return one row, but if it does return more than one then
+            #      we'll just end up updating the same device row multiple
+            #      times, which is fine.
+
+            if self.database_engine.supports_tuple_comparison:
+                where_clause = "(user_id, device_id) > (?, ?)"
+                where_args = [last_user_id, last_device_id]
+            else:
+                # We explicitly do a `user_id >= ? AND (...)` here to ensure
+                # that an index is used, as doing `user_id > ? OR (user_id = ? AND ...)`
+                # makes it hard for query optimiser to tell that it can use the
+                # index on user_id
+                where_clause = "user_id >= ? AND (user_id > ? OR device_id > ?)"
+                where_args = [last_user_id, last_user_id, last_device_id]
+
+            sql = """
+                SELECT
+                    last_seen, ip, user_agent, user_id, device_id
+                FROM (
+                    SELECT
+                        user_id, device_id, MAX(u.last_seen) AS last_seen
+                    FROM devices
+                    INNER JOIN user_ips AS u USING (user_id, device_id)
+                    WHERE %(where_clause)s
+                    GROUP BY user_id, device_id
+                    ORDER BY user_id ASC, device_id ASC
+                    LIMIT ?
+                ) c
+                INNER JOIN user_ips AS u USING (user_id, device_id, last_seen)
+            """ % {
+                "where_clause": where_clause
+            }
+            txn.execute(sql, where_args + [batch_size])
+
+            rows = txn.fetchall()
+            if not rows:
+                return 0
+
+            sql = """
+                UPDATE devices
+                SET last_seen = ?, ip = ?, user_agent = ?
+                WHERE user_id = ? AND device_id = ?
+            """
+            txn.execute_batch(sql, rows)
+
+            _, _, _, user_id, device_id = rows[-1]
+            self._background_update_progress_txn(
+                txn,
+                "devices_last_seen",
+                {"last_user_id": user_id, "last_device_id": device_id},
+            )
+
+            return len(rows)
+
+        updated = yield self.runInteraction(
+            "_devices_last_seen_update", _devices_last_seen_update_txn
+        )
+
+        if not updated:
+            yield self._end_background_update("devices_last_seen")
+
+        return updated
+
+
+class ClientIpStore(ClientIpBackgroundUpdateStore):
+    def __init__(self, db_conn, hs):
+
+        self.client_ip_last_seen = Cache(
+            name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR
+        )
+
+        super(ClientIpStore, self).__init__(db_conn, hs)
+
+        self.user_ips_max_age = hs.config.user_ips_max_age
+
+        # (user_id, access_token, ip,) -> (user_agent, device_id, last_seen)
+        self._batch_row_update = {}
+
+        self._client_ip_looper = self._clock.looping_call(
+            self._update_client_ips_batch, 5 * 1000
+        )
+        self.hs.get_reactor().addSystemEventTrigger(
+            "before", "shutdown", self._update_client_ips_batch
+        )
+
+        if self.user_ips_max_age:
+            self._clock.looping_call(self._prune_old_user_ips, 5 * 1000)
+
+    @defer.inlineCallbacks
     def insert_client_ip(
         self, user_id, access_token, ip, user_agent, device_id, now=None
     ):
@@ -314,20 +408,19 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
 
         self._batch_row_update[key] = (user_agent, device_id, now)
 
+    @wrap_as_background_process("update_client_ips")
     def _update_client_ips_batch(self):
 
         # If the DB pool has already terminated, don't try updating
         if not self.hs.get_db_pool().running:
             return
 
-        def update():
-            to_update = self._batch_row_update
-            self._batch_row_update = {}
-            return self.runInteraction(
-                "_update_client_ips_batch", self._update_client_ips_batch_txn, to_update
-            )
+        to_update = self._batch_row_update
+        self._batch_row_update = {}
 
-        return run_as_background_process("update_client_ips", update)
+        return self.runInteraction(
+            "_update_client_ips_batch", self._update_client_ips_batch_txn, to_update
+        )
 
     def _update_client_ips_batch_txn(self, txn, to_update):
         if "user_ips" in self._unsafe_to_upsert_tables or (
@@ -354,6 +447,21 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
                     },
                     lock=False,
                 )
+
+                # Technically an access token might not be associated with
+                # a device so we need to check.
+                if device_id:
+                    self._simple_upsert_txn(
+                        txn,
+                        table="devices",
+                        keyvalues={"user_id": user_id, "device_id": device_id},
+                        values={
+                            "user_agent": user_agent,
+                            "last_seen": last_seen,
+                            "ip": ip,
+                        },
+                        lock=False,
+                    )
             except Exception as e:
                 # Failed to upsert, log and continue
                 logger.error("Failed to insert client IP %r: %r", entry, e)
@@ -372,19 +480,14 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
             keys giving the column names
         """
 
-        res = yield self.runInteraction(
-            "get_last_client_ip_by_device",
-            self._get_last_client_ip_by_device_txn,
-            user_id,
-            device_id,
-            retcols=(
-                "user_id",
-                "access_token",
-                "ip",
-                "user_agent",
-                "device_id",
-                "last_seen",
-            ),
+        keyvalues = {"user_id": user_id}
+        if device_id is not None:
+            keyvalues["device_id"] = device_id
+
+        res = yield self._simple_select_list(
+            table="devices",
+            keyvalues=keyvalues,
+            retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
         )
 
         ret = {(d["user_id"], d["device_id"]): d for d in res}
@@ -403,42 +506,6 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
                     }
         return ret
 
-    @classmethod
-    def _get_last_client_ip_by_device_txn(cls, txn, user_id, device_id, retcols):
-        where_clauses = []
-        bindings = []
-        if device_id is None:
-            where_clauses.append("user_id = ?")
-            bindings.extend((user_id,))
-        else:
-            where_clauses.append("(user_id = ? AND device_id = ?)")
-            bindings.extend((user_id, device_id))
-
-        if not where_clauses:
-            return []
-
-        inner_select = (
-            "SELECT MAX(last_seen) mls, user_id, device_id FROM user_ips "
-            "WHERE %(where)s "
-            "GROUP BY user_id, device_id"
-        ) % {"where": " OR ".join(where_clauses)}
-
-        sql = (
-            "SELECT %(retcols)s FROM user_ips "
-            "JOIN (%(inner_select)s) ips ON"
-            "    user_ips.last_seen = ips.mls AND"
-            "    user_ips.user_id = ips.user_id AND"
-            "    (user_ips.device_id = ips.device_id OR"
-            "         (user_ips.device_id IS NULL AND ips.device_id IS NULL)"
-            "    )"
-        ) % {
-            "retcols": ",".join("user_ips." + c for c in retcols),
-            "inner_select": inner_select,
-        }
-
-        txn.execute(sql, bindings)
-        return cls.cursor_to_dict(txn)
-
     @defer.inlineCallbacks
     def get_user_ip_and_agents(self, user):
         user_id = user.to_string()
@@ -470,3 +537,45 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
             }
             for (access_token, ip), (user_agent, last_seen) in iteritems(results)
         )
+
+    @wrap_as_background_process("prune_old_user_ips")
+    async def _prune_old_user_ips(self):
+        """Removes entries in user IPs older than the configured period.
+        """
+
+        if self.user_ips_max_age is None:
+            # Nothing to do
+            return
+
+        if not await self.has_completed_background_update("devices_last_seen"):
+            # Only start pruning if we have finished populating the devices
+            # last seen info.
+            return
+
+        # We do a slightly funky SQL delete to ensure we don't try and delete
+        # too much at once (as the table may be very large from before we
+        # started pruning).
+        #
+        # This works by finding the max last_seen that is less than the given
+        # time, but has no more than N rows before it, deleting all rows with
+        # a lesser last_seen time. (We COALESCE so that the sub-SELECT always
+        # returns exactly one row).
+        sql = """
+            DELETE FROM user_ips
+            WHERE last_seen <= (
+                SELECT COALESCE(MAX(last_seen), -1)
+                FROM (
+                    SELECT last_seen FROM user_ips
+                    WHERE last_seen <= ?
+                    ORDER BY last_seen ASC
+                    LIMIT 5000
+                ) AS u
+            )
+        """
+
+        timestamp = self.clock.time_msec() - self.user_ips_max_age
+
+        def _prune_old_user_ips_txn(txn):
+            txn.execute(sql, (timestamp,))
+
+        await self.runInteraction("_prune_old_user_ips", _prune_old_user_ips_txn)