summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorAndrew Morgan <1342360+anoadragon453@users.noreply.github.com>2019-11-14 14:22:58 +0000
committerGitHub <noreply@github.com>2019-11-14 14:22:58 +0000
commit28578e75683c7f75c4be71c530a10efe84f2671f (patch)
tree2bf30d155d24db7e05c05bae80fcb41b0e61e393 /synapse
parentCreate configurable ratelimiter for 3pid invites (#11) (diff)
downloadsynapse-28578e75683c7f75c4be71c530a10efe84f2671f.tar.xz
Add a /user/:user_id/info servlet to give user deactivated/expired information (#12)
Diffstat (limited to 'synapse')
-rw-r--r--synapse/rest/client/v2_alpha/user_directory.py89
-rw-r--r--synapse/storage/_base.py9
2 files changed, 92 insertions, 6 deletions
diff --git a/synapse/rest/client/v2_alpha/user_directory.py b/synapse/rest/client/v2_alpha/user_directory.py
index e3603f2998..b6f4d8b3f4 100644
--- a/synapse/rest/client/v2_alpha/user_directory.py
+++ b/synapse/rest/client/v2_alpha/user_directory.py
@@ -21,6 +21,7 @@ from twisted.internet import defer
 
 from synapse.api.errors import SynapseError
 from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.types import UserID
 
 from ._base import client_patterns
 
@@ -93,5 +94,93 @@ class UserDirectorySearchRestServlet(RestServlet):
         defer.returnValue((200, results))
 
 
+class UserInfoServlet(RestServlet):
+    """
+    GET /user/{user_id}/info HTTP/1.1
+    """
+    PATTERNS = client_patterns(
+        "/user/(?P<user_id>[^/]*)/info$"
+    )
+
+    def __init__(self, hs):
+        super(UserInfoServlet, self).__init__()
+        self.hs = hs
+        self.auth = hs.get_auth()
+        self.store = hs.get_datastore()
+        self.notifier = hs.get_notifier()
+        self.clock = hs.get_clock()
+        self.transport_layer = hs.get_federation_transport_client()
+        registry = hs.get_federation_registry()
+
+        if not registry.query_handlers.get("user_info"):
+            registry.register_query_handler(
+                "user_info", self._on_federation_query
+            )
+
+    @defer.inlineCallbacks
+    def on_GET(self, request, user_id):
+        # Ensure the user is authenticated
+        yield self.auth.get_user_by_req(request, allow_guest=False)
+
+        user = UserID.from_string(user_id)
+        if not self.hs.is_mine(user):
+            # Attempt to make a federation request to the server that owns this user
+            args = {"user_id": user_id}
+            res = yield self.transport_layer.make_query(
+                user.domain, "user_info", args, retry_on_dns_fail=True,
+            )
+            defer.returnValue((200, res))
+
+        res = yield self._get_user_info(user_id)
+        defer.returnValue((200, res))
+
+    @defer.inlineCallbacks
+    def _on_federation_query(self, args):
+        """Called when a request for user information appears over federation
+
+        Args:
+            args (dict): Dictionary of query arguments provided by the request
+
+        Returns:
+            Deferred[dict]: Deactivation and expiration information for a given user
+        """
+        user_id = args.get("user_id")
+        if not user_id:
+            raise SynapseError(400, "user_id not provided")
+
+        user = UserID.from_string(user_id)
+        if not self.hs.is_mine(user):
+            raise SynapseError(400, "User is not hosted on this homeserver")
+
+        res = yield self._get_user_info(user_id)
+        defer.returnValue(res)
+
+    @defer.inlineCallbacks
+    def _get_user_info(self, user_id):
+        """Retrieve information about a given user
+
+        Args:
+            user_id (str): The User ID of a given user on this homeserver
+
+        Returns:
+            Deferred[dict]: Deactivation and expiration information for a given user
+        """
+        # Check whether user is deactivated
+        is_deactivated = yield self.store.get_user_deactivated_status(user_id)
+
+        # Check whether user is expired
+        expiration_ts = yield self.store.get_expiration_ts_for_user(user_id)
+        is_expired = (
+            expiration_ts is not None and self.clock.time_msec() >= expiration_ts
+        )
+
+        res = {
+            "expired": is_expired,
+            "deactivated": is_deactivated,
+        }
+        defer.returnValue(res)
+
+
 def register_servlets(hs, http_server):
     UserDirectorySearchRestServlet(hs).register(http_server)
+    UserInfoServlet(hs).register(http_server)
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 941c07fce5..537696547c 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -341,14 +341,11 @@ class SQLBaseStore(object):
                 expiration_ts,
             )
 
-        self._simple_insert_txn(
+        self._simple_upsert_txn(
             txn,
             "account_validity",
-            values={
-                "user_id": user_id,
-                "expiration_ts_ms": expiration_ts,
-                "email_sent": False,
-            },
+            keyvalues={"user_id": user_id, },
+            values={"expiration_ts_ms": expiration_ts, "email_sent": False, },
         )
 
     def start_profiling(self):