diff --git a/changelog.d/9092.feature b/changelog.d/9092.feature
new file mode 100644
index 0000000000..64843a6a95
--- /dev/null
+++ b/changelog.d/9092.feature
@@ -0,0 +1 @@
+ Add experimental support for handling `/devices` API on worker processes.
diff --git a/docs/workers.md b/docs/workers.md
index 298adf8695..7fb651bba4 100644
--- a/docs/workers.md
+++ b/docs/workers.md
@@ -214,6 +214,7 @@ expressions:
^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/members$
^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/state$
^/_matrix/client/(api/v1|r0|unstable)/account/3pid$
+ ^/_matrix/client/(api/v1|r0|unstable)/devices$
^/_matrix/client/(api/v1|r0|unstable)/keys/query$
^/_matrix/client/(api/v1|r0|unstable)/keys/changes$
^/_matrix/client/versions$
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index f24c648ac7..cb202bda44 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -107,6 +107,7 @@ from synapse.rest.client.v2_alpha.account_data import (
AccountDataServlet,
RoomAccountDataServlet,
)
+from synapse.rest.client.v2_alpha.devices import DevicesRestServlet
from synapse.rest.client.v2_alpha.keys import (
KeyChangesServlet,
KeyQueryServlet,
@@ -509,6 +510,7 @@ class GenericWorkerServer(HomeServer):
RegisterRestServlet(self).register(resource)
LoginRestServlet(self).register(resource)
ThreepidRestServlet(self).register(resource)
+ DevicesRestServlet(self).register(resource)
KeyQueryServlet(self).register(resource)
OneTimeKeyServlet(self).register(resource)
KeyChangesServlet(self).register(resource)
diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index c53c836337..ea1e8fb580 100644
--- a/synapse/storage/databases/main/client_ips.py
+++ b/synapse/storage/databases/main/client_ips.py
@@ -407,6 +407,34 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
"_prune_old_user_ips", _prune_old_user_ips_txn
)
+ async def get_last_client_ip_by_device(
+ self, user_id: str, device_id: Optional[str]
+ ) -> Dict[Tuple[str, str], dict]:
+ """For each device_id listed, give the user_ip it was last seen on.
+
+ The result might be slightly out of date as client IPs are inserted in batches.
+
+ Args:
+ user_id: The user to fetch devices for.
+ device_id: If None fetches all devices for the user
+
+ Returns:
+ A dictionary mapping a tuple of (user_id, device_id) to dicts, with
+ keys giving the column names from the devices table.
+ """
+
+ keyvalues = {"user_id": user_id}
+ if device_id is not None:
+ keyvalues["device_id"] = device_id
+
+ res = await self.db_pool.simple_select_list(
+ table="devices",
+ keyvalues=keyvalues,
+ retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
+ )
+
+ return {(d["user_id"], d["device_id"]): d for d in res}
+
class ClientIpStore(ClientIpWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs):
@@ -512,18 +540,9 @@ class ClientIpStore(ClientIpWorkerStore):
A dictionary mapping a tuple of (user_id, device_id) to dicts, with
keys giving the column names from the devices table.
"""
+ ret = await super().get_last_client_ip_by_device(user_id, device_id)
- keyvalues = {"user_id": user_id}
- if device_id is not None:
- keyvalues["device_id"] = device_id
-
- res = await self.db_pool.simple_select_list(
- table="devices",
- keyvalues=keyvalues,
- retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
- )
-
- ret = {(d["user_id"], d["device_id"]): d for d in res}
+ # Update what is retrieved from the database with data which is pending insertion.
for key in self._batch_row_update:
uid, access_token, ip = key
if uid == user_id:
|