diff --git a/changelog.d/6098.feature b/changelog.d/6098.feature
new file mode 100644
index 0000000000..f3c693c06b
--- /dev/null
+++ b/changelog.d/6098.feature
@@ -0,0 +1 @@
+Add support for pruning old rows in `user_ips` table.
diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml
index da31728037..8f801daf35 100644
--- a/docs/sample_config.yaml
+++ b/docs/sample_config.yaml
@@ -316,6 +316,12 @@ listeners:
#
redaction_retention_period: 7d
+# How long to track users' last seen time and IPs in the database.
+#
+# Defaults to `28d`. Set to `null` to disable clearing out of old rows.
+#
+#user_ips_max_age: 14d
+
## TLS ##
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 3a7a49bc91..9d3f1b5bfc 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -172,6 +172,13 @@ class ServerConfig(Config):
else:
self.redaction_retention_period = None
+ # How long to keep entries in the `users_ips` table.
+ user_ips_max_age = config.get("user_ips_max_age", "28d")
+ if user_ips_max_age is not None:
+ self.user_ips_max_age = self.parse_duration(user_ips_max_age)
+ else:
+ self.user_ips_max_age = None
+
# Options to disable HS
self.hs_disabled = config.get("hs_disabled", False)
self.hs_disabled_message = config.get("hs_disabled_message", "")
@@ -736,6 +743,12 @@ class ServerConfig(Config):
# Defaults to `7d`. Set to `null` to disable.
#
redaction_retention_period: 7d
+
+ # How long to track users' last seen time and IPs in the database.
+ #
+ # Defaults to `28d`. Set to `null` to disable clearing out of old rows.
+ #
+ #user_ips_max_age: 14d
"""
% locals()
)
diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py
index edd6b42db3..c53d2a0d40 100644
--- a/synapse/metrics/background_process_metrics.py
+++ b/synapse/metrics/background_process_metrics.py
@@ -15,6 +15,8 @@
import logging
import threading
+from asyncio import iscoroutine
+from functools import wraps
import six
@@ -173,7 +175,7 @@ def run_as_background_process(desc, func, *args, **kwargs):
Args:
desc (str): a description for this background process type
- func: a function, which may return a Deferred
+ func: a function, which may return a Deferred or a coroutine
args: positional args for func
kwargs: keyword args for func
@@ -197,7 +199,17 @@ def run_as_background_process(desc, func, *args, **kwargs):
_background_processes.setdefault(desc, set()).add(proc)
try:
- yield func(*args, **kwargs)
+ result = func(*args, **kwargs)
+
+ # We probably don't have an ensureDeferred in our call stack to handle
+ # coroutine results, so we need to ensureDeferred here.
+ #
+ # But we need this check because ensureDeferred doesn't like being
+ # called on immediate values (as opposed to Deferreds or coroutines).
+ if iscoroutine(result):
+ result = defer.ensureDeferred(result)
+
+ return (yield result)
except Exception:
logger.exception("Background process '%s' threw an exception", desc)
finally:
@@ -208,3 +220,20 @@ def run_as_background_process(desc, func, *args, **kwargs):
with PreserveLoggingContext():
return run()
+
+
+def wrap_as_background_process(desc):
+ """Decorator that wraps a function that gets called as a background
+ process.
+
+ Equivalent of calling the function with `run_as_background_process`
+ """
+
+ def wrap_as_background_process_inner(func):
+ @wraps(func)
+ def wrap_as_background_process_inner_2(*args, **kwargs):
+ return run_as_background_process(desc, func, *args, **kwargs)
+
+ return wrap_as_background_process_inner_2
+
+ return wrap_as_background_process_inner
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 9522acd972..80b57a948c 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -140,7 +140,7 @@ class BackgroundUpdateStore(SQLBaseStore):
"background_updates",
keyvalues=None,
retcol="1",
- desc="check_background_updates",
+ desc="has_completed_background_updates",
)
if not updates:
self._all_done = True
@@ -148,6 +148,26 @@ class BackgroundUpdateStore(SQLBaseStore):
return False
+ async def has_completed_background_update(self, update_name) -> bool:
+ """Check if the given background update has finished running.
+ """
+
+ if self._all_done:
+ return True
+
+ if update_name in self._background_update_queue:
+ return False
+
+ update_exists = await self._simple_select_one_onecol(
+ "background_updates",
+ keyvalues={"update_name": update_name},
+ retcol="1",
+ desc="has_completed_background_update",
+ allow_none=True,
+ )
+
+ return not update_exists
+
@defer.inlineCallbacks
def do_next_background_update(self, desired_duration_ms):
"""Does some amount of work on the next queued background update
diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py
index 8996689744..539584288d 100644
--- a/synapse/storage/client_ips.py
+++ b/synapse/storage/client_ips.py
@@ -19,7 +19,7 @@ from six import iteritems
from twisted.internet import defer
-from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.util.caches import CACHE_SIZE_FACTOR
from . import background_updates
@@ -42,6 +42,8 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
super(ClientIpStore, self).__init__(db_conn, hs)
+ self.user_ips_max_age = hs.config.user_ips_max_age
+
self.register_background_index_update(
"user_ips_device_index",
index_name="user_ips_device_id",
@@ -100,6 +102,9 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
"before", "shutdown", self._update_client_ips_batch
)
+ if self.user_ips_max_age:
+ self._clock.looping_call(self._prune_old_user_ips, 5 * 1000)
+
@defer.inlineCallbacks
def _remove_user_ip_nonunique(self, progress, batch_size):
def f(conn):
@@ -319,20 +324,19 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
self._batch_row_update[key] = (user_agent, device_id, now)
+ @wrap_as_background_process("update_client_ips")
def _update_client_ips_batch(self):
# If the DB pool has already terminated, don't try updating
if not self.hs.get_db_pool().running:
return
- def update():
- to_update = self._batch_row_update
- self._batch_row_update = {}
- return self.runInteraction(
- "_update_client_ips_batch", self._update_client_ips_batch_txn, to_update
- )
+ to_update = self._batch_row_update
+ self._batch_row_update = {}
- return run_as_background_process("update_client_ips", update)
+ return self.runInteraction(
+ "_update_client_ips_batch", self._update_client_ips_batch_txn, to_update
+ )
def _update_client_ips_batch_txn(self, txn, to_update):
if "user_ips" in self._unsafe_to_upsert_tables or (
@@ -496,3 +500,45 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
yield self._end_background_update("devices_last_seen")
return updated
+
+ @wrap_as_background_process("prune_old_user_ips")
+ async def _prune_old_user_ips(self):
+ """Removes entries in user IPs older than the configured period.
+ """
+
+ if self.user_ips_max_age is None:
+ # Nothing to do
+ return
+
+ if not await self.has_completed_background_update("devices_last_seen"):
+ # Only start pruning if we have finished populating the devices
+ # last seen info.
+ return
+
+ # We do a slightly funky SQL delete to ensure we don't try and delete
+ # too much at once (as the table may be very large from before we
+ # started pruning).
+ #
+ # This works by finding the max last_seen that is less than the given
+ # time, but has no more than N rows before it, deleting all rows with
+ # a lesser last_seen time. (We COALESCE so that the sub-SELECT always
+ # returns exactly one row).
+ sql = """
+ DELETE FROM user_ips
+ WHERE last_seen <= (
+ SELECT COALESCE(MAX(last_seen), -1)
+ FROM (
+ SELECT last_seen FROM user_ips
+ WHERE last_seen <= ?
+ ORDER BY last_seen ASC
+ LIMIT 5000
+ ) AS u
+ )
+ """
+
+ timestamp = self.clock.time_msec() - self.user_ips_max_age
+
+ def _prune_old_user_ips_txn(txn):
+ txn.execute(sql, (timestamp,))
+
+ await self.runInteraction("_prune_old_user_ips", _prune_old_user_ips_txn)
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index 76fe65b59e..afac5dec7f 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -279,6 +279,77 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
r,
)
+ def test_old_user_ips_pruned(self):
+ # First make sure we have completed all updates.
+ while not self.get_success(self.store.has_completed_background_updates()):
+ self.get_success(self.store.do_next_background_update(100), by=0.1)
+
+ # Insert a user IP
+ user_id = "@user:id"
+ self.get_success(
+ self.store.insert_client_ip(
+ user_id, "access_token", "ip", "user_agent", "device_id"
+ )
+ )
+
+ # Force persisting to disk
+ self.reactor.advance(200)
+
+ # We should see that in the DB
+ result = self.get_success(
+ self.store._simple_select_list(
+ table="user_ips",
+ keyvalues={"user_id": user_id},
+ retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
+ desc="get_user_ip_and_agents",
+ )
+ )
+
+ self.assertEqual(
+ result,
+ [
+ {
+ "access_token": "access_token",
+ "ip": "ip",
+ "user_agent": "user_agent",
+ "device_id": "device_id",
+ "last_seen": 0,
+ }
+ ],
+ )
+
+ # Now advance by a couple of months
+ self.reactor.advance(60 * 24 * 60 * 60)
+
+ # We should get no results.
+ result = self.get_success(
+ self.store._simple_select_list(
+ table="user_ips",
+ keyvalues={"user_id": user_id},
+ retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
+ desc="get_user_ip_and_agents",
+ )
+ )
+
+ self.assertEqual(result, [])
+
+ # But we should still get the correct values for the device
+ result = self.get_success(
+ self.store.get_last_client_ip_by_device(user_id, "device_id")
+ )
+
+ r = result[(user_id, "device_id")]
+ self.assertDictContainsSubset(
+ {
+ "user_id": user_id,
+ "device_id": "device_id",
+ "ip": "ip",
+ "user_agent": "user_agent",
+ "last_seen": 0,
+ },
+ r,
+ )
+
class ClientIpAuthTestCase(unittest.HomeserverTestCase):
|