diff --git a/contrib/vertobot/bot.pl b/contrib/vertobot/bot.pl
index 0430a38aa8..31eed40925 100755
--- a/contrib/vertobot/bot.pl
+++ b/contrib/vertobot/bot.pl
@@ -126,12 +126,26 @@ sub on_unknown_event
if (!$bridgestate->{$room_id}->{gathered_candidates}) {
$bridgestate->{$room_id}->{gathered_candidates} = 1;
my $offer = $bridgestate->{$room_id}->{offer};
- my $candidate_block = "";
+ my $candidate_block = {
+ audio => '',
+ video => '',
+ };
foreach (@{$event->{content}->{candidates}}) {
- $candidate_block .= "a=" . $_->{candidate} . "\r\n";
+ if ($_->{sdpMid}) {
+ $candidate_block->{$_->{sdpMid}} .= "a=" . $_->{candidate} . "\r\n";
+ }
+ else {
+ $candidate_block->{audio} .= "a=" . $_->{candidate} . "\r\n";
+ $candidate_block->{video} .= "a=" . $_->{candidate} . "\r\n";
+ }
}
- # XXX: collate using the right m= line - for now assume audio call
- $offer =~ s/(a=rtcp.*[\r\n]+)/$1$candidate_block/;
+
+ # XXX: assumes audio comes first
+ #$offer =~ s/(a=rtcp-mux[\r\n]+)/$1$candidate_block->{audio}/;
+ #$offer =~ s/(a=rtcp-mux[\r\n]+)/$1$candidate_block->{video}/;
+
+ $offer =~ s/(m=video)/$candidate_block->{audio}$1/;
+ $offer =~ s/(.$)/$1\n$candidate_block->{video}$1/;
my $f = send_verto_json_request("verto.invite", {
"sdp" => $offer,
@@ -172,22 +186,18 @@ sub on_room_message
warn "[Matrix] in $room_id: $from: " . $content->{body} . "\n";
}
-my $verto_connecting = $loop->new_future;
-$bot_verto->connect(
- %{ $CONFIG{"verto-bot"} },
- on_connect_error => sub { die "Cannot connect to verto - $_[-1]" },
- on_resolve_error => sub { die "Cannot resolve to verto - $_[-1]" },
-)->then( sub {
- warn("[Verto] connected to websocket");
- $verto_connecting->done($bot_verto) if not $verto_connecting->is_done;
-});
-
Future->needs_all(
$bot_matrix->login( %{ $CONFIG{"matrix-bot"} } )->then( sub {
$bot_matrix->start;
}),
- $verto_connecting,
+ $bot_verto->connect(
+ %{ $CONFIG{"verto-bot"} },
+ on_connect_error => sub { die "Cannot connect to verto - $_[-1]" },
+ on_resolve_error => sub { die "Cannot resolve to verto - $_[-1]" },
+ )->on_done( sub {
+ warn("[Verto] connected to websocket");
+ }),
)->get;
$loop->attach_signal(
diff --git a/contrib/vertobot/cpanfile b/contrib/vertobot/cpanfile
index c29fcaa6f6..800dc288ed 100644
--- a/contrib/vertobot/cpanfile
+++ b/contrib/vertobot/cpanfile
@@ -11,7 +11,4 @@ requires 'YAML', 0;
requires 'JSON', 0;
requires 'Getopt::Long', 0;
-on 'test' => sub {
- requires 'Test::More', '>= 0.98';
-};
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 1e3b0fbfb7..3d9237ccc3 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -361,7 +361,7 @@ class Auth(object):
except KeyError:
pass # normal users won't have the user_id query parameter set.
- user_info = yield self.get_user_by_token(access_token)
+ user_info = yield self.get_user_by_access_token(access_token)
user = user_info["user"]
device_id = user_info["device_id"]
token_id = user_info["token_id"]
@@ -390,7 +390,7 @@ class Auth(object):
)
@defer.inlineCallbacks
- def get_user_by_token(self, token):
+ def get_user_by_access_token(self, token):
""" Get a registered user's ID.
Args:
@@ -401,7 +401,7 @@ class Auth(object):
Raises:
AuthError if no user by that token exists or the token is invalid.
"""
- ret = yield self.store.get_user_by_token(token)
+ ret = yield self.store.get_user_by_access_token(token)
if not ret:
raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Unrecognised access token.",
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index 67e780864e..62de4b399f 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -32,9 +32,11 @@ class RegistrationConfig(Config):
)
self.registration_shared_secret = config.get("registration_shared_secret")
+ self.macaroon_secret_key = config.get("macaroon_secret_key")
def default_config(self, config_dir, server_name):
registration_shared_secret = random_string_with_symbols(50)
+ macaroon_secret_key = random_string_with_symbols(50)
return """\
## Registration ##
@@ -44,6 +46,8 @@ class RegistrationConfig(Config):
# If set, allows registration by anyone who also has the shared
# secret, even if registration is otherwise disabled.
registration_shared_secret: "%(registration_shared_secret)s"
+
+ macaroon_secret_key: "%(macaroon_secret_key)s"
""" % locals()
def add_arguments(self, parser):
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index ff2c66f442..af602bee44 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -26,6 +26,7 @@ from twisted.web.client import PartialDownloadError
import logging
import bcrypt
+import pymacaroons
import simplejson
import synapse.util.stringutils as stringutils
@@ -278,18 +279,18 @@ class AuthHandler(BaseHandler):
user_id (str): User ID
password (str): Password
Returns:
- The access token for the user's session.
+ A tuple of:
+ The access token for the user's session.
+ The refresh token for the user's session.
Raises:
StoreError if there was a problem storing the token.
LoginError if there was an authentication problem.
"""
yield self._check_password(user_id, password)
-
- reg_handler = self.hs.get_handlers().registration_handler
- access_token = reg_handler.generate_token(user_id)
logger.info("Logging in user %s", user_id)
- yield self.store.add_access_token_to_user(user_id, access_token)
- defer.returnValue(access_token)
+ access_token = yield self.issue_access_token(user_id)
+ refresh_token = yield self.issue_refresh_token(user_id)
+ defer.returnValue((access_token, refresh_token))
@defer.inlineCallbacks
def _check_password(self, user_id, password):
@@ -305,6 +306,49 @@ class AuthHandler(BaseHandler):
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
@defer.inlineCallbacks
+ def issue_access_token(self, user_id):
+ access_token = self.generate_access_token(user_id)
+ yield self.store.add_access_token_to_user(user_id, access_token)
+ defer.returnValue(access_token)
+
+ @defer.inlineCallbacks
+ def issue_refresh_token(self, user_id):
+ refresh_token = self.generate_refresh_token(user_id)
+ yield self.store.add_refresh_token_to_user(user_id, refresh_token)
+ defer.returnValue(refresh_token)
+
+ def generate_access_token(self, user_id):
+ macaroon = pymacaroons.Macaroon(
+ location = self.hs.config.server_name,
+ identifier = "key",
+ key = self.hs.config.macaroon_secret_key)
+ macaroon.add_first_party_caveat("gen = 1")
+ macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
+ macaroon.add_first_party_caveat("type = access")
+ now = self.hs.get_clock().time_msec()
+ expiry = now + (60 * 60 * 1000)
+ macaroon.add_first_party_caveat("time < %d" % (expiry,))
+
+ return macaroon.serialize()
+
+ def generate_refresh_token(self, user_id):
+ m = self._generate_base_macaroon(user_id)
+ m.add_first_party_caveat("type = refresh")
+ # Important to add a nonce, because otherwise every refresh token for a
+ # user will be the same.
+ m.add_first_party_caveat("nonce = %s" % stringutils.random_string_with_symbols(16))
+ return m.serialize()
+
+ def _generate_base_macaroon(self, user_id):
+ macaroon = pymacaroons.Macaroon(
+ location = self.hs.config.server_name,
+ identifier = "key",
+ key = self.hs.config.macaroon_secret_key)
+ macaroon.add_first_party_caveat("gen = 1")
+ macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
+ return macaroon
+
+ @defer.inlineCallbacks
def set_password(self, user_id, newpassword):
password_hash = bcrypt.hashpw(newpassword, bcrypt.gensalt())
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 39392d9fdd..3d1b6531c2 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -25,7 +25,6 @@ import synapse.util.stringutils as stringutils
from synapse.util.async import run_on_reactor
from synapse.http.client import CaptchaServerHttpClient
-import base64
import bcrypt
import logging
import urllib
@@ -91,7 +90,7 @@ class RegistrationHandler(BaseHandler):
user = UserID(localpart, self.hs.hostname)
user_id = user.to_string()
- token = self.generate_token(user_id)
+ token = self.auth_handler().generate_access_token(user_id)
yield self.store.register(
user_id=user_id,
token=token,
@@ -111,7 +110,7 @@ class RegistrationHandler(BaseHandler):
user_id = user.to_string()
yield self.check_user_id_is_valid(user_id)
- token = self.generate_token(user_id)
+ token = self.auth_handler().generate_access_token(user_id)
yield self.store.register(
user_id=user_id,
token=token,
@@ -161,7 +160,7 @@ class RegistrationHandler(BaseHandler):
400, "Invalid user localpart for this application service.",
errcode=Codes.EXCLUSIVE
)
- token = self.generate_token(user_id)
+ token = self.auth_handler().generate_access_token(user_id)
yield self.store.register(
user_id=user_id,
token=token,
@@ -208,7 +207,7 @@ class RegistrationHandler(BaseHandler):
user_id = user.to_string()
yield self.check_user_id_is_valid(user_id)
- token = self.generate_token(user_id)
+ token = self.auth_handler().generate_access_token(user_id)
try:
yield self.store.register(
user_id=user_id,
@@ -273,13 +272,6 @@ class RegistrationHandler(BaseHandler):
errcode=Codes.EXCLUSIVE
)
- def generate_token(self, user_id):
- # urlsafe variant uses _ and - so use . as the separator and replace
- # all =s with .s so http clients don't quote =s when it is used as
- # query params.
- return (base64.urlsafe_b64encode(user_id).replace('=', '.') + '.' +
- stringutils.random_string(18))
-
def _generate_user_id(self):
return "-" + stringutils.random_string(18)
@@ -322,3 +314,6 @@ class RegistrationHandler(BaseHandler):
}
)
defer.returnValue(data)
+
+ def auth_handler(self):
+ return self.hs.get_handlers().auth_handler
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index fa06480ad1..fa24199377 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -33,6 +33,7 @@ REQUIREMENTS = {
"ujson": ["ujson"],
"blist": ["blist"],
"pysaml2": ["saml2"],
+ "pymacaroons-pynacl": ["pymacaroons"],
}
CONDITIONAL_REQUIREMENTS = {
"web_client": {
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 0d5eafd0fa..67323a16bb 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -85,13 +85,15 @@ class LoginRestServlet(ClientV1RestServlet):
user_id = UserID.create(
user_id, self.hs.hostname).to_string()
- token = yield self.handlers.auth_handler.login_with_password(
+ auth_handler = self.handlers.auth_handler
+ access_token, refresh_token = yield auth_handler.login_with_password(
user_id=user_id,
password=login_submission["password"])
result = {
- "user_id": user_id, # may have changed
- "access_token": token,
+ "user_id": login_submission["user"], # may have changed
+ "access_token": access_token,
+ "refresh_token": refresh_token,
"home_server": self.hs.hostname,
}
diff --git a/synapse/rest/client/v2_alpha/__init__.py b/synapse/rest/client/v2_alpha/__init__.py
index 33f961e898..5831ff0e62 100644
--- a/synapse/rest/client/v2_alpha/__init__.py
+++ b/synapse/rest/client/v2_alpha/__init__.py
@@ -21,6 +21,7 @@ from . import (
auth,
receipts,
keys,
+ tokenrefresh,
)
from synapse.http.server import JsonResource
@@ -42,3 +43,4 @@ class ClientV2AlphaRestResource(JsonResource):
auth.register_servlets(hs, client_resource)
receipts.register_servlets(hs, client_resource)
keys.register_servlets(hs, client_resource)
+ tokenrefresh.register_servlets(hs, client_resource)
diff --git a/synapse/rest/client/v2_alpha/tokenrefresh.py b/synapse/rest/client/v2_alpha/tokenrefresh.py
new file mode 100644
index 0000000000..901e777983
--- /dev/null
+++ b/synapse/rest/client/v2_alpha/tokenrefresh.py
@@ -0,0 +1,56 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+from synapse.api.errors import AuthError, StoreError, SynapseError
+from synapse.http.servlet import RestServlet
+
+from ._base import client_v2_pattern, parse_json_dict_from_request
+
+
+class TokenRefreshRestServlet(RestServlet):
+ """
+ Exchanges refresh tokens for a pair of an access token and a new refresh
+ token.
+ """
+ PATTERN = client_v2_pattern("/tokenrefresh")
+
+ def __init__(self, hs):
+ super(TokenRefreshRestServlet, self).__init__()
+ self.hs = hs
+ self.store = hs.get_datastore()
+
+ @defer.inlineCallbacks
+ def on_POST(self, request):
+ body = parse_json_dict_from_request(request)
+ try:
+ old_refresh_token = body["refresh_token"]
+ auth_handler = self.hs.get_handlers().auth_handler
+ (user_id, new_refresh_token) = yield self.store.exchange_refresh_token(
+ old_refresh_token, auth_handler.generate_refresh_token)
+ new_access_token = yield auth_handler.issue_access_token(user_id)
+ defer.returnValue((200, {
+ "access_token": new_access_token,
+ "refresh_token": new_refresh_token,
+ }))
+ except KeyError:
+ raise SynapseError(400, "Missing required key 'refresh_token'.")
+ except StoreError:
+ raise AuthError(403, "Did not recognize refresh token")
+
+
+def register_servlets(hs, http_server):
+ TokenRefreshRestServlet(hs).register(http_server)
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index f154b1c8ae..53673b3bf5 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -54,7 +54,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts.
-SCHEMA_VERSION = 22
+SCHEMA_VERSION = 23
dir_path = os.path.abspath(os.path.dirname(__file__))
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 1444767a52..ce71389f02 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -181,6 +181,7 @@ class SQLBaseStore(object):
self._transaction_id_gen = IdGenerator("sent_transactions", "id", self)
self._state_groups_id_gen = IdGenerator("state_groups", "id", self)
self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self)
+ self._refresh_tokens_id_gen = IdGenerator("refresh_tokens", "id", self)
self._pushers_id_gen = IdGenerator("pushers", "id", self)
self._push_rule_id_gen = IdGenerator("push_rules", "id", self)
self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self)
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index bf803f2c6e..f632306688 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -51,6 +51,28 @@ class RegistrationStore(SQLBaseStore):
)
@defer.inlineCallbacks
+ def add_refresh_token_to_user(self, user_id, token):
+ """Adds a refresh token for the given user.
+
+ Args:
+ user_id (str): The user ID.
+ token (str): The new refresh token to add.
+ Raises:
+ StoreError if there was a problem adding this.
+ """
+ next_id = yield self._refresh_tokens_id_gen.get_next()
+
+ yield self._simple_insert(
+ "refresh_tokens",
+ {
+ "id": next_id,
+ "user_id": user_id,
+ "token": token
+ },
+ desc="add_refresh_token_to_user",
+ )
+
+ @defer.inlineCallbacks
def register(self, user_id, token, password_hash):
"""Attempts to register an account.
@@ -132,10 +154,10 @@ class RegistrationStore(SQLBaseStore):
user_id
)
for r in rows:
- self.get_user_by_token.invalidate((r,))
+ self.get_user_by_access_token.invalidate((r,))
@cached()
- def get_user_by_token(self, token):
+ def get_user_by_access_token(self, token):
"""Get a user from the given access token.
Args:
@@ -147,11 +169,51 @@ class RegistrationStore(SQLBaseStore):
StoreError if no user was found.
"""
return self.runInteraction(
- "get_user_by_token",
+ "get_user_by_access_token",
self._query_for_auth,
token
)
+ def exchange_refresh_token(self, refresh_token, token_generator):
+ """Exchange a refresh token for a new access token and refresh token.
+
+ Doing so invalidates the old refresh token - refresh tokens are single
+ use.
+
+ Args:
+ token (str): The refresh token of a user.
+ token_generator (fn: str -> str): Function which, when given a
+ user ID, returns a unique refresh token for that user. This
+ function must never return the same value twice.
+ Returns:
+ tuple of (user_id, refresh_token)
+ Raises:
+ StoreError if no user was found with that refresh token.
+ """
+ return self.runInteraction(
+ "exchange_refresh_token",
+ self._exchange_refresh_token,
+ refresh_token,
+ token_generator
+ )
+
+ def _exchange_refresh_token(self, txn, old_token, token_generator):
+ sql = "SELECT user_id FROM refresh_tokens WHERE token = ?"
+ txn.execute(sql, (old_token,))
+ rows = self.cursor_to_dict(txn)
+ if not rows:
+ raise StoreError(403, "Did not recognize refresh token")
+ user_id = rows[0]["user_id"]
+
+ # TODO(danielwh): Maybe perform a validation on the macaroon that
+ # macaroon.user_id == user_id.
+
+ new_token = token_generator(user_id)
+ sql = "UPDATE refresh_tokens SET token = ? WHERE token = ?"
+ txn.execute(sql, (new_token, old_token,))
+
+ return user_id, new_token
+
@defer.inlineCallbacks
def is_server_admin(self, user):
res = yield self._simple_select_one_onecol(
diff --git a/synapse/storage/schema/delta/23/refresh_tokens.sql b/synapse/storage/schema/delta/23/refresh_tokens.sql
new file mode 100644
index 0000000000..46839e016c
--- /dev/null
+++ b/synapse/storage/schema/delta/23/refresh_tokens.sql
@@ -0,0 +1,21 @@
+/* Copyright 2015 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE IF NOT EXISTS refresh_tokens(
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ token TEXT NOT NULL,
+ user_id TEXT NOT NULL,
+ UNIQUE (token)
+);
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index 4f83db5e84..3343c635cc 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -44,7 +44,7 @@ class AuthTestCase(unittest.TestCase):
"token_id": "ditto",
"admin": False
}
- self.store.get_user_by_token = Mock(return_value=user_info)
+ self.store.get_user_by_access_token = Mock(return_value=user_info)
request = Mock(args={})
request.args["access_token"] = [self.test_token]
@@ -54,7 +54,7 @@ class AuthTestCase(unittest.TestCase):
def test_get_user_by_req_user_bad_token(self):
self.store.get_app_service_by_token = Mock(return_value=None)
- self.store.get_user_by_token = Mock(return_value=None)
+ self.store.get_user_by_access_token = Mock(return_value=None)
request = Mock(args={})
request.args["access_token"] = [self.test_token]
@@ -70,7 +70,7 @@ class AuthTestCase(unittest.TestCase):
"token_id": "ditto",
"admin": False
}
- self.store.get_user_by_token = Mock(return_value=user_info)
+ self.store.get_user_by_access_token = Mock(return_value=user_info)
request = Mock(args={})
request.requestHeaders.getRawHeaders = Mock(return_value=[""])
@@ -81,7 +81,7 @@ class AuthTestCase(unittest.TestCase):
def test_get_user_by_req_appservice_valid_token(self):
app_service = Mock(token="foobar", url="a_url", sender=self.test_user)
self.store.get_app_service_by_token = Mock(return_value=app_service)
- self.store.get_user_by_token = Mock(return_value=None)
+ self.store.get_user_by_access_token = Mock(return_value=None)
request = Mock(args={})
request.args["access_token"] = [self.test_token]
@@ -91,7 +91,7 @@ class AuthTestCase(unittest.TestCase):
def test_get_user_by_req_appservice_bad_token(self):
self.store.get_app_service_by_token = Mock(return_value=None)
- self.store.get_user_by_token = Mock(return_value=None)
+ self.store.get_user_by_access_token = Mock(return_value=None)
request = Mock(args={})
request.args["access_token"] = [self.test_token]
@@ -102,7 +102,7 @@ class AuthTestCase(unittest.TestCase):
def test_get_user_by_req_appservice_missing_token(self):
app_service = Mock(token="foobar", url="a_url", sender=self.test_user)
self.store.get_app_service_by_token = Mock(return_value=app_service)
- self.store.get_user_by_token = Mock(return_value=None)
+ self.store.get_user_by_access_token = Mock(return_value=None)
request = Mock(args={})
request.requestHeaders.getRawHeaders = Mock(return_value=[""])
@@ -115,7 +115,7 @@ class AuthTestCase(unittest.TestCase):
app_service = Mock(token="foobar", url="a_url", sender=self.test_user)
app_service.is_interested_in_user = Mock(return_value=True)
self.store.get_app_service_by_token = Mock(return_value=app_service)
- self.store.get_user_by_token = Mock(return_value=None)
+ self.store.get_user_by_access_token = Mock(return_value=None)
request = Mock(args={})
request.args["access_token"] = [self.test_token]
@@ -129,7 +129,7 @@ class AuthTestCase(unittest.TestCase):
app_service = Mock(token="foobar", url="a_url", sender=self.test_user)
app_service.is_interested_in_user = Mock(return_value=False)
self.store.get_app_service_by_token = Mock(return_value=app_service)
- self.store.get_user_by_token = Mock(return_value=None)
+ self.store.get_user_by_access_token = Mock(return_value=None)
request = Mock(args={})
request.args["access_token"] = [self.test_token]
diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py
new file mode 100644
index 0000000000..978e4d0d2e
--- /dev/null
+++ b/tests/handlers/test_auth.py
@@ -0,0 +1,70 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import pymacaroons
+
+from mock import Mock, NonCallableMock
+from synapse.handlers.auth import AuthHandler
+from tests import unittest
+from tests.utils import setup_test_homeserver
+from twisted.internet import defer
+
+
+class AuthHandlers(object):
+ def __init__(self, hs):
+ self.auth_handler = AuthHandler(hs)
+
+
+class AuthTestCase(unittest.TestCase):
+ @defer.inlineCallbacks
+ def setUp(self):
+ self.hs = yield setup_test_homeserver(handlers=None)
+ self.hs.handlers = AuthHandlers(self.hs)
+
+ def test_token_is_a_macaroon(self):
+ self.hs.config.macaroon_secret_key = "this key is a huge secret"
+
+ token = self.hs.handlers.auth_handler.generate_access_token("some_user")
+ # Check that we can parse the thing with pymacaroons
+ macaroon = pymacaroons.Macaroon.deserialize(token)
+ # The most basic of sanity checks
+ if "some_user" not in macaroon.inspect():
+ self.fail("some_user was not in %s" % macaroon.inspect())
+
+ def test_macaroon_caveats(self):
+ self.hs.config.macaroon_secret_key = "this key is a massive secret"
+ self.hs.clock.now = 5000
+
+ token = self.hs.handlers.auth_handler.generate_access_token("a_user")
+ macaroon = pymacaroons.Macaroon.deserialize(token)
+
+ def verify_gen(caveat):
+ return caveat == "gen = 1"
+
+ def verify_user(caveat):
+ return caveat == "user_id = a_user"
+
+ def verify_type(caveat):
+ return caveat == "type = access"
+
+ def verify_expiry(caveat):
+ return caveat == "time < 8600000"
+
+ v = pymacaroons.Verifier()
+ v.satisfy_general(verify_gen)
+ v.satisfy_general(verify_user)
+ v.satisfy_general(verify_type)
+ v.satisfy_general(verify_expiry)
+ v.verify(macaroon, self.hs.config.macaroon_secret_key)
diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py
index 089a71568c..0b78a82a66 100644
--- a/tests/rest/client/v1/test_presence.py
+++ b/tests/rest/client/v1/test_presence.py
@@ -70,7 +70,7 @@ class PresenceStateTestCase(unittest.TestCase):
return defer.succeed([])
self.datastore.get_presence_list = get_presence_list
- def _get_user_by_token(token=None):
+ def _get_user_by_access_token(token=None):
return {
"user": UserID.from_string(myid),
"admin": False,
@@ -78,7 +78,7 @@ class PresenceStateTestCase(unittest.TestCase):
"token_id": 1,
}
- hs.get_v1auth().get_user_by_token = _get_user_by_token
+ hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
room_member_handler = hs.handlers.room_member_handler = Mock(
spec=[
@@ -159,7 +159,7 @@ class PresenceListTestCase(unittest.TestCase):
)
self.datastore.has_presence_state = has_presence_state
- def _get_user_by_token(token=None):
+ def _get_user_by_access_token(token=None):
return {
"user": UserID.from_string(myid),
"admin": False,
@@ -173,7 +173,7 @@ class PresenceListTestCase(unittest.TestCase):
]
)
- hs.get_v1auth().get_user_by_token = _get_user_by_token
+ hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
presence.register_servlets(hs, self.mock_resource)
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index c83348acf9..2e55cc08a1 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -54,14 +54,14 @@ class RoomPermissionsTestCase(RestTestCase):
hs.get_handlers().federation_handler = Mock()
- def _get_user_by_token(token=None):
+ def _get_user_by_access_token(token=None):
return {
"user": UserID.from_string(self.auth_user_id),
"admin": False,
"device_id": None,
"token_id": 1,
}
- hs.get_v1auth().get_user_by_token = _get_user_by_token
+ hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
def _insert_client_ip(*args, **kwargs):
return defer.succeed(None)
@@ -441,14 +441,14 @@ class RoomsMemberListTestCase(RestTestCase):
self.auth_user_id = self.user_id
- def _get_user_by_token(token=None):
+ def _get_user_by_access_token(token=None):
return {
"user": UserID.from_string(self.auth_user_id),
"admin": False,
"device_id": None,
"token_id": 1,
}
- hs.get_v1auth().get_user_by_token = _get_user_by_token
+ hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
def _insert_client_ip(*args, **kwargs):
return defer.succeed(None)
@@ -521,14 +521,14 @@ class RoomsCreateTestCase(RestTestCase):
hs.get_handlers().federation_handler = Mock()
- def _get_user_by_token(token=None):
+ def _get_user_by_access_token(token=None):
return {
"user": UserID.from_string(self.auth_user_id),
"admin": False,
"device_id": None,
"token_id": 1,
}
- hs.get_v1auth().get_user_by_token = _get_user_by_token
+ hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
def _insert_client_ip(*args, **kwargs):
return defer.succeed(None)
@@ -614,7 +614,7 @@ class RoomTopicTestCase(RestTestCase):
hs.get_handlers().federation_handler = Mock()
- def _get_user_by_token(token=None):
+ def _get_user_by_access_token(token=None):
return {
"user": UserID.from_string(self.auth_user_id),
"admin": False,
@@ -622,7 +622,7 @@ class RoomTopicTestCase(RestTestCase):
"token_id": 1,
}
- hs.get_v1auth().get_user_by_token = _get_user_by_token
+ hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
def _insert_client_ip(*args, **kwargs):
return defer.succeed(None)
@@ -721,14 +721,14 @@ class RoomMemberStateTestCase(RestTestCase):
hs.get_handlers().federation_handler = Mock()
- def _get_user_by_token(token=None):
+ def _get_user_by_access_token(token=None):
return {
"user": UserID.from_string(self.auth_user_id),
"admin": False,
"device_id": None,
"token_id": 1,
}
- hs.get_v1auth().get_user_by_token = _get_user_by_token
+ hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
def _insert_client_ip(*args, **kwargs):
return defer.succeed(None)
@@ -848,14 +848,14 @@ class RoomMessagesTestCase(RestTestCase):
hs.get_handlers().federation_handler = Mock()
- def _get_user_by_token(token=None):
+ def _get_user_by_access_token(token=None):
return {
"user": UserID.from_string(self.auth_user_id),
"admin": False,
"device_id": None,
"token_id": 1,
}
- hs.get_v1auth().get_user_by_token = _get_user_by_token
+ hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
def _insert_client_ip(*args, **kwargs):
return defer.succeed(None)
@@ -945,14 +945,14 @@ class RoomInitialSyncTestCase(RestTestCase):
hs.get_handlers().federation_handler = Mock()
- def _get_user_by_token(token=None):
+ def _get_user_by_access_token(token=None):
return {
"user": UserID.from_string(self.auth_user_id),
"admin": False,
"device_id": None,
"token_id": 1,
}
- hs.get_v1auth().get_user_by_token = _get_user_by_token
+ hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
def _insert_client_ip(*args, **kwargs):
return defer.succeed(None)
diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py
index 7d8b1c2683..dc8bbaaf0e 100644
--- a/tests/rest/client/v1/test_typing.py
+++ b/tests/rest/client/v1/test_typing.py
@@ -61,7 +61,7 @@ class RoomTypingTestCase(RestTestCase):
hs.get_handlers().federation_handler = Mock()
- def _get_user_by_token(token=None):
+ def _get_user_by_access_token(token=None):
return {
"user": UserID.from_string(self.auth_user_id),
"admin": False,
@@ -69,7 +69,7 @@ class RoomTypingTestCase(RestTestCase):
"token_id": 1,
}
- hs.get_v1auth().get_user_by_token = _get_user_by_token
+ hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
def _insert_client_ip(*args, **kwargs):
return defer.succeed(None)
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index 579441fb4a..c472d53043 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -37,7 +37,7 @@ class RestTestCase(unittest.TestCase):
self.mock_resource = None
self.auth_user_id = None
- def mock_get_user_by_token(self, token=None):
+ def mock_get_user_by_access_token(self, token=None):
return self.auth_user_id
@defer.inlineCallbacks
diff --git a/tests/rest/client/v2_alpha/__init__.py b/tests/rest/client/v2_alpha/__init__.py
index de5a917e6a..15568b36cd 100644
--- a/tests/rest/client/v2_alpha/__init__.py
+++ b/tests/rest/client/v2_alpha/__init__.py
@@ -43,14 +43,14 @@ class V2AlphaRestTestCase(unittest.TestCase):
resource_for_federation=self.mock_resource,
)
- def _get_user_by_token(token=None):
+ def _get_user_by_access_token(token=None):
return {
"user": UserID.from_string(self.USER_ID),
"admin": False,
"device_id": None,
"token_id": 1,
}
- hs.get_auth().get_user_by_token = _get_user_by_token
+ hs.get_auth().get_user_by_access_token = _get_user_by_access_token
for r in self.TO_REGISTER:
r.register_servlets(hs, self.mock_resource)
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index 2702291178..a4f929796a 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -17,7 +17,9 @@
from tests import unittest
from twisted.internet import defer
+from synapse.api.errors import StoreError
from synapse.storage.registration import RegistrationStore
+from synapse.util import stringutils
from tests.utils import setup_test_homeserver
@@ -27,6 +29,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
hs = yield setup_test_homeserver()
+ self.db_pool = hs.get_db_pool()
self.store = RegistrationStore(hs)
@@ -46,7 +49,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
(yield self.store.get_user_by_id(self.user_id))
)
- result = yield self.store.get_user_by_token(self.tokens[0])
+ result = yield self.store.get_user_by_access_token(self.tokens[0])
self.assertDictContainsSubset(
{
@@ -64,7 +67,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
yield self.store.register(self.user_id, self.tokens[0], self.pwhash)
yield self.store.add_access_token_to_user(self.user_id, self.tokens[1])
- result = yield self.store.get_user_by_token(self.tokens[1])
+ result = yield self.store.get_user_by_access_token(self.tokens[1])
self.assertDictContainsSubset(
{
@@ -77,3 +80,55 @@ class RegistrationStoreTestCase(unittest.TestCase):
self.assertTrue("token_id" in result)
+ @defer.inlineCallbacks
+ def test_exchange_refresh_token_valid(self):
+ uid = stringutils.random_string(32)
+ generator = TokenGenerator()
+ last_token = generator.generate(uid)
+
+ self.db_pool.runQuery(
+ "INSERT INTO refresh_tokens(user_id, token) VALUES(?,?)",
+ (uid, last_token,))
+
+ (found_user_id, refresh_token) = yield self.store.exchange_refresh_token(
+ last_token, generator.generate)
+ self.assertEqual(uid, found_user_id)
+
+ rows = yield self.db_pool.runQuery(
+ "SELECT token FROM refresh_tokens WHERE user_id = ?", (uid, ))
+ self.assertEqual([(refresh_token,)], rows)
+ # We issued token 1, then exchanged it for token 2
+ expected_refresh_token = u"%s-%d" % (uid, 2,)
+ self.assertEqual(expected_refresh_token, refresh_token)
+
+ @defer.inlineCallbacks
+ def test_exchange_refresh_token_none(self):
+ uid = stringutils.random_string(32)
+ generator = TokenGenerator()
+ last_token = generator.generate(uid)
+
+ with self.assertRaises(StoreError):
+ yield self.store.exchange_refresh_token(last_token, generator.generate)
+
+ @defer.inlineCallbacks
+ def test_exchange_refresh_token_invalid(self):
+ uid = stringutils.random_string(32)
+ generator = TokenGenerator()
+ last_token = generator.generate(uid)
+ wrong_token = "%s-wrong" % (last_token,)
+
+ self.db_pool.runQuery(
+ "INSERT INTO refresh_tokens(user_id, token) VALUES(?,?)",
+ (uid, wrong_token,))
+
+ with self.assertRaises(StoreError):
+ yield self.store.exchange_refresh_token(last_token, generator.generate)
+
+
+class TokenGenerator:
+ def __init__(self):
+ self._last_issued_token = 0
+
+ def generate(self, user_id):
+ self._last_issued_token += 1
+ return u"%s-%d" % (user_id, self._last_issued_token,)
diff --git a/tests/utils.py b/tests/utils.py
index eb035cf48f..d0fba2252d 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -44,6 +44,8 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
config.signing_key = [MockKey()]
config.event_cache_size = 1
config.disable_registration = False
+ config.macaroon_secret_key = "not even a little secret"
+ config.server_name = "server.under.test"
if "clock" not in kargs:
kargs["clock"] = MockClock()
@@ -275,7 +277,7 @@ class MemoryDataStore(object):
raise StoreError(400, "User in use.")
self.tokens_to_users[token] = user_id
- def get_user_by_token(self, token):
+ def get_user_by_access_token(self, token):
try:
return {
"name": self.tokens_to_users[token],
|