summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/6098.feature1
-rw-r--r--docs/sample_config.yaml6
-rw-r--r--synapse/config/server.py13
-rw-r--r--synapse/metrics/background_process_metrics.py33
-rw-r--r--synapse/storage/background_updates.py22
-rw-r--r--synapse/storage/client_ips.py62
-rw-r--r--tests/storage/test_client_ips.py71
7 files changed, 197 insertions, 11 deletions
diff --git a/changelog.d/6098.feature b/changelog.d/6098.feature
new file mode 100644
index 0000000000..f3c693c06b
--- /dev/null
+++ b/changelog.d/6098.feature
@@ -0,0 +1 @@
+Add support for pruning old rows in `user_ips` table.
diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml
index da31728037..8f801daf35 100644
--- a/docs/sample_config.yaml
+++ b/docs/sample_config.yaml
@@ -316,6 +316,12 @@ listeners:
 #
 redaction_retention_period: 7d
 
+# How long to track users' last seen time and IPs in the database.
+#
+# Defaults to `28d`. Set to `null` to disable clearing out of old rows.
+#
+#user_ips_max_age: 14d
+
 
 ## TLS ##
 
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 3a7a49bc91..9d3f1b5bfc 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -172,6 +172,13 @@ class ServerConfig(Config):
         else:
             self.redaction_retention_period = None
 
+        # How long to keep entries in the `users_ips` table.
+        user_ips_max_age = config.get("user_ips_max_age", "28d")
+        if user_ips_max_age is not None:
+            self.user_ips_max_age = self.parse_duration(user_ips_max_age)
+        else:
+            self.user_ips_max_age = None
+
         # Options to disable HS
         self.hs_disabled = config.get("hs_disabled", False)
         self.hs_disabled_message = config.get("hs_disabled_message", "")
@@ -736,6 +743,12 @@ class ServerConfig(Config):
         # Defaults to `7d`. Set to `null` to disable.
         #
         redaction_retention_period: 7d
+
+        # How long to track users' last seen time and IPs in the database.
+        #
+        # Defaults to `28d`. Set to `null` to disable clearing out of old rows.
+        #
+        #user_ips_max_age: 14d
         """
             % locals()
         )
diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py
index edd6b42db3..c53d2a0d40 100644
--- a/synapse/metrics/background_process_metrics.py
+++ b/synapse/metrics/background_process_metrics.py
@@ -15,6 +15,8 @@
 
 import logging
 import threading
+from asyncio import iscoroutine
+from functools import wraps
 
 import six
 
@@ -173,7 +175,7 @@ def run_as_background_process(desc, func, *args, **kwargs):
 
     Args:
         desc (str): a description for this background process type
-        func: a function, which may return a Deferred
+        func: a function, which may return a Deferred or a coroutine
         args: positional args for func
         kwargs: keyword args for func
 
@@ -197,7 +199,17 @@ def run_as_background_process(desc, func, *args, **kwargs):
                 _background_processes.setdefault(desc, set()).add(proc)
 
             try:
-                yield func(*args, **kwargs)
+                result = func(*args, **kwargs)
+
+                # We probably don't have an ensureDeferred in our call stack to handle
+                # coroutine results, so we need to ensureDeferred here.
+                #
+                # But we need this check because ensureDeferred doesn't like being
+                # called on immediate values (as opposed to Deferreds or coroutines).
+                if iscoroutine(result):
+                    result = defer.ensureDeferred(result)
+
+                return (yield result)
             except Exception:
                 logger.exception("Background process '%s' threw an exception", desc)
             finally:
@@ -208,3 +220,20 @@ def run_as_background_process(desc, func, *args, **kwargs):
 
     with PreserveLoggingContext():
         return run()
+
+
+def wrap_as_background_process(desc):
+    """Decorator that wraps a function that gets called as a background
+    process.
+
+    Equivalent of calling the function with `run_as_background_process`
+    """
+
+    def wrap_as_background_process_inner(func):
+        @wraps(func)
+        def wrap_as_background_process_inner_2(*args, **kwargs):
+            return run_as_background_process(desc, func, *args, **kwargs)
+
+        return wrap_as_background_process_inner_2
+
+    return wrap_as_background_process_inner
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 9522acd972..80b57a948c 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -140,7 +140,7 @@ class BackgroundUpdateStore(SQLBaseStore):
             "background_updates",
             keyvalues=None,
             retcol="1",
-            desc="check_background_updates",
+            desc="has_completed_background_updates",
         )
         if not updates:
             self._all_done = True
@@ -148,6 +148,26 @@ class BackgroundUpdateStore(SQLBaseStore):
 
         return False
 
+    async def has_completed_background_update(self, update_name) -> bool:
+        """Check if the given background update has finished running.
+        """
+
+        if self._all_done:
+            return True
+
+        if update_name in self._background_update_queue:
+            return False
+
+        update_exists = await self._simple_select_one_onecol(
+            "background_updates",
+            keyvalues={"update_name": update_name},
+            retcol="1",
+            desc="has_completed_background_update",
+            allow_none=True,
+        )
+
+        return not update_exists
+
     @defer.inlineCallbacks
     def do_next_background_update(self, desired_duration_ms):
         """Does some amount of work on the next queued background update
diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py
index 8996689744..539584288d 100644
--- a/synapse/storage/client_ips.py
+++ b/synapse/storage/client_ips.py
@@ -19,7 +19,7 @@ from six import iteritems
 
 from twisted.internet import defer
 
-from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.util.caches import CACHE_SIZE_FACTOR
 
 from . import background_updates
@@ -42,6 +42,8 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
 
         super(ClientIpStore, self).__init__(db_conn, hs)
 
+        self.user_ips_max_age = hs.config.user_ips_max_age
+
         self.register_background_index_update(
             "user_ips_device_index",
             index_name="user_ips_device_id",
@@ -100,6 +102,9 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
             "before", "shutdown", self._update_client_ips_batch
         )
 
+        if self.user_ips_max_age:
+            self._clock.looping_call(self._prune_old_user_ips, 5 * 1000)
+
     @defer.inlineCallbacks
     def _remove_user_ip_nonunique(self, progress, batch_size):
         def f(conn):
@@ -319,20 +324,19 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
 
         self._batch_row_update[key] = (user_agent, device_id, now)
 
+    @wrap_as_background_process("update_client_ips")
     def _update_client_ips_batch(self):
 
         # If the DB pool has already terminated, don't try updating
         if not self.hs.get_db_pool().running:
             return
 
-        def update():
-            to_update = self._batch_row_update
-            self._batch_row_update = {}
-            return self.runInteraction(
-                "_update_client_ips_batch", self._update_client_ips_batch_txn, to_update
-            )
+        to_update = self._batch_row_update
+        self._batch_row_update = {}
 
-        return run_as_background_process("update_client_ips", update)
+        return self.runInteraction(
+            "_update_client_ips_batch", self._update_client_ips_batch_txn, to_update
+        )
 
     def _update_client_ips_batch_txn(self, txn, to_update):
         if "user_ips" in self._unsafe_to_upsert_tables or (
@@ -496,3 +500,45 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
             yield self._end_background_update("devices_last_seen")
 
         return updated
+
+    @wrap_as_background_process("prune_old_user_ips")
+    async def _prune_old_user_ips(self):
+        """Removes entries in user IPs older than the configured period.
+        """
+
+        if self.user_ips_max_age is None:
+            # Nothing to do
+            return
+
+        if not await self.has_completed_background_update("devices_last_seen"):
+            # Only start pruning if we have finished populating the devices
+            # last seen info.
+            return
+
+        # We do a slightly funky SQL delete to ensure we don't try and delete
+        # too much at once (as the table may be very large from before we
+        # started pruning).
+        #
+        # This works by finding the max last_seen that is less than the given
+        # time, but has no more than N rows before it, deleting all rows with
+        # a lesser last_seen time. (We COALESCE so that the sub-SELECT always
+        # returns exactly one row).
+        sql = """
+            DELETE FROM user_ips
+            WHERE last_seen <= (
+                SELECT COALESCE(MAX(last_seen), -1)
+                FROM (
+                    SELECT last_seen FROM user_ips
+                    WHERE last_seen <= ?
+                    ORDER BY last_seen ASC
+                    LIMIT 5000
+                ) AS u
+            )
+        """
+
+        timestamp = self.clock.time_msec() - self.user_ips_max_age
+
+        def _prune_old_user_ips_txn(txn):
+            txn.execute(sql, (timestamp,))
+
+        await self.runInteraction("_prune_old_user_ips", _prune_old_user_ips_txn)
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index 76fe65b59e..afac5dec7f 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -279,6 +279,77 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
             r,
         )
 
+    def test_old_user_ips_pruned(self):
+        # First make sure we have completed all updates.
+        while not self.get_success(self.store.has_completed_background_updates()):
+            self.get_success(self.store.do_next_background_update(100), by=0.1)
+
+        # Insert a user IP
+        user_id = "@user:id"
+        self.get_success(
+            self.store.insert_client_ip(
+                user_id, "access_token", "ip", "user_agent", "device_id"
+            )
+        )
+
+        # Force persisting to disk
+        self.reactor.advance(200)
+
+        # We should see that in the DB
+        result = self.get_success(
+            self.store._simple_select_list(
+                table="user_ips",
+                keyvalues={"user_id": user_id},
+                retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
+                desc="get_user_ip_and_agents",
+            )
+        )
+
+        self.assertEqual(
+            result,
+            [
+                {
+                    "access_token": "access_token",
+                    "ip": "ip",
+                    "user_agent": "user_agent",
+                    "device_id": "device_id",
+                    "last_seen": 0,
+                }
+            ],
+        )
+
+        # Now advance by a couple of months
+        self.reactor.advance(60 * 24 * 60 * 60)
+
+        # We should get no results.
+        result = self.get_success(
+            self.store._simple_select_list(
+                table="user_ips",
+                keyvalues={"user_id": user_id},
+                retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
+                desc="get_user_ip_and_agents",
+            )
+        )
+
+        self.assertEqual(result, [])
+
+        # But we should still get the correct values for the device
+        result = self.get_success(
+            self.store.get_last_client_ip_by_device(user_id, "device_id")
+        )
+
+        r = result[(user_id, "device_id")]
+        self.assertDictContainsSubset(
+            {
+                "user_id": user_id,
+                "device_id": "device_id",
+                "ip": "ip",
+                "user_agent": "user_agent",
+                "last_seen": 0,
+            },
+            r,
+        )
+
 
 class ClientIpAuthTestCase(unittest.HomeserverTestCase):