summary refs log tree commit diff
diff options
context:
space:
mode:
-rwxr-xr-x.buildkite/scripts/create_postgres_db.py1
-rw-r--r--changelog.d/9036.feature1
-rw-r--r--changelog.d/9038.misc1
-rw-r--r--changelog.d/9039.removal1
-rw-r--r--changelog.d/9042.feature1
-rw-r--r--changelog.d/9043.feature1
-rw-r--r--changelog.d/9044.feature1
-rw-r--r--changelog.d/9051.bugfix1
-rw-r--r--changelog.d/9053.bugfix1
-rw-r--r--changelog.d/9054.bugfix1
-rw-r--r--demo/webserver.py59
-rw-r--r--mypy.ini1
-rwxr-xr-xscripts/synapse_port_db27
-rw-r--r--stubs/frozendict.pyi11
-rw-r--r--stubs/sortedcontainers/sorteddict.pyi6
-rw-r--r--stubs/txredisapi.pyi2
-rw-r--r--synapse/app/generic_worker.py3
-rw-r--r--synapse/config/_util.py2
-rw-r--r--synapse/config/workers.py10
-rw-r--r--synapse/federation/federation_server.py21
-rw-r--r--synapse/handlers/devicemessage.py48
-rw-r--r--synapse/notifier.py39
-rw-r--r--synapse/replication/slave/storage/deviceinbox.py32
-rw-r--r--synapse/replication/tcp/handler.py9
-rw-r--r--synapse/rest/client/v1/login.py4
-rw-r--r--synapse/storage/databases/main/__init__.py33
-rw-r--r--synapse/storage/databases/main/client_ips.py54
-rw-r--r--synapse/storage/databases/main/deviceinbox.py147
-rw-r--r--synapse/storage/databases/main/schema/delta/59/02shard_send_to_device.sql18
-rw-r--r--synapse/storage/databases/main/schema/delta/59/03shard_send_to_device_sequence.sql.postgres25
-rw-r--r--synapse/util/metrics.py3
-rw-r--r--tests/config/test_util.py53
-rw-r--r--tests/rest/client/v1/test_login.py282
-rw-r--r--tests/rest/client/v1/utils.py3
-rw-r--r--tox.ini20
35 files changed, 641 insertions, 281 deletions
diff --git a/.buildkite/scripts/create_postgres_db.py b/.buildkite/scripts/create_postgres_db.py

index df6082b0ac..956339de5c 100755 --- a/.buildkite/scripts/create_postgres_db.py +++ b/.buildkite/scripts/create_postgres_db.py
@@ -15,6 +15,7 @@ # limitations under the License. import logging + from synapse.storage.engines import create_engine logger = logging.getLogger("create_postgres_db") diff --git a/changelog.d/9036.feature b/changelog.d/9036.feature new file mode 100644
index 0000000000..01a24dcf49 --- /dev/null +++ b/changelog.d/9036.feature
@@ -0,0 +1 @@ +Add support for multiple SSO Identity Providers. diff --git a/changelog.d/9038.misc b/changelog.d/9038.misc new file mode 100644
index 0000000000..5b9e21a1db --- /dev/null +++ b/changelog.d/9038.misc
@@ -0,0 +1 @@ +Configure the linters to run on a consistent set of files. diff --git a/changelog.d/9039.removal b/changelog.d/9039.removal new file mode 100644
index 0000000000..fb99283ed8 --- /dev/null +++ b/changelog.d/9039.removal
@@ -0,0 +1 @@ +Remove broken and unmaintained `demo/webserver.py` script. diff --git a/changelog.d/9042.feature b/changelog.d/9042.feature new file mode 100644
index 0000000000..4ec319f1f2 --- /dev/null +++ b/changelog.d/9042.feature
@@ -0,0 +1 @@ +Add experimental support for handling and persistence of to-device messages to happen on worker processes. diff --git a/changelog.d/9043.feature b/changelog.d/9043.feature new file mode 100644
index 0000000000..4ec319f1f2 --- /dev/null +++ b/changelog.d/9043.feature
@@ -0,0 +1 @@ +Add experimental support for handling and persistence of to-device messages to happen on worker processes. diff --git a/changelog.d/9044.feature b/changelog.d/9044.feature new file mode 100644
index 0000000000..4ec319f1f2 --- /dev/null +++ b/changelog.d/9044.feature
@@ -0,0 +1 @@ +Add experimental support for handling and persistence of to-device messages to happen on worker processes. diff --git a/changelog.d/9051.bugfix b/changelog.d/9051.bugfix new file mode 100644
index 0000000000..272be9d7a3 --- /dev/null +++ b/changelog.d/9051.bugfix
@@ -0,0 +1 @@ +Fix error handling during insertion of client IPs into the database. diff --git a/changelog.d/9053.bugfix b/changelog.d/9053.bugfix new file mode 100644
index 0000000000..3d8bbf11a1 --- /dev/null +++ b/changelog.d/9053.bugfix
@@ -0,0 +1 @@ +Fix bug where we didn't correctly record CPU time spent in 'on_new_event' block. diff --git a/changelog.d/9054.bugfix b/changelog.d/9054.bugfix new file mode 100644
index 0000000000..0bfe951f17 --- /dev/null +++ b/changelog.d/9054.bugfix
@@ -0,0 +1 @@ +Fix a minor bug which could cause confusing error messages from invalid configurations. diff --git a/demo/webserver.py b/demo/webserver.py deleted file mode 100644
index ba176d3bd2..0000000000 --- a/demo/webserver.py +++ /dev/null
@@ -1,59 +0,0 @@ -import argparse -import BaseHTTPServer -import os -import SimpleHTTPServer -import cgi, logging - -from daemonize import Daemonize - - -class SimpleHTTPRequestHandlerWithPOST(SimpleHTTPServer.SimpleHTTPRequestHandler): - UPLOAD_PATH = "upload" - - """ - Accept all post request as file upload - """ - - def do_POST(self): - - path = os.path.join(self.UPLOAD_PATH, os.path.basename(self.path)) - length = self.headers["content-length"] - data = self.rfile.read(int(length)) - - with open(path, "wb") as fh: - fh.write(data) - - self.send_response(200) - self.send_header("Content-Type", "application/json") - self.end_headers() - - # Return the absolute path of the uploaded file - self.wfile.write('{"url":"/%s"}' % path) - - -def setup(): - parser = argparse.ArgumentParser() - parser.add_argument("directory") - parser.add_argument("-p", "--port", dest="port", type=int, default=8080) - parser.add_argument("-P", "--pid-file", dest="pid", default="web.pid") - args = parser.parse_args() - - # Get absolute path to directory to serve, as daemonize changes to '/' - os.chdir(args.directory) - dr = os.getcwd() - - httpd = BaseHTTPServer.HTTPServer(("", args.port), SimpleHTTPRequestHandlerWithPOST) - - def run(): - os.chdir(dr) - httpd.serve_forever() - - daemon = Daemonize( - app="synapse-webclient", pid=args.pid, action=run, auto_close_fds=False - ) - - daemon.start() - - -if __name__ == "__main__": - setup() diff --git a/mypy.ini b/mypy.ini
index 5d15b7bf1c..b996867121 100644 --- a/mypy.ini +++ b/mypy.ini
@@ -103,6 +103,7 @@ files = tests/replication, tests/test_utils, tests/handlers/test_password_providers.py, + tests/rest/client/v1/test_login.py, tests/rest/client/v2_alpha/test_auth.py, tests/util/test_stream_change_cache.py diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db
index c982ca9350..e238f10d26 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db
@@ -631,6 +631,7 @@ class Porter(object): await self._setup_state_group_id_seq() await self._setup_user_id_seq() await self._setup_events_stream_seqs() + await self._setup_device_inbox_seq() # Step 3. Get tables. self.progress.set_state("Fetching tables") @@ -913,6 +914,32 @@ class Porter(object): "_setup_events_stream_seqs", _setup_events_stream_seqs_set_pos, ) + async def _setup_device_inbox_seq(self): + """Set the device inbox sequence to the correct value. + """ + curr_local_id = await self.sqlite_store.db_pool.simple_select_one_onecol( + table="device_inbox", + keyvalues={}, + retcol="COALESCE(MAX(stream_id), 1)", + allow_none=True, + ) + + curr_federation_id = await self.sqlite_store.db_pool.simple_select_one_onecol( + table="device_federation_outbox", + keyvalues={}, + retcol="COALESCE(MAX(stream_id), 1)", + allow_none=True, + ) + + next_id = max(curr_local_id, curr_federation_id) + 1 + + def r(txn): + txn.execute( + "ALTER SEQUENCE device_inbox_sequence RESTART WITH %s", (next_id,) + ) + + return self.postgres_store.db_pool.runInteraction("_setup_device_inbox_seq", r) + ############################################## # The following is simply UI stuff diff --git a/stubs/frozendict.pyi b/stubs/frozendict.pyi
index 3f3af59f26..0368ba4703 100644 --- a/stubs/frozendict.pyi +++ b/stubs/frozendict.pyi
@@ -15,16 +15,7 @@ # Stub for frozendict. -from typing import ( - Any, - Hashable, - Iterable, - Iterator, - Mapping, - overload, - Tuple, - TypeVar, -) +from typing import Any, Hashable, Iterable, Iterator, Mapping, Tuple, TypeVar, overload _KT = TypeVar("_KT", bound=Hashable) # Key type. _VT = TypeVar("_VT") # Value type. diff --git a/stubs/sortedcontainers/sorteddict.pyi b/stubs/sortedcontainers/sorteddict.pyi
index 68779f968e..7b9fd079d9 100644 --- a/stubs/sortedcontainers/sorteddict.pyi +++ b/stubs/sortedcontainers/sorteddict.pyi
@@ -7,17 +7,17 @@ from typing import ( Callable, Dict, Hashable, - Iterator, - Iterable, ItemsView, + Iterable, + Iterator, KeysView, List, Mapping, Optional, Sequence, + Tuple, Type, TypeVar, - Tuple, Union, ValuesView, overload, diff --git a/stubs/txredisapi.pyi b/stubs/txredisapi.pyi
index 522244bb57..bfac6840e6 100644 --- a/stubs/txredisapi.pyi +++ b/stubs/txredisapi.pyi
@@ -16,7 +16,7 @@ """Contains *incomplete* type hints for txredisapi. """ -from typing import List, Optional, Union, Type +from typing import List, Optional, Type, Union class RedisProtocol: def publish(self, channel: str, message: bytes): ... diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index 30d53a3aed..038f698801 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py
@@ -108,6 +108,7 @@ from synapse.rest.client.v2_alpha.account_data import ( ) from synapse.rest.client.v2_alpha.keys import KeyChangesServlet, KeyQueryServlet from synapse.rest.client.v2_alpha.register import RegisterRestServlet +from synapse.rest.client.v2_alpha.sendtodevice import SendToDeviceRestServlet from synapse.rest.client.versions import VersionsRestServlet from synapse.rest.health import HealthResource from synapse.rest.key.v2 import KeyApiV2Resource @@ -520,6 +521,8 @@ class GenericWorkerServer(HomeServer): room.register_deprecated_servlets(self, resource) InitialSyncRestServlet(self).register(resource) + SendToDeviceRestServlet(self).register(resource) + user_directory.register_servlets(self, resource) # If presence is disabled, use the stub servlet that does diff --git a/synapse/config/_util.py b/synapse/config/_util.py
index 1bbe83c317..8fce7f6bb1 100644 --- a/synapse/config/_util.py +++ b/synapse/config/_util.py
@@ -56,7 +56,7 @@ def json_error_to_config_error( """ # copy `config_path` before modifying it. path = list(config_path) - for p in list(e.path): + for p in list(e.absolute_path): if isinstance(p, int): path.append("<item %i>" % p) else: diff --git a/synapse/config/workers.py b/synapse/config/workers.py
index 7ca9efec52..364583f48b 100644 --- a/synapse/config/workers.py +++ b/synapse/config/workers.py
@@ -53,6 +53,9 @@ class WriterLocations: default=["master"], type=List[str], converter=_instance_to_list_converter ) typing = attr.ib(default="master", type=str) + to_device = attr.ib( + default=["master"], type=List[str], converter=_instance_to_list_converter, + ) class WorkerConfig(Config): @@ -124,7 +127,7 @@ class WorkerConfig(Config): # Check that the configured writers for events and typing also appears in # `instance_map`. - for stream in ("events", "typing"): + for stream in ("events", "typing", "to_device"): instances = _instance_to_list_converter(getattr(self.writers, stream)) for instance in instances: if instance != "master" and instance not in self.instance_map: @@ -133,6 +136,11 @@ class WorkerConfig(Config): % (instance, stream) ) + if len(self.writers.to_device) != 1: + raise ConfigError( + "Must only specify one instance to handle `to_device` messages." + ) + self.events_shard_config = ShardedWorkerHandlingConfig(self.writers.events) # Whether this worker should run background tasks or not. diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index d8a2caf75f..d029e1cd15 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py
@@ -15,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +import random from typing import ( TYPE_CHECKING, Any, @@ -931,8 +932,10 @@ class FederationHandlerRegistry: ) # type: Dict[str, Callable[[str, dict], Awaitable[None]]] self.query_handlers = {} # type: Dict[str, Callable[[dict], Awaitable[None]]] - # Map from type to instance name that we should route EDU handling to. - self._edu_type_to_instance = {} # type: Dict[str, str] + # Map from type to instance names that we should route EDU handling to. + # We randomly choose one instance from the list to route to for each new + # EDU received. + self._edu_type_to_instance = {} # type: Dict[str, List[str]] def register_edu_handler( self, edu_type: str, handler: Callable[[str, JsonDict], Awaitable[None]] @@ -976,7 +979,12 @@ class FederationHandlerRegistry: def register_instance_for_edu(self, edu_type: str, instance_name: str): """Register that the EDU handler is on a different instance than master. """ - self._edu_type_to_instance[edu_type] = instance_name + self._edu_type_to_instance[edu_type] = [instance_name] + + def register_instances_for_edu(self, edu_type: str, instance_names: List[str]): + """Register that the EDU handler is on multiple instances. + """ + self._edu_type_to_instance[edu_type] = instance_names async def on_edu(self, edu_type: str, origin: str, content: dict): if not self.config.use_presence and edu_type == "m.presence": @@ -995,8 +1003,11 @@ class FederationHandlerRegistry: return # Check if we can route it somewhere else that isn't us - route_to = self._edu_type_to_instance.get(edu_type, "master") - if route_to != self._instance_name: + instances = self._edu_type_to_instance.get(edu_type, ["master"]) + if self._instance_name not in instances: + # Pick an instance randomly so that we don't overload one. + route_to = random.choice(instances) + try: await self._send_edu( instance_name=route_to, diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index 9cac5a8463..fc974a82e8 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py
@@ -24,6 +24,7 @@ from synapse.logging.opentracing import ( set_tag, start_active_span, ) +from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet from synapse.types import JsonDict, UserID, get_domain_from_id from synapse.util import json_encoder from synapse.util.stringutils import random_string @@ -44,13 +45,37 @@ class DeviceMessageHandler: self.store = hs.get_datastore() self.notifier = hs.get_notifier() self.is_mine = hs.is_mine - self.federation = hs.get_federation_sender() - hs.get_federation_registry().register_edu_handler( - "m.direct_to_device", self.on_direct_to_device_edu - ) + # We only need to poke the federation sender explicitly if its on the + # same instance. Other federation sender instances will get notified by + # `synapse.app.generic_worker.FederationSenderHandler` when it sees it + # in the to-device replication stream. + self.federation_sender = None + if hs.should_send_federation(): + self.federation_sender = hs.get_federation_sender() + + # If we can handle the to device EDUs we do so, otherwise we route them + # to the appropriate worker. + if hs.get_instance_name() in hs.config.worker.writers.to_device: + hs.get_federation_registry().register_edu_handler( + "m.direct_to_device", self.on_direct_to_device_edu + ) + else: + hs.get_federation_registry().register_instances_for_edu( + "m.direct_to_device", hs.config.worker.writers.to_device, + ) - self._device_list_updater = hs.get_device_handler().device_list_updater + # The handler to call when we think a user's device list might be out of + # sync. We do all device list resyncing on the master instance, so if + # we're on a worker we hit the device resync replication API. + if hs.config.worker.worker_app is None: + self._user_device_resync = ( + hs.get_device_handler().device_list_updater.user_device_resync + ) + else: + self._user_device_resync = ReplicationUserDevicesResyncRestServlet.make_client( + hs + ) async def on_direct_to_device_edu(self, origin: str, content: JsonDict) -> None: local_messages = {} @@ -138,9 +163,7 @@ class DeviceMessageHandler: await self.store.mark_remote_user_device_cache_as_stale(sender_user_id) # Immediately attempt a resync in the background - run_in_background( - self._device_list_updater.user_device_resync, sender_user_id - ) + run_in_background(self._user_device_resync, sender_user_id) async def send_device_message( self, @@ -195,7 +218,8 @@ class DeviceMessageHandler: ) log_kv({"remote_messages": remote_messages}) - for destination in remote_messages.keys(): - # Enqueue a new federation transaction to send the new - # device messages to each remote destination. - self.federation.send_device_messages(destination) + if self.federation_sender: + for destination in remote_messages.keys(): + # Enqueue a new federation transaction to send the new + # device messages to each remote destination. + self.federation_sender.send_device_messages(destination) diff --git a/synapse/notifier.py b/synapse/notifier.py
index c4c8bb271d..0745899b48 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py
@@ -396,31 +396,30 @@ class Notifier: Will wake up all listeners for the given users and rooms. """ - with PreserveLoggingContext(): - with Measure(self.clock, "on_new_event"): - user_streams = set() + with Measure(self.clock, "on_new_event"): + user_streams = set() - for user in users: - user_stream = self.user_to_user_stream.get(str(user)) - if user_stream is not None: - user_streams.add(user_stream) + for user in users: + user_stream = self.user_to_user_stream.get(str(user)) + if user_stream is not None: + user_streams.add(user_stream) - for room in rooms: - user_streams |= self.room_to_user_streams.get(room, set()) + for room in rooms: + user_streams |= self.room_to_user_streams.get(room, set()) - time_now_ms = self.clock.time_msec() - for user_stream in user_streams: - try: - user_stream.notify(stream_key, new_token, time_now_ms) - except Exception: - logger.exception("Failed to notify listener") + time_now_ms = self.clock.time_msec() + for user_stream in user_streams: + try: + user_stream.notify(stream_key, new_token, time_now_ms) + except Exception: + logger.exception("Failed to notify listener") - self.notify_replication() + self.notify_replication() - # Notify appservices - self._notify_app_services_ephemeral( - stream_key, new_token, users, - ) + # Notify appservices + self._notify_app_services_ephemeral( + stream_key, new_token, users, + ) def on_new_replication_data(self) -> None: """Used to inform replication listeners that something has happened diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py
index 62b68dd6e9..1260f6d141 100644 --- a/synapse/replication/slave/storage/deviceinbox.py +++ b/synapse/replication/slave/storage/deviceinbox.py
@@ -14,38 +14,8 @@ # limitations under the License. from synapse.replication.slave.storage._base import BaseSlavedStore -from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker -from synapse.replication.tcp.streams import ToDeviceStream -from synapse.storage.database import DatabasePool from synapse.storage.databases.main.deviceinbox import DeviceInboxWorkerStore -from synapse.util.caches.stream_change_cache import StreamChangeCache class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore): - def __init__(self, database: DatabasePool, db_conn, hs): - super().__init__(database, db_conn, hs) - self._device_inbox_id_gen = SlavedIdTracker( - db_conn, "device_inbox", "stream_id" - ) - self._device_inbox_stream_cache = StreamChangeCache( - "DeviceInboxStreamChangeCache", - self._device_inbox_id_gen.get_current_token(), - ) - self._device_federation_outbox_stream_cache = StreamChangeCache( - "DeviceFederationOutboxStreamChangeCache", - self._device_inbox_id_gen.get_current_token(), - ) - - def process_replication_rows(self, stream_name, instance_name, token, rows): - if stream_name == ToDeviceStream.NAME: - self._device_inbox_id_gen.advance(instance_name, token) - for row in rows: - if row.entity.startswith("@"): - self._device_inbox_stream_cache.entity_has_changed( - row.entity, token - ) - else: - self._device_federation_outbox_stream_cache.entity_has_changed( - row.entity, token - ) - return super().process_replication_rows(stream_name, instance_name, token, rows) + pass diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 95e5502bf2..1f89249475 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py
@@ -56,6 +56,7 @@ from synapse.replication.tcp.streams import ( EventsStream, FederationStream, Stream, + ToDeviceStream, TypingStream, ) @@ -115,6 +116,14 @@ class ReplicationCommandHandler: continue + if isinstance(stream, ToDeviceStream): + # Only add ToDeviceStream as a source on instances in charge of + # sending to device messages. + if hs.get_instance_name() in hs.config.worker.writers.to_device: + self._streams_to_replicate.append(stream) + + continue + if isinstance(stream, TypingStream): # Only add TypingStream as a source on the instance in charge of # typing. diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index ebc346105b..be938df962 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py
@@ -319,9 +319,9 @@ class SsoRedirectServlet(RestServlet): # register themselves with the main SSOHandler. if hs.config.cas_enabled: hs.get_cas_handler() - elif hs.config.saml2_enabled: + if hs.config.saml2_enabled: hs.get_saml_handler() - elif hs.config.oidc_enabled: + if hs.config.oidc_enabled: hs.get_oidc_handler() self._sso_handler = hs.get_sso_handler() diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 701748f93b..c4de07a0a8 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py
@@ -127,9 +127,6 @@ class DataStore( self._presence_id_gen = StreamIdGenerator( db_conn, "presence_stream", "stream_id" ) - self._device_inbox_id_gen = StreamIdGenerator( - db_conn, "device_inbox", "stream_id" - ) self._public_room_id_gen = StreamIdGenerator( db_conn, "public_room_list_stream", "stream_id" ) @@ -189,36 +186,6 @@ class DataStore( prefilled_cache=presence_cache_prefill, ) - max_device_inbox_id = self._device_inbox_id_gen.get_current_token() - device_inbox_prefill, min_device_inbox_id = self.db_pool.get_cache_dict( - db_conn, - "device_inbox", - entity_column="user_id", - stream_column="stream_id", - max_value=max_device_inbox_id, - limit=1000, - ) - self._device_inbox_stream_cache = StreamChangeCache( - "DeviceInboxStreamChangeCache", - min_device_inbox_id, - prefilled_cache=device_inbox_prefill, - ) - # The federation outbox and the local device inbox uses the same - # stream_id generator. - device_outbox_prefill, min_device_outbox_id = self.db_pool.get_cache_dict( - db_conn, - "device_federation_outbox", - entity_column="destination", - stream_column="stream_id", - max_value=max_device_inbox_id, - limit=1000, - ) - self._device_federation_outbox_stream_cache = StreamChangeCache( - "DeviceFederationOutboxStreamChangeCache", - min_device_outbox_id, - prefilled_cache=device_outbox_prefill, - ) - device_list_max = self._device_list_id_gen.get_current_token() self._device_list_stream_cache = StreamChangeCache( "DeviceListStreamChangeCache", device_list_max diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index e96a8b3f43..c53c836337 100644 --- a/synapse/storage/databases/main/client_ips.py +++ b/synapse/storage/databases/main/client_ips.py
@@ -470,43 +470,35 @@ class ClientIpStore(ClientIpWorkerStore): for entry in to_update.items(): (user_id, access_token, ip), (user_agent, device_id, last_seen) = entry - try: - self.db_pool.simple_upsert_txn( + self.db_pool.simple_upsert_txn( + txn, + table="user_ips", + keyvalues={"user_id": user_id, "access_token": access_token, "ip": ip}, + values={ + "user_agent": user_agent, + "device_id": device_id, + "last_seen": last_seen, + }, + lock=False, + ) + + # Technically an access token might not be associated with + # a device so we need to check. + if device_id: + # this is always an update rather than an upsert: the row should + # already exist, and if it doesn't, that may be because it has been + # deleted, and we don't want to re-create it. + self.db_pool.simple_update_txn( txn, - table="user_ips", - keyvalues={ - "user_id": user_id, - "access_token": access_token, - "ip": ip, - }, - values={ + table="devices", + keyvalues={"user_id": user_id, "device_id": device_id}, + updatevalues={ "user_agent": user_agent, - "device_id": device_id, "last_seen": last_seen, + "ip": ip, }, - lock=False, ) - # Technically an access token might not be associated with - # a device so we need to check. - if device_id: - # this is always an update rather than an upsert: the row should - # already exist, and if it doesn't, that may be because it has been - # deleted, and we don't want to re-create it. - self.db_pool.simple_update_txn( - txn, - table="devices", - keyvalues={"user_id": user_id, "device_id": device_id}, - updatevalues={ - "user_agent": user_agent, - "last_seen": last_seen, - "ip": ip, - }, - ) - except Exception as e: - # Failed to upsert, log and continue - logger.error("Failed to insert client IP %r: %r", entry, e) - async def get_last_client_ip_by_device( self, user_id: str, device_id: Optional[str] ) -> Dict[Tuple[str, str], dict]: diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index eb72c21155..58d3f71e45 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py
@@ -17,10 +17,14 @@ import logging from typing import List, Tuple from synapse.logging.opentracing import log_kv, set_tag, trace +from synapse.replication.tcp.streams import ToDeviceStream from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import DatabasePool +from synapse.storage.engines import PostgresEngine +from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.util import json_encoder from synapse.util.caches.expiringcache import ExpiringCache +from synapse.util.caches.stream_change_cache import StreamChangeCache logger = logging.getLogger(__name__) @@ -29,6 +33,8 @@ class DeviceInboxWorkerStore(SQLBaseStore): def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) + self._instance_name = hs.get_instance_name() + # Map of (user_id, device_id) to the last stream_id that has been # deleted up to. This is so that we can no op deletions. self._last_device_delete_cache = ExpiringCache( @@ -38,6 +44,73 @@ class DeviceInboxWorkerStore(SQLBaseStore): expiry_ms=30 * 60 * 1000, ) + if isinstance(database.engine, PostgresEngine): + self._can_write_to_device = ( + self._instance_name in hs.config.worker.writers.to_device + ) + + self._device_inbox_id_gen = MultiWriterIdGenerator( + db_conn=db_conn, + db=database, + stream_name="to_device", + instance_name=self._instance_name, + table="device_inbox", + instance_column="instance_name", + id_column="stream_id", + sequence_name="device_inbox_sequence", + writers=hs.config.worker.writers.to_device, + ) + else: + self._can_write_to_device = True + self._device_inbox_id_gen = StreamIdGenerator( + db_conn, "device_inbox", "stream_id" + ) + + max_device_inbox_id = self._device_inbox_id_gen.get_current_token() + device_inbox_prefill, min_device_inbox_id = self.db_pool.get_cache_dict( + db_conn, + "device_inbox", + entity_column="user_id", + stream_column="stream_id", + max_value=max_device_inbox_id, + limit=1000, + ) + self._device_inbox_stream_cache = StreamChangeCache( + "DeviceInboxStreamChangeCache", + min_device_inbox_id, + prefilled_cache=device_inbox_prefill, + ) + + # The federation outbox and the local device inbox uses the same + # stream_id generator. + device_outbox_prefill, min_device_outbox_id = self.db_pool.get_cache_dict( + db_conn, + "device_federation_outbox", + entity_column="destination", + stream_column="stream_id", + max_value=max_device_inbox_id, + limit=1000, + ) + self._device_federation_outbox_stream_cache = StreamChangeCache( + "DeviceFederationOutboxStreamChangeCache", + min_device_outbox_id, + prefilled_cache=device_outbox_prefill, + ) + + def process_replication_rows(self, stream_name, instance_name, token, rows): + if stream_name == ToDeviceStream.NAME: + self._device_inbox_id_gen.advance(instance_name, token) + for row in rows: + if row.entity.startswith("@"): + self._device_inbox_stream_cache.entity_has_changed( + row.entity, token + ) + else: + self._device_federation_outbox_stream_cache.entity_has_changed( + row.entity, token + ) + return super().process_replication_rows(stream_name, instance_name, token, rows) + def get_to_device_stream_token(self): return self._device_inbox_id_gen.get_current_token() @@ -290,38 +363,6 @@ class DeviceInboxWorkerStore(SQLBaseStore): "get_all_new_device_messages", get_all_new_device_messages_txn ) - -class DeviceInboxBackgroundUpdateStore(SQLBaseStore): - DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" - - def __init__(self, database: DatabasePool, db_conn, hs): - super().__init__(database, db_conn, hs) - - self.db_pool.updates.register_background_index_update( - "device_inbox_stream_index", - index_name="device_inbox_stream_id_user_id", - table="device_inbox", - columns=["stream_id", "user_id"], - ) - - self.db_pool.updates.register_background_update_handler( - self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox - ) - - async def _background_drop_index_device_inbox(self, progress, batch_size): - def reindex_txn(conn): - txn = conn.cursor() - txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id") - txn.close() - - await self.db_pool.runWithConnection(reindex_txn) - - await self.db_pool.updates._end_background_update(self.DEVICE_INBOX_STREAM_ID) - - return 1 - - -class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore): @trace async def add_messages_to_device_inbox( self, @@ -340,6 +381,8 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) The new stream_id. """ + assert self._can_write_to_device + def add_messages_txn(txn, now_ms, stream_id): # Add the local messages directly to the local inbox. self._add_messages_to_local_device_inbox_txn( @@ -358,6 +401,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) "stream_id": stream_id, "queued_ts": now_ms, "messages_json": json_encoder.encode(edu), + "instance_name": self._instance_name, } for destination, edu in remote_messages_by_destination.items() ], @@ -380,6 +424,8 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) async def add_messages_from_remote_to_device_inbox( self, origin: str, message_id: str, local_messages_by_user_then_device: dict ) -> int: + assert self._can_write_to_device + def add_messages_txn(txn, now_ms, stream_id): # Check if we've already inserted a matching message_id for that # origin. This can happen if the origin doesn't receive our @@ -428,6 +474,8 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) def _add_messages_to_local_device_inbox_txn( self, txn, stream_id, messages_by_user_then_device ): + assert self._can_write_to_device + local_by_user_then_device = {} for user_id, messages_by_device in messages_by_user_then_device.items(): messages_json_for_user = {} @@ -481,8 +529,43 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) "device_id": device_id, "stream_id": stream_id, "message_json": message_json, + "instance_name": self._instance_name, } for user_id, messages_by_device in local_by_user_then_device.items() for device_id, message_json in messages_by_device.items() ], ) + + +class DeviceInboxBackgroundUpdateStore(SQLBaseStore): + DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" + + def __init__(self, database: DatabasePool, db_conn, hs): + super().__init__(database, db_conn, hs) + + self.db_pool.updates.register_background_index_update( + "device_inbox_stream_index", + index_name="device_inbox_stream_id_user_id", + table="device_inbox", + columns=["stream_id", "user_id"], + ) + + self.db_pool.updates.register_background_update_handler( + self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox + ) + + async def _background_drop_index_device_inbox(self, progress, batch_size): + def reindex_txn(conn): + txn = conn.cursor() + txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id") + txn.close() + + await self.db_pool.runWithConnection(reindex_txn) + + await self.db_pool.updates._end_background_update(self.DEVICE_INBOX_STREAM_ID) + + return 1 + + +class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore): + pass diff --git a/synapse/storage/databases/main/schema/delta/59/02shard_send_to_device.sql b/synapse/storage/databases/main/schema/delta/59/02shard_send_to_device.sql new file mode 100644
index 0000000000..d781a92fec --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/59/02shard_send_to_device.sql
@@ -0,0 +1,18 @@ +/* Copyright 2021 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE device_inbox ADD COLUMN instance_name TEXT; +ALTER TABLE device_federation_inbox ADD COLUMN instance_name TEXT; +ALTER TABLE device_federation_outbox ADD COLUMN instance_name TEXT; diff --git a/synapse/storage/databases/main/schema/delta/59/03shard_send_to_device_sequence.sql.postgres b/synapse/storage/databases/main/schema/delta/59/03shard_send_to_device_sequence.sql.postgres new file mode 100644
index 0000000000..45a845a3a5 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/59/03shard_send_to_device_sequence.sql.postgres
@@ -0,0 +1,25 @@ +/* Copyright 2021 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE SEQUENCE IF NOT EXISTS device_inbox_sequence; + +-- We need to take the max across both device_inbox and device_federation_outbox +-- tables as they share the ID generator +SELECT setval('device_inbox_sequence', ( + SELECT GREATEST( + (SELECT COALESCE(MAX(stream_id), 1) FROM device_inbox), + (SELECT COALESCE(MAX(stream_id), 1) FROM device_federation_outbox) + ) +)); diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index 24123d5cc4..f4de6b9f54 100644 --- a/synapse/util/metrics.py +++ b/synapse/util/metrics.py
@@ -111,7 +111,8 @@ class Measure: curr_context = current_context() if not curr_context: logger.warning( - "Starting metrics collection from sentinel context: metrics will be lost" + "Starting metrics collection %r from sentinel context: metrics will be lost", + name, ) parent_context = None else: diff --git a/tests/config/test_util.py b/tests/config/test_util.py new file mode 100644
index 0000000000..10363e3765 --- /dev/null +++ b/tests/config/test_util.py
@@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.config import ConfigError +from synapse.config._util import validate_config + +from tests.unittest import TestCase + + +class ValidateConfigTestCase(TestCase): + """Test cases for synapse.config._util.validate_config""" + + def test_bad_object_in_array(self): + """malformed objects within an array should be validated correctly""" + + # consider a structure: + # + # array_of_objs: + # - r: 1 + # foo: 2 + # + # - r: 2 + # bar: 3 + # + # ... where each entry must contain an "r": check that the path + # to the required item is correclty reported. + + schema = { + "type": "object", + "properties": { + "array_of_objs": { + "type": "array", + "items": {"type": "object", "required": ["r"]}, + }, + }, + } + + with self.assertRaises(ConfigError) as c: + validate_config(schema, {"array_of_objs": [{}]}, ("base",)) + + self.assertEqual(c.exception.path, ["base", "array_of_objs", "<item 0>"]) diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index 901c72d36a..1d1dc9f8a2 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py
@@ -1,22 +1,67 @@ -import json +# -*- coding: utf-8 -*- +# Copyright 2019-2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import time import urllib.parse +from html.parser import HTMLParser +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union from mock import Mock -try: - import jwt -except ImportError: - jwt = None +import pymacaroons + +from twisted.web.resource import Resource import synapse.rest.admin from synapse.appservice import ApplicationService from synapse.rest.client.v1 import login, logout from synapse.rest.client.v2_alpha import devices, register from synapse.rest.client.v2_alpha.account import WhoamiRestServlet +from synapse.rest.synapse.client.pick_idp import PickIdpResource from tests import unittest -from tests.unittest import override_config +from tests.handlers.test_oidc import HAS_OIDC +from tests.handlers.test_saml import has_saml2 +from tests.rest.client.v1.utils import TEST_OIDC_AUTH_ENDPOINT, TEST_OIDC_CONFIG +from tests.unittest import override_config, skip_unless + +try: + import jwt + + HAS_JWT = True +except ImportError: + HAS_JWT = False + + +# public_base_url used in some tests +BASE_URL = "https://synapse/" + +# CAS server used in some tests +CAS_SERVER = "https://fake.test" + +# just enough to tell pysaml2 where to redirect to +SAML_SERVER = "https://test.saml.server/idp/sso" +TEST_SAML_METADATA = """ +<md:EntityDescriptor xmlns:md="urn:oasis:names:tc:SAML:2.0:metadata"> + <md:IDPSSODescriptor protocolSupportEnumeration="urn:oasis:names:tc:SAML:2.0:protocol"> + <md:SingleSignOnService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect" Location="%(SAML_SERVER)s"/> + </md:IDPSSODescriptor> +</md:EntityDescriptor> +""" % { + "SAML_SERVER": SAML_SERVER, +} LOGIN_URL = b"/_matrix/client/r0/login" TEST_URL = b"/_matrix/client/r0/account/whoami" @@ -314,6 +359,184 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"200", channel.result) +@skip_unless(has_saml2 and HAS_OIDC, "Requires SAML2 and OIDC") +class MultiSSOTestCase(unittest.HomeserverTestCase): + """Tests for homeservers with multiple SSO providers enabled""" + + servlets = [ + login.register_servlets, + ] + + def default_config(self) -> Dict[str, Any]: + config = super().default_config() + + config["public_baseurl"] = BASE_URL + + config["cas_config"] = { + "enabled": True, + "server_url": CAS_SERVER, + "service_url": "https://matrix.goodserver.com:8448", + } + + config["saml2_config"] = { + "sp_config": { + "metadata": {"inline": [TEST_SAML_METADATA]}, + # use the XMLSecurity backend to avoid relying on xmlsec1 + "crypto_backend": "XMLSecurity", + }, + } + + config["oidc_config"] = TEST_OIDC_CONFIG + + return config + + def create_resource_dict(self) -> Dict[str, Resource]: + d = super().create_resource_dict() + d["/_synapse/client/pick_idp"] = PickIdpResource(self.hs) + return d + + def test_multi_sso_redirect(self): + """/login/sso/redirect should redirect to an identity picker""" + client_redirect_url = "https://x?<abc>" + + # first hit the redirect url, which should redirect to our idp picker + channel = self.make_request( + "GET", + "/_matrix/client/r0/login/sso/redirect?redirectUrl=" + client_redirect_url, + ) + self.assertEqual(channel.code, 302, channel.result) + uri = channel.headers.getRawHeaders("Location")[0] + + # hitting that picker should give us some HTML + channel = self.make_request("GET", uri) + self.assertEqual(channel.code, 200, channel.result) + + # parse the form to check it has fields assumed elsewhere in this class + class FormPageParser(HTMLParser): + def __init__(self): + super().__init__() + + # the values of the hidden inputs: map from name to value + self.hiddens = {} # type: Dict[str, Optional[str]] + + # the values of the radio buttons + self.radios = [] # type: List[Optional[str]] + + def handle_starttag( + self, tag: str, attrs: Iterable[Tuple[str, Optional[str]]] + ) -> None: + attr_dict = dict(attrs) + if tag == "input": + if attr_dict["type"] == "radio" and attr_dict["name"] == "idp": + self.radios.append(attr_dict["value"]) + elif attr_dict["type"] == "hidden": + input_name = attr_dict["name"] + assert input_name + self.hiddens[input_name] = attr_dict["value"] + + def error(_, message): + self.fail(message) + + p = FormPageParser() + p.feed(channel.result["body"].decode("utf-8")) + p.close() + + self.assertCountEqual(p.radios, ["cas", "oidc", "saml"]) + + self.assertEqual(p.hiddens["redirectUrl"], client_redirect_url) + + def test_multi_sso_redirect_to_cas(self): + """If CAS is chosen, should redirect to the CAS server""" + client_redirect_url = "https://x?<abc>" + + channel = self.make_request( + "GET", + "/_synapse/client/pick_idp?redirectUrl=" + client_redirect_url + "&idp=cas", + shorthand=False, + ) + self.assertEqual(channel.code, 302, channel.result) + cas_uri = channel.headers.getRawHeaders("Location")[0] + cas_uri_path, cas_uri_query = cas_uri.split("?", 1) + + # it should redirect us to the login page of the cas server + self.assertEqual(cas_uri_path, CAS_SERVER + "/login") + + # check that the redirectUrl is correctly encoded in the service param - ie, the + # place that CAS will redirect to + cas_uri_params = urllib.parse.parse_qs(cas_uri_query) + service_uri = cas_uri_params["service"][0] + _, service_uri_query = service_uri.split("?", 1) + service_uri_params = urllib.parse.parse_qs(service_uri_query) + self.assertEqual(service_uri_params["redirectUrl"][0], client_redirect_url) + + def test_multi_sso_redirect_to_saml(self): + """If SAML is chosen, should redirect to the SAML server""" + client_redirect_url = "https://x?<abc>" + + channel = self.make_request( + "GET", + "/_synapse/client/pick_idp?redirectUrl=" + + client_redirect_url + + "&idp=saml", + ) + self.assertEqual(channel.code, 302, channel.result) + saml_uri = channel.headers.getRawHeaders("Location")[0] + saml_uri_path, saml_uri_query = saml_uri.split("?", 1) + + # it should redirect us to the login page of the SAML server + self.assertEqual(saml_uri_path, SAML_SERVER) + + # the RelayState is used to carry the client redirect url + saml_uri_params = urllib.parse.parse_qs(saml_uri_query) + relay_state_param = saml_uri_params["RelayState"][0] + self.assertEqual(relay_state_param, client_redirect_url) + + def test_multi_sso_redirect_to_oidc(self): + """If OIDC is chosen, should redirect to the OIDC auth endpoint""" + client_redirect_url = "https://x?<abc>" + + channel = self.make_request( + "GET", + "/_synapse/client/pick_idp?redirectUrl=" + + client_redirect_url + + "&idp=oidc", + ) + self.assertEqual(channel.code, 302, channel.result) + oidc_uri = channel.headers.getRawHeaders("Location")[0] + oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1) + + # it should redirect us to the auth page of the OIDC server + self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT) + + # ... and should have set a cookie including the redirect url + cookies = dict( + h.split(";")[0].split("=", maxsplit=1) + for h in channel.headers.getRawHeaders("Set-Cookie") + ) + + oidc_session_cookie = cookies["oidc_session"] + macaroon = pymacaroons.Macaroon.deserialize(oidc_session_cookie) + self.assertEqual( + self._get_value_from_macaroon(macaroon, "client_redirect_url"), + client_redirect_url, + ) + + def test_multi_sso_redirect_to_unknown(self): + """An unknown IdP should cause a 400""" + channel = self.make_request( + "GET", "/_synapse/client/pick_idp?redirectUrl=http://x&idp=xyz", + ) + self.assertEqual(channel.code, 400, channel.result) + + @staticmethod + def _get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str: + prefix = key + " = " + for caveat in macaroon.caveats: + if caveat.caveat_id.startswith(prefix): + return caveat.caveat_id[len(prefix) :] + raise ValueError("No %s caveat in macaroon" % (key,)) + + class CASTestCase(unittest.HomeserverTestCase): servlets = [ @@ -327,7 +550,7 @@ class CASTestCase(unittest.HomeserverTestCase): config = self.default_config() config["cas_config"] = { "enabled": True, - "server_url": "https://fake.test", + "server_url": CAS_SERVER, "service_url": "https://matrix.goodserver.com:8448", } @@ -413,8 +636,7 @@ class CASTestCase(unittest.HomeserverTestCase): } ) def test_cas_redirect_whitelisted(self): - """Tests that the SSO login flow serves a redirect to a whitelisted url - """ + """Tests that the SSO login flow serves a redirect to a whitelisted url""" self._test_redirect("https://legit-site.com/") @override_config({"public_baseurl": "https://example.com"}) @@ -462,10 +684,8 @@ class CASTestCase(unittest.HomeserverTestCase): self.assertIn(b"SSO account deactivated", channel.result["body"]) +@skip_unless(HAS_JWT, "requires jwt") class JWTTestCase(unittest.HomeserverTestCase): - if not jwt: - skip = "requires jwt" - servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, login.register_servlets, @@ -481,17 +701,17 @@ class JWTTestCase(unittest.HomeserverTestCase): self.hs.config.jwt_algorithm = self.jwt_algorithm return self.hs - def jwt_encode(self, token: str, secret: str = jwt_secret) -> str: + def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_secret) -> str: # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str. - result = jwt.encode(token, secret, self.jwt_algorithm) + result = jwt.encode( + payload, secret, self.jwt_algorithm + ) # type: Union[str, bytes] if isinstance(result, bytes): return result.decode("ascii") return result def jwt_login(self, *args): - params = json.dumps( - {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)} - ) + params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)} channel = self.make_request(b"POST", LOGIN_URL, params) return channel @@ -623,7 +843,7 @@ class JWTTestCase(unittest.HomeserverTestCase): ) def test_login_no_token(self): - params = json.dumps({"type": "org.matrix.login.jwt"}) + params = {"type": "org.matrix.login.jwt"} channel = self.make_request(b"POST", LOGIN_URL, params) self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") @@ -633,10 +853,8 @@ class JWTTestCase(unittest.HomeserverTestCase): # The JWTPubKeyTestCase is a complement to JWTTestCase where we instead use # RSS256, with a public key configured in synapse as "jwt_secret", and tokens # signed by the private key. +@skip_unless(HAS_JWT, "requires jwt") class JWTPubKeyTestCase(unittest.HomeserverTestCase): - if not jwt: - skip = "requires jwt" - servlets = [ login.register_servlets, ] @@ -693,17 +911,15 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase): self.hs.config.jwt_algorithm = "RS256" return self.hs - def jwt_encode(self, token: str, secret: str = jwt_privatekey) -> str: + def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_privatekey) -> str: # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str. - result = jwt.encode(token, secret, "RS256") + result = jwt.encode(payload, secret, "RS256") # type: Union[bytes,str] if isinstance(result, bytes): return result.decode("ascii") return result def jwt_login(self, *args): - params = json.dumps( - {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)} - ) + params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)} channel = self.make_request(b"POST", LOGIN_URL, params) return channel @@ -773,8 +989,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): return self.hs def test_login_appservice_user(self): - """Test that an appservice user can use /login - """ + """Test that an appservice user can use /login""" self.register_as_user(AS_USER) params = { @@ -788,8 +1003,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"200", channel.result) def test_login_appservice_user_bot(self): - """Test that the appservice bot can use /login - """ + """Test that the appservice bot can use /login""" self.register_as_user(AS_USER) params = { @@ -803,8 +1017,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"200", channel.result) def test_login_appservice_wrong_user(self): - """Test that non-as users cannot login with the as token - """ + """Test that non-as users cannot login with the as token""" self.register_as_user(AS_USER) params = { @@ -818,8 +1031,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"403", channel.result) def test_login_appservice_wrong_as(self): - """Test that as users cannot login with wrong as token - """ + """Test that as users cannot login with wrong as token""" self.register_as_user(AS_USER) params = { @@ -834,7 +1046,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): def test_login_appservice_no_token(self): """Test that users must provide a token when using the appservice - login method + login method """ self.register_as_user(AS_USER) diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index dbc27893b5..81b7f84360 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py
@@ -444,6 +444,7 @@ class RestHelper: # an 'oidc_config' suitable for login_via_oidc. +TEST_OIDC_AUTH_ENDPOINT = "https://issuer.test/auth" TEST_OIDC_CONFIG = { "enabled": True, "discover": False, @@ -451,7 +452,7 @@ TEST_OIDC_CONFIG = { "client_id": "test-client-id", "client_secret": "test-client-secret", "scopes": ["profile"], - "authorization_endpoint": "https://z", + "authorization_endpoint": TEST_OIDC_AUTH_ENDPOINT, "token_endpoint": "https://issuer.test/token", "userinfo_endpoint": "https://issuer.test/userinfo", "user_mapping_provider": {"config": {"localpart_template": "{{ user.sub }}"}}, diff --git a/tox.ini b/tox.ini
index dec732f76e..3cf68a47a6 100644 --- a/tox.ini +++ b/tox.ini
@@ -26,6 +26,20 @@ deps = pip>=10 ; python_version >= '3.6' pip>=10,<21.0 ; python_version < '3.6' +# directories/files we run the linters on +lint_targets = + setup.py + synapse + tests + scripts + scripts-dev + stubs + contrib + synctl + synmark + .buildkite + docker + # default settings for all tox environments [testenv] deps = @@ -130,13 +144,13 @@ commands = [testenv:check_codestyle] extras = lint commands = - python -m black --check --diff . - /bin/sh -c "flake8 synapse tests scripts scripts-dev contrib synctl {env:PEP8SUFFIX:}" + python -m black --check --diff {[base]lint_targets} + flake8 {[base]lint_targets} {env:PEP8SUFFIX:} {toxinidir}/scripts-dev/config-lint.sh [testenv:check_isort] extras = lint -commands = /bin/sh -c "isort -c --df --sp setup.cfg synapse tests scripts-dev scripts" +commands = isort -c --df --sp setup.cfg {[base]lint_targets} [testenv:check-newsfragment] skip_install = True