diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py
index 0e42013bb9..c9f889b511 100644
--- a/tests/handlers/test_auth.py
+++ b/tests/handlers/test_auth.py
@@ -68,38 +68,45 @@ class AuthTestCase(unittest.HomeserverTestCase):
v.verify(macaroon, self.hs.config.macaroon_secret_key)
def test_short_term_login_token_gives_user_id(self):
- token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000)
- user_id = self.get_success(
- self.auth_handler.validate_short_term_login_token_and_get_user_id(token)
+ token = self.macaroon_generator.generate_short_term_login_token(
+ "a_user", "", 5000
)
- self.assertEqual("a_user", user_id)
+ res = self.get_success(self.auth_handler.validate_short_term_login_token(token))
+ self.assertEqual("a_user", res.user_id)
+ self.assertEqual("", res.auth_provider_id)
# when we advance the clock, the token should be rejected
self.reactor.advance(6)
self.get_failure(
- self.auth_handler.validate_short_term_login_token_and_get_user_id(token),
+ self.auth_handler.validate_short_term_login_token(token),
AuthError,
)
+ def test_short_term_login_token_gives_auth_provider(self):
+ token = self.macaroon_generator.generate_short_term_login_token(
+ "a_user", auth_provider_id="my_idp"
+ )
+ res = self.get_success(self.auth_handler.validate_short_term_login_token(token))
+ self.assertEqual("a_user", res.user_id)
+ self.assertEqual("my_idp", res.auth_provider_id)
+
def test_short_term_login_token_cannot_replace_user_id(self):
- token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000)
+ token = self.macaroon_generator.generate_short_term_login_token(
+ "a_user", "", 5000
+ )
macaroon = pymacaroons.Macaroon.deserialize(token)
- user_id = self.get_success(
- self.auth_handler.validate_short_term_login_token_and_get_user_id(
- macaroon.serialize()
- )
+ res = self.get_success(
+ self.auth_handler.validate_short_term_login_token(macaroon.serialize())
)
- self.assertEqual("a_user", user_id)
+ self.assertEqual("a_user", res.user_id)
# add another "user_id" caveat, which might allow us to override the
# user_id.
macaroon.add_first_party_caveat("user_id = b_user")
self.get_failure(
- self.auth_handler.validate_short_term_login_token_and_get_user_id(
- macaroon.serialize()
- ),
+ self.auth_handler.validate_short_term_login_token(macaroon.serialize()),
AuthError,
)
@@ -113,7 +120,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
)
self.get_success(
- self.auth_handler.validate_short_term_login_token_and_get_user_id(
+ self.auth_handler.validate_short_term_login_token(
self._get_macaroon().serialize()
)
)
@@ -135,7 +142,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
return_value=make_awaitable(self.large_number_of_users)
)
self.get_failure(
- self.auth_handler.validate_short_term_login_token_and_get_user_id(
+ self.auth_handler.validate_short_term_login_token(
self._get_macaroon().serialize()
),
ResourceLimitError,
@@ -159,7 +166,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
ResourceLimitError,
)
self.get_failure(
- self.auth_handler.validate_short_term_login_token_and_get_user_id(
+ self.auth_handler.validate_short_term_login_token(
self._get_macaroon().serialize()
),
ResourceLimitError,
@@ -175,7 +182,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
)
)
self.get_success(
- self.auth_handler.validate_short_term_login_token_and_get_user_id(
+ self.auth_handler.validate_short_term_login_token(
self._get_macaroon().serialize()
)
)
@@ -197,11 +204,13 @@ class AuthTestCase(unittest.HomeserverTestCase):
return_value=make_awaitable(self.small_number_of_users)
)
self.get_success(
- self.auth_handler.validate_short_term_login_token_and_get_user_id(
+ self.auth_handler.validate_short_term_login_token(
self._get_macaroon().serialize()
)
)
def _get_macaroon(self):
- token = self.macaroon_generator.generate_short_term_login_token("user_a", 5000)
+ token = self.macaroon_generator.generate_short_term_login_token(
+ "user_a", "", 5000
+ )
return pymacaroons.Macaroon.deserialize(token)
diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py
index 6f992291b8..7975af243c 100644
--- a/tests/handlers/test_cas.py
+++ b/tests/handlers/test_cas.py
@@ -66,7 +66,7 @@ class CasHandlerTestCase(HomeserverTestCase):
# check that the auth handler got called as expected
auth_handler.complete_sso_login.assert_called_once_with(
- "@test_user:test", request, "redirect_uri", None, new_user=True
+ "@test_user:test", "cas", request, "redirect_uri", None, new_user=True
)
def test_map_cas_user_to_existing_user(self):
@@ -89,7 +89,7 @@ class CasHandlerTestCase(HomeserverTestCase):
# check that the auth handler got called as expected
auth_handler.complete_sso_login.assert_called_once_with(
- "@test_user:test", request, "redirect_uri", None, new_user=False
+ "@test_user:test", "cas", request, "redirect_uri", None, new_user=False
)
# Subsequent calls should map to the same mxid.
@@ -98,7 +98,7 @@ class CasHandlerTestCase(HomeserverTestCase):
self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
)
auth_handler.complete_sso_login.assert_called_once_with(
- "@test_user:test", request, "redirect_uri", None, new_user=False
+ "@test_user:test", "cas", request, "redirect_uri", None, new_user=False
)
def test_map_cas_user_to_invalid_localpart(self):
@@ -116,7 +116,7 @@ class CasHandlerTestCase(HomeserverTestCase):
# check that the auth handler got called as expected
auth_handler.complete_sso_login.assert_called_once_with(
- "@f=c3=b6=c3=b6:test", request, "redirect_uri", None, new_user=True
+ "@f=c3=b6=c3=b6:test", "cas", request, "redirect_uri", None, new_user=True
)
@override_config(
@@ -160,7 +160,7 @@ class CasHandlerTestCase(HomeserverTestCase):
# check that the auth handler got called as expected
auth_handler.complete_sso_login.assert_called_once_with(
- "@test_user:test", request, "redirect_uri", None, new_user=True
+ "@test_user:test", "cas", request, "redirect_uri", None, new_user=True
)
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index cf1de28fa9..02d4b2de0d 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
-from typing import Optional
from urllib.parse import parse_qs, urlparse
from mock import ANY, Mock, patch
@@ -23,6 +22,7 @@ import pymacaroons
from synapse.handlers.sso import MappingException
from synapse.server import HomeServer
from synapse.types import UserID
+from synapse.util.macaroons import get_value_from_macaroon
from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock
from tests.unittest import HomeserverTestCase, override_config
@@ -360,15 +360,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.assertEqual(name, b"oidc_session")
macaroon = pymacaroons.Macaroon.deserialize(cookie)
- state = self.handler._token_generator._get_value_from_macaroon(
- macaroon, "state"
- )
- nonce = self.handler._token_generator._get_value_from_macaroon(
- macaroon, "nonce"
- )
- redirect = self.handler._token_generator._get_value_from_macaroon(
- macaroon, "client_redirect_url"
- )
+ state = get_value_from_macaroon(macaroon, "state")
+ nonce = get_value_from_macaroon(macaroon, "nonce")
+ redirect = get_value_from_macaroon(macaroon, "client_redirect_url")
self.assertEqual(params["state"], [state])
self.assertEqual(params["nonce"], [nonce])
@@ -434,7 +428,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.get_success(self.handler.handle_oidc_callback(request))
auth_handler.complete_sso_login.assert_called_once_with(
- expected_user_id, request, client_redirect_url, None, new_user=True
+ expected_user_id, "oidc", request, client_redirect_url, None, new_user=True
)
self.provider._exchange_code.assert_called_once_with(code)
self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce)
@@ -465,7 +459,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.get_success(self.handler.handle_oidc_callback(request))
auth_handler.complete_sso_login.assert_called_once_with(
- expected_user_id, request, client_redirect_url, None, new_user=False
+ expected_user_id, "oidc", request, client_redirect_url, None, new_user=False
)
self.provider._exchange_code.assert_called_once_with(code)
self.provider._parse_id_token.assert_not_called()
@@ -651,6 +645,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
auth_handler.complete_sso_login.assert_called_once_with(
"@foo:test",
+ "oidc",
request,
client_redirect_url,
{"phone": "1234567"},
@@ -668,7 +663,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
}
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_called_once_with(
- "@test_user:test", ANY, ANY, None, new_user=True
+ "@test_user:test", "oidc", ANY, ANY, None, new_user=True
)
auth_handler.complete_sso_login.reset_mock()
@@ -679,7 +674,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
}
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_called_once_with(
- "@test_user_2:test", ANY, ANY, None, new_user=True
+ "@test_user_2:test", "oidc", ANY, ANY, None, new_user=True
)
auth_handler.complete_sso_login.reset_mock()
@@ -716,14 +711,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
}
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_called_once_with(
- user.to_string(), ANY, ANY, None, new_user=False
+ user.to_string(), "oidc", ANY, ANY, None, new_user=False
)
auth_handler.complete_sso_login.reset_mock()
# Subsequent calls should map to the same mxid.
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_called_once_with(
- user.to_string(), ANY, ANY, None, new_user=False
+ user.to_string(), "oidc", ANY, ANY, None, new_user=False
)
auth_handler.complete_sso_login.reset_mock()
@@ -738,7 +733,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
}
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_called_once_with(
- user.to_string(), ANY, ANY, None, new_user=False
+ user.to_string(), "oidc", ANY, ANY, None, new_user=False
)
auth_handler.complete_sso_login.reset_mock()
@@ -774,7 +769,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_called_once_with(
- "@TEST_USER_2:test", ANY, ANY, None, new_user=False
+ "@TEST_USER_2:test", "oidc", ANY, ANY, None, new_user=False
)
def test_map_userinfo_to_invalid_localpart(self):
@@ -810,7 +805,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
# test_user is already taken, so test_user1 gets registered instead.
auth_handler.complete_sso_login.assert_called_once_with(
- "@test_user1:test", ANY, ANY, None, new_user=True
+ "@test_user1:test", "oidc", ANY, ANY, None, new_user=True
)
auth_handler.complete_sso_login.reset_mock()
@@ -866,7 +861,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
state: str,
nonce: str,
client_redirect_url: str,
- ui_auth_session_id: Optional[str] = None,
+ ui_auth_session_id: str = "",
) -> str:
from synapse.handlers.oidc_handler import OidcSessionData
@@ -909,6 +904,7 @@ async def _make_callback_with_userinfo(
idp_id="oidc",
nonce="nonce",
client_redirect_url=client_redirect_url,
+ ui_auth_session_id="",
),
)
request = _build_callback_request("code", state, session)
diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py
index 029af2853e..30efd43b40 100644
--- a/tests/handlers/test_saml.py
+++ b/tests/handlers/test_saml.py
@@ -131,7 +131,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
# check that the auth handler got called as expected
auth_handler.complete_sso_login.assert_called_once_with(
- "@test_user:test", request, "redirect_uri", None, new_user=True
+ "@test_user:test", "saml", request, "redirect_uri", None, new_user=True
)
@override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}})
@@ -157,7 +157,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
# check that the auth handler got called as expected
auth_handler.complete_sso_login.assert_called_once_with(
- "@test_user:test", request, "", None, new_user=False
+ "@test_user:test", "saml", request, "", None, new_user=False
)
# Subsequent calls should map to the same mxid.
@@ -166,7 +166,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
self.handler._handle_authn_response(request, saml_response, "")
)
auth_handler.complete_sso_login.assert_called_once_with(
- "@test_user:test", request, "", None, new_user=False
+ "@test_user:test", "saml", request, "", None, new_user=False
)
def test_map_saml_response_to_invalid_localpart(self):
@@ -214,7 +214,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
# test_user is already taken, so test_user1 gets registered instead.
auth_handler.complete_sso_login.assert_called_once_with(
- "@test_user1:test", request, "", None, new_user=True
+ "@test_user1:test", "saml", request, "", None, new_user=True
)
auth_handler.complete_sso_login.reset_mock()
@@ -310,7 +310,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
# check that the auth handler got called as expected
auth_handler.complete_sso_login.assert_called_once_with(
- "@test_user:test", request, "redirect_uri", None, new_user=True
+ "@test_user:test", "saml", request, "redirect_uri", None, new_user=True
)
diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py
index a06ad2c03e..41af8c4847 100644
--- a/tests/storage/test_purge.py
+++ b/tests/storage/test_purge.py
@@ -13,9 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
-
-from synapse.api.errors import NotFoundError
+from synapse.api.errors import NotFoundError, SynapseError
from synapse.rest.client.v1 import room
from tests.unittest import HomeserverTestCase
@@ -33,9 +31,12 @@ class PurgeTests(HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.room_id = self.helper.create_room_as(self.user_id)
- def test_purge(self):
+ self.store = hs.get_datastore()
+ self.storage = self.hs.get_storage()
+
+ def test_purge_history(self):
"""
- Purging a room will delete everything before the topological point.
+ Purging a room history will delete everything before the topological point.
"""
# Send four messages to the room
first = self.helper.send(self.room_id, body="test1")
@@ -43,30 +44,27 @@ class PurgeTests(HomeserverTestCase):
third = self.helper.send(self.room_id, body="test3")
last = self.helper.send(self.room_id, body="test4")
- store = self.hs.get_datastore()
- storage = self.hs.get_storage()
-
# Get the topological token
token = self.get_success(
- store.get_topological_token_for_event(last["event_id"])
+ self.store.get_topological_token_for_event(last["event_id"])
)
token_str = self.get_success(token.to_string(self.hs.get_datastore()))
# Purge everything before this topological token
self.get_success(
- storage.purge_events.purge_history(self.room_id, token_str, True)
+ self.storage.purge_events.purge_history(self.room_id, token_str, True)
)
# 1-3 should fail and last will succeed, meaning that 1-3 are deleted
# and last is not.
- self.get_failure(store.get_event(first["event_id"]), NotFoundError)
- self.get_failure(store.get_event(second["event_id"]), NotFoundError)
- self.get_failure(store.get_event(third["event_id"]), NotFoundError)
- self.get_success(store.get_event(last["event_id"]))
+ self.get_failure(self.store.get_event(first["event_id"]), NotFoundError)
+ self.get_failure(self.store.get_event(second["event_id"]), NotFoundError)
+ self.get_failure(self.store.get_event(third["event_id"]), NotFoundError)
+ self.get_success(self.store.get_event(last["event_id"]))
- def test_purge_wont_delete_extrems(self):
+ def test_purge_history_wont_delete_extrems(self):
"""
- Purging a room will delete everything before the topological point.
+ Purging a room history will delete everything before the topological point.
"""
# Send four messages to the room
first = self.helper.send(self.room_id, body="test1")
@@ -74,22 +72,43 @@ class PurgeTests(HomeserverTestCase):
third = self.helper.send(self.room_id, body="test3")
last = self.helper.send(self.room_id, body="test4")
- storage = self.hs.get_datastore()
-
# Set the topological token higher than it should be
token = self.get_success(
- storage.get_topological_token_for_event(last["event_id"])
+ self.store.get_topological_token_for_event(last["event_id"])
)
event = "t{}-{}".format(token.topological + 1, token.stream + 1)
# Purge everything before this topological token
- purge = defer.ensureDeferred(storage.purge_history(self.room_id, event, True))
- self.pump()
- f = self.failureResultOf(purge)
+ f = self.get_failure(
+ self.storage.purge_events.purge_history(self.room_id, event, True),
+ SynapseError,
+ )
self.assertIn("greater than forward", f.value.args[0])
# Try and get the events
- self.get_success(storage.get_event(first["event_id"]))
- self.get_success(storage.get_event(second["event_id"]))
- self.get_success(storage.get_event(third["event_id"]))
- self.get_success(storage.get_event(last["event_id"]))
+ self.get_success(self.store.get_event(first["event_id"]))
+ self.get_success(self.store.get_event(second["event_id"]))
+ self.get_success(self.store.get_event(third["event_id"]))
+ self.get_success(self.store.get_event(last["event_id"]))
+
+ def test_purge_room(self):
+ """
+ Purging a room will delete everything about it.
+ """
+ # Send four messages to the room
+ first = self.helper.send(self.room_id, body="test1")
+
+ # Get the current room state.
+ state_handler = self.hs.get_state_handler()
+ create_event = self.get_success(
+ state_handler.get_current_state(self.room_id, "m.room.create", "")
+ )
+ self.assertIsNotNone(create_event)
+
+ # Purge everything before this topological token
+ self.get_success(self.storage.purge_events.purge_room(self.room_id))
+
+ # The events aren't found.
+ self.store._invalidate_get_event_cache(create_event.event_id)
+ self.get_failure(self.store.get_event(create_event.event_id), NotFoundError)
+ self.get_failure(self.store.get_event(first["event_id"]), NotFoundError)
diff --git a/tests/util/caches/test_responsecache.py b/tests/util/caches/test_responsecache.py
new file mode 100644
index 0000000000..f9a187b8de
--- /dev/null
+++ b/tests/util/caches/test_responsecache.py
@@ -0,0 +1,131 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.util.caches.response_cache import ResponseCache
+
+from tests.server import get_clock
+from tests.unittest import TestCase
+
+
+class DeferredCacheTestCase(TestCase):
+ """
+ A TestCase class for ResponseCache.
+
+ The test-case function naming has some logic to it in it's parts, here's some notes about it:
+ wait: Denotes tests that have an element of "waiting" before its wrapped result becomes available
+ (Generally these just use .delayed_return instead of .instant_return in it's wrapped call.)
+ expire: Denotes tests that test expiry after assured existence.
+ (These have cache with a short timeout_ms=, shorter than will be tested through advancing the clock)
+ """
+
+ def setUp(self):
+ self.reactor, self.clock = get_clock()
+
+ def with_cache(self, name: str, ms: int = 0) -> ResponseCache:
+ return ResponseCache(self.clock, name, timeout_ms=ms)
+
+ @staticmethod
+ async def instant_return(o: str) -> str:
+ return o
+
+ async def delayed_return(self, o: str) -> str:
+ await self.clock.sleep(1)
+ return o
+
+ def test_cache_hit(self):
+ cache = self.with_cache("keeping_cache", ms=9001)
+
+ expected_result = "howdy"
+
+ wrap_d = cache.wrap(0, self.instant_return, expected_result)
+
+ self.assertEqual(
+ expected_result,
+ self.successResultOf(wrap_d),
+ "initial wrap result should be the same",
+ )
+ self.assertEqual(
+ expected_result,
+ self.successResultOf(cache.get(0)),
+ "cache should have the result",
+ )
+
+ def test_cache_miss(self):
+ cache = self.with_cache("trashing_cache", ms=0)
+
+ expected_result = "howdy"
+
+ wrap_d = cache.wrap(0, self.instant_return, expected_result)
+
+ self.assertEqual(
+ expected_result,
+ self.successResultOf(wrap_d),
+ "initial wrap result should be the same",
+ )
+ self.assertIsNone(cache.get(0), "cache should not have the result now")
+
+ def test_cache_expire(self):
+ cache = self.with_cache("short_cache", ms=1000)
+
+ expected_result = "howdy"
+
+ wrap_d = cache.wrap(0, self.instant_return, expected_result)
+
+ self.assertEqual(expected_result, self.successResultOf(wrap_d))
+ self.assertEqual(
+ expected_result,
+ self.successResultOf(cache.get(0)),
+ "cache should still have the result",
+ )
+
+ # cache eviction timer is handled
+ self.reactor.pump((2,))
+
+ self.assertIsNone(cache.get(0), "cache should not have the result now")
+
+ def test_cache_wait_hit(self):
+ cache = self.with_cache("neutral_cache")
+
+ expected_result = "howdy"
+
+ wrap_d = cache.wrap(0, self.delayed_return, expected_result)
+ self.assertNoResult(wrap_d)
+
+ # function wakes up, returns result
+ self.reactor.pump((2,))
+
+ self.assertEqual(expected_result, self.successResultOf(wrap_d))
+
+ def test_cache_wait_expire(self):
+ cache = self.with_cache("medium_cache", ms=3000)
+
+ expected_result = "howdy"
+
+ wrap_d = cache.wrap(0, self.delayed_return, expected_result)
+ self.assertNoResult(wrap_d)
+
+ # stop at 1 second to callback cache eviction callLater at that time, then another to set time at 2
+ self.reactor.pump((1, 1))
+
+ self.assertEqual(expected_result, self.successResultOf(wrap_d))
+ self.assertEqual(
+ expected_result,
+ self.successResultOf(cache.get(0)),
+ "cache should still have the result",
+ )
+
+ # (1 + 1 + 2) > 3.0, cache eviction timer is handled
+ self.reactor.pump((2,))
+
+ self.assertIsNone(cache.get(0), "cache should not have the result now")
|