diff --git a/tests/config/test_cache.py b/tests/config/test_cache.py
new file mode 100644
index 0000000000..d3ec24c975
--- /dev/null
+++ b/tests/config/test_cache.py
@@ -0,0 +1,171 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.config._base import Config, RootConfig
+from synapse.config.cache import CacheConfig, add_resizable_cache
+from synapse.util.caches.lrucache import LruCache
+
+from tests.unittest import TestCase
+
+
+class FakeServer(Config):
+ section = "server"
+
+
+class TestConfig(RootConfig):
+ config_classes = [FakeServer, CacheConfig]
+
+
+class CacheConfigTests(TestCase):
+ def setUp(self):
+ # Reset caches before each test
+ TestConfig().caches.reset()
+
+ def test_individual_caches_from_environ(self):
+ """
+ Individual cache factors will be loaded from the environment.
+ """
+ config = {}
+ t = TestConfig()
+ t.caches._environ = {
+ "SYNAPSE_CACHE_FACTOR_SOMETHING_OR_OTHER": "2",
+ "SYNAPSE_NOT_CACHE": "BLAH",
+ }
+ t.read_config(config, config_dir_path="", data_dir_path="")
+
+ self.assertEqual(dict(t.caches.cache_factors), {"something_or_other": 2.0})
+
+ def test_config_overrides_environ(self):
+ """
+ Individual cache factors defined in the environment will take precedence
+ over those in the config.
+ """
+ config = {"caches": {"per_cache_factors": {"foo": 2, "bar": 3}}}
+ t = TestConfig()
+ t.caches._environ = {
+ "SYNAPSE_CACHE_FACTOR_SOMETHING_OR_OTHER": "2",
+ "SYNAPSE_CACHE_FACTOR_FOO": 1,
+ }
+ t.read_config(config, config_dir_path="", data_dir_path="")
+
+ self.assertEqual(
+ dict(t.caches.cache_factors),
+ {"foo": 1.0, "bar": 3.0, "something_or_other": 2.0},
+ )
+
+ def test_individual_instantiated_before_config_load(self):
+ """
+ If a cache is instantiated before the config is read, it will be given
+ the default cache size in the interim, and then resized once the config
+ is loaded.
+ """
+ cache = LruCache(100)
+
+ add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor)
+ self.assertEqual(cache.max_size, 50)
+
+ config = {"caches": {"per_cache_factors": {"foo": 3}}}
+ t = TestConfig()
+ t.read_config(config, config_dir_path="", data_dir_path="")
+
+ self.assertEqual(cache.max_size, 300)
+
+ def test_individual_instantiated_after_config_load(self):
+ """
+ If a cache is instantiated after the config is read, it will be
+ immediately resized to the correct size given the per_cache_factor if
+ there is one.
+ """
+ config = {"caches": {"per_cache_factors": {"foo": 2}}}
+ t = TestConfig()
+ t.read_config(config, config_dir_path="", data_dir_path="")
+
+ cache = LruCache(100)
+ add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor)
+ self.assertEqual(cache.max_size, 200)
+
+ def test_global_instantiated_before_config_load(self):
+ """
+ If a cache is instantiated before the config is read, it will be given
+ the default cache size in the interim, and then resized to the new
+ default cache size once the config is loaded.
+ """
+ cache = LruCache(100)
+ add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor)
+ self.assertEqual(cache.max_size, 50)
+
+ config = {"caches": {"global_factor": 4}}
+ t = TestConfig()
+ t.read_config(config, config_dir_path="", data_dir_path="")
+
+ self.assertEqual(cache.max_size, 400)
+
+ def test_global_instantiated_after_config_load(self):
+ """
+ If a cache is instantiated after the config is read, it will be
+ immediately resized to the correct size given the global factor if there
+ is no per-cache factor.
+ """
+ config = {"caches": {"global_factor": 1.5}}
+ t = TestConfig()
+ t.read_config(config, config_dir_path="", data_dir_path="")
+
+ cache = LruCache(100)
+ add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor)
+ self.assertEqual(cache.max_size, 150)
+
+ def test_cache_with_asterisk_in_name(self):
+ """Some caches have asterisks in their name, test that they are set correctly.
+ """
+
+ config = {
+ "caches": {
+ "per_cache_factors": {"*cache_a*": 5, "cache_b": 6, "cache_c": 2}
+ }
+ }
+ t = TestConfig()
+ t.caches._environ = {
+ "SYNAPSE_CACHE_FACTOR_CACHE_A": "2",
+ "SYNAPSE_CACHE_FACTOR_CACHE_B": 3,
+ }
+ t.read_config(config, config_dir_path="", data_dir_path="")
+
+ cache_a = LruCache(100)
+ add_resizable_cache("*cache_a*", cache_resize_callback=cache_a.set_cache_factor)
+ self.assertEqual(cache_a.max_size, 200)
+
+ cache_b = LruCache(100)
+ add_resizable_cache("*Cache_b*", cache_resize_callback=cache_b.set_cache_factor)
+ self.assertEqual(cache_b.max_size, 300)
+
+ cache_c = LruCache(100)
+ add_resizable_cache("*cache_c*", cache_resize_callback=cache_c.set_cache_factor)
+ self.assertEqual(cache_c.max_size, 200)
+
+ def test_apply_cache_factor_from_config(self):
+ """Caches can disable applying cache factor updates, mainly used by
+ event cache size.
+ """
+
+ config = {"caches": {"event_cache_size": "10k"}}
+ t = TestConfig()
+ t.read_config(config, config_dir_path="", data_dir_path="")
+
+ cache = LruCache(
+ max_size=t.caches.event_cache_size, apply_cache_factor_from_config=False,
+ )
+ add_resizable_cache("event_cache", cache_resize_callback=cache.set_cache_factor)
+
+ self.assertEqual(cache.max_size, 10240)
diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py
index ab5f5ac549..c1274c14af 100644
--- a/tests/events/test_utils.py
+++ b/tests/events/test_utils.py
@@ -156,7 +156,7 @@ class PruneEventTestCase(unittest.TestCase):
"signatures": {},
"unsigned": {},
},
- room_version=RoomVersions.MSC2432_DEV,
+ room_version=RoomVersions.V6,
)
diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py
index 94980733c4..0c9987be54 100644
--- a/tests/federation/test_complexity.py
+++ b/tests/federation/test_complexity.py
@@ -79,7 +79,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
# Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=defer.succeed({"v1": 9999}))
- handler.federation_handler.do_invite_join = Mock(return_value=defer.succeed(1))
+ handler.federation_handler.do_invite_join = Mock(
+ return_value=defer.succeed(("", 1))
+ )
d = handler._remote_join(
None,
@@ -115,7 +117,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
# Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=defer.succeed(None))
- handler.federation_handler.do_invite_join = Mock(return_value=defer.succeed(1))
+ handler.federation_handler.do_invite_join = Mock(
+ return_value=defer.succeed(("", 1))
+ )
# Artificially raise the complexity
self.hs.get_datastore().get_current_state_event_counts = lambda x: defer.succeed(
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index 132e35651d..96fea58673 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -13,9 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+from unittest import TestCase
from synapse.api.constants import EventTypes
-from synapse.api.errors import AuthError, Codes
+from synapse.api.errors import AuthError, Codes, SynapseError
+from synapse.api.room_versions import RoomVersions
+from synapse.events import EventBase
from synapse.federation.federation_base import event_from_pdu_json
from synapse.logging.context import LoggingContext, run_in_background
from synapse.rest import admin
@@ -207,3 +210,65 @@ class FederationTestCase(unittest.HomeserverTestCase):
self.assertEqual(r[(EventTypes.Member, other_user)], join_event.event_id)
return join_event
+
+
+class EventFromPduTestCase(TestCase):
+ def test_valid_json(self):
+ """Valid JSON should be turned into an event."""
+ ev = event_from_pdu_json(
+ {
+ "type": EventTypes.Message,
+ "content": {"bool": True, "null": None, "int": 1, "str": "foobar"},
+ "room_id": "!room:test",
+ "sender": "@user:test",
+ "depth": 1,
+ "prev_events": [],
+ "auth_events": [],
+ "origin_server_ts": 1234,
+ },
+ RoomVersions.V6,
+ )
+
+ self.assertIsInstance(ev, EventBase)
+
+ def test_invalid_numbers(self):
+ """Invalid values for an integer should be rejected, all floats should be rejected."""
+ for value in [
+ -(2 ** 53),
+ 2 ** 53,
+ 1.0,
+ float("inf"),
+ float("-inf"),
+ float("nan"),
+ ]:
+ with self.assertRaises(SynapseError):
+ event_from_pdu_json(
+ {
+ "type": EventTypes.Message,
+ "content": {"foo": value},
+ "room_id": "!room:test",
+ "sender": "@user:test",
+ "depth": 1,
+ "prev_events": [],
+ "auth_events": [],
+ "origin_server_ts": 1234,
+ },
+ RoomVersions.V6,
+ )
+
+ def test_invalid_nested(self):
+ """List and dictionaries are recursively searched."""
+ with self.assertRaises(SynapseError):
+ event_from_pdu_json(
+ {
+ "type": EventTypes.Message,
+ "content": {"foo": [{"bar": 2 ** 56}]},
+ "room_id": "!room:test",
+ "sender": "@user:test",
+ "depth": 1,
+ "prev_events": [],
+ "auth_events": [],
+ "origin_server_ts": 1234,
+ },
+ RoomVersions.V6,
+ )
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
new file mode 100644
index 0000000000..1bb25ab684
--- /dev/null
+++ b/tests/handlers/test_oidc.py
@@ -0,0 +1,570 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Quentin Gliech
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+from urllib.parse import parse_qs, urlparse
+
+from mock import Mock, patch
+
+import attr
+import pymacaroons
+
+from twisted.internet import defer
+from twisted.python.failure import Failure
+from twisted.web._newclient import ResponseDone
+
+from synapse.handlers.oidc_handler import (
+ MappingException,
+ OidcError,
+ OidcHandler,
+ OidcMappingProvider,
+)
+from synapse.types import UserID
+
+from tests.unittest import HomeserverTestCase, override_config
+
+
+@attr.s
+class FakeResponse:
+ code = attr.ib()
+ body = attr.ib()
+ phrase = attr.ib()
+
+ def deliverBody(self, protocol):
+ protocol.dataReceived(self.body)
+ protocol.connectionLost(Failure(ResponseDone()))
+
+
+# These are a few constants that are used as config parameters in the tests.
+ISSUER = "https://issuer/"
+CLIENT_ID = "test-client-id"
+CLIENT_SECRET = "test-client-secret"
+BASE_URL = "https://synapse/"
+CALLBACK_URL = BASE_URL + "_synapse/oidc/callback"
+SCOPES = ["openid"]
+
+AUTHORIZATION_ENDPOINT = ISSUER + "authorize"
+TOKEN_ENDPOINT = ISSUER + "token"
+USERINFO_ENDPOINT = ISSUER + "userinfo"
+WELL_KNOWN = ISSUER + ".well-known/openid-configuration"
+JWKS_URI = ISSUER + ".well-known/jwks.json"
+
+# config for common cases
+COMMON_CONFIG = {
+ "discover": False,
+ "authorization_endpoint": AUTHORIZATION_ENDPOINT,
+ "token_endpoint": TOKEN_ENDPOINT,
+ "jwks_uri": JWKS_URI,
+}
+
+
+# The cookie name and path don't really matter, just that it has to be coherent
+# between the callback & redirect handlers.
+COOKIE_NAME = b"oidc_session"
+COOKIE_PATH = "/_synapse/oidc"
+
+MockedMappingProvider = Mock(OidcMappingProvider)
+
+
+def simple_async_mock(return_value=None, raises=None):
+ # AsyncMock is not available in python3.5, this mimics part of its behaviour
+ async def cb(*args, **kwargs):
+ if raises:
+ raise raises
+ return return_value
+
+ return Mock(side_effect=cb)
+
+
+async def get_json(url):
+ # Mock get_json calls to handle jwks & oidc discovery endpoints
+ if url == WELL_KNOWN:
+ # Minimal discovery document, as defined in OpenID.Discovery
+ # https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata
+ return {
+ "issuer": ISSUER,
+ "authorization_endpoint": AUTHORIZATION_ENDPOINT,
+ "token_endpoint": TOKEN_ENDPOINT,
+ "jwks_uri": JWKS_URI,
+ "userinfo_endpoint": USERINFO_ENDPOINT,
+ "response_types_supported": ["code"],
+ "subject_types_supported": ["public"],
+ "id_token_signing_alg_values_supported": ["RS256"],
+ }
+ elif url == JWKS_URI:
+ return {"keys": []}
+
+
+class OidcHandlerTestCase(HomeserverTestCase):
+ def make_homeserver(self, reactor, clock):
+
+ self.http_client = Mock(spec=["get_json"])
+ self.http_client.get_json.side_effect = get_json
+ self.http_client.user_agent = "Synapse Test"
+
+ config = self.default_config()
+ config["public_baseurl"] = BASE_URL
+ oidc_config = config.get("oidc_config", {})
+ oidc_config["enabled"] = True
+ oidc_config["client_id"] = CLIENT_ID
+ oidc_config["client_secret"] = CLIENT_SECRET
+ oidc_config["issuer"] = ISSUER
+ oidc_config["scopes"] = SCOPES
+ oidc_config["user_mapping_provider"] = {
+ "module": __name__ + ".MockedMappingProvider"
+ }
+ config["oidc_config"] = oidc_config
+
+ hs = self.setup_test_homeserver(
+ http_client=self.http_client,
+ proxied_http_client=self.http_client,
+ config=config,
+ )
+
+ self.handler = OidcHandler(hs)
+
+ return hs
+
+ def metadata_edit(self, values):
+ return patch.dict(self.handler._provider_metadata, values)
+
+ def assertRenderedError(self, error, error_description=None):
+ args = self.handler._render_error.call_args[0]
+ self.assertEqual(args[1], error)
+ if error_description is not None:
+ self.assertEqual(args[2], error_description)
+ # Reset the render_error mock
+ self.handler._render_error.reset_mock()
+
+ def test_config(self):
+ """Basic config correctly sets up the callback URL and client auth correctly."""
+ self.assertEqual(self.handler._callback_url, CALLBACK_URL)
+ self.assertEqual(self.handler._client_auth.client_id, CLIENT_ID)
+ self.assertEqual(self.handler._client_auth.client_secret, CLIENT_SECRET)
+
+ @override_config({"oidc_config": {"discover": True}})
+ @defer.inlineCallbacks
+ def test_discovery(self):
+ """The handler should discover the endpoints from OIDC discovery document."""
+ # This would throw if some metadata were invalid
+ metadata = yield defer.ensureDeferred(self.handler.load_metadata())
+ self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
+
+ self.assertEqual(metadata.issuer, ISSUER)
+ self.assertEqual(metadata.authorization_endpoint, AUTHORIZATION_ENDPOINT)
+ self.assertEqual(metadata.token_endpoint, TOKEN_ENDPOINT)
+ self.assertEqual(metadata.jwks_uri, JWKS_URI)
+ # FIXME: it seems like authlib does not have that defined in its metadata models
+ # self.assertEqual(metadata.userinfo_endpoint, USERINFO_ENDPOINT)
+
+ # subsequent calls should be cached
+ self.http_client.reset_mock()
+ yield defer.ensureDeferred(self.handler.load_metadata())
+ self.http_client.get_json.assert_not_called()
+
+ @override_config({"oidc_config": COMMON_CONFIG})
+ @defer.inlineCallbacks
+ def test_no_discovery(self):
+ """When discovery is disabled, it should not try to load from discovery document."""
+ yield defer.ensureDeferred(self.handler.load_metadata())
+ self.http_client.get_json.assert_not_called()
+
+ @override_config({"oidc_config": COMMON_CONFIG})
+ @defer.inlineCallbacks
+ def test_load_jwks(self):
+ """JWKS loading is done once (then cached) if used."""
+ jwks = yield defer.ensureDeferred(self.handler.load_jwks())
+ self.http_client.get_json.assert_called_once_with(JWKS_URI)
+ self.assertEqual(jwks, {"keys": []})
+
+ # subsequent calls should be cached…
+ self.http_client.reset_mock()
+ yield defer.ensureDeferred(self.handler.load_jwks())
+ self.http_client.get_json.assert_not_called()
+
+ # …unless forced
+ self.http_client.reset_mock()
+ yield defer.ensureDeferred(self.handler.load_jwks(force=True))
+ self.http_client.get_json.assert_called_once_with(JWKS_URI)
+
+ # Throw if the JWKS uri is missing
+ with self.metadata_edit({"jwks_uri": None}):
+ with self.assertRaises(RuntimeError):
+ yield defer.ensureDeferred(self.handler.load_jwks(force=True))
+
+ # Return empty key set if JWKS are not used
+ self.handler._scopes = [] # not asking the openid scope
+ self.http_client.get_json.reset_mock()
+ jwks = yield defer.ensureDeferred(self.handler.load_jwks(force=True))
+ self.http_client.get_json.assert_not_called()
+ self.assertEqual(jwks, {"keys": []})
+
+ @override_config({"oidc_config": COMMON_CONFIG})
+ def test_validate_config(self):
+ """Provider metadatas are extensively validated."""
+ h = self.handler
+
+ # Default test config does not throw
+ h._validate_metadata()
+
+ with self.metadata_edit({"issuer": None}):
+ self.assertRaisesRegex(ValueError, "issuer", h._validate_metadata)
+
+ with self.metadata_edit({"issuer": "http://insecure/"}):
+ self.assertRaisesRegex(ValueError, "issuer", h._validate_metadata)
+
+ with self.metadata_edit({"issuer": "https://invalid/?because=query"}):
+ self.assertRaisesRegex(ValueError, "issuer", h._validate_metadata)
+
+ with self.metadata_edit({"authorization_endpoint": None}):
+ self.assertRaisesRegex(
+ ValueError, "authorization_endpoint", h._validate_metadata
+ )
+
+ with self.metadata_edit({"authorization_endpoint": "http://insecure/auth"}):
+ self.assertRaisesRegex(
+ ValueError, "authorization_endpoint", h._validate_metadata
+ )
+
+ with self.metadata_edit({"token_endpoint": None}):
+ self.assertRaisesRegex(ValueError, "token_endpoint", h._validate_metadata)
+
+ with self.metadata_edit({"token_endpoint": "http://insecure/token"}):
+ self.assertRaisesRegex(ValueError, "token_endpoint", h._validate_metadata)
+
+ with self.metadata_edit({"jwks_uri": None}):
+ self.assertRaisesRegex(ValueError, "jwks_uri", h._validate_metadata)
+
+ with self.metadata_edit({"jwks_uri": "http://insecure/jwks.json"}):
+ self.assertRaisesRegex(ValueError, "jwks_uri", h._validate_metadata)
+
+ with self.metadata_edit({"response_types_supported": ["id_token"]}):
+ self.assertRaisesRegex(
+ ValueError, "response_types_supported", h._validate_metadata
+ )
+
+ with self.metadata_edit(
+ {"token_endpoint_auth_methods_supported": ["client_secret_basic"]}
+ ):
+ # should not throw, as client_secret_basic is the default auth method
+ h._validate_metadata()
+
+ with self.metadata_edit(
+ {"token_endpoint_auth_methods_supported": ["client_secret_post"]}
+ ):
+ self.assertRaisesRegex(
+ ValueError,
+ "token_endpoint_auth_methods_supported",
+ h._validate_metadata,
+ )
+
+ # Tests for configs that the userinfo endpoint
+ self.assertFalse(h._uses_userinfo)
+ h._scopes = [] # do not request the openid scope
+ self.assertTrue(h._uses_userinfo)
+ self.assertRaisesRegex(ValueError, "userinfo_endpoint", h._validate_metadata)
+
+ with self.metadata_edit(
+ {"userinfo_endpoint": USERINFO_ENDPOINT, "jwks_uri": None}
+ ):
+ # Shouldn't raise with a valid userinfo, even without
+ h._validate_metadata()
+
+ @override_config({"oidc_config": {"skip_verification": True}})
+ def test_skip_verification(self):
+ """Provider metadata validation can be disabled by config."""
+ with self.metadata_edit({"issuer": "http://insecure"}):
+ # This should not throw
+ self.handler._validate_metadata()
+
+ @defer.inlineCallbacks
+ def test_redirect_request(self):
+ """The redirect request has the right arguments & generates a valid session cookie."""
+ req = Mock(spec=["addCookie"])
+ url = yield defer.ensureDeferred(
+ self.handler.handle_redirect_request(req, b"http://client/redirect")
+ )
+ url = urlparse(url)
+ auth_endpoint = urlparse(AUTHORIZATION_ENDPOINT)
+
+ self.assertEqual(url.scheme, auth_endpoint.scheme)
+ self.assertEqual(url.netloc, auth_endpoint.netloc)
+ self.assertEqual(url.path, auth_endpoint.path)
+
+ params = parse_qs(url.query)
+ self.assertEqual(params["redirect_uri"], [CALLBACK_URL])
+ self.assertEqual(params["response_type"], ["code"])
+ self.assertEqual(params["scope"], [" ".join(SCOPES)])
+ self.assertEqual(params["client_id"], [CLIENT_ID])
+ self.assertEqual(len(params["state"]), 1)
+ self.assertEqual(len(params["nonce"]), 1)
+
+ # Check what is in the cookie
+ # note: python3.5 mock does not have the .called_once() method
+ calls = req.addCookie.call_args_list
+ self.assertEqual(len(calls), 1) # called once
+ # For some reason, call.args does not work with python3.5
+ args = calls[0][0]
+ kwargs = calls[0][1]
+ self.assertEqual(args[0], COOKIE_NAME)
+ self.assertEqual(kwargs["path"], COOKIE_PATH)
+ cookie = args[1]
+
+ macaroon = pymacaroons.Macaroon.deserialize(cookie)
+ state = self.handler._get_value_from_macaroon(macaroon, "state")
+ nonce = self.handler._get_value_from_macaroon(macaroon, "nonce")
+ redirect = self.handler._get_value_from_macaroon(
+ macaroon, "client_redirect_url"
+ )
+
+ self.assertEqual(params["state"], [state])
+ self.assertEqual(params["nonce"], [nonce])
+ self.assertEqual(redirect, "http://client/redirect")
+
+ @defer.inlineCallbacks
+ def test_callback_error(self):
+ """Errors from the provider returned in the callback are displayed."""
+ self.handler._render_error = Mock()
+ request = Mock(args={})
+ request.args[b"error"] = [b"invalid_client"]
+ yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.assertRenderedError("invalid_client", "")
+
+ request.args[b"error_description"] = [b"some description"]
+ yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.assertRenderedError("invalid_client", "some description")
+
+ @defer.inlineCallbacks
+ def test_callback(self):
+ """Code callback works and display errors if something went wrong.
+
+ A lot of scenarios are tested here:
+ - when the callback works, with userinfo from ID token
+ - when the user mapping fails
+ - when ID token verification fails
+ - when the callback works, with userinfo fetched from the userinfo endpoint
+ - when the userinfo fetching fails
+ - when the code exchange fails
+ """
+ token = {
+ "type": "bearer",
+ "id_token": "id_token",
+ "access_token": "access_token",
+ }
+ userinfo = {
+ "sub": "foo",
+ "preferred_username": "bar",
+ }
+ user_id = UserID("foo", "domain.org")
+ self.handler._render_error = Mock(return_value=None)
+ self.handler._exchange_code = simple_async_mock(return_value=token)
+ self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
+ self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
+ self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
+ self.handler._auth_handler.complete_sso_login = simple_async_mock()
+ request = Mock(spec=["args", "getCookie", "addCookie"])
+
+ code = "code"
+ state = "state"
+ nonce = "nonce"
+ client_redirect_url = "http://client/redirect"
+ session = self.handler._generate_oidc_session_token(
+ state=state,
+ nonce=nonce,
+ client_redirect_url=client_redirect_url,
+ ui_auth_session_id=None,
+ )
+ request.getCookie.return_value = session
+
+ request.args = {}
+ request.args[b"code"] = [code.encode("utf-8")]
+ request.args[b"state"] = [state.encode("utf-8")]
+
+ yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+
+ self.handler._auth_handler.complete_sso_login.assert_called_once_with(
+ user_id, request, client_redirect_url,
+ )
+ self.handler._exchange_code.assert_called_once_with(code)
+ self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce)
+ self.handler._map_userinfo_to_user.assert_called_once_with(userinfo, token)
+ self.handler._fetch_userinfo.assert_not_called()
+ self.handler._render_error.assert_not_called()
+
+ # Handle mapping errors
+ self.handler._map_userinfo_to_user = simple_async_mock(
+ raises=MappingException()
+ )
+ yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.assertRenderedError("mapping_error")
+ self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
+
+ # Handle ID token errors
+ self.handler._parse_id_token = simple_async_mock(raises=Exception())
+ yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.assertRenderedError("invalid_token")
+
+ self.handler._auth_handler.complete_sso_login.reset_mock()
+ self.handler._exchange_code.reset_mock()
+ self.handler._parse_id_token.reset_mock()
+ self.handler._map_userinfo_to_user.reset_mock()
+ self.handler._fetch_userinfo.reset_mock()
+
+ # With userinfo fetching
+ self.handler._scopes = [] # do not ask the "openid" scope
+ yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+
+ self.handler._auth_handler.complete_sso_login.assert_called_once_with(
+ user_id, request, client_redirect_url,
+ )
+ self.handler._exchange_code.assert_called_once_with(code)
+ self.handler._parse_id_token.assert_not_called()
+ self.handler._map_userinfo_to_user.assert_called_once_with(userinfo, token)
+ self.handler._fetch_userinfo.assert_called_once_with(token)
+ self.handler._render_error.assert_not_called()
+
+ # Handle userinfo fetching error
+ self.handler._fetch_userinfo = simple_async_mock(raises=Exception())
+ yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.assertRenderedError("fetch_error")
+
+ # Handle code exchange failure
+ self.handler._exchange_code = simple_async_mock(
+ raises=OidcError("invalid_request")
+ )
+ yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.assertRenderedError("invalid_request")
+
+ @defer.inlineCallbacks
+ def test_callback_session(self):
+ """The callback verifies the session presence and validity"""
+ self.handler._render_error = Mock(return_value=None)
+ request = Mock(spec=["args", "getCookie", "addCookie"])
+
+ # Missing cookie
+ request.args = {}
+ request.getCookie.return_value = None
+ yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.assertRenderedError("missing_session", "No session cookie found")
+
+ # Missing session parameter
+ request.args = {}
+ request.getCookie.return_value = "session"
+ yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.assertRenderedError("invalid_request", "State parameter is missing")
+
+ # Invalid cookie
+ request.args = {}
+ request.args[b"state"] = [b"state"]
+ request.getCookie.return_value = "session"
+ yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.assertRenderedError("invalid_session")
+
+ # Mismatching session
+ session = self.handler._generate_oidc_session_token(
+ state="state",
+ nonce="nonce",
+ client_redirect_url="http://client/redirect",
+ ui_auth_session_id=None,
+ )
+ request.args = {}
+ request.args[b"state"] = [b"mismatching state"]
+ request.getCookie.return_value = session
+ yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.assertRenderedError("mismatching_session")
+
+ # Valid session
+ request.args = {}
+ request.args[b"state"] = [b"state"]
+ request.getCookie.return_value = session
+ yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.assertRenderedError("invalid_request")
+
+ @override_config({"oidc_config": {"client_auth_method": "client_secret_post"}})
+ @defer.inlineCallbacks
+ def test_exchange_code(self):
+ """Code exchange behaves correctly and handles various error scenarios."""
+ token = {"type": "bearer"}
+ token_json = json.dumps(token).encode("utf-8")
+ self.http_client.request = simple_async_mock(
+ return_value=FakeResponse(code=200, phrase=b"OK", body=token_json)
+ )
+ code = "code"
+ ret = yield defer.ensureDeferred(self.handler._exchange_code(code))
+ kwargs = self.http_client.request.call_args[1]
+
+ self.assertEqual(ret, token)
+ self.assertEqual(kwargs["method"], "POST")
+ self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT)
+
+ args = parse_qs(kwargs["data"].decode("utf-8"))
+ self.assertEqual(args["grant_type"], ["authorization_code"])
+ self.assertEqual(args["code"], [code])
+ self.assertEqual(args["client_id"], [CLIENT_ID])
+ self.assertEqual(args["client_secret"], [CLIENT_SECRET])
+ self.assertEqual(args["redirect_uri"], [CALLBACK_URL])
+
+ # Test error handling
+ self.http_client.request = simple_async_mock(
+ return_value=FakeResponse(
+ code=400,
+ phrase=b"Bad Request",
+ body=b'{"error": "foo", "error_description": "bar"}',
+ )
+ )
+ with self.assertRaises(OidcError) as exc:
+ yield defer.ensureDeferred(self.handler._exchange_code(code))
+ self.assertEqual(exc.exception.error, "foo")
+ self.assertEqual(exc.exception.error_description, "bar")
+
+ # Internal server error with no JSON body
+ self.http_client.request = simple_async_mock(
+ return_value=FakeResponse(
+ code=500, phrase=b"Internal Server Error", body=b"Not JSON",
+ )
+ )
+ with self.assertRaises(OidcError) as exc:
+ yield defer.ensureDeferred(self.handler._exchange_code(code))
+ self.assertEqual(exc.exception.error, "server_error")
+
+ # Internal server error with JSON body
+ self.http_client.request = simple_async_mock(
+ return_value=FakeResponse(
+ code=500,
+ phrase=b"Internal Server Error",
+ body=b'{"error": "internal_server_error"}',
+ )
+ )
+ with self.assertRaises(OidcError) as exc:
+ yield defer.ensureDeferred(self.handler._exchange_code(code))
+ self.assertEqual(exc.exception.error, "internal_server_error")
+
+ # 4xx error without "error" field
+ self.http_client.request = simple_async_mock(
+ return_value=FakeResponse(code=400, phrase=b"Bad request", body=b"{}",)
+ )
+ with self.assertRaises(OidcError) as exc:
+ yield defer.ensureDeferred(self.handler._exchange_code(code))
+ self.assertEqual(exc.exception.error, "server_error")
+
+ # 2xx error with "error" field
+ self.http_client.request = simple_async_mock(
+ return_value=FakeResponse(
+ code=200, phrase=b"OK", body=b'{"error": "some_error"}',
+ )
+ )
+ with self.assertRaises(OidcError) as exc:
+ yield defer.ensureDeferred(self.handler._exchange_code(code))
+ self.assertEqual(exc.exception.error, "some_error")
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 51e2b37218..2fa8d4739b 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -86,7 +86,10 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
reactor.pump((1000,))
hs = self.setup_test_homeserver(
- notifier=Mock(), http_client=mock_federation_client, keyring=mock_keyring
+ notifier=Mock(),
+ http_client=mock_federation_client,
+ keyring=mock_keyring,
+ replication_streams={},
)
hs.datastores = datastores
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 7b92bdbc47..572df8d80b 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -185,7 +185,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
# Allow all users.
return False
- spam_checker.spam_checker = AllowAll()
+ spam_checker.spam_checkers = [AllowAll()]
# The results do not change:
# We get one search result when searching for user2 by user1.
@@ -198,7 +198,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
# All users are spammy.
return True
- spam_checker.spam_checker = BlockAll()
+ spam_checker.spam_checkers = [BlockAll()]
# User1 now gets no search results for any of the other users.
s = self.get_success(self.handler.search_users(u1, "user2", 10))
diff --git a/tests/replication/tcp/streams/_base.py b/tests/replication/_base.py
index 7b56d2028d..9d4f0bbe44 100644
--- a/tests/replication/tcp/streams/_base.py
+++ b/tests/replication/_base.py
@@ -27,6 +27,7 @@ from synapse.app.generic_worker import (
GenericWorkerServer,
)
from synapse.http.site import SynapseRequest
+from synapse.replication.http import streams
from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
@@ -42,6 +43,10 @@ logger = logging.getLogger(__name__)
class BaseStreamTestCase(unittest.HomeserverTestCase):
"""Base class for tests of the replication streams"""
+ servlets = [
+ streams.register_servlets,
+ ]
+
def prepare(self, reactor, clock, hs):
# build a replication server
server_factory = ReplicationStreamProtocolFactory(hs)
@@ -49,17 +54,11 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self.server = server_factory.buildProtocol(None)
# Make a new HomeServer object for the worker
- config = self.default_config()
- config["worker_app"] = "synapse.app.generic_worker"
- config["worker_replication_host"] = "testserv"
- config["worker_replication_http_port"] = "8765"
-
self.reactor.lookups["testserv"] = "1.2.3.4"
-
self.worker_hs = self.setup_test_homeserver(
http_client=None,
homeserverToUse=GenericWorkerServer,
- config=config,
+ config=self._get_worker_hs_config(),
reactor=self.reactor,
)
@@ -78,6 +77,13 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self._client_transport = None
self._server_transport = None
+ def _get_worker_hs_config(self) -> dict:
+ config = self.default_config()
+ config["worker_app"] = "synapse.app.generic_worker"
+ config["worker_replication_host"] = "testserv"
+ config["worker_replication_http_port"] = "8765"
+ return config
+
def _build_replication_data_handler(self):
return TestReplicationDataHandler(self.worker_hs)
diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py
index 1615dfab5e..32cb04645f 100644
--- a/tests/replication/slave/storage/_base.py
+++ b/tests/replication/slave/storage/_base.py
@@ -15,23 +15,13 @@
from mock import Mock, NonCallableMock
-from synapse.replication.tcp.client import (
- DirectTcpReplicationClientFactory,
- ReplicationDataHandler,
-)
-from synapse.replication.tcp.handler import ReplicationCommandHandler
-from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
-from synapse.storage.database import make_conn
+from tests.replication._base import BaseStreamTestCase
-from tests import unittest
-from tests.server import FakeTransport
-
-class BaseSlavedStoreTestCase(unittest.HomeserverTestCase):
+class BaseSlavedStoreTestCase(BaseStreamTestCase):
def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver(
- "blue",
federation_client=Mock(),
ratelimiter=NonCallableMock(spec_set=["can_do_action"]),
)
@@ -41,39 +31,13 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase):
return hs
def prepare(self, reactor, clock, hs):
+ super().prepare(reactor, clock, hs)
- db_config = hs.config.database.get_single_database()
- self.master_store = self.hs.get_datastore()
- self.storage = hs.get_storage()
- database = hs.get_datastores().databases[0]
- self.slaved_store = self.STORE_TYPE(
- database, make_conn(db_config, database.engine), self.hs
- )
- self.event_id = 0
-
- server_factory = ReplicationStreamProtocolFactory(self.hs)
- self.streamer = hs.get_replication_streamer()
-
- # We now do some gut wrenching so that we have a client that is based
- # off of the slave store rather than the main store.
- self.replication_handler = ReplicationCommandHandler(self.hs)
- self.replication_handler._instance_name = "worker"
- self.replication_handler._replication_data_handler = ReplicationDataHandler(
- self.slaved_store
- )
+ self.reconnect()
- client_factory = DirectTcpReplicationClientFactory(
- self.hs, "client_name", self.replication_handler
- )
- client_factory.handler = self.replication_handler
-
- server = server_factory.buildProtocol(None)
- client = client_factory.buildProtocol(None)
-
- client.makeConnection(FakeTransport(server, reactor))
-
- self.server_to_client_transport = FakeTransport(client, reactor)
- server.makeConnection(self.server_to_client_transport)
+ self.master_store = hs.get_datastore()
+ self.slaved_store = self.worker_hs.get_datastore()
+ self.storage = hs.get_storage()
def replicate(self):
"""Tell the master side of replication that something has happened, and then
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index f0561b30e3..1a88c7fb80 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -17,17 +17,18 @@ from canonicaljson import encode_canonical_json
from synapse.api.room_versions import RoomVersions
from synapse.events import FrozenEvent, _EventInternalMetadata, make_event_from_dict
-from synapse.events.snapshot import EventContext
from synapse.handlers.room import RoomEventSource
from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.storage.roommember import RoomsForUser
+from tests.server import FakeTransport
+
from ._base import BaseSlavedStoreTestCase
-USER_ID = "@feeling:blue"
-USER_ID_2 = "@bright:blue"
+USER_ID = "@feeling:test"
+USER_ID_2 = "@bright:test"
OUTLIER = {"outlier": True}
-ROOM_ID = "!room:blue"
+ROOM_ID = "!room:test"
logger = logging.getLogger(__name__)
@@ -239,7 +240,8 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
self.check("get_rooms_for_user_with_stream_ordering", (USER_ID_2,), set())
# limit the replication rate
- repl_transport = self.server_to_client_transport
+ repl_transport = self._server_transport
+ assert isinstance(repl_transport, FakeTransport)
repl_transport.autoflush = False
# build the join and message events and persist them in the same batch.
@@ -322,7 +324,6 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
type="m.room.message",
key=None,
internal={},
- state=None,
depth=None,
prev_events=[],
auth_events=[],
@@ -362,15 +363,8 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
event = make_event_from_dict(event_dict, internal_metadata_dict=internal)
self.event_id += 1
-
- if state is not None:
- state_ids = {key: e.event_id for key, e in state.items()}
- context = EventContext.with_state(
- state_group=None, current_state_ids=state_ids, prev_state_ids=state_ids
- )
- else:
- state_handler = self.hs.get_state_handler()
- context = self.get_success(state_handler.compute_event_context(event))
+ state_handler = self.hs.get_state_handler()
+ context = self.get_success(state_handler.compute_event_context(event))
self.master_store.add_push_actions_to_staging(
event.event_id, {user_id: actions for user_id, actions in push_actions}
diff --git a/tests/replication/tcp/streams/test_account_data.py b/tests/replication/tcp/streams/test_account_data.py
new file mode 100644
index 0000000000..6a5116dd2a
--- /dev/null
+++ b/tests/replication/tcp/streams/test_account_data.py
@@ -0,0 +1,117 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.replication.tcp.streams._base import (
+ _STREAM_UPDATE_TARGET_ROW_COUNT,
+ AccountDataStream,
+)
+
+from tests.replication._base import BaseStreamTestCase
+
+
+class AccountDataStreamTestCase(BaseStreamTestCase):
+ def test_update_function_room_account_data_limit(self):
+ """Test replication with many room account data updates
+ """
+ store = self.hs.get_datastore()
+
+ # generate lots of account data updates
+ updates = []
+ for i in range(_STREAM_UPDATE_TARGET_ROW_COUNT + 5):
+ update = "m.test_type.%i" % (i,)
+ self.get_success(
+ store.add_account_data_to_room("test_user", "test_room", update, {})
+ )
+ updates.append(update)
+
+ # also one global update
+ self.get_success(store.add_account_data_for_user("test_user", "m.global", {}))
+
+ # tell the notifier to catch up to avoid duplicate rows.
+ # workaround for https://github.com/matrix-org/synapse/issues/7360
+ # FIXME remove this when the above is fixed
+ self.replicate()
+
+ # check we're testing what we think we are: no rows should yet have been
+ # received
+ self.assertEqual([], self.test_handler.received_rdata_rows)
+
+ # now reconnect to pull the updates
+ self.reconnect()
+ self.replicate()
+
+ # we should have received all the expected rows in the right order
+ received_rows = self.test_handler.received_rdata_rows
+
+ for t in updates:
+ (stream_name, token, row) = received_rows.pop(0)
+ self.assertEqual(stream_name, AccountDataStream.NAME)
+ self.assertIsInstance(row, AccountDataStream.AccountDataStreamRow)
+ self.assertEqual(row.data_type, t)
+ self.assertEqual(row.room_id, "test_room")
+
+ (stream_name, token, row) = received_rows.pop(0)
+ self.assertIsInstance(row, AccountDataStream.AccountDataStreamRow)
+ self.assertEqual(row.data_type, "m.global")
+ self.assertIsNone(row.room_id)
+
+ self.assertEqual([], received_rows)
+
+ def test_update_function_global_account_data_limit(self):
+ """Test replication with many global account data updates
+ """
+ store = self.hs.get_datastore()
+
+ # generate lots of account data updates
+ updates = []
+ for i in range(_STREAM_UPDATE_TARGET_ROW_COUNT + 5):
+ update = "m.test_type.%i" % (i,)
+ self.get_success(store.add_account_data_for_user("test_user", update, {}))
+ updates.append(update)
+
+ # also one per-room update
+ self.get_success(
+ store.add_account_data_to_room("test_user", "test_room", "m.per_room", {})
+ )
+
+ # tell the notifier to catch up to avoid duplicate rows.
+ # workaround for https://github.com/matrix-org/synapse/issues/7360
+ # FIXME remove this when the above is fixed
+ self.replicate()
+
+ # check we're testing what we think we are: no rows should yet have been
+ # received
+ self.assertEqual([], self.test_handler.received_rdata_rows)
+
+ # now reconnect to pull the updates
+ self.reconnect()
+ self.replicate()
+
+ # we should have received all the expected rows in the right order
+ received_rows = self.test_handler.received_rdata_rows
+
+ for t in updates:
+ (stream_name, token, row) = received_rows.pop(0)
+ self.assertEqual(stream_name, AccountDataStream.NAME)
+ self.assertIsInstance(row, AccountDataStream.AccountDataStreamRow)
+ self.assertEqual(row.data_type, t)
+ self.assertIsNone(row.room_id)
+
+ (stream_name, token, row) = received_rows.pop(0)
+ self.assertIsInstance(row, AccountDataStream.AccountDataStreamRow)
+ self.assertEqual(row.data_type, "m.per_room")
+ self.assertEqual(row.room_id, "test_room")
+
+ self.assertEqual([], received_rows)
diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py
index 8bd67bb9f1..51bf0ef4e9 100644
--- a/tests/replication/tcp/streams/test_events.py
+++ b/tests/replication/tcp/streams/test_events.py
@@ -26,7 +26,7 @@ from synapse.replication.tcp.streams.events import (
from synapse.rest import admin
from synapse.rest.client.v1 import login, room
-from tests.replication.tcp.streams._base import BaseStreamTestCase
+from tests.replication._base import BaseStreamTestCase
from tests.test_utils.event_injection import inject_event, inject_member_event
diff --git a/tests/replication/tcp/streams/test_federation.py b/tests/replication/tcp/streams/test_federation.py
new file mode 100644
index 0000000000..2babea4e3e
--- /dev/null
+++ b/tests/replication/tcp/streams/test_federation.py
@@ -0,0 +1,81 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.federation.send_queue import EduRow
+from synapse.replication.tcp.streams.federation import FederationStream
+
+from tests.replication._base import BaseStreamTestCase
+
+
+class FederationStreamTestCase(BaseStreamTestCase):
+ def _get_worker_hs_config(self) -> dict:
+ # enable federation sending on the worker
+ config = super()._get_worker_hs_config()
+ # TODO: make it so we don't need both of these
+ config["send_federation"] = True
+ config["worker_app"] = "synapse.app.federation_sender"
+ return config
+
+ def test_catchup(self):
+ """Basic test of catchup on reconnect
+
+ Makes sure that updates sent while we are offline are received later.
+ """
+ fed_sender = self.hs.get_federation_sender()
+ received_rows = self.test_handler.received_rdata_rows
+
+ fed_sender.build_and_send_edu("testdest", "m.test_edu", {"a": "b"})
+
+ self.reconnect()
+ self.reactor.advance(0)
+
+ # check we're testing what we think we are: no rows should yet have been
+ # received
+ self.assertEqual(received_rows, [])
+
+ # We should now see an attempt to connect to the master
+ request = self.handle_http_replication_attempt()
+ self.assert_request_is_get_repl_stream_updates(request, "federation")
+
+ # we should have received an update row
+ stream_name, token, row = received_rows.pop()
+ self.assertEqual(stream_name, "federation")
+ self.assertIsInstance(row, FederationStream.FederationStreamRow)
+ self.assertEqual(row.type, EduRow.TypeId)
+ edurow = EduRow.from_data(row.data)
+ self.assertEqual(edurow.edu.edu_type, "m.test_edu")
+ self.assertEqual(edurow.edu.origin, self.hs.hostname)
+ self.assertEqual(edurow.edu.destination, "testdest")
+ self.assertEqual(edurow.edu.content, {"a": "b"})
+
+ self.assertEqual(received_rows, [])
+
+ # additional updates should be transferred without an HTTP hit
+ fed_sender.build_and_send_edu("testdest", "m.test1", {"c": "d"})
+ self.reactor.advance(0)
+ # there should be no http hit
+ self.assertEqual(len(self.reactor.tcpClients), 0)
+ # ... but we should have a row
+ self.assertEqual(len(received_rows), 1)
+
+ stream_name, token, row = received_rows.pop()
+ self.assertEqual(stream_name, "federation")
+ self.assertIsInstance(row, FederationStream.FederationStreamRow)
+ self.assertEqual(row.type, EduRow.TypeId)
+ edurow = EduRow.from_data(row.data)
+ self.assertEqual(edurow.edu.edu_type, "m.test1")
+ self.assertEqual(edurow.edu.origin, self.hs.hostname)
+ self.assertEqual(edurow.edu.destination, "testdest")
+ self.assertEqual(edurow.edu.content, {"c": "d"})
diff --git a/tests/replication/tcp/streams/test_receipts.py b/tests/replication/tcp/streams/test_receipts.py
index 5853314fd4..56b062ecc1 100644
--- a/tests/replication/tcp/streams/test_receipts.py
+++ b/tests/replication/tcp/streams/test_receipts.py
@@ -19,7 +19,7 @@ from mock import Mock
from synapse.replication.tcp.streams._base import ReceiptsStream
-from tests.replication.tcp.streams._base import BaseStreamTestCase
+from tests.replication._base import BaseStreamTestCase
USER_ID = "@feeling:blue"
diff --git a/tests/replication/tcp/streams/test_typing.py b/tests/replication/tcp/streams/test_typing.py
index d25a7b194e..fd62b26356 100644
--- a/tests/replication/tcp/streams/test_typing.py
+++ b/tests/replication/tcp/streams/test_typing.py
@@ -15,19 +15,14 @@
from mock import Mock
from synapse.handlers.typing import RoomMember
-from synapse.replication.http import streams
from synapse.replication.tcp.streams import TypingStream
-from tests.replication.tcp.streams._base import BaseStreamTestCase
+from tests.replication._base import BaseStreamTestCase
USER_ID = "@feeling:blue"
class TypingStreamTestCase(BaseStreamTestCase):
- servlets = [
- streams.register_servlets,
- ]
-
def _build_replication_data_handler(self):
return Mock(wraps=super()._build_replication_data_handler())
diff --git a/tests/replication/tcp/test_commands.py b/tests/replication/tcp/test_commands.py
index 7ddfd0a733..60c10a441a 100644
--- a/tests/replication/tcp/test_commands.py
+++ b/tests/replication/tcp/test_commands.py
@@ -30,7 +30,7 @@ class ParseCommandTestCase(TestCase):
def test_parse_rdata(self):
line = 'RDATA events master 6287863 ["ev", ["$eventid", "!roomid", "type", null, null, null]]'
cmd = parse_command_from_line(line)
- self.assertIsInstance(cmd, RdataCommand)
+ assert isinstance(cmd, RdataCommand)
self.assertEqual(cmd.stream_name, "events")
self.assertEqual(cmd.instance_name, "master")
self.assertEqual(cmd.token, 6287863)
@@ -38,7 +38,7 @@ class ParseCommandTestCase(TestCase):
def test_parse_rdata_batch(self):
line = 'RDATA presence master batch ["@foo:example.com", "online"]'
cmd = parse_command_from_line(line)
- self.assertIsInstance(cmd, RdataCommand)
+ assert isinstance(cmd, RdataCommand)
self.assertEqual(cmd.stream_name, "presence")
self.assertEqual(cmd.instance_name, "master")
self.assertIsNone(cmd.token)
diff --git a/tests/replication/test_federation_ack.py b/tests/replication/test_federation_ack.py
new file mode 100644
index 0000000000..5448d9f0dc
--- /dev/null
+++ b/tests/replication/test_federation_ack.py
@@ -0,0 +1,71 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import mock
+
+from synapse.app.generic_worker import GenericWorkerServer
+from synapse.replication.tcp.commands import FederationAckCommand
+from synapse.replication.tcp.protocol import AbstractConnection
+from synapse.replication.tcp.streams.federation import FederationStream
+
+from tests.unittest import HomeserverTestCase
+
+
+class FederationAckTestCase(HomeserverTestCase):
+ def default_config(self) -> dict:
+ config = super().default_config()
+ config["worker_app"] = "synapse.app.federation_sender"
+ config["send_federation"] = True
+ return config
+
+ def make_homeserver(self, reactor, clock):
+ hs = self.setup_test_homeserver(homeserverToUse=GenericWorkerServer)
+ return hs
+
+ def test_federation_ack_sent(self):
+ """A FEDERATION_ACK should be sent back after each RDATA federation
+
+ This test checks that the federation sender is correctly sending back
+ FEDERATION_ACK messages. The test works by spinning up a federation_sender
+ worker server, and then fishing out its ReplicationCommandHandler. We wire
+ the RCH up to a mock connection (so that we can observe the command being sent)
+ and then poke in an RDATA row.
+
+ XXX: it might be nice to do this by pretending to be a synapse master worker
+ (or a redis server), and having the worker connect to us via a mocked-up TCP
+ transport, rather than assuming that the implementation has a
+ ReplicationCommandHandler.
+ """
+ rch = self.hs.get_tcp_replication()
+
+ # wire up the ReplicationCommandHandler to a mock connection
+ mock_connection = mock.Mock(spec=AbstractConnection)
+ rch.new_connection(mock_connection)
+
+ # tell it it received an RDATA row
+ self.get_success(
+ rch.on_rdata(
+ "federation",
+ "master",
+ token=10,
+ rows=[FederationStream.FederationStreamRow(type="x", data=[1, 2, 3])],
+ )
+ )
+
+ # now check that the FEDERATION_ACK was sent
+ mock_connection.send_command.assert_called_once()
+ cmd = mock_connection.send_command.call_args[0][0]
+ assert isinstance(cmd, FederationAckCommand)
+ self.assertEqual(cmd.token, 10)
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index 249c93722f..54cd24bf64 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -701,6 +701,47 @@ class RoomTestCase(unittest.HomeserverTestCase):
_search_test(None, "bar")
_search_test(None, "", expected_http_code=400)
+ def test_single_room(self):
+ """Test that a single room can be requested correctly"""
+ # Create two test rooms
+ room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+ room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+ room_name_1 = "something"
+ room_name_2 = "else"
+
+ # Set the name for each room
+ self.helper.send_state(
+ room_id_1, "m.room.name", {"name": room_name_1}, tok=self.admin_user_tok,
+ )
+ self.helper.send_state(
+ room_id_2, "m.room.name", {"name": room_name_2}, tok=self.admin_user_tok,
+ )
+
+ url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,)
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ self.assertIn("room_id", channel.json_body)
+ self.assertIn("name", channel.json_body)
+ self.assertIn("canonical_alias", channel.json_body)
+ self.assertIn("joined_members", channel.json_body)
+ self.assertIn("joined_local_members", channel.json_body)
+ self.assertIn("version", channel.json_body)
+ self.assertIn("creator", channel.json_body)
+ self.assertIn("encryption", channel.json_body)
+ self.assertIn("federatable", channel.json_body)
+ self.assertIn("public", channel.json_body)
+ self.assertIn("join_rules", channel.json_body)
+ self.assertIn("guest_access", channel.json_body)
+ self.assertIn("history_visibility", channel.json_body)
+ self.assertIn("state_events", channel.json_body)
+
+ self.assertEqual(room_id_1, channel.json_body["room_id"])
+
class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index 1856c7ffd5..eb8f6264fd 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/v1/test_login.py
@@ -4,7 +4,7 @@ import urllib.parse
from mock import Mock
import synapse.rest.admin
-from synapse.rest.client.v1 import login
+from synapse.rest.client.v1 import login, logout
from synapse.rest.client.v2_alpha import devices
from synapse.rest.client.v2_alpha.account import WhoamiRestServlet
@@ -20,6 +20,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
login.register_servlets,
+ logout.register_servlets,
devices.register_servlets,
lambda hs, http_server: WhoamiRestServlet(hs).register(http_server),
]
@@ -256,6 +257,72 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
self.render(request)
self.assertEquals(channel.code, 200, channel.result)
+ @override_config({"session_lifetime": "24h"})
+ def test_session_can_hard_logout_after_being_soft_logged_out(self):
+ self.register_user("kermit", "monkey")
+
+ # log in as normal
+ access_token = self.login("kermit", "monkey")
+
+ # we should now be able to make requests with the access token
+ request, channel = self.make_request(
+ b"GET", TEST_URL, access_token=access_token
+ )
+ self.render(request)
+ self.assertEquals(channel.code, 200, channel.result)
+
+ # time passes
+ self.reactor.advance(24 * 3600)
+
+ # ... and we should be soft-logouted
+ request, channel = self.make_request(
+ b"GET", TEST_URL, access_token=access_token
+ )
+ self.render(request)
+ self.assertEquals(channel.code, 401, channel.result)
+ self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
+ self.assertEquals(channel.json_body["soft_logout"], True)
+
+ # Now try to hard logout this session
+ request, channel = self.make_request(
+ b"POST", "/logout", access_token=access_token
+ )
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ @override_config({"session_lifetime": "24h"})
+ def test_session_can_hard_logout_all_sessions_after_being_soft_logged_out(self):
+ self.register_user("kermit", "monkey")
+
+ # log in as normal
+ access_token = self.login("kermit", "monkey")
+
+ # we should now be able to make requests with the access token
+ request, channel = self.make_request(
+ b"GET", TEST_URL, access_token=access_token
+ )
+ self.render(request)
+ self.assertEquals(channel.code, 200, channel.result)
+
+ # time passes
+ self.reactor.advance(24 * 3600)
+
+ # ... and we should be soft-logouted
+ request, channel = self.make_request(
+ b"GET", TEST_URL, access_token=access_token
+ )
+ self.render(request)
+ self.assertEquals(channel.code, 401, channel.result)
+ self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
+ self.assertEquals(channel.json_body["soft_logout"], True)
+
+ # Now try to hard log out all of the user's sessions
+ request, channel = self.make_request(
+ b"POST", "/logout/all", access_token=access_token
+ )
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
class CASTestCase(unittest.HomeserverTestCase):
diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py
index 0d6936fd36..3ab611f618 100644
--- a/tests/rest/client/v2_alpha/test_account.py
+++ b/tests/rest/client/v2_alpha/test_account.py
@@ -46,7 +46,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
# Email config.
self.email_attempts = []
- def sendmail(smtphost, from_addr, to_addrs, msg, **kwargs):
+ async def sendmail(smtphost, from_addr, to_addrs, msg, **kwargs):
self.email_attempts.append(msg)
return
@@ -358,7 +358,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
# Email config.
self.email_attempts = []
- def sendmail(smtphost, from_addr, to_addrs, msg, **kwargs):
+ async def sendmail(smtphost, from_addr, to_addrs, msg, **kwargs):
self.email_attempts.append(msg)
config["email"] = {
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index b0f3e183e5..2e07cddfce 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -29,7 +29,7 @@ import synapse.rest.admin
from synapse.api.constants import LoginType
from synapse.api.errors import Codes
from synapse.appservice import ApplicationService
-from synapse.rest.client.v1 import login
+from synapse.rest.client.v1 import login, logout
from synapse.rest.client.v2_alpha import account, account_validity, register, sync
from tests import unittest
@@ -358,6 +358,7 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
synapse.rest.admin.register_servlets_for_client_rest_resource,
login.register_servlets,
sync.register_servlets,
+ logout.register_servlets,
account_validity.register_servlets,
account.register_servlets,
]
@@ -451,6 +452,39 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result
)
+ def test_logging_out_expired_user(self):
+ user_id = self.register_user("kermit", "monkey")
+ tok = self.login("kermit", "monkey")
+
+ self.register_user("admin", "adminpassword", admin=True)
+ admin_tok = self.login("admin", "adminpassword")
+
+ url = "/_matrix/client/unstable/admin/account_validity/validity"
+ params = {
+ "user_id": user_id,
+ "expiration_ts": 0,
+ "enable_renewal_emails": False,
+ }
+ request_data = json.dumps(params)
+ request, channel = self.make_request(
+ b"POST", url, request_data, access_token=admin_tok
+ )
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ # Try to log the user out
+ request, channel = self.make_request(b"POST", "/logout", access_token=tok)
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ # Log the user in again (allowed for expired accounts)
+ tok = self.login("kermit", "monkey")
+
+ # Try to log out all of the user's sessions
+ request, channel = self.make_request(b"POST", "/logout/all", access_token=tok)
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
class AccountValidityUserDirectoryTestCase(unittest.HomeserverTestCase):
diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py
index 406f29a7c0..99908edba3 100644
--- a/tests/server_notices/test_resource_limits_server_notices.py
+++ b/tests/server_notices/test_resource_limits_server_notices.py
@@ -27,20 +27,33 @@ from synapse.server_notices.resource_limits_server_notices import (
)
from tests import unittest
+from tests.unittest import override_config
+from tests.utils import default_config
class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
- def make_homeserver(self, reactor, clock):
- hs_config = self.default_config()
- hs_config["server_notices"] = {
- "system_mxid_localpart": "server",
- "system_mxid_display_name": "test display name",
- "system_mxid_avatar_url": None,
- "room_name": "Server Notices",
- }
+ def default_config(self):
+ config = default_config("test")
+
+ config.update(
+ {
+ "admin_contact": "mailto:user@test.com",
+ "limit_usage_by_mau": True,
+ "server_notices": {
+ "system_mxid_localpart": "server",
+ "system_mxid_display_name": "test display name",
+ "system_mxid_avatar_url": None,
+ "room_name": "Server Notices",
+ },
+ }
+ )
+
+ # apply any additional config which was specified via the override_config
+ # decorator.
+ if self._extra_config is not None:
+ config.update(self._extra_config)
- hs = self.setup_test_homeserver(config=hs_config)
- return hs
+ return config
def prepare(self, reactor, clock, hs):
self.server_notices_sender = self.hs.get_server_notices_sender()
@@ -60,7 +73,6 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
)
self._send_notice = self._rlsn._server_notices_manager.send_notice
- self.hs.config.limit_usage_by_mau = True
self.user_id = "@user_id:test"
self._rlsn._server_notices_manager.get_or_create_notice_room_for_user = Mock(
@@ -68,21 +80,17 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
)
self._rlsn._store.add_tag_to_room = Mock(return_value=defer.succeed(None))
self._rlsn._store.get_tags_for_room = Mock(return_value=defer.succeed({}))
- self.hs.config.admin_contact = "mailto:user@test.com"
-
- def test_maybe_send_server_notice_to_user_flag_off(self):
- """Tests cases where the flags indicate nothing to do"""
- # test hs disabled case
- self.hs.config.hs_disabled = True
+ @override_config({"hs_disabled": True})
+ def test_maybe_send_server_notice_disabled_hs(self):
+ """If the HS is disabled, we should not send notices"""
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
-
self._send_notice.assert_not_called()
- # Test when mau limiting disabled
- self.hs.config.hs_disabled = False
- self.hs.config.limit_usage_by_mau = False
- self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
+ @override_config({"limit_usage_by_mau": False})
+ def test_maybe_send_server_notice_to_user_flag_off(self):
+ """If mau limiting is disabled, we should not send notices"""
+ self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
self._send_notice.assert_not_called()
def test_maybe_send_server_notice_to_user_remove_blocked_notice(self):
@@ -153,13 +161,12 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
self._send_notice.assert_not_called()
+ @override_config({"mau_limit_alerting": False})
def test_maybe_send_server_notice_when_alerting_suppressed_room_unblocked(self):
"""
Test that when server is over MAU limit and alerting is suppressed, then
an alert message is not sent into the room
"""
- self.hs.config.mau_limit_alerting = False
-
self._rlsn._auth.check_auth_blocking = Mock(
return_value=defer.succeed(None),
side_effect=ResourceLimitError(
@@ -170,12 +177,11 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
self.assertEqual(self._send_notice.call_count, 0)
+ @override_config({"mau_limit_alerting": False})
def test_check_hs_disabled_unaffected_by_mau_alert_suppression(self):
"""
Test that when a server is disabled, that MAU limit alerting is ignored.
"""
- self.hs.config.mau_limit_alerting = False
-
self._rlsn._auth.check_auth_blocking = Mock(
return_value=defer.succeed(None),
side_effect=ResourceLimitError(
@@ -187,12 +193,12 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
# Would be better to check contents, but 2 calls == set blocking event
self.assertEqual(self._send_notice.call_count, 2)
+ @override_config({"mau_limit_alerting": False})
def test_maybe_send_server_notice_when_alerting_suppressed_room_blocked(self):
"""
When the room is already in a blocked state, test that when alerting
is suppressed that the room is returned to an unblocked state.
"""
- self.hs.config.mau_limit_alerting = False
self._rlsn._auth.check_auth_blocking = Mock(
return_value=defer.succeed(None),
side_effect=ResourceLimitError(
diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index e37260a820..5a50e4fdd4 100644
--- a/tests/storage/test__base.py
+++ b/tests/storage/test__base.py
@@ -25,8 +25,8 @@ from synapse.util.caches.descriptors import Cache, cached
from tests import unittest
-class CacheTestCase(unittest.TestCase):
- def setUp(self):
+class CacheTestCase(unittest.HomeserverTestCase):
+ def prepare(self, reactor, clock, homeserver):
self.cache = Cache("test")
def test_empty(self):
@@ -96,7 +96,7 @@ class CacheTestCase(unittest.TestCase):
cache.get(3)
-class CacheDecoratorTestCase(unittest.TestCase):
+class CacheDecoratorTestCase(unittest.HomeserverTestCase):
@defer.inlineCallbacks
def test_passthrough(self):
class A(object):
@@ -239,7 +239,7 @@ class CacheDecoratorTestCase(unittest.TestCase):
callcount2 = [0]
class A(object):
- @cached(max_entries=4) # HACK: This makes it 2 due to cache factor
+ @cached(max_entries=2)
def func(self, key):
callcount[0] += 1
return key
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index 31710949a8..ef296e7dab 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -43,7 +43,7 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
)
hs.config.app_service_config_files = self.as_yaml_files
- hs.config.event_cache_size = 1
+ hs.config.caches.event_cache_size = 1
hs.config.password_providers = []
self.as_token = "token1"
@@ -110,7 +110,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
)
hs.config.app_service_config_files = self.as_yaml_files
- hs.config.event_cache_size = 1
+ hs.config.caches.event_cache_size = 1
hs.config.password_providers = []
self.as_list = [
@@ -422,7 +422,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
)
hs.config.app_service_config_files = [f1, f2]
- hs.config.event_cache_size = 1
+ hs.config.caches.event_cache_size = 1
hs.config.password_providers = []
database = hs.get_datastores().databases[0]
@@ -440,7 +440,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
)
hs.config.app_service_config_files = [f1, f2]
- hs.config.event_cache_size = 1
+ hs.config.caches.event_cache_size = 1
hs.config.password_providers = []
with self.assertRaises(ConfigError) as cm:
@@ -464,7 +464,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
)
hs.config.app_service_config_files = [f1, f2]
- hs.config.event_cache_size = 1
+ hs.config.caches.event_cache_size = 1
hs.config.password_providers = []
with self.assertRaises(ConfigError) as cm:
diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index cdee0a9e60..278961c331 100644
--- a/tests/storage/test_base.py
+++ b/tests/storage/test_base.py
@@ -51,7 +51,8 @@ class SQLBaseStoreTestCase(unittest.TestCase):
config = Mock()
config._disable_native_upserts = True
- config.event_cache_size = 1
+ config.caches = Mock()
+ config.caches.event_cache_size = 1
hs = TestHomeServer("test", config=config)
sqlite_config = {"name": "sqlite3"}
diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py
index 0e04b2cf92..43425c969a 100644
--- a/tests/storage/test_cleanup_extrems.py
+++ b/tests/storage/test_cleanup_extrems.py
@@ -39,7 +39,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
# Create a test user and room
self.user = UserID("alice", "test")
self.requester = Requester(self.user, None, False, None, None)
- info = self.get_success(self.room_creator.create_room(self.requester, {}))
+ info, _ = self.get_success(self.room_creator.create_room(self.requester, {}))
self.room_id = info["room_id"]
def run_background_update(self):
@@ -261,7 +261,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
self.user = UserID.from_string(self.register_user("user1", "password"))
self.token1 = self.login("user1", "password")
self.requester = Requester(self.user, None, False, None, None)
- info = self.get_success(self.room_creator.create_room(self.requester, {}))
+ info, _ = self.get_success(self.room_creator.create_room(self.requester, {}))
self.room_id = info["room_id"]
self.event_creator = homeserver.get_event_creation_handler()
homeserver.config.user_consent_version = self.CONSENT_VERSION
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index bf674dd184..3b483bc7f0 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -23,6 +23,7 @@ from synapse.http.site import XForwardedForRequest
from synapse.rest.client.v1 import login
from tests import unittest
+from tests.unittest import override_config
class ClientIpStoreTestCase(unittest.HomeserverTestCase):
@@ -137,9 +138,8 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
],
)
+ @override_config({"limit_usage_by_mau": False, "max_mau_value": 50})
def test_disabled_monthly_active_user(self):
- self.hs.config.limit_usage_by_mau = False
- self.hs.config.max_mau_value = 50
user_id = "@user:server"
self.get_success(
self.store.insert_client_ip(
@@ -149,9 +149,8 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
active = self.get_success(self.store.user_last_seen_monthly_active(user_id))
self.assertFalse(active)
+ @override_config({"limit_usage_by_mau": True, "max_mau_value": 50})
def test_adding_monthly_active_user_when_full(self):
- self.hs.config.limit_usage_by_mau = True
- self.hs.config.max_mau_value = 50
lots_of_users = 100
user_id = "@user:server"
@@ -166,9 +165,8 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
active = self.get_success(self.store.user_last_seen_monthly_active(user_id))
self.assertFalse(active)
+ @override_config({"limit_usage_by_mau": True, "max_mau_value": 50})
def test_adding_monthly_active_user_when_space(self):
- self.hs.config.limit_usage_by_mau = True
- self.hs.config.max_mau_value = 50
user_id = "@user:server"
active = self.get_success(self.store.user_last_seen_monthly_active(user_id))
self.assertFalse(active)
@@ -184,9 +182,8 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
active = self.get_success(self.store.user_last_seen_monthly_active(user_id))
self.assertTrue(active)
+ @override_config({"limit_usage_by_mau": True, "max_mau_value": 50})
def test_updating_monthly_active_user_when_space(self):
- self.hs.config.limit_usage_by_mau = True
- self.hs.config.max_mau_value = 50
user_id = "@user:server"
self.get_success(self.store.register_user(user_id=user_id, password_hash=None))
diff --git a/tests/storage/test_event_metrics.py b/tests/storage/test_event_metrics.py
index a7b7fd36d3..a7b85004e5 100644
--- a/tests/storage/test_event_metrics.py
+++ b/tests/storage/test_event_metrics.py
@@ -33,7 +33,7 @@ class ExtremStatisticsTestCase(HomeserverTestCase):
events = [(3, 2), (6, 2), (4, 6)]
for event_count, extrems in events:
- info = self.get_success(room_creator.create_room(requester, {}))
+ info, _ = self.get_success(room_creator.create_room(requester, {}))
room_id = info["room_id"]
last_event = None
diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index d4bcf1821e..b45bc9c115 100644
--- a/tests/storage/test_event_push_actions.py
+++ b/tests/storage/test_event_push_actions.py
@@ -35,6 +35,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
def setUp(self):
hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
self.store = hs.get_datastore()
+ self.persist_events_store = hs.get_datastores().persist_events
@defer.inlineCallbacks
def test_get_unread_push_actions_for_user_in_range_for_http(self):
@@ -76,7 +77,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
)
yield self.store.db.runInteraction(
"",
- self.store._set_push_actions_for_event_and_users_txn,
+ self.persist_events_store._set_push_actions_for_event_and_users_txn,
[(event, None)],
[(event, None)],
)
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
new file mode 100644
index 0000000000..55e9ecf264
--- /dev/null
+++ b/tests/storage/test_id_generators.py
@@ -0,0 +1,184 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from synapse.storage.database import Database
+from synapse.storage.util.id_generators import MultiWriterIdGenerator
+
+from tests.unittest import HomeserverTestCase
+from tests.utils import USE_POSTGRES_FOR_TESTS
+
+
+class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
+ if not USE_POSTGRES_FOR_TESTS:
+ skip = "Requires Postgres"
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+ self.db = self.store.db # type: Database
+
+ self.get_success(self.db.runInteraction("_setup_db", self._setup_db))
+
+ def _setup_db(self, txn):
+ txn.execute("CREATE SEQUENCE foobar_seq")
+ txn.execute(
+ """
+ CREATE TABLE foobar (
+ stream_id BIGINT NOT NULL,
+ instance_name TEXT NOT NULL,
+ data TEXT
+ );
+ """
+ )
+
+ def _create_id_generator(self, instance_name="master") -> MultiWriterIdGenerator:
+ def _create(conn):
+ return MultiWriterIdGenerator(
+ conn,
+ self.db,
+ instance_name=instance_name,
+ table="foobar",
+ instance_column="instance_name",
+ id_column="stream_id",
+ sequence_name="foobar_seq",
+ )
+
+ return self.get_success(self.db.runWithConnection(_create))
+
+ def _insert_rows(self, instance_name: str, number: int):
+ def _insert(txn):
+ for _ in range(number):
+ txn.execute(
+ "INSERT INTO foobar VALUES (nextval('foobar_seq'), ?)",
+ (instance_name,),
+ )
+
+ self.get_success(self.db.runInteraction("test_single_instance", _insert))
+
+ def test_empty(self):
+ """Test an ID generator against an empty database gives sensible
+ current positions.
+ """
+
+ id_gen = self._create_id_generator()
+
+ # The table is empty so we expect an empty map for positions
+ self.assertEqual(id_gen.get_positions(), {})
+
+ def test_single_instance(self):
+ """Test that reads and writes from a single process are handled
+ correctly.
+ """
+
+ # Prefill table with 7 rows written by 'master'
+ self._insert_rows("master", 7)
+
+ id_gen = self._create_id_generator()
+
+ self.assertEqual(id_gen.get_positions(), {"master": 7})
+ self.assertEqual(id_gen.get_current_token("master"), 7)
+
+ # Try allocating a new ID gen and check that we only see position
+ # advanced after we leave the context manager.
+
+ async def _get_next_async():
+ with await id_gen.get_next() as stream_id:
+ self.assertEqual(stream_id, 8)
+
+ self.assertEqual(id_gen.get_positions(), {"master": 7})
+ self.assertEqual(id_gen.get_current_token("master"), 7)
+
+ self.get_success(_get_next_async())
+
+ self.assertEqual(id_gen.get_positions(), {"master": 8})
+ self.assertEqual(id_gen.get_current_token("master"), 8)
+
+ def test_multi_instance(self):
+ """Test that reads and writes from multiple processes are handled
+ correctly.
+ """
+ self._insert_rows("first", 3)
+ self._insert_rows("second", 4)
+
+ first_id_gen = self._create_id_generator("first")
+ second_id_gen = self._create_id_generator("second")
+
+ self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
+ self.assertEqual(first_id_gen.get_current_token("first"), 3)
+ self.assertEqual(first_id_gen.get_current_token("second"), 7)
+
+ # Try allocating a new ID gen and check that we only see position
+ # advanced after we leave the context manager.
+
+ async def _get_next_async():
+ with await first_id_gen.get_next() as stream_id:
+ self.assertEqual(stream_id, 8)
+
+ self.assertEqual(
+ first_id_gen.get_positions(), {"first": 3, "second": 7}
+ )
+
+ self.get_success(_get_next_async())
+
+ self.assertEqual(first_id_gen.get_positions(), {"first": 8, "second": 7})
+
+ # However the ID gen on the second instance won't have seen the update
+ self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7})
+
+ # ... but calling `get_next` on the second instance should give a unique
+ # stream ID
+
+ async def _get_next_async():
+ with await second_id_gen.get_next() as stream_id:
+ self.assertEqual(stream_id, 9)
+
+ self.assertEqual(
+ second_id_gen.get_positions(), {"first": 3, "second": 7}
+ )
+
+ self.get_success(_get_next_async())
+
+ self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 9})
+
+ # If the second ID gen gets told about the first, it correctly updates
+ second_id_gen.advance("first", 8)
+ self.assertEqual(second_id_gen.get_positions(), {"first": 8, "second": 9})
+
+ def test_get_next_txn(self):
+ """Test that the `get_next_txn` function works correctly.
+ """
+
+ # Prefill table with 7 rows written by 'master'
+ self._insert_rows("master", 7)
+
+ id_gen = self._create_id_generator()
+
+ self.assertEqual(id_gen.get_positions(), {"master": 7})
+ self.assertEqual(id_gen.get_current_token("master"), 7)
+
+ # Try allocating a new ID gen and check that we only see position
+ # advanced after we leave the context manager.
+
+ def _get_next_txn(txn):
+ stream_id = id_gen.get_next_txn(txn)
+ self.assertEqual(stream_id, 8)
+
+ self.assertEqual(id_gen.get_positions(), {"master": 7})
+ self.assertEqual(id_gen.get_current_token("master"), 7)
+
+ self.get_success(self.db.runInteraction("test", _get_next_txn))
+
+ self.assertEqual(id_gen.get_positions(), {"master": 8})
+ self.assertEqual(id_gen.get_current_token("master"), 8)
diff --git a/tests/storage/test_main.py b/tests/storage/test_main.py
index 3c79eebbfa..0155ffd04e 100644
--- a/tests/storage/test_main.py
+++ b/tests/storage/test_main.py
@@ -36,7 +36,9 @@ class DataStoreTestCase(unittest.TestCase):
def test_get_users_paginate(self):
yield self.store.register_user(self.user.to_string(), "pass")
yield self.store.create_profile(self.user.localpart)
- yield self.store.set_profile_displayname(self.user.localpart, self.displayname, 1)
+ yield self.store.set_profile_displayname(
+ self.user.localpart, self.displayname, 1
+ )
users, total = yield self.store.get_users_paginate(
0, 10, name="bc", guests=False
diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py
index bc53bf0951..447fcb3a1c 100644
--- a/tests/storage/test_monthly_active_users.py
+++ b/tests/storage/test_monthly_active_users.py
@@ -19,94 +19,106 @@ from twisted.internet import defer
from synapse.api.constants import UserTypes
from tests import unittest
+from tests.unittest import default_config, override_config
FORTY_DAYS = 40 * 24 * 60 * 60
+def gen_3pids(count):
+ """Generate `count` threepids as a list."""
+ return [
+ {"medium": "email", "address": "user%i@matrix.org" % i} for i in range(count)
+ ]
+
+
class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
- def make_homeserver(self, reactor, clock):
+ def default_config(self):
+ config = default_config("test")
+
+ config.update({"limit_usage_by_mau": True, "max_mau_value": 50})
- hs = self.setup_test_homeserver()
- self.store = hs.get_datastore()
- hs.config.limit_usage_by_mau = True
- hs.config.max_mau_value = 50
+ # apply any additional config which was specified via the override_config
+ # decorator.
+ if self._extra_config is not None:
+ config.update(self._extra_config)
+ return config
+
+ def prepare(self, reactor, clock, homeserver):
+ self.store = homeserver.get_datastore()
# Advance the clock a bit
reactor.advance(FORTY_DAYS)
- return hs
-
+ @override_config({"max_mau_value": 3, "mau_limit_reserved_threepids": gen_3pids(3)})
def test_initialise_reserved_users(self):
- self.hs.config.max_mau_value = 5
+ threepids = self.hs.config.mau_limits_reserved_threepids
+
+ # register three users, of which two have reserved 3pids, and a third
+ # which is a support user.
user1 = "@user1:server"
- user1_email = "user1@matrix.org"
+ user1_email = threepids[0]["address"]
user2 = "@user2:server"
- user2_email = "user2@matrix.org"
+ user2_email = threepids[1]["address"]
user3 = "@user3:server"
- user3_email = "user3@matrix.org"
- threepids = [
- {"medium": "email", "address": user1_email},
- {"medium": "email", "address": user2_email},
- {"medium": "email", "address": user3_email},
- ]
- self.hs.config.mau_limits_reserved_threepids = threepids
- # -1 because user3 is a support user and does not count
- user_num = len(threepids) - 1
-
- self.store.register_user(user_id=user1, password_hash=None)
- self.store.register_user(user_id=user2, password_hash=None)
- self.store.register_user(
- user_id=user3, password_hash=None, user_type=UserTypes.SUPPORT
- )
+ self.store.register_user(user_id=user1)
+ self.store.register_user(user_id=user2)
+ self.store.register_user(user_id=user3, user_type=UserTypes.SUPPORT)
self.pump()
now = int(self.hs.get_clock().time_msec())
self.store.user_add_threepid(user1, "email", user1_email, now, now)
self.store.user_add_threepid(user2, "email", user2_email, now, now)
+ # XXX why are we doing this here? this function is only run at startup
+ # so it is odd to re-run it here.
self.store.db.runInteraction(
"initialise", self.store._initialise_reserved_users, threepids
)
self.pump()
- active_count = self.store.get_monthly_active_count()
+ # the number of users we expect will be counted against the mau limit
+ # -1 because user3 is a support user and does not count
+ user_num = len(threepids) - 1
- # Test total counts, ensure user3 (support user) is not counted
- self.assertEquals(self.get_success(active_count), user_num)
+ # Check the number of active users. Ensure user3 (support user) is not counted
+ active_count = self.get_success(self.store.get_monthly_active_count())
+ self.assertEquals(active_count, user_num)
- # Test user is marked as active
+ # Test each of the registered users is marked as active
timestamp = self.store.user_last_seen_monthly_active(user1)
self.assertTrue(self.get_success(timestamp))
timestamp = self.store.user_last_seen_monthly_active(user2)
self.assertTrue(self.get_success(timestamp))
- # Test that users are never removed from the db.
+ # Test that users with reserved 3pids are not removed from the MAU table
+ # XXX some of this is redundant. poking things into the config shouldn't
+ # work, and in any case it's not obvious what we expect to happen when
+ # we advance the reactor.
self.hs.config.max_mau_value = 0
-
self.reactor.advance(FORTY_DAYS)
self.hs.config.max_mau_value = 5
-
self.store.reap_monthly_active_users()
self.pump()
active_count = self.store.get_monthly_active_count()
self.assertEquals(self.get_success(active_count), user_num)
- # Test that regular users are removed from the db
+ # Add some more users and check they are counted as active
ru_count = 2
self.store.upsert_monthly_active_user("@ru1:server")
self.store.upsert_monthly_active_user("@ru2:server")
self.pump()
-
active_count = self.store.get_monthly_active_count()
self.assertEqual(self.get_success(active_count), user_num + ru_count)
- self.hs.config.max_mau_value = user_num
+
+ # now run the reaper and check that the number of active users is reduced
+ # to max_mau_value
self.store.reap_monthly_active_users()
self.pump()
active_count = self.store.get_monthly_active_count()
- self.assertEquals(self.get_success(active_count), user_num)
+ self.assertEquals(self.get_success(active_count), 3)
def test_can_insert_and_count_mau(self):
count = self.store.get_monthly_active_count()
@@ -136,8 +148,8 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
result = self.store.user_last_seen_monthly_active(user_id3)
self.assertNotEqual(self.get_success(result), 0)
+ @override_config({"max_mau_value": 5})
def test_reap_monthly_active_users(self):
- self.hs.config.max_mau_value = 5
initial_users = 10
for i in range(initial_users):
self.store.upsert_monthly_active_user("@user%d:server" % i)
@@ -158,19 +170,19 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
count = self.store.get_monthly_active_count()
self.assertEquals(self.get_success(count), 0)
+ # Note that below says mau_limit (no s), this is the name of the config
+ # value, although it gets stored on the config object as mau_limits.
+ @override_config({"max_mau_value": 5, "mau_limit_reserved_threepids": gen_3pids(5)})
def test_reap_monthly_active_users_reserved_users(self):
""" Tests that reaping correctly handles reaping where reserved users are
present"""
-
- self.hs.config.max_mau_value = 5
- initial_users = 5
+ threepids = self.hs.config.mau_limits_reserved_threepids
+ initial_users = len(threepids)
reserved_user_number = initial_users - 1
- threepids = []
for i in range(initial_users):
user = "@user%d:server" % i
- email = "user%d@example.com" % i
+ email = "user%d@matrix.org" % i
self.get_success(self.store.upsert_monthly_active_user(user))
- threepids.append({"medium": "email", "address": email})
# Need to ensure that the most recent entries in the
# monthly_active_users table are reserved
now = int(self.hs.get_clock().time_msec())
@@ -182,7 +194,6 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.store.user_add_threepid(user, "email", email, now, now)
)
- self.hs.config.mau_limits_reserved_threepids = threepids
self.store.db.runInteraction(
"initialise", self.store._initialise_reserved_users, threepids
)
@@ -279,11 +290,11 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.pump()
self.assertEqual(self.get_success(count), 0)
+ # Note that the max_mau_value setting should not matter.
+ @override_config(
+ {"limit_usage_by_mau": False, "mau_stats_only": True, "max_mau_value": 1}
+ )
def test_track_monthly_users_without_cap(self):
- self.hs.config.limit_usage_by_mau = False
- self.hs.config.mau_stats_only = True
- self.hs.config.max_mau_value = 1 # should not matter
-
count = self.store.get_monthly_active_count()
self.assertEqual(0, self.get_success(count))
@@ -294,9 +305,8 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
count = self.store.get_monthly_active_count()
self.assertEqual(2, self.get_success(count))
+ @override_config({"limit_usage_by_mau": False, "mau_stats_only": False})
def test_no_users_when_not_tracking(self):
- self.hs.config.limit_usage_by_mau = False
- self.hs.config.mau_stats_only = False
self.store.upsert_monthly_active_user = Mock()
self.store.populate_monthly_active_users("@user:sever")
diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py
index 086adeb8fd..3b78d48896 100644
--- a/tests/storage/test_room.py
+++ b/tests/storage/test_room.py
@@ -55,6 +55,17 @@ class RoomStoreTestCase(unittest.TestCase):
(yield self.store.get_room(self.room.to_string())),
)
+ @defer.inlineCallbacks
+ def test_get_room_with_stats(self):
+ self.assertDictContainsSubset(
+ {
+ "room_id": self.room.to_string(),
+ "creator": self.u_creator.to_string(),
+ "public": True,
+ },
+ (yield self.store.get_room_with_stats(self.room.to_string())),
+ )
+
class RoomEventsStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py
index 6c2351cf55..69b4c5d6c2 100644
--- a/tests/test_event_auth.py
+++ b/tests/test_event_auth.py
@@ -136,21 +136,18 @@ class EventAuthTestCase(unittest.TestCase):
# creator should be able to send aliases
event_auth.check(
- RoomVersions.MSC2432_DEV,
- _alias_event(creator),
- auth_events,
- do_sig_check=False,
+ RoomVersions.V6, _alias_event(creator), auth_events, do_sig_check=False,
)
# No particular checks are done on the state key.
event_auth.check(
- RoomVersions.MSC2432_DEV,
+ RoomVersions.V6,
_alias_event(creator, state_key=""),
auth_events,
do_sig_check=False,
)
event_auth.check(
- RoomVersions.MSC2432_DEV,
+ RoomVersions.V6,
_alias_event(creator, state_key="test.com"),
auth_events,
do_sig_check=False,
@@ -159,8 +156,38 @@ class EventAuthTestCase(unittest.TestCase):
# Per standard auth rules, the member must be in the room.
with self.assertRaises(AuthError):
event_auth.check(
- RoomVersions.MSC2432_DEV,
- _alias_event(other),
+ RoomVersions.V6, _alias_event(other), auth_events, do_sig_check=False,
+ )
+
+ def test_msc2209(self):
+ """
+ Notifications power levels get checked due to MSC2209.
+ """
+ creator = "@creator:example.com"
+ pleb = "@joiner:example.com"
+
+ auth_events = {
+ ("m.room.create", ""): _create_event(creator),
+ ("m.room.member", creator): _join_event(creator),
+ ("m.room.power_levels", ""): _power_levels_event(
+ creator, {"state_default": "30", "users": {pleb: "30"}}
+ ),
+ ("m.room.member", pleb): _join_event(pleb),
+ }
+
+ # pleb should be able to modify the notifications power level.
+ event_auth.check(
+ RoomVersions.V1,
+ _power_levels_event(pleb, {"notifications": {"room": 100}}),
+ auth_events,
+ do_sig_check=False,
+ )
+
+ # But an MSC2209 room rejects this change.
+ with self.assertRaises(AuthError):
+ event_auth.check(
+ RoomVersions.V6,
+ _power_levels_event(pleb, {"notifications": {"room": 100}}),
auth_events,
do_sig_check=False,
)
diff --git a/tests/test_federation.py b/tests/test_federation.py
index f297de95f1..c5099dd039 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -6,12 +6,13 @@ from synapse.events import make_event_from_dict
from synapse.logging.context import LoggingContext
from synapse.types import Requester, UserID
from synapse.util import Clock
+from synapse.util.retryutils import NotRetryingDestination
from tests import unittest
from tests.server import ThreadedMemoryReactorClock, setup_test_homeserver
-class MessageAcceptTests(unittest.TestCase):
+class MessageAcceptTests(unittest.HomeserverTestCase):
def setUp(self):
self.http_client = Mock()
@@ -27,13 +28,13 @@ class MessageAcceptTests(unittest.TestCase):
user_id = UserID("us", "test")
our_user = Requester(user_id, None, False, None, None)
room_creator = self.homeserver.get_room_creation_handler()
- room = ensureDeferred(
+ room_deferred = ensureDeferred(
room_creator.create_room(
our_user, room_creator.PRESETS_DICT["public_chat"], ratelimit=False
)
)
self.reactor.advance(0.1)
- self.room_id = self.successResultOf(room)["room_id"]
+ self.room_id = self.successResultOf(room_deferred)[0]["room_id"]
self.store = self.homeserver.get_datastore()
@@ -145,3 +146,63 @@ class MessageAcceptTests(unittest.TestCase):
# Make sure the invalid event isn't there
extrem = maybeDeferred(self.store.get_latest_event_ids_in_room, self.room_id)
self.assertEqual(self.successResultOf(extrem)[0], "$join:test.serv")
+
+ def test_retry_device_list_resync(self):
+ """Tests that device lists are marked as stale if they couldn't be synced, and
+ that stale device lists are retried periodically.
+ """
+ remote_user_id = "@john:test_remote"
+ remote_origin = "test_remote"
+
+ # Track the number of attempts to resync the user's device list.
+ self.resync_attempts = 0
+
+ # When this function is called, increment the number of resync attempts (only if
+ # we're querying devices for the right user ID), then raise a
+ # NotRetryingDestination error to fail the resync gracefully.
+ def query_user_devices(destination, user_id):
+ if user_id == remote_user_id:
+ self.resync_attempts += 1
+
+ raise NotRetryingDestination(0, 0, destination)
+
+ # Register the mock on the federation client.
+ federation_client = self.homeserver.get_federation_client()
+ federation_client.query_user_devices = Mock(side_effect=query_user_devices)
+
+ # Register a mock on the store so that the incoming update doesn't fail because
+ # we don't share a room with the user.
+ store = self.homeserver.get_datastore()
+ store.get_rooms_for_user = Mock(return_value=["!someroom:test"])
+
+ # Manually inject a fake device list update. We need this update to include at
+ # least one prev_id so that the user's device list will need to be retried.
+ device_list_updater = self.homeserver.get_device_handler().device_list_updater
+ self.get_success(
+ device_list_updater.incoming_device_list_update(
+ origin=remote_origin,
+ edu_content={
+ "deleted": False,
+ "device_display_name": "Mobile",
+ "device_id": "QBUAZIFURK",
+ "prev_id": [5],
+ "stream_id": 6,
+ "user_id": remote_user_id,
+ },
+ )
+ )
+
+ # Check that there was one resync attempt.
+ self.assertEqual(self.resync_attempts, 1)
+
+ # Check that the resync attempt failed and caused the user's device list to be
+ # marked as stale.
+ need_resync = self.get_success(
+ store.get_user_ids_requiring_device_list_resync()
+ )
+ self.assertIn(remote_user_id, need_resync)
+
+ # Check that waiting for 30 seconds caused Synapse to retry resyncing the device
+ # list.
+ self.reactor.advance(30)
+ self.assertEqual(self.resync_attempts, 2)
diff --git a/tests/test_mau.py b/tests/test_mau.py
index eb159e3ba5..8a97f0998d 100644
--- a/tests/test_mau.py
+++ b/tests/test_mau.py
@@ -17,47 +17,44 @@
import json
-from mock import Mock
-
-from synapse.api.auth_blocking import AuthBlocking
from synapse.api.constants import LoginType
from synapse.api.errors import Codes, HttpResponseException, SynapseError
from synapse.rest.client.v2_alpha import register, sync
from tests import unittest
+from tests.unittest import override_config
+from tests.utils import default_config
class TestMauLimit(unittest.HomeserverTestCase):
servlets = [register.register_servlets, sync.register_servlets]
- def make_homeserver(self, reactor, clock):
+ def default_config(self):
+ config = default_config("test")
- self.hs = self.setup_test_homeserver(
- "red", http_client=None, federation_client=Mock()
+ config.update(
+ {
+ "registrations_require_3pid": [],
+ "limit_usage_by_mau": True,
+ "max_mau_value": 2,
+ "mau_trial_days": 0,
+ "server_notices": {
+ "system_mxid_localpart": "server",
+ "room_name": "Test Server Notice Room",
+ },
+ }
)
- self.store = self.hs.get_datastore()
-
- self.hs.config.registrations_require_3pid = []
- self.hs.config.enable_registration_captcha = False
- self.hs.config.recaptcha_public_key = []
-
- self.hs.config.limit_usage_by_mau = True
- self.hs.config.hs_disabled = False
- self.hs.config.max_mau_value = 2
- self.hs.config.server_notices_mxid = "@server:red"
- self.hs.config.server_notices_mxid_display_name = None
- self.hs.config.server_notices_mxid_avatar_url = None
- self.hs.config.server_notices_room_name = "Test Server Notice Room"
- self.hs.config.mau_trial_days = 0
+ # apply any additional config which was specified via the override_config
+ # decorator.
+ if self._extra_config is not None:
+ config.update(self._extra_config)
- # AuthBlocking reads config options during hs creation. Recreate the
- # hs' copy of AuthBlocking after we've updated config values above
- self.auth_blocking = AuthBlocking(self.hs)
- self.hs.get_auth()._auth_blocking = self.auth_blocking
+ return config
- return self.hs
+ def prepare(self, reactor, clock, homeserver):
+ self.store = homeserver.get_datastore()
def test_simple_deny_mau(self):
# Create and sync so that the MAU counts get updated
@@ -66,6 +63,9 @@ class TestMauLimit(unittest.HomeserverTestCase):
token2 = self.create_user("kermit2")
self.do_sync_for_user(token2)
+ # check we're testing what we think we are: there should be two active users
+ self.assertEqual(self.get_success(self.store.get_monthly_active_count()), 2)
+
# We've created and activated two users, we shouldn't be able to
# register new users
with self.assertRaises(SynapseError) as cm:
@@ -93,9 +93,8 @@ class TestMauLimit(unittest.HomeserverTestCase):
token3 = self.create_user("kermit3")
self.do_sync_for_user(token3)
+ @override_config({"mau_trial_days": 1})
def test_trial_delay(self):
- self.hs.config.mau_trial_days = 1
-
# We should be able to register more than the limit initially
token1 = self.create_user("kermit1")
self.do_sync_for_user(token1)
@@ -127,8 +126,8 @@ class TestMauLimit(unittest.HomeserverTestCase):
self.assertEqual(e.code, 403)
self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
+ @override_config({"mau_trial_days": 1})
def test_trial_users_cant_come_back(self):
- self.auth_blocking._mau_trial_days = 1
self.hs.config.mau_trial_days = 1
# We should be able to register more than the limit initially
@@ -176,11 +175,11 @@ class TestMauLimit(unittest.HomeserverTestCase):
self.assertEqual(e.code, 403)
self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
+ @override_config(
+ # max_mau_value should not matter
+ {"max_mau_value": 1, "limit_usage_by_mau": False, "mau_stats_only": True}
+ )
def test_tracked_but_not_limited(self):
- self.auth_blocking._max_mau_value = 1 # should not matter
- self.auth_blocking._limit_usage_by_mau = False
- self.hs.config.mau_stats_only = True
-
# Simply being able to create 2 users indicates that the
# limit was not reached.
token1 = self.create_user("kermit1")
diff --git a/tests/test_metrics.py b/tests/test_metrics.py
index 270f853d60..f5f63d8ed6 100644
--- a/tests/test_metrics.py
+++ b/tests/test_metrics.py
@@ -15,6 +15,7 @@
# limitations under the License.
from synapse.metrics import REGISTRY, InFlightGauge, generate_latest
+from synapse.util.caches.descriptors import Cache
from tests import unittest
@@ -129,3 +130,36 @@ class BuildInfoTests(unittest.TestCase):
self.assertTrue(b"osversion=" in items[0])
self.assertTrue(b"pythonversion=" in items[0])
self.assertTrue(b"version=" in items[0])
+
+
+class CacheMetricsTests(unittest.HomeserverTestCase):
+ def test_cache_metric(self):
+ """
+ Caches produce metrics reflecting their state when scraped.
+ """
+ CACHE_NAME = "cache_metrics_test_fgjkbdfg"
+ cache = Cache(CACHE_NAME, max_entries=777)
+
+ items = {
+ x.split(b"{")[0].decode("ascii"): x.split(b" ")[1].decode("ascii")
+ for x in filter(
+ lambda x: b"cache_metrics_test_fgjkbdfg" in x,
+ generate_latest(REGISTRY).split(b"\n"),
+ )
+ }
+
+ self.assertEqual(items["synapse_util_caches_cache_size"], "0.0")
+ self.assertEqual(items["synapse_util_caches_cache_max_size"], "777.0")
+
+ cache.prefill("1", "hi")
+
+ items = {
+ x.split(b"{")[0].decode("ascii"): x.split(b" ")[1].decode("ascii")
+ for x in filter(
+ lambda x: b"cache_metrics_test_fgjkbdfg" in x,
+ generate_latest(REGISTRY).split(b"\n"),
+ )
+ }
+
+ self.assertEqual(items["synapse_util_caches_cache_size"], "1.0")
+ self.assertEqual(items["synapse_util_caches_cache_max_size"], "777.0")
diff --git a/tests/test_server.py b/tests/test_server.py
index 0d57eed268..e9a43b1e45 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -27,6 +27,7 @@ from synapse.api.errors import Codes, RedirectException, SynapseError
from synapse.http.server import (
DirectServeResource,
JsonResource,
+ OptionsResource,
wrap_html_request_handler,
)
from synapse.http.site import SynapseSite, logger
@@ -168,6 +169,86 @@ class JsonResourceTests(unittest.TestCase):
self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED")
+class OptionsResourceTests(unittest.TestCase):
+ def setUp(self):
+ self.reactor = ThreadedMemoryReactorClock()
+
+ class DummyResource(Resource):
+ isLeaf = True
+
+ def render(self, request):
+ return request.path
+
+ # Setup a resource with some children.
+ self.resource = OptionsResource()
+ self.resource.putChild(b"res", DummyResource())
+
+ def _make_request(self, method, path):
+ """Create a request from the method/path and return a channel with the response."""
+ request, channel = make_request(self.reactor, method, path, shorthand=False)
+ request.prepath = [] # This doesn't get set properly by make_request.
+
+ # Create a site and query for the resource.
+ site = SynapseSite("test", "site_tag", {}, self.resource, "1.0")
+ request.site = site
+ resource = site.getResourceFor(request)
+
+ # Finally, render the resource and return the channel.
+ render(request, resource, self.reactor)
+ return channel
+
+ def test_unknown_options_request(self):
+ """An OPTIONS requests to an unknown URL still returns 200 OK."""
+ channel = self._make_request(b"OPTIONS", b"/foo/")
+ self.assertEqual(channel.result["code"], b"200")
+ self.assertEqual(channel.result["body"], b"{}")
+
+ # Ensure the correct CORS headers have been added
+ self.assertTrue(
+ channel.headers.hasHeader(b"Access-Control-Allow-Origin"),
+ "has CORS Origin header",
+ )
+ self.assertTrue(
+ channel.headers.hasHeader(b"Access-Control-Allow-Methods"),
+ "has CORS Methods header",
+ )
+ self.assertTrue(
+ channel.headers.hasHeader(b"Access-Control-Allow-Headers"),
+ "has CORS Headers header",
+ )
+
+ def test_known_options_request(self):
+ """An OPTIONS requests to an known URL still returns 200 OK."""
+ channel = self._make_request(b"OPTIONS", b"/res/")
+ self.assertEqual(channel.result["code"], b"200")
+ self.assertEqual(channel.result["body"], b"{}")
+
+ # Ensure the correct CORS headers have been added
+ self.assertTrue(
+ channel.headers.hasHeader(b"Access-Control-Allow-Origin"),
+ "has CORS Origin header",
+ )
+ self.assertTrue(
+ channel.headers.hasHeader(b"Access-Control-Allow-Methods"),
+ "has CORS Methods header",
+ )
+ self.assertTrue(
+ channel.headers.hasHeader(b"Access-Control-Allow-Headers"),
+ "has CORS Headers header",
+ )
+
+ def test_unknown_request(self):
+ """A non-OPTIONS request to an unknown URL should 404."""
+ channel = self._make_request(b"GET", b"/foo/")
+ self.assertEqual(channel.result["code"], b"404")
+
+ def test_known_request(self):
+ """A non-OPTIONS request to an known URL should query the proper resource."""
+ channel = self._make_request(b"GET", b"/res/")
+ self.assertEqual(channel.result["code"], b"200")
+ self.assertEqual(channel.result["body"], b"/res/")
+
+
class WrapHtmlRequestHandlerTests(unittest.TestCase):
class TestResource(DirectServeResource):
callback = None
diff --git a/tests/util/test_expiring_cache.py b/tests/util/test_expiring_cache.py
index 50bc7702d2..49ffeebd0e 100644
--- a/tests/util/test_expiring_cache.py
+++ b/tests/util/test_expiring_cache.py
@@ -21,7 +21,7 @@ from tests.utils import MockClock
from .. import unittest
-class ExpiringCacheTestCase(unittest.TestCase):
+class ExpiringCacheTestCase(unittest.HomeserverTestCase):
def test_get_set(self):
clock = MockClock()
cache = ExpiringCache("test", clock, max_len=1)
diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py
index 786947375d..0adb2174af 100644
--- a/tests/util/test_lrucache.py
+++ b/tests/util/test_lrucache.py
@@ -22,7 +22,7 @@ from synapse.util.caches.treecache import TreeCache
from .. import unittest
-class LruCacheTestCase(unittest.TestCase):
+class LruCacheTestCase(unittest.HomeserverTestCase):
def test_get_set(self):
cache = LruCache(1)
cache["key"] = "value"
@@ -84,7 +84,7 @@ class LruCacheTestCase(unittest.TestCase):
self.assertEquals(len(cache), 0)
-class LruCacheCallbacksTestCase(unittest.TestCase):
+class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
def test_get(self):
m = Mock()
cache = LruCache(1)
@@ -233,7 +233,7 @@ class LruCacheCallbacksTestCase(unittest.TestCase):
self.assertEquals(m3.call_count, 1)
-class LruCacheSizedTestCase(unittest.TestCase):
+class LruCacheSizedTestCase(unittest.HomeserverTestCase):
def test_evict(self):
cache = LruCache(5, size_callback=len)
cache["key1"] = [0]
diff --git a/tests/util/test_stream_change_cache.py b/tests/util/test_stream_change_cache.py
index 6857933540..13b753e367 100644
--- a/tests/util/test_stream_change_cache.py
+++ b/tests/util/test_stream_change_cache.py
@@ -1,11 +1,9 @@
-from mock import patch
-
from synapse.util.caches.stream_change_cache import StreamChangeCache
from tests import unittest
-class StreamChangeCacheTests(unittest.TestCase):
+class StreamChangeCacheTests(unittest.HomeserverTestCase):
"""
Tests for StreamChangeCache.
"""
@@ -54,7 +52,6 @@ class StreamChangeCacheTests(unittest.TestCase):
self.assertTrue(cache.has_entity_changed("user@foo.com", 0))
self.assertTrue(cache.has_entity_changed("not@here.website", 0))
- @patch("synapse.util.caches.CACHE_SIZE_FACTOR", 1.0)
def test_entity_has_changed_pops_off_start(self):
"""
StreamChangeCache.entity_has_changed will respect the max size and
diff --git a/tests/utils.py b/tests/utils.py
index f9be62b499..59c020a051 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -167,6 +167,7 @@ def default_config(name, parse=False):
# disable user directory updates, because they get done in the
# background, which upsets the test runner.
"update_user_directory": False,
+ "caches": {"global_factor": 1},
}
if parse:
|