summary refs log tree commit diff
path: root/synapse/storage/client_ips.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/client_ips.py')
-rw-r--r--synapse/storage/client_ips.py85
1 files changed, 81 insertions, 4 deletions
diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py
index a90990e006..71e5ea112f 100644
--- a/synapse/storage/client_ips.py
+++ b/synapse/storage/client_ips.py
@@ -13,10 +13,14 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from ._base import SQLBaseStore, Cache
+import logging
 
 from twisted.internet import defer
 
+from ._base import Cache
+from . import background_updates
+
+logger = logging.getLogger(__name__)
 
 # Number of msec of granularity to store the user IP 'last seen' time. Smaller
 # times give more inserts into the database even for readonly API hits
@@ -24,8 +28,7 @@ from twisted.internet import defer
 LAST_SEEN_GRANULARITY = 120 * 1000
 
 
-class ClientIpStore(SQLBaseStore):
-
+class ClientIpStore(background_updates.BackgroundUpdateStore):
     def __init__(self, hs):
         self.client_ip_last_seen = Cache(
             name="client_ip_last_seen",
@@ -34,8 +37,15 @@ class ClientIpStore(SQLBaseStore):
 
         super(ClientIpStore, self).__init__(hs)
 
+        self.register_background_index_update(
+            "user_ips_device_index",
+            index_name="user_ips_device_id",
+            table="user_ips",
+            columns=["user_id", "device_id", "last_seen"],
+        )
+
     @defer.inlineCallbacks
-    def insert_client_ip(self, user, access_token, ip, user_agent):
+    def insert_client_ip(self, user, access_token, ip, user_agent, device_id):
         now = int(self._clock.time_msec())
         key = (user.to_string(), access_token, ip)
 
@@ -59,6 +69,7 @@ class ClientIpStore(SQLBaseStore):
                 "access_token": access_token,
                 "ip": ip,
                 "user_agent": user_agent,
+                "device_id": device_id,
             },
             values={
                 "last_seen": now,
@@ -66,3 +77,69 @@ class ClientIpStore(SQLBaseStore):
             desc="insert_client_ip",
             lock=False,
         )
+
+    @defer.inlineCallbacks
+    def get_last_client_ip_by_device(self, devices):
+        """For each device_id listed, give the user_ip it was last seen on
+
+        Args:
+            devices (iterable[(str, str)]):  list of (user_id, device_id) pairs
+
+        Returns:
+            defer.Deferred: resolves to a dict, where the keys
+            are (user_id, device_id) tuples. The values are also dicts, with
+            keys giving the column names
+        """
+
+        res = yield self.runInteraction(
+            "get_last_client_ip_by_device",
+            self._get_last_client_ip_by_device_txn,
+            retcols=(
+                "user_id",
+                "access_token",
+                "ip",
+                "user_agent",
+                "device_id",
+                "last_seen",
+            ),
+            devices=devices
+        )
+
+        ret = {(d["user_id"], d["device_id"]): d for d in res}
+        defer.returnValue(ret)
+
+    @classmethod
+    def _get_last_client_ip_by_device_txn(cls, txn, devices, retcols):
+        where_clauses = []
+        bindings = []
+        for (user_id, device_id) in devices:
+            if device_id is None:
+                where_clauses.append("(user_id = ? AND device_id IS NULL)")
+                bindings.extend((user_id, ))
+            else:
+                where_clauses.append("(user_id = ? AND device_id = ?)")
+                bindings.extend((user_id, device_id))
+
+        inner_select = (
+            "SELECT MAX(last_seen) mls, user_id, device_id FROM user_ips "
+            "WHERE %(where)s "
+            "GROUP BY user_id, device_id"
+        ) % {
+            "where": " OR ".join(where_clauses),
+        }
+
+        sql = (
+            "SELECT %(retcols)s FROM user_ips "
+            "JOIN (%(inner_select)s) ips ON"
+            "    user_ips.last_seen = ips.mls AND"
+            "    user_ips.user_id = ips.user_id AND"
+            "    (user_ips.device_id = ips.device_id OR"
+            "         (user_ips.device_id IS NULL AND ips.device_id IS NULL)"
+            "    )"
+        ) % {
+            "retcols": ",".join("user_ips." + c for c in retcols),
+            "inner_select": inner_select,
+        }
+
+        txn.execute(sql, bindings)
+        return cls.cursor_to_dict(txn)