summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/10833.misc1
-rw-r--r--synapse/module_api/__init__.py82
-rw-r--r--synapse/storage/databases/main/client_ips.py27
-rw-r--r--tests/module_api/test_api.py72
4 files changed, 174 insertions, 8 deletions
diff --git a/changelog.d/10833.misc b/changelog.d/10833.misc
new file mode 100644
index 0000000000..f23c0a1a02
--- /dev/null
+++ b/changelog.d/10833.misc
@@ -0,0 +1 @@
+Extend the ModuleApi to let plug-ins check whether an ID is local and to access IP + User Agent data.
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 3196c2bec6..174e6934a8 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -24,8 +24,10 @@ from typing import (
     List,
     Optional,
     Tuple,
+    Union,
 )
 
+import attr
 import jinja2
 
 from twisted.internet import defer
@@ -46,7 +48,14 @@ from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage.database import DatabasePool, LoggingTransaction
 from synapse.storage.databases.main.roommember import ProfileInfo
 from synapse.storage.state import StateFilter
-from synapse.types import JsonDict, Requester, UserID, UserInfo, create_requester
+from synapse.types import (
+    DomainSpecificString,
+    JsonDict,
+    Requester,
+    UserID,
+    UserInfo,
+    create_requester,
+)
 from synapse.util import Clock
 from synapse.util.caches.descriptors import cached
 
@@ -79,6 +88,18 @@ __all__ = [
 logger = logging.getLogger(__name__)
 
 
+@attr.s(auto_attribs=True)
+class UserIpAndAgent:
+    """
+    An IP address and user agent used by a user to connect to this homeserver.
+    """
+
+    ip: str
+    user_agent: str
+    # The time at which this user agent/ip was last seen.
+    last_seen: int
+
+
 class ModuleApi:
     """A proxy object that gets passed to various plugin modules so they
     can register new users etc if necessary.
@@ -700,6 +721,65 @@ class ModuleApi:
             (td for td in (self.custom_template_dir, custom_template_directory) if td),
         )
 
+    def is_mine(self, id: Union[str, DomainSpecificString]) -> bool:
+        """
+        Checks whether an ID (user id, room, ...) comes from this homeserver.
+
+        Args:
+            id: any Matrix id (e.g. user id, room id, ...), either as a raw id,
+                e.g. string "@user:example.com" or as a parsed UserID, RoomID, ...
+        Returns:
+            True if id comes from this homeserver, False otherwise.
+
+        Added in Synapse v1.44.0.
+        """
+        if isinstance(id, DomainSpecificString):
+            return self._hs.is_mine(id)
+        else:
+            return self._hs.is_mine_id(id)
+
+    async def get_user_ip_and_agents(
+        self, user_id: str, since_ts: int = 0
+    ) -> List[UserIpAndAgent]:
+        """
+        Return the list of user IPs and agents for a user.
+
+        Args:
+            user_id: the id of a user, local or remote
+            since_ts: a timestamp in seconds since the epoch,
+                or the epoch itself if not specified.
+        Returns:
+            The list of all UserIpAndAgent that the user has
+            used to connect to this homeserver since `since_ts`.
+            If the user is remote, this list is empty.
+
+        Added in Synapse v1.44.0.
+        """
+        # Don't hit the db if this is not a local user.
+        is_mine = False
+        try:
+            # Let's be defensive against ill-formed strings.
+            if self.is_mine(user_id):
+                is_mine = True
+        except Exception:
+            pass
+
+        if is_mine:
+            raw_data = await self._store.get_user_ip_and_agents(
+                UserID.from_string(user_id), since_ts
+            )
+            # Sanitize some of the data. We don't want to return tokens.
+            return [
+                UserIpAndAgent(
+                    ip=str(data["ip"]),
+                    user_agent=str(data["user_agent"]),
+                    last_seen=int(data["last_seen"]),
+                )
+                for data in raw_data
+            ]
+        else:
+            return []
+
 
 class PublicRoomListManager:
     """Contains methods for adding to, removing from and querying whether a room
diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index 7a98275d92..7e33ae578c 100644
--- a/synapse/storage/databases/main/client_ips.py
+++ b/synapse/storage/databases/main/client_ips.py
@@ -555,8 +555,11 @@ class ClientIpStore(ClientIpWorkerStore):
         return ret
 
     async def get_user_ip_and_agents(
-        self, user: UserID
+        self, user: UserID, since_ts: int = 0
     ) -> List[Dict[str, Union[str, int]]]:
+        """
+        Fetch IP/User Agent connection since a given timestamp.
+        """
         user_id = user.to_string()
         results = {}
 
@@ -568,13 +571,23 @@ class ClientIpStore(ClientIpWorkerStore):
             ) = key
             if uid == user_id:
                 user_agent, _, last_seen = self._batch_row_update[key]
-                results[(access_token, ip)] = (user_agent, last_seen)
+                if last_seen >= since_ts:
+                    results[(access_token, ip)] = (user_agent, last_seen)
 
-        rows = await self.db_pool.simple_select_list(
-            table="user_ips",
-            keyvalues={"user_id": user_id},
-            retcols=["access_token", "ip", "user_agent", "last_seen"],
-            desc="get_user_ip_and_agents",
+        def get_recent(txn):
+            txn.execute(
+                """
+                SELECT access_token, ip, user_agent, last_seen FROM user_ips
+                WHERE last_seen >= ? AND user_id = ?
+                ORDER BY last_seen
+                DESC
+                """,
+                (since_ts, user_id),
+            )
+            return txn.fetchall()
+
+        rows = await self.db_pool.runInteraction(
+            desc="get_user_ip_and_agents", func=get_recent
         )
 
         results.update(
diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
index 7dd519cd44..9d38974fba 100644
--- a/tests/module_api/test_api.py
+++ b/tests/module_api/test_api.py
@@ -43,6 +43,7 @@ class ModuleApiTestCase(HomeserverTestCase):
         self.module_api = homeserver.get_module_api()
         self.event_creation_handler = homeserver.get_event_creation_handler()
         self.sync_handler = homeserver.get_sync_handler()
+        self.auth_handler = homeserver.get_auth_handler()
 
     def make_homeserver(self, reactor, clock):
         return self.setup_test_homeserver(
@@ -89,6 +90,77 @@ class ModuleApiTestCase(HomeserverTestCase):
         found_user = self.get_success(self.module_api.get_userinfo_by_id("@alice:test"))
         self.assertIsNone(found_user)
 
+    def test_get_user_ip_and_agents(self):
+        user_id = self.register_user("test_get_user_ip_and_agents_user", "1234")
+
+        # Initially, we should have no ip/agent for our user.
+        info = self.get_success(self.module_api.get_user_ip_and_agents(user_id))
+        self.assertEqual(info, [])
+
+        # Insert a first ip, agent. We should be able to retrieve it.
+        self.get_success(
+            self.store.insert_client_ip(
+                user_id, "access_token", "ip_1", "user_agent_1", "device_1", None
+            )
+        )
+        info = self.get_success(self.module_api.get_user_ip_and_agents(user_id))
+
+        self.assertEqual(len(info), 1)
+        last_seen_1 = info[0].last_seen
+
+        # Insert a second ip, agent at a later date. We should be able to retrieve it.
+        last_seen_2 = last_seen_1 + 10000
+        print("%s => %s" % (last_seen_1, last_seen_2))
+        self.get_success(
+            self.store.insert_client_ip(
+                user_id, "access_token", "ip_2", "user_agent_2", "device_2", last_seen_2
+            )
+        )
+        info = self.get_success(self.module_api.get_user_ip_and_agents(user_id))
+
+        self.assertEqual(len(info), 2)
+        ip_1_seen = False
+        ip_2_seen = False
+
+        for i in info:
+            if i.ip == "ip_1":
+                ip_1_seen = True
+                self.assertEqual(i.user_agent, "user_agent_1")
+                self.assertEqual(i.last_seen, last_seen_1)
+            elif i.ip == "ip_2":
+                ip_2_seen = True
+                self.assertEqual(i.user_agent, "user_agent_2")
+                self.assertEqual(i.last_seen, last_seen_2)
+        self.assertTrue(ip_1_seen)
+        self.assertTrue(ip_2_seen)
+
+        # If we fetch from a midpoint between last_seen_1 and last_seen_2,
+        # we should only find the second ip, agent.
+        info = self.get_success(
+            self.module_api.get_user_ip_and_agents(
+                user_id, (last_seen_1 + last_seen_2) / 2
+            )
+        )
+        self.assertEqual(len(info), 1)
+        self.assertEqual(info[0].ip, "ip_2")
+        self.assertEqual(info[0].user_agent, "user_agent_2")
+        self.assertEqual(info[0].last_seen, last_seen_2)
+
+        # If we fetch from a point later than last_seen_2, we shouldn't
+        # find anything.
+        info = self.get_success(
+            self.module_api.get_user_ip_and_agents(user_id, last_seen_2 + 10000)
+        )
+        self.assertEqual(info, [])
+
+    def test_get_user_ip_and_agents__no_user_found(self):
+        info = self.get_success(
+            self.module_api.get_user_ip_and_agents(
+                "@test_get_user_ip_and_agents_user_nonexistent:example.com"
+            )
+        )
+        self.assertEqual(info, [])
+
     def test_sending_events_into_room(self):
         """Tests that a module can send events into a room"""
         # Mock out create_and_send_nonmember_event to check whether events are being sent