summary refs log tree commit diff
diff options
context:
space:
mode:
authorOlivier Wilkinson (reivilibre) <oliverw@matrix.org>2022-03-18 14:24:14 +0000
committerOlivier Wilkinson (reivilibre) <oliverw@matrix.org>2022-03-18 14:24:14 +0000
commit023036838c02dd79bcfce2e8d29220af13378148 (patch)
tree9ce2ebb063dca2f1c5f92fe807db785b7fdd8c69
parentMake the background worker handle USER_IP replication commands (diff)
downloadsynapse-github/rei/update_client_ips_bgw_de1.tar.xz
-rw-r--r--synapse/replication/slave/storage/client_ips.py166
-rw-r--r--synapse/storage/databases/main/client_ips.py116
2 files changed, 159 insertions, 123 deletions
diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py
index 14706a0817..4bfeaeb54b 100644
--- a/synapse/replication/slave/storage/client_ips.py
+++ b/synapse/replication/slave/storage/client_ips.py
@@ -11,27 +11,173 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import abc
+import logging
+from typing import TYPE_CHECKING, Dict, Mapping, Optional, Tuple
 
-from typing import TYPE_CHECKING
-
-from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
+from synapse.metrics.background_process_metrics import wrap_as_background_process
+from synapse.storage.database import DatabasePool, LoggingTransaction
 from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY
 from synapse.util.caches.lrucache import LruCache
 
-from ._base import BaseSlavedStore
-
 if TYPE_CHECKING:
     from synapse.server import HomeServer
 
+logger = logging.getLogger(__name__)
+
+
+class AbstractClientIpStrategy(abc.ABC):
+    """
+    Abstract interface for the operations that a store should be able to provide
+    for dealing with client IPs.
+
+    See `DatabaseWritingClientIpStrategy` (the single writer)
+    and `ReplicationStreamingClientIpStrategy` (the
+    """
+
+    async def insert_client_ip(
+        self, user_id: str, access_token: str, ip: str, user_agent: str, device_id: str
+    ) -> None:
+        """
+        Insert a client IP.
+
+        TODO docstring
+        """
+        ...
+
+
+class DatabaseWritingClientIpStrategy(AbstractClientIpStrategy):
+    """
+    Strategy for writing client IPs by direct database access.
+    This is intended to be used on a single designated Synapse worker
+    (the background worker).
+    """
 
-class SlavedClientIpStore(BaseSlavedStore):
     def __init__(
         self,
-        database: DatabasePool,
-        db_conn: LoggingDatabaseConnection,
+        db_pool: DatabasePool,
         hs: "HomeServer",
-    ):
-        super().__init__(database, db_conn, hs)
+    ) -> None:
+        assert (
+            hs.config.worker.run_background_tasks
+        ), "This worker is not designated to update client IPs"
+
+        self._clock = hs.get_clock()
+        self._store = hs.get_datastores().main
+        self._db_pool = db_pool
+
+        # This is the designated worker that can write to the client IP
+        # tables.
+
+        # (user_id, access_token, ip,) -> last_seen
+        self.client_ip_last_seen = LruCache[Tuple[str, str, str], int](
+            cache_name="client_ip_last_seen", max_size=50000
+        )
+
+        # (user_id, access_token, ip,) -> (user_agent, device_id, last_seen)
+        self._batch_row_update: Dict[
+            Tuple[str, str, str], Tuple[str, Optional[str], int]
+        ] = {}
+
+        self._client_ip_looper = self._clock.looping_call(
+            self._update_client_ips_batch, 5 * 1000
+        )
+        hs.get_reactor().addSystemEventTrigger(
+            "before", "shutdown", self._update_client_ips_batch
+        )
+
+    async def insert_client_ip(
+        self,
+        user_id: str,
+        access_token: str,
+        ip: str,
+        user_agent: str,
+        device_id: Optional[str],
+        now: Optional[int] = None,
+    ) -> None:
+        if not now:
+            now = int(self._clock.time_msec())
+        key = (user_id, access_token, ip)
+
+        try:
+            last_seen = self.client_ip_last_seen.get(key)
+        except KeyError:
+            last_seen = None
+        await self._store.populate_monthly_active_users(user_id)
+        # Rate-limited inserts
+        if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
+            return
+
+        self.client_ip_last_seen.set(key, now)
+
+        self._batch_row_update[key] = (user_agent, device_id, now)
+
+    @wrap_as_background_process("update_client_ips")
+    async def _update_client_ips_batch(self) -> None:
+        # If the DB pool has already terminated, don't try updating
+        if not self._db_pool.is_running():
+            return
+
+        to_update = self._batch_row_update
+        self._batch_row_update = {}
+
+        await self._db_pool.runInteraction(
+            "_update_client_ips_batch", self._update_client_ips_batch_txn, to_update
+        )
+
+    def _update_client_ips_batch_txn(
+        self,
+        txn: LoggingTransaction,
+        to_update: Mapping[Tuple[str, str, str], Tuple[str, Optional[str], int]],
+    ) -> None:
+        db_pool = self._db_pool
+        if "user_ips" in db_pool._unsafe_to_upsert_tables or (
+            not db_pool.engine.can_native_upsert
+        ):
+            db_pool.engine.lock_table(txn, "user_ips")
+
+        for entry in to_update.items():
+            (user_id, access_token, ip), (user_agent, device_id, last_seen) = entry
+
+            db_pool.simple_upsert_txn(
+                txn,
+                table="user_ips",
+                keyvalues={"user_id": user_id, "access_token": access_token, "ip": ip},
+                values={
+                    "user_agent": user_agent,
+                    "device_id": device_id,
+                    "last_seen": last_seen,
+                },
+                lock=False,
+            )
+
+            # Technically an access token might not be associated with
+            # a device so we need to check.
+            if device_id:
+                # this is always an update rather than an upsert: the row should
+                # already exist, and if it doesn't, that may be because it has been
+                # deleted, and we don't want to re-create it.
+                db_pool.simple_update_txn(
+                    txn,
+                    table="devices",
+                    keyvalues={"user_id": user_id, "device_id": device_id},
+                    updatevalues={
+                        "user_agent": user_agent,
+                        "last_seen": last_seen,
+                        "ip": ip,
+                    },
+                )
+
+
+class ReplicationStreamingClientIpStrategy(AbstractClientIpStrategy):
+    """
+    Strategy for writing client IPs by streaming them over replication to
+    a designated writer worker.
+    """
+
+    def __init__(self, hs: "HomeServer"):
+        self.hs = hs
+        self._clock = hs.get_clock()
 
         self.client_ip_last_seen: LruCache[tuple, int] = LruCache(
             cache_name="client_ip_last_seen", max_size=50000
diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index 71867da01e..f75f813f53 100644
--- a/synapse/storage/databases/main/client_ips.py
+++ b/synapse/storage/databases/main/client_ips.py
@@ -13,7 +13,7 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Union, cast
+from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union, cast
 
 from typing_extensions import TypedDict
 
@@ -29,7 +29,6 @@ from synapse.storage.databases.main.monthly_active_users import (
     MonthlyActiveUsersWorkerStore,
 )
 from synapse.types import JsonDict, UserID
-from synapse.util.caches.lrucache import LruCache
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -416,25 +415,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorke
             self._clock.looping_call(self._prune_old_user_ips, 5 * 1000)
 
         if self._update_on_this_worker:
-            # This is the designated worker that can write to the client IP
-            # tables.
-
-            # (user_id, access_token, ip,) -> last_seen
-            self.client_ip_last_seen = LruCache[Tuple[str, str, str], int](
-                cache_name="client_ip_last_seen", max_size=50000
-            )
-
-            # (user_id, access_token, ip,) -> (user_agent, device_id, last_seen)
-            self._batch_row_update: Dict[
-                Tuple[str, str, str], Tuple[str, Optional[str], int]
-            ] = {}
-
-            self._client_ip_looper = self._clock.looping_call(
-                self._update_client_ips_batch, 5 * 1000
-            )
-            self.hs.get_reactor().addSystemEventTrigger(
-                "before", "shutdown", self._update_client_ips_batch
-            )
+            ...  # TODO
 
     @wrap_as_background_process("prune_old_user_ips")
     async def _prune_old_user_ips(self) -> None:
@@ -564,98 +545,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorke
             for access_token, ip, user_agent, last_seen in rows
         ]
 
-    async def insert_client_ip(
-        self,
-        user_id: str,
-        access_token: str,
-        ip: str,
-        user_agent: str,
-        device_id: Optional[str],
-        now: Optional[int] = None,
-    ) -> None:
-        assert (
-            self._update_on_this_worker
-        ), "This worker is not designated to update client IPs"
-
-        if not now:
-            now = int(self._clock.time_msec())
-        key = (user_id, access_token, ip)
-
-        try:
-            last_seen = self.client_ip_last_seen.get(key)
-        except KeyError:
-            last_seen = None
-        await self.populate_monthly_active_users(user_id)
-        # Rate-limited inserts
-        if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
-            return
-
-        self.client_ip_last_seen.set(key, now)
-
-        self._batch_row_update[key] = (user_agent, device_id, now)
-
-    @wrap_as_background_process("update_client_ips")
-    async def _update_client_ips_batch(self) -> None:
-        assert (
-            self._update_on_this_worker
-        ), "This worker is not designated to update client IPs"
-
-        # If the DB pool has already terminated, don't try updating
-        if not self.db_pool.is_running():
-            return
-
-        to_update = self._batch_row_update
-        self._batch_row_update = {}
-
-        await self.db_pool.runInteraction(
-            "_update_client_ips_batch", self._update_client_ips_batch_txn, to_update
-        )
-
-    def _update_client_ips_batch_txn(
-        self,
-        txn: LoggingTransaction,
-        to_update: Mapping[Tuple[str, str, str], Tuple[str, Optional[str], int]],
-    ) -> None:
-        assert (
-            self._update_on_this_worker
-        ), "This worker is not designated to update client IPs"
-
-        if "user_ips" in self.db_pool._unsafe_to_upsert_tables or (
-            not self.database_engine.can_native_upsert
-        ):
-            self.database_engine.lock_table(txn, "user_ips")
-
-        for entry in to_update.items():
-            (user_id, access_token, ip), (user_agent, device_id, last_seen) = entry
-
-            self.db_pool.simple_upsert_txn(
-                txn,
-                table="user_ips",
-                keyvalues={"user_id": user_id, "access_token": access_token, "ip": ip},
-                values={
-                    "user_agent": user_agent,
-                    "device_id": device_id,
-                    "last_seen": last_seen,
-                },
-                lock=False,
-            )
-
-            # Technically an access token might not be associated with
-            # a device so we need to check.
-            if device_id:
-                # this is always an update rather than an upsert: the row should
-                # already exist, and if it doesn't, that may be because it has been
-                # deleted, and we don't want to re-create it.
-                self.db_pool.simple_update_txn(
-                    txn,
-                    table="devices",
-                    keyvalues={"user_id": user_id, "device_id": device_id},
-                    updatevalues={
-                        "user_agent": user_agent,
-                        "last_seen": last_seen,
-                        "ip": ip,
-                    },
-                )
+    # TODO ici
 
     async def get_last_client_ip_by_device(
         self, user_id: str, device_id: Optional[str]