diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index d4e75b5b2e..c0cb8ef296 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -21,7 +21,14 @@ from twisted.internet import defer
import synapse.handlers.auth
from synapse.api.auth import Auth
-from synapse.api.errors import AuthError, Codes, ResourceLimitError
+from synapse.api.errors import (
+ AuthError,
+ Codes,
+ InvalidClientCredentialsError,
+ InvalidClientTokenError,
+ MissingClientTokenError,
+ ResourceLimitError,
+)
from synapse.types import UserID
from tests import unittest
@@ -70,7 +77,9 @@ class AuthTestCase(unittest.TestCase):
request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
d = self.auth.get_user_by_req(request)
- self.failureResultOf(d, AuthError)
+ f = self.failureResultOf(d, InvalidClientTokenError).value
+ self.assertEqual(f.code, 401)
+ self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
def test_get_user_by_req_user_missing_token(self):
user_info = {"name": self.test_user, "token_id": "ditto"}
@@ -79,7 +88,9 @@ class AuthTestCase(unittest.TestCase):
request = Mock(args={})
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
d = self.auth.get_user_by_req(request)
- self.failureResultOf(d, AuthError)
+ f = self.failureResultOf(d, MissingClientTokenError).value
+ self.assertEqual(f.code, 401)
+ self.assertEqual(f.errcode, "M_MISSING_TOKEN")
@defer.inlineCallbacks
def test_get_user_by_req_appservice_valid_token(self):
@@ -133,7 +144,9 @@ class AuthTestCase(unittest.TestCase):
request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
d = self.auth.get_user_by_req(request)
- self.failureResultOf(d, AuthError)
+ f = self.failureResultOf(d, InvalidClientTokenError).value
+ self.assertEqual(f.code, 401)
+ self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
def test_get_user_by_req_appservice_bad_token(self):
self.store.get_app_service_by_token = Mock(return_value=None)
@@ -143,7 +156,9 @@ class AuthTestCase(unittest.TestCase):
request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
d = self.auth.get_user_by_req(request)
- self.failureResultOf(d, AuthError)
+ f = self.failureResultOf(d, InvalidClientTokenError).value
+ self.assertEqual(f.code, 401)
+ self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
def test_get_user_by_req_appservice_missing_token(self):
app_service = Mock(token="foobar", url="a_url", sender=self.test_user)
@@ -153,7 +168,9 @@ class AuthTestCase(unittest.TestCase):
request = Mock(args={})
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
d = self.auth.get_user_by_req(request)
- self.failureResultOf(d, AuthError)
+ f = self.failureResultOf(d, MissingClientTokenError).value
+ self.assertEqual(f.code, 401)
+ self.assertEqual(f.errcode, "M_MISSING_TOKEN")
@defer.inlineCallbacks
def test_get_user_by_req_appservice_valid_token_valid_user_id(self):
@@ -244,10 +261,12 @@ class AuthTestCase(unittest.TestCase):
USER_ID = "@percy:matrix.org"
self.store.add_access_token_to_user = Mock()
- token = yield self.hs.handlers.auth_handler.issue_access_token(
- USER_ID, "DEVICE"
+ token = yield self.hs.handlers.auth_handler.get_access_token_for_user_id(
+ USER_ID, "DEVICE", valid_until_ms=None
+ )
+ self.store.add_access_token_to_user.assert_called_with(
+ USER_ID, token, "DEVICE", None
)
- self.store.add_access_token_to_user.assert_called_with(USER_ID, token, "DEVICE")
def get_user(tok):
if token != tok:
@@ -280,7 +299,7 @@ class AuthTestCase(unittest.TestCase):
request.args[b"access_token"] = [guest_tok.encode("ascii")]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
- with self.assertRaises(AuthError) as cm:
+ with self.assertRaises(InvalidClientCredentialsError) as cm:
yield self.auth.get_user_by_req(request, allow_guest=True)
self.assertEqual(401, cm.exception.code)
@@ -325,7 +344,7 @@ class AuthTestCase(unittest.TestCase):
unknown_threepid = {"medium": "email", "address": "unreserved@server.com"}
self.hs.config.mau_limits_reserved_threepids = [threepid]
- yield self.store.register(user_id="user1", token="123", password_hash=None)
+ yield self.store.register_user(user_id="user1", password_hash=None)
with self.assertRaises(ResourceLimitError):
yield self.auth.check_auth_blocking()
diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py
index b204a0700d..b03103d96f 100644
--- a/tests/handlers/test_auth.py
+++ b/tests/handlers/test_auth.py
@@ -117,7 +117,9 @@ class AuthTestCase(unittest.TestCase):
def test_mau_limits_disabled(self):
self.hs.config.limit_usage_by_mau = False
# Ensure does not throw exception
- yield self.auth_handler.get_access_token_for_user_id("user_a")
+ yield self.auth_handler.get_access_token_for_user_id(
+ "user_a", device_id=None, valid_until_ms=None
+ )
yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
self._get_macaroon().serialize()
@@ -131,7 +133,9 @@ class AuthTestCase(unittest.TestCase):
)
with self.assertRaises(ResourceLimitError):
- yield self.auth_handler.get_access_token_for_user_id("user_a")
+ yield self.auth_handler.get_access_token_for_user_id(
+ "user_a", device_id=None, valid_until_ms=None
+ )
self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.large_number_of_users)
@@ -150,7 +154,9 @@ class AuthTestCase(unittest.TestCase):
return_value=defer.succeed(self.hs.config.max_mau_value)
)
with self.assertRaises(ResourceLimitError):
- yield self.auth_handler.get_access_token_for_user_id("user_a")
+ yield self.auth_handler.get_access_token_for_user_id(
+ "user_a", device_id=None, valid_until_ms=None
+ )
self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.hs.config.max_mau_value)
@@ -166,7 +172,9 @@ class AuthTestCase(unittest.TestCase):
self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.hs.config.max_mau_value)
)
- yield self.auth_handler.get_access_token_for_user_id("user_a")
+ yield self.auth_handler.get_access_token_for_user_id(
+ "user_a", device_id=None, valid_until_ms=None
+ )
self.hs.get_datastore().user_last_seen_monthly_active = Mock(
return_value=defer.succeed(self.hs.get_clock().time_msec())
)
@@ -185,7 +193,9 @@ class AuthTestCase(unittest.TestCase):
return_value=defer.succeed(self.small_number_of_users)
)
# Ensure does not raise exception
- yield self.auth_handler.get_access_token_for_user_id("user_a")
+ yield self.auth_handler.get_access_token_for_user_id(
+ "user_a", device_id=None, valid_until_ms=None
+ )
self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.small_number_of_users)
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index 4edce7af43..90d0129374 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -18,7 +18,7 @@ from mock import Mock
from twisted.internet import defer
from synapse.api.constants import UserTypes
-from synapse.api.errors import ResourceLimitError, SynapseError
+from synapse.api.errors import Codes, ResourceLimitError, SynapseError
from synapse.handlers.register import RegistrationHandler
from synapse.types import RoomAlias, UserID, create_requester
@@ -67,7 +67,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
user_id = frank.to_string()
requester = create_requester(user_id)
result_user_id, result_token = self.get_success(
- self.handler.get_or_create_user(requester, frank.localpart, "Frankie")
+ self.get_or_create_user(requester, frank.localpart, "Frankie")
)
self.assertEquals(result_user_id, user_id)
self.assertTrue(result_token is not None)
@@ -77,17 +77,13 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
store = self.hs.get_datastore()
frank = UserID.from_string("@frank:test")
self.get_success(
- store.register(
- user_id=frank.to_string(),
- token="jkv;g498752-43gj['eamb!-5",
- password_hash=None,
- )
+ store.register_user(user_id=frank.to_string(), password_hash=None)
)
local_part = frank.localpart
user_id = frank.to_string()
requester = create_requester(user_id)
result_user_id, result_token = self.get_success(
- self.handler.get_or_create_user(requester, local_part, None)
+ self.get_or_create_user(requester, local_part, None)
)
self.assertEquals(result_user_id, user_id)
self.assertTrue(result_token is not None)
@@ -95,9 +91,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
def test_mau_limits_when_disabled(self):
self.hs.config.limit_usage_by_mau = False
# Ensure does not throw exception
- self.get_success(
- self.handler.get_or_create_user(self.requester, "a", "display_name")
- )
+ self.get_success(self.get_or_create_user(self.requester, "a", "display_name"))
def test_get_or_create_user_mau_not_blocked(self):
self.hs.config.limit_usage_by_mau = True
@@ -105,7 +99,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
return_value=defer.succeed(self.hs.config.max_mau_value - 1)
)
# Ensure does not throw exception
- self.get_success(self.handler.get_or_create_user(self.requester, "c", "User"))
+ self.get_success(self.get_or_create_user(self.requester, "c", "User"))
def test_get_or_create_user_mau_blocked(self):
self.hs.config.limit_usage_by_mau = True
@@ -113,7 +107,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
return_value=defer.succeed(self.lots_of_users)
)
self.get_failure(
- self.handler.get_or_create_user(self.requester, "b", "display_name"),
+ self.get_or_create_user(self.requester, "b", "display_name"),
ResourceLimitError,
)
@@ -121,7 +115,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
return_value=defer.succeed(self.hs.config.max_mau_value)
)
self.get_failure(
- self.handler.get_or_create_user(self.requester, "b", "display_name"),
+ self.get_or_create_user(self.requester, "b", "display_name"),
ResourceLimitError,
)
@@ -131,21 +125,21 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
return_value=defer.succeed(self.lots_of_users)
)
self.get_failure(
- self.handler.register(localpart="local_part"), ResourceLimitError
+ self.handler.register_user(localpart="local_part"), ResourceLimitError
)
self.store.get_monthly_active_count = Mock(
return_value=defer.succeed(self.hs.config.max_mau_value)
)
self.get_failure(
- self.handler.register(localpart="local_part"), ResourceLimitError
+ self.handler.register_user(localpart="local_part"), ResourceLimitError
)
def test_auto_create_auto_join_rooms(self):
room_alias_str = "#room:test"
self.hs.config.auto_join_rooms = [room_alias_str]
- res = self.get_success(self.handler.register(localpart="jeff"))
- rooms = self.get_success(self.store.get_rooms_for_user(res[0]))
+ user_id = self.get_success(self.handler.register_user(localpart="jeff"))
+ rooms = self.get_success(self.store.get_rooms_for_user(user_id))
directory_handler = self.hs.get_handlers().directory_handler
room_alias = RoomAlias.from_string(room_alias_str)
room_id = self.get_success(directory_handler.get_association(room_alias))
@@ -156,25 +150,25 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
def test_auto_create_auto_join_rooms_with_no_rooms(self):
self.hs.config.auto_join_rooms = []
frank = UserID.from_string("@frank:test")
- res = self.get_success(self.handler.register(frank.localpart))
- self.assertEqual(res[0], frank.to_string())
- rooms = self.get_success(self.store.get_rooms_for_user(res[0]))
+ user_id = self.get_success(self.handler.register_user(frank.localpart))
+ self.assertEqual(user_id, frank.to_string())
+ rooms = self.get_success(self.store.get_rooms_for_user(user_id))
self.assertEqual(len(rooms), 0)
def test_auto_create_auto_join_where_room_is_another_domain(self):
self.hs.config.auto_join_rooms = ["#room:another"]
frank = UserID.from_string("@frank:test")
- res = self.get_success(self.handler.register(frank.localpart))
- self.assertEqual(res[0], frank.to_string())
- rooms = self.get_success(self.store.get_rooms_for_user(res[0]))
+ user_id = self.get_success(self.handler.register_user(frank.localpart))
+ self.assertEqual(user_id, frank.to_string())
+ rooms = self.get_success(self.store.get_rooms_for_user(user_id))
self.assertEqual(len(rooms), 0)
def test_auto_create_auto_join_where_auto_create_is_false(self):
self.hs.config.autocreate_auto_join_rooms = False
room_alias_str = "#room:test"
self.hs.config.auto_join_rooms = [room_alias_str]
- res = self.get_success(self.handler.register(localpart="jeff"))
- rooms = self.get_success(self.store.get_rooms_for_user(res[0]))
+ user_id = self.get_success(self.handler.register_user(localpart="jeff"))
+ rooms = self.get_success(self.store.get_rooms_for_user(user_id))
self.assertEqual(len(rooms), 0)
def test_auto_create_auto_join_rooms_when_support_user_exists(self):
@@ -182,8 +176,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.hs.config.auto_join_rooms = [room_alias_str]
self.store.is_support_user = Mock(return_value=True)
- res = self.get_success(self.handler.register(localpart="support"))
- rooms = self.get_success(self.store.get_rooms_for_user(res[0]))
+ user_id = self.get_success(self.handler.register_user(localpart="support"))
+ rooms = self.get_success(self.store.get_rooms_for_user(user_id))
self.assertEqual(len(rooms), 0)
directory_handler = self.hs.get_handlers().directory_handler
room_alias = RoomAlias.from_string(room_alias_str)
@@ -211,24 +205,82 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
# When:-
# * the user is registered and post consent actions are called
- res = self.get_success(self.handler.register(localpart="jeff"))
- self.get_success(self.handler.post_consent_actions(res[0]))
+ user_id = self.get_success(self.handler.register_user(localpart="jeff"))
+ self.get_success(self.handler.post_consent_actions(user_id))
# Then:-
# * Ensure that they have not been joined to the room
- rooms = self.get_success(self.store.get_rooms_for_user(res[0]))
+ rooms = self.get_success(self.store.get_rooms_for_user(user_id))
self.assertEqual(len(rooms), 0)
def test_register_support_user(self):
- res = self.get_success(
- self.handler.register(localpart="user", user_type=UserTypes.SUPPORT)
+ user_id = self.get_success(
+ self.handler.register_user(localpart="user", user_type=UserTypes.SUPPORT)
)
- self.assertTrue(self.store.is_support_user(res[0]))
+ d = self.store.is_support_user(user_id)
+ self.assertTrue(self.get_success(d))
def test_register_not_support_user(self):
- res = self.get_success(self.handler.register(localpart="user"))
- self.assertFalse(self.store.is_support_user(res[0]))
+ user_id = self.get_success(self.handler.register_user(localpart="user"))
+ d = self.store.is_support_user(user_id)
+ self.assertFalse(self.get_success(d))
def test_invalid_user_id_length(self):
invalid_user_id = "x" * 256
- self.get_failure(self.handler.register(localpart=invalid_user_id), SynapseError)
+ self.get_failure(
+ self.handler.register_user(localpart=invalid_user_id), SynapseError
+ )
+
+ @defer.inlineCallbacks
+ def get_or_create_user(self, requester, localpart, displayname, password_hash=None):
+ """Creates a new user if the user does not exist,
+ else revokes all previous access tokens and generates a new one.
+
+ XXX: this used to be in the main codebase, but was only used by this file,
+ so got moved here. TODO: get rid of it, probably
+
+ Args:
+ localpart : The local part of the user ID to register. If None,
+ one will be randomly generated.
+ Returns:
+ A tuple of (user_id, access_token).
+ Raises:
+ RegistrationError if there was a problem registering.
+ """
+ if localpart is None:
+ raise SynapseError(400, "Request must include user id")
+ yield self.hs.get_auth().check_auth_blocking()
+ need_register = True
+
+ try:
+ yield self.handler.check_username(localpart)
+ except SynapseError as e:
+ if e.errcode == Codes.USER_IN_USE:
+ need_register = False
+ else:
+ raise
+
+ user = UserID(localpart, self.hs.hostname)
+ user_id = user.to_string()
+ token = self.macaroon_generator.generate_access_token(user_id)
+
+ if need_register:
+ yield self.handler.register_with_store(
+ user_id=user_id,
+ password_hash=password_hash,
+ create_profile_with_displayname=user.localpart,
+ )
+ else:
+ yield self.hs.get_auth_handler().delete_access_tokens_for_user(user_id)
+
+ yield self.store.add_access_token_to_user(
+ user_id=user_id, token=token, device_id=None, valid_until_ms=None
+ )
+
+ if displayname is not None:
+ # logger.info("setting user display name: %s -> %s", user_id, displayname)
+ yield self.hs.get_profile_handler().set_displayname(
+ user, requester, displayname, by_admin=True
+ )
+
+ defer.returnValue((user_id, token))
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index b135486c48..c5e91a8c41 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -47,11 +47,8 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
def test_handle_local_profile_change_with_support_user(self):
support_user_id = "@support:test"
self.get_success(
- self.store.register(
- user_id=support_user_id,
- token="123",
- password_hash=None,
- user_type=UserTypes.SUPPORT,
+ self.store.register_user(
+ user_id=support_user_id, password_hash=None, user_type=UserTypes.SUPPORT
)
)
@@ -73,11 +70,8 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
def test_handle_user_deactivated_support_user(self):
s_user_id = "@support:test"
self.get_success(
- self.store.register(
- user_id=s_user_id,
- token="123",
- password_hash=None,
- user_type=UserTypes.SUPPORT,
+ self.store.register_user(
+ user_id=s_user_id, password_hash=None, user_type=UserTypes.SUPPORT
)
)
@@ -90,7 +84,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
def test_handle_user_deactivated_regular_user(self):
r_user_id = "@regular:test"
self.get_success(
- self.store.register(user_id=r_user_id, token="123", password_hash=None)
+ self.store.register_user(user_id=r_user_id, password_hash=None)
)
self.store.remove_from_user_dir = Mock()
self.get_success(self.handler.handle_user_deactivated(r_user_id))
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index 0397f91a9e..eae5411325 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/v1/test_login.py
@@ -2,10 +2,14 @@ import json
import synapse.rest.admin
from synapse.rest.client.v1 import login
+from synapse.rest.client.v2_alpha import devices
+from synapse.rest.client.v2_alpha.account import WhoamiRestServlet
from tests import unittest
+from tests.unittest import override_config
LOGIN_URL = b"/_matrix/client/r0/login"
+TEST_URL = b"/_matrix/client/r0/account/whoami"
class LoginRestServletTestCase(unittest.HomeserverTestCase):
@@ -13,6 +17,8 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
login.register_servlets,
+ devices.register_servlets,
+ lambda hs, http_server: WhoamiRestServlet(hs).register(http_server),
]
def make_homeserver(self, reactor, clock):
@@ -144,3 +150,105 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
self.render(request)
self.assertEquals(channel.result["code"], b"403", channel.result)
+
+ @override_config({"session_lifetime": "24h"})
+ def test_soft_logout(self):
+ self.register_user("kermit", "monkey")
+
+ # we shouldn't be able to make requests without an access token
+ request, channel = self.make_request(b"GET", TEST_URL)
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"401", channel.result)
+ self.assertEquals(channel.json_body["errcode"], "M_MISSING_TOKEN")
+
+ # log in as normal
+ params = {
+ "type": "m.login.password",
+ "identifier": {"type": "m.id.user", "user": "kermit"},
+ "password": "monkey",
+ }
+ request, channel = self.make_request(b"POST", LOGIN_URL, params)
+ self.render(request)
+
+ self.assertEquals(channel.code, 200, channel.result)
+ access_token = channel.json_body["access_token"]
+ device_id = channel.json_body["device_id"]
+
+ # we should now be able to make requests with the access token
+ request, channel = self.make_request(
+ b"GET", TEST_URL, access_token=access_token
+ )
+ self.render(request)
+ self.assertEquals(channel.code, 200, channel.result)
+
+ # time passes
+ self.reactor.advance(24 * 3600)
+
+ # ... and we should be soft-logouted
+ request, channel = self.make_request(
+ b"GET", TEST_URL, access_token=access_token
+ )
+ self.render(request)
+ self.assertEquals(channel.code, 401, channel.result)
+ self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
+ self.assertEquals(channel.json_body["soft_logout"], True)
+
+ #
+ # test behaviour after deleting the expired device
+ #
+
+ # we now log in as a different device
+ access_token_2 = self.login("kermit", "monkey")
+
+ # more requests with the expired token should still return a soft-logout
+ self.reactor.advance(3600)
+ request, channel = self.make_request(
+ b"GET", TEST_URL, access_token=access_token
+ )
+ self.render(request)
+ self.assertEquals(channel.code, 401, channel.result)
+ self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
+ self.assertEquals(channel.json_body["soft_logout"], True)
+
+ # ... but if we delete that device, it will be a proper logout
+ self._delete_device(access_token_2, "kermit", "monkey", device_id)
+
+ request, channel = self.make_request(
+ b"GET", TEST_URL, access_token=access_token
+ )
+ self.render(request)
+ self.assertEquals(channel.code, 401, channel.result)
+ self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
+ self.assertEquals(channel.json_body["soft_logout"], False)
+
+ def _delete_device(self, access_token, user_id, password, device_id):
+ """Perform the UI-Auth to delete a device"""
+ request, channel = self.make_request(
+ b"DELETE", "devices/" + device_id, access_token=access_token
+ )
+ self.render(request)
+ self.assertEquals(channel.code, 401, channel.result)
+ # check it's a UI-Auth fail
+ self.assertEqual(
+ set(channel.json_body.keys()),
+ {"flows", "params", "session"},
+ channel.result,
+ )
+
+ auth = {
+ "type": "m.login.password",
+ # https://github.com/matrix-org/synapse/issues/5665
+ # "identifier": {"type": "m.id.user", "user": user_id},
+ "user": user_id,
+ "password": password,
+ "session": channel.json_body["session"],
+ }
+
+ request, channel = self.make_request(
+ b"DELETE",
+ "devices/" + device_id,
+ access_token=access_token,
+ content={"auth": auth},
+ )
+ self.render(request)
+ self.assertEquals(channel.code, 200, channel.result)
diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py
index dff9b2f10c..140d8b3772 100644
--- a/tests/rest/client/v1/test_profile.py
+++ b/tests/rest/client/v1/test_profile.py
@@ -288,3 +288,50 @@ class ProfilesRestrictedTestCase(unittest.HomeserverTestCase):
# if the user isn't already in the room), because we only want to
# make sure the user isn't in the room.
pass
+
+
+class OwnProfileUnrestrictedTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ admin.register_servlets_for_client_rest_resource,
+ login.register_servlets,
+ profile.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ config = self.default_config()
+ config["require_auth_for_profile_requests"] = True
+ self.hs = self.setup_test_homeserver(config=config)
+
+ return self.hs
+
+ def prepare(self, reactor, clock, hs):
+ # User requesting the profile.
+ self.requester = self.register_user("requester", "pass")
+ self.requester_tok = self.login("requester", "pass")
+
+ def test_can_lookup_own_profile(self):
+ """Tests that a user can lookup their own profile without having to be in a room
+ if 'require_auth_for_profile_requests' is set to true in the server's config.
+ """
+ request, channel = self.make_request(
+ "GET", "/profile/" + self.requester, access_token=self.requester_tok
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.result)
+
+ request, channel = self.make_request(
+ "GET",
+ "/profile/" + self.requester + "/displayname",
+ access_token=self.requester_tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.result)
+
+ request, channel = self.make_request(
+ "GET",
+ "/profile/" + self.requester + "/avatar_url",
+ access_token=self.requester_tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.result)
diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py
index 3deeed3a70..58c6951852 100644
--- a/tests/rest/client/v2_alpha/test_relations.py
+++ b/tests/rest/client/v2_alpha/test_relations.py
@@ -126,6 +126,11 @@ class RelationsTestCase(unittest.HomeserverTestCase):
channel.json_body["chunk"][0],
)
+ # We also expect to get the original event (the id of which is self.parent_id)
+ self.assertEquals(
+ channel.json_body["original_event"]["event_id"], self.parent_id
+ )
+
# Make sure next_batch has something in it that looks like it could be a
# valid token.
self.assertIsInstance(
@@ -466,9 +471,15 @@ class RelationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.json_body["content"], new_body)
- self.assertEquals(
- channel.json_body["unsigned"].get("m.relations"),
- {RelationTypes.REPLACE: {"event_id": edit_event_id}},
+ relations_dict = channel.json_body["unsigned"].get("m.relations")
+ self.assertIn(RelationTypes.REPLACE, relations_dict)
+
+ m_replace_dict = relations_dict[RelationTypes.REPLACE]
+ for key in ["event_id", "sender", "origin_server_ts"]:
+ self.assertIn(key, m_replace_dict)
+
+ self.assert_dict(
+ {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict
)
def test_multi_edit(self):
@@ -518,9 +529,15 @@ class RelationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.json_body["content"], new_body)
- self.assertEquals(
- channel.json_body["unsigned"].get("m.relations"),
- {RelationTypes.REPLACE: {"event_id": edit_event_id}},
+ relations_dict = channel.json_body["unsigned"].get("m.relations")
+ self.assertIn(RelationTypes.REPLACE, relations_dict)
+
+ m_replace_dict = relations_dict[RelationTypes.REPLACE]
+ for key in ["event_id", "sender", "origin_server_ts"]:
+ self.assertIn(key, m_replace_dict)
+
+ self.assert_dict(
+ {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict
)
def _send_relation(
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index 59c6f8c227..09305c3bf1 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -185,9 +185,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
self.hs.config.limit_usage_by_mau = True
self.hs.config.max_mau_value = 50
user_id = "@user:server"
- self.get_success(
- self.store.register(user_id=user_id, token="123", password_hash=None)
- )
+ self.get_success(self.store.register_user(user_id=user_id, password_hash=None))
active = self.get_success(self.store.user_last_seen_monthly_active(user_id))
self.assertFalse(active)
diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py
index 0ce0b991f9..1494650d10 100644
--- a/tests/storage/test_monthly_active_users.py
+++ b/tests/storage/test_monthly_active_users.py
@@ -53,10 +53,10 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
# -1 because user3 is a support user and does not count
user_num = len(threepids) - 1
- self.store.register(user_id=user1, token="123", password_hash=None)
- self.store.register(user_id=user2, token="456", password_hash=None)
- self.store.register(
- user_id=user3, token="789", password_hash=None, user_type=UserTypes.SUPPORT
+ self.store.register_user(user_id=user1, password_hash=None)
+ self.store.register_user(user_id=user2, password_hash=None)
+ self.store.register_user(
+ user_id=user3, password_hash=None, user_type=UserTypes.SUPPORT
)
self.pump()
@@ -161,9 +161,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
def test_populate_monthly_users_is_guest(self):
# Test that guest users are not added to mau list
user_id = "@user_id:host"
- self.store.register(
- user_id=user_id, token="123", password_hash=None, make_guest=True
- )
+ self.store.register_user(user_id=user_id, password_hash=None, make_guest=True)
self.store.upsert_monthly_active_user = Mock()
self.store.populate_monthly_active_users(user_id)
self.pump()
@@ -216,8 +214,8 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.assertEquals(self.get_success(count), 0)
# Test reserved registed users
- self.store.register(user_id=user1, token="123", password_hash=None)
- self.store.register(user_id=user2, token="456", password_hash=None)
+ self.store.register_user(user_id=user1, password_hash=None)
+ self.store.register_user(user_id=user2, password_hash=None)
self.pump()
now = int(self.hs.get_clock().time_msec())
@@ -232,11 +230,8 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.pump()
self.assertEqual(self.get_success(count), 0)
- self.store.register(
- user_id=support_user_id,
- token="123",
- password_hash=None,
- user_type=UserTypes.SUPPORT,
+ self.store.register_user(
+ user_id=support_user_id, password_hash=None, user_type=UserTypes.SUPPORT
)
self.store.upsert_monthly_active_user(support_user_id)
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index 625b651e91..0253c4ac05 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -37,7 +37,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_register(self):
- yield self.store.register(self.user_id, self.tokens[0], self.pwhash)
+ yield self.store.register_user(self.user_id, self.pwhash)
self.assertEquals(
{
@@ -53,17 +53,11 @@ class RegistrationStoreTestCase(unittest.TestCase):
(yield self.store.get_user_by_id(self.user_id)),
)
- result = yield self.store.get_user_by_access_token(self.tokens[0])
-
- self.assertDictContainsSubset({"name": self.user_id}, result)
-
- self.assertTrue("token_id" in result)
-
@defer.inlineCallbacks
def test_add_tokens(self):
- yield self.store.register(self.user_id, self.tokens[0], self.pwhash)
+ yield self.store.register_user(self.user_id, self.pwhash)
yield self.store.add_access_token_to_user(
- self.user_id, self.tokens[1], self.device_id
+ self.user_id, self.tokens[1], self.device_id, valid_until_ms=None
)
result = yield self.store.get_user_by_access_token(self.tokens[1])
@@ -77,9 +71,12 @@ class RegistrationStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_user_delete_access_tokens(self):
# add some tokens
- yield self.store.register(self.user_id, self.tokens[0], self.pwhash)
+ yield self.store.register_user(self.user_id, self.pwhash)
yield self.store.add_access_token_to_user(
- self.user_id, self.tokens[1], self.device_id
+ self.user_id, self.tokens[0], device_id=None, valid_until_ms=None
+ )
+ yield self.store.add_access_token_to_user(
+ self.user_id, self.tokens[1], self.device_id, valid_until_ms=None
)
# now delete some
@@ -108,24 +105,12 @@ class RegistrationStoreTestCase(unittest.TestCase):
res = yield self.store.is_support_user(None)
self.assertFalse(res)
- yield self.store.register(user_id=TEST_USER, token="123", password_hash=None)
+ yield self.store.register_user(user_id=TEST_USER, password_hash=None)
res = yield self.store.is_support_user(TEST_USER)
self.assertFalse(res)
- yield self.store.register(
- user_id=SUPPORT_USER,
- token="456",
- password_hash=None,
- user_type=UserTypes.SUPPORT,
+ yield self.store.register_user(
+ user_id=SUPPORT_USER, password_hash=None, user_type=UserTypes.SUPPORT
)
res = yield self.store.is_support_user(SUPPORT_USER)
self.assertTrue(res)
-
-
-class TokenGenerator:
- def __init__(self):
- self._last_issued_token = 0
-
- def generate(self, user_id):
- self._last_issued_token += 1
- return "%s-%d" % (user_id, self._last_issued_token)
diff --git a/tests/unittest.py b/tests/unittest.py
index 3ebed6e612..cabe787cb4 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -157,6 +157,21 @@ class HomeserverTestCase(TestCase):
"""
A base TestCase that reduces boilerplate for HomeServer-using test cases.
+ Defines a setUp method which creates a mock reactor, and instantiates a homeserver
+ running on that reactor.
+
+ There are various hooks for modifying the way that the homeserver is instantiated:
+
+ * override make_homeserver, for example by making it pass different parameters into
+ setup_test_homeserver.
+
+ * override default_config, to return a modified configuration dictionary for use
+ by setup_test_homeserver.
+
+ * On a per-test basis, you can use the @override_config decorator to give a
+ dictionary containing additional configuration settings to be added to the basic
+ config dict.
+
Attributes:
servlets (list[function]): List of servlet registration function.
user_id (str): The user ID to assume if auth is hijacked.
@@ -168,6 +183,13 @@ class HomeserverTestCase(TestCase):
hijack_auth = True
needs_threadpool = False
+ def __init__(self, methodName, *args, **kwargs):
+ super().__init__(methodName, *args, **kwargs)
+
+ # see if we have any additional config for this test
+ method = getattr(self, methodName)
+ self._extra_config = getattr(method, "_extra_config", None)
+
def setUp(self):
"""
Set up the TestCase by calling the homeserver constructor, optionally
@@ -276,7 +298,14 @@ class HomeserverTestCase(TestCase):
Args:
name (str): The homeserver name/domain.
"""
- return default_config(name)
+ config = default_config(name)
+
+ # apply any additional config which was specified via the override_config
+ # decorator.
+ if self._extra_config is not None:
+ config.update(self._extra_config)
+
+ return config
def prepare(self, reactor, clock, homeserver):
"""
@@ -534,3 +563,27 @@ class HomeserverTestCase(TestCase):
)
self.render(request)
self.assertEqual(channel.code, 403, channel.result)
+
+
+def override_config(extra_config):
+ """A decorator which can be applied to test functions to give additional HS config
+
+ For use
+
+ For example:
+
+ class MyTestCase(HomeserverTestCase):
+ @override_config({"enable_registration": False, ...})
+ def test_foo(self):
+ ...
+
+ Args:
+ extra_config(dict): Additional config settings to be merged into the default
+ config dict before instantiating the test homeserver.
+ """
+
+ def decorator(func):
+ func._extra_config = extra_config
+ return func
+
+ return decorator
|