diff options
author | Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> | 2019-11-14 14:22:58 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2019-11-14 14:22:58 +0000 |
commit | 28578e75683c7f75c4be71c530a10efe84f2671f (patch) | |
tree | 2bf30d155d24db7e05c05bae80fcb41b0e61e393 /synapse | |
parent | Create configurable ratelimiter for 3pid invites (#11) (diff) | |
download | synapse-28578e75683c7f75c4be71c530a10efe84f2671f.tar.xz |
Add a /user/:user_id/info servlet to give user deactivated/expired information (#12)
Diffstat (limited to 'synapse')
-rw-r--r-- | synapse/rest/client/v2_alpha/user_directory.py | 89 | ||||
-rw-r--r-- | synapse/storage/_base.py | 9 |
2 files changed, 92 insertions, 6 deletions
diff --git a/synapse/rest/client/v2_alpha/user_directory.py b/synapse/rest/client/v2_alpha/user_directory.py index e3603f2998..b6f4d8b3f4 100644 --- a/synapse/rest/client/v2_alpha/user_directory.py +++ b/synapse/rest/client/v2_alpha/user_directory.py @@ -21,6 +21,7 @@ from twisted.internet import defer from synapse.api.errors import SynapseError from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.types import UserID from ._base import client_patterns @@ -93,5 +94,93 @@ class UserDirectorySearchRestServlet(RestServlet): defer.returnValue((200, results)) +class UserInfoServlet(RestServlet): + """ + GET /user/{user_id}/info HTTP/1.1 + """ + PATTERNS = client_patterns( + "/user/(?P<user_id>[^/]*)/info$" + ) + + def __init__(self, hs): + super(UserInfoServlet, self).__init__() + self.hs = hs + self.auth = hs.get_auth() + self.store = hs.get_datastore() + self.notifier = hs.get_notifier() + self.clock = hs.get_clock() + self.transport_layer = hs.get_federation_transport_client() + registry = hs.get_federation_registry() + + if not registry.query_handlers.get("user_info"): + registry.register_query_handler( + "user_info", self._on_federation_query + ) + + @defer.inlineCallbacks + def on_GET(self, request, user_id): + # Ensure the user is authenticated + yield self.auth.get_user_by_req(request, allow_guest=False) + + user = UserID.from_string(user_id) + if not self.hs.is_mine(user): + # Attempt to make a federation request to the server that owns this user + args = {"user_id": user_id} + res = yield self.transport_layer.make_query( + user.domain, "user_info", args, retry_on_dns_fail=True, + ) + defer.returnValue((200, res)) + + res = yield self._get_user_info(user_id) + defer.returnValue((200, res)) + + @defer.inlineCallbacks + def _on_federation_query(self, args): + """Called when a request for user information appears over federation + + Args: + args (dict): Dictionary of query arguments provided by the request + + Returns: + Deferred[dict]: Deactivation and expiration information for a given user + """ + user_id = args.get("user_id") + if not user_id: + raise SynapseError(400, "user_id not provided") + + user = UserID.from_string(user_id) + if not self.hs.is_mine(user): + raise SynapseError(400, "User is not hosted on this homeserver") + + res = yield self._get_user_info(user_id) + defer.returnValue(res) + + @defer.inlineCallbacks + def _get_user_info(self, user_id): + """Retrieve information about a given user + + Args: + user_id (str): The User ID of a given user on this homeserver + + Returns: + Deferred[dict]: Deactivation and expiration information for a given user + """ + # Check whether user is deactivated + is_deactivated = yield self.store.get_user_deactivated_status(user_id) + + # Check whether user is expired + expiration_ts = yield self.store.get_expiration_ts_for_user(user_id) + is_expired = ( + expiration_ts is not None and self.clock.time_msec() >= expiration_ts + ) + + res = { + "expired": is_expired, + "deactivated": is_deactivated, + } + defer.returnValue(res) + + def register_servlets(hs, http_server): UserDirectorySearchRestServlet(hs).register(http_server) + UserInfoServlet(hs).register(http_server) diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 941c07fce5..537696547c 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -341,14 +341,11 @@ class SQLBaseStore(object): expiration_ts, ) - self._simple_insert_txn( + self._simple_upsert_txn( txn, "account_validity", - values={ - "user_id": user_id, - "expiration_ts_ms": expiration_ts, - "email_sent": False, - }, + keyvalues={"user_id": user_id, }, + values={"expiration_ts_ms": expiration_ts, "email_sent": False, }, ) def start_profiling(self): |