summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorDavid Baker <dave@matrix.org>2016-06-02 13:35:25 +0100
committerDavid Baker <dave@matrix.org>2016-06-02 13:35:25 +0100
commit3a3fb2f6f9d958d62a96391491f794042d1a419c (patch)
treeb0085e37f02db4f019440086a91905057843889b /synapse
parentEmail unsubscribing that may in theory, work (diff)
parentSplit out the auth handler (diff)
downloadsynapse-3a3fb2f6f9d958d62a96391491f794042d1a419c.tar.xz
Merge branch 'dbkr/split_out_auth_handler' into dbkr/email_unsubscribe
Diffstat (limited to 'synapse')
-rw-r--r--synapse/handlers/__init__.py2
-rw-r--r--synapse/handlers/register.py2
-rw-r--r--synapse/handlers/room.py2
-rw-r--r--synapse/push/mailer.py2
-rw-r--r--synapse/rest/client/v1/login.py11
-rw-r--r--synapse/rest/client/v2_alpha/account.py4
-rw-r--r--synapse/rest/client/v2_alpha/auth.py2
-rw-r--r--synapse/rest/client/v2_alpha/register.py2
-rw-r--r--synapse/rest/client/v2_alpha/tokenrefresh.py2
-rw-r--r--synapse/server.py5
-rw-r--r--synapse/storage/push_rule.py3
-rw-r--r--synapse/storage/signatures.py25
12 files changed, 38 insertions, 24 deletions
diff --git a/synapse/handlers/__init__.py b/synapse/handlers/__init__.py
index c0069e23d6..d28e07f0d9 100644
--- a/synapse/handlers/__init__.py
+++ b/synapse/handlers/__init__.py
@@ -24,7 +24,6 @@ from .federation import FederationHandler
 from .profile import ProfileHandler
 from .directory import DirectoryHandler
 from .admin import AdminHandler
-from .auth import AuthHandler
 from .identity import IdentityHandler
 from .receipts import ReceiptsHandler
 from .search import SearchHandler
@@ -50,7 +49,6 @@ class Handlers(object):
         self.directory_handler = DirectoryHandler(hs)
         self.admin_handler = AdminHandler(hs)
         self.receipts_handler = ReceiptsHandler(hs)
-        self.auth_handler = AuthHandler(hs)
         self.identity_handler = IdentityHandler(hs)
         self.search_handler = SearchHandler(hs)
         self.room_context_handler = RoomContextHandler(hs)
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 16f33f8371..bbc07b045e 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -413,7 +413,7 @@ class RegistrationHandler(BaseHandler):
         defer.returnValue((user_id, token))
 
     def auth_handler(self):
-        return self.hs.get_handlers().auth_handler
+        return self.hs.get_auth_handler()
 
     @defer.inlineCallbacks
     def guest_access_token_for(self, medium, address, inviter_user_id):
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 77063b021a..9fd34588dd 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -441,7 +441,7 @@ class RoomListHandler(BaseHandler):
             self.hs.config.secondary_directory_servers
         )
         self.remote_list_request_cache.set((), deferred)
-        yield deferred
+        self.remote_list_cache = yield deferred
 
     @defer.inlineCallbacks
     def get_aggregated_public_room_list(self):
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index 95250bad7d..b96cab5ef8 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -124,6 +124,8 @@ class Mailer(object):
             user_display_name = yield self.store.get_profile_displayname(
                 UserID.from_string(user_id).localpart
             )
+            if user_display_name is None:
+                user_display_name = user_id
         except StoreError:
             user_display_name = user_id
 
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 3b5544851b..8df9d10efa 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -58,6 +58,7 @@ class LoginRestServlet(ClientV1RestServlet):
         self.cas_required_attributes = hs.config.cas_required_attributes
         self.servername = hs.config.server_name
         self.http_client = hs.get_simple_http_client()
+        self.auth_handler = self.hs.get_auth_handler()
 
     def on_GET(self, request):
         flows = []
@@ -143,7 +144,7 @@ class LoginRestServlet(ClientV1RestServlet):
                 user_id, self.hs.hostname
             ).to_string()
 
-        auth_handler = self.handlers.auth_handler
+        auth_handler = self.auth_handler
         user_id, access_token, refresh_token = yield auth_handler.login_with_password(
             user_id=user_id,
             password=login_submission["password"])
@@ -160,7 +161,7 @@ class LoginRestServlet(ClientV1RestServlet):
     @defer.inlineCallbacks
     def do_token_login(self, login_submission):
         token = login_submission['token']
-        auth_handler = self.handlers.auth_handler
+        auth_handler = self.auth_handler
         user_id = (
             yield auth_handler.validate_short_term_login_token_and_get_user_id(token)
         )
@@ -194,7 +195,7 @@ class LoginRestServlet(ClientV1RestServlet):
                     raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
 
         user_id = UserID.create(user, self.hs.hostname).to_string()
-        auth_handler = self.handlers.auth_handler
+        auth_handler = self.auth_handler
         user_exists = yield auth_handler.does_user_exist(user_id)
         if user_exists:
             user_id, access_token, refresh_token = (
@@ -243,7 +244,7 @@ class LoginRestServlet(ClientV1RestServlet):
             raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED)
 
         user_id = UserID.create(user, self.hs.hostname).to_string()
-        auth_handler = self.handlers.auth_handler
+        auth_handler = self.auth_handler
         user_exists = yield auth_handler.does_user_exist(user_id)
         if user_exists:
             user_id, access_token, refresh_token = (
@@ -412,7 +413,7 @@ class CasTicketServlet(ClientV1RestServlet):
                     raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
 
         user_id = UserID.create(user, self.hs.hostname).to_string()
-        auth_handler = self.handlers.auth_handler
+        auth_handler = self.auth_handler
         user_exists = yield auth_handler.does_user_exist(user_id)
         if not user_exists:
             user_id, _ = (
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index c88c270537..9a84873a5f 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -35,7 +35,7 @@ class PasswordRestServlet(RestServlet):
         super(PasswordRestServlet, self).__init__()
         self.hs = hs
         self.auth = hs.get_auth()
-        self.auth_handler = hs.get_handlers().auth_handler
+        self.auth_handler = hs.get_auth_handler()
 
     @defer.inlineCallbacks
     def on_POST(self, request):
@@ -97,7 +97,7 @@ class ThreepidRestServlet(RestServlet):
         self.hs = hs
         self.identity_handler = hs.get_handlers().identity_handler
         self.auth = hs.get_auth()
-        self.auth_handler = hs.get_handlers().auth_handler
+        self.auth_handler = hs.get_auth_handler()
 
     @defer.inlineCallbacks
     def on_GET(self, request):
diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py
index 78181b7b18..58d3cad6a1 100644
--- a/synapse/rest/client/v2_alpha/auth.py
+++ b/synapse/rest/client/v2_alpha/auth.py
@@ -104,7 +104,7 @@ class AuthRestServlet(RestServlet):
         super(AuthRestServlet, self).__init__()
         self.hs = hs
         self.auth = hs.get_auth()
-        self.auth_handler = hs.get_handlers().auth_handler
+        self.auth_handler = hs.get_auth_handler()
         self.registration_handler = hs.get_handlers().registration_handler
 
     @defer.inlineCallbacks
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 1ecc02d94d..2088c316d1 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -49,7 +49,7 @@ class RegisterRestServlet(RestServlet):
         self.hs = hs
         self.auth = hs.get_auth()
         self.store = hs.get_datastore()
-        self.auth_handler = hs.get_handlers().auth_handler
+        self.auth_handler = hs.get_auth_handler()
         self.registration_handler = hs.get_handlers().registration_handler
         self.identity_handler = hs.get_handlers().identity_handler
 
diff --git a/synapse/rest/client/v2_alpha/tokenrefresh.py b/synapse/rest/client/v2_alpha/tokenrefresh.py
index a158c2209a..8270e8787f 100644
--- a/synapse/rest/client/v2_alpha/tokenrefresh.py
+++ b/synapse/rest/client/v2_alpha/tokenrefresh.py
@@ -38,7 +38,7 @@ class TokenRefreshRestServlet(RestServlet):
         body = parse_json_object_from_request(request)
         try:
             old_refresh_token = body["refresh_token"]
-            auth_handler = self.hs.get_handlers().auth_handler
+            auth_handler = self.hs.get_auth_handler()
             (user_id, new_refresh_token) = yield self.store.exchange_refresh_token(
                 old_refresh_token, auth_handler.generate_refresh_token)
             new_access_token = yield auth_handler.issue_access_token(user_id)
diff --git a/synapse/server.py b/synapse/server.py
index 7cf22b1eea..dd4b81c658 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -33,6 +33,7 @@ from synapse.handlers.presence import PresenceHandler
 from synapse.handlers.sync import SyncHandler
 from synapse.handlers.typing import TypingHandler
 from synapse.handlers.room import RoomListHandler
+from synapse.handlers.auth import AuthHandler
 from synapse.handlers.appservice import ApplicationServicesHandler
 from synapse.state import StateHandler
 from synapse.storage import DataStore
@@ -89,6 +90,7 @@ class HomeServer(object):
         'sync_handler',
         'typing_handler',
         'room_list_handler',
+        'auth_handler',
         'application_service_api',
         'application_service_scheduler',
         'application_service_handler',
@@ -190,6 +192,9 @@ class HomeServer(object):
     def build_room_list_handler(self):
         return RoomListHandler(self)
 
+    def build_auth_handler(self):
+        return AuthHandler(self)
+
     def build_application_service_api(self):
         return ApplicationServiceApi(self)
 
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index 216a8bf69c..ebb97c8474 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -106,7 +106,8 @@ class PushRuleStore(SQLBaseStore):
             desc="bulk_get_push_rules_enabled",
         )
         for row in rows:
-            results.setdefault(row['user_name'], {})[row['rule_id']] = row['enabled']
+            enabled = bool(row['enabled'])
+            results.setdefault(row['user_name'], {})[row['rule_id']] = enabled
         defer.returnValue(results)
 
     @defer.inlineCallbacks
diff --git a/synapse/storage/signatures.py b/synapse/storage/signatures.py
index b10f2a5787..ea6823f18d 100644
--- a/synapse/storage/signatures.py
+++ b/synapse/storage/signatures.py
@@ -19,17 +19,24 @@ from ._base import SQLBaseStore
 
 from unpaddedbase64 import encode_base64
 from synapse.crypto.event_signing import compute_event_reference_hash
+from synapse.util.caches.descriptors import cached, cachedList
 
 
 class SignatureStore(SQLBaseStore):
     """Persistence for event signatures and hashes"""
 
+    @cached(lru=True)
+    def get_event_reference_hash(self, event_id):
+        return self._get_event_reference_hashes_txn(event_id)
+
+    @cachedList(cached_method_name="get_event_reference_hash",
+                list_name="event_ids", num_args=1)
     def get_event_reference_hashes(self, event_ids):
         def f(txn):
-            return [
-                self._get_event_reference_hashes_txn(txn, ev)
-                for ev in event_ids
-            ]
+            return {
+                event_id: self._get_event_reference_hashes_txn(txn, event_id)
+                for event_id in event_ids
+            }
 
         return self.runInteraction(
             "get_event_reference_hashes",
@@ -41,15 +48,15 @@ class SignatureStore(SQLBaseStore):
         hashes = yield self.get_event_reference_hashes(
             event_ids
         )
-        hashes = [
-            {
+        hashes = {
+            e_id: {
                 k: encode_base64(v) for k, v in h.items()
                 if k == "sha256"
             }
-            for h in hashes
-        ]
+            for e_id, h in hashes.items()
+        }
 
-        defer.returnValue(zip(event_ids, hashes))
+        defer.returnValue(hashes.items())
 
     def _get_event_reference_hashes_txn(self, txn, event_id):
         """Get all the hashes for a given PDU.