summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2017-02-02 16:50:28 +0000
committerErik Johnston <erik@matrix.org>2017-02-02 16:50:28 +0000
commitf8c407a13b04f906dc761bcf33992e05af219098 (patch)
tree0479c242f3d7df94437755296616e657a0a15858
parentUpdate changelog (diff)
parentMerge pull request #1877 from matrix-org/erikj/device_list_fixes (diff)
downloadsynapse-f8c407a13b04f906dc761bcf33992e05af219098.tar.xz
Merge branch 'develop' of github.com:matrix-org/synapse into release-v0.19.0
-rw-r--r--synapse/handlers/auth.py80
-rw-r--r--synapse/handlers/device.py42
-rw-r--r--synapse/handlers/presence.py44
-rw-r--r--synapse/handlers/register.py10
-rw-r--r--synapse/handlers/room.py1
-rw-r--r--synapse/notifier.py1
-rw-r--r--synapse/push/mailer.py4
-rw-r--r--synapse/replication/slave/storage/events.py3
-rw-r--r--synapse/rest/client/v1/login.py5
-rw-r--r--synapse/rest/client/v2_alpha/keys.py2
-rw-r--r--synapse/rest/client/v2_alpha/register.py3
-rw-r--r--synapse/server.py6
-rw-r--r--synapse/storage/roommember.py17
-rw-r--r--synapse/util/caches/descriptors.py5
-rw-r--r--tests/handlers/test_auth.py12
-rw-r--r--tests/handlers/test_register.py7
16 files changed, 135 insertions, 107 deletions
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 221d7ea7a2..fffba34383 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -65,6 +65,7 @@ class AuthHandler(BaseHandler):
 
         self.hs = hs  # FIXME better possibility to access registrationHandler later?
         self.device_handler = hs.get_device_handler()
+        self.macaroon_gen = hs.get_macaroon_generator()
 
     @defer.inlineCallbacks
     def check_auth(self, flows, clientdict, clientip):
@@ -529,37 +530,11 @@ class AuthHandler(BaseHandler):
 
     @defer.inlineCallbacks
     def issue_access_token(self, user_id, device_id=None):
-        access_token = self.generate_access_token(user_id)
+        access_token = self.macaroon_gen.generate_access_token(user_id)
         yield self.store.add_access_token_to_user(user_id, access_token,
                                                   device_id)
         defer.returnValue(access_token)
 
-    def generate_access_token(self, user_id, extra_caveats=None):
-        extra_caveats = extra_caveats or []
-        macaroon = self._generate_base_macaroon(user_id)
-        macaroon.add_first_party_caveat("type = access")
-        # Include a nonce, to make sure that each login gets a different
-        # access token.
-        macaroon.add_first_party_caveat("nonce = %s" % (
-            stringutils.random_string_with_symbols(16),
-        ))
-        for caveat in extra_caveats:
-            macaroon.add_first_party_caveat(caveat)
-        return macaroon.serialize()
-
-    def generate_short_term_login_token(self, user_id, duration_in_ms=(2 * 60 * 1000)):
-        macaroon = self._generate_base_macaroon(user_id)
-        macaroon.add_first_party_caveat("type = login")
-        now = self.hs.get_clock().time_msec()
-        expiry = now + duration_in_ms
-        macaroon.add_first_party_caveat("time < %d" % (expiry,))
-        return macaroon.serialize()
-
-    def generate_delete_pusher_token(self, user_id):
-        macaroon = self._generate_base_macaroon(user_id)
-        macaroon.add_first_party_caveat("type = delete_pusher")
-        return macaroon.serialize()
-
     def validate_short_term_login_token_and_get_user_id(self, login_token):
         auth_api = self.hs.get_auth()
         try:
@@ -570,15 +545,6 @@ class AuthHandler(BaseHandler):
         except Exception:
             raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN)
 
-    def _generate_base_macaroon(self, user_id):
-        macaroon = pymacaroons.Macaroon(
-            location=self.hs.config.server_name,
-            identifier="key",
-            key=self.hs.config.macaroon_secret_key)
-        macaroon.add_first_party_caveat("gen = 1")
-        macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
-        return macaroon
-
     @defer.inlineCallbacks
     def set_password(self, user_id, newpassword, requester=None):
         password_hash = self.hash(newpassword)
@@ -673,6 +639,48 @@ class AuthHandler(BaseHandler):
             return False
 
 
+class MacaroonGeneartor(object):
+    def __init__(self, hs):
+        self.clock = hs.get_clock()
+        self.server_name = hs.config.server_name
+        self.macaroon_secret_key = hs.config.macaroon_secret_key
+
+    def generate_access_token(self, user_id, extra_caveats=None):
+        extra_caveats = extra_caveats or []
+        macaroon = self._generate_base_macaroon(user_id)
+        macaroon.add_first_party_caveat("type = access")
+        # Include a nonce, to make sure that each login gets a different
+        # access token.
+        macaroon.add_first_party_caveat("nonce = %s" % (
+            stringutils.random_string_with_symbols(16),
+        ))
+        for caveat in extra_caveats:
+            macaroon.add_first_party_caveat(caveat)
+        return macaroon.serialize()
+
+    def generate_short_term_login_token(self, user_id, duration_in_ms=(2 * 60 * 1000)):
+        macaroon = self._generate_base_macaroon(user_id)
+        macaroon.add_first_party_caveat("type = login")
+        now = self.clock.time_msec()
+        expiry = now + duration_in_ms
+        macaroon.add_first_party_caveat("time < %d" % (expiry,))
+        return macaroon.serialize()
+
+    def generate_delete_pusher_token(self, user_id):
+        macaroon = self._generate_base_macaroon(user_id)
+        macaroon.add_first_party_caveat("type = delete_pusher")
+        return macaroon.serialize()
+
+    def _generate_base_macaroon(self, user_id):
+        macaroon = pymacaroons.Macaroon(
+            location=self.server_name,
+            identifier="key",
+            key=self.macaroon_secret_key)
+        macaroon.add_first_party_caveat("gen = 1")
+        macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
+        return macaroon
+
+
 class _AccountHandler(object):
     """A proxy object that gets passed to password auth providers so they
     can register new users etc if necessary.
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 815410969c..158206aef6 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -17,7 +17,7 @@ from synapse.api import errors
 from synapse.api.constants import EventTypes
 from synapse.util import stringutils
 from synapse.util.async import Linearizer
-from synapse.types import get_domain_from_id
+from synapse.types import get_domain_from_id, RoomStreamToken
 from twisted.internet import defer
 from ._base import BaseHandler
 
@@ -198,20 +198,22 @@ class DeviceHandler(BaseHandler):
         """Notify that a user's device(s) has changed. Pokes the notifier, and
         remote servers if the user is local.
         """
-        rooms = yield self.store.get_rooms_for_user(user_id)
-        room_ids = [r.room_id for r in rooms]
+        users_who_share_room = yield self.store.get_users_who_share_room_with_user(
+            user_id
+        )
 
         hosts = set()
         if self.hs.is_mine_id(user_id):
-            for room_id in room_ids:
-                users = yield self.store.get_users_in_room(room_id)
-                hosts.update(get_domain_from_id(u) for u in users)
+            hosts.update(get_domain_from_id(u) for u in users_who_share_room)
             hosts.discard(self.server_name)
 
         position = yield self.store.add_device_change_to_streams(
             user_id, device_ids, list(hosts)
         )
 
+        rooms = yield self.store.get_rooms_for_user(user_id)
+        room_ids = [r.room_id for r in rooms]
+
         yield self.notifier.on_new_event(
             "device_list_key", position, rooms=room_ids,
         )
@@ -243,15 +245,15 @@ class DeviceHandler(BaseHandler):
 
         possibly_changed = set(changed)
         for room_id in rooms_changed:
-            # Fetch (an approximation) of the current state at the time.
-            event_rows, token = yield self.store.get_recent_event_ids_for_room(
-                room_id, end_token=from_token.room_key, limit=1,
-            )
+            # Fetch  the current state at the time.
+            stream_ordering = RoomStreamToken.parse_stream_token(from_token.room_key)
 
-            if event_rows:
-                last_event_id = event_rows[-1]["event_id"]
-                prev_state_ids = yield self.store.get_state_ids_for_event(last_event_id)
-            else:
+            try:
+                event_ids = yield self.store.get_forward_extremeties_for_room(
+                    room_id, stream_ordering=stream_ordering
+                )
+                prev_state_ids = yield self.store.get_state_ids_for_events(event_ids)
+            except:
                 prev_state_ids = {}
 
             current_state_ids = yield self.state.get_current_state_ids(room_id)
@@ -266,13 +268,13 @@ class DeviceHandler(BaseHandler):
                     if not prev_event_id or prev_event_id != event_id:
                         possibly_changed.add(state_key)
 
-        user_ids_changed = set()
-        for other_user_id in possibly_changed:
-            other_rooms = yield self.store.get_rooms_for_user(other_user_id)
-            if room_ids.intersection(e.room_id for e in other_rooms):
-                user_ids_changed.add(other_user_id)
+        users_who_share_room = yield self.store.get_users_who_share_room_with_user(
+            user_id
+        )
 
-        defer.returnValue(user_ids_changed)
+        # Take the intersection of the users whose devices may have changed
+        # and those that actually still share a room with the user
+        defer.returnValue(users_who_share_room & possibly_changed)
 
     @defer.inlineCallbacks
     def _incoming_device_list_update(self, origin, edu_content):
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 9982ae0fed..fdfce2a88c 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -1011,7 +1011,7 @@ class PresenceEventSource(object):
     @defer.inlineCallbacks
     @log_function
     def get_new_events(self, user, from_key, room_ids=None, include_offline=True,
-                       **kwargs):
+                       explicit_room_id=None, **kwargs):
         # The process for getting presence events are:
         #  1. Get the rooms the user is in.
         #  2. Get the list of user in the rooms.
@@ -1028,22 +1028,24 @@ class PresenceEventSource(object):
             user_id = user.to_string()
             if from_key is not None:
                 from_key = int(from_key)
-            room_ids = room_ids or []
 
             presence = self.get_presence_handler()
             stream_change_cache = self.store.presence_stream_cache
 
-            if not room_ids:
-                rooms = yield self.store.get_rooms_for_user(user_id)
-                room_ids = set(e.room_id for e in rooms)
-            else:
-                room_ids = set(room_ids)
-
             max_token = self.store.get_current_presence_token()
 
             plist = yield self.store.get_presence_list_accepted(user.localpart)
-            friends = set(row["observed_user_id"] for row in plist)
-            friends.add(user_id)  # So that we receive our own presence
+            users_interested_in = set(row["observed_user_id"] for row in plist)
+            users_interested_in.add(user_id)  # So that we receive our own presence
+
+            users_who_share_room = yield self.store.get_users_who_share_room_with_user(
+                user_id
+            )
+            users_interested_in.update(users_who_share_room)
+
+            if explicit_room_id:
+                user_ids = yield self.store.get_users_in_room(explicit_room_id)
+                users_interested_in.update(user_ids)
 
             user_ids_changed = set()
             changed = None
@@ -1055,35 +1057,19 @@ class PresenceEventSource(object):
                 # work out if we share a room or they're in our presence list
                 get_updates_counter.inc("stream")
                 for other_user_id in changed:
-                    if other_user_id in friends:
+                    if other_user_id in users_interested_in:
                         user_ids_changed.add(other_user_id)
-                        continue
-                    other_rooms = yield self.store.get_rooms_for_user(other_user_id)
-                    if room_ids.intersection(e.room_id for e in other_rooms):
-                        user_ids_changed.add(other_user_id)
-                        continue
             else:
                 # Too many possible updates. Find all users we can see and check
                 # if any of them have changed.
                 get_updates_counter.inc("full")
 
-                user_ids_to_check = set()
-                for room_id in room_ids:
-                    users = yield self.store.get_users_in_room(room_id)
-                    user_ids_to_check.update(users)
-
-                user_ids_to_check.update(friends)
-
-                # Always include yourself. Only really matters for when the user is
-                # not in any rooms, but still.
-                user_ids_to_check.add(user_id)
-
                 if from_key:
                     user_ids_changed = stream_change_cache.get_entities_changed(
-                        user_ids_to_check, from_key,
+                        users_interested_in, from_key,
                     )
                 else:
-                    user_ids_changed = user_ids_to_check
+                    user_ids_changed = users_interested_in
 
             updates = yield presence.current_state_for_users(user_ids_changed)
 
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 286f0cef0a..03c6a85fc6 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -40,6 +40,8 @@ class RegistrationHandler(BaseHandler):
 
         self._next_generated_user_id = None
 
+        self.macaroon_gen = hs.get_macaroon_generator()
+
     @defer.inlineCallbacks
     def check_username(self, localpart, guest_access_token=None,
                        assigned_user_id=None):
@@ -143,7 +145,7 @@ class RegistrationHandler(BaseHandler):
 
             token = None
             if generate_token:
-                token = self.auth_handler().generate_access_token(user_id)
+                token = self.macaroon_gen.generate_access_token(user_id)
             yield self.store.register(
                 user_id=user_id,
                 token=token,
@@ -167,7 +169,7 @@ class RegistrationHandler(BaseHandler):
                 user_id = user.to_string()
                 yield self.check_user_id_not_appservice_exclusive(user_id)
                 if generate_token:
-                    token = self.auth_handler().generate_access_token(user_id)
+                    token = self.macaroon_gen.generate_access_token(user_id)
                 try:
                     yield self.store.register(
                         user_id=user_id,
@@ -254,7 +256,7 @@ class RegistrationHandler(BaseHandler):
         user_id = user.to_string()
 
         yield self.check_user_id_not_appservice_exclusive(user_id)
-        token = self.auth_handler().generate_access_token(user_id)
+        token = self.macaroon_gen.generate_access_token(user_id)
         try:
             yield self.store.register(
                 user_id=user_id,
@@ -399,7 +401,7 @@ class RegistrationHandler(BaseHandler):
 
         user = UserID(localpart, self.hs.hostname)
         user_id = user.to_string()
-        token = self.auth_handler().generate_access_token(user_id)
+        token = self.macaroon_gen.generate_access_token(user_id)
 
         if need_register:
             yield self.store.register(
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 5f18007e90..7e7671c9a2 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -437,6 +437,7 @@ class RoomEventSource(object):
             limit,
             room_ids,
             is_guest,
+            explicit_room_id=None,
     ):
         # We just ignore the key for now.
 
diff --git a/synapse/notifier.py b/synapse/notifier.py
index acbd4bb5ae..8051a7a842 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -378,6 +378,7 @@ class Notifier(object):
                     limit=limit,
                     is_guest=is_peeking,
                     room_ids=room_ids,
+                    explicit_room_id=explicit_room_id,
                 )
 
                 if name == "room":
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index ce2d31fb98..62d794f22b 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -81,7 +81,7 @@ class Mailer(object):
     def __init__(self, hs, app_name):
         self.hs = hs
         self.store = self.hs.get_datastore()
-        self.auth_handler = self.hs.get_auth_handler()
+        self.macaroon_gen = self.hs.get_macaroon_generator()
         self.state_handler = self.hs.get_state_handler()
         loader = jinja2.FileSystemLoader(self.hs.config.email_template_dir)
         self.app_name = app_name
@@ -466,7 +466,7 @@ class Mailer(object):
 
     def make_unsubscribe_link(self, user_id, app_id, email_address):
         params = {
-            "access_token": self.auth_handler.generate_delete_pusher_token(user_id),
+            "access_token": self.macaroon_gen.generate_delete_pusher_token(user_id),
             "app_id": app_id,
             "pushkey": email_address,
         }
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index 15a025a019..d72ff6055c 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -73,6 +73,9 @@ class SlavedEventStore(BaseSlavedStore):
     # to reach inside the __dict__ to extract them.
     get_rooms_for_user = RoomMemberStore.__dict__["get_rooms_for_user"]
     get_users_in_room = RoomMemberStore.__dict__["get_users_in_room"]
+    get_users_who_share_room_with_user = (
+        RoomMemberStore.__dict__["get_users_who_share_room_with_user"]
+    )
     get_latest_event_ids_in_room = EventFederationStore.__dict__[
         "get_latest_event_ids_in_room"
     ]
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 0c9cdff3b8..72057f1b0c 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -330,6 +330,7 @@ class CasTicketServlet(ClientV1RestServlet):
         self.cas_required_attributes = hs.config.cas_required_attributes
         self.auth_handler = hs.get_auth_handler()
         self.handlers = hs.get_handlers()
+        self.macaroon_gen = hs.get_macaroon_generator()
 
     @defer.inlineCallbacks
     def on_GET(self, request):
@@ -368,7 +369,9 @@ class CasTicketServlet(ClientV1RestServlet):
                 yield self.handlers.registration_handler.register(localpart=user)
             )
 
-        login_token = auth_handler.generate_short_term_login_token(registered_user_id)
+        login_token = self.macaroon_gen.generate_short_term_login_token(
+            registered_user_id
+        )
         redirect_url = self.add_login_token_to_redirect_url(client_redirect_url,
                                                             login_token)
         request.redirect(redirect_url)
diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py
index f99b53530a..6a3cfe84f8 100644
--- a/synapse/rest/client/v2_alpha/keys.py
+++ b/synapse/rest/client/v2_alpha/keys.py
@@ -193,7 +193,7 @@ class KeyChangesServlet(RestServlet):
         )
 
         defer.returnValue((200, {
-            "changed": changed
+            "changed": list(changed),
         }))
 
 
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 3e7a285e10..ccca5a12d5 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -96,6 +96,7 @@ class RegisterRestServlet(RestServlet):
         self.registration_handler = hs.get_handlers().registration_handler
         self.identity_handler = hs.get_handlers().identity_handler
         self.device_handler = hs.get_device_handler()
+        self.macaroon_gen = hs.get_macaroon_generator()
 
     @defer.inlineCallbacks
     def on_POST(self, request):
@@ -436,7 +437,7 @@ class RegisterRestServlet(RestServlet):
             user_id, device_id, initial_display_name
         )
 
-        access_token = self.auth_handler.generate_access_token(
+        access_token = self.macaroon_gen.generate_access_token(
             user_id, ["guest = true"]
         )
         defer.returnValue((200, {
diff --git a/synapse/server.py b/synapse/server.py
index 0bfb411269..c577032041 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -37,7 +37,7 @@ from synapse.federation.transport.client import TransportLayerClient
 from synapse.federation.transaction_queue import TransactionQueue
 from synapse.handlers import Handlers
 from synapse.handlers.appservice import ApplicationServicesHandler
-from synapse.handlers.auth import AuthHandler
+from synapse.handlers.auth import AuthHandler, MacaroonGeneartor
 from synapse.handlers.devicemessage import DeviceMessageHandler
 from synapse.handlers.device import DeviceHandler
 from synapse.handlers.e2e_keys import E2eKeysHandler
@@ -131,6 +131,7 @@ class HomeServer(object):
         'federation_transport_client',
         'federation_sender',
         'receipts_handler',
+        'macaroon_generator',
     ]
 
     def __init__(self, hostname, **kwargs):
@@ -213,6 +214,9 @@ class HomeServer(object):
     def build_auth_handler(self):
         return AuthHandler(self)
 
+    def build_macaroon_generator(self):
+        return MacaroonGeneartor(self)
+
     def build_device_handler(self):
         return DeviceHandler(self)
 
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index ee800d074f..249217e114 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -280,6 +280,23 @@ class RoomMemberStore(SQLBaseStore):
             user_id, membership_list=[Membership.JOIN],
         )
 
+    @cachedInlineCallbacks(max_entries=50000, cache_context=True, iterable=True)
+    def get_users_who_share_room_with_user(self, user_id, cache_context):
+        """Returns the set of users who share a room with `user_id`
+        """
+        rooms = yield self.get_rooms_for_user(
+            user_id, on_invalidate=cache_context.invalidate,
+        )
+
+        user_who_share_room = set()
+        for room in rooms:
+            user_ids = yield self.get_users_in_room(
+                room.room_id, on_invalidate=cache_context.invalidate,
+            )
+            user_who_share_room.update(user_ids)
+
+        defer.returnValue(user_who_share_room)
+
     def forget(self, user_id, room_id):
         """Indicate that user_id wishes to discard history for room_id."""
         def f(txn):
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 675bfd5feb..998de70d29 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -478,6 +478,11 @@ class CacheListDescriptor(object):
 
 
 class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))):
+    # We rely on _CacheContext implementing __eq__ and __hash__ sensibly,
+    # which namedtuple does for us (i.e. two _CacheContext are the same if
+    # their caches and keys match). This is important in particular to
+    # dedupe when we add callbacks to lru cache nodes, otherwise the number
+    # of callbacks would grow.
     def invalidate(self):
         self.cache.invalidate(self.key)
 
diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py
index 9d013e5ca7..1822dcf1e0 100644
--- a/tests/handlers/test_auth.py
+++ b/tests/handlers/test_auth.py
@@ -34,11 +34,10 @@ class AuthTestCase(unittest.TestCase):
         self.hs = yield setup_test_homeserver(handlers=None)
         self.hs.handlers = AuthHandlers(self.hs)
         self.auth_handler = self.hs.handlers.auth_handler
+        self.macaroon_generator = self.hs.get_macaroon_generator()
 
     def test_token_is_a_macaroon(self):
-        self.hs.config.macaroon_secret_key = "this key is a huge secret"
-
-        token = self.auth_handler.generate_access_token("some_user")
+        token = self.macaroon_generator.generate_access_token("some_user")
         # Check that we can parse the thing with pymacaroons
         macaroon = pymacaroons.Macaroon.deserialize(token)
         # The most basic of sanity checks
@@ -46,10 +45,9 @@ class AuthTestCase(unittest.TestCase):
             self.fail("some_user was not in %s" % macaroon.inspect())
 
     def test_macaroon_caveats(self):
-        self.hs.config.macaroon_secret_key = "this key is a massive secret"
         self.hs.clock.now = 5000
 
-        token = self.auth_handler.generate_access_token("a_user")
+        token = self.macaroon_generator.generate_access_token("a_user")
         macaroon = pymacaroons.Macaroon.deserialize(token)
 
         def verify_gen(caveat):
@@ -74,7 +72,7 @@ class AuthTestCase(unittest.TestCase):
     def test_short_term_login_token_gives_user_id(self):
         self.hs.clock.now = 1000
 
-        token = self.auth_handler.generate_short_term_login_token(
+        token = self.macaroon_generator.generate_short_term_login_token(
             "a_user", 5000
         )
 
@@ -93,7 +91,7 @@ class AuthTestCase(unittest.TestCase):
             )
 
     def test_short_term_login_token_cannot_replace_user_id(self):
-        token = self.auth_handler.generate_short_term_login_token(
+        token = self.macaroon_generator.generate_short_term_login_token(
             "a_user", 5000
         )
         macaroon = pymacaroons.Macaroon.deserialize(token)
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index a4380c48b4..c8cf9a63ec 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -41,15 +41,12 @@ class RegistrationTestCase(unittest.TestCase):
             handlers=None,
             http_client=None,
             expire_access_token=True)
-        self.auth_handler = Mock(
+        self.macaroon_generator = Mock(
             generate_access_token=Mock(return_value='secret'))
+        self.hs.get_macaroon_generator = Mock(return_value=self.macaroon_generator)
         self.hs.handlers = RegistrationHandlers(self.hs)
         self.handler = self.hs.get_handlers().registration_handler
         self.hs.get_handlers().profile_handler = Mock()
-        self.mock_handler = Mock(spec=[
-            "generate_access_token",
-        ])
-        self.hs.get_auth_handler = Mock(return_value=self.auth_handler)
 
     @defer.inlineCallbacks
     def test_user_is_created_and_logged_in_if_doesnt_exist(self):