diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
new file mode 100644
index 0000000000..61963aa90d
--- /dev/null
+++ b/tests/handlers/test_oidc.py
@@ -0,0 +1,565 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Quentin Gliech
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+from urllib.parse import parse_qs, urlparse
+
+from mock import Mock, patch
+
+import attr
+import pymacaroons
+
+from twisted.internet import defer
+from twisted.python.failure import Failure
+from twisted.web._newclient import ResponseDone
+
+from synapse.handlers.oidc_handler import (
+ MappingException,
+ OidcError,
+ OidcHandler,
+ OidcMappingProvider,
+)
+from synapse.types import UserID
+
+from tests.unittest import HomeserverTestCase, override_config
+
+
+@attr.s
+class FakeResponse:
+ code = attr.ib()
+ body = attr.ib()
+ phrase = attr.ib()
+
+ def deliverBody(self, protocol):
+ protocol.dataReceived(self.body)
+ protocol.connectionLost(Failure(ResponseDone()))
+
+
+# These are a few constants that are used as config parameters in the tests.
+ISSUER = "https://issuer/"
+CLIENT_ID = "test-client-id"
+CLIENT_SECRET = "test-client-secret"
+BASE_URL = "https://synapse/"
+CALLBACK_URL = BASE_URL + "_synapse/oidc/callback"
+SCOPES = ["openid"]
+
+AUTHORIZATION_ENDPOINT = ISSUER + "authorize"
+TOKEN_ENDPOINT = ISSUER + "token"
+USERINFO_ENDPOINT = ISSUER + "userinfo"
+WELL_KNOWN = ISSUER + ".well-known/openid-configuration"
+JWKS_URI = ISSUER + ".well-known/jwks.json"
+
+# config for common cases
+COMMON_CONFIG = {
+ "discover": False,
+ "authorization_endpoint": AUTHORIZATION_ENDPOINT,
+ "token_endpoint": TOKEN_ENDPOINT,
+ "jwks_uri": JWKS_URI,
+}
+
+
+# The cookie name and path don't really matter, just that it has to be coherent
+# between the callback & redirect handlers.
+COOKIE_NAME = b"oidc_session"
+COOKIE_PATH = "/_synapse/oidc"
+
+MockedMappingProvider = Mock(OidcMappingProvider)
+
+
+def simple_async_mock(return_value=None, raises=None):
+ # AsyncMock is not available in python3.5, this mimics part of its behaviour
+ async def cb(*args, **kwargs):
+ if raises:
+ raise raises
+ return return_value
+
+ return Mock(side_effect=cb)
+
+
+async def get_json(url):
+ # Mock get_json calls to handle jwks & oidc discovery endpoints
+ if url == WELL_KNOWN:
+ # Minimal discovery document, as defined in OpenID.Discovery
+ # https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata
+ return {
+ "issuer": ISSUER,
+ "authorization_endpoint": AUTHORIZATION_ENDPOINT,
+ "token_endpoint": TOKEN_ENDPOINT,
+ "jwks_uri": JWKS_URI,
+ "userinfo_endpoint": USERINFO_ENDPOINT,
+ "response_types_supported": ["code"],
+ "subject_types_supported": ["public"],
+ "id_token_signing_alg_values_supported": ["RS256"],
+ }
+ elif url == JWKS_URI:
+ return {"keys": []}
+
+
+class OidcHandlerTestCase(HomeserverTestCase):
+ def make_homeserver(self, reactor, clock):
+
+ self.http_client = Mock(spec=["get_json"])
+ self.http_client.get_json.side_effect = get_json
+ self.http_client.user_agent = "Synapse Test"
+
+ config = self.default_config()
+ config["public_baseurl"] = BASE_URL
+ oidc_config = config.get("oidc_config", {})
+ oidc_config["enabled"] = True
+ oidc_config["client_id"] = CLIENT_ID
+ oidc_config["client_secret"] = CLIENT_SECRET
+ oidc_config["issuer"] = ISSUER
+ oidc_config["scopes"] = SCOPES
+ oidc_config["user_mapping_provider"] = {
+ "module": __name__ + ".MockedMappingProvider"
+ }
+ config["oidc_config"] = oidc_config
+
+ hs = self.setup_test_homeserver(
+ http_client=self.http_client,
+ proxied_http_client=self.http_client,
+ config=config,
+ )
+
+ self.handler = OidcHandler(hs)
+
+ return hs
+
+ def metadata_edit(self, values):
+ return patch.dict(self.handler._provider_metadata, values)
+
+ def assertRenderedError(self, error, error_description=None):
+ args = self.handler._render_error.call_args[0]
+ self.assertEqual(args[1], error)
+ if error_description is not None:
+ self.assertEqual(args[2], error_description)
+ # Reset the render_error mock
+ self.handler._render_error.reset_mock()
+
+ def test_config(self):
+ """Basic config correctly sets up the callback URL and client auth correctly."""
+ self.assertEqual(self.handler._callback_url, CALLBACK_URL)
+ self.assertEqual(self.handler._client_auth.client_id, CLIENT_ID)
+ self.assertEqual(self.handler._client_auth.client_secret, CLIENT_SECRET)
+
+ @override_config({"oidc_config": {"discover": True}})
+ @defer.inlineCallbacks
+ def test_discovery(self):
+ """The handler should discover the endpoints from OIDC discovery document."""
+ # This would throw if some metadata were invalid
+ metadata = yield defer.ensureDeferred(self.handler.load_metadata())
+ self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
+
+ self.assertEqual(metadata.issuer, ISSUER)
+ self.assertEqual(metadata.authorization_endpoint, AUTHORIZATION_ENDPOINT)
+ self.assertEqual(metadata.token_endpoint, TOKEN_ENDPOINT)
+ self.assertEqual(metadata.jwks_uri, JWKS_URI)
+ # FIXME: it seems like authlib does not have that defined in its metadata models
+ # self.assertEqual(metadata.userinfo_endpoint, USERINFO_ENDPOINT)
+
+ # subsequent calls should be cached
+ self.http_client.reset_mock()
+ yield defer.ensureDeferred(self.handler.load_metadata())
+ self.http_client.get_json.assert_not_called()
+
+ @override_config({"oidc_config": COMMON_CONFIG})
+ @defer.inlineCallbacks
+ def test_no_discovery(self):
+ """When discovery is disabled, it should not try to load from discovery document."""
+ yield defer.ensureDeferred(self.handler.load_metadata())
+ self.http_client.get_json.assert_not_called()
+
+ @override_config({"oidc_config": COMMON_CONFIG})
+ @defer.inlineCallbacks
+ def test_load_jwks(self):
+ """JWKS loading is done once (then cached) if used."""
+ jwks = yield defer.ensureDeferred(self.handler.load_jwks())
+ self.http_client.get_json.assert_called_once_with(JWKS_URI)
+ self.assertEqual(jwks, {"keys": []})
+
+ # subsequent calls should be cached…
+ self.http_client.reset_mock()
+ yield defer.ensureDeferred(self.handler.load_jwks())
+ self.http_client.get_json.assert_not_called()
+
+ # …unless forced
+ self.http_client.reset_mock()
+ yield defer.ensureDeferred(self.handler.load_jwks(force=True))
+ self.http_client.get_json.assert_called_once_with(JWKS_URI)
+
+ # Throw if the JWKS uri is missing
+ with self.metadata_edit({"jwks_uri": None}):
+ with self.assertRaises(RuntimeError):
+ yield defer.ensureDeferred(self.handler.load_jwks(force=True))
+
+ # Return empty key set if JWKS are not used
+ self.handler._scopes = [] # not asking the openid scope
+ self.http_client.get_json.reset_mock()
+ jwks = yield defer.ensureDeferred(self.handler.load_jwks(force=True))
+ self.http_client.get_json.assert_not_called()
+ self.assertEqual(jwks, {"keys": []})
+
+ @override_config({"oidc_config": COMMON_CONFIG})
+ def test_validate_config(self):
+ """Provider metadatas are extensively validated."""
+ h = self.handler
+
+ # Default test config does not throw
+ h._validate_metadata()
+
+ with self.metadata_edit({"issuer": None}):
+ self.assertRaisesRegex(ValueError, "issuer", h._validate_metadata)
+
+ with self.metadata_edit({"issuer": "http://insecure/"}):
+ self.assertRaisesRegex(ValueError, "issuer", h._validate_metadata)
+
+ with self.metadata_edit({"issuer": "https://invalid/?because=query"}):
+ self.assertRaisesRegex(ValueError, "issuer", h._validate_metadata)
+
+ with self.metadata_edit({"authorization_endpoint": None}):
+ self.assertRaisesRegex(
+ ValueError, "authorization_endpoint", h._validate_metadata
+ )
+
+ with self.metadata_edit({"authorization_endpoint": "http://insecure/auth"}):
+ self.assertRaisesRegex(
+ ValueError, "authorization_endpoint", h._validate_metadata
+ )
+
+ with self.metadata_edit({"token_endpoint": None}):
+ self.assertRaisesRegex(ValueError, "token_endpoint", h._validate_metadata)
+
+ with self.metadata_edit({"token_endpoint": "http://insecure/token"}):
+ self.assertRaisesRegex(ValueError, "token_endpoint", h._validate_metadata)
+
+ with self.metadata_edit({"jwks_uri": None}):
+ self.assertRaisesRegex(ValueError, "jwks_uri", h._validate_metadata)
+
+ with self.metadata_edit({"jwks_uri": "http://insecure/jwks.json"}):
+ self.assertRaisesRegex(ValueError, "jwks_uri", h._validate_metadata)
+
+ with self.metadata_edit({"response_types_supported": ["id_token"]}):
+ self.assertRaisesRegex(
+ ValueError, "response_types_supported", h._validate_metadata
+ )
+
+ with self.metadata_edit(
+ {"token_endpoint_auth_methods_supported": ["client_secret_basic"]}
+ ):
+ # should not throw, as client_secret_basic is the default auth method
+ h._validate_metadata()
+
+ with self.metadata_edit(
+ {"token_endpoint_auth_methods_supported": ["client_secret_post"]}
+ ):
+ self.assertRaisesRegex(
+ ValueError,
+ "token_endpoint_auth_methods_supported",
+ h._validate_metadata,
+ )
+
+ # Tests for configs that the userinfo endpoint
+ self.assertFalse(h._uses_userinfo)
+ h._scopes = [] # do not request the openid scope
+ self.assertTrue(h._uses_userinfo)
+ self.assertRaisesRegex(ValueError, "userinfo_endpoint", h._validate_metadata)
+
+ with self.metadata_edit(
+ {"userinfo_endpoint": USERINFO_ENDPOINT, "jwks_uri": None}
+ ):
+ # Shouldn't raise with a valid userinfo, even without
+ h._validate_metadata()
+
+ @override_config({"oidc_config": {"skip_verification": True}})
+ def test_skip_verification(self):
+ """Provider metadata validation can be disabled by config."""
+ with self.metadata_edit({"issuer": "http://insecure"}):
+ # This should not throw
+ self.handler._validate_metadata()
+
+ @defer.inlineCallbacks
+ def test_redirect_request(self):
+ """The redirect request has the right arguments & generates a valid session cookie."""
+ req = Mock(spec=["addCookie", "redirect", "finish"])
+ yield defer.ensureDeferred(
+ self.handler.handle_redirect_request(req, b"http://client/redirect")
+ )
+ url = req.redirect.call_args[0][0]
+ url = urlparse(url)
+ auth_endpoint = urlparse(AUTHORIZATION_ENDPOINT)
+
+ self.assertEqual(url.scheme, auth_endpoint.scheme)
+ self.assertEqual(url.netloc, auth_endpoint.netloc)
+ self.assertEqual(url.path, auth_endpoint.path)
+
+ params = parse_qs(url.query)
+ self.assertEqual(params["redirect_uri"], [CALLBACK_URL])
+ self.assertEqual(params["response_type"], ["code"])
+ self.assertEqual(params["scope"], [" ".join(SCOPES)])
+ self.assertEqual(params["client_id"], [CLIENT_ID])
+ self.assertEqual(len(params["state"]), 1)
+ self.assertEqual(len(params["nonce"]), 1)
+
+ # Check what is in the cookie
+ # note: python3.5 mock does not have the .called_once() method
+ calls = req.addCookie.call_args_list
+ self.assertEqual(len(calls), 1) # called once
+ # For some reason, call.args does not work with python3.5
+ args = calls[0][0]
+ kwargs = calls[0][1]
+ self.assertEqual(args[0], COOKIE_NAME)
+ self.assertEqual(kwargs["path"], COOKIE_PATH)
+ cookie = args[1]
+
+ macaroon = pymacaroons.Macaroon.deserialize(cookie)
+ state = self.handler._get_value_from_macaroon(macaroon, "state")
+ nonce = self.handler._get_value_from_macaroon(macaroon, "nonce")
+ redirect = self.handler._get_value_from_macaroon(
+ macaroon, "client_redirect_url"
+ )
+
+ self.assertEqual(params["state"], [state])
+ self.assertEqual(params["nonce"], [nonce])
+ self.assertEqual(redirect, "http://client/redirect")
+
+ @defer.inlineCallbacks
+ def test_callback_error(self):
+ """Errors from the provider returned in the callback are displayed."""
+ self.handler._render_error = Mock()
+ request = Mock(args={})
+ request.args[b"error"] = [b"invalid_client"]
+ yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.assertRenderedError("invalid_client", "")
+
+ request.args[b"error_description"] = [b"some description"]
+ yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.assertRenderedError("invalid_client", "some description")
+
+ @defer.inlineCallbacks
+ def test_callback(self):
+ """Code callback works and display errors if something went wrong.
+
+ A lot of scenarios are tested here:
+ - when the callback works, with userinfo from ID token
+ - when the user mapping fails
+ - when ID token verification fails
+ - when the callback works, with userinfo fetched from the userinfo endpoint
+ - when the userinfo fetching fails
+ - when the code exchange fails
+ """
+ token = {
+ "type": "bearer",
+ "id_token": "id_token",
+ "access_token": "access_token",
+ }
+ userinfo = {
+ "sub": "foo",
+ "preferred_username": "bar",
+ }
+ user_id = UserID("foo", "domain.org")
+ self.handler._render_error = Mock(return_value=None)
+ self.handler._exchange_code = simple_async_mock(return_value=token)
+ self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
+ self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
+ self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
+ self.handler._auth_handler.complete_sso_login = simple_async_mock()
+ request = Mock(spec=["args", "getCookie", "addCookie"])
+
+ code = "code"
+ state = "state"
+ nonce = "nonce"
+ client_redirect_url = "http://client/redirect"
+ session = self.handler._generate_oidc_session_token(
+ state=state, nonce=nonce, client_redirect_url=client_redirect_url,
+ )
+ request.getCookie.return_value = session
+
+ request.args = {}
+ request.args[b"code"] = [code.encode("utf-8")]
+ request.args[b"state"] = [state.encode("utf-8")]
+
+ yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+
+ self.handler._auth_handler.complete_sso_login.assert_called_once_with(
+ user_id, request, client_redirect_url,
+ )
+ self.handler._exchange_code.assert_called_once_with(code)
+ self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce)
+ self.handler._map_userinfo_to_user.assert_called_once_with(userinfo, token)
+ self.handler._fetch_userinfo.assert_not_called()
+ self.handler._render_error.assert_not_called()
+
+ # Handle mapping errors
+ self.handler._map_userinfo_to_user = simple_async_mock(
+ raises=MappingException()
+ )
+ yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.assertRenderedError("mapping_error")
+ self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
+
+ # Handle ID token errors
+ self.handler._parse_id_token = simple_async_mock(raises=Exception())
+ yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.assertRenderedError("invalid_token")
+
+ self.handler._auth_handler.complete_sso_login.reset_mock()
+ self.handler._exchange_code.reset_mock()
+ self.handler._parse_id_token.reset_mock()
+ self.handler._map_userinfo_to_user.reset_mock()
+ self.handler._fetch_userinfo.reset_mock()
+
+ # With userinfo fetching
+ self.handler._scopes = [] # do not ask the "openid" scope
+ yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+
+ self.handler._auth_handler.complete_sso_login.assert_called_once_with(
+ user_id, request, client_redirect_url,
+ )
+ self.handler._exchange_code.assert_called_once_with(code)
+ self.handler._parse_id_token.assert_not_called()
+ self.handler._map_userinfo_to_user.assert_called_once_with(userinfo, token)
+ self.handler._fetch_userinfo.assert_called_once_with(token)
+ self.handler._render_error.assert_not_called()
+
+ # Handle userinfo fetching error
+ self.handler._fetch_userinfo = simple_async_mock(raises=Exception())
+ yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.assertRenderedError("fetch_error")
+
+ # Handle code exchange failure
+ self.handler._exchange_code = simple_async_mock(
+ raises=OidcError("invalid_request")
+ )
+ yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.assertRenderedError("invalid_request")
+
+ @defer.inlineCallbacks
+ def test_callback_session(self):
+ """The callback verifies the session presence and validity"""
+ self.handler._render_error = Mock(return_value=None)
+ request = Mock(spec=["args", "getCookie", "addCookie"])
+
+ # Missing cookie
+ request.args = {}
+ request.getCookie.return_value = None
+ yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.assertRenderedError("missing_session", "No session cookie found")
+
+ # Missing session parameter
+ request.args = {}
+ request.getCookie.return_value = "session"
+ yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.assertRenderedError("invalid_request", "State parameter is missing")
+
+ # Invalid cookie
+ request.args = {}
+ request.args[b"state"] = [b"state"]
+ request.getCookie.return_value = "session"
+ yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.assertRenderedError("invalid_session")
+
+ # Mismatching session
+ session = self.handler._generate_oidc_session_token(
+ state="state", nonce="nonce", client_redirect_url="http://client/redirect",
+ )
+ request.args = {}
+ request.args[b"state"] = [b"mismatching state"]
+ request.getCookie.return_value = session
+ yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.assertRenderedError("mismatching_session")
+
+ # Valid session
+ request.args = {}
+ request.args[b"state"] = [b"state"]
+ request.getCookie.return_value = session
+ yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.assertRenderedError("invalid_request")
+
+ @override_config({"oidc_config": {"client_auth_method": "client_secret_post"}})
+ @defer.inlineCallbacks
+ def test_exchange_code(self):
+ """Code exchange behaves correctly and handles various error scenarios."""
+ token = {"type": "bearer"}
+ token_json = json.dumps(token).encode("utf-8")
+ self.http_client.request = simple_async_mock(
+ return_value=FakeResponse(code=200, phrase=b"OK", body=token_json)
+ )
+ code = "code"
+ ret = yield defer.ensureDeferred(self.handler._exchange_code(code))
+ kwargs = self.http_client.request.call_args[1]
+
+ self.assertEqual(ret, token)
+ self.assertEqual(kwargs["method"], "POST")
+ self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT)
+
+ args = parse_qs(kwargs["data"].decode("utf-8"))
+ self.assertEqual(args["grant_type"], ["authorization_code"])
+ self.assertEqual(args["code"], [code])
+ self.assertEqual(args["client_id"], [CLIENT_ID])
+ self.assertEqual(args["client_secret"], [CLIENT_SECRET])
+ self.assertEqual(args["redirect_uri"], [CALLBACK_URL])
+
+ # Test error handling
+ self.http_client.request = simple_async_mock(
+ return_value=FakeResponse(
+ code=400,
+ phrase=b"Bad Request",
+ body=b'{"error": "foo", "error_description": "bar"}',
+ )
+ )
+ with self.assertRaises(OidcError) as exc:
+ yield defer.ensureDeferred(self.handler._exchange_code(code))
+ self.assertEqual(exc.exception.error, "foo")
+ self.assertEqual(exc.exception.error_description, "bar")
+
+ # Internal server error with no JSON body
+ self.http_client.request = simple_async_mock(
+ return_value=FakeResponse(
+ code=500, phrase=b"Internal Server Error", body=b"Not JSON",
+ )
+ )
+ with self.assertRaises(OidcError) as exc:
+ yield defer.ensureDeferred(self.handler._exchange_code(code))
+ self.assertEqual(exc.exception.error, "server_error")
+
+ # Internal server error with JSON body
+ self.http_client.request = simple_async_mock(
+ return_value=FakeResponse(
+ code=500,
+ phrase=b"Internal Server Error",
+ body=b'{"error": "internal_server_error"}',
+ )
+ )
+ with self.assertRaises(OidcError) as exc:
+ yield defer.ensureDeferred(self.handler._exchange_code(code))
+ self.assertEqual(exc.exception.error, "internal_server_error")
+
+ # 4xx error without "error" field
+ self.http_client.request = simple_async_mock(
+ return_value=FakeResponse(code=400, phrase=b"Bad request", body=b"{}",)
+ )
+ with self.assertRaises(OidcError) as exc:
+ yield defer.ensureDeferred(self.handler._exchange_code(code))
+ self.assertEqual(exc.exception.error, "server_error")
+
+ # 2xx error with "error" field
+ self.http_client.request = simple_async_mock(
+ return_value=FakeResponse(
+ code=200, phrase=b"OK", body=b'{"error": "some_error"}',
+ )
+ )
+ with self.assertRaises(OidcError) as exc:
+ yield defer.ensureDeferred(self.handler._exchange_code(code))
+ self.assertEqual(exc.exception.error, "some_error")
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 7b92bdbc47..572df8d80b 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -185,7 +185,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
# Allow all users.
return False
- spam_checker.spam_checker = AllowAll()
+ spam_checker.spam_checkers = [AllowAll()]
# The results do not change:
# We get one search result when searching for user2 by user1.
@@ -198,7 +198,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
# All users are spammy.
return True
- spam_checker.spam_checker = BlockAll()
+ spam_checker.spam_checkers = [BlockAll()]
# User1 now gets no search results for any of the other users.
s = self.get_success(self.handler.search_users(u1, "user2", 10))
diff --git a/tests/replication/tcp/streams/_base.py b/tests/replication/tcp/streams/_base.py
index 7b56d2028d..9d4f0bbe44 100644
--- a/tests/replication/tcp/streams/_base.py
+++ b/tests/replication/tcp/streams/_base.py
@@ -27,6 +27,7 @@ from synapse.app.generic_worker import (
GenericWorkerServer,
)
from synapse.http.site import SynapseRequest
+from synapse.replication.http import streams
from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
@@ -42,6 +43,10 @@ logger = logging.getLogger(__name__)
class BaseStreamTestCase(unittest.HomeserverTestCase):
"""Base class for tests of the replication streams"""
+ servlets = [
+ streams.register_servlets,
+ ]
+
def prepare(self, reactor, clock, hs):
# build a replication server
server_factory = ReplicationStreamProtocolFactory(hs)
@@ -49,17 +54,11 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self.server = server_factory.buildProtocol(None)
# Make a new HomeServer object for the worker
- config = self.default_config()
- config["worker_app"] = "synapse.app.generic_worker"
- config["worker_replication_host"] = "testserv"
- config["worker_replication_http_port"] = "8765"
-
self.reactor.lookups["testserv"] = "1.2.3.4"
-
self.worker_hs = self.setup_test_homeserver(
http_client=None,
homeserverToUse=GenericWorkerServer,
- config=config,
+ config=self._get_worker_hs_config(),
reactor=self.reactor,
)
@@ -78,6 +77,13 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self._client_transport = None
self._server_transport = None
+ def _get_worker_hs_config(self) -> dict:
+ config = self.default_config()
+ config["worker_app"] = "synapse.app.generic_worker"
+ config["worker_replication_host"] = "testserv"
+ config["worker_replication_http_port"] = "8765"
+ return config
+
def _build_replication_data_handler(self):
return TestReplicationDataHandler(self.worker_hs)
diff --git a/tests/replication/tcp/streams/test_federation.py b/tests/replication/tcp/streams/test_federation.py
new file mode 100644
index 0000000000..eea4565da3
--- /dev/null
+++ b/tests/replication/tcp/streams/test_federation.py
@@ -0,0 +1,81 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.federation.send_queue import EduRow
+from synapse.replication.tcp.streams.federation import FederationStream
+
+from tests.replication.tcp.streams._base import BaseStreamTestCase
+
+
+class FederationStreamTestCase(BaseStreamTestCase):
+ def _get_worker_hs_config(self) -> dict:
+ # enable federation sending on the worker
+ config = super()._get_worker_hs_config()
+ # TODO: make it so we don't need both of these
+ config["send_federation"] = True
+ config["worker_app"] = "synapse.app.federation_sender"
+ return config
+
+ def test_catchup(self):
+ """Basic test of catchup on reconnect
+
+ Makes sure that updates sent while we are offline are received later.
+ """
+ fed_sender = self.hs.get_federation_sender()
+ received_rows = self.test_handler.received_rdata_rows
+
+ fed_sender.build_and_send_edu("testdest", "m.test_edu", {"a": "b"})
+
+ self.reconnect()
+ self.reactor.advance(0)
+
+ # check we're testing what we think we are: no rows should yet have been
+ # received
+ self.assertEqual(received_rows, [])
+
+ # We should now see an attempt to connect to the master
+ request = self.handle_http_replication_attempt()
+ self.assert_request_is_get_repl_stream_updates(request, "federation")
+
+ # we should have received an update row
+ stream_name, token, row = received_rows.pop()
+ self.assertEqual(stream_name, "federation")
+ self.assertIsInstance(row, FederationStream.FederationStreamRow)
+ self.assertEqual(row.type, EduRow.TypeId)
+ edurow = EduRow.from_data(row.data)
+ self.assertEqual(edurow.edu.edu_type, "m.test_edu")
+ self.assertEqual(edurow.edu.origin, self.hs.hostname)
+ self.assertEqual(edurow.edu.destination, "testdest")
+ self.assertEqual(edurow.edu.content, {"a": "b"})
+
+ self.assertEqual(received_rows, [])
+
+ # additional updates should be transferred without an HTTP hit
+ fed_sender.build_and_send_edu("testdest", "m.test1", {"c": "d"})
+ self.reactor.advance(0)
+ # there should be no http hit
+ self.assertEqual(len(self.reactor.tcpClients), 0)
+ # ... but we should have a row
+ self.assertEqual(len(received_rows), 1)
+
+ stream_name, token, row = received_rows.pop()
+ self.assertEqual(stream_name, "federation")
+ self.assertIsInstance(row, FederationStream.FederationStreamRow)
+ self.assertEqual(row.type, EduRow.TypeId)
+ edurow = EduRow.from_data(row.data)
+ self.assertEqual(edurow.edu.edu_type, "m.test1")
+ self.assertEqual(edurow.edu.origin, self.hs.hostname)
+ self.assertEqual(edurow.edu.destination, "testdest")
+ self.assertEqual(edurow.edu.content, {"c": "d"})
diff --git a/tests/replication/tcp/streams/test_typing.py b/tests/replication/tcp/streams/test_typing.py
index d25a7b194e..125c63dab5 100644
--- a/tests/replication/tcp/streams/test_typing.py
+++ b/tests/replication/tcp/streams/test_typing.py
@@ -15,7 +15,6 @@
from mock import Mock
from synapse.handlers.typing import RoomMember
-from synapse.replication.http import streams
from synapse.replication.tcp.streams import TypingStream
from tests.replication.tcp.streams._base import BaseStreamTestCase
@@ -24,10 +23,6 @@ USER_ID = "@feeling:blue"
class TypingStreamTestCase(BaseStreamTestCase):
- servlets = [
- streams.register_servlets,
- ]
-
def _build_replication_data_handler(self):
return Mock(wraps=super()._build_replication_data_handler())
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index 249c93722f..54cd24bf64 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -701,6 +701,47 @@ class RoomTestCase(unittest.HomeserverTestCase):
_search_test(None, "bar")
_search_test(None, "", expected_http_code=400)
+ def test_single_room(self):
+ """Test that a single room can be requested correctly"""
+ # Create two test rooms
+ room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+ room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+ room_name_1 = "something"
+ room_name_2 = "else"
+
+ # Set the name for each room
+ self.helper.send_state(
+ room_id_1, "m.room.name", {"name": room_name_1}, tok=self.admin_user_tok,
+ )
+ self.helper.send_state(
+ room_id_2, "m.room.name", {"name": room_name_2}, tok=self.admin_user_tok,
+ )
+
+ url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,)
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ self.assertIn("room_id", channel.json_body)
+ self.assertIn("name", channel.json_body)
+ self.assertIn("canonical_alias", channel.json_body)
+ self.assertIn("joined_members", channel.json_body)
+ self.assertIn("joined_local_members", channel.json_body)
+ self.assertIn("version", channel.json_body)
+ self.assertIn("creator", channel.json_body)
+ self.assertIn("encryption", channel.json_body)
+ self.assertIn("federatable", channel.json_body)
+ self.assertIn("public", channel.json_body)
+ self.assertIn("join_rules", channel.json_body)
+ self.assertIn("guest_access", channel.json_body)
+ self.assertIn("history_visibility", channel.json_body)
+ self.assertIn("state_events", channel.json_body)
+
+ self.assertEqual(room_id_1, channel.json_body["room_id"])
+
class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
new file mode 100644
index 0000000000..55e9ecf264
--- /dev/null
+++ b/tests/storage/test_id_generators.py
@@ -0,0 +1,184 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from synapse.storage.database import Database
+from synapse.storage.util.id_generators import MultiWriterIdGenerator
+
+from tests.unittest import HomeserverTestCase
+from tests.utils import USE_POSTGRES_FOR_TESTS
+
+
+class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
+ if not USE_POSTGRES_FOR_TESTS:
+ skip = "Requires Postgres"
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+ self.db = self.store.db # type: Database
+
+ self.get_success(self.db.runInteraction("_setup_db", self._setup_db))
+
+ def _setup_db(self, txn):
+ txn.execute("CREATE SEQUENCE foobar_seq")
+ txn.execute(
+ """
+ CREATE TABLE foobar (
+ stream_id BIGINT NOT NULL,
+ instance_name TEXT NOT NULL,
+ data TEXT
+ );
+ """
+ )
+
+ def _create_id_generator(self, instance_name="master") -> MultiWriterIdGenerator:
+ def _create(conn):
+ return MultiWriterIdGenerator(
+ conn,
+ self.db,
+ instance_name=instance_name,
+ table="foobar",
+ instance_column="instance_name",
+ id_column="stream_id",
+ sequence_name="foobar_seq",
+ )
+
+ return self.get_success(self.db.runWithConnection(_create))
+
+ def _insert_rows(self, instance_name: str, number: int):
+ def _insert(txn):
+ for _ in range(number):
+ txn.execute(
+ "INSERT INTO foobar VALUES (nextval('foobar_seq'), ?)",
+ (instance_name,),
+ )
+
+ self.get_success(self.db.runInteraction("test_single_instance", _insert))
+
+ def test_empty(self):
+ """Test an ID generator against an empty database gives sensible
+ current positions.
+ """
+
+ id_gen = self._create_id_generator()
+
+ # The table is empty so we expect an empty map for positions
+ self.assertEqual(id_gen.get_positions(), {})
+
+ def test_single_instance(self):
+ """Test that reads and writes from a single process are handled
+ correctly.
+ """
+
+ # Prefill table with 7 rows written by 'master'
+ self._insert_rows("master", 7)
+
+ id_gen = self._create_id_generator()
+
+ self.assertEqual(id_gen.get_positions(), {"master": 7})
+ self.assertEqual(id_gen.get_current_token("master"), 7)
+
+ # Try allocating a new ID gen and check that we only see position
+ # advanced after we leave the context manager.
+
+ async def _get_next_async():
+ with await id_gen.get_next() as stream_id:
+ self.assertEqual(stream_id, 8)
+
+ self.assertEqual(id_gen.get_positions(), {"master": 7})
+ self.assertEqual(id_gen.get_current_token("master"), 7)
+
+ self.get_success(_get_next_async())
+
+ self.assertEqual(id_gen.get_positions(), {"master": 8})
+ self.assertEqual(id_gen.get_current_token("master"), 8)
+
+ def test_multi_instance(self):
+ """Test that reads and writes from multiple processes are handled
+ correctly.
+ """
+ self._insert_rows("first", 3)
+ self._insert_rows("second", 4)
+
+ first_id_gen = self._create_id_generator("first")
+ second_id_gen = self._create_id_generator("second")
+
+ self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
+ self.assertEqual(first_id_gen.get_current_token("first"), 3)
+ self.assertEqual(first_id_gen.get_current_token("second"), 7)
+
+ # Try allocating a new ID gen and check that we only see position
+ # advanced after we leave the context manager.
+
+ async def _get_next_async():
+ with await first_id_gen.get_next() as stream_id:
+ self.assertEqual(stream_id, 8)
+
+ self.assertEqual(
+ first_id_gen.get_positions(), {"first": 3, "second": 7}
+ )
+
+ self.get_success(_get_next_async())
+
+ self.assertEqual(first_id_gen.get_positions(), {"first": 8, "second": 7})
+
+ # However the ID gen on the second instance won't have seen the update
+ self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7})
+
+ # ... but calling `get_next` on the second instance should give a unique
+ # stream ID
+
+ async def _get_next_async():
+ with await second_id_gen.get_next() as stream_id:
+ self.assertEqual(stream_id, 9)
+
+ self.assertEqual(
+ second_id_gen.get_positions(), {"first": 3, "second": 7}
+ )
+
+ self.get_success(_get_next_async())
+
+ self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 9})
+
+ # If the second ID gen gets told about the first, it correctly updates
+ second_id_gen.advance("first", 8)
+ self.assertEqual(second_id_gen.get_positions(), {"first": 8, "second": 9})
+
+ def test_get_next_txn(self):
+ """Test that the `get_next_txn` function works correctly.
+ """
+
+ # Prefill table with 7 rows written by 'master'
+ self._insert_rows("master", 7)
+
+ id_gen = self._create_id_generator()
+
+ self.assertEqual(id_gen.get_positions(), {"master": 7})
+ self.assertEqual(id_gen.get_current_token("master"), 7)
+
+ # Try allocating a new ID gen and check that we only see position
+ # advanced after we leave the context manager.
+
+ def _get_next_txn(txn):
+ stream_id = id_gen.get_next_txn(txn)
+ self.assertEqual(stream_id, 8)
+
+ self.assertEqual(id_gen.get_positions(), {"master": 7})
+ self.assertEqual(id_gen.get_current_token("master"), 7)
+
+ self.get_success(self.db.runInteraction("test", _get_next_txn))
+
+ self.assertEqual(id_gen.get_positions(), {"master": 8})
+ self.assertEqual(id_gen.get_current_token("master"), 8)
diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py
index 086adeb8fd..3b78d48896 100644
--- a/tests/storage/test_room.py
+++ b/tests/storage/test_room.py
@@ -55,6 +55,17 @@ class RoomStoreTestCase(unittest.TestCase):
(yield self.store.get_room(self.room.to_string())),
)
+ @defer.inlineCallbacks
+ def test_get_room_with_stats(self):
+ self.assertDictContainsSubset(
+ {
+ "room_id": self.room.to_string(),
+ "creator": self.u_creator.to_string(),
+ "public": True,
+ },
+ (yield self.store.get_room_with_stats(self.room.to_string())),
+ )
+
class RoomEventsStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
|