summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erikj@jki.re>2016-10-11 11:20:54 +0100
committerGitHub <noreply@github.com>2016-10-11 11:20:54 +0100
commita2f2516199dc0c4bb94b26a870da0d6531d57723 (patch)
tree3f4ad96f70e2ad39e8881d9f65e0380b547d3582
parentMerge pull request #1150 from Rugvip/state_key (diff)
parentrest/client/v1/register: use the correct requester in createUser (diff)
downloadsynapse-a2f2516199dc0c4bb94b26a870da0d6531d57723.tar.xz
Merge pull request #1157 from Rugvip/nolimit
Remove rate limiting from app service senders and fix get_or_create_user requester
-rw-r--r--synapse/api/auth.py7
-rw-r--r--synapse/handlers/_base.py8
-rw-r--r--synapse/handlers/appservice.py20
-rw-r--r--synapse/handlers/directory.py11
-rw-r--r--synapse/handlers/profile.py8
-rw-r--r--synapse/handlers/register.py11
-rw-r--r--synapse/handlers/room.py2
-rw-r--r--synapse/handlers/sync.py2
-rw-r--r--synapse/rest/client/v1/register.py11
-rw-r--r--synapse/storage/appservice.py12
-rw-r--r--tests/handlers/test_register.py8
-rw-r--r--tests/rest/client/v1/test_register.py30
-rw-r--r--tests/rest/client/v2_alpha/test_register.py2
-rw-r--r--tests/storage/test_appservice.py9
14 files changed, 64 insertions, 77 deletions
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index b994f07de4..1b3b55d517 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -653,7 +653,7 @@ class Auth(object):
 
     @defer.inlineCallbacks
     def _get_appservice_user_id(self, request):
-        app_service = yield self.store.get_app_service_by_token(
+        app_service = self.store.get_app_service_by_token(
             get_access_token_from_request(
                 request, self.TOKEN_NOT_FOUND_HTTP_STATUS
             )
@@ -855,13 +855,12 @@ class Auth(object):
         }
         defer.returnValue(user_info)
 
-    @defer.inlineCallbacks
     def get_appservice_by_req(self, request):
         try:
             token = get_access_token_from_request(
                 request, self.TOKEN_NOT_FOUND_HTTP_STATUS
             )
-            service = yield self.store.get_app_service_by_token(token)
+            service = self.store.get_app_service_by_token(token)
             if not service:
                 logger.warn("Unrecognised appservice access token: %s" % (token,))
                 raise AuthError(
@@ -870,7 +869,7 @@ class Auth(object):
                     errcode=Codes.UNKNOWN_TOKEN
                 )
             request.authenticated_entity = service.sender
-            defer.returnValue(service)
+            return defer.succeed(service)
         except KeyError:
             raise AuthError(
                 self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token."
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index e58735294e..4981643166 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -55,8 +55,14 @@ class BaseHandler(object):
 
     def ratelimit(self, requester):
         time_now = self.clock.time()
+        user_id = requester.user.to_string()
+
+        app_service = self.store.get_app_service_by_user_id(user_id)
+        if app_service is not None:
+            return  # do not ratelimit app service senders
+
         allowed, time_allowed = self.ratelimiter.send_message(
-            requester.user.to_string(), time_now,
+            user_id, time_now,
             msg_rate_hz=self.hs.config.rc_messages_per_second,
             burst_count=self.hs.config.rc_message_burst_count,
         )
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 88fa0bb2e4..05af54d31b 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -59,7 +59,7 @@ class ApplicationServicesHandler(object):
         Args:
             current_id(int): The current maximum ID.
         """
-        services = yield self.store.get_app_services()
+        services = self.store.get_app_services()
         if not services or not self.notify_appservices:
             return
 
@@ -142,7 +142,7 @@ class ApplicationServicesHandler(object):
             association can be found.
         """
         room_alias_str = room_alias.to_string()
-        services = yield self.store.get_app_services()
+        services = self.store.get_app_services()
         alias_query_services = [
             s for s in services if (
                 s.is_interested_in_alias(room_alias_str)
@@ -177,7 +177,7 @@ class ApplicationServicesHandler(object):
 
     @defer.inlineCallbacks
     def get_3pe_protocols(self, only_protocol=None):
-        services = yield self.store.get_app_services()
+        services = self.store.get_app_services()
         protocols = {}
 
         # Collect up all the individual protocol responses out of the ASes
@@ -224,7 +224,7 @@ class ApplicationServicesHandler(object):
             list<ApplicationService>: A list of services interested in this
             event based on the service regex.
         """
-        services = yield self.store.get_app_services()
+        services = self.store.get_app_services()
         interested_list = [
             s for s in services if (
                 yield s.is_interested(event, self.store)
@@ -232,23 +232,21 @@ class ApplicationServicesHandler(object):
         ]
         defer.returnValue(interested_list)
 
-    @defer.inlineCallbacks
     def _get_services_for_user(self, user_id):
-        services = yield self.store.get_app_services()
+        services = self.store.get_app_services()
         interested_list = [
             s for s in services if (
                 s.is_interested_in_user(user_id)
             )
         ]
-        defer.returnValue(interested_list)
+        return defer.succeed(interested_list)
 
-    @defer.inlineCallbacks
     def _get_services_for_3pn(self, protocol):
-        services = yield self.store.get_app_services()
+        services = self.store.get_app_services()
         interested_list = [
             s for s in services if s.is_interested_in_protocol(protocol)
         ]
-        defer.returnValue(interested_list)
+        return defer.succeed(interested_list)
 
     @defer.inlineCallbacks
     def _is_unknown_user(self, user_id):
@@ -264,7 +262,7 @@ class ApplicationServicesHandler(object):
             return
 
         # user not found; could be the AS though, so check.
-        services = yield self.store.get_app_services()
+        services = self.store.get_app_services()
         service_list = [s for s in services if s.sender == user_id]
         defer.returnValue(len(service_list) == 0)
 
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index 14352985e2..c00274afc3 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -288,13 +288,12 @@ class DirectoryHandler(BaseHandler):
             result = yield as_handler.query_room_alias_exists(room_alias)
         defer.returnValue(result)
 
-    @defer.inlineCallbacks
     def can_modify_alias(self, alias, user_id=None):
         # Any application service "interested" in an alias they are regexing on
         # can modify the alias.
         # Users can only modify the alias if ALL the interested services have
         # non-exclusive locks on the alias (or there are no interested services)
-        services = yield self.store.get_app_services()
+        services = self.store.get_app_services()
         interested_services = [
             s for s in services if s.is_interested_in_alias(alias.to_string())
         ]
@@ -302,14 +301,12 @@ class DirectoryHandler(BaseHandler):
         for service in interested_services:
             if user_id == service.sender:
                 # this user IS the app service so they can do whatever they like
-                defer.returnValue(True)
-                return
+                return defer.succeed(True)
             elif service.is_exclusive_alias(alias.to_string()):
                 # another service has an exclusive lock on this alias.
-                defer.returnValue(False)
-                return
+                return defer.succeed(False)
         # either no interested services, or no service with an exclusive lock
-        defer.returnValue(True)
+        return defer.succeed(True)
 
     @defer.inlineCallbacks
     def _user_can_delete_alias(self, alias, user_id):
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index d9ac09078d..87f74dfb8e 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -65,13 +65,13 @@ class ProfileHandler(BaseHandler):
                 defer.returnValue(result["displayname"])
 
     @defer.inlineCallbacks
-    def set_displayname(self, target_user, requester, new_displayname):
+    def set_displayname(self, target_user, requester, new_displayname, by_admin=False):
         """target_user is the user whose displayname is to be changed;
         auth_user is the user attempting to make this change."""
         if not self.hs.is_mine(target_user):
             raise SynapseError(400, "User is not hosted on this Home Server")
 
-        if target_user != requester.user:
+        if not by_admin and target_user != requester.user:
             raise AuthError(400, "Cannot set another user's displayname")
 
         if new_displayname == '':
@@ -111,13 +111,13 @@ class ProfileHandler(BaseHandler):
             defer.returnValue(result["avatar_url"])
 
     @defer.inlineCallbacks
-    def set_avatar_url(self, target_user, requester, new_avatar_url):
+    def set_avatar_url(self, target_user, requester, new_avatar_url, by_admin=False):
         """target_user is the user whose avatar_url is to be changed;
         auth_user is the user attempting to make this change."""
         if not self.hs.is_mine(target_user):
             raise SynapseError(400, "User is not hosted on this Home Server")
 
-        if target_user != requester.user:
+        if not by_admin and target_user != requester.user:
             raise AuthError(400, "Cannot set another user's avatar_url")
 
         yield self.store.set_profile_avatar_url(
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index dd75c4fecf..7e119f13b1 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -19,7 +19,6 @@ import urllib
 
 from twisted.internet import defer
 
-import synapse.types
 from synapse.api.errors import (
     AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError
 )
@@ -194,7 +193,7 @@ class RegistrationHandler(BaseHandler):
     def appservice_register(self, user_localpart, as_token):
         user = UserID(user_localpart, self.hs.hostname)
         user_id = user.to_string()
-        service = yield self.store.get_app_service_by_token(as_token)
+        service = self.store.get_app_service_by_token(as_token)
         if not service:
             raise AuthError(403, "Invalid application service token.")
         if not service.is_interested_in_user(user_id):
@@ -305,11 +304,10 @@ class RegistrationHandler(BaseHandler):
             # XXX: This should be a deferred list, shouldn't it?
             yield identity_handler.bind_threepid(c, user_id)
 
-    @defer.inlineCallbacks
     def check_user_id_not_appservice_exclusive(self, user_id, allowed_appservice=None):
         # valid user IDs must not clash with any user ID namespaces claimed by
         # application services.
-        services = yield self.store.get_app_services()
+        services = self.store.get_app_services()
         interested_services = [
             s for s in services
             if s.is_interested_in_user(user_id)
@@ -371,7 +369,7 @@ class RegistrationHandler(BaseHandler):
         defer.returnValue(data)
 
     @defer.inlineCallbacks
-    def get_or_create_user(self, localpart, displayname, duration_in_ms,
+    def get_or_create_user(self, requester, localpart, displayname, duration_in_ms,
                            password_hash=None):
         """Creates a new user if the user does not exist,
         else revokes all previous access tokens and generates a new one.
@@ -418,9 +416,8 @@ class RegistrationHandler(BaseHandler):
         if displayname is not None:
             logger.info("setting user display name: %s -> %s", user_id, displayname)
             profile_handler = self.hs.get_handlers().profile_handler
-            requester = synapse.types.create_requester(user)
             yield profile_handler.set_displayname(
-                user, requester, displayname
+                user, requester, displayname, by_admin=True,
             )
 
         defer.returnValue((user_id, token))
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index cbd26f8f95..a7f533f7be 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -437,7 +437,7 @@ class RoomEventSource(object):
             logger.warn("Stream has topological part!!!! %r", from_key)
             from_key = "s%s" % (from_token.stream,)
 
-        app_service = yield self.store.get_app_service_by_user_id(
+        app_service = self.store.get_app_service_by_user_id(
             user.to_string()
         )
         if app_service:
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index b5962f4f5a..1f910ff814 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -788,7 +788,7 @@ class SyncHandler(object):
 
         assert since_token
 
-        app_service = yield self.store.get_app_service_by_user_id(user_id)
+        app_service = self.store.get_app_service_by_user_id(user_id)
         if app_service:
             rooms = yield self.store.get_app_service_rooms(app_service)
             joined_room_ids = set(r.room_id for r in rooms)
diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py
index 3046da7aec..b5a76fefac 100644
--- a/synapse/rest/client/v1/register.py
+++ b/synapse/rest/client/v1/register.py
@@ -22,6 +22,7 @@ from synapse.api.auth import get_access_token_from_request
 from .base import ClientV1RestServlet, client_path_patterns
 import synapse.util.stringutils as stringutils
 from synapse.http.servlet import parse_json_object_from_request
+from synapse.types import create_requester
 
 from synapse.util.async import run_on_reactor
 
@@ -391,15 +392,16 @@ class CreateUserRestServlet(ClientV1RestServlet):
         user_json = parse_json_object_from_request(request)
 
         access_token = get_access_token_from_request(request)
-        app_service = yield self.store.get_app_service_by_token(
+        app_service = self.store.get_app_service_by_token(
             access_token
         )
         if not app_service:
             raise SynapseError(403, "Invalid application service token.")
 
-        logger.debug("creating user: %s", user_json)
+        requester = create_requester(app_service.sender)
 
-        response = yield self._do_create(user_json)
+        logger.debug("creating user: %s", user_json)
+        response = yield self._do_create(requester, user_json)
 
         defer.returnValue((200, response))
 
@@ -407,7 +409,7 @@ class CreateUserRestServlet(ClientV1RestServlet):
         return 403, {}
 
     @defer.inlineCallbacks
-    def _do_create(self, user_json):
+    def _do_create(self, requester, user_json):
         yield run_on_reactor()
 
         if "localpart" not in user_json:
@@ -433,6 +435,7 @@ class CreateUserRestServlet(ClientV1RestServlet):
 
         handler = self.handlers.registration_handler
         user_id, token = yield handler.get_or_create_user(
+            requester=requester,
             localpart=localpart,
             displayname=displayname,
             duration_in_ms=(duration_seconds * 1000),
diff --git a/synapse/storage/appservice.py b/synapse/storage/appservice.py
index a854a87eab..3d5994a580 100644
--- a/synapse/storage/appservice.py
+++ b/synapse/storage/appservice.py
@@ -37,7 +37,7 @@ class ApplicationServiceStore(SQLBaseStore):
         )
 
     def get_app_services(self):
-        return defer.succeed(self.services_cache)
+        return self.services_cache
 
     def get_app_service_by_user_id(self, user_id):
         """Retrieve an application service from their user ID.
@@ -54,8 +54,8 @@ class ApplicationServiceStore(SQLBaseStore):
         """
         for service in self.services_cache:
             if service.sender == user_id:
-                return defer.succeed(service)
-        return defer.succeed(None)
+                return service
+        return None
 
     def get_app_service_by_token(self, token):
         """Get the application service with the given appservice token.
@@ -67,8 +67,8 @@ class ApplicationServiceStore(SQLBaseStore):
         """
         for service in self.services_cache:
             if service.token == token:
-                return defer.succeed(service)
-        return defer.succeed(None)
+                return service
+        return None
 
     def get_app_service_rooms(self, service):
         """Get a list of RoomsForUser for this application service.
@@ -163,7 +163,7 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
             ["as_id"]
         )
         # NB: This assumes this class is linked with ApplicationServiceStore
-        as_list = yield self.get_app_services()
+        as_list = self.get_app_services()
         services = []
 
         for res in results:
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index a7de3c7c17..9c9d144690 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -17,7 +17,7 @@ from twisted.internet import defer
 from .. import unittest
 
 from synapse.handlers.register import RegistrationHandler
-from synapse.types import UserID
+from synapse.types import UserID, create_requester
 
 from tests.utils import setup_test_homeserver
 
@@ -57,8 +57,9 @@ class RegistrationTestCase(unittest.TestCase):
         local_part = "someone"
         display_name = "someone"
         user_id = "@someone:test"
+        requester = create_requester("@as:test")
         result_user_id, result_token = yield self.handler.get_or_create_user(
-            local_part, display_name, duration_ms)
+            requester, local_part, display_name, duration_ms)
         self.assertEquals(result_user_id, user_id)
         self.assertEquals(result_token, 'secret')
 
@@ -74,7 +75,8 @@ class RegistrationTestCase(unittest.TestCase):
         local_part = "frank"
         display_name = "Frank"
         user_id = "@frank:test"
+        requester = create_requester("@as:test")
         result_user_id, result_token = yield self.handler.get_or_create_user(
-            local_part, display_name, duration_ms)
+            requester, local_part, display_name, duration_ms)
         self.assertEquals(result_user_id, user_id)
         self.assertEquals(result_token, 'secret')
diff --git a/tests/rest/client/v1/test_register.py b/tests/rest/client/v1/test_register.py
index 4a898a034f..44ba9ff58f 100644
--- a/tests/rest/client/v1/test_register.py
+++ b/tests/rest/client/v1/test_register.py
@@ -31,33 +31,21 @@ class CreateUserServletTestCase(unittest.TestCase):
         )
         self.request.args = {}
 
-        self.appservice = None
-        self.auth = Mock(get_appservice_by_req=Mock(
-            side_effect=lambda x: defer.succeed(self.appservice))
-        )
+        self.registration_handler = Mock()
 
-        self.auth_result = (False, None, None, None)
-        self.auth_handler = Mock(
-            check_auth=Mock(side_effect=lambda x, y, z: self.auth_result),
-            get_session_data=Mock(return_value=None)
+        self.appservice = Mock(sender="@as:test")
+        self.datastore = Mock(
+            get_app_service_by_token=Mock(return_value=self.appservice)
         )
-        self.registration_handler = Mock()
-        self.identity_handler = Mock()
-        self.login_handler = Mock()
 
-        # do the dance to hook it up to the hs global
-        self.handlers = Mock(
-            auth_handler=self.auth_handler,
+        # do the dance to hook things up to the hs global
+        handlers = Mock(
             registration_handler=self.registration_handler,
-            identity_handler=self.identity_handler,
-            login_handler=self.login_handler
         )
         self.hs = Mock()
-        self.hs.hostname = "supergbig~testing~thing.com"
-        self.hs.get_auth = Mock(return_value=self.auth)
-        self.hs.get_handlers = Mock(return_value=self.handlers)
-        self.hs.config.enable_registration = True
-        # init the thing we're testing
+        self.hs.hostname = "superbig~testing~thing.com"
+        self.hs.get_datastore = Mock(return_value=self.datastore)
+        self.hs.get_handlers = Mock(return_value=handlers)
         self.servlet = CreateUserRestServlet(self.hs)
 
     @defer.inlineCallbacks
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index 8ac56a1fb2..e9cb416e4b 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -19,7 +19,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
 
         self.appservice = None
         self.auth = Mock(get_appservice_by_req=Mock(
-            side_effect=lambda x: defer.succeed(self.appservice))
+            side_effect=lambda x: self.appservice)
         )
 
         self.auth_result = (False, None, None, None)
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index 3e2862daae..f3df8302da 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -71,14 +71,12 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
             outfile.write(yaml.dump(as_yaml))
             self.as_yaml_files.append(as_token)
 
-    @defer.inlineCallbacks
     def test_retrieve_unknown_service_token(self):
-        service = yield self.store.get_app_service_by_token("invalid_token")
+        service = self.store.get_app_service_by_token("invalid_token")
         self.assertEquals(service, None)
 
-    @defer.inlineCallbacks
     def test_retrieval_of_service(self):
-        stored_service = yield self.store.get_app_service_by_token(
+        stored_service = self.store.get_app_service_by_token(
             self.as_token
         )
         self.assertEquals(stored_service.token, self.as_token)
@@ -97,9 +95,8 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
             []
         )
 
-    @defer.inlineCallbacks
     def test_retrieval_of_all_services(self):
-        services = yield self.store.get_app_services()
+        services = self.store.get_app_services()
         self.assertEquals(len(services), 3)