summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/app/generic_worker.py2
-rw-r--r--synapse/http/client.py10
-rw-r--r--synapse/storage/databases/main/client_ips.py41
3 files changed, 37 insertions, 16 deletions
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index f24c648ac7..cb202bda44 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -107,6 +107,7 @@ from synapse.rest.client.v2_alpha.account_data import (
     AccountDataServlet,
     RoomAccountDataServlet,
 )
+from synapse.rest.client.v2_alpha.devices import DevicesRestServlet
 from synapse.rest.client.v2_alpha.keys import (
     KeyChangesServlet,
     KeyQueryServlet,
@@ -509,6 +510,7 @@ class GenericWorkerServer(HomeServer):
                     RegisterRestServlet(self).register(resource)
                     LoginRestServlet(self).register(resource)
                     ThreepidRestServlet(self).register(resource)
+                    DevicesRestServlet(self).register(resource)
                     KeyQueryServlet(self).register(resource)
                     OneTimeKeyServlet(self).register(resource)
                     KeyChangesServlet(self).register(resource)
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 5f74ee1149..dc4b81ca60 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -32,7 +32,7 @@ from typing import (
 
 import treq
 from canonicaljson import encode_canonical_json
-from netaddr import IPAddress, IPSet
+from netaddr import AddrFormatError, IPAddress, IPSet
 from prometheus_client import Counter
 from zope.interface import implementer, provider
 
@@ -261,16 +261,16 @@ class BlacklistingAgentWrapper(Agent):
 
         try:
             ip_address = IPAddress(h.hostname)
-
+        except AddrFormatError:
+            # Not an IP
+            pass
+        else:
             if check_against_blacklist(
                 ip_address, self._ip_whitelist, self._ip_blacklist
             ):
                 logger.info("Blocking access to %s due to blacklist" % (ip_address,))
                 e = SynapseError(403, "IP address blocked by IP blacklist entry")
                 return defer.fail(Failure(e))
-        except Exception:
-            # Not an IP
-            pass
 
         return self._agent.request(
             method, uri, headers=headers, bodyProducer=bodyProducer
diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index c53c836337..ea1e8fb580 100644
--- a/synapse/storage/databases/main/client_ips.py
+++ b/synapse/storage/databases/main/client_ips.py
@@ -407,6 +407,34 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
             "_prune_old_user_ips", _prune_old_user_ips_txn
         )
 
+    async def get_last_client_ip_by_device(
+        self, user_id: str, device_id: Optional[str]
+    ) -> Dict[Tuple[str, str], dict]:
+        """For each device_id listed, give the user_ip it was last seen on.
+
+        The result might be slightly out of date as client IPs are inserted in batches.
+
+        Args:
+            user_id: The user to fetch devices for.
+            device_id: If None fetches all devices for the user
+
+        Returns:
+            A dictionary mapping a tuple of (user_id, device_id) to dicts, with
+            keys giving the column names from the devices table.
+        """
+
+        keyvalues = {"user_id": user_id}
+        if device_id is not None:
+            keyvalues["device_id"] = device_id
+
+        res = await self.db_pool.simple_select_list(
+            table="devices",
+            keyvalues=keyvalues,
+            retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
+        )
+
+        return {(d["user_id"], d["device_id"]): d for d in res}
+
 
 class ClientIpStore(ClientIpWorkerStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
@@ -512,18 +540,9 @@ class ClientIpStore(ClientIpWorkerStore):
             A dictionary mapping a tuple of (user_id, device_id) to dicts, with
             keys giving the column names from the devices table.
         """
+        ret = await super().get_last_client_ip_by_device(user_id, device_id)
 
-        keyvalues = {"user_id": user_id}
-        if device_id is not None:
-            keyvalues["device_id"] = device_id
-
-        res = await self.db_pool.simple_select_list(
-            table="devices",
-            keyvalues=keyvalues,
-            retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
-        )
-
-        ret = {(d["user_id"], d["device_id"]): d for d in res}
+        # Update what is retrieved from the database with data which is pending insertion.
         for key in self._batch_row_update:
             uid, access_token, ip = key
             if uid == user_id: