From f70f44abc73689a66d0a05dc703ca38241092174 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 28 Feb 2020 11:45:35 +0000 Subject: Remove handling of multiple rows per ID --- tests/storage/test_devices.py | 45 ------------------------------------------- 1 file changed, 45 deletions(-) (limited to 'tests') diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py index 6f8d990959..c2539b353a 100644 --- a/tests/storage/test_devices.py +++ b/tests/storage/test_devices.py @@ -88,51 +88,6 @@ class DeviceStoreTestCase(tests.unittest.TestCase): # Check original device_ids are contained within these updates self._check_devices_in_updates(device_ids, device_updates) - @defer.inlineCallbacks - def test_get_device_updates_by_remote_limited(self): - # Test breaking the update limit in 1, 101, and 1 device_id segments - - # first add one device - device_ids1 = ["device_id0"] - yield self.store.add_device_change_to_streams( - "user_id", device_ids1, ["someotherhost"] - ) - - # then add 101 - device_ids2 = ["device_id" + str(i + 1) for i in range(101)] - yield self.store.add_device_change_to_streams( - "user_id", device_ids2, ["someotherhost"] - ) - - # then one more - device_ids3 = ["newdevice"] - yield self.store.add_device_change_to_streams( - "user_id", device_ids3, ["someotherhost"] - ) - - # - # now read them back. - # - - # first we should get a single update - now_stream_id, device_updates = yield self.store.get_device_updates_by_remote( - "someotherhost", -1, limit=100 - ) - self._check_devices_in_updates(device_ids1, device_updates) - - # Then we should get an empty list back as the 101 devices broke the limit - now_stream_id, device_updates = yield self.store.get_device_updates_by_remote( - "someotherhost", now_stream_id, limit=100 - ) - self.assertEqual(len(device_updates), 0) - - # The 101 devices should've been cleared, so we should now just get one device - # update - now_stream_id, device_updates = yield self.store.get_device_updates_by_remote( - "someotherhost", now_stream_id, limit=100 - ) - self._check_devices_in_updates(device_ids3, device_updates) - def _check_devices_in_updates(self, expected_device_ids, device_updates): """Check that an specific device ids exist in a list of device update EDUs""" self.assertEqual(len(device_updates), len(expected_device_ids)) -- cgit 1.5.1 From c165c1233b8ef244fadca97c7d465fdcf473d077 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 20 Mar 2020 16:24:22 +0100 Subject: Improve database configuration docs (#6988) Attempts to clarify the sample config for databases, and add some stuff about tcp keepalives to `postgres.md`. --- changelog.d/6988.doc | 1 + docs/postgres.md | 42 ++++++++++++++----- docs/sample_config.yaml | 43 +++++++++++++++++--- synapse/config/_base.py | 2 - synapse/config/database.py | 93 +++++++++++++++++++++++++++---------------- tests/config/test_database.py | 22 +--------- 6 files changed, 132 insertions(+), 71 deletions(-) create mode 100644 changelog.d/6988.doc (limited to 'tests') diff --git a/changelog.d/6988.doc b/changelog.d/6988.doc new file mode 100644 index 0000000000..b6f71bb966 --- /dev/null +++ b/changelog.d/6988.doc @@ -0,0 +1 @@ +Improve the documentation for database configuration. diff --git a/docs/postgres.md b/docs/postgres.md index e0793ecee8..16a630c3d1 100644 --- a/docs/postgres.md +++ b/docs/postgres.md @@ -105,19 +105,41 @@ of free memory the database host has available. When you are ready to start using PostgreSQL, edit the `database` section in your config file to match the following lines: - database: - name: psycopg2 - args: - user: - password: - database: - host: - cp_min: 5 - cp_max: 10 +```yaml +database: + name: psycopg2 + args: + user: + password: + database: + host: + cp_min: 5 + cp_max: 10 +``` All key, values in `args` are passed to the `psycopg2.connect(..)` function, except keys beginning with `cp_`, which are consumed by the -twisted adbapi connection pool. +twisted adbapi connection pool. See the [libpq +documentation](https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS) +for a list of options which can be passed. + +You should consider tuning the `args.keepalives_*` options if there is any danger of +the connection between your homeserver and database dropping, otherwise Synapse +may block for an extended period while it waits for a response from the +database server. Example values might be: + +```yaml +# seconds of inactivity after which TCP should send a keepalive message to the server +keepalives_idle: 10 + +# the number of seconds after which a TCP keepalive message that is not +# acknowledged by the server should be retransmitted +keepalives_interval: 10 + +# the number of TCP keepalives that can be lost before the client's connection +# to the server is considered dead +keepalives_count: 3 +``` ## Porting from SQLite diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 2ff0dd05a2..276e43b732 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -578,13 +578,46 @@ acme: ## Database ## +# The 'database' setting defines the database that synapse uses to store all of +# its data. +# +# 'name' gives the database engine to use: either 'sqlite3' (for SQLite) or +# 'psycopg2' (for PostgreSQL). +# +# 'args' gives options which are passed through to the database engine, +# except for options starting 'cp_', which are used to configure the Twisted +# connection pool. For a reference to valid arguments, see: +# * for sqlite: https://docs.python.org/3/library/sqlite3.html#sqlite3.connect +# * for postgres: https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS +# * for the connection pool: https://twistedmatrix.com/documents/current/api/twisted.enterprise.adbapi.ConnectionPool.html#__init__ +# +# +# Example SQLite configuration: +# +#database: +# name: sqlite3 +# args: +# database: /path/to/homeserver.db +# +# +# Example Postgres configuration: +# +#database: +# name: psycopg2 +# args: +# user: synapse +# password: secretpassword +# database: synapse +# host: localhost +# cp_min: 5 +# cp_max: 10 +# +# For more information on using Synapse with Postgres, see `docs/postgres.md`. +# database: - # The database engine name - name: "sqlite3" - # Arguments to pass to the engine + name: sqlite3 args: - # Path to the database - database: "DATADIR/homeserver.db" + database: DATADIR/homeserver.db # Number of events to cache in memory. # diff --git a/synapse/config/_base.py b/synapse/config/_base.py index ba846042c4..efe2af5504 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -294,7 +294,6 @@ class RootConfig(object): report_stats=None, open_private_ports=False, listeners=None, - database_conf=None, tls_certificate_path=None, tls_private_key_path=None, acme_domain=None, @@ -367,7 +366,6 @@ class RootConfig(object): report_stats=report_stats, open_private_ports=open_private_ports, listeners=listeners, - database_conf=database_conf, tls_certificate_path=tls_certificate_path, tls_private_key_path=tls_private_key_path, acme_domain=acme_domain, diff --git a/synapse/config/database.py b/synapse/config/database.py index 219b32f670..b8ab2f86ac 100644 --- a/synapse/config/database.py +++ b/synapse/config/database.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,14 +15,60 @@ # limitations under the License. import logging import os -from textwrap import indent - -import yaml from synapse.config._base import Config, ConfigError logger = logging.getLogger(__name__) +DEFAULT_CONFIG = """\ +## Database ## + +# The 'database' setting defines the database that synapse uses to store all of +# its data. +# +# 'name' gives the database engine to use: either 'sqlite3' (for SQLite) or +# 'psycopg2' (for PostgreSQL). +# +# 'args' gives options which are passed through to the database engine, +# except for options starting 'cp_', which are used to configure the Twisted +# connection pool. For a reference to valid arguments, see: +# * for sqlite: https://docs.python.org/3/library/sqlite3.html#sqlite3.connect +# * for postgres: https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS +# * for the connection pool: https://twistedmatrix.com/documents/current/api/twisted.enterprise.adbapi.ConnectionPool.html#__init__ +# +# +# Example SQLite configuration: +# +#database: +# name: sqlite3 +# args: +# database: /path/to/homeserver.db +# +# +# Example Postgres configuration: +# +#database: +# name: psycopg2 +# args: +# user: synapse +# password: secretpassword +# database: synapse +# host: localhost +# cp_min: 5 +# cp_max: 10 +# +# For more information on using Synapse with Postgres, see `docs/postgres.md`. +# +database: + name: sqlite3 + args: + database: %(database_path)s + +# Number of events to cache in memory. +# +#event_cache_size: 10K +""" + class DatabaseConnectionConfig: """Contains the connection config for a particular database. @@ -36,10 +83,12 @@ class DatabaseConnectionConfig: """ def __init__(self, name: str, db_config: dict): - if db_config["name"] not in ("sqlite3", "psycopg2"): - raise ConfigError("Unsupported database type %r" % (db_config["name"],)) + db_engine = db_config.get("name", "sqlite3") - if db_config["name"] == "sqlite3": + if db_engine not in ("sqlite3", "psycopg2"): + raise ConfigError("Unsupported database type %r" % (db_engine,)) + + if db_engine == "sqlite3": db_config.setdefault("args", {}).update( {"cp_min": 1, "cp_max": 1, "check_same_thread": False} ) @@ -97,34 +146,10 @@ class DatabaseConfig(Config): self.set_databasepath(config.get("database_path")) - def generate_config_section(self, data_dir_path, database_conf, **kwargs): - if not database_conf: - database_path = os.path.join(data_dir_path, "homeserver.db") - database_conf = ( - """# The database engine name - name: "sqlite3" - # Arguments to pass to the engine - args: - # Path to the database - database: "%(database_path)s" - """ - % locals() - ) - else: - database_conf = indent(yaml.dump(database_conf), " " * 10).lstrip() - - return ( - """\ - ## Database ## - - database: - %(database_conf)s - # Number of events to cache in memory. - # - #event_cache_size: 10K - """ - % locals() - ) + def generate_config_section(self, data_dir_path, **kwargs): + return DEFAULT_CONFIG % { + "database_path": os.path.join(data_dir_path, "homeserver.db") + } def read_arguments(self, args): self.set_databasepath(args.database_path) diff --git a/tests/config/test_database.py b/tests/config/test_database.py index 151d3006ac..f675bde68e 100644 --- a/tests/config/test_database.py +++ b/tests/config/test_database.py @@ -21,9 +21,9 @@ from tests import unittest class DatabaseConfigTestCase(unittest.TestCase): - def test_database_configured_correctly_no_database_conf_param(self): + def test_database_configured_correctly(self): conf = yaml.safe_load( - DatabaseConfig().generate_config_section("/data_dir_path", None) + DatabaseConfig().generate_config_section(data_dir_path="/data_dir_path") ) expected_database_conf = { @@ -32,21 +32,3 @@ class DatabaseConfigTestCase(unittest.TestCase): } self.assertEqual(conf["database"], expected_database_conf) - - def test_database_configured_correctly_database_conf_param(self): - - database_conf = { - "name": "my super fast datastore", - "args": { - "user": "matrix", - "password": "synapse_database_password", - "host": "synapse_database_host", - "database": "matrix", - }, - } - - conf = yaml.safe_load( - DatabaseConfig().generate_config_section("/data_dir_path", database_conf) - ) - - self.assertEqual(conf["database"], database_conf) -- cgit 1.5.1 From a564b92d37625855940fe599c730a9958c33f973 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 23 Mar 2020 13:59:11 +0000 Subject: Convert `*StreamRow` classes to inner classes (#7116) This just helps keep the rows closer to their streams, so that it's easier to see what the format of each stream is. --- changelog.d/7116.misc | 1 + synapse/app/generic_worker.py | 2 +- synapse/federation/send_queue.py | 2 +- synapse/replication/tcp/streams/_base.py | 181 +++++++++++++------------ synapse/replication/tcp/streams/federation.py | 16 +-- tests/replication/tcp/streams/test_receipts.py | 4 +- 6 files changed, 106 insertions(+), 100 deletions(-) create mode 100644 changelog.d/7116.misc (limited to 'tests') diff --git a/changelog.d/7116.misc b/changelog.d/7116.misc new file mode 100644 index 0000000000..89d90bd49e --- /dev/null +++ b/changelog.d/7116.misc @@ -0,0 +1 @@ +Convert `*StreamRow` classes to inner classes. diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 136babe6ce..c8fd8909a4 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -804,7 +804,7 @@ class FederationSenderHandler(object): async def _on_new_receipts(self, rows): """ Args: - rows (iterable[synapse.replication.tcp.streams.ReceiptsStreamRow]): + rows (Iterable[synapse.replication.tcp.streams.ReceiptsStream.ReceiptsStreamRow]): new receipts to be processed """ for receipt in rows: diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py index 876fb0e245..e1700ca8aa 100644 --- a/synapse/federation/send_queue.py +++ b/synapse/federation/send_queue.py @@ -477,7 +477,7 @@ def process_rows_for_federation(transaction_queue, rows): Args: transaction_queue (FederationSender) - rows (list(synapse.replication.tcp.streams.FederationStreamRow)) + rows (list(synapse.replication.tcp.streams.federation.FederationStream.FederationStreamRow)) """ # The federation stream contains a bunch of different types of diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index abf5c6c6a8..32d9514883 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -28,94 +28,6 @@ logger = logging.getLogger(__name__) MAX_EVENTS_BEHIND = 500000 -BackfillStreamRow = namedtuple( - "BackfillStreamRow", - ( - "event_id", # str - "room_id", # str - "type", # str - "state_key", # str, optional - "redacts", # str, optional - "relates_to", # str, optional - ), -) -PresenceStreamRow = namedtuple( - "PresenceStreamRow", - ( - "user_id", # str - "state", # str - "last_active_ts", # int - "last_federation_update_ts", # int - "last_user_sync_ts", # int - "status_msg", # str - "currently_active", # bool - ), -) -TypingStreamRow = namedtuple( - "TypingStreamRow", ("room_id", "user_ids") # str # list(str) -) -ReceiptsStreamRow = namedtuple( - "ReceiptsStreamRow", - ( - "room_id", # str - "receipt_type", # str - "user_id", # str - "event_id", # str - "data", # dict - ), -) -PushRulesStreamRow = namedtuple("PushRulesStreamRow", ("user_id",)) # str -PushersStreamRow = namedtuple( - "PushersStreamRow", - ("user_id", "app_id", "pushkey", "deleted"), # str # str # str # bool -) - - -@attr.s -class CachesStreamRow: - """Stream to inform workers they should invalidate their cache. - - Attributes: - cache_func: Name of the cached function. - keys: The entry in the cache to invalidate. If None then will - invalidate all. - invalidation_ts: Timestamp of when the invalidation took place. - """ - - cache_func = attr.ib(type=str) - keys = attr.ib(type=Optional[List[Any]]) - invalidation_ts = attr.ib(type=int) - - -PublicRoomsStreamRow = namedtuple( - "PublicRoomsStreamRow", - ( - "room_id", # str - "visibility", # str - "appservice_id", # str, optional - "network_id", # str, optional - ), -) - - -@attr.s -class DeviceListsStreamRow: - entity = attr.ib(type=str) - - -ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ("entity",)) # str -TagAccountDataStreamRow = namedtuple( - "TagAccountDataStreamRow", ("user_id", "room_id", "data") # str # str # dict -) -AccountDataStreamRow = namedtuple( - "AccountDataStream", ("user_id", "room_id", "data_type") # str # str # str -) -GroupsStreamRow = namedtuple( - "GroupsStreamRow", - ("group_id", "user_id", "type", "content"), # str # str # str # dict -) -UserSignatureStreamRow = namedtuple("UserSignatureStreamRow", ("user_id")) # str - class Stream(object): """Base class for the streams. @@ -234,6 +146,18 @@ class BackfillStream(Stream): or it went from being an outlier to not. """ + BackfillStreamRow = namedtuple( + "BackfillStreamRow", + ( + "event_id", # str + "room_id", # str + "type", # str + "state_key", # str, optional + "redacts", # str, optional + "relates_to", # str, optional + ), + ) + NAME = "backfill" ROW_TYPE = BackfillStreamRow @@ -246,6 +170,19 @@ class BackfillStream(Stream): class PresenceStream(Stream): + PresenceStreamRow = namedtuple( + "PresenceStreamRow", + ( + "user_id", # str + "state", # str + "last_active_ts", # int + "last_federation_update_ts", # int + "last_user_sync_ts", # int + "status_msg", # str + "currently_active", # bool + ), + ) + NAME = "presence" ROW_TYPE = PresenceStreamRow @@ -260,6 +197,10 @@ class PresenceStream(Stream): class TypingStream(Stream): + TypingStreamRow = namedtuple( + "TypingStreamRow", ("room_id", "user_ids") # str # list(str) + ) + NAME = "typing" ROW_TYPE = TypingStreamRow @@ -273,6 +214,17 @@ class TypingStream(Stream): class ReceiptsStream(Stream): + ReceiptsStreamRow = namedtuple( + "ReceiptsStreamRow", + ( + "room_id", # str + "receipt_type", # str + "user_id", # str + "event_id", # str + "data", # dict + ), + ) + NAME = "receipts" ROW_TYPE = ReceiptsStreamRow @@ -289,6 +241,8 @@ class PushRulesStream(Stream): """A user has changed their push rules """ + PushRulesStreamRow = namedtuple("PushRulesStreamRow", ("user_id",)) # str + NAME = "push_rules" ROW_TYPE = PushRulesStreamRow @@ -309,6 +263,11 @@ class PushersStream(Stream): """A user has added/changed/removed a pusher """ + PushersStreamRow = namedtuple( + "PushersStreamRow", + ("user_id", "app_id", "pushkey", "deleted"), # str # str # str # bool + ) + NAME = "pushers" ROW_TYPE = PushersStreamRow @@ -326,6 +285,21 @@ class CachesStream(Stream): the cache on the workers """ + @attr.s + class CachesStreamRow: + """Stream to inform workers they should invalidate their cache. + + Attributes: + cache_func: Name of the cached function. + keys: The entry in the cache to invalidate. If None then will + invalidate all. + invalidation_ts: Timestamp of when the invalidation took place. + """ + + cache_func = attr.ib(type=str) + keys = attr.ib(type=Optional[List[Any]]) + invalidation_ts = attr.ib(type=int) + NAME = "caches" ROW_TYPE = CachesStreamRow @@ -342,6 +316,16 @@ class PublicRoomsStream(Stream): """The public rooms list changed """ + PublicRoomsStreamRow = namedtuple( + "PublicRoomsStreamRow", + ( + "room_id", # str + "visibility", # str + "appservice_id", # str, optional + "network_id", # str, optional + ), + ) + NAME = "public_rooms" ROW_TYPE = PublicRoomsStreamRow @@ -359,6 +343,10 @@ class DeviceListsStream(Stream): told about a device update. """ + @attr.s + class DeviceListsStreamRow: + entity = attr.ib(type=str) + NAME = "device_lists" ROW_TYPE = DeviceListsStreamRow @@ -375,6 +363,8 @@ class ToDeviceStream(Stream): """New to_device messages for a client """ + ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ("entity",)) # str + NAME = "to_device" ROW_TYPE = ToDeviceStreamRow @@ -391,6 +381,10 @@ class TagAccountDataStream(Stream): """Someone added/removed a tag for a room """ + TagAccountDataStreamRow = namedtuple( + "TagAccountDataStreamRow", ("user_id", "room_id", "data") # str # str # dict + ) + NAME = "tag_account_data" ROW_TYPE = TagAccountDataStreamRow @@ -407,6 +401,10 @@ class AccountDataStream(Stream): """Global or per room account data was changed """ + AccountDataStreamRow = namedtuple( + "AccountDataStream", ("user_id", "room_id", "data_type") # str # str # str + ) + NAME = "account_data" ROW_TYPE = AccountDataStreamRow @@ -432,6 +430,11 @@ class AccountDataStream(Stream): class GroupServerStream(Stream): + GroupsStreamRow = namedtuple( + "GroupsStreamRow", + ("group_id", "user_id", "type", "content"), # str # str # str # dict + ) + NAME = "groups" ROW_TYPE = GroupsStreamRow @@ -448,6 +451,8 @@ class UserSignatureStream(Stream): """A user has signed their own device with their user-signing key """ + UserSignatureStreamRow = namedtuple("UserSignatureStreamRow", ("user_id")) # str + NAME = "user_signature" ROW_TYPE = UserSignatureStreamRow diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py index 615f3dc9ac..f5f9336430 100644 --- a/synapse/replication/tcp/streams/federation.py +++ b/synapse/replication/tcp/streams/federation.py @@ -17,20 +17,20 @@ from collections import namedtuple from ._base import Stream -FederationStreamRow = namedtuple( - "FederationStreamRow", - ( - "type", # str, the type of data as defined in the BaseFederationRows - "data", # dict, serialization of a federation.send_queue.BaseFederationRow - ), -) - class FederationStream(Stream): """Data to be sent over federation. Only available when master has federation sending disabled. """ + FederationStreamRow = namedtuple( + "FederationStreamRow", + ( + "type", # str, the type of data as defined in the BaseFederationRows + "data", # dict, serialization of a federation.send_queue.BaseFederationRow + ), + ) + NAME = "federation" ROW_TYPE = FederationStreamRow diff --git a/tests/replication/tcp/streams/test_receipts.py b/tests/replication/tcp/streams/test_receipts.py index d5a99f6caa..fa2493cad6 100644 --- a/tests/replication/tcp/streams/test_receipts.py +++ b/tests/replication/tcp/streams/test_receipts.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from synapse.replication.tcp.streams._base import ReceiptsStreamRow +from synapse.replication.tcp.streams._base import ReceiptsStream from tests.replication.tcp.streams._base import BaseStreamTestCase @@ -38,7 +38,7 @@ class ReceiptsStreamTestCase(BaseStreamTestCase): rdata_rows = self.test_handler.received_rdata_rows self.assertEqual(1, len(rdata_rows)) self.assertEqual(rdata_rows[0][0], "receipts") - row = rdata_rows[0][2] # type: ReceiptsStreamRow + row = rdata_rows[0][2] # type: ReceiptsStream.ReceiptsStreamRow self.assertEqual(ROOM_ID, row.room_id) self.assertEqual("m.read", row.receipt_type) self.assertEqual(USER_ID, row.user_id) -- cgit 1.5.1 From 39230d217104f3cd7aba9065dc478f935ce1e614 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 24 Mar 2020 14:45:33 +0000 Subject: Clean up some LoggingContext stuff (#7120) * Pull Sentinel out of LoggingContext ... and drop a few unnecessary references to it * Factor out LoggingContext.current_context move `current_context` and `set_context` out to top-level functions. Mostly this means that I can more easily trace what's actually referring to LoggingContext, but I think it's generally neater. * move copy-to-parent into `stop` this really just makes `start` and `stop` more symetric. It also means that it behaves correctly if you manually `set_log_context` rather than using the context manager. * Replace `LoggingContext.alive` with `finished` Turn `alive` into `finished` and make it a bit better defined. --- changelog.d/7120.misc | 1 + docs/log_contexts.md | 5 +- synapse/crypto/keyring.py | 4 +- synapse/federation/federation_base.py | 4 +- synapse/handlers/sync.py | 4 +- synapse/http/request_metrics.py | 6 +- synapse/logging/_structured.py | 4 +- synapse/logging/context.py | 234 +++++++++++---------- synapse/logging/scopecontextmanager.py | 13 +- synapse/storage/data_stores/main/events_worker.py | 4 +- synapse/storage/database.py | 11 +- synapse/util/metrics.py | 4 +- synapse/util/patch_inline_callbacks.py | 36 ++-- tests/crypto/test_keyring.py | 7 +- .../federation/test_matrix_federation_agent.py | 6 +- tests/http/federation/test_srv_resolver.py | 6 +- tests/http/test_fedclient.py | 6 +- tests/rest/client/test_transactions.py | 16 +- tests/unittest.py | 12 +- tests/util/caches/test_descriptors.py | 22 +- tests/util/test_async_utils.py | 15 +- tests/util/test_linearizer.py | 6 +- tests/util/test_logcontext.py | 22 +- tests/utils.py | 6 +- 24 files changed, 232 insertions(+), 222 deletions(-) create mode 100644 changelog.d/7120.misc (limited to 'tests') diff --git a/changelog.d/7120.misc b/changelog.d/7120.misc new file mode 100644 index 0000000000..731f4dcb52 --- /dev/null +++ b/changelog.d/7120.misc @@ -0,0 +1 @@ +Clean up some LoggingContext code. diff --git a/docs/log_contexts.md b/docs/log_contexts.md index 5331e8c88b..fe30ca2791 100644 --- a/docs/log_contexts.md +++ b/docs/log_contexts.md @@ -29,14 +29,13 @@ from synapse.logging import context # omitted from future snippets def handle_request(request_id): request_context = context.LoggingContext() - calling_context = context.LoggingContext.current_context() - context.LoggingContext.set_current_context(request_context) + calling_context = context.set_current_context(request_context) try: request_context.request = request_id do_request_handling() logger.debug("finished") finally: - context.LoggingContext.set_current_context(calling_context) + context.set_current_context(calling_context) def do_request_handling(): logger.debug("phew") # this will be logged against request_id diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 983f0ead8c..a9f4025bfe 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -43,8 +43,8 @@ from synapse.api.errors import ( SynapseError, ) from synapse.logging.context import ( - LoggingContext, PreserveLoggingContext, + current_context, make_deferred_yieldable, preserve_fn, run_in_background, @@ -236,7 +236,7 @@ class Keyring(object): """ try: - ctx = LoggingContext.current_context() + ctx = current_context() # map from server name to a set of outstanding request ids server_to_request_ids = {} diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index b0b0eba41e..4b115aac04 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -32,8 +32,8 @@ from synapse.events import EventBase, make_event_from_dict from synapse.events.utils import prune_event from synapse.http.servlet import assert_params_in_dict from synapse.logging.context import ( - LoggingContext, PreserveLoggingContext, + current_context, make_deferred_yieldable, ) from synapse.types import JsonDict, get_domain_from_id @@ -78,7 +78,7 @@ class FederationBase(object): """ deferreds = _check_sigs_on_pdus(self.keyring, room_version, pdus) - ctx = LoggingContext.current_context() + ctx = current_context() def callback(_, pdu: EventBase): with PreserveLoggingContext(ctx): diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 669dbc8a48..5746fdea14 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -26,7 +26,7 @@ from prometheus_client import Counter from synapse.api.constants import EventTypes, Membership from synapse.api.filtering import FilterCollection from synapse.events import EventBase -from synapse.logging.context import LoggingContext +from synapse.logging.context import current_context from synapse.push.clientformat import format_push_rules_for_user from synapse.storage.roommember import MemberSummary from synapse.storage.state import StateFilter @@ -301,7 +301,7 @@ class SyncHandler(object): else: sync_type = "incremental_sync" - context = LoggingContext.current_context() + context = current_context() if context: context.tag = sync_type diff --git a/synapse/http/request_metrics.py b/synapse/http/request_metrics.py index 58f9cc61c8..b58ae3d9db 100644 --- a/synapse/http/request_metrics.py +++ b/synapse/http/request_metrics.py @@ -19,7 +19,7 @@ import threading from prometheus_client.core import Counter, Histogram -from synapse.logging.context import LoggingContext +from synapse.logging.context import current_context from synapse.metrics import LaterGauge logger = logging.getLogger(__name__) @@ -148,7 +148,7 @@ LaterGauge( class RequestMetrics(object): def start(self, time_sec, name, method): self.start = time_sec - self.start_context = LoggingContext.current_context() + self.start_context = current_context() self.name = name self.method = method @@ -163,7 +163,7 @@ class RequestMetrics(object): with _in_flight_requests_lock: _in_flight_requests.discard(self) - context = LoggingContext.current_context() + context = current_context() tag = "" if context: diff --git a/synapse/logging/_structured.py b/synapse/logging/_structured.py index ffa7b20ca8..7372450b45 100644 --- a/synapse/logging/_structured.py +++ b/synapse/logging/_structured.py @@ -42,7 +42,7 @@ from synapse.logging._terse_json import ( TerseJSONToConsoleLogObserver, TerseJSONToTCPLogObserver, ) -from synapse.logging.context import LoggingContext +from synapse.logging.context import current_context def stdlib_log_level_to_twisted(level: str) -> LogLevel: @@ -86,7 +86,7 @@ class LogContextObserver(object): ].startswith("Timing out client"): return - context = LoggingContext.current_context() + context = current_context() # Copy the context information to the log event. if context is not None: diff --git a/synapse/logging/context.py b/synapse/logging/context.py index 860b99a4c6..a8eafb1c7c 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py @@ -175,7 +175,54 @@ class ContextResourceUsage(object): return res -LoggingContextOrSentinel = Union["LoggingContext", "LoggingContext.Sentinel"] +LoggingContextOrSentinel = Union["LoggingContext", "_Sentinel"] + + +class _Sentinel(object): + """Sentinel to represent the root context""" + + __slots__ = ["previous_context", "finished", "request", "scope", "tag"] + + def __init__(self) -> None: + # Minimal set for compatibility with LoggingContext + self.previous_context = None + self.finished = False + self.request = None + self.scope = None + self.tag = None + + def __str__(self): + return "sentinel" + + def copy_to(self, record): + pass + + def copy_to_twisted_log_entry(self, record): + record["request"] = None + record["scope"] = None + + def start(self): + pass + + def stop(self): + pass + + def add_database_transaction(self, duration_sec): + pass + + def add_database_scheduled(self, sched_sec): + pass + + def record_event_fetch(self, event_count): + pass + + def __nonzero__(self): + return False + + __bool__ = __nonzero__ # python3 + + +SENTINEL_CONTEXT = _Sentinel() class LoggingContext(object): @@ -199,76 +246,33 @@ class LoggingContext(object): "_resource_usage", "usage_start", "main_thread", - "alive", + "finished", "request", "tag", "scope", ] - thread_local = threading.local() - - class Sentinel(object): - """Sentinel to represent the root context""" - - __slots__ = ["previous_context", "alive", "request", "scope", "tag"] - - def __init__(self) -> None: - # Minimal set for compatibility with LoggingContext - self.previous_context = None - self.alive = None - self.request = None - self.scope = None - self.tag = None - - def __str__(self): - return "sentinel" - - def copy_to(self, record): - pass - - def copy_to_twisted_log_entry(self, record): - record["request"] = None - record["scope"] = None - - def start(self): - pass - - def stop(self): - pass - - def add_database_transaction(self, duration_sec): - pass - - def add_database_scheduled(self, sched_sec): - pass - - def record_event_fetch(self, event_count): - pass - - def __nonzero__(self): - return False - - __bool__ = __nonzero__ # python3 - - sentinel = Sentinel() - def __init__(self, name=None, parent_context=None, request=None) -> None: - self.previous_context = LoggingContext.current_context() + self.previous_context = current_context() self.name = name # track the resources used by this context so far self._resource_usage = ContextResourceUsage() - # If alive has the thread resource usage when the logcontext last - # became active. + # The thread resource usage when the logcontext became active. None + # if the context is not currently active. self.usage_start = None self.main_thread = get_thread_id() self.request = None self.tag = "" - self.alive = True self.scope = None # type: Optional[_LogContextScope] + # keep track of whether we have hit the __exit__ block for this context + # (suggesting that the the thing that created the context thinks it should + # be finished, and that re-activating it would suggest an error). + self.finished = False + self.parent_context = parent_context if self.parent_context is not None: @@ -283,44 +287,15 @@ class LoggingContext(object): return str(self.request) return "%s@%x" % (self.name, id(self)) - @classmethod - def current_context(cls) -> LoggingContextOrSentinel: - """Get the current logging context from thread local storage - - Returns: - LoggingContext: the current logging context - """ - return getattr(cls.thread_local, "current_context", cls.sentinel) - - @classmethod - def set_current_context( - cls, context: LoggingContextOrSentinel - ) -> LoggingContextOrSentinel: - """Set the current logging context in thread local storage - Args: - context(LoggingContext): The context to activate. - Returns: - The context that was previously active - """ - current = cls.current_context() - - if current is not context: - current.stop() - cls.thread_local.current_context = context - context.start() - return current - def __enter__(self) -> "LoggingContext": """Enters this logging context into thread local storage""" - old_context = self.set_current_context(self) + old_context = set_current_context(self) if self.previous_context != old_context: logger.warning( "Expected previous context %r, found %r", self.previous_context, old_context, ) - self.alive = True - return self def __exit__(self, type, value, traceback) -> None: @@ -329,24 +304,19 @@ class LoggingContext(object): Returns: None to avoid suppressing any exceptions that were thrown. """ - current = self.set_current_context(self.previous_context) + current = set_current_context(self.previous_context) if current is not self: - if current is self.sentinel: + if current is SENTINEL_CONTEXT: logger.warning("Expected logging context %s was lost", self) else: logger.warning( "Expected logging context %s but found %s", self, current ) - self.alive = False - - # if we have a parent, pass our CPU usage stats on - if self.parent_context is not None and hasattr( - self.parent_context, "_resource_usage" - ): - self.parent_context._resource_usage += self._resource_usage - # reset them in case we get entered again - self._resource_usage.reset() + # the fact that we are here suggests that the caller thinks that everything + # is done and dusted for this logcontext, and further activity will not get + # recorded against the correct metrics. + self.finished = True def copy_to(self, record) -> None: """Copy logging fields from this context to a log record or @@ -371,9 +341,14 @@ class LoggingContext(object): logger.warning("Started logcontext %s on different thread", self) return + if self.finished: + logger.warning("Re-starting finished log context %s", self) + # If we haven't already started record the thread resource usage so # far - if not self.usage_start: + if self.usage_start: + logger.warning("Re-starting already-active log context %s", self) + else: self.usage_start = get_thread_resource_usage() def stop(self) -> None: @@ -396,6 +371,15 @@ class LoggingContext(object): self.usage_start = None + # if we have a parent, pass our CPU usage stats on + if self.parent_context is not None and hasattr( + self.parent_context, "_resource_usage" + ): + self.parent_context._resource_usage += self._resource_usage + + # reset them in case we get entered again + self._resource_usage.reset() + def get_resource_usage(self) -> ContextResourceUsage: """Get resources used by this logcontext so far. @@ -409,7 +393,7 @@ class LoggingContext(object): # If we are on the correct thread and we're currently running then we # can include resource usage so far. is_main_thread = get_thread_id() == self.main_thread - if self.alive and self.usage_start and is_main_thread: + if self.usage_start and is_main_thread: utime_delta, stime_delta = self._get_cputime() res.ru_utime += utime_delta res.ru_stime += stime_delta @@ -492,7 +476,7 @@ class LoggingContextFilter(logging.Filter): Returns: True to include the record in the log output. """ - context = LoggingContext.current_context() + context = current_context() for key, value in self.defaults.items(): setattr(record, key, value) @@ -512,27 +496,24 @@ class PreserveLoggingContext(object): __slots__ = ["current_context", "new_context", "has_parent"] - def __init__(self, new_context: Optional[LoggingContextOrSentinel] = None) -> None: - if new_context is None: - self.new_context = LoggingContext.sentinel # type: LoggingContextOrSentinel - else: - self.new_context = new_context + def __init__( + self, new_context: LoggingContextOrSentinel = SENTINEL_CONTEXT + ) -> None: + self.new_context = new_context def __enter__(self) -> None: """Captures the current logging context""" - self.current_context = LoggingContext.set_current_context(self.new_context) + self.current_context = set_current_context(self.new_context) if self.current_context: self.has_parent = self.current_context.previous_context is not None - if not self.current_context.alive: - logger.debug("Entering dead context: %s", self.current_context) def __exit__(self, type, value, traceback) -> None: """Restores the current logging context""" - context = LoggingContext.set_current_context(self.current_context) + context = set_current_context(self.current_context) if context != self.new_context: - if context is LoggingContext.sentinel: + if not context: logger.warning("Expected logging context %s was lost", self.new_context) else: logger.warning( @@ -541,9 +522,30 @@ class PreserveLoggingContext(object): context, ) - if self.current_context is not LoggingContext.sentinel: - if not self.current_context.alive: - logger.debug("Restoring dead context: %s", self.current_context) + +_thread_local = threading.local() +_thread_local.current_context = SENTINEL_CONTEXT + + +def current_context() -> LoggingContextOrSentinel: + """Get the current logging context from thread local storage""" + return getattr(_thread_local, "current_context", SENTINEL_CONTEXT) + + +def set_current_context(context: LoggingContextOrSentinel) -> LoggingContextOrSentinel: + """Set the current logging context in thread local storage + Args: + context(LoggingContext): The context to activate. + Returns: + The context that was previously active + """ + current = current_context() + + if current is not context: + current.stop() + _thread_local.current_context = context + context.start() + return current def nested_logging_context( @@ -572,7 +574,7 @@ def nested_logging_context( if parent_context is not None: context = parent_context # type: LoggingContextOrSentinel else: - context = LoggingContext.current_context() + context = current_context() return LoggingContext( parent_context=context, request=str(context.request) + "-" + suffix ) @@ -604,7 +606,7 @@ def run_in_background(f, *args, **kwargs): CRITICAL error about an unhandled error will be logged without much indication about where it came from. """ - current = LoggingContext.current_context() + current = current_context() try: res = f(*args, **kwargs) except: # noqa: E722 @@ -625,7 +627,7 @@ def run_in_background(f, *args, **kwargs): # The function may have reset the context before returning, so # we need to restore it now. - ctx = LoggingContext.set_current_context(current) + ctx = set_current_context(current) # The original context will be restored when the deferred # completes, but there is nothing waiting for it, so it will @@ -674,7 +676,7 @@ def make_deferred_yieldable(deferred): # ok, we can't be sure that a yield won't block, so let's reset the # logcontext, and add a callback to the deferred to restore it. - prev_context = LoggingContext.set_current_context(LoggingContext.sentinel) + prev_context = set_current_context(SENTINEL_CONTEXT) deferred.addBoth(_set_context_cb, prev_context) return deferred @@ -684,7 +686,7 @@ ResultT = TypeVar("ResultT") def _set_context_cb(result: ResultT, context: LoggingContext) -> ResultT: """A callback function which just sets the logging context""" - LoggingContext.set_current_context(context) + set_current_context(context) return result @@ -752,7 +754,7 @@ def defer_to_threadpool(reactor, threadpool, f, *args, **kwargs): Deferred: A Deferred which fires a callback with the result of `f`, or an errback if `f` throws an exception. """ - logcontext = LoggingContext.current_context() + logcontext = current_context() def g(): with LoggingContext(parent_context=logcontext): diff --git a/synapse/logging/scopecontextmanager.py b/synapse/logging/scopecontextmanager.py index 4eed4f2338..dc3ab00cbb 100644 --- a/synapse/logging/scopecontextmanager.py +++ b/synapse/logging/scopecontextmanager.py @@ -19,7 +19,7 @@ from opentracing import Scope, ScopeManager import twisted -from synapse.logging.context import LoggingContext, nested_logging_context +from synapse.logging.context import current_context, nested_logging_context logger = logging.getLogger(__name__) @@ -49,11 +49,8 @@ class LogContextScopeManager(ScopeManager): (Scope) : the Scope that is active, or None if not available. """ - ctx = LoggingContext.current_context() - if ctx is LoggingContext.sentinel: - return None - else: - return ctx.scope + ctx = current_context() + return ctx.scope def activate(self, span, finish_on_close): """ @@ -70,9 +67,9 @@ class LogContextScopeManager(ScopeManager): """ enter_logcontext = False - ctx = LoggingContext.current_context() + ctx = current_context() - if ctx is LoggingContext.sentinel: + if not ctx: # We don't want this scope to affect. logger.error("Tried to activate scope outside of loggingcontext") return Scope(None, span) diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py index ca237c6f12..3013f49d32 100644 --- a/synapse/storage/data_stores/main/events_worker.py +++ b/synapse/storage/data_stores/main/events_worker.py @@ -35,7 +35,7 @@ from synapse.api.room_versions import ( ) from synapse.events import make_event_from_dict from synapse.events.utils import prune_event -from synapse.logging.context import LoggingContext, PreserveLoggingContext +from synapse.logging.context import PreserveLoggingContext, current_context from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause from synapse.storage.database import Database @@ -409,7 +409,7 @@ class EventsWorkerStore(SQLBaseStore): missing_events_ids = [e for e in event_ids if e not in event_entry_map] if missing_events_ids: - log_ctx = LoggingContext.current_context() + log_ctx = current_context() log_ctx.record_event_fetch(len(missing_events_ids)) # Note that _get_events_from_db is also responsible for turning db rows diff --git a/synapse/storage/database.py b/synapse/storage/database.py index e61595336c..715c0346dd 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -32,6 +32,7 @@ from synapse.config.database import DatabaseConnectionConfig from synapse.logging.context import ( LoggingContext, LoggingContextOrSentinel, + current_context, make_deferred_yieldable, ) from synapse.metrics.background_process_metrics import run_as_background_process @@ -483,7 +484,7 @@ class Database(object): end = monotonic_time() duration = end - start - LoggingContext.current_context().add_database_transaction(duration) + current_context().add_database_transaction(duration) transaction_logger.debug("[TXN END] {%s} %f sec", name, duration) @@ -510,7 +511,7 @@ class Database(object): after_callbacks = [] # type: List[_CallbackListEntry] exception_callbacks = [] # type: List[_CallbackListEntry] - if LoggingContext.current_context() == LoggingContext.sentinel: + if not current_context(): logger.warning("Starting db txn '%s' from sentinel context", desc) try: @@ -547,10 +548,8 @@ class Database(object): Returns: Deferred: The result of func """ - parent_context = ( - LoggingContext.current_context() - ) # type: Optional[LoggingContextOrSentinel] - if parent_context == LoggingContext.sentinel: + parent_context = current_context() # type: Optional[LoggingContextOrSentinel] + if not parent_context: logger.warning( "Starting db connection from sentinel context: metrics will be lost" ) diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py index 7b18455469..ec61e14423 100644 --- a/synapse/util/metrics.py +++ b/synapse/util/metrics.py @@ -21,7 +21,7 @@ from prometheus_client import Counter from twisted.internet import defer -from synapse.logging.context import LoggingContext +from synapse.logging.context import LoggingContext, current_context from synapse.metrics import InFlightGauge logger = logging.getLogger(__name__) @@ -106,7 +106,7 @@ class Measure(object): raise RuntimeError("Measure() objects cannot be re-used") self.start = self.clock.time() - parent_context = LoggingContext.current_context() + parent_context = current_context() self._logging_context = LoggingContext( "Measure[%s]" % (self.name,), parent_context ) diff --git a/synapse/util/patch_inline_callbacks.py b/synapse/util/patch_inline_callbacks.py index 3925927f9f..fdff195771 100644 --- a/synapse/util/patch_inline_callbacks.py +++ b/synapse/util/patch_inline_callbacks.py @@ -32,7 +32,7 @@ def do_patch(): Patch defer.inlineCallbacks so that it checks the state of the logcontext on exit """ - from synapse.logging.context import LoggingContext + from synapse.logging.context import current_context global _already_patched @@ -43,35 +43,35 @@ def do_patch(): def new_inline_callbacks(f): @functools.wraps(f) def wrapped(*args, **kwargs): - start_context = LoggingContext.current_context() + start_context = current_context() changes = [] # type: List[str] orig = orig_inline_callbacks(_check_yield_points(f, changes)) try: res = orig(*args, **kwargs) except Exception: - if LoggingContext.current_context() != start_context: + if current_context() != start_context: for err in changes: print(err, file=sys.stderr) err = "%s changed context from %s to %s on exception" % ( f, start_context, - LoggingContext.current_context(), + current_context(), ) print(err, file=sys.stderr) raise Exception(err) raise if not isinstance(res, Deferred) or res.called: - if LoggingContext.current_context() != start_context: + if current_context() != start_context: for err in changes: print(err, file=sys.stderr) err = "Completed %s changed context from %s to %s" % ( f, start_context, - LoggingContext.current_context(), + current_context(), ) # print the error to stderr because otherwise all we # see in travis-ci is the 500 error @@ -79,23 +79,23 @@ def do_patch(): raise Exception(err) return res - if LoggingContext.current_context() != LoggingContext.sentinel: + if current_context(): err = ( "%s returned incomplete deferred in non-sentinel context " "%s (start was %s)" - ) % (f, LoggingContext.current_context(), start_context) + ) % (f, current_context(), start_context) print(err, file=sys.stderr) raise Exception(err) def check_ctx(r): - if LoggingContext.current_context() != start_context: + if current_context() != start_context: for err in changes: print(err, file=sys.stderr) err = "%s completion of %s changed context from %s to %s" % ( "Failure" if isinstance(r, Failure) else "Success", f, start_context, - LoggingContext.current_context(), + current_context(), ) print(err, file=sys.stderr) raise Exception(err) @@ -127,7 +127,7 @@ def _check_yield_points(f: Callable, changes: List[str]): function """ - from synapse.logging.context import LoggingContext + from synapse.logging.context import current_context @functools.wraps(f) def check_yield_points_inner(*args, **kwargs): @@ -136,7 +136,7 @@ def _check_yield_points(f: Callable, changes: List[str]): last_yield_line_no = gen.gi_frame.f_lineno result = None # type: Any while True: - expected_context = LoggingContext.current_context() + expected_context = current_context() try: isFailure = isinstance(result, Failure) @@ -145,7 +145,7 @@ def _check_yield_points(f: Callable, changes: List[str]): else: d = gen.send(result) except (StopIteration, defer._DefGen_Return) as e: - if LoggingContext.current_context() != expected_context: + if current_context() != expected_context: # This happens when the context is lost sometime *after* the # final yield and returning. E.g. we forgot to yield on a # function that returns a deferred. @@ -159,7 +159,7 @@ def _check_yield_points(f: Callable, changes: List[str]): % ( f.__qualname__, expected_context, - LoggingContext.current_context(), + current_context(), f.__code__.co_filename, last_yield_line_no, ) @@ -173,13 +173,13 @@ def _check_yield_points(f: Callable, changes: List[str]): # This happens if we yield on a deferred that doesn't follow # the log context rules without wrapping in a `make_deferred_yieldable`. # We raise here as this should never happen. - if LoggingContext.current_context() is not LoggingContext.sentinel: + if current_context(): err = ( "%s yielded with context %s rather than sentinel," " yielded on line %d in %s" % ( frame.f_code.co_name, - LoggingContext.current_context(), + current_context(), frame.f_lineno, frame.f_code.co_filename, ) @@ -191,7 +191,7 @@ def _check_yield_points(f: Callable, changes: List[str]): except Exception as e: result = Failure(e) - if LoggingContext.current_context() != expected_context: + if current_context() != expected_context: # This happens because the context is lost sometime *after* the # previous yield and *after* the current yield. E.g. the @@ -206,7 +206,7 @@ def _check_yield_points(f: Callable, changes: List[str]): % ( frame.f_code.co_name, expected_context, - LoggingContext.current_context(), + current_context(), last_yield_line_no, frame.f_lineno, frame.f_code.co_filename, diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index 34d5895f18..70c8e72303 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -34,6 +34,7 @@ from synapse.crypto.keyring import ( from synapse.logging.context import ( LoggingContext, PreserveLoggingContext, + current_context, make_deferred_yieldable, ) from synapse.storage.keys import FetchKeyResult @@ -83,9 +84,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): ) def check_context(self, _, expected): - self.assertEquals( - getattr(LoggingContext.current_context(), "request", None), expected - ) + self.assertEquals(getattr(current_context(), "request", None), expected) def test_verify_json_objects_for_server_awaits_previous_requests(self): key1 = signedjson.key.generate_signing_key(1) @@ -105,7 +104,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): @defer.inlineCallbacks def get_perspectives(**kwargs): - self.assertEquals(LoggingContext.current_context().request, "11") + self.assertEquals(current_context().request, "11") with PreserveLoggingContext(): yield persp_deferred return persp_resp diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index fdc1d918ff..562397cdda 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -38,7 +38,7 @@ from synapse.http.federation.well_known_resolver import ( WellKnownResolver, _cache_period_from_headers, ) -from synapse.logging.context import LoggingContext +from synapse.logging.context import SENTINEL_CONTEXT, LoggingContext, current_context from synapse.util.caches.ttlcache import TTLCache from tests import unittest @@ -155,7 +155,7 @@ class MatrixFederationAgentTests(unittest.TestCase): self.assertNoResult(fetch_d) # should have reset logcontext to the sentinel - _check_logcontext(LoggingContext.sentinel) + _check_logcontext(SENTINEL_CONTEXT) try: fetch_res = yield fetch_d @@ -1197,7 +1197,7 @@ class TestCachePeriodFromHeaders(unittest.TestCase): def _check_logcontext(context): - current = LoggingContext.current_context() + current = current_context() if current is not context: raise AssertionError("Expected logcontext %s but was %s" % (context, current)) diff --git a/tests/http/federation/test_srv_resolver.py b/tests/http/federation/test_srv_resolver.py index df034ab237..babc201643 100644 --- a/tests/http/federation/test_srv_resolver.py +++ b/tests/http/federation/test_srv_resolver.py @@ -22,7 +22,7 @@ from twisted.internet.error import ConnectError from twisted.names import dns, error from synapse.http.federation.srv_resolver import SrvResolver -from synapse.logging.context import LoggingContext +from synapse.logging.context import SENTINEL_CONTEXT, LoggingContext, current_context from tests import unittest from tests.utils import MockClock @@ -54,12 +54,12 @@ class SrvResolverTestCase(unittest.TestCase): self.assertNoResult(resolve_d) # should have reset to the sentinel context - self.assertIs(LoggingContext.current_context(), LoggingContext.sentinel) + self.assertIs(current_context(), SENTINEL_CONTEXT) result = yield resolve_d # should have restored our context - self.assertIs(LoggingContext.current_context(), ctx) + self.assertIs(current_context(), ctx) return result diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py index 2b01f40a42..fff4f0cbf4 100644 --- a/tests/http/test_fedclient.py +++ b/tests/http/test_fedclient.py @@ -29,14 +29,14 @@ from synapse.http.matrixfederationclient import ( MatrixFederationHttpClient, MatrixFederationRequest, ) -from synapse.logging.context import LoggingContext +from synapse.logging.context import SENTINEL_CONTEXT, LoggingContext, current_context from tests.server import FakeTransport from tests.unittest import HomeserverTestCase def check_logcontext(context): - current = LoggingContext.current_context() + current = current_context() if current is not context: raise AssertionError("Expected logcontext %s but was %s" % (context, current)) @@ -64,7 +64,7 @@ class FederationClientTests(HomeserverTestCase): self.assertNoResult(fetch_d) # should have reset logcontext to the sentinel - check_logcontext(LoggingContext.sentinel) + check_logcontext(SENTINEL_CONTEXT) try: fetch_res = yield fetch_d diff --git a/tests/rest/client/test_transactions.py b/tests/rest/client/test_transactions.py index a3d7e3c046..171632e195 100644 --- a/tests/rest/client/test_transactions.py +++ b/tests/rest/client/test_transactions.py @@ -2,7 +2,7 @@ from mock import Mock, call from twisted.internet import defer, reactor -from synapse.logging.context import LoggingContext +from synapse.logging.context import SENTINEL_CONTEXT, LoggingContext, current_context from synapse.rest.client.transactions import CLEANUP_PERIOD_MS, HttpTransactionCache from synapse.util import Clock @@ -52,14 +52,14 @@ class HttpTransactionCacheTestCase(unittest.TestCase): def test(): with LoggingContext("c") as c1: res = yield self.cache.fetch_or_execute(self.mock_key, cb) - self.assertIs(LoggingContext.current_context(), c1) + self.assertIs(current_context(), c1) self.assertEqual(res, "yay") # run the test twice in parallel d = defer.gatherResults([test(), test()]) - self.assertIs(LoggingContext.current_context(), LoggingContext.sentinel) + self.assertIs(current_context(), SENTINEL_CONTEXT) yield d - self.assertIs(LoggingContext.current_context(), LoggingContext.sentinel) + self.assertIs(current_context(), SENTINEL_CONTEXT) @defer.inlineCallbacks def test_does_not_cache_exceptions(self): @@ -81,11 +81,11 @@ class HttpTransactionCacheTestCase(unittest.TestCase): yield self.cache.fetch_or_execute(self.mock_key, cb) except Exception as e: self.assertEqual(e.args[0], "boo") - self.assertIs(LoggingContext.current_context(), test_context) + self.assertIs(current_context(), test_context) res = yield self.cache.fetch_or_execute(self.mock_key, cb) self.assertEqual(res, self.mock_http_response) - self.assertIs(LoggingContext.current_context(), test_context) + self.assertIs(current_context(), test_context) @defer.inlineCallbacks def test_does_not_cache_failures(self): @@ -107,11 +107,11 @@ class HttpTransactionCacheTestCase(unittest.TestCase): yield self.cache.fetch_or_execute(self.mock_key, cb) except Exception as e: self.assertEqual(e.args[0], "boo") - self.assertIs(LoggingContext.current_context(), test_context) + self.assertIs(current_context(), test_context) res = yield self.cache.fetch_or_execute(self.mock_key, cb) self.assertEqual(res, self.mock_http_response) - self.assertIs(LoggingContext.current_context(), test_context) + self.assertIs(current_context(), test_context) @defer.inlineCallbacks def test_cleans_up(self): diff --git a/tests/unittest.py b/tests/unittest.py index 8816a4d152..439174dbfc 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -38,7 +38,11 @@ from synapse.config.ratelimiting import FederationRateLimitConfig from synapse.federation.transport import server as federation_server from synapse.http.server import JsonResource from synapse.http.site import SynapseRequest, SynapseSite -from synapse.logging.context import LoggingContext +from synapse.logging.context import ( + SENTINEL_CONTEXT, + current_context, + set_current_context, +) from synapse.server import HomeServer from synapse.types import Requester, UserID, create_requester from synapse.util.ratelimitutils import FederationRateLimiter @@ -97,10 +101,10 @@ class TestCase(unittest.TestCase): def setUp(orig): # if we're not starting in the sentinel logcontext, then to be honest # all future bets are off. - if LoggingContext.current_context() is not LoggingContext.sentinel: + if current_context(): self.fail( "Test starting with non-sentinel logging context %s" - % (LoggingContext.current_context(),) + % (current_context(),) ) old_level = logging.getLogger().level @@ -122,7 +126,7 @@ class TestCase(unittest.TestCase): # force a GC to workaround problems with deferreds leaking logcontexts when # they are GCed (see the logcontext docs) gc.collect() - LoggingContext.set_current_context(LoggingContext.sentinel) + set_current_context(SENTINEL_CONTEXT) return ret diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index 39e360fe24..4d2b9e0d64 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -22,8 +22,10 @@ from twisted.internet import defer, reactor from synapse.api.errors import SynapseError from synapse.logging.context import ( + SENTINEL_CONTEXT, LoggingContext, PreserveLoggingContext, + current_context, make_deferred_yieldable, ) from synapse.util.caches import descriptors @@ -194,7 +196,7 @@ class DescriptorTestCase(unittest.TestCase): with LoggingContext() as c1: c1.name = "c1" r = yield obj.fn(1) - self.assertEqual(LoggingContext.current_context(), c1) + self.assertEqual(current_context(), c1) return r def check_result(r): @@ -204,12 +206,12 @@ class DescriptorTestCase(unittest.TestCase): # set off a deferred which will do a cache lookup d1 = do_lookup() - self.assertEqual(LoggingContext.current_context(), LoggingContext.sentinel) + self.assertEqual(current_context(), SENTINEL_CONTEXT) d1.addCallback(check_result) # and another d2 = do_lookup() - self.assertEqual(LoggingContext.current_context(), LoggingContext.sentinel) + self.assertEqual(current_context(), SENTINEL_CONTEXT) d2.addCallback(check_result) # let the lookup complete @@ -239,14 +241,14 @@ class DescriptorTestCase(unittest.TestCase): try: d = obj.fn(1) self.assertEqual( - LoggingContext.current_context(), LoggingContext.sentinel + current_context(), SENTINEL_CONTEXT, ) yield d self.fail("No exception thrown") except SynapseError: pass - self.assertEqual(LoggingContext.current_context(), c1) + self.assertEqual(current_context(), c1) # the cache should now be empty self.assertEqual(len(obj.fn.cache.cache), 0) @@ -255,7 +257,7 @@ class DescriptorTestCase(unittest.TestCase): # set off a deferred which will do a cache lookup d1 = do_lookup() - self.assertEqual(LoggingContext.current_context(), LoggingContext.sentinel) + self.assertEqual(current_context(), SENTINEL_CONTEXT) return d1 @@ -366,10 +368,10 @@ class CachedListDescriptorTestCase(unittest.TestCase): @descriptors.cachedList("fn", "args1", inlineCallbacks=True) def list_fn(self, args1, arg2): - assert LoggingContext.current_context().request == "c1" + assert current_context().request == "c1" # we want this to behave like an asynchronous function yield run_on_reactor() - assert LoggingContext.current_context().request == "c1" + assert current_context().request == "c1" return self.mock(args1, arg2) with LoggingContext() as c1: @@ -377,9 +379,9 @@ class CachedListDescriptorTestCase(unittest.TestCase): obj = Cls() obj.mock.return_value = {10: "fish", 20: "chips"} d1 = obj.list_fn([10, 20], 2) - self.assertEqual(LoggingContext.current_context(), LoggingContext.sentinel) + self.assertEqual(current_context(), SENTINEL_CONTEXT) r = yield d1 - self.assertEqual(LoggingContext.current_context(), c1) + self.assertEqual(current_context(), c1) obj.mock.assert_called_once_with([10, 20], 2) self.assertEqual(r, {10: "fish", 20: "chips"}) obj.mock.reset_mock() diff --git a/tests/util/test_async_utils.py b/tests/util/test_async_utils.py index f60918069a..17fd86d02d 100644 --- a/tests/util/test_async_utils.py +++ b/tests/util/test_async_utils.py @@ -16,7 +16,12 @@ from twisted.internet import defer from twisted.internet.defer import CancelledError, Deferred from twisted.internet.task import Clock -from synapse.logging.context import LoggingContext, PreserveLoggingContext +from synapse.logging.context import ( + SENTINEL_CONTEXT, + LoggingContext, + PreserveLoggingContext, + current_context, +) from synapse.util.async_helpers import timeout_deferred from tests.unittest import TestCase @@ -79,10 +84,10 @@ class TimeoutDeferredTest(TestCase): # the errbacks should be run in the test logcontext def errback(res, deferred_name): self.assertIs( - LoggingContext.current_context(), + current_context(), context_one, "errback %s run in unexpected logcontext %s" - % (deferred_name, LoggingContext.current_context()), + % (deferred_name, current_context()), ) return res @@ -90,7 +95,7 @@ class TimeoutDeferredTest(TestCase): original_deferred.addErrback(errback, "orig") timing_out_d = timeout_deferred(original_deferred, 1.0, self.clock) self.assertNoResult(timing_out_d) - self.assertIs(LoggingContext.current_context(), LoggingContext.sentinel) + self.assertIs(current_context(), SENTINEL_CONTEXT) timing_out_d.addErrback(errback, "timingout") self.clock.pump((1.0,)) @@ -99,4 +104,4 @@ class TimeoutDeferredTest(TestCase): blocking_was_cancelled[0], "non-completing deferred was not cancelled" ) self.failureResultOf(timing_out_d, defer.TimeoutError) - self.assertIs(LoggingContext.current_context(), context_one) + self.assertIs(current_context(), context_one) diff --git a/tests/util/test_linearizer.py b/tests/util/test_linearizer.py index 0ec8ef90ce..852ef23185 100644 --- a/tests/util/test_linearizer.py +++ b/tests/util/test_linearizer.py @@ -19,7 +19,7 @@ from six.moves import range from twisted.internet import defer, reactor from twisted.internet.defer import CancelledError -from synapse.logging.context import LoggingContext +from synapse.logging.context import LoggingContext, current_context from synapse.util import Clock from synapse.util.async_helpers import Linearizer @@ -54,11 +54,11 @@ class LinearizerTestCase(unittest.TestCase): def func(i, sleep=False): with LoggingContext("func(%s)" % i) as lc: with (yield linearizer.queue("")): - self.assertEqual(LoggingContext.current_context(), lc) + self.assertEqual(current_context(), lc) if sleep: yield Clock(reactor).sleep(0) - self.assertEqual(LoggingContext.current_context(), lc) + self.assertEqual(current_context(), lc) func(0, sleep=True) for i in range(1, 100): diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py index 281b32c4b8..95301c013c 100644 --- a/tests/util/test_logcontext.py +++ b/tests/util/test_logcontext.py @@ -2,8 +2,10 @@ import twisted.python.failure from twisted.internet import defer, reactor from synapse.logging.context import ( + SENTINEL_CONTEXT, LoggingContext, PreserveLoggingContext, + current_context, make_deferred_yieldable, nested_logging_context, run_in_background, @@ -15,7 +17,7 @@ from .. import unittest class LoggingContextTestCase(unittest.TestCase): def _check_test_key(self, value): - self.assertEquals(LoggingContext.current_context().request, value) + self.assertEquals(current_context().request, value) def test_with_context(self): with LoggingContext() as context_one: @@ -41,7 +43,7 @@ class LoggingContextTestCase(unittest.TestCase): self._check_test_key("one") def _test_run_in_background(self, function): - sentinel_context = LoggingContext.current_context() + sentinel_context = current_context() callback_completed = [False] @@ -71,7 +73,7 @@ class LoggingContextTestCase(unittest.TestCase): # make sure that the context was reset before it got thrown back # into the reactor try: - self.assertIs(LoggingContext.current_context(), sentinel_context) + self.assertIs(current_context(), sentinel_context) d2.callback(None) except BaseException: d2.errback(twisted.python.failure.Failure()) @@ -108,7 +110,7 @@ class LoggingContextTestCase(unittest.TestCase): async def testfunc(): self._check_test_key("one") d = Clock(reactor).sleep(0) - self.assertIs(LoggingContext.current_context(), LoggingContext.sentinel) + self.assertIs(current_context(), SENTINEL_CONTEXT) await d self._check_test_key("one") @@ -129,14 +131,14 @@ class LoggingContextTestCase(unittest.TestCase): reactor.callLater(0, d.callback, None) return d - sentinel_context = LoggingContext.current_context() + sentinel_context = current_context() with LoggingContext() as context_one: context_one.request = "one" d1 = make_deferred_yieldable(blocking_function()) # make sure that the context was reset by make_deferred_yieldable - self.assertIs(LoggingContext.current_context(), sentinel_context) + self.assertIs(current_context(), sentinel_context) yield d1 @@ -145,14 +147,14 @@ class LoggingContextTestCase(unittest.TestCase): @defer.inlineCallbacks def test_make_deferred_yieldable_with_chained_deferreds(self): - sentinel_context = LoggingContext.current_context() + sentinel_context = current_context() with LoggingContext() as context_one: context_one.request = "one" d1 = make_deferred_yieldable(_chained_deferred_function()) # make sure that the context was reset by make_deferred_yieldable - self.assertIs(LoggingContext.current_context(), sentinel_context) + self.assertIs(current_context(), sentinel_context) yield d1 @@ -189,14 +191,14 @@ class LoggingContextTestCase(unittest.TestCase): reactor.callLater(0, d.callback, None) await d - sentinel_context = LoggingContext.current_context() + sentinel_context = current_context() with LoggingContext() as context_one: context_one.request = "one" d1 = make_deferred_yieldable(blocking_function()) # make sure that the context was reset by make_deferred_yieldable - self.assertIs(LoggingContext.current_context(), sentinel_context) + self.assertIs(current_context(), sentinel_context) yield d1 diff --git a/tests/utils.py b/tests/utils.py index 513f358f4f..968d109f77 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -35,7 +35,7 @@ from synapse.config.homeserver import HomeServerConfig from synapse.config.server import DEFAULT_ROOM_VERSION from synapse.federation.transport import server as federation_server from synapse.http.server import HttpServer -from synapse.logging.context import LoggingContext +from synapse.logging.context import current_context, set_current_context from synapse.server import HomeServer from synapse.storage import DataStore from synapse.storage.engines import PostgresEngine, create_engine @@ -493,10 +493,10 @@ class MockClock(object): return self.time() * 1000 def call_later(self, delay, callback, *args, **kwargs): - current_context = LoggingContext.current_context() + ctx = current_context() def wrapped_callback(): - LoggingContext.thread_local.current_context = current_context + set_current_context(ctx) callback(*args, **kwargs) t = [self.now + delay, wrapped_callback, False] -- cgit 1.5.1 From 28d9d6e8a9d6a6d5162de41cada1b6d6d4b0f941 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 24 Mar 2020 18:33:49 +0000 Subject: Remove spurious "name" parameter to `default_config` this is never set to anything other than "test", and is a source of unnecessary boilerplate. --- tests/app/test_frontend_proxy.py | 4 ++-- tests/app/test_openid_listener.py | 4 ++-- tests/federation/test_complexity.py | 4 ++-- tests/handlers/test_register.py | 2 +- tests/rest/client/v2_alpha/test_register.py | 4 ++-- tests/rest/key/v2/test_remote_key_resource.py | 4 ++-- tests/server_notices/test_resource_limits_server_notices.py | 2 +- tests/test_terms_auth.py | 4 ++-- tests/unittest.py | 7 ++----- 9 files changed, 16 insertions(+), 19 deletions(-) (limited to 'tests') diff --git a/tests/app/test_frontend_proxy.py b/tests/app/test_frontend_proxy.py index d3feafa1b7..be20a89682 100644 --- a/tests/app/test_frontend_proxy.py +++ b/tests/app/test_frontend_proxy.py @@ -27,8 +27,8 @@ class FrontendProxyTests(HomeserverTestCase): return hs - def default_config(self, name="test"): - c = super().default_config(name) + def default_config(self): + c = super().default_config() c["worker_app"] = "synapse.app.frontend_proxy" return c diff --git a/tests/app/test_openid_listener.py b/tests/app/test_openid_listener.py index 89fcc3889a..7364f9f1ec 100644 --- a/tests/app/test_openid_listener.py +++ b/tests/app/test_openid_listener.py @@ -29,8 +29,8 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase): ) return hs - def default_config(self, name="test"): - conf = super().default_config(name) + def default_config(self): + conf = super().default_config() # we're using FederationReaderServer, which uses a SlavedStore, so we # have to tell the FederationHandler not to try to access stuff that is only # in the primary store. diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py index 24fa8dbb45..94980733c4 100644 --- a/tests/federation/test_complexity.py +++ b/tests/federation/test_complexity.py @@ -33,8 +33,8 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): login.register_servlets, ] - def default_config(self, name="test"): - config = super().default_config(name=name) + def default_config(self): + config = super().default_config() config["limit_remote_rooms"] = {"enabled": True, "complexity": 0.05} return config diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index e2915eb7b1..e7b638dbfe 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -34,7 +34,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): """ Tests the RegistrationHandler. """ def make_homeserver(self, reactor, clock): - hs_config = self.default_config("test") + hs_config = self.default_config() # some of the tests rely on us having a user consent version hs_config["user_consent"] = { diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index d0c997e385..b6ed06e02d 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -36,8 +36,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): servlets = [register.register_servlets] url = b"/_matrix/client/r0/register" - def default_config(self, name="test"): - config = super().default_config(name) + def default_config(self): + config = super().default_config() config["allow_guest_access"] = True return config diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py index 6776a56cad..99eb477149 100644 --- a/tests/rest/key/v2/test_remote_key_resource.py +++ b/tests/rest/key/v2/test_remote_key_resource.py @@ -143,8 +143,8 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase): endpoint, to check that the two implementations are compatible. """ - def default_config(self, *args, **kwargs): - config = super().default_config(*args, **kwargs) + def default_config(self): + config = super().default_config() # replace the signing key with our own self.hs_signing_key = signedjson.key.generate_signing_key("kssk") diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index eb540e34f6..0d27b92a86 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -28,7 +28,7 @@ from tests import unittest class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): - hs_config = self.default_config("test") + hs_config = self.default_config() hs_config["server_notices"] = { "system_mxid_localpart": "server", "system_mxid_display_name": "test display name", diff --git a/tests/test_terms_auth.py b/tests/test_terms_auth.py index 5ec5d2b358..81d796f3f3 100644 --- a/tests/test_terms_auth.py +++ b/tests/test_terms_auth.py @@ -28,8 +28,8 @@ from tests import unittest class TermsTestCase(unittest.HomeserverTestCase): servlets = [register_servlets] - def default_config(self, name="test"): - config = super().default_config(name) + def default_config(self): + config = super().default_config() config.update( { "public_baseurl": "https://example.org/", diff --git a/tests/unittest.py b/tests/unittest.py index 8816a4d152..23b59bea22 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -311,14 +311,11 @@ class HomeserverTestCase(TestCase): return resource - def default_config(self, name="test"): + def default_config(self): """ Get a default HomeServer config dict. - - Args: - name (str): The homeserver name/domain. """ - config = default_config(name) + config = default_config("test") # apply any additional config which was specified via the override_config # decorator. -- cgit 1.5.1 From 4cff617df1ba6f241fee6957cc44859f57edcc0e Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 25 Mar 2020 14:54:01 +0000 Subject: Move catchup of replication streams to worker. (#7024) This changes the replication protocol so that the server does not send down `RDATA` for rows that happened before the client connected. Instead, the server will send a `POSITION` and clients then query the database (or master out of band) to get up to date. --- changelog.d/7024.misc | 1 + docs/tcp_replication.md | 46 ++--- synapse/app/generic_worker.py | 3 + synapse/federation/sender/__init__.py | 9 + synapse/replication/http/__init__.py | 2 + synapse/replication/http/streams.py | 78 ++++++++ synapse/replication/slave/storage/_base.py | 14 +- synapse/replication/slave/storage/pushers.py | 3 + synapse/replication/tcp/client.py | 3 +- synapse/replication/tcp/commands.py | 34 +--- synapse/replication/tcp/protocol.py | 206 ++++++++-------------- synapse/replication/tcp/resource.py | 19 +- synapse/replication/tcp/streams/__init__.py | 8 +- synapse/replication/tcp/streams/_base.py | 160 +++++++++++------ synapse/replication/tcp/streams/events.py | 5 +- synapse/replication/tcp/streams/federation.py | 19 +- synapse/server.py | 5 + synapse/storage/data_stores/main/cache.py | 44 ++--- synapse/storage/data_stores/main/deviceinbox.py | 88 ++++----- synapse/storage/data_stores/main/events.py | 114 ------------ synapse/storage/data_stores/main/events_worker.py | 114 ++++++++++++ synapse/storage/data_stores/main/room.py | 40 ++--- tests/replication/tcp/streams/_base.py | 55 ++++-- tests/replication/tcp/streams/test_receipts.py | 52 +++++- 24 files changed, 635 insertions(+), 487 deletions(-) create mode 100644 changelog.d/7024.misc create mode 100644 synapse/replication/http/streams.py (limited to 'tests') diff --git a/changelog.d/7024.misc b/changelog.d/7024.misc new file mode 100644 index 0000000000..676f285377 --- /dev/null +++ b/changelog.d/7024.misc @@ -0,0 +1 @@ +Move catchup of replication streams logic to worker. diff --git a/docs/tcp_replication.md b/docs/tcp_replication.md index e3a4634b14..d4f7d9ec18 100644 --- a/docs/tcp_replication.md +++ b/docs/tcp_replication.md @@ -14,16 +14,16 @@ example flow would be (where '>' indicates master to worker and '<' worker to master flows): > SERVER example.com - < REPLICATE events 53 + < REPLICATE + > POSITION events 53 > RDATA events 54 ["$foo1:bar.com", ...] > RDATA events 55 ["$foo4:bar.com", ...] -The example shows the server accepting a new connection and sending its -identity with the `SERVER` command, followed by the client asking to -subscribe to the `events` stream from the token `53`. The server then -periodically sends `RDATA` commands which have the format -`RDATA `, where the format of `` is -defined by the individual streams. +The example shows the server accepting a new connection and sending its identity +with the `SERVER` command, followed by the client server to respond with the +position of all streams. The server then periodically sends `RDATA` commands +which have the format `RDATA `, where the format of +`` is defined by the individual streams. Error reporting happens by either the client or server sending an ERROR command, and usually the connection will be closed. @@ -32,9 +32,6 @@ Since the protocol is a simple line based, its possible to manually connect to the server using a tool like netcat. A few things should be noted when manually using the protocol: -- When subscribing to a stream using `REPLICATE`, the special token - `NOW` can be used to get all future updates. The special stream name - `ALL` can be used with `NOW` to subscribe to all available streams. - The federation stream is only available if federation sending has been disabled on the main process. - The server will only time connections out that have sent a `PING` @@ -91,9 +88,7 @@ The client: - Sends a `NAME` command, allowing the server to associate a human friendly name with the connection. This is optional. - Sends a `PING` as above -- For each stream the client wishes to subscribe to it sends a - `REPLICATE` with the `stream_name` and token it wants to subscribe - from. +- Sends a `REPLICATE` to get the current position of all streams. - On receipt of a `SERVER` command, checks that the server name matches the expected server name. @@ -140,9 +135,7 @@ the wire: > PING 1490197665618 < NAME synapse.app.appservice < PING 1490197665618 - < REPLICATE events 1 - < REPLICATE backfill 1 - < REPLICATE caches 1 + < REPLICATE > POSITION events 1 > POSITION backfill 1 > POSITION caches 1 @@ -181,9 +174,9 @@ client (C): #### POSITION (S) - The position of the stream has been updated. Sent to the client - after all missing updates for a stream have been sent to the client - and they're now up to date. + On receipt of a POSITION command clients should check if they have missed any + updates, and if so then fetch them out of band. Sent in response to a + REPLICATE command (but can happen at any time). #### ERROR (S, C) @@ -199,20 +192,7 @@ client (C): #### REPLICATE (C) -Asks the server to replicate a given stream. The syntax is: - -``` - REPLICATE -``` - -Where `` may be either: - * a numeric stream_id to stream updates since (exclusive) - * `NOW` to stream all subsequent updates. - -The `` is the name of a replication stream to subscribe -to (see [here](../synapse/replication/tcp/streams/_base.py) for a list -of streams). It can also be `ALL` to subscribe to all known streams, -in which case the `` must be set to `NOW`. +Asks the server for the current position of all streams. #### USER_SYNC (C) diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index bd1733573b..fba7ad9551 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -401,6 +401,9 @@ class GenericWorkerTyping(object): self._room_serials[row.room_id] = token self._room_typing[row.room_id] = row.user_ids + def get_current_token(self) -> int: + return self._latest_room_serial + class GenericWorkerSlavedStore( # FIXME(#3714): We need to add UserDirectoryStore as we write directly diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index 233cb33daf..a477578e44 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -499,4 +499,13 @@ class FederationSender(object): self._get_per_destination_queue(destination).attempt_new_transaction() def get_current_token(self) -> int: + # Dummy implementation for case where federation sender isn't offloaded + # to a worker. return 0 + + async def get_replication_rows( + self, from_token, to_token, limit, federation_ack=None + ): + # Dummy implementation for case where federation sender isn't offloaded + # to a worker. + return [] diff --git a/synapse/replication/http/__init__.py b/synapse/replication/http/__init__.py index 28dbc6fcba..4613b2538c 100644 --- a/synapse/replication/http/__init__.py +++ b/synapse/replication/http/__init__.py @@ -21,6 +21,7 @@ from synapse.replication.http import ( membership, register, send_event, + streams, ) REPLICATION_PREFIX = "/_synapse/replication" @@ -38,3 +39,4 @@ class ReplicationRestResource(JsonResource): login.register_servlets(hs, self) register.register_servlets(hs, self) devices.register_servlets(hs, self) + streams.register_servlets(hs, self) diff --git a/synapse/replication/http/streams.py b/synapse/replication/http/streams.py new file mode 100644 index 0000000000..ffd4c61993 --- /dev/null +++ b/synapse/replication/http/streams.py @@ -0,0 +1,78 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.api.errors import SynapseError +from synapse.http.servlet import parse_integer +from synapse.replication.http._base import ReplicationEndpoint + +logger = logging.getLogger(__name__) + + +class ReplicationGetStreamUpdates(ReplicationEndpoint): + """Fetches stream updates from a server. Used for streams not persisted to + the database, e.g. typing notifications. + + The API looks like: + + GET /_synapse/replication/get_repl_stream_updates/events?from_token=0&to_token=10&limit=100 + + 200 OK + + { + updates: [ ... ], + upto_token: 10, + limited: False, + } + + """ + + NAME = "get_repl_stream_updates" + PATH_ARGS = ("stream_name",) + METHOD = "GET" + + def __init__(self, hs): + super().__init__(hs) + + # We pull the streams from the replication steamer (if we try and make + # them ourselves we end up in an import loop). + self.streams = hs.get_replication_streamer().get_streams() + + @staticmethod + def _serialize_payload(stream_name, from_token, upto_token, limit): + return {"from_token": from_token, "upto_token": upto_token, "limit": limit} + + async def _handle_request(self, request, stream_name): + stream = self.streams.get(stream_name) + if stream is None: + raise SynapseError(400, "Unknown stream") + + from_token = parse_integer(request, "from_token", required=True) + upto_token = parse_integer(request, "upto_token", required=True) + limit = parse_integer(request, "limit", required=True) + + updates, upto_token, limited = await stream.get_updates_since( + from_token, upto_token, limit + ) + + return ( + 200, + {"updates": updates, "upto_token": upto_token, "limited": limited}, + ) + + +def register_servlets(hs, http_server): + ReplicationGetStreamUpdates(hs).register(http_server) diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py index f45cbd37a0..751c799d94 100644 --- a/synapse/replication/slave/storage/_base.py +++ b/synapse/replication/slave/storage/_base.py @@ -18,8 +18,10 @@ from typing import Dict, Optional import six -from synapse.storage._base import SQLBaseStore -from synapse.storage.data_stores.main.cache import CURRENT_STATE_CACHE_NAME +from synapse.storage.data_stores.main.cache import ( + CURRENT_STATE_CACHE_NAME, + CacheInvalidationWorkerStore, +) from synapse.storage.database import Database from synapse.storage.engines import PostgresEngine @@ -35,7 +37,7 @@ def __func__(inp): return inp.__func__ -class BaseSlavedStore(SQLBaseStore): +class BaseSlavedStore(CacheInvalidationWorkerStore): def __init__(self, database: Database, db_conn, hs): super(BaseSlavedStore, self).__init__(database, db_conn, hs) if isinstance(self.database_engine, PostgresEngine): @@ -60,6 +62,12 @@ class BaseSlavedStore(SQLBaseStore): pos["caches"] = self._cache_id_gen.get_current_token() return pos + def get_cache_stream_token(self): + if self._cache_id_gen: + return self._cache_id_gen.get_current_token() + else: + return 0 + def process_replication_rows(self, stream_name, token, rows): if stream_name == "caches": if self._cache_id_gen: diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py index f22c2d44a3..bce8a3d115 100644 --- a/synapse/replication/slave/storage/pushers.py +++ b/synapse/replication/slave/storage/pushers.py @@ -33,6 +33,9 @@ class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore): result["pushers"] = self._pushers_id_gen.get_current_token() return result + def get_pushers_stream_token(self): + return self._pushers_id_gen.get_current_token() + def process_replication_rows(self, stream_name, token, rows): if stream_name == "pushers": self._pushers_id_gen.advance(token) diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 02ab5b66ea..7e7ad0f798 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -55,6 +55,7 @@ class ReplicationClientFactory(ReconnectingClientFactory): self.client_name = client_name self.handler = handler self.server_name = hs.config.server_name + self.hs = hs self._clock = hs.get_clock() # As self.clock is defined in super class hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.stopTrying) @@ -65,7 +66,7 @@ class ReplicationClientFactory(ReconnectingClientFactory): def buildProtocol(self, addr): logger.info("Connected to replication: %r", addr) return ClientReplicationStreamProtocol( - self.client_name, self.server_name, self._clock, self.handler + self.hs, self.client_name, self.server_name, self._clock, self.handler, ) def clientConnectionLost(self, connector, reason): diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index 451671412d..5a6b734094 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -136,8 +136,8 @@ class PositionCommand(Command): """Sent by the server to tell the client the stream postition without needing to send an RDATA. - Sent to the client after all missing updates for a stream have been sent - to the client and they're now up to date. + On receipt of a POSITION command clients should check if they have missed + any updates, and if so then fetch them out of band. """ NAME = "POSITION" @@ -179,42 +179,24 @@ class NameCommand(Command): class ReplicateCommand(Command): - """Sent by the client to subscribe to the stream. + """Sent by the client to subscribe to streams. Format:: - REPLICATE - - Where may be either: - * a numeric stream_id to stream updates from - * "NOW" to stream all subsequent updates. - - The can be "ALL" to subscribe to all known streams, in which - case the must be set to "NOW", i.e.:: - - REPLICATE ALL NOW + REPLICATE """ NAME = "REPLICATE" - def __init__(self, stream_name, token): - self.stream_name = stream_name - self.token = token + def __init__(self): + pass @classmethod def from_line(cls, line): - stream_name, token = line.split(" ", 1) - if token in ("NOW", "now"): - token = "NOW" - else: - token = int(token) - return cls(stream_name, token) + return cls() def to_line(self): - return " ".join((self.stream_name, str(self.token))) - - def get_logcontext_id(self): - return "REPLICATE-" + self.stream_name + return "" class UserSyncCommand(Command): diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index bc1482a9bb..f81d2e2442 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -35,9 +35,7 @@ indicate which side is sending, these are *not* included on the wire:: > PING 1490197665618 < NAME synapse.app.appservice < PING 1490197665618 - < REPLICATE events 1 - < REPLICATE backfill 1 - < REPLICATE caches 1 + < REPLICATE > POSITION events 1 > POSITION backfill 1 > POSITION caches 1 @@ -53,17 +51,15 @@ import fcntl import logging import struct from collections import defaultdict -from typing import Any, DefaultDict, Dict, List, Set, Tuple +from typing import Any, DefaultDict, Dict, List, Set -from six import iteritems, iterkeys +from six import iteritems from prometheus_client import Counter -from twisted.internet import defer from twisted.protocols.basic import LineOnlyReceiver from twisted.python.failure import Failure -from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.metrics import LaterGauge from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.tcp.commands import ( @@ -82,11 +78,16 @@ from synapse.replication.tcp.commands import ( SyncCommand, UserSyncCommand, ) -from synapse.replication.tcp.streams import STREAMS_MAP +from synapse.replication.tcp.streams import STREAMS_MAP, Stream from synapse.types import Collection from synapse.util import Clock from synapse.util.stringutils import random_string +MYPY = False +if MYPY: + from synapse.server import HomeServer + + connection_close_counter = Counter( "synapse_replication_tcp_protocol_close_reason", "", ["reason_type"] ) @@ -411,16 +412,6 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): self.server_name = server_name self.streamer = streamer - # The streams the client has subscribed to and is up to date with - self.replication_streams = set() # type: Set[str] - - # The streams the client is currently subscribing to. - self.connecting_streams = set() # type: Set[str] - - # Map from stream name to list of updates to send once we've finished - # subscribing the client to the stream. - self.pending_rdata = {} # type: Dict[str, List[Tuple[int, Any]]] - def connectionMade(self): self.send_command(ServerCommand(self.server_name)) BaseReplicationStreamProtocol.connectionMade(self) @@ -436,21 +427,10 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): ) async def on_REPLICATE(self, cmd): - stream_name = cmd.stream_name - token = cmd.token - - if stream_name == "ALL": - # Subscribe to all streams we're publishing to. - deferreds = [ - run_in_background(self.subscribe_to_stream, stream, token) - for stream in iterkeys(self.streamer.streams_by_name) - ] - - await make_deferred_yieldable( - defer.gatherResults(deferreds, consumeErrors=True) - ) - else: - await self.subscribe_to_stream(stream_name, token) + # Subscribe to all streams we're publishing to. + for stream_name in self.streamer.streams_by_name: + current_token = self.streamer.get_stream_token(stream_name) + self.send_command(PositionCommand(stream_name, current_token)) async def on_FEDERATION_ACK(self, cmd): self.streamer.federation_ack(cmd.token) @@ -474,87 +454,12 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): cmd.last_seen, ) - async def subscribe_to_stream(self, stream_name, token): - """Subscribe the remote to a stream. - - This invloves checking if they've missed anything and sending those - updates down if they have. During that time new updates for the stream - are queued and sent once we've sent down any missed updates. - """ - self.replication_streams.discard(stream_name) - self.connecting_streams.add(stream_name) - - try: - # Get missing updates - updates, current_token = await self.streamer.get_stream_updates( - stream_name, token - ) - - # Send all the missing updates - for update in updates: - token, row = update[0], update[1] - self.send_command(RdataCommand(stream_name, token, row)) - - # We send a POSITION command to ensure that they have an up to - # date token (especially useful if we didn't send any updates - # above) - self.send_command(PositionCommand(stream_name, current_token)) - - # Now we can send any updates that came in while we were subscribing - pending_rdata = self.pending_rdata.pop(stream_name, []) - updates = [] - for token, update in pending_rdata: - # If the token is null, it is part of a batch update. Batches - # are multiple updates that share a single token. To denote - # this, the token is set to None for all tokens in the batch - # except for the last. If we find a None token, we keep looking - # through tokens until we find one that is not None and then - # process all previous updates in the batch as if they had the - # final token. - if token is None: - # Store this update as part of a batch - updates.append(update) - continue - - if token <= current_token: - # This update or batch of updates is older than - # current_token, dismiss it - updates = [] - continue - - updates.append(update) - - # Send all updates that are part of this batch with the - # found token - for update in updates: - self.send_command(RdataCommand(stream_name, token, update)) - - # Clear stored updates - updates = [] - - # They're now fully subscribed - self.replication_streams.add(stream_name) - except Exception as e: - logger.exception("[%s] Failed to handle REPLICATE command", self.id()) - self.send_error("failed to handle replicate: %r", e) - finally: - self.connecting_streams.discard(stream_name) - def stream_update(self, stream_name, token, data): """Called when a new update is available to stream to clients. We need to check if the client is interested in the stream or not """ - if stream_name in self.replication_streams: - # The client is subscribed to the stream - self.send_command(RdataCommand(stream_name, token, data)) - elif stream_name in self.connecting_streams: - # The client is being subscribed to the stream - logger.debug("[%s] Queuing RDATA %r %r", self.id(), stream_name, token) - self.pending_rdata.setdefault(stream_name, []).append((token, data)) - else: - # The client isn't subscribed - logger.debug("[%s] Dropping RDATA %r %r", self.id(), stream_name, token) + self.send_command(RdataCommand(stream_name, token, data)) def send_sync(self, data): self.send_command(SyncCommand(data)) @@ -638,6 +543,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): def __init__( self, + hs: "HomeServer", client_name: str, server_name: str, clock: Clock, @@ -649,22 +555,25 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): self.server_name = server_name self.handler = handler + self.streams = { + stream.NAME: stream(hs) for stream in STREAMS_MAP.values() + } # type: Dict[str, Stream] + # Set of stream names that have been subscribe to, but haven't yet # caught up with. This is used to track when the client has been fully # connected to the remote. - self.streams_connecting = set() # type: Set[str] + self.streams_connecting = set(STREAMS_MAP) # type: Set[str] # Map of stream to batched updates. See RdataCommand for info on how # batching works. - self.pending_batches = {} # type: Dict[str, Any] + self.pending_batches = {} # type: Dict[str, List[Any]] def connectionMade(self): self.send_command(NameCommand(self.client_name)) BaseReplicationStreamProtocol.connectionMade(self) # Once we've connected subscribe to the necessary streams - for stream_name, token in iteritems(self.handler.get_streams_to_replicate()): - self.replicate(stream_name, token) + self.replicate() # Tell the server if we have any users currently syncing (should only # happen on synchrotrons) @@ -676,10 +585,6 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): # We've now finished connecting to so inform the client handler self.handler.update_connection(self) - # This will happen if we don't actually subscribe to any streams - if not self.streams_connecting: - self.handler.finished_connecting() - async def on_SERVER(self, cmd): if cmd.data != self.server_name: logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data) @@ -697,7 +602,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): ) raise - if cmd.token is None: + if cmd.token is None or stream_name in self.streams_connecting: # I.e. this is part of a batch of updates for this stream. Batch # until we get an update for the stream with a non None token self.pending_batches.setdefault(stream_name, []).append(row) @@ -707,14 +612,55 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): rows.append(row) await self.handler.on_rdata(stream_name, cmd.token, rows) - async def on_POSITION(self, cmd): - # When we get a `POSITION` command it means we've finished getting - # missing updates for the given stream, and are now up to date. + async def on_POSITION(self, cmd: PositionCommand): + stream = self.streams.get(cmd.stream_name) + if not stream: + logger.error("Got POSITION for unknown stream: %s", cmd.stream_name) + return + + # Find where we previously streamed up to. + current_token = self.handler.get_streams_to_replicate().get(cmd.stream_name) + if current_token is None: + logger.warning( + "Got POSITION for stream we're not subscribed to: %s", cmd.stream_name + ) + return + + # Fetch all updates between then and now. + limited = True + while limited: + updates, current_token, limited = await stream.get_updates_since( + current_token, cmd.token + ) + + # Check if the connection was closed underneath us, if so we bail + # rather than risk having concurrent catch ups going on. + if self.state == ConnectionStates.CLOSED: + return + + if updates: + await self.handler.on_rdata( + cmd.stream_name, + current_token, + [stream.parse_row(update[1]) for update in updates], + ) + + # We've now caught up to position sent to us, notify handler. + await self.handler.on_position(cmd.stream_name, cmd.token) + self.streams_connecting.discard(cmd.stream_name) if not self.streams_connecting: self.handler.finished_connecting() - await self.handler.on_position(cmd.stream_name, cmd.token) + # Check if the connection was closed underneath us, if so we bail + # rather than risk having concurrent catch ups going on. + if self.state == ConnectionStates.CLOSED: + return + + # Handle any RDATA that came in while we were catching up. + rows = self.pending_batches.pop(cmd.stream_name, []) + if rows: + await self.handler.on_rdata(cmd.stream_name, rows[-1].token, rows) async def on_SYNC(self, cmd): self.handler.on_sync(cmd.data) @@ -722,22 +668,12 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand): self.handler.on_remote_server_up(cmd.data) - def replicate(self, stream_name, token): + def replicate(self): """Send the subscription request to the server """ - if stream_name not in STREAMS_MAP: - raise Exception("Invalid stream name %r" % (stream_name,)) - - logger.info( - "[%s] Subscribing to replication stream: %r from %r", - self.id(), - stream_name, - token, - ) - - self.streams_connecting.add(stream_name) + logger.info("[%s] Subscribing to replication streams", self.id()) - self.send_command(ReplicateCommand(stream_name, token)) + self.send_command(ReplicateCommand()) def on_connection_closed(self): BaseReplicationStreamProtocol.on_connection_closed(self) diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index 6e2ebaf614..4374e99e32 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -17,7 +17,7 @@ import logging import random -from typing import Any, List +from typing import Any, Dict, List from six import itervalues @@ -30,7 +30,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.util.metrics import Measure, measure_func from .protocol import ServerReplicationStreamProtocol -from .streams import STREAMS_MAP +from .streams import STREAMS_MAP, Stream from .streams.federation import FederationStream stream_updates_counter = Counter( @@ -52,7 +52,7 @@ class ReplicationStreamProtocolFactory(Factory): """ def __init__(self, hs): - self.streamer = ReplicationStreamer(hs) + self.streamer = hs.get_replication_streamer() self.clock = hs.get_clock() self.server_name = hs.config.server_name @@ -133,6 +133,11 @@ class ReplicationStreamer(object): for conn in self.connections: conn.send_error("server shutting down") + def get_streams(self) -> Dict[str, Stream]: + """Get a mapp from stream name to stream instance. + """ + return self.streams_by_name + def on_notifier_poke(self): """Checks if there is actually any new data and sends it to the connections if there are. @@ -190,7 +195,8 @@ class ReplicationStreamer(object): stream.current_token(), ) try: - updates, current_token = await stream.get_updates() + updates, current_token, limited = await stream.get_updates() + self.pending_updates |= limited except Exception: logger.info("Failed to handle stream %s", stream.NAME) raise @@ -226,8 +232,7 @@ class ReplicationStreamer(object): self.pending_updates = False self.is_looping = False - @measure_func("repl.get_stream_updates") - async def get_stream_updates(self, stream_name, token): + def get_stream_token(self, stream_name): """For a given stream get all updates since token. This is called when a client first subscribes to a stream. """ @@ -235,7 +240,7 @@ class ReplicationStreamer(object): if not stream: raise Exception("unknown stream %s", stream_name) - return await stream.get_updates_since(token) + return stream.current_token() @measure_func("repl.federation_ack") def federation_ack(self, token): diff --git a/synapse/replication/tcp/streams/__init__.py b/synapse/replication/tcp/streams/__init__.py index 29199f5b46..37bcd3de66 100644 --- a/synapse/replication/tcp/streams/__init__.py +++ b/synapse/replication/tcp/streams/__init__.py @@ -24,6 +24,9 @@ Each stream is defined by the following information: current_token: The function that returns the current token for the stream update_function: The function that returns a list of updates between two tokens """ + +from typing import Dict, Type + from synapse.replication.tcp.streams._base import ( AccountDataStream, BackfillStream, @@ -35,6 +38,7 @@ from synapse.replication.tcp.streams._base import ( PushersStream, PushRulesStream, ReceiptsStream, + Stream, TagAccountDataStream, ToDeviceStream, TypingStream, @@ -63,10 +67,12 @@ STREAMS_MAP = { GroupServerStream, UserSignatureStream, ) -} +} # type: Dict[str, Type[Stream]] + __all__ = [ "STREAMS_MAP", + "Stream", "BackfillStream", "PresenceStream", "TypingStream", diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 32d9514883..c14dff6c64 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -14,13 +14,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import itertools import logging from collections import namedtuple -from typing import Any, List, Optional, Tuple +from typing import Any, Awaitable, Callable, List, Optional, Tuple import attr +from synapse.replication.http.streams import ReplicationGetStreamUpdates from synapse.types import JsonDict logger = logging.getLogger(__name__) @@ -29,6 +29,15 @@ logger = logging.getLogger(__name__) MAX_EVENTS_BEHIND = 500000 +# Some type aliases to make things a bit easier. + +# A stream position token +Token = int + +# A pair of position in stream and args used to create an instance of `ROW_TYPE`. +StreamRow = Tuple[Token, tuple] + + class Stream(object): """Base class for the streams. @@ -56,6 +65,7 @@ class Stream(object): return cls.ROW_TYPE(*row) def __init__(self, hs): + # The token from which we last asked for updates self.last_token = self.current_token() @@ -65,61 +75,46 @@ class Stream(object): """ self.last_token = self.current_token() - async def get_updates(self): + async def get_updates(self) -> Tuple[List[Tuple[Token, JsonDict]], Token, bool]: """Gets all updates since the last time this function was called (or since the stream was constructed if it hadn't been called before). Returns: - Deferred[Tuple[List[Tuple[int, Any]], int]: - Resolves to a pair ``(updates, current_token)``, where ``updates`` is a - list of ``(token, row)`` entries. ``row`` will be json-serialised and - sent over the replication steam. + A triplet `(updates, new_last_token, limited)`, where `updates` is + a list of `(token, row)` entries, `new_last_token` is the new + position in stream, and `limited` is whether there are more updates + to fetch. """ - updates, current_token = await self.get_updates_since(self.last_token) + current_token = self.current_token() + updates, current_token, limited = await self.get_updates_since( + self.last_token, current_token + ) self.last_token = current_token - return updates, current_token + return updates, current_token, limited async def get_updates_since( - self, from_token: int - ) -> Tuple[List[Tuple[int, JsonDict]], int]: + self, from_token: Token, upto_token: Token, limit: int = 100 + ) -> Tuple[List[Tuple[Token, JsonDict]], Token, bool]: """Like get_updates except allows specifying from when we should stream updates Returns: - Resolves to a pair `(updates, new_last_token)`, where `updates` is - a list of `(token, row)` entries and `new_last_token` is the new - position in stream. + A triplet `(updates, new_last_token, limited)`, where `updates` is + a list of `(token, row)` entries, `new_last_token` is the new + position in stream, and `limited` is whether there are more updates + to fetch. """ - if from_token in ("NOW", "now"): - return [], self.current_token() - - current_token = self.current_token() - from_token = int(from_token) - if from_token == current_token: - return [], current_token + if from_token == upto_token: + return [], upto_token, False - rows = await self.update_function( - from_token, current_token, limit=MAX_EVENTS_BEHIND + 1 + updates, upto_token, limited = await self.update_function( + from_token, upto_token, limit=limit, ) - - # never turn more than MAX_EVENTS_BEHIND + 1 into updates. - rows = itertools.islice(rows, MAX_EVENTS_BEHIND + 1) - - updates = [(row[0], row[1:]) for row in rows] - - # check we didn't get more rows than the limit. - # doing it like this allows the update_function to be a generator. - if len(updates) >= MAX_EVENTS_BEHIND: - raise Exception("stream %s has fallen behind" % (self.NAME)) - - # The update function didn't hit the limit, so we must have got all - # the updates to `current_token`, and can return that as our new - # stream position. - return updates, current_token + return updates, upto_token, limited def current_token(self): """Gets the current token of the underlying streams. Should be provided @@ -141,6 +136,48 @@ class Stream(object): raise NotImplementedError() +def db_query_to_update_function( + query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]] +) -> Callable[[Token, Token, int], Awaitable[Tuple[List[StreamRow], Token, bool]]]: + """Wraps a db query function which returns a list of rows to make it + suitable for use as an `update_function` for the Stream class + """ + + async def update_function(from_token, upto_token, limit): + rows = await query_function(from_token, upto_token, limit) + updates = [(row[0], row[1:]) for row in rows] + limited = False + if len(updates) == limit: + upto_token = rows[-1][0] + limited = True + + return updates, upto_token, limited + + return update_function + + +def make_http_update_function( + hs, stream_name: str +) -> Callable[[Token, Token, Token], Awaitable[Tuple[List[StreamRow], Token, bool]]]: + """Makes a suitable function for use as an `update_function` that queries + the master process for updates. + """ + + client = ReplicationGetStreamUpdates.make_client(hs) + + async def update_function( + from_token: int, upto_token: int, limit: int + ) -> Tuple[List[Tuple[int, tuple]], int, bool]: + return await client( + stream_name=stream_name, + from_token=from_token, + upto_token=upto_token, + limit=limit, + ) + + return update_function + + class BackfillStream(Stream): """We fetched some old events and either we had never seen that event before or it went from being an outlier to not. @@ -164,7 +201,7 @@ class BackfillStream(Stream): def __init__(self, hs): store = hs.get_datastore() self.current_token = store.get_current_backfill_token # type: ignore - self.update_function = store.get_all_new_backfill_event_rows # type: ignore + self.update_function = db_query_to_update_function(store.get_all_new_backfill_event_rows) # type: ignore super(BackfillStream, self).__init__(hs) @@ -190,8 +227,15 @@ class PresenceStream(Stream): store = hs.get_datastore() presence_handler = hs.get_presence_handler() + self._is_worker = hs.config.worker_app is not None + self.current_token = store.get_current_presence_token # type: ignore - self.update_function = presence_handler.get_all_presence_updates # type: ignore + + if hs.config.worker_app is None: + self.update_function = db_query_to_update_function(presence_handler.get_all_presence_updates) # type: ignore + else: + # Query master process + self.update_function = make_http_update_function(hs, self.NAME) # type: ignore super(PresenceStream, self).__init__(hs) @@ -208,7 +252,12 @@ class TypingStream(Stream): typing_handler = hs.get_typing_handler() self.current_token = typing_handler.get_current_token # type: ignore - self.update_function = typing_handler.get_all_typing_updates # type: ignore + + if hs.config.worker_app is None: + self.update_function = db_query_to_update_function(typing_handler.get_all_typing_updates) # type: ignore + else: + # Query master process + self.update_function = make_http_update_function(hs, self.NAME) # type: ignore super(TypingStream, self).__init__(hs) @@ -232,7 +281,7 @@ class ReceiptsStream(Stream): store = hs.get_datastore() self.current_token = store.get_max_receipt_stream_id # type: ignore - self.update_function = store.get_all_updated_receipts # type: ignore + self.update_function = db_query_to_update_function(store.get_all_updated_receipts) # type: ignore super(ReceiptsStream, self).__init__(hs) @@ -256,7 +305,13 @@ class PushRulesStream(Stream): async def update_function(self, from_token, to_token, limit): rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit) - return [(row[0], row[2]) for row in rows] + + limited = False + if len(rows) == limit: + to_token = rows[-1][0] + limited = True + + return [(row[0], (row[2],)) for row in rows], to_token, limited class PushersStream(Stream): @@ -275,7 +330,7 @@ class PushersStream(Stream): store = hs.get_datastore() self.current_token = store.get_pushers_stream_token # type: ignore - self.update_function = store.get_all_updated_pushers_rows # type: ignore + self.update_function = db_query_to_update_function(store.get_all_updated_pushers_rows) # type: ignore super(PushersStream, self).__init__(hs) @@ -307,7 +362,7 @@ class CachesStream(Stream): store = hs.get_datastore() self.current_token = store.get_cache_stream_token # type: ignore - self.update_function = store.get_all_updated_caches # type: ignore + self.update_function = db_query_to_update_function(store.get_all_updated_caches) # type: ignore super(CachesStream, self).__init__(hs) @@ -333,7 +388,7 @@ class PublicRoomsStream(Stream): store = hs.get_datastore() self.current_token = store.get_current_public_room_stream_id # type: ignore - self.update_function = store.get_all_new_public_rooms # type: ignore + self.update_function = db_query_to_update_function(store.get_all_new_public_rooms) # type: ignore super(PublicRoomsStream, self).__init__(hs) @@ -354,7 +409,7 @@ class DeviceListsStream(Stream): store = hs.get_datastore() self.current_token = store.get_device_stream_token # type: ignore - self.update_function = store.get_all_device_list_changes_for_remotes # type: ignore + self.update_function = db_query_to_update_function(store.get_all_device_list_changes_for_remotes) # type: ignore super(DeviceListsStream, self).__init__(hs) @@ -372,7 +427,7 @@ class ToDeviceStream(Stream): store = hs.get_datastore() self.current_token = store.get_to_device_stream_token # type: ignore - self.update_function = store.get_all_new_device_messages # type: ignore + self.update_function = db_query_to_update_function(store.get_all_new_device_messages) # type: ignore super(ToDeviceStream, self).__init__(hs) @@ -392,7 +447,7 @@ class TagAccountDataStream(Stream): store = hs.get_datastore() self.current_token = store.get_max_account_data_stream_id # type: ignore - self.update_function = store.get_all_updated_tags # type: ignore + self.update_function = db_query_to_update_function(store.get_all_updated_tags) # type: ignore super(TagAccountDataStream, self).__init__(hs) @@ -412,10 +467,11 @@ class AccountDataStream(Stream): self.store = hs.get_datastore() self.current_token = self.store.get_max_account_data_stream_id # type: ignore + self.update_function = db_query_to_update_function(self._update_function) # type: ignore super(AccountDataStream, self).__init__(hs) - async def update_function(self, from_token, to_token, limit): + async def _update_function(self, from_token, to_token, limit): global_results, room_results = await self.store.get_all_updated_account_data( from_token, from_token, to_token, limit ) @@ -442,7 +498,7 @@ class GroupServerStream(Stream): store = hs.get_datastore() self.current_token = store.get_group_stream_token # type: ignore - self.update_function = store.get_all_groups_changes # type: ignore + self.update_function = db_query_to_update_function(store.get_all_groups_changes) # type: ignore super(GroupServerStream, self).__init__(hs) @@ -460,6 +516,6 @@ class UserSignatureStream(Stream): store = hs.get_datastore() self.current_token = store.get_device_stream_token # type: ignore - self.update_function = store.get_all_user_signature_changes_for_remotes # type: ignore + self.update_function = db_query_to_update_function(store.get_all_user_signature_changes_for_remotes) # type: ignore super(UserSignatureStream, self).__init__(hs) diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index b3afabb8cd..c6a595629f 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -19,7 +19,7 @@ from typing import Tuple, Type import attr -from ._base import Stream +from ._base import Stream, db_query_to_update_function """Handling of the 'events' replication stream @@ -117,10 +117,11 @@ class EventsStream(Stream): def __init__(self, hs): self._store = hs.get_datastore() self.current_token = self._store.get_current_events_token # type: ignore + self.update_function = db_query_to_update_function(self._update_function) # type: ignore super(EventsStream, self).__init__(hs) - async def update_function(self, from_token, current_token, limit=None): + async def _update_function(self, from_token, current_token, limit=None): event_rows = await self._store.get_all_new_forward_event_rows( from_token, current_token, limit ) diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py index f5f9336430..48c1d45718 100644 --- a/synapse/replication/tcp/streams/federation.py +++ b/synapse/replication/tcp/streams/federation.py @@ -15,7 +15,9 @@ # limitations under the License. from collections import namedtuple -from ._base import Stream +from twisted.internet import defer + +from synapse.replication.tcp.streams._base import Stream, db_query_to_update_function class FederationStream(Stream): @@ -33,11 +35,18 @@ class FederationStream(Stream): NAME = "federation" ROW_TYPE = FederationStreamRow + _QUERY_MASTER = True def __init__(self, hs): - federation_sender = hs.get_federation_sender() - - self.current_token = federation_sender.get_current_token # type: ignore - self.update_function = federation_sender.get_replication_rows # type: ignore + # Not all synapse instances will have a federation sender instance, + # whether that's a `FederationSender` or a `FederationRemoteSendQueue`, + # so we stub the stream out when that is the case. + if hs.config.worker_app is None or hs.should_send_federation(): + federation_sender = hs.get_federation_sender() + self.current_token = federation_sender.get_current_token # type: ignore + self.update_function = db_query_to_update_function(federation_sender.get_replication_rows) # type: ignore + else: + self.current_token = lambda: 0 # type: ignore + self.update_function = lambda from_token, upto_token, limit: defer.succeed(([], upto_token, bool)) # type: ignore super(FederationStream, self).__init__(hs) diff --git a/synapse/server.py b/synapse/server.py index 1b980371de..9426eb1672 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -85,6 +85,7 @@ from synapse.http.matrixfederationclient import MatrixFederationHttpClient from synapse.notifier import Notifier from synapse.push.action_generator import ActionGenerator from synapse.push.pusherpool import PusherPool +from synapse.replication.tcp.resource import ReplicationStreamer from synapse.rest.media.v1.media_repository import ( MediaRepository, MediaRepositoryResource, @@ -199,6 +200,7 @@ class HomeServer(object): "saml_handler", "event_client_serializer", "storage", + "replication_streamer", ] REQUIRED_ON_MASTER_STARTUP = ["user_directory_handler", "stats_handler"] @@ -536,6 +538,9 @@ class HomeServer(object): def build_storage(self) -> Storage: return Storage(self, self.datastores) + def build_replication_streamer(self) -> ReplicationStreamer: + return ReplicationStreamer(self) + def remove_pusher(self, app_id, push_key, user_id): return self.get_pusherpool().remove_pusher(app_id, push_key, user_id) diff --git a/synapse/storage/data_stores/main/cache.py b/synapse/storage/data_stores/main/cache.py index d4c44dcc75..4dc5da3fe8 100644 --- a/synapse/storage/data_stores/main/cache.py +++ b/synapse/storage/data_stores/main/cache.py @@ -32,7 +32,29 @@ logger = logging.getLogger(__name__) CURRENT_STATE_CACHE_NAME = "cs_cache_fake" -class CacheInvalidationStore(SQLBaseStore): +class CacheInvalidationWorkerStore(SQLBaseStore): + def get_all_updated_caches(self, last_id, current_id, limit): + if last_id == current_id: + return defer.succeed([]) + + def get_all_updated_caches_txn(txn): + # We purposefully don't bound by the current token, as we want to + # send across cache invalidations as quickly as possible. Cache + # invalidations are idempotent, so duplicates are fine. + sql = ( + "SELECT stream_id, cache_func, keys, invalidation_ts" + " FROM cache_invalidation_stream" + " WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?" + ) + txn.execute(sql, (last_id, limit)) + return txn.fetchall() + + return self.db.runInteraction( + "get_all_updated_caches", get_all_updated_caches_txn + ) + + +class CacheInvalidationStore(CacheInvalidationWorkerStore): async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]): """Invalidates the cache and adds it to the cache stream so slaves will know to invalidate their caches. @@ -145,26 +167,6 @@ class CacheInvalidationStore(SQLBaseStore): }, ) - def get_all_updated_caches(self, last_id, current_id, limit): - if last_id == current_id: - return defer.succeed([]) - - def get_all_updated_caches_txn(txn): - # We purposefully don't bound by the current token, as we want to - # send across cache invalidations as quickly as possible. Cache - # invalidations are idempotent, so duplicates are fine. - sql = ( - "SELECT stream_id, cache_func, keys, invalidation_ts" - " FROM cache_invalidation_stream" - " WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?" - ) - txn.execute(sql, (last_id, limit)) - return txn.fetchall() - - return self.db.runInteraction( - "get_all_updated_caches", get_all_updated_caches_txn - ) - def get_cache_stream_token(self): if self._cache_id_gen: return self._cache_id_gen.get_current_token() diff --git a/synapse/storage/data_stores/main/deviceinbox.py b/synapse/storage/data_stores/main/deviceinbox.py index 0613b49f4a..9a1178fb39 100644 --- a/synapse/storage/data_stores/main/deviceinbox.py +++ b/synapse/storage/data_stores/main/deviceinbox.py @@ -207,6 +207,50 @@ class DeviceInboxWorkerStore(SQLBaseStore): "delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn ) + def get_all_new_device_messages(self, last_pos, current_pos, limit): + """ + Args: + last_pos(int): + current_pos(int): + limit(int): + Returns: + A deferred list of rows from the device inbox + """ + if last_pos == current_pos: + return defer.succeed([]) + + def get_all_new_device_messages_txn(txn): + # We limit like this as we might have multiple rows per stream_id, and + # we want to make sure we always get all entries for any stream_id + # we return. + upper_pos = min(current_pos, last_pos + limit) + sql = ( + "SELECT max(stream_id), user_id" + " FROM device_inbox" + " WHERE ? < stream_id AND stream_id <= ?" + " GROUP BY user_id" + ) + txn.execute(sql, (last_pos, upper_pos)) + rows = txn.fetchall() + + sql = ( + "SELECT max(stream_id), destination" + " FROM device_federation_outbox" + " WHERE ? < stream_id AND stream_id <= ?" + " GROUP BY destination" + ) + txn.execute(sql, (last_pos, upper_pos)) + rows.extend(txn) + + # Order by ascending stream ordering + rows.sort() + + return rows + + return self.db.runInteraction( + "get_all_new_device_messages", get_all_new_device_messages_txn + ) + class DeviceInboxBackgroundUpdateStore(SQLBaseStore): DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" @@ -411,47 +455,3 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) rows.append((user_id, device_id, stream_id, message_json)) txn.executemany(sql, rows) - - def get_all_new_device_messages(self, last_pos, current_pos, limit): - """ - Args: - last_pos(int): - current_pos(int): - limit(int): - Returns: - A deferred list of rows from the device inbox - """ - if last_pos == current_pos: - return defer.succeed([]) - - def get_all_new_device_messages_txn(txn): - # We limit like this as we might have multiple rows per stream_id, and - # we want to make sure we always get all entries for any stream_id - # we return. - upper_pos = min(current_pos, last_pos + limit) - sql = ( - "SELECT max(stream_id), user_id" - " FROM device_inbox" - " WHERE ? < stream_id AND stream_id <= ?" - " GROUP BY user_id" - ) - txn.execute(sql, (last_pos, upper_pos)) - rows = txn.fetchall() - - sql = ( - "SELECT max(stream_id), destination" - " FROM device_federation_outbox" - " WHERE ? < stream_id AND stream_id <= ?" - " GROUP BY destination" - ) - txn.execute(sql, (last_pos, upper_pos)) - rows.extend(txn) - - # Order by ascending stream ordering - rows.sort() - - return rows - - return self.db.runInteraction( - "get_all_new_device_messages", get_all_new_device_messages_txn - ) diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index d593ef47b8..e71c23541d 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -1267,104 +1267,6 @@ class EventsStore( ret = yield self.db.runInteraction("count_daily_active_rooms", _count) return ret - def get_current_backfill_token(self): - """The current minimum token that backfilled events have reached""" - return -self._backfill_id_gen.get_current_token() - - def get_current_events_token(self): - """The current maximum token that events have reached""" - return self._stream_id_gen.get_current_token() - - def get_all_new_forward_event_rows(self, last_id, current_id, limit): - if last_id == current_id: - return defer.succeed([]) - - def get_all_new_forward_event_rows(txn): - sql = ( - "SELECT e.stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts, relates_to_id" - " FROM events AS e" - " LEFT JOIN redactions USING (event_id)" - " LEFT JOIN state_events USING (event_id)" - " LEFT JOIN event_relations USING (event_id)" - " WHERE ? < stream_ordering AND stream_ordering <= ?" - " ORDER BY stream_ordering ASC" - " LIMIT ?" - ) - txn.execute(sql, (last_id, current_id, limit)) - new_event_updates = txn.fetchall() - - if len(new_event_updates) == limit: - upper_bound = new_event_updates[-1][0] - else: - upper_bound = current_id - - sql = ( - "SELECT event_stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts, relates_to_id" - " FROM events AS e" - " INNER JOIN ex_outlier_stream USING (event_id)" - " LEFT JOIN redactions USING (event_id)" - " LEFT JOIN state_events USING (event_id)" - " LEFT JOIN event_relations USING (event_id)" - " WHERE ? < event_stream_ordering" - " AND event_stream_ordering <= ?" - " ORDER BY event_stream_ordering DESC" - ) - txn.execute(sql, (last_id, upper_bound)) - new_event_updates.extend(txn) - - return new_event_updates - - return self.db.runInteraction( - "get_all_new_forward_event_rows", get_all_new_forward_event_rows - ) - - def get_all_new_backfill_event_rows(self, last_id, current_id, limit): - if last_id == current_id: - return defer.succeed([]) - - def get_all_new_backfill_event_rows(txn): - sql = ( - "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts, relates_to_id" - " FROM events AS e" - " LEFT JOIN redactions USING (event_id)" - " LEFT JOIN state_events USING (event_id)" - " LEFT JOIN event_relations USING (event_id)" - " WHERE ? > stream_ordering AND stream_ordering >= ?" - " ORDER BY stream_ordering ASC" - " LIMIT ?" - ) - txn.execute(sql, (-last_id, -current_id, limit)) - new_event_updates = txn.fetchall() - - if len(new_event_updates) == limit: - upper_bound = new_event_updates[-1][0] - else: - upper_bound = current_id - - sql = ( - "SELECT -event_stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts, relates_to_id" - " FROM events AS e" - " INNER JOIN ex_outlier_stream USING (event_id)" - " LEFT JOIN redactions USING (event_id)" - " LEFT JOIN state_events USING (event_id)" - " LEFT JOIN event_relations USING (event_id)" - " WHERE ? > event_stream_ordering" - " AND event_stream_ordering >= ?" - " ORDER BY event_stream_ordering DESC" - ) - txn.execute(sql, (-last_id, -upper_bound)) - new_event_updates.extend(txn.fetchall()) - - return new_event_updates - - return self.db.runInteraction( - "get_all_new_backfill_event_rows", get_all_new_backfill_event_rows - ) - @cached(num_args=5, max_entries=10) def get_all_new_events( self, @@ -1850,22 +1752,6 @@ class EventsStore( return (int(res["topological_ordering"]), int(res["stream_ordering"])) - def get_all_updated_current_state_deltas(self, from_token, to_token, limit): - def get_all_updated_current_state_deltas_txn(txn): - sql = """ - SELECT stream_id, room_id, type, state_key, event_id - FROM current_state_delta_stream - WHERE ? < stream_id AND stream_id <= ? - ORDER BY stream_id ASC LIMIT ? - """ - txn.execute(sql, (from_token, to_token, limit)) - return txn.fetchall() - - return self.db.runInteraction( - "get_all_updated_current_state_deltas", - get_all_updated_current_state_deltas_txn, - ) - def insert_labels_for_event_txn( self, txn, event_id, labels, room_id, topological_ordering ): diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py index 3013f49d32..16ea8948b1 100644 --- a/synapse/storage/data_stores/main/events_worker.py +++ b/synapse/storage/data_stores/main/events_worker.py @@ -963,3 +963,117 @@ class EventsWorkerStore(SQLBaseStore): complexity_v1 = round(state_events / 500, 2) return {"v1": complexity_v1} + + def get_current_backfill_token(self): + """The current minimum token that backfilled events have reached""" + return -self._backfill_id_gen.get_current_token() + + def get_current_events_token(self): + """The current maximum token that events have reached""" + return self._stream_id_gen.get_current_token() + + def get_all_new_forward_event_rows(self, last_id, current_id, limit): + if last_id == current_id: + return defer.succeed([]) + + def get_all_new_forward_event_rows(txn): + sql = ( + "SELECT e.stream_ordering, e.event_id, e.room_id, e.type," + " state_key, redacts, relates_to_id" + " FROM events AS e" + " LEFT JOIN redactions USING (event_id)" + " LEFT JOIN state_events USING (event_id)" + " LEFT JOIN event_relations USING (event_id)" + " WHERE ? < stream_ordering AND stream_ordering <= ?" + " ORDER BY stream_ordering ASC" + " LIMIT ?" + ) + txn.execute(sql, (last_id, current_id, limit)) + new_event_updates = txn.fetchall() + + if len(new_event_updates) == limit: + upper_bound = new_event_updates[-1][0] + else: + upper_bound = current_id + + sql = ( + "SELECT event_stream_ordering, e.event_id, e.room_id, e.type," + " state_key, redacts, relates_to_id" + " FROM events AS e" + " INNER JOIN ex_outlier_stream USING (event_id)" + " LEFT JOIN redactions USING (event_id)" + " LEFT JOIN state_events USING (event_id)" + " LEFT JOIN event_relations USING (event_id)" + " WHERE ? < event_stream_ordering" + " AND event_stream_ordering <= ?" + " ORDER BY event_stream_ordering DESC" + ) + txn.execute(sql, (last_id, upper_bound)) + new_event_updates.extend(txn) + + return new_event_updates + + return self.db.runInteraction( + "get_all_new_forward_event_rows", get_all_new_forward_event_rows + ) + + def get_all_new_backfill_event_rows(self, last_id, current_id, limit): + if last_id == current_id: + return defer.succeed([]) + + def get_all_new_backfill_event_rows(txn): + sql = ( + "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type," + " state_key, redacts, relates_to_id" + " FROM events AS e" + " LEFT JOIN redactions USING (event_id)" + " LEFT JOIN state_events USING (event_id)" + " LEFT JOIN event_relations USING (event_id)" + " WHERE ? > stream_ordering AND stream_ordering >= ?" + " ORDER BY stream_ordering ASC" + " LIMIT ?" + ) + txn.execute(sql, (-last_id, -current_id, limit)) + new_event_updates = txn.fetchall() + + if len(new_event_updates) == limit: + upper_bound = new_event_updates[-1][0] + else: + upper_bound = current_id + + sql = ( + "SELECT -event_stream_ordering, e.event_id, e.room_id, e.type," + " state_key, redacts, relates_to_id" + " FROM events AS e" + " INNER JOIN ex_outlier_stream USING (event_id)" + " LEFT JOIN redactions USING (event_id)" + " LEFT JOIN state_events USING (event_id)" + " LEFT JOIN event_relations USING (event_id)" + " WHERE ? > event_stream_ordering" + " AND event_stream_ordering >= ?" + " ORDER BY event_stream_ordering DESC" + ) + txn.execute(sql, (-last_id, -upper_bound)) + new_event_updates.extend(txn.fetchall()) + + return new_event_updates + + return self.db.runInteraction( + "get_all_new_backfill_event_rows", get_all_new_backfill_event_rows + ) + + def get_all_updated_current_state_deltas(self, from_token, to_token, limit): + def get_all_updated_current_state_deltas_txn(txn): + sql = """ + SELECT stream_id, room_id, type, state_key, event_id + FROM current_state_delta_stream + WHERE ? < stream_id AND stream_id <= ? + ORDER BY stream_id ASC LIMIT ? + """ + txn.execute(sql, (from_token, to_token, limit)) + return txn.fetchall() + + return self.db.runInteraction( + "get_all_updated_current_state_deltas", + get_all_updated_current_state_deltas_txn, + ) diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py index e6c10c6316..aaebe427d3 100644 --- a/synapse/storage/data_stores/main/room.py +++ b/synapse/storage/data_stores/main/room.py @@ -732,6 +732,26 @@ class RoomWorkerStore(SQLBaseStore): return total_media_quarantined + def get_all_new_public_rooms(self, prev_id, current_id, limit): + def get_all_new_public_rooms(txn): + sql = """ + SELECT stream_id, room_id, visibility, appservice_id, network_id + FROM public_room_list_stream + WHERE stream_id > ? AND stream_id <= ? + ORDER BY stream_id ASC + LIMIT ? + """ + + txn.execute(sql, (prev_id, current_id, limit)) + return txn.fetchall() + + if prev_id == current_id: + return defer.succeed([]) + + return self.db.runInteraction( + "get_all_new_public_rooms", get_all_new_public_rooms + ) + class RoomBackgroundUpdateStore(SQLBaseStore): REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory" @@ -1249,26 +1269,6 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): def get_current_public_room_stream_id(self): return self._public_room_id_gen.get_current_token() - def get_all_new_public_rooms(self, prev_id, current_id, limit): - def get_all_new_public_rooms(txn): - sql = """ - SELECT stream_id, room_id, visibility, appservice_id, network_id - FROM public_room_list_stream - WHERE stream_id > ? AND stream_id <= ? - ORDER BY stream_id ASC - LIMIT ? - """ - - txn.execute(sql, (prev_id, current_id, limit)) - return txn.fetchall() - - if prev_id == current_id: - return defer.succeed([]) - - return self.db.runInteraction( - "get_all_new_public_rooms", get_all_new_public_rooms - ) - @defer.inlineCallbacks def block_room(self, room_id, user_id): """Marks the room as blocked. Can be called multiple times. diff --git a/tests/replication/tcp/streams/_base.py b/tests/replication/tcp/streams/_base.py index e96ad4ca4e..a755fe2879 100644 --- a/tests/replication/tcp/streams/_base.py +++ b/tests/replication/tcp/streams/_base.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from mock import Mock from synapse.replication.tcp.commands import ReplicateCommand @@ -29,19 +30,37 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): # build a replication server server_factory = ReplicationStreamProtocolFactory(self.hs) self.streamer = server_factory.streamer - server = server_factory.buildProtocol(None) + self.server = server_factory.buildProtocol(None) - # build a replication client, with a dummy handler - handler_factory = Mock() - self.test_handler = TestReplicationClientHandler() - self.test_handler.factory = handler_factory + self.test_handler = Mock(wraps=TestReplicationClientHandler()) self.client = ClientReplicationStreamProtocol( - "client", "test", clock, self.test_handler + hs, "client", "test", clock, self.test_handler, ) - # wire them together - self.client.makeConnection(FakeTransport(server, reactor)) - server.makeConnection(FakeTransport(self.client, reactor)) + self._client_transport = None + self._server_transport = None + + def reconnect(self): + if self._client_transport: + self.client.close() + + if self._server_transport: + self.server.close() + + self._client_transport = FakeTransport(self.server, self.reactor) + self.client.makeConnection(self._client_transport) + + self._server_transport = FakeTransport(self.client, self.reactor) + self.server.makeConnection(self._server_transport) + + def disconnect(self): + if self._client_transport: + self._client_transport = None + self.client.close() + + if self._server_transport: + self._server_transport = None + self.server.close() def replicate(self): """Tell the master side of replication that something has happened, and then @@ -50,19 +69,24 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): self.streamer.on_notifier_poke() self.pump(0.1) - def replicate_stream(self, stream, token="NOW"): + def replicate_stream(self): """Make the client end a REPLICATE command to set up a subscription to a stream""" - self.client.send_command(ReplicateCommand(stream, token)) + self.client.send_command(ReplicateCommand()) class TestReplicationClientHandler(object): """Drop-in for ReplicationClientHandler which just collects RDATA rows""" def __init__(self): - self.received_rdata_rows = [] + self.streams = set() + self._received_rdata_rows = [] def get_streams_to_replicate(self): - return {} + positions = {s: 0 for s in self.streams} + for stream, token, _ in self._received_rdata_rows: + if stream in self.streams: + positions[stream] = max(token, positions.get(stream, 0)) + return positions def get_currently_syncing_users(self): return [] @@ -73,6 +97,9 @@ class TestReplicationClientHandler(object): def finished_connecting(self): pass + async def on_position(self, stream_name, token): + """Called when we get new position data.""" + async def on_rdata(self, stream_name, token, rows): for r in rows: - self.received_rdata_rows.append((stream_name, token, r)) + self._received_rdata_rows.append((stream_name, token, r)) diff --git a/tests/replication/tcp/streams/test_receipts.py b/tests/replication/tcp/streams/test_receipts.py index fa2493cad6..0ec0825a0e 100644 --- a/tests/replication/tcp/streams/test_receipts.py +++ b/tests/replication/tcp/streams/test_receipts.py @@ -17,30 +17,64 @@ from synapse.replication.tcp.streams._base import ReceiptsStream from tests.replication.tcp.streams._base import BaseStreamTestCase USER_ID = "@feeling:blue" -ROOM_ID = "!room:blue" -EVENT_ID = "$event:blue" class ReceiptsStreamTestCase(BaseStreamTestCase): def test_receipt(self): + self.reconnect() + # make the client subscribe to the receipts stream - self.replicate_stream("receipts", "NOW") + self.replicate_stream() + self.test_handler.streams.add("receipts") # tell the master to send a new receipt self.get_success( self.hs.get_datastore().insert_receipt( - ROOM_ID, "m.read", USER_ID, [EVENT_ID], {"a": 1} + "!room:blue", "m.read", USER_ID, ["$event:blue"], {"a": 1} ) ) self.replicate() # there should be one RDATA command - rdata_rows = self.test_handler.received_rdata_rows + self.test_handler.on_rdata.assert_called_once() + stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0] + self.assertEqual(stream_name, "receipts") self.assertEqual(1, len(rdata_rows)) - self.assertEqual(rdata_rows[0][0], "receipts") - row = rdata_rows[0][2] # type: ReceiptsStream.ReceiptsStreamRow - self.assertEqual(ROOM_ID, row.room_id) + row = rdata_rows[0] # type: ReceiptsStream.ReceiptsStreamRow + self.assertEqual("!room:blue", row.room_id) self.assertEqual("m.read", row.receipt_type) self.assertEqual(USER_ID, row.user_id) - self.assertEqual(EVENT_ID, row.event_id) + self.assertEqual("$event:blue", row.event_id) self.assertEqual({"a": 1}, row.data) + + # Now let's disconnect and insert some data. + self.disconnect() + + self.test_handler.on_rdata.reset_mock() + + self.get_success( + self.hs.get_datastore().insert_receipt( + "!room2:blue", "m.read", USER_ID, ["$event2:foo"], {"a": 2} + ) + ) + self.replicate() + + # Nothing should have happened as we are disconnected + self.test_handler.on_rdata.assert_not_called() + + self.reconnect() + self.pump(0.1) + + # We should now have caught up and get the missing data + self.test_handler.on_rdata.assert_called_once() + stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0] + self.assertEqual(stream_name, "receipts") + self.assertEqual(token, 3) + self.assertEqual(1, len(rdata_rows)) + + row = rdata_rows[0] # type: ReceiptsStream.ReceiptsStreamRow + self.assertEqual("!room2:blue", row.room_id) + self.assertEqual("m.read", row.receipt_type) + self.assertEqual(USER_ID, row.user_id) + self.assertEqual("$event2:foo", row.event_id) + self.assertEqual({"a": 2}, row.data) -- cgit 1.5.1 From 1c1242acba9694a3a4b1eb3b14ec0bac11ee4ff8 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 26 Mar 2020 07:39:34 -0400 Subject: Validate that the session is not modified during UI-Auth (#7068) --- changelog.d/7068.bugfix | 1 + synapse/handlers/auth.py | 37 +++++++++++++++-- synapse/rest/client/v2_alpha/account.py | 11 ++++-- synapse/rest/client/v2_alpha/devices.py | 4 +- synapse/rest/client/v2_alpha/keys.py | 2 +- synapse/rest/client/v2_alpha/register.py | 5 ++- tests/rest/client/v2_alpha/test_auth.py | 68 +++++++++++++++++++++++++++++++- tests/test_terms_auth.py | 3 +- 8 files changed, 117 insertions(+), 14 deletions(-) create mode 100644 changelog.d/7068.bugfix (limited to 'tests') diff --git a/changelog.d/7068.bugfix b/changelog.d/7068.bugfix new file mode 100644 index 0000000000..d1693a7f22 --- /dev/null +++ b/changelog.d/7068.bugfix @@ -0,0 +1 @@ +Ensure that a user inteactive authentication session is tied to a single request. diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 7860f9625e..2ce1425dfa 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -125,7 +125,11 @@ class AuthHandler(BaseHandler): @defer.inlineCallbacks def validate_user_via_ui_auth( - self, requester: Requester, request_body: Dict[str, Any], clientip: str + self, + requester: Requester, + request: SynapseRequest, + request_body: Dict[str, Any], + clientip: str, ): """ Checks that the user is who they claim to be, via a UI auth. @@ -137,6 +141,8 @@ class AuthHandler(BaseHandler): Args: requester: The user, as given by the access token + request: The request sent by the client. + request_body: The body of the request sent by the client clientip: The IP address of the client. @@ -172,7 +178,9 @@ class AuthHandler(BaseHandler): flows = [[login_type] for login_type in self._supported_login_types] try: - result, params, _ = yield self.check_auth(flows, request_body, clientip) + result, params, _ = yield self.check_auth( + flows, request, request_body, clientip + ) except LoginError: # Update the ratelimite to say we failed (`can_do_action` doesn't raise). self._failed_uia_attempts_ratelimiter.can_do_action( @@ -211,7 +219,11 @@ class AuthHandler(BaseHandler): @defer.inlineCallbacks def check_auth( - self, flows: List[List[str]], clientdict: Dict[str, Any], clientip: str + self, + flows: List[List[str]], + request: SynapseRequest, + clientdict: Dict[str, Any], + clientip: str, ): """ Takes a dictionary sent by the client in the login / registration @@ -231,6 +243,8 @@ class AuthHandler(BaseHandler): strings representing auth-types. At least one full flow must be completed in order for auth to be successful. + request: The request sent by the client. + clientdict: The dictionary from the client root level, not the 'auth' key: this method prompts for auth if none is sent. @@ -270,13 +284,27 @@ class AuthHandler(BaseHandler): # email auth link on there). It's probably too open to abuse # because it lets unauthenticated clients store arbitrary objects # on a homeserver. - # Revisit: Assumimg the REST APIs do sensible validation, the data + # Revisit: Assuming the REST APIs do sensible validation, the data # isn't arbintrary. session["clientdict"] = clientdict self._save_session(session) elif "clientdict" in session: clientdict = session["clientdict"] + # Ensure that the queried operation does not vary between stages of + # the UI authentication session. This is done by generating a stable + # comparator based on the URI, method, and body (minus the auth dict) + # and storing it during the initial query. Subsequent queries ensure + # that this comparator has not changed. + comparator = (request.uri, request.method, clientdict) + if "ui_auth" not in session: + session["ui_auth"] = comparator + elif session["ui_auth"] != comparator: + raise SynapseError( + 403, + "Requested operation has changed during the UI authentication session.", + ) + if not authdict: raise InteractiveAuthIncompleteError( self._auth_dict_for_flows(flows, session) @@ -322,6 +350,7 @@ class AuthHandler(BaseHandler): creds, list(clientdict), ) + return creds, clientdict, session["id"] ret = self._auth_dict_for_flows(flows, session) diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index 631cc74cb4..b1249b664c 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -234,13 +234,16 @@ class PasswordRestServlet(RestServlet): if self.auth.has_access_token(request): requester = await self.auth.get_user_by_req(request) params = await self.auth_handler.validate_user_via_ui_auth( - requester, body, self.hs.get_ip_from_request(request) + requester, request, body, self.hs.get_ip_from_request(request), ) user_id = requester.user.to_string() else: requester = None result, params, _ = await self.auth_handler.check_auth( - [[LoginType.EMAIL_IDENTITY]], body, self.hs.get_ip_from_request(request) + [[LoginType.EMAIL_IDENTITY]], + request, + body, + self.hs.get_ip_from_request(request), ) if LoginType.EMAIL_IDENTITY in result: @@ -308,7 +311,7 @@ class DeactivateAccountRestServlet(RestServlet): return 200, {} await self.auth_handler.validate_user_via_ui_auth( - requester, body, self.hs.get_ip_from_request(request) + requester, request, body, self.hs.get_ip_from_request(request), ) result = await self._deactivate_account_handler.deactivate_account( requester.user.to_string(), erase, id_server=body.get("id_server") @@ -656,7 +659,7 @@ class ThreepidAddRestServlet(RestServlet): assert_valid_client_secret(client_secret) await self.auth_handler.validate_user_via_ui_auth( - requester, body, self.hs.get_ip_from_request(request) + requester, request, body, self.hs.get_ip_from_request(request), ) validation_session = await self.identity_handler.validate_threepid_session( diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py index 94ff73f384..119d979052 100644 --- a/synapse/rest/client/v2_alpha/devices.py +++ b/synapse/rest/client/v2_alpha/devices.py @@ -81,7 +81,7 @@ class DeleteDevicesRestServlet(RestServlet): assert_params_in_dict(body, ["devices"]) await self.auth_handler.validate_user_via_ui_auth( - requester, body, self.hs.get_ip_from_request(request) + requester, request, body, self.hs.get_ip_from_request(request), ) await self.device_handler.delete_devices( @@ -127,7 +127,7 @@ class DeviceRestServlet(RestServlet): raise await self.auth_handler.validate_user_via_ui_auth( - requester, body, self.hs.get_ip_from_request(request) + requester, request, body, self.hs.get_ip_from_request(request), ) await self.device_handler.delete_device(requester.user.to_string(), device_id) diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index f7ed4daf90..5eb7ef35a4 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -263,7 +263,7 @@ class SigningKeyUploadServlet(RestServlet): body = parse_json_object_from_request(request) await self.auth_handler.validate_user_via_ui_auth( - requester, body, self.hs.get_ip_from_request(request) + requester, request, body, self.hs.get_ip_from_request(request), ) result = await self.e2e_keys_handler.upload_signing_keys_for_user(user_id, body) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index a09189b1b4..6963d79310 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -499,7 +499,10 @@ class RegisterRestServlet(RestServlet): ) auth_result, params, session_id = await self.auth_handler.check_auth( - self._registration_flows, body, self.hs.get_ip_from_request(request) + self._registration_flows, + request, + body, + self.hs.get_ip_from_request(request), ) # Check that we're not trying to register a denied 3pid. diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py index b6df1396ad..624bf5ada2 100644 --- a/tests/rest/client/v2_alpha/test_auth.py +++ b/tests/rest/client/v2_alpha/test_auth.py @@ -104,7 +104,7 @@ class FallbackAuthTests(unittest.HomeserverTestCase): ) self.render(request) - # Now we should have fufilled a complete auth flow, including + # Now we should have fulfilled a complete auth flow, including # the recaptcha fallback step, we can then send a # request to the register API with the session in the authdict. request, channel = self.make_request( @@ -115,3 +115,69 @@ class FallbackAuthTests(unittest.HomeserverTestCase): # We're given a registered user. self.assertEqual(channel.json_body["user_id"], "@user:test") + + def test_cannot_change_operation(self): + """ + The initial requested operation cannot be modified during the user interactive authentication session. + """ + + # Make the initial request to register. (Later on a different password + # will be used.) + request, channel = self.make_request( + "POST", + "register", + {"username": "user", "type": "m.login.password", "password": "bar"}, + ) + self.render(request) + + # Returns a 401 as per the spec + self.assertEqual(request.code, 401) + # Grab the session + session = channel.json_body["session"] + # Assert our configured public key is being given + self.assertEqual( + channel.json_body["params"]["m.login.recaptcha"]["public_key"], "brokencake" + ) + + request, channel = self.make_request( + "GET", "auth/m.login.recaptcha/fallback/web?session=" + session + ) + self.render(request) + self.assertEqual(request.code, 200) + + request, channel = self.make_request( + "POST", + "auth/m.login.recaptcha/fallback/web?session=" + + session + + "&g-recaptcha-response=a", + ) + self.render(request) + self.assertEqual(request.code, 200) + + # The recaptcha handler is called with the response given + attempts = self.recaptcha_checker.recaptcha_attempts + self.assertEqual(len(attempts), 1) + self.assertEqual(attempts[0][0]["response"], "a") + + # also complete the dummy auth + request, channel = self.make_request( + "POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}} + ) + self.render(request) + + # Now we should have fulfilled a complete auth flow, including + # the recaptcha fallback step. Make the initial request again, but + # with a different password. This causes the request to fail since the + # operaiton was modified during the ui auth session. + request, channel = self.make_request( + "POST", + "register", + { + "username": "user", + "type": "m.login.password", + "password": "foo", # Note this doesn't match the original request. + "auth": {"session": session}, + }, + ) + self.render(request) + self.assertEqual(channel.code, 403) diff --git a/tests/test_terms_auth.py b/tests/test_terms_auth.py index 5ec5d2b358..a3f98a1412 100644 --- a/tests/test_terms_auth.py +++ b/tests/test_terms_auth.py @@ -53,7 +53,8 @@ class TermsTestCase(unittest.HomeserverTestCase): def test_ui_auth(self): # Do a UI auth request - request, channel = self.make_request(b"POST", self.url, b"{}") + request_data = json.dumps({"username": "kermit", "password": "monkey"}) + request, channel = self.make_request(b"POST", self.url, request_data) self.render(request) self.assertEquals(channel.result["code"], b"401", channel.result) -- cgit 1.5.1 From e8e2ddb60ae11db488f159901d918cb159695912 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Thu, 26 Mar 2020 17:51:13 +0100 Subject: Allow server admins to define and enforce a password policy (MSC2000). (#7118) --- changelog.d/7118.feature | 1 + docs/sample_config.yaml | 35 ++++ synapse/api/errors.py | 21 +++ synapse/config/password.py | 39 +++++ synapse/handlers/password_policy.py | 93 +++++++++++ synapse/handlers/set_password.py | 2 + synapse/rest/__init__.py | 2 + synapse/rest/client/v2_alpha/password_policy.py | 58 +++++++ synapse/rest/client/v2_alpha/register.py | 2 + synapse/server.py | 5 + tests/rest/client/v2_alpha/test_password_policy.py | 179 +++++++++++++++++++++ 11 files changed, 437 insertions(+) create mode 100644 changelog.d/7118.feature create mode 100644 synapse/handlers/password_policy.py create mode 100644 synapse/rest/client/v2_alpha/password_policy.py create mode 100644 tests/rest/client/v2_alpha/test_password_policy.py (limited to 'tests') diff --git a/changelog.d/7118.feature b/changelog.d/7118.feature new file mode 100644 index 0000000000..5cbfd98160 --- /dev/null +++ b/changelog.d/7118.feature @@ -0,0 +1 @@ +Allow server admins to define and enforce a password policy (MSC2000). \ No newline at end of file diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 2ef83646b3..1a1d061759 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1482,6 +1482,41 @@ password_config: # #pepper: "EVEN_MORE_SECRET" + # Define and enforce a password policy. Each parameter is optional. + # This is an implementation of MSC2000. + # + policy: + # Whether to enforce the password policy. + # Defaults to 'false'. + # + #enabled: true + + # Minimum accepted length for a password. + # Defaults to 0. + # + #minimum_length: 15 + + # Whether a password must contain at least one digit. + # Defaults to 'false'. + # + #require_digit: true + + # Whether a password must contain at least one symbol. + # A symbol is any character that's not a number or a letter. + # Defaults to 'false'. + # + #require_symbol: true + + # Whether a password must contain at least one lowercase letter. + # Defaults to 'false'. + # + #require_lowercase: true + + # Whether a password must contain at least one lowercase letter. + # Defaults to 'false'. + # + #require_uppercase: true + # Configuration for sending emails from Synapse. # diff --git a/synapse/api/errors.py b/synapse/api/errors.py index 616942b057..11da016ac5 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -64,6 +64,13 @@ class Codes(object): INCOMPATIBLE_ROOM_VERSION = "M_INCOMPATIBLE_ROOM_VERSION" WRONG_ROOM_KEYS_VERSION = "M_WRONG_ROOM_KEYS_VERSION" EXPIRED_ACCOUNT = "ORG_MATRIX_EXPIRED_ACCOUNT" + PASSWORD_TOO_SHORT = "M_PASSWORD_TOO_SHORT" + PASSWORD_NO_DIGIT = "M_PASSWORD_NO_DIGIT" + PASSWORD_NO_UPPERCASE = "M_PASSWORD_NO_UPPERCASE" + PASSWORD_NO_LOWERCASE = "M_PASSWORD_NO_LOWERCASE" + PASSWORD_NO_SYMBOL = "M_PASSWORD_NO_SYMBOL" + PASSWORD_IN_DICTIONARY = "M_PASSWORD_IN_DICTIONARY" + WEAK_PASSWORD = "M_WEAK_PASSWORD" INVALID_SIGNATURE = "M_INVALID_SIGNATURE" USER_DEACTIVATED = "M_USER_DEACTIVATED" BAD_ALIAS = "M_BAD_ALIAS" @@ -439,6 +446,20 @@ class IncompatibleRoomVersionError(SynapseError): return cs_error(self.msg, self.errcode, room_version=self._room_version) +class PasswordRefusedError(SynapseError): + """A password has been refused, either during password reset/change or registration. + """ + + def __init__( + self, + msg="This password doesn't comply with the server's policy", + errcode=Codes.WEAK_PASSWORD, + ): + super(PasswordRefusedError, self).__init__( + code=400, msg=msg, errcode=errcode, + ) + + class RequestSendFailed(RuntimeError): """Sending a HTTP request over federation failed due to not being able to talk to the remote server for some reason. diff --git a/synapse/config/password.py b/synapse/config/password.py index 2a634ac751..9c0ea8c30a 100644 --- a/synapse/config/password.py +++ b/synapse/config/password.py @@ -31,6 +31,10 @@ class PasswordConfig(Config): self.password_localdb_enabled = password_config.get("localdb_enabled", True) self.password_pepper = password_config.get("pepper", "") + # Password policy + self.password_policy = password_config.get("policy") or {} + self.password_policy_enabled = self.password_policy.get("enabled", False) + def generate_config_section(self, config_dir_path, server_name, **kwargs): return """\ password_config: @@ -48,4 +52,39 @@ class PasswordConfig(Config): # DO NOT CHANGE THIS AFTER INITIAL SETUP! # #pepper: "EVEN_MORE_SECRET" + + # Define and enforce a password policy. Each parameter is optional. + # This is an implementation of MSC2000. + # + policy: + # Whether to enforce the password policy. + # Defaults to 'false'. + # + #enabled: true + + # Minimum accepted length for a password. + # Defaults to 0. + # + #minimum_length: 15 + + # Whether a password must contain at least one digit. + # Defaults to 'false'. + # + #require_digit: true + + # Whether a password must contain at least one symbol. + # A symbol is any character that's not a number or a letter. + # Defaults to 'false'. + # + #require_symbol: true + + # Whether a password must contain at least one lowercase letter. + # Defaults to 'false'. + # + #require_lowercase: true + + # Whether a password must contain at least one lowercase letter. + # Defaults to 'false'. + # + #require_uppercase: true """ diff --git a/synapse/handlers/password_policy.py b/synapse/handlers/password_policy.py new file mode 100644 index 0000000000..d06b110269 --- /dev/null +++ b/synapse/handlers/password_policy.py @@ -0,0 +1,93 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import re + +from synapse.api.errors import Codes, PasswordRefusedError + +logger = logging.getLogger(__name__) + + +class PasswordPolicyHandler(object): + def __init__(self, hs): + self.policy = hs.config.password_policy + self.enabled = hs.config.password_policy_enabled + + # Regexps for the spec'd policy parameters. + self.regexp_digit = re.compile("[0-9]") + self.regexp_symbol = re.compile("[^a-zA-Z0-9]") + self.regexp_uppercase = re.compile("[A-Z]") + self.regexp_lowercase = re.compile("[a-z]") + + def validate_password(self, password): + """Checks whether a given password complies with the server's policy. + + Args: + password (str): The password to check against the server's policy. + + Raises: + PasswordRefusedError: The password doesn't comply with the server's policy. + """ + + if not self.enabled: + return + + minimum_accepted_length = self.policy.get("minimum_length", 0) + if len(password) < minimum_accepted_length: + raise PasswordRefusedError( + msg=( + "The password must be at least %d characters long" + % minimum_accepted_length + ), + errcode=Codes.PASSWORD_TOO_SHORT, + ) + + if ( + self.policy.get("require_digit", False) + and self.regexp_digit.search(password) is None + ): + raise PasswordRefusedError( + msg="The password must include at least one digit", + errcode=Codes.PASSWORD_NO_DIGIT, + ) + + if ( + self.policy.get("require_symbol", False) + and self.regexp_symbol.search(password) is None + ): + raise PasswordRefusedError( + msg="The password must include at least one symbol", + errcode=Codes.PASSWORD_NO_SYMBOL, + ) + + if ( + self.policy.get("require_uppercase", False) + and self.regexp_uppercase.search(password) is None + ): + raise PasswordRefusedError( + msg="The password must include at least one uppercase letter", + errcode=Codes.PASSWORD_NO_UPPERCASE, + ) + + if ( + self.policy.get("require_lowercase", False) + and self.regexp_lowercase.search(password) is None + ): + raise PasswordRefusedError( + msg="The password must include at least one lowercase letter", + errcode=Codes.PASSWORD_NO_LOWERCASE, + ) diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py index 12657ca698..7d1263caf2 100644 --- a/synapse/handlers/set_password.py +++ b/synapse/handlers/set_password.py @@ -32,6 +32,7 @@ class SetPasswordHandler(BaseHandler): super(SetPasswordHandler, self).__init__(hs) self._auth_handler = hs.get_auth_handler() self._device_handler = hs.get_device_handler() + self._password_policy_handler = hs.get_password_policy_handler() @defer.inlineCallbacks def set_password( @@ -44,6 +45,7 @@ class SetPasswordHandler(BaseHandler): if not self.hs.config.password_localdb_enabled: raise SynapseError(403, "Password change disabled", errcode=Codes.FORBIDDEN) + self._password_policy_handler.validate_password(new_password) password_hash = yield self._auth_handler.hash(new_password) try: diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py index 4a1fc2ec2b..46e458e95b 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py @@ -41,6 +41,7 @@ from synapse.rest.client.v2_alpha import ( keys, notifications, openid, + password_policy, read_marker, receipts, register, @@ -118,6 +119,7 @@ class ClientRestResource(JsonResource): capabilities.register_servlets(hs, client_resource) account_validity.register_servlets(hs, client_resource) relations.register_servlets(hs, client_resource) + password_policy.register_servlets(hs, client_resource) # moving to /_synapse/admin synapse.rest.admin.register_servlets_for_client_rest_resource( diff --git a/synapse/rest/client/v2_alpha/password_policy.py b/synapse/rest/client/v2_alpha/password_policy.py new file mode 100644 index 0000000000..968403cca4 --- /dev/null +++ b/synapse/rest/client/v2_alpha/password_policy.py @@ -0,0 +1,58 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.http.servlet import RestServlet + +from ._base import client_patterns + +logger = logging.getLogger(__name__) + + +class PasswordPolicyServlet(RestServlet): + PATTERNS = client_patterns("/password_policy$") + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super(PasswordPolicyServlet, self).__init__() + + self.policy = hs.config.password_policy + self.enabled = hs.config.password_policy_enabled + + def on_GET(self, request): + if not self.enabled or not self.policy: + return (200, {}) + + policy = {} + + for param in [ + "minimum_length", + "require_digit", + "require_symbol", + "require_lowercase", + "require_uppercase", + ]: + if param in self.policy: + policy["m.%s" % param] = self.policy[param] + + return (200, policy) + + +def register_servlets(hs, http_server): + PasswordPolicyServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 6963d79310..66fc8ec179 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -373,6 +373,7 @@ class RegisterRestServlet(RestServlet): self.room_member_handler = hs.get_room_member_handler() self.macaroon_gen = hs.get_macaroon_generator() self.ratelimiter = hs.get_registration_ratelimiter() + self.password_policy_handler = hs.get_password_policy_handler() self.clock = hs.get_clock() self._registration_flows = _calculate_registration_flows( @@ -420,6 +421,7 @@ class RegisterRestServlet(RestServlet): or len(body["password"]) > 512 ): raise SynapseError(400, "Invalid password") + self.password_policy_handler.validate_password(body["password"]) desired_username = None if "username" in body: diff --git a/synapse/server.py b/synapse/server.py index 9426eb1672..d0d80e8ac5 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -66,6 +66,7 @@ from synapse.handlers.groups_local import GroupsLocalHandler, GroupsLocalWorkerH from synapse.handlers.initial_sync import InitialSyncHandler from synapse.handlers.message import EventCreationHandler, MessageHandler from synapse.handlers.pagination import PaginationHandler +from synapse.handlers.password_policy import PasswordPolicyHandler from synapse.handlers.presence import PresenceHandler from synapse.handlers.profile import BaseProfileHandler, MasterProfileHandler from synapse.handlers.read_marker import ReadMarkerHandler @@ -199,6 +200,7 @@ class HomeServer(object): "account_validity_handler", "saml_handler", "event_client_serializer", + "password_policy_handler", "storage", "replication_streamer", ] @@ -535,6 +537,9 @@ class HomeServer(object): def build_event_client_serializer(self): return EventClientSerializer(self) + def build_password_policy_handler(self): + return PasswordPolicyHandler(self) + def build_storage(self) -> Storage: return Storage(self, self.datastores) diff --git a/tests/rest/client/v2_alpha/test_password_policy.py b/tests/rest/client/v2_alpha/test_password_policy.py new file mode 100644 index 0000000000..c57072f50c --- /dev/null +++ b/tests/rest/client/v2_alpha/test_password_policy.py @@ -0,0 +1,179 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +from synapse.api.constants import LoginType +from synapse.api.errors import Codes +from synapse.rest import admin +from synapse.rest.client.v1 import login +from synapse.rest.client.v2_alpha import account, password_policy, register + +from tests import unittest + + +class PasswordPolicyTestCase(unittest.HomeserverTestCase): + """Tests the password policy feature and its compliance with MSC2000. + + When validating a password, Synapse does the necessary checks in this order: + + 1. Password is long enough + 2. Password contains digit(s) + 3. Password contains symbol(s) + 4. Password contains uppercase letter(s) + 5. Password contains lowercase letter(s) + + For each test below that checks whether a password triggers the right error code, + that test provides a password good enough to pass the previous tests, but not the + one it is currently testing (nor any test that comes afterward). + """ + + servlets = [ + admin.register_servlets_for_client_rest_resource, + login.register_servlets, + register.register_servlets, + password_policy.register_servlets, + account.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + self.register_url = "/_matrix/client/r0/register" + self.policy = { + "enabled": True, + "minimum_length": 10, + "require_digit": True, + "require_symbol": True, + "require_lowercase": True, + "require_uppercase": True, + } + + config = self.default_config() + config["password_config"] = { + "policy": self.policy, + } + + hs = self.setup_test_homeserver(config=config) + return hs + + def test_get_policy(self): + """Tests if the /password_policy endpoint returns the configured policy.""" + + request, channel = self.make_request( + "GET", "/_matrix/client/r0/password_policy" + ) + self.render(request) + + self.assertEqual(channel.code, 200, channel.result) + self.assertEqual( + channel.json_body, + { + "m.minimum_length": 10, + "m.require_digit": True, + "m.require_symbol": True, + "m.require_lowercase": True, + "m.require_uppercase": True, + }, + channel.result, + ) + + def test_password_too_short(self): + request_data = json.dumps({"username": "kermit", "password": "shorty"}) + request, channel = self.make_request("POST", self.register_url, request_data) + self.render(request) + + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual( + channel.json_body["errcode"], Codes.PASSWORD_TOO_SHORT, channel.result, + ) + + def test_password_no_digit(self): + request_data = json.dumps({"username": "kermit", "password": "longerpassword"}) + request, channel = self.make_request("POST", self.register_url, request_data) + self.render(request) + + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual( + channel.json_body["errcode"], Codes.PASSWORD_NO_DIGIT, channel.result, + ) + + def test_password_no_symbol(self): + request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword"}) + request, channel = self.make_request("POST", self.register_url, request_data) + self.render(request) + + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual( + channel.json_body["errcode"], Codes.PASSWORD_NO_SYMBOL, channel.result, + ) + + def test_password_no_uppercase(self): + request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword!"}) + request, channel = self.make_request("POST", self.register_url, request_data) + self.render(request) + + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual( + channel.json_body["errcode"], Codes.PASSWORD_NO_UPPERCASE, channel.result, + ) + + def test_password_no_lowercase(self): + request_data = json.dumps({"username": "kermit", "password": "L0NGERPASSWORD!"}) + request, channel = self.make_request("POST", self.register_url, request_data) + self.render(request) + + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual( + channel.json_body["errcode"], Codes.PASSWORD_NO_LOWERCASE, channel.result, + ) + + def test_password_compliant(self): + request_data = json.dumps({"username": "kermit", "password": "L0ngerpassword!"}) + request, channel = self.make_request("POST", self.register_url, request_data) + self.render(request) + + # Getting a 401 here means the password has passed validation and the server has + # responded with a list of registration flows. + self.assertEqual(channel.code, 401, channel.result) + + def test_password_change(self): + """This doesn't test every possible use case, only that hitting /account/password + triggers the password validation code. + """ + compliant_password = "C0mpl!antpassword" + not_compliant_password = "notcompliantpassword" + + user_id = self.register_user("kermit", compliant_password) + tok = self.login("kermit", compliant_password) + + request_data = json.dumps( + { + "new_password": not_compliant_password, + "auth": { + "password": compliant_password, + "type": LoginType.PASSWORD, + "user": user_id, + }, + } + ) + request, channel = self.make_request( + "POST", + "/_matrix/client/r0/account/password", + request_data, + access_token=tok, + ) + self.render(request) + + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.json_body["errcode"], Codes.PASSWORD_NO_DIGIT) -- cgit 1.5.1 From 665630fcaab8f09e83ff77f35d5244a718e20701 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 27 Mar 2020 11:39:43 +0000 Subject: Add tests for outbound device pokes --- changelog.d/7157.misc | 1 + tests/federation/test_federation_sender.py | 303 ++++++++++++++++++++++++++++- tests/unittest.py | 1 + 3 files changed, 302 insertions(+), 3 deletions(-) create mode 100644 changelog.d/7157.misc (limited to 'tests') diff --git a/changelog.d/7157.misc b/changelog.d/7157.misc new file mode 100644 index 0000000000..0eb1128c7a --- /dev/null +++ b/changelog.d/7157.misc @@ -0,0 +1 @@ +Add tests for outbound device pokes. diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index d456267b87..7763b12159 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -12,19 +12,25 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional from mock import Mock +from signedjson import key, sign +from signedjson.types import BaseKey, SigningKey + from twisted.internet import defer -from synapse.types import ReadReceipt +from synapse.rest import admin +from synapse.rest.client.v1 import login +from synapse.types import JsonDict, ReadReceipt from tests.unittest import HomeserverTestCase, override_config -class FederationSenderTestCases(HomeserverTestCase): +class FederationSenderReceiptsTestCases(HomeserverTestCase): def make_homeserver(self, reactor, clock): - return super(FederationSenderTestCases, self).setup_test_homeserver( + return self.setup_test_homeserver( state_handler=Mock(spec=["get_current_hosts_in_room"]), federation_transport_client=Mock(spec=["send_transaction"]), ) @@ -147,3 +153,294 @@ class FederationSenderTestCases(HomeserverTestCase): } ], ) + + +class FederationSenderDevicesTestCases(HomeserverTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + return self.setup_test_homeserver( + state_handler=Mock(spec=["get_current_hosts_in_room"]), + federation_transport_client=Mock(spec=["send_transaction"]), + ) + + def default_config(self): + c = super().default_config() + c["send_federation"] = True + return c + + def prepare(self, reactor, clock, hs): + # stub out get_current_hosts_in_room + mock_state_handler = hs.get_state_handler() + mock_state_handler.get_current_hosts_in_room.return_value = ["test", "host2"] + + # stub out get_users_who_share_room_with_user so that it claims that + # `@user2:host2` is in the room + def get_users_who_share_room_with_user(user_id): + return defer.succeed({"@user2:host2"}) + + hs.get_datastore().get_users_who_share_room_with_user = ( + get_users_who_share_room_with_user + ) + + # whenever send_transaction is called, record the edu data + self.edus = [] + self.hs.get_federation_transport_client().send_transaction.side_effect = ( + self.record_transaction + ) + + def record_transaction(self, txn, json_cb): + data = json_cb() + self.edus.extend(data["edus"]) + return defer.succeed({}) + + def test_send_device_updates(self): + """Basic case: each device update should result in an EDU""" + # create a device + u1 = self.register_user("user", "pass") + self.login(u1, "pass", device_id="D1") + + # expect one edu + self.assertEqual(len(self.edus), 1) + stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D1", None) + + # a second call should produce no new device EDUs + self.hs.get_federation_sender().send_device_messages("host2") + self.pump() + self.assertEqual(self.edus, []) + + # a second device + self.login("user", "pass", device_id="D2") + + self.assertEqual(len(self.edus), 1) + self.check_device_update_edu(self.edus.pop(0), u1, "D2", stream_id) + + def test_upload_signatures(self): + """Uploading signatures on some devices should produce updates for that user""" + + e2e_handler = self.hs.get_e2e_keys_handler() + + # register two devices + u1 = self.register_user("user", "pass") + self.login(u1, "pass", device_id="D1") + self.login(u1, "pass", device_id="D2") + + # expect two edus + self.assertEqual(len(self.edus), 2) + stream_id = None + stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D1", stream_id) + stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D2", stream_id) + + # upload signing keys for each device + device1_signing_key = self.generate_and_upload_device_signing_key(u1, "D1") + device2_signing_key = self.generate_and_upload_device_signing_key(u1, "D2") + + # expect two more edus + self.assertEqual(len(self.edus), 2) + stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D1", stream_id) + stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D2", stream_id) + + # upload master key and self-signing key + master_signing_key = generate_self_id_key() + master_key = { + "user_id": u1, + "usage": ["master"], + "keys": {key_id(master_signing_key): encode_pubkey(master_signing_key)}, + } + + # private key: HvQBbU+hc2Zr+JP1sE0XwBe1pfZZEYtJNPJLZJtS+F8 + selfsigning_signing_key = generate_self_id_key() + selfsigning_key = { + "user_id": u1, + "usage": ["self_signing"], + "keys": { + key_id(selfsigning_signing_key): encode_pubkey(selfsigning_signing_key) + }, + } + sign.sign_json(selfsigning_key, u1, master_signing_key) + + cross_signing_keys = { + "master_key": master_key, + "self_signing_key": selfsigning_key, + } + + self.get_success( + e2e_handler.upload_signing_keys_for_user(u1, cross_signing_keys) + ) + + # expect signing key update edu + self.assertEqual(len(self.edus), 1) + self.assertEqual(self.edus.pop(0)["edu_type"], "org.matrix.signing_key_update") + + # sign the devices + d1_json = build_device_dict(u1, "D1", device1_signing_key) + sign.sign_json(d1_json, u1, selfsigning_signing_key) + d2_json = build_device_dict(u1, "D2", device2_signing_key) + sign.sign_json(d2_json, u1, selfsigning_signing_key) + + ret = self.get_success( + e2e_handler.upload_signatures_for_device_keys( + u1, {u1: {"D1": d1_json, "D2": d2_json}}, + ) + ) + self.assertEqual(ret["failures"], {}) + + # expect two edus, in one or two transactions. We don't know what order the + # devices will be updated. + self.assertEqual(len(self.edus), 2) + stream_id = None # FIXME: there is a discontinuity in the stream IDs: see #7142 + for edu in self.edus: + self.assertEqual(edu["edu_type"], "m.device_list_update") + c = edu["content"] + if stream_id is not None: + self.assertEqual(c["prev_id"], [stream_id]) + stream_id = c["stream_id"] + devices = {edu["content"]["device_id"] for edu in self.edus} + self.assertEqual({"D1", "D2"}, devices) + + def test_delete_devices(self): + """If devices are deleted, that should result in EDUs too""" + + # create devices + u1 = self.register_user("user", "pass") + self.login("user", "pass", device_id="D1") + self.login("user", "pass", device_id="D2") + self.login("user", "pass", device_id="D3") + + # expect three edus + self.assertEqual(len(self.edus), 3) + stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D1", None) + stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D2", stream_id) + stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D3", stream_id) + + # delete them again + self.get_success( + self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"]) + ) + + # expect three edus, in an unknown order + self.assertEqual(len(self.edus), 3) + for edu in self.edus: + self.assertEqual(edu["edu_type"], "m.device_list_update") + c = edu["content"] + self.assertGreaterEqual( + c.items(), + {"user_id": u1, "prev_id": [stream_id], "deleted": True}.items(), + ) + stream_id = c["stream_id"] + devices = {edu["content"]["device_id"] for edu in self.edus} + self.assertEqual({"D1", "D2", "D3"}, devices) + + def test_unreachable_server(self): + """If the destination server is unreachable, all the updates should get sent on + recovery + """ + mock_send_txn = self.hs.get_federation_transport_client().send_transaction + mock_send_txn.side_effect = lambda t, cb: defer.fail("fail") + + # create devices + u1 = self.register_user("user", "pass") + self.login("user", "pass", device_id="D1") + self.login("user", "pass", device_id="D2") + self.login("user", "pass", device_id="D3") + + # delete them again + self.get_success( + self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"]) + ) + + self.assertGreaterEqual(mock_send_txn.call_count, 4) + + # recover the server + mock_send_txn.side_effect = self.record_transaction + self.hs.get_federation_sender().send_device_messages("host2") + self.pump() + + # for each device, there should be a single update + self.assertEqual(len(self.edus), 3) + stream_id = None + for edu in self.edus: + self.assertEqual(edu["edu_type"], "m.device_list_update") + c = edu["content"] + self.assertEqual(c["prev_id"], [stream_id] if stream_id is not None else []) + stream_id = c["stream_id"] + devices = {edu["content"]["device_id"] for edu in self.edus} + self.assertEqual({"D1", "D2", "D3"}, devices) + + def check_device_update_edu( + self, + edu: JsonDict, + user_id: str, + device_id: str, + prev_stream_id: Optional[int], + ) -> int: + """Check that the given EDU is an update for the given device + Returns the stream_id. + """ + self.assertEqual(edu["edu_type"], "m.device_list_update") + content = edu["content"] + + expected = { + "user_id": user_id, + "device_id": device_id, + "prev_id": [prev_stream_id] if prev_stream_id is not None else [], + } + + self.assertLessEqual(expected.items(), content.items()) + return content["stream_id"] + + def check_signing_key_update_txn(self, txn: JsonDict,) -> None: + """Check that the txn has an EDU with a signing key update. + """ + edus = txn["edus"] + self.assertEqual(len(edus), 1) + + def generate_and_upload_device_signing_key( + self, user_id: str, device_id: str + ) -> SigningKey: + """Generate a signing keypair for the given device, and upload it""" + sk = key.generate_signing_key(device_id) + + device_dict = build_device_dict(user_id, device_id, sk) + + self.get_success( + self.hs.get_e2e_keys_handler().upload_keys_for_user( + user_id, device_id, {"device_keys": device_dict}, + ) + ) + return sk + + +def generate_self_id_key() -> SigningKey: + """generate a signing key whose version is its public key + + ... as used by the cross-signing-keys. + """ + k = key.generate_signing_key("x") + k.version = encode_pubkey(k) + return k + + +def key_id(k: BaseKey) -> str: + return "%s:%s" % (k.alg, k.version) + + +def encode_pubkey(sk: SigningKey) -> str: + """Encode the public key corresponding to the given signing key as base64""" + return key.encode_verify_key_base64(key.get_verify_key(sk)) + + +def build_device_dict(user_id: str, device_id: str, sk: SigningKey): + """Build a dict representing the given device""" + return { + "user_id": user_id, + "device_id": device_id, + "algorithms": ["m.olm.curve25519-aes-sha256", "m.megolm.v1.aes-sha"], + "keys": { + "curve25519:" + device_id: "curve25519+key", + key_id(sk): encode_pubkey(sk), + }, + } diff --git a/tests/unittest.py b/tests/unittest.py index 23b59bea22..3d57b77a5d 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -490,6 +490,7 @@ class HomeserverTestCase(TestCase): "password": password, "admin": admin, "mac": want_mac, + "inhibit_login": True, } ) request, channel = self.make_request( -- cgit 1.5.1 From 8327eb9280cbcb492e05652a96be9f1cd1c0e7c4 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Fri, 27 Mar 2020 20:15:23 +0100 Subject: Add options to prevent users from changing their profile. (#7096) --- changelog.d/7096.feature | 1 + docs/sample_config.yaml | 23 +++ synapse/config/registration.py | 27 +++ synapse/handlers/profile.py | 16 ++ synapse/rest/client/v2_alpha/account.py | 16 ++ tests/handlers/test_profile.py | 65 ++++++- tests/rest/client/v2_alpha/test_account.py | 302 +++++++++++++++++++++++++++++ 7 files changed, 449 insertions(+), 1 deletion(-) create mode 100644 changelog.d/7096.feature (limited to 'tests') diff --git a/changelog.d/7096.feature b/changelog.d/7096.feature new file mode 100644 index 0000000000..00f47b2a14 --- /dev/null +++ b/changelog.d/7096.feature @@ -0,0 +1 @@ +Add options to prevent users from changing their profile or associated 3PIDs. \ No newline at end of file diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 1a1d061759..545226f753 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1086,6 +1086,29 @@ account_threepid_delegates: #email: https://example.com # Delegate email sending to example.com #msisdn: http://localhost:8090 # Delegate SMS sending to this local process +# Whether users are allowed to change their displayname after it has +# been initially set. Useful when provisioning users based on the +# contents of a third-party directory. +# +# Does not apply to server administrators. Defaults to 'true' +# +#enable_set_displayname: false + +# Whether users are allowed to change their avatar after it has been +# initially set. Useful when provisioning users based on the contents +# of a third-party directory. +# +# Does not apply to server administrators. Defaults to 'true' +# +#enable_set_avatar_url: false + +# Whether users can change the 3PIDs associated with their accounts +# (email address and msisdn). +# +# Defaults to 'true' +# +#enable_3pid_changes: false + # Users who register on this homeserver will automatically be joined # to these rooms # diff --git a/synapse/config/registration.py b/synapse/config/registration.py index 9bb3beedbc..e7ea3a01cb 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -129,6 +129,10 @@ class RegistrationConfig(Config): raise ConfigError("Invalid auto_join_rooms entry %s" % (room_alias,)) self.autocreate_auto_join_rooms = config.get("autocreate_auto_join_rooms", True) + self.enable_set_displayname = config.get("enable_set_displayname", True) + self.enable_set_avatar_url = config.get("enable_set_avatar_url", True) + self.enable_3pid_changes = config.get("enable_3pid_changes", True) + self.disable_msisdn_registration = config.get( "disable_msisdn_registration", False ) @@ -330,6 +334,29 @@ class RegistrationConfig(Config): #email: https://example.com # Delegate email sending to example.com #msisdn: http://localhost:8090 # Delegate SMS sending to this local process + # Whether users are allowed to change their displayname after it has + # been initially set. Useful when provisioning users based on the + # contents of a third-party directory. + # + # Does not apply to server administrators. Defaults to 'true' + # + #enable_set_displayname: false + + # Whether users are allowed to change their avatar after it has been + # initially set. Useful when provisioning users based on the contents + # of a third-party directory. + # + # Does not apply to server administrators. Defaults to 'true' + # + #enable_set_avatar_url: false + + # Whether users can change the 3PIDs associated with their accounts + # (email address and msisdn). + # + # Defaults to 'true' + # + #enable_3pid_changes: false + # Users who register on this homeserver will automatically be joined # to these rooms # diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 50ce0c585b..6aa1c0f5e0 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -157,6 +157,15 @@ class BaseProfileHandler(BaseHandler): if not by_admin and target_user != requester.user: raise AuthError(400, "Cannot set another user's displayname") + if not by_admin and not self.hs.config.enable_set_displayname: + profile = yield self.store.get_profileinfo(target_user.localpart) + if profile.display_name: + raise SynapseError( + 400, + "Changing display name is disabled on this server", + Codes.FORBIDDEN, + ) + if len(new_displayname) > MAX_DISPLAYNAME_LEN: raise SynapseError( 400, "Displayname is too long (max %i)" % (MAX_DISPLAYNAME_LEN,) @@ -218,6 +227,13 @@ class BaseProfileHandler(BaseHandler): if not by_admin and target_user != requester.user: raise AuthError(400, "Cannot set another user's avatar_url") + if not by_admin and not self.hs.config.enable_set_avatar_url: + profile = yield self.store.get_profileinfo(target_user.localpart) + if profile.avatar_url: + raise SynapseError( + 400, "Changing avatar is disabled on this server", Codes.FORBIDDEN + ) + if len(new_avatar_url) > MAX_AVATAR_URL_LEN: raise SynapseError( 400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN,) diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index b1249b664c..f80b5e40ea 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -605,6 +605,11 @@ class ThreepidRestServlet(RestServlet): return 200, {"threepids": threepids} async def on_POST(self, request): + if not self.hs.config.enable_3pid_changes: + raise SynapseError( + 400, "3PID changes are disabled on this server", Codes.FORBIDDEN + ) + requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() body = parse_json_object_from_request(request) @@ -649,6 +654,11 @@ class ThreepidAddRestServlet(RestServlet): @interactive_auth_handler async def on_POST(self, request): + if not self.hs.config.enable_3pid_changes: + raise SynapseError( + 400, "3PID changes are disabled on this server", Codes.FORBIDDEN + ) + requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() body = parse_json_object_from_request(request) @@ -744,10 +754,16 @@ class ThreepidDeleteRestServlet(RestServlet): def __init__(self, hs): super(ThreepidDeleteRestServlet, self).__init__() + self.hs = hs self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() async def on_POST(self, request): + if not self.hs.config.enable_3pid_changes: + raise SynapseError( + 400, "3PID changes are disabled on this server", Codes.FORBIDDEN + ) + body = parse_json_object_from_request(request) assert_params_in_dict(body, ["medium", "address"]) diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index d60c124eec..be665262c6 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -19,7 +19,7 @@ from mock import Mock, NonCallableMock from twisted.internet import defer import synapse.types -from synapse.api.errors import AuthError +from synapse.api.errors import AuthError, SynapseError from synapse.handlers.profile import MasterProfileHandler from synapse.types import UserID @@ -70,6 +70,7 @@ class ProfileTestCase(unittest.TestCase): yield self.store.create_profile(self.frank.localpart) self.handler = hs.get_profile_handler() + self.hs = hs @defer.inlineCallbacks def test_get_my_name(self): @@ -90,6 +91,33 @@ class ProfileTestCase(unittest.TestCase): "Frank Jr.", ) + # Set displayname again + yield self.handler.set_displayname( + self.frank, synapse.types.create_requester(self.frank), "Frank" + ) + + self.assertEquals( + (yield self.store.get_profile_displayname(self.frank.localpart)), "Frank", + ) + + @defer.inlineCallbacks + def test_set_my_name_if_disabled(self): + self.hs.config.enable_set_displayname = False + + # Setting displayname for the first time is allowed + yield self.store.set_profile_displayname(self.frank.localpart, "Frank") + + self.assertEquals( + (yield self.store.get_profile_displayname(self.frank.localpart)), "Frank", + ) + + # Setting displayname a second time is forbidden + d = self.handler.set_displayname( + self.frank, synapse.types.create_requester(self.frank), "Frank Jr." + ) + + yield self.assertFailure(d, SynapseError) + @defer.inlineCallbacks def test_set_my_name_noauth(self): d = self.handler.set_displayname( @@ -147,3 +175,38 @@ class ProfileTestCase(unittest.TestCase): (yield self.store.get_profile_avatar_url(self.frank.localpart)), "http://my.server/pic.gif", ) + + # Set avatar again + yield self.handler.set_avatar_url( + self.frank, + synapse.types.create_requester(self.frank), + "http://my.server/me.png", + ) + + self.assertEquals( + (yield self.store.get_profile_avatar_url(self.frank.localpart)), + "http://my.server/me.png", + ) + + @defer.inlineCallbacks + def test_set_my_avatar_if_disabled(self): + self.hs.config.enable_set_avatar_url = False + + # Setting displayname for the first time is allowed + yield self.store.set_profile_avatar_url( + self.frank.localpart, "http://my.server/me.png" + ) + + self.assertEquals( + (yield self.store.get_profile_avatar_url(self.frank.localpart)), + "http://my.server/me.png", + ) + + # Set avatar a second time is forbidden + d = self.handler.set_avatar_url( + self.frank, + synapse.types.create_requester(self.frank), + "http://my.server/pic.gif", + ) + + yield self.assertFailure(d, SynapseError) diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py index c3facc00eb..45a9d445f8 100644 --- a/tests/rest/client/v2_alpha/test_account.py +++ b/tests/rest/client/v2_alpha/test_account.py @@ -24,6 +24,7 @@ import pkg_resources import synapse.rest.admin from synapse.api.constants import LoginType, Membership +from synapse.api.errors import Codes from synapse.rest.client.v1 import login, room from synapse.rest.client.v2_alpha import account, register @@ -325,3 +326,304 @@ class DeactivateTestCase(unittest.HomeserverTestCase): ) self.render(request) self.assertEqual(request.code, 200) + + +class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): + + servlets = [ + account.register_servlets, + login.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, + ] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + + # Email config. + self.email_attempts = [] + + def sendmail(smtphost, from_addr, to_addrs, msg, **kwargs): + self.email_attempts.append(msg) + + config["email"] = { + "enable_notifs": False, + "template_dir": os.path.abspath( + pkg_resources.resource_filename("synapse", "res/templates") + ), + "smtp_host": "127.0.0.1", + "smtp_port": 20, + "require_transport_security": False, + "smtp_user": None, + "smtp_pass": None, + "notif_from": "test@example.com", + } + config["public_baseurl"] = "https://example.com" + + self.hs = self.setup_test_homeserver(config=config, sendmail=sendmail) + return self.hs + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + + self.user_id = self.register_user("kermit", "test") + self.user_id_tok = self.login("kermit", "test") + self.email = "test@example.com" + self.url_3pid = b"account/3pid" + + def test_add_email(self): + """Test adding an email to profile + """ + client_secret = "foobar" + session_id = self._request_token(self.email, client_secret) + + self.assertEquals(len(self.email_attempts), 1) + link = self._get_link_from_email() + + self._validate_token(link) + + request, channel = self.make_request( + "POST", + b"/_matrix/client/unstable/account/3pid/add", + { + "client_secret": client_secret, + "sid": session_id, + "auth": { + "type": "m.login.password", + "user": self.user_id, + "password": "test", + }, + }, + access_token=self.user_id_tok, + ) + + self.render(request) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Get user + request, channel = self.make_request( + "GET", self.url_3pid, access_token=self.user_id_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) + self.assertEqual(self.email, channel.json_body["threepids"][0]["address"]) + + def test_add_email_if_disabled(self): + """Test adding email to profile when doing so is disallowed + """ + self.hs.config.enable_3pid_changes = False + + client_secret = "foobar" + session_id = self._request_token(self.email, client_secret) + + self.assertEquals(len(self.email_attempts), 1) + link = self._get_link_from_email() + + self._validate_token(link) + + request, channel = self.make_request( + "POST", + b"/_matrix/client/unstable/account/3pid/add", + { + "client_secret": client_secret, + "sid": session_id, + "auth": { + "type": "m.login.password", + "user": self.user_id, + "password": "test", + }, + }, + access_token=self.user_id_tok, + ) + self.render(request) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + # Get user + request, channel = self.make_request( + "GET", self.url_3pid, access_token=self.user_id_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertFalse(channel.json_body["threepids"]) + + def test_delete_email(self): + """Test deleting an email from profile + """ + # Add a threepid + self.get_success( + self.store.user_add_threepid( + user_id=self.user_id, + medium="email", + address=self.email, + validated_at=0, + added_at=0, + ) + ) + + request, channel = self.make_request( + "POST", + b"account/3pid/delete", + {"medium": "email", "address": self.email}, + access_token=self.user_id_tok, + ) + self.render(request) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Get user + request, channel = self.make_request( + "GET", self.url_3pid, access_token=self.user_id_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertFalse(channel.json_body["threepids"]) + + def test_delete_email_if_disabled(self): + """Test deleting an email from profile when disallowed + """ + self.hs.config.enable_3pid_changes = False + + # Add a threepid + self.get_success( + self.store.user_add_threepid( + user_id=self.user_id, + medium="email", + address=self.email, + validated_at=0, + added_at=0, + ) + ) + + request, channel = self.make_request( + "POST", + b"account/3pid/delete", + {"medium": "email", "address": self.email}, + access_token=self.user_id_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + # Get user + request, channel = self.make_request( + "GET", self.url_3pid, access_token=self.user_id_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) + self.assertEqual(self.email, channel.json_body["threepids"][0]["address"]) + + def test_cant_add_email_without_clicking_link(self): + """Test that we do actually need to click the link in the email + """ + client_secret = "foobar" + session_id = self._request_token(self.email, client_secret) + + self.assertEquals(len(self.email_attempts), 1) + + # Attempt to add email without clicking the link + request, channel = self.make_request( + "POST", + b"/_matrix/client/unstable/account/3pid/add", + { + "client_secret": client_secret, + "sid": session_id, + "auth": { + "type": "m.login.password", + "user": self.user_id, + "password": "test", + }, + }, + access_token=self.user_id_tok, + ) + self.render(request) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"]) + + # Get user + request, channel = self.make_request( + "GET", self.url_3pid, access_token=self.user_id_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertFalse(channel.json_body["threepids"]) + + def test_no_valid_token(self): + """Test that we do actually need to request a token and can't just + make a session up. + """ + client_secret = "foobar" + session_id = "weasle" + + # Attempt to add email without even requesting an email + request, channel = self.make_request( + "POST", + b"/_matrix/client/unstable/account/3pid/add", + { + "client_secret": client_secret, + "sid": session_id, + "auth": { + "type": "m.login.password", + "user": self.user_id, + "password": "test", + }, + }, + access_token=self.user_id_tok, + ) + self.render(request) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"]) + + # Get user + request, channel = self.make_request( + "GET", self.url_3pid, access_token=self.user_id_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertFalse(channel.json_body["threepids"]) + + def _request_token(self, email, client_secret): + request, channel = self.make_request( + "POST", + b"account/3pid/email/requestToken", + {"client_secret": client_secret, "email": email, "send_attempt": 1}, + ) + self.render(request) + self.assertEquals(200, channel.code, channel.result) + + return channel.json_body["sid"] + + def _validate_token(self, link): + # Remove the host + path = link.replace("https://example.com", "") + + request, channel = self.make_request("GET", path, shorthand=False) + self.render(request) + self.assertEquals(200, channel.code, channel.result) + + def _get_link_from_email(self): + assert self.email_attempts, "No emails have been sent" + + raw_msg = self.email_attempts[-1].decode("UTF-8") + mail = Parser().parsestr(raw_msg) + + text = None + for part in mail.walk(): + if part.get_content_type() == "text/plain": + text = part.get_payload(decode=True).decode("UTF-8") + break + + if not text: + self.fail("Could not find text portion of email to parse") + + match = re.search(r"https://example.com\S+", text) + assert match, "Could not find link in email" + + return match.group(0) -- cgit 1.5.1 From fb69690761762092c8e44d509d4f72408c4c67e0 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Fri, 27 Mar 2020 20:16:43 +0100 Subject: Admin API to join users to a room. (#7051) --- changelog.d/7051.feature | 1 + docs/admin_api/room_membership.md | 34 +++++ synapse/rest/admin/__init__.py | 7 +- synapse/rest/admin/rooms.py | 79 ++++++++++- tests/rest/admin/test_room.py | 288 ++++++++++++++++++++++++++++++++++++++ 5 files changed, 405 insertions(+), 4 deletions(-) create mode 100644 changelog.d/7051.feature create mode 100644 docs/admin_api/room_membership.md create mode 100644 tests/rest/admin/test_room.py (limited to 'tests') diff --git a/changelog.d/7051.feature b/changelog.d/7051.feature new file mode 100644 index 0000000000..3e36a3f65e --- /dev/null +++ b/changelog.d/7051.feature @@ -0,0 +1 @@ +Admin API `POST /_synapse/admin/v1/join/` to join users to a room like `auto_join_rooms` for creation of users. \ No newline at end of file diff --git a/docs/admin_api/room_membership.md b/docs/admin_api/room_membership.md new file mode 100644 index 0000000000..16736d3d37 --- /dev/null +++ b/docs/admin_api/room_membership.md @@ -0,0 +1,34 @@ +# Edit Room Membership API + +This API allows an administrator to join an user account with a given `user_id` +to a room with a given `room_id_or_alias`. You can only modify the membership of +local users. The server administrator must be in the room and have permission to +invite users. + +## Parameters + +The following parameters are available: + +* `user_id` - Fully qualified user: for example, `@user:server.com`. +* `room_id_or_alias` - The room identifier or alias to join: for example, + `!636q39766251:server.com`. + +## Usage + +``` +POST /_synapse/admin/v1/join/ + +{ + "user_id": "@user:server.com" +} +``` + +Including an `access_token` of a server admin. + +Response: + +``` +{ + "room_id": "!636q39766251:server.com" +} +``` diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index 42cc2b062a..ed70d448a1 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -29,7 +29,11 @@ from synapse.rest.admin._base import ( from synapse.rest.admin.groups import DeleteGroupAdminRestServlet from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_media_repo from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet -from synapse.rest.admin.rooms import ListRoomRestServlet, ShutdownRoomRestServlet +from synapse.rest.admin.rooms import ( + JoinRoomAliasServlet, + ListRoomRestServlet, + ShutdownRoomRestServlet, +) from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet from synapse.rest.admin.users import ( AccountValidityRenewServlet, @@ -189,6 +193,7 @@ def register_servlets(hs, http_server): """ register_servlets_for_client_rest_resource(hs, http_server) ListRoomRestServlet(hs).register(http_server) + JoinRoomAliasServlet(hs).register(http_server) PurgeRoomServlet(hs).register(http_server) SendServerNoticeServlet(hs).register(http_server) VersionServlet(hs).register(http_server) diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index f9b8c0a4f0..659b8a10ee 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -13,9 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from typing import List, Optional -from synapse.api.constants import Membership -from synapse.api.errors import Codes, SynapseError +from synapse.api.constants import EventTypes, JoinRules, Membership +from synapse.api.errors import Codes, NotFoundError, SynapseError from synapse.http.servlet import ( RestServlet, assert_params_in_dict, @@ -29,7 +30,7 @@ from synapse.rest.admin._base import ( historical_admin_path_patterns, ) from synapse.storage.data_stores.main.room import RoomSortOrder -from synapse.types import create_requester +from synapse.types import RoomAlias, RoomID, UserID, create_requester from synapse.util.async_helpers import maybe_awaitable logger = logging.getLogger(__name__) @@ -237,3 +238,75 @@ class ListRoomRestServlet(RestServlet): response["prev_batch"] = 0 return 200, response + + +class JoinRoomAliasServlet(RestServlet): + + PATTERNS = admin_patterns("/join/(?P[^/]*)") + + def __init__(self, hs): + self.hs = hs + self.auth = hs.get_auth() + self.room_member_handler = hs.get_room_member_handler() + self.admin_handler = hs.get_handlers().admin_handler + self.state_handler = hs.get_state_handler() + + async def on_POST(self, request, room_identifier): + requester = await self.auth.get_user_by_req(request) + await assert_user_is_admin(self.auth, requester.user) + + content = parse_json_object_from_request(request) + + assert_params_in_dict(content, ["user_id"]) + target_user = UserID.from_string(content["user_id"]) + + if not self.hs.is_mine(target_user): + raise SynapseError(400, "This endpoint can only be used with local users") + + if not await self.admin_handler.get_user(target_user): + raise NotFoundError("User not found") + + if RoomID.is_valid(room_identifier): + room_id = room_identifier + try: + remote_room_hosts = [ + x.decode("ascii") for x in request.args[b"server_name"] + ] # type: Optional[List[str]] + except Exception: + remote_room_hosts = None + elif RoomAlias.is_valid(room_identifier): + handler = self.room_member_handler + room_alias = RoomAlias.from_string(room_identifier) + room_id, remote_room_hosts = await handler.lookup_room_alias(room_alias) + room_id = room_id.to_string() + else: + raise SynapseError( + 400, "%s was not legal room ID or room alias" % (room_identifier,) + ) + + fake_requester = create_requester(target_user) + + # send invite if room has "JoinRules.INVITE" + room_state = await self.state_handler.get_current_state(room_id) + join_rules_event = room_state.get((EventTypes.JoinRules, "")) + if join_rules_event: + if not (join_rules_event.content.get("join_rule") == JoinRules.PUBLIC): + await self.room_member_handler.update_membership( + requester=requester, + target=fake_requester.user, + room_id=room_id, + action="invite", + remote_room_hosts=remote_room_hosts, + ratelimit=False, + ) + + await self.room_member_handler.update_membership( + requester=fake_requester, + target=fake_requester.user, + room_id=room_id, + action="join", + remote_room_hosts=remote_room_hosts, + ratelimit=False, + ) + + return 200, {"room_id": room_id} diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py new file mode 100644 index 0000000000..672cc3eac5 --- /dev/null +++ b/tests/rest/admin/test_room.py @@ -0,0 +1,288 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Dirk Klimpel +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +import synapse.rest.admin +from synapse.api.errors import Codes +from synapse.rest.client.v1 import login, room + +from tests import unittest + +"""Tests admin REST events for /rooms paths.""" + + +class JoinAliasRoomTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, homeserver): + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.creator = self.register_user("creator", "test") + self.creator_tok = self.login("creator", "test") + + self.second_user_id = self.register_user("second", "test") + self.second_tok = self.login("second", "test") + + self.public_room_id = self.helper.create_room_as( + self.creator, tok=self.creator_tok, is_public=True + ) + self.url = "/_synapse/admin/v1/join/{}".format(self.public_room_id) + + def test_requester_is_no_admin(self): + """ + If the user is not a server admin, an error 403 is returned. + """ + body = json.dumps({"user_id": self.second_user_id}) + + request, channel = self.make_request( + "POST", + self.url, + content=body.encode(encoding="utf_8"), + access_token=self.second_tok, + ) + self.render(request) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_invalid_parameter(self): + """ + If a parameter is missing, return an error + """ + body = json.dumps({"unknown_parameter": "@unknown:test"}) + + request, channel = self.make_request( + "POST", + self.url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) + + def test_local_user_does_not_exist(self): + """ + Tests that a lookup for a user that does not exist returns a 404 + """ + body = json.dumps({"user_id": "@unknown:test"}) + + request, channel = self.make_request( + "POST", + self.url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + + def test_remote_user(self): + """ + Check that only local user can join rooms. + """ + body = json.dumps({"user_id": "@not:exist.bla"}) + + request, channel = self.make_request( + "POST", + self.url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + "This endpoint can only be used with local users", + channel.json_body["error"], + ) + + def test_room_does_not_exist(self): + """ + Check that unknown rooms/server return error 404. + """ + body = json.dumps({"user_id": self.second_user_id}) + url = "/_synapse/admin/v1/join/!unknown:test" + + request, channel = self.make_request( + "POST", + url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("No known servers", channel.json_body["error"]) + + def test_room_is_not_valid(self): + """ + Check that invalid room names, return an error 400. + """ + body = json.dumps({"user_id": self.second_user_id}) + url = "/_synapse/admin/v1/join/invalidroom" + + request, channel = self.make_request( + "POST", + url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + "invalidroom was not legal room ID or room alias", + channel.json_body["error"], + ) + + def test_join_public_room(self): + """ + Test joining a local user to a public room with "JoinRules.PUBLIC" + """ + body = json.dumps({"user_id": self.second_user_id}) + + request, channel = self.make_request( + "POST", + self.url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(self.public_room_id, channel.json_body["room_id"]) + + # Validate if user is a member of the room + + request, channel = self.make_request( + "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, + ) + self.render(request) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(self.public_room_id, channel.json_body["joined_rooms"][0]) + + def test_join_private_room_if_not_member(self): + """ + Test joining a local user to a private room with "JoinRules.INVITE" + when server admin is not member of this room. + """ + private_room_id = self.helper.create_room_as( + self.creator, tok=self.creator_tok, is_public=False + ) + url = "/_synapse/admin/v1/join/{}".format(private_room_id) + body = json.dumps({"user_id": self.second_user_id}) + + request, channel = self.make_request( + "POST", + url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_join_private_room_if_member(self): + """ + Test joining a local user to a private room with "JoinRules.INVITE", + when server admin is member of this room. + """ + private_room_id = self.helper.create_room_as( + self.creator, tok=self.creator_tok, is_public=False + ) + self.helper.invite( + room=private_room_id, + src=self.creator, + targ=self.admin_user, + tok=self.creator_tok, + ) + self.helper.join( + room=private_room_id, user=self.admin_user, tok=self.admin_user_tok + ) + + # Validate if server admin is a member of the room + + request, channel = self.make_request( + "GET", "/_matrix/client/r0/joined_rooms", access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) + + # Join user to room. + + url = "/_synapse/admin/v1/join/{}".format(private_room_id) + body = json.dumps({"user_id": self.second_user_id}) + + request, channel = self.make_request( + "POST", + url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(private_room_id, channel.json_body["room_id"]) + + # Validate if user is a member of the room + + request, channel = self.make_request( + "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, + ) + self.render(request) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) + + def test_join_private_room_if_owner(self): + """ + Test joining a local user to a private room with "JoinRules.INVITE", + when server admin is owner of this room. + """ + private_room_id = self.helper.create_room_as( + self.admin_user, tok=self.admin_user_tok, is_public=False + ) + url = "/_synapse/admin/v1/join/{}".format(private_room_id) + body = json.dumps({"user_id": self.second_user_id}) + + request, channel = self.make_request( + "POST", + url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(private_room_id, channel.json_body["room_id"]) + + # Validate if user is a member of the room + + request, channel = self.make_request( + "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, + ) + self.render(request) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) -- cgit 1.5.1 From b7da598a61a1bcea3855edf403bdc5ea32cc9e7a Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 27 Mar 2020 20:24:52 +0000 Subject: Always whitelist the login fallback for SSO (#7153) That fallback sets the redirect URL to itself (so it can process the login token then return gracefully to the client). This would make it pointless to ask the user for confirmation, since the URL the confirmation page would be showing wouldn't be the client's. --- changelog.d/7153.feature | 1 + docs/sample_config.yaml | 4 ++++ synapse/config/sso.py | 15 +++++++++++++++ tests/rest/client/v1/test_login.py | 9 ++++++++- 4 files changed, 28 insertions(+), 1 deletion(-) create mode 100644 changelog.d/7153.feature (limited to 'tests') diff --git a/changelog.d/7153.feature b/changelog.d/7153.feature new file mode 100644 index 0000000000..414ebe1f69 --- /dev/null +++ b/changelog.d/7153.feature @@ -0,0 +1 @@ +Always whitelist the login fallback in the SSO configuration if `public_baseurl` is set. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 545226f753..743949945a 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1444,6 +1444,10 @@ sso: # phishing attacks from evil.site. To avoid this, include a slash after the # hostname: "https://my.client/". # + # If public_baseurl is set, then the login fallback page (used by clients + # that don't natively support the required login flows) is whitelisted in + # addition to any URLs in this list. + # # By default, this list is empty. # #client_whitelist: diff --git a/synapse/config/sso.py b/synapse/config/sso.py index 95762689bc..ec3dca9efc 100644 --- a/synapse/config/sso.py +++ b/synapse/config/sso.py @@ -39,6 +39,17 @@ class SSOConfig(Config): self.sso_client_whitelist = sso_config.get("client_whitelist") or [] + # Attempt to also whitelist the server's login fallback, since that fallback sets + # the redirect URL to itself (so it can process the login token then return + # gracefully to the client). This would make it pointless to ask the user for + # confirmation, since the URL the confirmation page would be showing wouldn't be + # the client's. + # public_baseurl is an optional setting, so we only add the fallback's URL to the + # list if it's provided (because we can't figure out what that URL is otherwise). + if self.public_baseurl: + login_fallback_url = self.public_baseurl + "_matrix/static/client/login" + self.sso_client_whitelist.append(login_fallback_url) + def generate_config_section(self, **kwargs): return """\ # Additional settings to use with single-sign on systems such as SAML2 and CAS. @@ -54,6 +65,10 @@ class SSOConfig(Config): # phishing attacks from evil.site. To avoid this, include a slash after the # hostname: "https://my.client/". # + # If public_baseurl is set, then the login fallback page (used by clients + # that don't natively support the required login flows) is whitelisted in + # addition to any URLs in this list. + # # By default, this list is empty. # #client_whitelist: diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index da2c9bfa1e..aed8853d6e 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -350,7 +350,14 @@ class CASRedirectConfirmTestCase(unittest.HomeserverTestCase): def test_cas_redirect_whitelisted(self): """Tests that the SSO login flow serves a redirect to a whitelisted url """ - redirect_url = "https://legit-site.com/" + self._test_redirect("https://legit-site.com/") + + @override_config({"public_baseurl": "https://example.com"}) + def test_cas_redirect_login_fallback(self): + self._test_redirect("https://example.com/_matrix/static/client/login") + + def _test_redirect(self, redirect_url): + """Tests that the SSO login flow serves a redirect for the given redirect URL.""" cas_ticket_url = ( "/_matrix/client/r0/login/cas/ticket?redirectUrl=%s&ticket=ticket" % (urllib.parse.quote(redirect_url)) -- cgit 1.5.1 From 7966a1cde9d4b598faa06620424844f2b35c94af Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 30 Mar 2020 19:06:52 +0100 Subject: Rewrite prune_old_outbound_device_pokes for efficiency (#7159) make sure we clear out all but one update for the user --- changelog.d/7159.bugfix | 1 + synapse/handlers/federation.py | 25 +------- synapse/storage/data_stores/main/devices.py | 71 ++++++++++++++++++---- synapse/util/stringutils.py | 21 ++++++- tests/federation/test_federation_sender.py | 92 +++++++++++++++++++++++++++++ 5 files changed, 173 insertions(+), 37 deletions(-) create mode 100644 changelog.d/7159.bugfix (limited to 'tests') diff --git a/changelog.d/7159.bugfix b/changelog.d/7159.bugfix new file mode 100644 index 0000000000..1b341b127b --- /dev/null +++ b/changelog.d/7159.bugfix @@ -0,0 +1 @@ +Fix excessive CPU usage by `prune_old_outbound_device_pokes` job. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 38ab6a8fc3..c7aa7acf3b 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -49,6 +49,7 @@ from synapse.event_auth import auth_types_for_event from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.events.validator import EventValidator +from synapse.handlers._base import BaseHandler from synapse.logging.context import ( make_deferred_yieldable, nested_logging_context, @@ -69,10 +70,9 @@ from synapse.types import JsonDict, StateMap, UserID, get_domain_from_id from synapse.util.async_helpers import Linearizer, concurrently_execute from synapse.util.distributor import user_joined_room from synapse.util.retryutils import NotRetryingDestination +from synapse.util.stringutils import shortstr from synapse.visibility import filter_events_for_server -from ._base import BaseHandler - logger = logging.getLogger(__name__) @@ -93,27 +93,6 @@ class _NewEventInfo: auth_events = attr.ib(type=Optional[StateMap[EventBase]], default=None) -def shortstr(iterable, maxitems=5): - """If iterable has maxitems or fewer, return the stringification of a list - containing those items. - - Otherwise, return the stringification of a a list with the first maxitems items, - followed by "...". - - Args: - iterable (Iterable): iterable to truncate - maxitems (int): number of items to return before truncating - - Returns: - unicode - """ - - items = list(itertools.islice(iterable, maxitems + 1)) - if len(items) <= maxitems: - return str(items) - return "[" + ", ".join(repr(r) for r in items[:maxitems]) + ", ...]" - - class FederationHandler(BaseHandler): """Handles events that originated from federation. Responsible for: diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py index 2d47cfd131..3140e1b722 100644 --- a/synapse/storage/data_stores/main/devices.py +++ b/synapse/storage/data_stores/main/devices.py @@ -41,6 +41,7 @@ from synapse.util.caches.descriptors import ( cachedList, ) from synapse.util.iterutils import batch_iter +from synapse.util.stringutils import shortstr logger = logging.getLogger(__name__) @@ -1092,18 +1093,47 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ], ) - def _prune_old_outbound_device_pokes(self): + def _prune_old_outbound_device_pokes(self, prune_age=24 * 60 * 60 * 1000): """Delete old entries out of the device_lists_outbound_pokes to ensure - that we don't fill up due to dead servers. We keep one entry per - (destination, user_id) tuple to ensure that the prev_ids remain correct - if the server does come back. + that we don't fill up due to dead servers. + + Normally, we try to send device updates as a delta since a previous known point: + this is done by setting the prev_id in the m.device_list_update EDU. However, + for that to work, we have to have a complete record of each change to + each device, which can add up to quite a lot of data. + + An alternative mechanism is that, if the remote server sees that it has missed + an entry in the stream_id sequence for a given user, it will request a full + list of that user's devices. Hence, we can reduce the amount of data we have to + store (and transmit in some future transaction), by clearing almost everything + for a given destination out of the database, and having the remote server + resync. + + All we need to do is make sure we keep at least one row for each + (user, destination) pair, to remind us to send a m.device_list_update EDU for + that user when the destination comes back. It doesn't matter which device + we keep. """ - yesterday = self._clock.time_msec() - 24 * 60 * 60 * 1000 + yesterday = self._clock.time_msec() - prune_age def _prune_txn(txn): + # look for (user, destination) pairs which have an update older than + # the cutoff. + # + # For each pair, we also need to know the most recent stream_id, and + # an arbitrary device_id at that stream_id. select_sql = """ - SELECT destination, user_id, max(stream_id) as stream_id - FROM device_lists_outbound_pokes + SELECT + dlop1.destination, + dlop1.user_id, + MAX(dlop1.stream_id) AS stream_id, + (SELECT MIN(dlop2.device_id) AS device_id FROM + device_lists_outbound_pokes dlop2 + WHERE dlop2.destination = dlop1.destination AND + dlop2.user_id=dlop1.user_id AND + dlop2.stream_id=MAX(dlop1.stream_id) + ) + FROM device_lists_outbound_pokes dlop1 GROUP BY destination, user_id HAVING min(ts) < ? AND count(*) > 1 """ @@ -1114,14 +1144,29 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): if not rows: return + logger.info( + "Pruning old outbound device list updates for %i users/destinations: %s", + len(rows), + shortstr((row[0], row[1]) for row in rows), + ) + + # we want to keep the update with the highest stream_id for each user. + # + # there might be more than one update (with different device_ids) with the + # same stream_id, so we also delete all but one rows with the max stream id. delete_sql = """ DELETE FROM device_lists_outbound_pokes - WHERE ts < ? AND destination = ? AND user_id = ? AND stream_id < ? + WHERE destination = ? AND user_id = ? AND ( + stream_id < ? OR + (stream_id = ? AND device_id != ?) + ) """ - - txn.executemany( - delete_sql, ((yesterday, row[0], row[1], row[2]) for row in rows) - ) + count = 0 + for (destination, user_id, stream_id, device_id) in rows: + txn.execute( + delete_sql, (destination, user_id, stream_id, stream_id, device_id) + ) + count += txn.rowcount # Since we've deleted unsent deltas, we need to remove the entry # of last successful sent so that the prev_ids are correctly set. @@ -1131,7 +1176,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): """ txn.executemany(sql, ((row[0], row[1]) for row in rows)) - logger.info("Pruned %d device list outbound pokes", txn.rowcount) + logger.info("Pruned %d device list outbound pokes", count) return run_as_background_process( "prune_old_outbound_device_pokes", diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py index 2c0dcb5208..6899bcb788 100644 --- a/synapse/util/stringutils.py +++ b/synapse/util/stringutils.py @@ -13,10 +13,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import itertools import random import re import string +from collections import Iterable import six from six import PY2, PY3 @@ -126,3 +127,21 @@ def assert_valid_client_secret(client_secret): raise SynapseError( 400, "Invalid client_secret parameter", errcode=Codes.INVALID_PARAM ) + + +def shortstr(iterable: Iterable, maxitems: int = 5) -> str: + """If iterable has maxitems or fewer, return the stringification of a list + containing those items. + + Otherwise, return the stringification of a a list with the first maxitems items, + followed by "...". + + Args: + iterable: iterable to truncate + maxitems: number of items to return before truncating + """ + + items = list(itertools.islice(iterable, maxitems + 1)) + if len(items) <= maxitems: + return str(items) + return "[" + ", ".join(repr(r) for r in items[:maxitems]) + ", ...]" diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index 7763b12159..a5fe5c6880 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -370,6 +370,98 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): devices = {edu["content"]["device_id"] for edu in self.edus} self.assertEqual({"D1", "D2", "D3"}, devices) + def test_prune_outbound_device_pokes1(self): + """If a destination is unreachable, and the updates are pruned, we should get + a single update. + + This case tests the behaviour when the server has never been reachable. + """ + mock_send_txn = self.hs.get_federation_transport_client().send_transaction + mock_send_txn.side_effect = lambda t, cb: defer.fail("fail") + + # create devices + u1 = self.register_user("user", "pass") + self.login("user", "pass", device_id="D1") + self.login("user", "pass", device_id="D2") + self.login("user", "pass", device_id="D3") + + # delete them again + self.get_success( + self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"]) + ) + + self.assertGreaterEqual(mock_send_txn.call_count, 4) + + # run the prune job + self.reactor.advance(10) + self.get_success( + self.hs.get_datastore()._prune_old_outbound_device_pokes(prune_age=1) + ) + + # recover the server + mock_send_txn.side_effect = self.record_transaction + self.hs.get_federation_sender().send_device_messages("host2") + self.pump() + + # there should be a single update for this user. + self.assertEqual(len(self.edus), 1) + edu = self.edus.pop(0) + self.assertEqual(edu["edu_type"], "m.device_list_update") + c = edu["content"] + + # synapse uses an empty prev_id list to indicate "needs a full resync". + self.assertEqual(c["prev_id"], []) + + def test_prune_outbound_device_pokes2(self): + """If a destination is unreachable, and the updates are pruned, we should get + a single update. + + This case tests the behaviour when the server was reachable, but then goes + offline. + """ + + # create first device + u1 = self.register_user("user", "pass") + self.login("user", "pass", device_id="D1") + + # expect the update EDU + self.assertEqual(len(self.edus), 1) + self.check_device_update_edu(self.edus.pop(0), u1, "D1", None) + + # now the server goes offline + mock_send_txn = self.hs.get_federation_transport_client().send_transaction + mock_send_txn.side_effect = lambda t, cb: defer.fail("fail") + + self.login("user", "pass", device_id="D2") + self.login("user", "pass", device_id="D3") + + # delete them again + self.get_success( + self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"]) + ) + + self.assertGreaterEqual(mock_send_txn.call_count, 3) + + # run the prune job + self.reactor.advance(10) + self.get_success( + self.hs.get_datastore()._prune_old_outbound_device_pokes(prune_age=1) + ) + + # recover the server + mock_send_txn.side_effect = self.record_transaction + self.hs.get_federation_sender().send_device_messages("host2") + self.pump() + + # ... and we should get a single update for this user. + self.assertEqual(len(self.edus), 1) + edu = self.edus.pop(0) + self.assertEqual(edu["edu_type"], "m.device_list_update") + c = edu["content"] + + # synapse uses an empty prev_id list to indicate "needs a full resync". + self.assertEqual(c["prev_id"], []) + def check_device_update_edu( self, edu: JsonDict, -- cgit 1.5.1 From cfe8c8ab8e412b6320e5963ced0670fbc7b00d1b Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 31 Mar 2020 17:24:06 +0100 Subject: Remove unused `start_background_update` This was only used in a unit test, so let's just inline it in the test. --- synapse/storage/background_updates.py | 21 --------------------- tests/storage/test_background_update.py | 14 +++++++++----- 2 files changed, 9 insertions(+), 26 deletions(-) (limited to 'tests') diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index eb1a7e5002..d4e26eab6c 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -400,27 +400,6 @@ class BackgroundUpdater(object): self.register_background_update_handler(update_name, updater) - def start_background_update(self, update_name, progress): - """Starts a background update running. - - Args: - update_name: The update to set running. - progress: The initial state of the progress of the update. - - Returns: - A deferred that completes once the task has been added to the - queue. - """ - # Clear the background update queue so that we will pick up the new - # task on the next iteration of do_background_update. - self._background_update_queue = [] - progress_json = json.dumps(progress) - - return self.db.simple_insert( - "background_updates", - {"update_name": update_name, "progress_json": progress_json}, - ) - def _end_background_update(self, update_name): """Removes a completed background update task from the queue. diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py index ae14fb407d..aca41eb215 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py @@ -25,12 +25,20 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase): # the target runtime for each bg update target_background_update_duration_ms = 50000 + store = self.hs.get_datastore() + self.get_success( + store.db.simple_insert( + "background_updates", + values={"update_name": "test_update", "progress_json": '{"my_key": 1}'}, + ) + ) + # first step: make a bit of progress @defer.inlineCallbacks def update(progress, count): yield self.clock.sleep((count * duration_ms) / 1000) progress = {"my_key": progress["my_key"] + 1} - yield self.hs.get_datastore().db.runInteraction( + yield store.db.runInteraction( "update_progress", self.updates._background_update_progress_txn, "test_update", @@ -39,10 +47,6 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase): return count self.update_handler.side_effect = update - - self.get_success( - self.updates.start_background_update("test_update", {"my_key": 1}) - ) self.update_handler.reset_mock() res = self.get_success( self.updates.do_next_background_update( -- cgit 1.5.1 From 26d17b9bdc0de51d5f1a7526e8ab70e7f7796e4d Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 31 Mar 2020 17:25:10 +0100 Subject: Make `has_completed_background_updates` async (Almost) everywhere that uses it is happy with an awaitable. --- synapse/storage/background_updates.py | 7 +++---- tests/storage/test_background_update.py | 4 +++- 2 files changed, 6 insertions(+), 5 deletions(-) (limited to 'tests') diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index d4e26eab6c..79494bcdf5 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -119,12 +119,11 @@ class BackgroundUpdater(object): self._all_done = True return None - @defer.inlineCallbacks - def has_completed_background_updates(self): + async def has_completed_background_updates(self) -> bool: """Check if all the background updates have completed Returns: - Deferred[bool]: True if all background updates have completed + True if all background updates have completed """ # if we've previously determined that there is nothing left to do, that # is easy @@ -138,7 +137,7 @@ class BackgroundUpdater(object): # otherwise, check if there are updates to be run. This is important, # as we may be running on a worker which doesn't perform the bg updates # itself, but still wants to wait for them to happen. - updates = yield self.db.simple_select_onecol( + updates = await self.db.simple_select_onecol( "background_updates", keyvalues=None, retcol="1", diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py index aca41eb215..e578de8acf 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py @@ -11,7 +11,9 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, homeserver): self.updates = self.hs.get_datastore().db.updates # type: BackgroundUpdater # the base test class should have run the real bg updates for us - self.assertTrue(self.updates.has_completed_background_updates()) + self.assertTrue( + self.get_success(self.updates.has_completed_background_updates()) + ) self.update_handler = Mock() self.updates.register_background_update_handler( -- cgit 1.5.1 From 51f4d52cb4663a056372d779b78488aeae45f554 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 31 Mar 2020 17:27:56 +0100 Subject: Set a logging context while running the bg updates This mostly just reduces the amount of "running from sentinel context" spam during unittest setup. --- tests/unittest.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) (limited to 'tests') diff --git a/tests/unittest.py b/tests/unittest.py index d0406ca2fd..27af5228fe 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -40,6 +40,7 @@ from synapse.http.server import JsonResource from synapse.http.site import SynapseRequest, SynapseSite from synapse.logging.context import ( SENTINEL_CONTEXT, + LoggingContext, current_context, set_current_context, ) @@ -419,15 +420,17 @@ class HomeserverTestCase(TestCase): config_obj.parse_config_dict(config, "", "") kwargs["config"] = config_obj + async def run_bg_updates(): + with LoggingContext("run_bg_updates", request="run_bg_updates-1"): + while not await stor.db.updates.has_completed_background_updates(): + await stor.db.updates.do_next_background_update(1) + hs = setup_test_homeserver(self.addCleanup, *args, **kwargs) stor = hs.get_datastore() # Run the database background updates, when running against "master". if hs.__class__.__name__ == "TestHomeServer": - while not self.get_success( - stor.db.updates.has_completed_background_updates() - ): - self.get_success(stor.db.updates.do_next_background_update(1)) + self.get_success(run_bg_updates()) return hs -- cgit 1.5.1 From b4c22342320d7de86c02dfb36415a38c62bec88d Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 31 Mar 2020 17:31:32 +0100 Subject: Make do_next_background_update return a bool returning a None or an int that we don't use is confusing. --- synapse/storage/background_updates.py | 12 +++++------- tests/storage/test_background_update.py | 6 +++--- 2 files changed, 8 insertions(+), 10 deletions(-) (limited to 'tests') diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index 79494bcdf5..4a59132bf3 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -111,7 +111,7 @@ class BackgroundUpdater(object): except Exception: logger.exception("Error doing update") else: - if result is None: + if result: logger.info( "No more background updates to do." " Unscheduling background update task." @@ -169,9 +169,7 @@ class BackgroundUpdater(object): return not update_exists - async def do_next_background_update( - self, desired_duration_ms: float - ) -> Optional[int]: + async def do_next_background_update(self, desired_duration_ms: float) -> bool: """Does some amount of work on the next queued background update Returns once some amount of work is done. @@ -180,7 +178,7 @@ class BackgroundUpdater(object): desired_duration_ms(float): How long we want to spend updating. Returns: - None if there is no more work to do, otherwise an int + True if there is no more work to do, otherwise False """ if not self._background_update_queue: updates = await self.db.simple_select_list( @@ -195,14 +193,14 @@ class BackgroundUpdater(object): if not self._background_update_queue: # no work left to do - return None + return True # pop from the front, and add back to the back update_name = self._background_update_queue.pop(0) self._background_update_queue.append(update_name) res = await self._do_background_update(update_name, desired_duration_ms) - return res + return False async def _do_background_update( self, update_name: str, desired_duration_ms: float diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py index e578de8acf..940b166129 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py @@ -56,7 +56,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase): ), by=0.1, ) - self.assertIsNotNone(res) + self.assertFalse(res) # on the first call, we should get run with the default background update size self.update_handler.assert_called_once_with( @@ -79,7 +79,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase): result = self.get_success( self.updates.do_next_background_update(target_background_update_duration_ms) ) - self.assertIsNotNone(result) + self.assertFalse(result) self.update_handler.assert_called_once() # third step: we don't expect to be called any more @@ -87,5 +87,5 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase): result = self.get_success( self.updates.do_next_background_update(target_background_update_duration_ms) ) - self.assertIsNone(result) + self.assertTrue(result) self.assertFalse(self.update_handler.called) -- cgit 1.5.1 From 468dcc767bf379ba2b4ed4b6d1c6537473175eab Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 1 Apr 2020 08:27:05 -0400 Subject: Allow admins to create aliases when they are not in the room (#7191) --- changelog.d/7191.feature | 1 + synapse/handlers/directory.py | 6 +++- tests/handlers/test_directory.py | 62 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 68 insertions(+), 1 deletion(-) create mode 100644 changelog.d/7191.feature (limited to 'tests') diff --git a/changelog.d/7191.feature b/changelog.d/7191.feature new file mode 100644 index 0000000000..83d5685bb2 --- /dev/null +++ b/changelog.d/7191.feature @@ -0,0 +1 @@ +Admin users are no longer required to be in a room to create an alias for it. diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 1d842c369b..53e5f585d9 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -127,7 +127,11 @@ class DirectoryHandler(BaseHandler): errcode=Codes.EXCLUSIVE, ) else: - if self.require_membership and check_membership: + # Server admins are not subject to the same constraints as normal + # users when creating an alias (e.g. being in the room). + is_admin = yield self.auth.is_server_admin(requester.user) + + if (self.require_membership and check_membership) and not is_admin: rooms_for_user = yield self.store.get_rooms_for_user(user_id) if room_id not in rooms_for_user: raise AuthError( diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index 5e40adba52..00bb776271 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -102,6 +102,68 @@ class DirectoryTestCase(unittest.HomeserverTestCase): self.assertEquals({"room_id": "!8765asdf:test", "servers": ["test"]}, response) +class TestCreateAlias(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + room.register_servlets, + directory.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.handler = hs.get_handlers().directory_handler + + # Create user + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + # Create a test room + self.room_id = self.helper.create_room_as( + self.admin_user, tok=self.admin_user_tok + ) + + self.test_alias = "#test:test" + self.room_alias = RoomAlias.from_string(self.test_alias) + + # Create a test user. + self.test_user = self.register_user("user", "pass", admin=False) + self.test_user_tok = self.login("user", "pass") + self.helper.join(room=self.room_id, user=self.test_user, tok=self.test_user_tok) + + def test_create_alias_joined_room(self): + """A user can create an alias for a room they're in.""" + self.get_success( + self.handler.create_association( + create_requester(self.test_user), self.room_alias, self.room_id, + ) + ) + + def test_create_alias_other_room(self): + """A user cannot create an alias for a room they're NOT in.""" + other_room_id = self.helper.create_room_as( + self.admin_user, tok=self.admin_user_tok + ) + + self.get_failure( + self.handler.create_association( + create_requester(self.test_user), self.room_alias, other_room_id, + ), + synapse.api.errors.SynapseError, + ) + + def test_create_alias_admin(self): + """An admin can create an alias for a room they're NOT in.""" + other_room_id = self.helper.create_room_as( + self.test_user, tok=self.test_user_tok + ) + + self.get_success( + self.handler.create_association( + create_requester(self.admin_user), self.room_alias, other_room_id, + ) + ) + + class TestDeleteAlias(unittest.HomeserverTestCase): servlets = [ synapse.rest.admin.register_servlets, -- cgit 1.5.1 From daa1ac89a0be4dd3cc941da4caeb2ddcbd701eff Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 3 Apr 2020 10:40:22 +0100 Subject: Fix device list update stream ids going backward (#7158) Occasionally we could get a federation device list update transaction which looked like: ``` [ {'edu_type': 'm.device_list_update', 'content': {'user_id': '@user:test', 'device_id': 'D2', 'prev_id': [], 'stream_id': 12, 'deleted': True}}, {'edu_type': 'm.device_list_update', 'content': {'user_id': '@user:test', 'device_id': 'D1', 'prev_id': [12], 'stream_id': 11, 'deleted': True}}, {'edu_type': 'm.device_list_update', 'content': {'user_id': '@user:test', 'device_id': 'D3', 'prev_id': [11], 'stream_id': 13, 'deleted': True}} ] ``` Having `stream_ids` which are lower than `prev_ids` looks odd. It might work (I'm not actually sure), but in any case it doesn't seem like a reasonable thing to expect other implementations to support. --- changelog.d/7158.misc | 1 + synapse/storage/data_stores/main/devices.py | 10 ++++++++-- tests/federation/test_federation_sender.py | 6 ++++++ 3 files changed, 15 insertions(+), 2 deletions(-) create mode 100644 changelog.d/7158.misc (limited to 'tests') diff --git a/changelog.d/7158.misc b/changelog.d/7158.misc new file mode 100644 index 0000000000..269b8daeb0 --- /dev/null +++ b/changelog.d/7158.misc @@ -0,0 +1 @@ +Fix device list update stream ids going backward. diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py index 20995e1b78..dd3561e9b2 100644 --- a/synapse/storage/data_stores/main/devices.py +++ b/synapse/storage/data_stores/main/devices.py @@ -165,7 +165,6 @@ class DeviceWorkerStore(SQLBaseStore): # the max stream_id across each set of duplicate entries # # maps (user_id, device_id) -> (stream_id, opentracing_context) - # as long as their stream_id does not match that of the last row # # opentracing_context contains the opentracing metadata for the request # that created the poke @@ -270,7 +269,14 @@ class DeviceWorkerStore(SQLBaseStore): prev_id = yield self._get_last_device_update_for_remote_user( destination, user_id, from_stream_id ) - for device_id, device in iteritems(user_devices): + + # make sure we go through the devices in stream order + device_ids = sorted( + user_devices.keys(), key=lambda i: query_map[(user_id, i)][0], + ) + + for device_id in device_ids: + device = user_devices[device_id] stream_id, opentracing_context = query_map[(user_id, device_id)] result = { "user_id": user_id, diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index a5fe5c6880..33105576af 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -297,6 +297,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): c = edu["content"] if stream_id is not None: self.assertEqual(c["prev_id"], [stream_id]) + self.assertGreaterEqual(c["stream_id"], stream_id) stream_id = c["stream_id"] devices = {edu["content"]["device_id"] for edu in self.edus} self.assertEqual({"D1", "D2"}, devices) @@ -330,6 +331,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): c.items(), {"user_id": u1, "prev_id": [stream_id], "deleted": True}.items(), ) + self.assertGreaterEqual(c["stream_id"], stream_id) stream_id = c["stream_id"] devices = {edu["content"]["device_id"] for edu in self.edus} self.assertEqual({"D1", "D2", "D3"}, devices) @@ -366,6 +368,8 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): self.assertEqual(edu["edu_type"], "m.device_list_update") c = edu["content"] self.assertEqual(c["prev_id"], [stream_id] if stream_id is not None else []) + if stream_id is not None: + self.assertGreaterEqual(c["stream_id"], stream_id) stream_id = c["stream_id"] devices = {edu["content"]["device_id"] for edu in self.edus} self.assertEqual({"D1", "D2", "D3"}, devices) @@ -482,6 +486,8 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): } self.assertLessEqual(expected.items(), content.items()) + if prev_stream_id is not None: + self.assertGreaterEqual(content["stream_id"], prev_stream_id) return content["stream_id"] def check_signing_key_update_txn(self, txn: JsonDict,) -> None: -- cgit 1.5.1 From d73bf18d13031d9f9c0375b83f2cc5ff6f415251 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Sat, 4 Apr 2020 17:27:45 +0200 Subject: Server notices: Dissociate room creation/lookup from invite (#7199) Fixes #6815 Before figuring out whether we should alert a user on MAU, we call get_notice_room_for_user to get some info on the existing server notices room for this user. This function, if the room doesn't exist, creates it and invites the user in it. This means that, if we decide later that no server notice is needed, the user gets invited in a room with no message in it. This happens at every restart of the server, since the room ID returned by get_notice_room_for_user is cached. This PR fixes that by moving the inviting bit to a dedicated function, that's only called when the server actually needs to send a notice to the user. A potential issue with this approach is that the room that's created by get_notice_room_for_user doesn't match how that same function looks for an existing room (i.e. it creates a room that doesn't have an invite or a join for the current user in it, so it could lead to a new room being created each time a user syncs), but I'm not sure this is a problem given it's cached until the server restarts, so that function won't run very often. It also renames get_notice_room_for_user into get_or_create_notice_room_for_user to make what it does clearer. --- changelog.d/7199.bugfix | 1 + .../resource_limits_server_notices.py | 4 +- synapse/server_notices/server_notices_manager.py | 51 +++++++-- .../test_resource_limits_server_notices.py | 120 ++++++++++++++++++--- 4 files changed, 154 insertions(+), 22 deletions(-) create mode 100644 changelog.d/7199.bugfix (limited to 'tests') diff --git a/changelog.d/7199.bugfix b/changelog.d/7199.bugfix new file mode 100644 index 0000000000..b234163ea8 --- /dev/null +++ b/changelog.d/7199.bugfix @@ -0,0 +1 @@ +Fix a bug that could cause a user to be invited to a server notices (aka System Alerts) room without any notice being sent. diff --git a/synapse/server_notices/resource_limits_server_notices.py b/synapse/server_notices/resource_limits_server_notices.py index 9fae2e0afe..ce4a828894 100644 --- a/synapse/server_notices/resource_limits_server_notices.py +++ b/synapse/server_notices/resource_limits_server_notices.py @@ -80,7 +80,9 @@ class ResourceLimitsServerNotices(object): # In practice, not sure we can ever get here return - room_id = yield self._server_notices_manager.get_notice_room_for_user(user_id) + room_id = yield self._server_notices_manager.get_or_create_notice_room_for_user( + user_id + ) if not room_id: logger.warning("Failed to get server notices room") diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py index f7432c8d2f..bf0943f265 100644 --- a/synapse/server_notices/server_notices_manager.py +++ b/synapse/server_notices/server_notices_manager.py @@ -17,7 +17,7 @@ import logging from twisted.internet import defer from synapse.api.constants import EventTypes, Membership, RoomCreationPreset -from synapse.types import create_requester +from synapse.types import UserID, create_requester from synapse.util.caches.descriptors import cachedInlineCallbacks logger = logging.getLogger(__name__) @@ -36,10 +36,12 @@ class ServerNoticesManager(object): self._store = hs.get_datastore() self._config = hs.config self._room_creation_handler = hs.get_room_creation_handler() + self._room_member_handler = hs.get_room_member_handler() self._event_creation_handler = hs.get_event_creation_handler() self._is_mine_id = hs.is_mine_id self._notifier = hs.get_notifier() + self.server_notices_mxid = self._config.server_notices_mxid def is_enabled(self): """Checks if server notices are enabled on this server. @@ -66,7 +68,8 @@ class ServerNoticesManager(object): Returns: Deferred[FrozenEvent] """ - room_id = yield self.get_notice_room_for_user(user_id) + room_id = yield self.get_or_create_notice_room_for_user(user_id) + yield self.maybe_invite_user_to_room(user_id, room_id) system_mxid = self._config.server_notices_mxid requester = create_requester(system_mxid) @@ -89,10 +92,11 @@ class ServerNoticesManager(object): return res @cachedInlineCallbacks() - def get_notice_room_for_user(self, user_id): + def get_or_create_notice_room_for_user(self, user_id): """Get the room for notices for a given user - If we have not yet created a notice room for this user, create it + If we have not yet created a notice room for this user, create it, but don't + invite the user to it. Args: user_id (str): complete user id for the user we want a room for @@ -108,7 +112,6 @@ class ServerNoticesManager(object): rooms = yield self._store.get_rooms_for_local_user_where_membership_is( user_id, [Membership.INVITE, Membership.JOIN] ) - system_mxid = self._config.server_notices_mxid for room in rooms: # it's worth noting that there is an asymmetry here in that we # expect the user to be invited or joined, but the system user must @@ -116,10 +119,14 @@ class ServerNoticesManager(object): # manages to invite the system user to a room, that doesn't make it # the server notices room. user_ids = yield self._store.get_users_in_room(room.room_id) - if system_mxid in user_ids: + if self.server_notices_mxid in user_ids: # we found a room which our user shares with the system notice # user - logger.info("Using room %s", room.room_id) + logger.info( + "Using existing server notices room %s for user %s", + room.room_id, + user_id, + ) return room.room_id # apparently no existing notice room: create a new one @@ -138,14 +145,13 @@ class ServerNoticesManager(object): "avatar_url": self._config.server_notices_mxid_avatar_url, } - requester = create_requester(system_mxid) + requester = create_requester(self.server_notices_mxid) info = yield self._room_creation_handler.create_room( requester, config={ "preset": RoomCreationPreset.PRIVATE_CHAT, "name": self._config.server_notices_room_name, "power_level_content_override": {"users_default": -10}, - "invite": (user_id,), }, ratelimit=False, creator_join_profile=join_profile, @@ -159,3 +165,30 @@ class ServerNoticesManager(object): logger.info("Created server notices room %s for %s", room_id, user_id) return room_id + + @defer.inlineCallbacks + def maybe_invite_user_to_room(self, user_id: str, room_id: str): + """Invite the given user to the given server room, unless the user has already + joined or been invited to it. + + Args: + user_id: The ID of the user to invite. + room_id: The ID of the room to invite the user to. + """ + requester = create_requester(self.server_notices_mxid) + + # Check whether the user has already joined or been invited to this room. If + # that's the case, there is no need to re-invite them. + joined_rooms = yield self._store.get_rooms_for_local_user_where_membership_is( + user_id, [Membership.INVITE, Membership.JOIN] + ) + for room in joined_rooms: + if room.room_id == room_id: + return + + yield self._room_member_handler.update_membership( + requester=requester, + target=UserID.from_string(user_id), + room_id=room_id, + action="invite", + ) diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index 0d27b92a86..93eb053b8c 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -19,6 +19,9 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, LimitBlockingTypes, ServerNoticeMsgType from synapse.api.errors import ResourceLimitError +from synapse.rest import admin +from synapse.rest.client.v1 import login, room +from synapse.rest.client.v2_alpha import sync from synapse.server_notices.resource_limits_server_notices import ( ResourceLimitsServerNotices, ) @@ -67,7 +70,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): # self.server_notices_mxid_avatar_url = None # self.server_notices_room_name = "Server Notices" - self._rlsn._server_notices_manager.get_notice_room_for_user = Mock( + self._rlsn._server_notices_manager.get_or_create_notice_room_for_user = Mock( returnValue="" ) self._rlsn._store.add_tag_to_room = Mock() @@ -215,6 +218,26 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + room.register_servlets, + sync.register_servlets, + ] + + def default_config(self): + c = super().default_config() + c["server_notices"] = { + "system_mxid_localpart": "server", + "system_mxid_display_name": None, + "system_mxid_avatar_url": None, + "room_name": "Test Server Notice Room", + } + c["limit_usage_by_mau"] = True + c["max_mau_value"] = 5 + c["admin_contact"] = "mailto:user@test.com" + return c + def prepare(self, reactor, clock, hs): self.store = self.hs.get_datastore() self.server_notices_sender = self.hs.get_server_notices_sender() @@ -228,18 +251,8 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): if not isinstance(self._rlsn, ResourceLimitsServerNotices): raise Exception("Failed to find reference to ResourceLimitsServerNotices") - self.hs.config.limit_usage_by_mau = True - self.hs.config.hs_disabled = False - self.hs.config.max_mau_value = 5 - self.hs.config.server_notices_mxid = "@server:test" - self.hs.config.server_notices_mxid_display_name = None - self.hs.config.server_notices_mxid_avatar_url = None - self.hs.config.server_notices_room_name = "Test Server Notice Room" - self.user_id = "@user_id:test" - self.hs.config.admin_contact = "mailto:user@test.com" - def test_server_notice_only_sent_once(self): self.store.get_monthly_active_count = Mock(return_value=1000) @@ -253,7 +266,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): # Now lets get the last load of messages in the service notice room and # check that there is only one server notice room_id = self.get_success( - self.server_notices_manager.get_notice_room_for_user(self.user_id) + self.server_notices_manager.get_or_create_notice_room_for_user(self.user_id) ) token = self.get_success(self.event_source.get_current_token()) @@ -273,3 +286,86 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): count += 1 self.assertEqual(count, 1) + + def test_no_invite_without_notice(self): + """Tests that a user doesn't get invited to a server notices room without a + server notice being sent. + + The scenario for this test is a single user on a server where the MAU limit + hasn't been reached (since it's the only user and the limit is 5), so users + shouldn't receive a server notice. + """ + self.register_user("user", "password") + tok = self.login("user", "password") + + request, channel = self.make_request("GET", "/sync?timeout=0", access_token=tok) + self.render(request) + + invites = channel.json_body["rooms"]["invite"] + self.assertEqual(len(invites), 0, invites) + + def test_invite_with_notice(self): + """Tests that, if the MAU limit is hit, the server notices user invites each user + to a room in which it has sent a notice. + """ + user_id, tok, room_id = self._trigger_notice_and_join() + + # Sync again to retrieve the events in the room, so we can check whether this + # room has a notice in it. + request, channel = self.make_request("GET", "/sync?timeout=0", access_token=tok) + self.render(request) + + # Scan the events in the room to search for a message from the server notices + # user. + events = channel.json_body["rooms"]["join"][room_id]["timeline"]["events"] + notice_in_room = False + for event in events: + if ( + event["type"] == EventTypes.Message + and event["sender"] == self.hs.config.server_notices_mxid + ): + notice_in_room = True + + self.assertTrue(notice_in_room, "No server notice in room") + + def _trigger_notice_and_join(self): + """Creates enough active users to hit the MAU limit and trigger a system notice + about it, then joins the system notices room with one of the users created. + + Returns: + user_id (str): The ID of the user that joined the room. + tok (str): The access token of the user that joined the room. + room_id (str): The ID of the room that's been joined. + """ + user_id = None + tok = None + invites = [] + + # Register as many users as the MAU limit allows. + for i in range(self.hs.config.max_mau_value): + localpart = "user%d" % i + user_id = self.register_user(localpart, "password") + tok = self.login(localpart, "password") + + # Sync with the user's token to mark the user as active. + request, channel = self.make_request( + "GET", "/sync?timeout=0", access_token=tok, + ) + self.render(request) + + # Also retrieves the list of invites for this user. We don't care about that + # one except if we're processing the last user, which should have received an + # invite to a room with a server notice about the MAU limit being reached. + # We could also pick another user and sync with it, which would return an + # invite to a system notices room, but it doesn't matter which user we're + # using so we use the last one because it saves us an extra sync. + invites = channel.json_body["rooms"]["invite"] + + # Make sure we have an invite to process. + self.assertEqual(len(invites), 1, invites) + + # Join the room. + room_id = list(invites.keys())[0] + self.helper.join(room=room_id, user=user_id, tok=tok) + + return user_id, tok, room_id -- cgit 1.5.1 From 5016b162fcf0372fe35404c64f80aeaf21461f31 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 6 Apr 2020 09:58:42 +0100 Subject: Move client command handling out of TCP protocol (#7185) The aim here is to move the command handling out of the TCP protocol classes and to also merge the client and server command handling (so that we can reuse them for redis protocol). This PR simply moves the client paths to the new `ReplicationCommandHandler`, a future PR will move the server paths too. --- changelog.d/7185.misc | 1 + synapse/app/admin_cmd.py | 12 -- synapse/app/generic_worker.py | 9 +- synapse/replication/tcp/__init__.py | 30 ++- synapse/replication/tcp/client.py | 179 +++--------------- synapse/replication/tcp/handler.py | 252 +++++++++++++++++++++++++ synapse/replication/tcp/protocol.py | 197 +++---------------- synapse/server.py | 8 +- synapse/server.pyi | 7 +- tests/replication/slave/storage/_base.py | 15 +- tests/replication/tcp/streams/_base.py | 38 ++-- tests/replication/tcp/streams/test_receipts.py | 1 - 12 files changed, 378 insertions(+), 371 deletions(-) create mode 100644 changelog.d/7185.misc create mode 100644 synapse/replication/tcp/handler.py (limited to 'tests') diff --git a/changelog.d/7185.misc b/changelog.d/7185.misc new file mode 100644 index 0000000000..deb9ca7021 --- /dev/null +++ b/changelog.d/7185.misc @@ -0,0 +1 @@ +Move client command handling out of TCP protocol. diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py index 1c7c6ec0c8..a37818fe9a 100644 --- a/synapse/app/admin_cmd.py +++ b/synapse/app/admin_cmd.py @@ -43,7 +43,6 @@ from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore from synapse.replication.slave.storage.receipts import SlavedReceiptsStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.replication.slave.storage.room import RoomStore -from synapse.replication.tcp.client import ReplicationClientHandler from synapse.server import HomeServer from synapse.util.logcontext import LoggingContext from synapse.util.versionstring import get_version_string @@ -79,17 +78,6 @@ class AdminCmdServer(HomeServer): def start_listening(self, listeners): pass - def build_tcp_replication(self): - return AdminCmdReplicationHandler(self) - - -class AdminCmdReplicationHandler(ReplicationClientHandler): - async def on_rdata(self, stream_name, token, rows): - pass - - def get_streams_to_replicate(self): - return {} - @defer.inlineCallbacks def export_data_command(hs, args): diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 174bef360f..dcd0709a02 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -64,7 +64,7 @@ from synapse.replication.slave.storage.receipts import SlavedReceiptsStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.replication.slave.storage.room import RoomStore from synapse.replication.slave.storage.transactions import SlavedTransactionStore -from synapse.replication.tcp.client import ReplicationClientHandler +from synapse.replication.tcp.client import ReplicationDataHandler from synapse.replication.tcp.commands import ClearUserSyncsCommand from synapse.replication.tcp.streams import ( AccountDataStream, @@ -603,7 +603,7 @@ class GenericWorkerServer(HomeServer): def remove_pusher(self, app_id, push_key, user_id): self.get_tcp_replication().send_remove_pusher(app_id, push_key, user_id) - def build_tcp_replication(self): + def build_replication_data_handler(self): return GenericWorkerReplicationHandler(self) def build_presence_handler(self): @@ -613,7 +613,7 @@ class GenericWorkerServer(HomeServer): return GenericWorkerTyping(self) -class GenericWorkerReplicationHandler(ReplicationClientHandler): +class GenericWorkerReplicationHandler(ReplicationDataHandler): def __init__(self, hs): super(GenericWorkerReplicationHandler, self).__init__(hs.get_datastore()) @@ -644,9 +644,6 @@ class GenericWorkerReplicationHandler(ReplicationClientHandler): args.update(self.send_handler.stream_positions()) return args - def get_currently_syncing_users(self): - return self.presence_handler.get_currently_syncing_users() - async def process_and_notify(self, stream_name, token, rows): try: if self.send_handler: diff --git a/synapse/replication/tcp/__init__.py b/synapse/replication/tcp/__init__.py index 81c2ea7ee9..523a1358d4 100644 --- a/synapse/replication/tcp/__init__.py +++ b/synapse/replication/tcp/__init__.py @@ -20,11 +20,31 @@ Further details can be found in docs/tcp_replication.rst Structure of the module: - * client.py - the client classes used for workers to connect to master + * handler.py - the classes used to handle sending/receiving commands to + replication * command.py - the definitions of all the valid commands - * protocol.py - contains bot the client and server protocol implementations, - these should not be used directly - * resource.py - the server classes that accepts and handle client connections - * streams.py - the definitons of all the valid streams + * protocol.py - the TCP protocol classes + * resource.py - handles streaming stream updates to replications + * streams/ - the definitons of all the valid streams + +The general interaction of the classes are: + + +---------------------+ + | ReplicationStreamer | + +---------------------+ + | + v + +---------------------------+ +----------------------+ + | ReplicationCommandHandler |---->|ReplicationDataHandler| + +---------------------------+ +----------------------+ + | ^ + v | + +-------------+ + | Protocols | + | (TCP/redis) | + +-------------+ + +Where the ReplicationDataHandler (or subclasses) handles incoming stream +updates. """ diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index e86d9805f1..700ae79158 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -16,26 +16,16 @@ """ import logging -from typing import Dict, List, Optional +from typing import TYPE_CHECKING, Dict -from twisted.internet import defer from twisted.internet.protocol import ReconnectingClientFactory from synapse.replication.slave.storage._base import BaseSlavedStore -from synapse.replication.tcp.protocol import ( - AbstractReplicationClientHandler, - ClientReplicationStreamProtocol, -) - -from .commands import ( - Command, - FederationAckCommand, - InvalidateCacheCommand, - RemoteServerUpCommand, - RemovePusherCommand, - UserIpCommand, - UserSyncCommand, -) +from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol + +if TYPE_CHECKING: + from synapse.server import HomeServer + from synapse.replication.tcp.handler import ReplicationCommandHandler logger = logging.getLogger(__name__) @@ -44,16 +34,20 @@ class ReplicationClientFactory(ReconnectingClientFactory): """Factory for building connections to the master. Will reconnect if the connection is lost. - Accepts a handler that will be called when new data is available or data - is required. + Accepts a handler that is passed to `ClientReplicationStreamProtocol`. """ initialDelay = 0.1 maxDelay = 1 # Try at least once every N seconds - def __init__(self, hs, client_name, handler: AbstractReplicationClientHandler): + def __init__( + self, + hs: "HomeServer", + client_name: str, + command_handler: "ReplicationCommandHandler", + ): self.client_name = client_name - self.handler = handler + self.command_handler = command_handler self.server_name = hs.config.server_name self.hs = hs self._clock = hs.get_clock() # As self.clock is defined in super class @@ -66,7 +60,11 @@ class ReplicationClientFactory(ReconnectingClientFactory): def buildProtocol(self, addr): logger.info("Connected to replication: %r", addr) return ClientReplicationStreamProtocol( - self.hs, self.client_name, self.server_name, self._clock, self.handler, + self.hs, + self.client_name, + self.server_name, + self._clock, + self.command_handler, ) def clientConnectionLost(self, connector, reason): @@ -78,41 +76,17 @@ class ReplicationClientFactory(ReconnectingClientFactory): ReconnectingClientFactory.clientConnectionFailed(self, connector, reason) -class ReplicationClientHandler(AbstractReplicationClientHandler): - """A base handler that can be passed to the ReplicationClientFactory. +class ReplicationDataHandler: + """Handles incoming stream updates from replication. - By default proxies incoming replication data to the SlaveStore. + This instance notifies the slave data store about updates. Can be subclassed + to handle updates in additional ways. """ def __init__(self, store: BaseSlavedStore): self.store = store - # The current connection. None if we are currently (re)connecting - self.connection = None - - # Any pending commands to be sent once a new connection has been - # established - self.pending_commands = [] # type: List[Command] - - # Map from string -> deferred, to wake up when receiveing a SYNC with - # the given string. - # Used for tests. - self.awaiting_syncs = {} # type: Dict[str, defer.Deferred] - - # The factory used to create connections. - self.factory = None # type: Optional[ReplicationClientFactory] - - def start_replication(self, hs): - """Helper method to start a replication connection to the remote server - using TCP. - """ - client_name = hs.config.worker_name - self.factory = ReplicationClientFactory(hs, client_name, self) - host = hs.config.worker_replication_host - port = hs.config.worker_replication_port - hs.get_reactor().connectTCP(host, port, self.factory) - - async def on_rdata(self, stream_name, token, rows): + async def on_rdata(self, stream_name: str, token: int, rows: list): """Called to handle a batch of replication data with a given stream token. By default this just pokes the slave store. Can be overridden in subclasses to @@ -124,30 +98,8 @@ class ReplicationClientHandler(AbstractReplicationClientHandler): rows (list): a list of Stream.ROW_TYPE objects as returned by Stream.parse_row. """ - logger.debug("Received rdata %s -> %s", stream_name, token) self.store.process_replication_rows(stream_name, token, rows) - async def on_position(self, stream_name, token): - """Called when we get new position data. By default this just pokes - the slave store. - - Can be overriden in subclasses to handle more. - """ - self.store.process_replication_rows(stream_name, token, []) - - def on_sync(self, data): - """When we received a SYNC we wake up any deferreds that were waiting - for the sync with the given data. - - Used by tests. - """ - d = self.awaiting_syncs.pop(data, None) - if d: - d.callback(data) - - def on_remote_server_up(self, server: str): - """Called when get a new REMOTE_SERVER_UP command.""" - def get_streams_to_replicate(self) -> Dict[str, int]: """Called when a new connection has been established and we need to subscribe to streams. @@ -163,85 +115,10 @@ class ReplicationClientHandler(AbstractReplicationClientHandler): args["account_data"] = user_account_data elif room_account_data: args["account_data"] = room_account_data - return args - def get_currently_syncing_users(self): - """Get the list of currently syncing users (if any). This is called - when a connection has been established and we need to send the - currently syncing users. (Overriden by the synchrotron's only) - """ - return [] - - def send_command(self, cmd): - """Send a command to master (when we get establish a connection if we - don't have one already.) - """ - if self.connection: - self.connection.send_command(cmd) - else: - logger.warning("Queuing command as not connected: %r", cmd.NAME) - self.pending_commands.append(cmd) - - def send_federation_ack(self, token): - """Ack data for the federation stream. This allows the master to drop - data stored purely in memory. - """ - self.send_command(FederationAckCommand(token)) - - def send_user_sync(self, instance_id, user_id, is_syncing, last_sync_ms): - """Poke the master that a user has started/stopped syncing. - """ - self.send_command( - UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms) - ) - - def send_remove_pusher(self, app_id, push_key, user_id): - """Poke the master to remove a pusher for a user - """ - cmd = RemovePusherCommand(app_id, push_key, user_id) - self.send_command(cmd) - - def send_invalidate_cache(self, cache_func, keys): - """Poke the master to invalidate a cache. - """ - cmd = InvalidateCacheCommand(cache_func.__name__, keys) - self.send_command(cmd) - - def send_user_ip(self, user_id, access_token, ip, user_agent, device_id, last_seen): - """Tell the master that the user made a request. - """ - cmd = UserIpCommand(user_id, access_token, ip, user_agent, device_id, last_seen) - self.send_command(cmd) - - def send_remote_server_up(self, server: str): - self.send_command(RemoteServerUpCommand(server)) - - def await_sync(self, data): - """Returns a deferred that is resolved when we receive a SYNC command - with given data. - - [Not currently] used by tests. - """ - return self.awaiting_syncs.setdefault(data, defer.Deferred()) - - def update_connection(self, connection): - """Called when a connection has been established (or lost with None). - """ - self.connection = connection - if connection: - for cmd in self.pending_commands: - connection.send_command(cmd) - self.pending_commands = [] - - def finished_connecting(self): - """Called when we have successfully subscribed and caught up to all - streams we're interested in. - """ - logger.info("Finished connecting to server") + async def on_position(self, stream_name: str, token: int): + self.store.process_replication_rows(stream_name, token, []) - # We don't reset the delay any earlier as otherwise if there is a - # problem during start up we'll end up tight looping connecting to the - # server. - if self.factory: - self.factory.resetDelay() + def on_remote_server_up(self, server: str): + """Called when get a new REMOTE_SERVER_UP command.""" diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py new file mode 100644 index 0000000000..12a1cfd6d1 --- /dev/null +++ b/synapse/replication/tcp/handler.py @@ -0,0 +1,252 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Vector Creations Ltd +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Any, Callable, Dict, List, Optional, Set + +from prometheus_client import Counter + +from synapse.replication.tcp.client import ReplicationClientFactory +from synapse.replication.tcp.commands import ( + Command, + FederationAckCommand, + InvalidateCacheCommand, + PositionCommand, + RdataCommand, + RemoteServerUpCommand, + RemovePusherCommand, + SyncCommand, + UserIpCommand, + UserSyncCommand, +) +from synapse.replication.tcp.streams import STREAMS_MAP, Stream +from synapse.util.async_helpers import Linearizer + +logger = logging.getLogger(__name__) + + +# number of updates received for each RDATA stream +inbound_rdata_count = Counter( + "synapse_replication_tcp_protocol_inbound_rdata_count", "", ["stream_name"] +) + + +class ReplicationCommandHandler: + """Handles incoming commands from replication as well as sending commands + back out to connections. + """ + + def __init__(self, hs): + self._replication_data_handler = hs.get_replication_data_handler() + self._presence_handler = hs.get_presence_handler() + + # Set of streams that we've caught up with. + self._streams_connected = set() # type: Set[str] + + self._streams = { + stream.NAME: stream(hs) for stream in STREAMS_MAP.values() + } # type: Dict[str, Stream] + + self._position_linearizer = Linearizer("replication_position") + + # Map of stream to batched updates. See RdataCommand for info on how + # batching works. + self._pending_batches = {} # type: Dict[str, List[Any]] + + # The factory used to create connections. + self._factory = None # type: Optional[ReplicationClientFactory] + + # The current connection. None if we are currently (re)connecting + self._connection = None + + def start_replication(self, hs): + """Helper method to start a replication connection to the remote server + using TCP. + """ + client_name = hs.config.worker_name + self._factory = ReplicationClientFactory(hs, client_name, self) + host = hs.config.worker_replication_host + port = hs.config.worker_replication_port + hs.get_reactor().connectTCP(host, port, self._factory) + + async def on_RDATA(self, cmd: RdataCommand): + stream_name = cmd.stream_name + inbound_rdata_count.labels(stream_name).inc() + + try: + row = STREAMS_MAP[stream_name].parse_row(cmd.row) + except Exception: + logger.exception("Failed to parse RDATA: %r %r", stream_name, cmd.row) + raise + + if cmd.token is None or stream_name not in self._streams_connected: + # I.e. either this is part of a batch of updates for this stream (in + # which case batch until we get an update for the stream with a non + # None token) or we're currently connecting so we queue up rows. + self._pending_batches.setdefault(stream_name, []).append(row) + else: + # Check if this is the last of a batch of updates + rows = self._pending_batches.pop(stream_name, []) + rows.append(row) + await self.on_rdata(stream_name, cmd.token, rows) + + async def on_rdata(self, stream_name: str, token: int, rows: list): + """Called to handle a batch of replication data with a given stream token. + + Args: + stream_name: name of the replication stream for this batch of rows + token: stream token for this batch of rows + rows: a list of Stream.ROW_TYPE objects as returned by + Stream.parse_row. + """ + logger.debug("Received rdata %s -> %s", stream_name, token) + await self._replication_data_handler.on_rdata(stream_name, token, rows) + + async def on_POSITION(self, cmd: PositionCommand): + stream = self._streams.get(cmd.stream_name) + if not stream: + logger.error("Got POSITION for unknown stream: %s", cmd.stream_name) + return + + # We protect catching up with a linearizer in case the replication + # connection reconnects under us. + with await self._position_linearizer.queue(cmd.stream_name): + # We're about to go and catch up with the stream, so mark as connecting + # to stop RDATA being handled at the same time by removing stream from + # list of connected streams. We also clear any batched up RDATA from + # before we got the POSITION. + self._streams_connected.discard(cmd.stream_name) + self._pending_batches.clear() + + # Find where we previously streamed up to. + current_token = self._replication_data_handler.get_streams_to_replicate().get( + cmd.stream_name + ) + if current_token is None: + logger.warning( + "Got POSITION for stream we're not subscribed to: %s", + cmd.stream_name, + ) + return + + # Fetch all updates between then and now. + limited = True + while limited: + updates, current_token, limited = await stream.get_updates_since( + current_token, cmd.token + ) + if updates: + await self.on_rdata( + cmd.stream_name, + current_token, + [stream.parse_row(update[1]) for update in updates], + ) + + # We've now caught up to position sent to us, notify handler. + await self._replication_data_handler.on_position(cmd.stream_name, cmd.token) + + # Handle any RDATA that came in while we were catching up. + rows = self._pending_batches.pop(cmd.stream_name, []) + if rows: + await self._replication_data_handler.on_rdata( + cmd.stream_name, rows[-1].token, rows + ) + + self._streams_connected.add(cmd.stream_name) + + async def on_SYNC(self, cmd: SyncCommand): + pass + + async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand): + """"Called when get a new REMOTE_SERVER_UP command.""" + self._replication_data_handler.on_remote_server_up(cmd.data) + + def get_currently_syncing_users(self): + """Get the list of currently syncing users (if any). This is called + when a connection has been established and we need to send the + currently syncing users. + """ + return self._presence_handler.get_currently_syncing_users() + + def update_connection(self, connection): + """Called when a connection has been established (or lost with None). + """ + self._connection = connection + + def finished_connecting(self): + """Called when we have successfully subscribed and caught up to all + streams we're interested in. + """ + logger.info("Finished connecting to server") + + # We don't reset the delay any earlier as otherwise if there is a + # problem during start up we'll end up tight looping connecting to the + # server. + if self._factory: + self._factory.resetDelay() + + def send_command(self, cmd: Command): + """Send a command to master (when we get establish a connection if we + don't have one already.) + """ + if self._connection: + self._connection.send_command(cmd) + else: + logger.warning("Dropping command as not connected: %r", cmd.NAME) + + def send_federation_ack(self, token: int): + """Ack data for the federation stream. This allows the master to drop + data stored purely in memory. + """ + self.send_command(FederationAckCommand(token)) + + def send_user_sync( + self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int + ): + """Poke the master that a user has started/stopped syncing. + """ + self.send_command( + UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms) + ) + + def send_remove_pusher(self, app_id: str, push_key: str, user_id: str): + """Poke the master to remove a pusher for a user + """ + cmd = RemovePusherCommand(app_id, push_key, user_id) + self.send_command(cmd) + + def send_invalidate_cache(self, cache_func: Callable, keys: tuple): + """Poke the master to invalidate a cache. + """ + cmd = InvalidateCacheCommand(cache_func.__name__, keys) + self.send_command(cmd) + + def send_user_ip( + self, + user_id: str, + access_token: str, + ip: str, + user_agent: str, + device_id: str, + last_seen: int, + ): + """Tell the master that the user made a request. + """ + cmd = UserIpCommand(user_id, access_token, ip, user_agent, device_id, last_seen) + self.send_command(cmd) + + def send_remote_server_up(self, server: str): + self.send_command(RemoteServerUpCommand(server)) diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index dae246825f..f2a37f568e 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -46,12 +46,11 @@ indicate which side is sending, these are *not* included on the wire:: > ERROR server stopping * connection closed by server * """ -import abc import fcntl import logging import struct from collections import defaultdict -from typing import Any, DefaultDict, Dict, List, Set +from typing import TYPE_CHECKING, DefaultDict, List from six import iteritems @@ -78,13 +77,12 @@ from synapse.replication.tcp.commands import ( SyncCommand, UserSyncCommand, ) -from synapse.replication.tcp.streams import STREAMS_MAP, Stream from synapse.types import Collection from synapse.util import Clock from synapse.util.stringutils import random_string -MYPY = False -if MYPY: +if TYPE_CHECKING: + from synapse.replication.tcp.handler import ReplicationCommandHandler from synapse.server import HomeServer @@ -475,71 +473,6 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): self.streamer.lost_connection(self) -class AbstractReplicationClientHandler(metaclass=abc.ABCMeta): - """ - The interface for the handler that should be passed to - ClientReplicationStreamProtocol - """ - - @abc.abstractmethod - async def on_rdata(self, stream_name, token, rows): - """Called to handle a batch of replication data with a given stream token. - - Args: - stream_name (str): name of the replication stream for this batch of rows - token (int): stream token for this batch of rows - rows (list): a list of Stream.ROW_TYPE objects as returned by - Stream.parse_row. - """ - raise NotImplementedError() - - @abc.abstractmethod - async def on_position(self, stream_name, token): - """Called when we get new position data.""" - raise NotImplementedError() - - @abc.abstractmethod - def on_sync(self, data): - """Called when get a new SYNC command.""" - raise NotImplementedError() - - @abc.abstractmethod - async def on_remote_server_up(self, server: str): - """Called when get a new REMOTE_SERVER_UP command.""" - raise NotImplementedError() - - @abc.abstractmethod - def get_streams_to_replicate(self): - """Called when a new connection has been established and we need to - subscribe to streams. - - Returns: - map from stream name to the most recent update we have for - that stream (ie, the point we want to start replicating from) - """ - raise NotImplementedError() - - @abc.abstractmethod - def get_currently_syncing_users(self): - """Get the list of currently syncing users (if any). This is called - when a connection has been established and we need to send the - currently syncing users.""" - raise NotImplementedError() - - @abc.abstractmethod - def update_connection(self, connection): - """Called when a connection has been established (or lost with None). - """ - raise NotImplementedError() - - @abc.abstractmethod - def finished_connecting(self): - """Called when we have successfully subscribed and caught up to all - streams we're interested in. - """ - raise NotImplementedError() - - class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): VALID_INBOUND_COMMANDS = VALID_SERVER_COMMANDS VALID_OUTBOUND_COMMANDS = VALID_CLIENT_COMMANDS @@ -550,7 +483,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): client_name: str, server_name: str, clock: Clock, - handler: AbstractReplicationClientHandler, + command_handler: "ReplicationCommandHandler", ): BaseReplicationStreamProtocol.__init__(self, clock) @@ -558,20 +491,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): self.client_name = client_name self.server_name = server_name - self.handler = handler - - self.streams = { - stream.NAME: stream(hs) for stream in STREAMS_MAP.values() - } # type: Dict[str, Stream] - - # Set of stream names that have been subscribe to, but haven't yet - # caught up with. This is used to track when the client has been fully - # connected to the remote. - self.streams_connecting = set(STREAMS_MAP) # type: Set[str] - - # Map of stream to batched updates. See RdataCommand for info on how - # batching works. - self.pending_batches = {} # type: Dict[str, List[Any]] + self.handler = command_handler def connectionMade(self): self.send_command(NameCommand(self.client_name)) @@ -589,89 +509,39 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): # We've now finished connecting to so inform the client handler self.handler.update_connection(self) + self.handler.finished_connecting() - async def on_SERVER(self, cmd): - if cmd.data != self.server_name: - logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data) - self.send_error("Wrong remote") - - async def on_RDATA(self, cmd): - stream_name = cmd.stream_name - inbound_rdata_count.labels(stream_name).inc() - - try: - row = STREAMS_MAP[stream_name].parse_row(cmd.row) - except Exception: - logger.exception( - "[%s] Failed to parse RDATA: %r %r", self.id(), stream_name, cmd.row - ) - raise - - if cmd.token is None or stream_name in self.streams_connecting: - # I.e. this is part of a batch of updates for this stream. Batch - # until we get an update for the stream with a non None token - self.pending_batches.setdefault(stream_name, []).append(row) - else: - # Check if this is the last of a batch of updates - rows = self.pending_batches.pop(stream_name, []) - rows.append(row) - await self.handler.on_rdata(stream_name, cmd.token, rows) - - async def on_POSITION(self, cmd: PositionCommand): - stream = self.streams.get(cmd.stream_name) - if not stream: - logger.error("Got POSITION for unknown stream: %s", cmd.stream_name) - return - - # Find where we previously streamed up to. - current_token = self.handler.get_streams_to_replicate().get(cmd.stream_name) - if current_token is None: - logger.warning( - "Got POSITION for stream we're not subscribed to: %s", cmd.stream_name - ) - return - - # Fetch all updates between then and now. - limited = True - while limited: - updates, current_token, limited = await stream.get_updates_since( - current_token, cmd.token - ) - - # Check if the connection was closed underneath us, if so we bail - # rather than risk having concurrent catch ups going on. - if self.state == ConnectionStates.CLOSED: - return - - if updates: - await self.handler.on_rdata( - cmd.stream_name, - current_token, - [stream.parse_row(update[1]) for update in updates], - ) + async def handle_command(self, cmd: Command): + """Handle a command we have received over the replication stream. - # We've now caught up to position sent to us, notify handler. - await self.handler.on_position(cmd.stream_name, cmd.token) + Delegates to `command_handler.on_`, which must return an + awaitable. - self.streams_connecting.discard(cmd.stream_name) - if not self.streams_connecting: - self.handler.finished_connecting() + Args: + cmd: received command + """ + handled = False - # Check if the connection was closed underneath us, if so we bail - # rather than risk having concurrent catch ups going on. - if self.state == ConnectionStates.CLOSED: - return + # First call any command handlers on this instance. These are for TCP + # specific handling. + cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None) + if cmd_func: + await cmd_func(cmd) + handled = True - # Handle any RDATA that came in while we were catching up. - rows = self.pending_batches.pop(cmd.stream_name, []) - if rows: - await self.handler.on_rdata(cmd.stream_name, rows[-1].token, rows) + # Then call out to the handler. + cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None) + if cmd_func: + await cmd_func(cmd) + handled = True - async def on_SYNC(self, cmd): - self.handler.on_sync(cmd.data) + if not handled: + logger.warning("Unhandled command: %r", cmd) - async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand): - self.handler.on_remote_server_up(cmd.data) + async def on_SERVER(self, cmd): + if cmd.data != self.server_name: + logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data) + self.send_error("Wrong remote") def replicate(self): """Send the subscription request to the server @@ -768,8 +638,3 @@ tcp_outbound_commands = LaterGauge( for k, count in iteritems(p.outbound_commands_counter) }, ) - -# number of updates received for each RDATA stream -inbound_rdata_count = Counter( - "synapse_replication_tcp_protocol_inbound_rdata_count", "", ["stream_name"] -) diff --git a/synapse/server.py b/synapse/server.py index 9228e1c892..9d273c980c 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -87,6 +87,8 @@ from synapse.http.matrixfederationclient import MatrixFederationHttpClient from synapse.notifier import Notifier from synapse.push.action_generator import ActionGenerator from synapse.push.pusherpool import PusherPool +from synapse.replication.tcp.client import ReplicationDataHandler +from synapse.replication.tcp.handler import ReplicationCommandHandler from synapse.replication.tcp.resource import ReplicationStreamer from synapse.rest.media.v1.media_repository import ( MediaRepository, @@ -206,6 +208,7 @@ class HomeServer(object): "password_policy_handler", "storage", "replication_streamer", + "replication_data_handler", ] REQUIRED_ON_MASTER_STARTUP = ["user_directory_handler", "stats_handler"] @@ -468,7 +471,7 @@ class HomeServer(object): return ReadMarkerHandler(self) def build_tcp_replication(self): - raise NotImplementedError() + return ReplicationCommandHandler(self) def build_action_generator(self): return ActionGenerator(self) @@ -562,6 +565,9 @@ class HomeServer(object): def build_replication_streamer(self) -> ReplicationStreamer: return ReplicationStreamer(self) + def build_replication_data_handler(self): + return ReplicationDataHandler(self.get_datastore()) + def remove_pusher(self, app_id, push_key, user_id): return self.get_pusherpool().remove_pusher(app_id, push_key, user_id) diff --git a/synapse/server.pyi b/synapse/server.pyi index 9d1dfa71e7..9013e9bac9 100644 --- a/synapse/server.pyi +++ b/synapse/server.pyi @@ -19,6 +19,7 @@ import synapse.handlers.set_password import synapse.http.client import synapse.notifier import synapse.replication.tcp.client +import synapse.replication.tcp.handler import synapse.rest.media.v1.media_repository import synapse.server_notices.server_notices_manager import synapse.server_notices.server_notices_sender @@ -106,7 +107,11 @@ class HomeServer(object): pass def get_tcp_replication( self, - ) -> synapse.replication.tcp.client.ReplicationClientHandler: + ) -> synapse.replication.tcp.handler.ReplicationCommandHandler: + pass + def get_replication_data_handler( + self, + ) -> synapse.replication.tcp.client.ReplicationDataHandler: pass def get_federation_registry( self, diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index 2a1e7c7166..8902a5ab69 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -17,8 +17,9 @@ from mock import Mock, NonCallableMock from synapse.replication.tcp.client import ( ReplicationClientFactory, - ReplicationClientHandler, + ReplicationDataHandler, ) +from synapse.replication.tcp.handler import ReplicationCommandHandler from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory from synapse.storage.database import make_conn @@ -51,15 +52,19 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase): self.event_id = 0 server_factory = ReplicationStreamProtocolFactory(self.hs) - self.streamer = server_factory.streamer + self.streamer = hs.get_replication_streamer() - handler_factory = Mock() - self.replication_handler = ReplicationClientHandler(self.slaved_store) - self.replication_handler.factory = handler_factory + # We now do some gut wrenching so that we have a client that is based + # off of the slave store rather than the main store. + self.replication_handler = ReplicationCommandHandler(self.hs) + self.replication_handler._replication_data_handler = ReplicationDataHandler( + self.slaved_store + ) client_factory = ReplicationClientFactory( self.hs, "client_name", self.replication_handler ) + client_factory.handler = self.replication_handler server = server_factory.buildProtocol(None) client = client_factory.buildProtocol(None) diff --git a/tests/replication/tcp/streams/_base.py b/tests/replication/tcp/streams/_base.py index a755fe2879..32238fe79a 100644 --- a/tests/replication/tcp/streams/_base.py +++ b/tests/replication/tcp/streams/_base.py @@ -15,7 +15,7 @@ from mock import Mock -from synapse.replication.tcp.commands import ReplicateCommand +from synapse.replication.tcp.handler import ReplicationCommandHandler from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory @@ -26,15 +26,20 @@ from tests.server import FakeTransport class BaseStreamTestCase(unittest.HomeserverTestCase): """Base class for tests of the replication streams""" + def make_homeserver(self, reactor, clock): + self.test_handler = Mock(wraps=TestReplicationDataHandler()) + return self.setup_test_homeserver(replication_data_handler=self.test_handler) + def prepare(self, reactor, clock, hs): # build a replication server - server_factory = ReplicationStreamProtocolFactory(self.hs) - self.streamer = server_factory.streamer + server_factory = ReplicationStreamProtocolFactory(hs) + self.streamer = hs.get_replication_streamer() self.server = server_factory.buildProtocol(None) - self.test_handler = Mock(wraps=TestReplicationClientHandler()) + repl_handler = ReplicationCommandHandler(hs) + repl_handler.handler = self.test_handler self.client = ClientReplicationStreamProtocol( - hs, "client", "test", clock, self.test_handler, + hs, "client", "test", clock, repl_handler, ) self._client_transport = None @@ -69,13 +74,9 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): self.streamer.on_notifier_poke() self.pump(0.1) - def replicate_stream(self): - """Make the client end a REPLICATE command to set up a subscription to a stream""" - self.client.send_command(ReplicateCommand()) - -class TestReplicationClientHandler(object): - """Drop-in for ReplicationClientHandler which just collects RDATA rows""" +class TestReplicationDataHandler: + """Drop-in for ReplicationDataHandler which just collects RDATA rows""" def __init__(self): self.streams = set() @@ -88,18 +89,9 @@ class TestReplicationClientHandler(object): positions[stream] = max(token, positions.get(stream, 0)) return positions - def get_currently_syncing_users(self): - return [] - - def update_connection(self, connection): - pass - - def finished_connecting(self): - pass - - async def on_position(self, stream_name, token): - """Called when we get new position data.""" - async def on_rdata(self, stream_name, token, rows): for r in rows: self._received_rdata_rows.append((stream_name, token, r)) + + async def on_position(self, stream_name, token): + pass diff --git a/tests/replication/tcp/streams/test_receipts.py b/tests/replication/tcp/streams/test_receipts.py index 0ec0825a0e..a0206f7363 100644 --- a/tests/replication/tcp/streams/test_receipts.py +++ b/tests/replication/tcp/streams/test_receipts.py @@ -24,7 +24,6 @@ class ReceiptsStreamTestCase(BaseStreamTestCase): self.reconnect() # make the client subscribe to the receipts stream - self.replicate_stream() self.test_handler.streams.add("receipts") # tell the master to send a new receipt -- cgit 1.5.1 From f31e65a749f84f8b3278c91784509d908d4fb342 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 7 Apr 2020 23:06:39 +0100 Subject: bg update to clear out duplicate outbound_device_list_pokes (#7193) We seem to have some duplicates, which could do with being cleared out. --- changelog.d/7193.misc | 1 + synapse/storage/data_stores/main/client_ips.py | 16 ++--- synapse/storage/data_stores/main/devices.py | 73 ++++++++++++++++++- .../delta/58/02remove_dup_outbound_pokes.sql | 22 ++++++ synapse/storage/database.py | 83 +++++++++++++++++++++- tests/storage/test_database.py | 52 ++++++++++++++ 6 files changed, 234 insertions(+), 13 deletions(-) create mode 100644 changelog.d/7193.misc create mode 100644 synapse/storage/data_stores/main/schema/delta/58/02remove_dup_outbound_pokes.sql create mode 100644 tests/storage/test_database.py (limited to 'tests') diff --git a/changelog.d/7193.misc b/changelog.d/7193.misc new file mode 100644 index 0000000000..383a738e64 --- /dev/null +++ b/changelog.d/7193.misc @@ -0,0 +1 @@ +Add a background database update job to clear out duplicate `device_lists_outbound_pokes`. diff --git a/synapse/storage/data_stores/main/client_ips.py b/synapse/storage/data_stores/main/client_ips.py index e1ccb27142..92bc06919b 100644 --- a/synapse/storage/data_stores/main/client_ips.py +++ b/synapse/storage/data_stores/main/client_ips.py @@ -21,7 +21,7 @@ from twisted.internet import defer from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore -from synapse.storage.database import Database +from synapse.storage.database import Database, make_tuple_comparison_clause from synapse.util.caches import CACHE_SIZE_FACTOR from synapse.util.caches.descriptors import Cache @@ -303,16 +303,10 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): # we'll just end up updating the same device row multiple # times, which is fine. - if self.database_engine.supports_tuple_comparison: - where_clause = "(user_id, device_id) > (?, ?)" - where_args = [last_user_id, last_device_id] - else: - # We explicitly do a `user_id >= ? AND (...)` here to ensure - # that an index is used, as doing `user_id > ? OR (user_id = ? AND ...)` - # makes it hard for query optimiser to tell that it can use the - # index on user_id - where_clause = "user_id >= ? AND (user_id > ? OR device_id > ?)" - where_args = [last_user_id, last_user_id, last_device_id] + where_clause, where_args = make_tuple_comparison_clause( + self.database_engine, + [("user_id", last_user_id), ("device_id", last_device_id)], + ) sql = """ SELECT diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py index 4c5bea4a5c..ee3a2ab031 100644 --- a/synapse/storage/data_stores/main/devices.py +++ b/synapse/storage/data_stores/main/devices.py @@ -32,7 +32,11 @@ from synapse.logging.opentracing import ( ) from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause -from synapse.storage.database import Database, LoggingTransaction +from synapse.storage.database import ( + Database, + LoggingTransaction, + make_tuple_comparison_clause, +) from synapse.types import Collection, get_verify_key_from_cross_signing_key from synapse.util.caches.descriptors import ( Cache, @@ -49,6 +53,8 @@ DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = ( "drop_device_list_streams_non_unique_indexes" ) +BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes" + class DeviceWorkerStore(SQLBaseStore): def get_device(self, user_id, device_id): @@ -714,6 +720,11 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): self._drop_device_list_streams_non_unique_indexes, ) + # clear out duplicate device list outbound pokes + self.db.updates.register_background_update_handler( + BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, self._remove_duplicate_outbound_pokes, + ) + @defer.inlineCallbacks def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size): def f(conn): @@ -728,6 +739,66 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): ) return 1 + async def _remove_duplicate_outbound_pokes(self, progress, batch_size): + # for some reason, we have accumulated duplicate entries in + # device_lists_outbound_pokes, which makes prune_outbound_device_list_pokes less + # efficient. + # + # For each duplicate, we delete all the existing rows and put one back. + + KEY_COLS = ["stream_id", "destination", "user_id", "device_id"] + last_row = progress.get( + "last_row", + {"stream_id": 0, "destination": "", "user_id": "", "device_id": ""}, + ) + + def _txn(txn): + clause, args = make_tuple_comparison_clause( + self.db.engine, [(x, last_row[x]) for x in KEY_COLS] + ) + sql = """ + SELECT stream_id, destination, user_id, device_id, MAX(ts) AS ts + FROM device_lists_outbound_pokes + WHERE %s + GROUP BY %s + HAVING count(*) > 1 + ORDER BY %s + LIMIT ? + """ % ( + clause, # WHERE + ",".join(KEY_COLS), # GROUP BY + ",".join(KEY_COLS), # ORDER BY + ) + txn.execute(sql, args + [batch_size]) + rows = self.db.cursor_to_dict(txn) + + row = None + for row in rows: + self.db.simple_delete_txn( + txn, "device_lists_outbound_pokes", {x: row[x] for x in KEY_COLS}, + ) + + row["sent"] = False + self.db.simple_insert_txn( + txn, "device_lists_outbound_pokes", row, + ) + + if row: + self.db.updates._background_update_progress_txn( + txn, BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, {"last_row": row}, + ) + + return len(rows) + + rows = await self.db.runInteraction(BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, _txn) + + if not rows: + await self.db.updates._end_background_update( + BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES + ) + + return rows + class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): def __init__(self, database: Database, db_conn, hs): diff --git a/synapse/storage/data_stores/main/schema/delta/58/02remove_dup_outbound_pokes.sql b/synapse/storage/data_stores/main/schema/delta/58/02remove_dup_outbound_pokes.sql new file mode 100644 index 0000000000..fdc39e9ba5 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/58/02remove_dup_outbound_pokes.sql @@ -0,0 +1,22 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + /* for some reason, we have accumulated duplicate entries in + * device_lists_outbound_pokes, which makes prune_outbound_device_list_pokes less + * efficient. + */ + +INSERT INTO background_updates (ordering, update_name, progress_json) + VALUES (5800, 'remove_dup_outbound_pokes', '{}'); diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 715c0346dd..a7cd97b0b0 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -17,7 +17,17 @@ import logging import time from time import monotonic as monotonic_time -from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple +from typing import ( + Any, + Callable, + Dict, + Iterable, + Iterator, + List, + Optional, + Tuple, + TypeVar, +) from six import iteritems, iterkeys, itervalues from six.moves import intern, range @@ -1557,3 +1567,74 @@ def make_in_list_sql_clause( return "%s = ANY(?)" % (column,), [list(iterable)] else: return "%s IN (%s)" % (column, ",".join("?" for _ in iterable)), list(iterable) + + +KV = TypeVar("KV") + + +def make_tuple_comparison_clause( + database_engine: BaseDatabaseEngine, keys: List[Tuple[str, KV]] +) -> Tuple[str, List[KV]]: + """Returns a tuple comparison SQL clause + + Depending what the SQL engine supports, builds a SQL clause that looks like either + "(a, b) > (?, ?)", or "(a > ?) OR (a == ? AND b > ?)". + + Args: + database_engine + keys: A set of (column, value) pairs to be compared. + + Returns: + A tuple of SQL query and the args + """ + if database_engine.supports_tuple_comparison: + return ( + "(%s) > (%s)" % (",".join(k[0] for k in keys), ",".join("?" for _ in keys)), + [k[1] for k in keys], + ) + + # we want to build a clause + # (a > ?) OR + # (a == ? AND b > ?) OR + # (a == ? AND b == ? AND c > ?) + # ... + # (a == ? AND b == ? AND ... AND z > ?) + # + # or, equivalently: + # + # (a > ? OR (a == ? AND + # (b > ? OR (b == ? AND + # ... + # (y > ? OR (y == ? AND + # z > ? + # )) + # ... + # )) + # )) + # + # which itself is equivalent to (and apparently easier for the query optimiser): + # + # (a >= ? AND (a > ? OR + # (b >= ? AND (b > ? OR + # ... + # (y >= ? AND (y > ? OR + # z > ? + # )) + # ... + # )) + # )) + # + # + + clause = "" + args = [] # type: List[KV] + for k, v in keys[:-1]: + clause = clause + "(%s >= ? AND (%s > ? OR " % (k, k) + args.extend([v, v]) + + (k, v) = keys[-1] + clause += "%s > ?" % (k,) + args.append(v) + + clause += "))" * (len(keys) - 1) + return clause, args diff --git a/tests/storage/test_database.py b/tests/storage/test_database.py new file mode 100644 index 0000000000..5a77c84962 --- /dev/null +++ b/tests/storage/test_database.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.storage.database import make_tuple_comparison_clause +from synapse.storage.engines import BaseDatabaseEngine + +from tests import unittest + + +def _stub_db_engine(**kwargs) -> BaseDatabaseEngine: + # returns a DatabaseEngine, circumventing the abc mechanism + # any kwargs are set as attributes on the class before instantiating it + t = type( + "TestBaseDatabaseEngine", + (BaseDatabaseEngine,), + dict(BaseDatabaseEngine.__dict__), + ) + # defeat the abc mechanism + t.__abstractmethods__ = set() + for k, v in kwargs.items(): + setattr(t, k, v) + return t(None, None) + + +class TupleComparisonClauseTestCase(unittest.TestCase): + def test_native_tuple_comparison(self): + db_engine = _stub_db_engine(supports_tuple_comparison=True) + clause, args = make_tuple_comparison_clause(db_engine, [("a", 1), ("b", 2)]) + self.assertEqual(clause, "(a,b) > (?,?)") + self.assertEqual(args, [1, 2]) + + def test_emulated_tuple_comparison(self): + db_engine = _stub_db_engine(supports_tuple_comparison=False) + clause, args = make_tuple_comparison_clause( + db_engine, [("a", 1), ("b", 2), ("c", 3)] + ) + self.assertEqual( + clause, "(a >= ? AND (a > ? OR (b >= ? AND (b > ? OR c > ?))))" + ) + self.assertEqual(args, [1, 1, 2, 2, 3]) -- cgit 1.5.1 From b85d7652ff084fee997e0bb44ecd46c2789abbdd Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 9 Apr 2020 13:28:13 -0400 Subject: Do not allow a deactivated user to login via SSO. (#7240) --- changelog.d/7240.bugfix | 1 + synapse/config/sso.py | 7 ++++ synapse/handlers/auth.py | 34 +++++++++++++++--- synapse/handlers/cas_handler.py | 2 +- synapse/handlers/saml_handler.py | 2 +- synapse/module_api/__init__.py | 22 +++++++++++- synapse/res/templates/sso_account_deactivated.html | 10 ++++++ tests/rest/client/v1/test_login.py | 42 ++++++++++++++++++++-- 8 files changed, 110 insertions(+), 10 deletions(-) create mode 100644 changelog.d/7240.bugfix create mode 100644 synapse/res/templates/sso_account_deactivated.html (limited to 'tests') diff --git a/changelog.d/7240.bugfix b/changelog.d/7240.bugfix new file mode 100644 index 0000000000..83b18d3e11 --- /dev/null +++ b/changelog.d/7240.bugfix @@ -0,0 +1 @@ +Do not allow a deactivated user to login via SSO. diff --git a/synapse/config/sso.py b/synapse/config/sso.py index ec3dca9efc..686678a3b7 100644 --- a/synapse/config/sso.py +++ b/synapse/config/sso.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os from typing import Any, Dict import pkg_resources @@ -36,6 +37,12 @@ class SSOConfig(Config): template_dir = pkg_resources.resource_filename("synapse", "res/templates",) self.sso_redirect_confirm_template_dir = template_dir + self.sso_account_deactivated_template = self.read_file( + os.path.join( + self.sso_redirect_confirm_template_dir, "sso_account_deactivated.html" + ), + "sso_account_deactivated_template", + ) self.sso_client_whitelist = sso_config.get("client_whitelist") or [] diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 892adb00b9..fbfbd44a2e 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -161,6 +161,9 @@ class AuthHandler(BaseHandler): self._sso_auth_confirm_template = load_jinja2_templates( hs.config.sso_redirect_confirm_template_dir, ["sso_auth_confirm.html"], )[0] + self._sso_account_deactivated_template = ( + hs.config.sso_account_deactivated_template + ) self._server_name = hs.config.server_name @@ -644,9 +647,6 @@ class AuthHandler(BaseHandler): Returns: defer.Deferred: (unicode) canonical_user_id, or None if zero or multiple matches - - Raises: - UserDeactivatedError if a user is found but is deactivated. """ res = yield self._find_user_id_and_pwd_hash(user_id) if res is not None: @@ -1099,7 +1099,7 @@ class AuthHandler(BaseHandler): request.write(html_bytes) finish_request(request) - def complete_sso_login( + async def complete_sso_login( self, registered_user_id: str, request: SynapseRequest, @@ -1113,6 +1113,32 @@ class AuthHandler(BaseHandler): client_redirect_url: The URL to which to redirect the user at the end of the process. """ + # If the account has been deactivated, do not proceed with the login + # flow. + deactivated = await self.store.get_user_deactivated_status(registered_user_id) + if deactivated: + html = self._sso_account_deactivated_template.encode("utf-8") + + request.setResponseCode(403) + request.setHeader(b"Content-Type", b"text/html; charset=utf-8") + request.setHeader(b"Content-Length", b"%d" % (len(html),)) + request.write(html) + finish_request(request) + return + + self._complete_sso_login(registered_user_id, request, client_redirect_url) + + def _complete_sso_login( + self, + registered_user_id: str, + request: SynapseRequest, + client_redirect_url: str, + ): + """ + The synchronous portion of complete_sso_login. + + This exists purely for backwards compatibility of synapse.module_api.ModuleApi. + """ # Create a login token login_token = self.macaroon_gen.generate_short_term_login_token( registered_user_id diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py index d977badf35..5cb3f9d133 100644 --- a/synapse/handlers/cas_handler.py +++ b/synapse/handlers/cas_handler.py @@ -216,6 +216,6 @@ class CasHandler: localpart=localpart, default_display_name=user_display_name ) - self._auth_handler.complete_sso_login( + await self._auth_handler.complete_sso_login( registered_user_id, request, client_redirect_url ) diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py index 4741c82f61..7c9454b504 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml_handler.py @@ -154,7 +154,7 @@ class SamlHandler: ) else: - self._auth_handler.complete_sso_login(user_id, request, relay_state) + await self._auth_handler.complete_sso_login(user_id, request, relay_state) async def _map_saml_response_to_user( self, resp_bytes: str, client_redirect_url: str diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index c7fffd72f2..afc3598e11 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -220,6 +220,26 @@ class ModuleApi(object): want their access token sent to `client_redirect_url`, or redirect them to that URL with a token directly if the URL matches with one of the whitelisted clients. + This is deprecated in favor of complete_sso_login_async. + + Args: + registered_user_id: The MXID that has been registered as a previous step of + of this SSO login. + request: The request to respond to. + client_redirect_url: The URL to which to offer to redirect the user (or to + redirect them directly if whitelisted). + """ + self._auth_handler._complete_sso_login( + registered_user_id, request, client_redirect_url, + ) + + async def complete_sso_login_async( + self, registered_user_id: str, request: SynapseRequest, client_redirect_url: str + ): + """Complete a SSO login by redirecting the user to a page to confirm whether they + want their access token sent to `client_redirect_url`, or redirect them to that + URL with a token directly if the URL matches with one of the whitelisted clients. + Args: registered_user_id: The MXID that has been registered as a previous step of of this SSO login. @@ -227,6 +247,6 @@ class ModuleApi(object): client_redirect_url: The URL to which to offer to redirect the user (or to redirect them directly if whitelisted). """ - self._auth_handler.complete_sso_login( + await self._auth_handler.complete_sso_login( registered_user_id, request, client_redirect_url, ) diff --git a/synapse/res/templates/sso_account_deactivated.html b/synapse/res/templates/sso_account_deactivated.html new file mode 100644 index 0000000000..4eb8db9fb4 --- /dev/null +++ b/synapse/res/templates/sso_account_deactivated.html @@ -0,0 +1,10 @@ + + + + + SSO account deactivated + + +

This account has been deactivated.

+ + diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index aed8853d6e..1856c7ffd5 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -257,7 +257,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.code, 200, channel.result) -class CASRedirectConfirmTestCase(unittest.HomeserverTestCase): +class CASTestCase(unittest.HomeserverTestCase): servlets = [ login.register_servlets, @@ -274,6 +274,9 @@ class CASRedirectConfirmTestCase(unittest.HomeserverTestCase): "service_url": "https://matrix.goodserver.com:8448", } + cas_user_id = "username" + self.user_id = "@%s:test" % cas_user_id + async def get_raw(uri, args): """Return an example response payload from a call to the `/proxyValidate` endpoint of a CAS server, copied from @@ -282,10 +285,11 @@ class CASRedirectConfirmTestCase(unittest.HomeserverTestCase): This needs to be returned by an async function (as opposed to set as the mock's return value) because the corresponding Synapse code awaits on it. """ - return """ + return ( + """ - username + %s PGTIOU-84678-8a9d... https://proxy2/pgtUrl @@ -294,6 +298,8 @@ class CASRedirectConfirmTestCase(unittest.HomeserverTestCase): """ + % cas_user_id + ) mocked_http_client = Mock(spec=["get_raw"]) mocked_http_client.get_raw.side_effect = get_raw @@ -304,6 +310,9 @@ class CASRedirectConfirmTestCase(unittest.HomeserverTestCase): return self.hs + def prepare(self, reactor, clock, hs): + self.deactivate_account_handler = hs.get_deactivate_account_handler() + def test_cas_redirect_confirm(self): """Tests that the SSO login flow serves a confirmation page before redirecting a user to the redirect URL. @@ -370,3 +379,30 @@ class CASRedirectConfirmTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 302) location_headers = channel.headers.getRawHeaders("Location") self.assertEqual(location_headers[0][: len(redirect_url)], redirect_url) + + @override_config({"sso": {"client_whitelist": ["https://legit-site.com/"]}}) + def test_deactivated_user(self): + """Logging in as a deactivated account should error.""" + redirect_url = "https://legit-site.com/" + + # First login (to create the user). + self._test_redirect(redirect_url) + + # Deactivate the account. + self.get_success( + self.deactivate_account_handler.deactivate_account(self.user_id, False) + ) + + # Request the CAS ticket. + cas_ticket_url = ( + "/_matrix/client/r0/login/cas/ticket?redirectUrl=%s&ticket=ticket" + % (urllib.parse.quote(redirect_url)) + ) + + # Get Synapse to call the fake CAS and serve the template. + request, channel = self.make_request("GET", cas_ticket_url) + self.render(request) + + # Because the user is deactivated they are served an error template. + self.assertEqual(channel.code, 403) + self.assertIn(b"SSO account deactivated", channel.result["body"]) -- cgit 1.5.1 From ac978ab3da48ca1d0c6f40d9857d769e72e960dd Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Thu, 9 Apr 2020 18:45:38 +0100 Subject: Default PL100 to enable encryption in a room (#7230) --- changelog.d/7230.feature | 1 + synapse/handlers/room.py | 1 + tests/rest/client/test_power_levels.py | 205 +++++++++++++++++++++++++++++++++ tests/rest/client/v1/utils.py | 96 ++++++++++++++- 4 files changed, 299 insertions(+), 4 deletions(-) create mode 100644 changelog.d/7230.feature create mode 100644 tests/rest/client/test_power_levels.py (limited to 'tests') diff --git a/changelog.d/7230.feature b/changelog.d/7230.feature new file mode 100644 index 0000000000..aab777648f --- /dev/null +++ b/changelog.d/7230.feature @@ -0,0 +1 @@ +Require admin privileges to enable room encryption by default. This does not affect existing rooms. diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index f580ab2e9f..df3e0cff67 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -806,6 +806,7 @@ class RoomCreationHandler(BaseHandler): EventTypes.RoomAvatar: 50, EventTypes.Tombstone: 100, EventTypes.ServerACL: 100, + EventTypes.RoomEncryption: 100, }, "events_default": 0, "state_default": 50, diff --git a/tests/rest/client/test_power_levels.py b/tests/rest/client/test_power_levels.py new file mode 100644 index 0000000000..913ea3c98e --- /dev/null +++ b/tests/rest/client/test_power_levels.py @@ -0,0 +1,205 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.rest import admin +from synapse.rest.client.v1 import login, room +from synapse.rest.client.v2_alpha import sync + +from tests.unittest import HomeserverTestCase + + +class PowerLevelsTestCase(HomeserverTestCase): + """Tests that power levels are enforced in various situations""" + + servlets = [ + admin.register_servlets, + room.register_servlets, + login.register_servlets, + sync.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + + return self.setup_test_homeserver(config=config) + + def prepare(self, reactor, clock, hs): + # register a room admin, moderator and regular user + self.admin_user_id = self.register_user("admin", "pass") + self.admin_access_token = self.login("admin", "pass") + self.mod_user_id = self.register_user("mod", "pass") + self.mod_access_token = self.login("mod", "pass") + self.user_user_id = self.register_user("user", "pass") + self.user_access_token = self.login("user", "pass") + + # Create a room + self.room_id = self.helper.create_room_as( + self.admin_user_id, tok=self.admin_access_token + ) + + # Invite the other users + self.helper.invite( + room=self.room_id, + src=self.admin_user_id, + tok=self.admin_access_token, + targ=self.mod_user_id, + ) + self.helper.invite( + room=self.room_id, + src=self.admin_user_id, + tok=self.admin_access_token, + targ=self.user_user_id, + ) + + # Make the other users join the room + self.helper.join( + room=self.room_id, user=self.mod_user_id, tok=self.mod_access_token + ) + self.helper.join( + room=self.room_id, user=self.user_user_id, tok=self.user_access_token + ) + + # Mod the mod + room_power_levels = self.helper.get_state( + self.room_id, "m.room.power_levels", tok=self.admin_access_token, + ) + + # Update existing power levels with mod at PL50 + room_power_levels["users"].update({self.mod_user_id: 50}) + + self.helper.send_state( + self.room_id, + "m.room.power_levels", + room_power_levels, + tok=self.admin_access_token, + ) + + def test_non_admins_cannot_enable_room_encryption(self): + # have the mod try to enable room encryption + self.helper.send_state( + self.room_id, + "m.room.encryption", + {"algorithm": "m.megolm.v1.aes-sha2"}, + tok=self.mod_access_token, + expect_code=403, # expect failure + ) + + # have the user try to enable room encryption + self.helper.send_state( + self.room_id, + "m.room.encryption", + {"algorithm": "m.megolm.v1.aes-sha2"}, + tok=self.user_access_token, + expect_code=403, # expect failure + ) + + def test_non_admins_cannot_send_server_acl(self): + # have the mod try to send a server ACL + self.helper.send_state( + self.room_id, + "m.room.server_acl", + { + "allow": ["*"], + "allow_ip_literals": False, + "deny": ["*.evil.com", "evil.com"], + }, + tok=self.mod_access_token, + expect_code=403, # expect failure + ) + + # have the user try to send a server ACL + self.helper.send_state( + self.room_id, + "m.room.server_acl", + { + "allow": ["*"], + "allow_ip_literals": False, + "deny": ["*.evil.com", "evil.com"], + }, + tok=self.user_access_token, + expect_code=403, # expect failure + ) + + def test_non_admins_cannot_tombstone_room(self): + # Create another room that will serve as our "upgraded room" + self.upgraded_room_id = self.helper.create_room_as( + self.admin_user_id, tok=self.admin_access_token + ) + + # have the mod try to send a tombstone event + self.helper.send_state( + self.room_id, + "m.room.tombstone", + { + "body": "This room has been replaced", + "replacement_room": self.upgraded_room_id, + }, + tok=self.mod_access_token, + expect_code=403, # expect failure + ) + + # have the user try to send a tombstone event + self.helper.send_state( + self.room_id, + "m.room.tombstone", + { + "body": "This room has been replaced", + "replacement_room": self.upgraded_room_id, + }, + tok=self.user_access_token, + expect_code=403, # expect failure + ) + + def test_admins_can_enable_room_encryption(self): + # have the admin try to enable room encryption + self.helper.send_state( + self.room_id, + "m.room.encryption", + {"algorithm": "m.megolm.v1.aes-sha2"}, + tok=self.admin_access_token, + expect_code=200, # expect success + ) + + def test_admins_can_send_server_acl(self): + # have the admin try to send a server ACL + self.helper.send_state( + self.room_id, + "m.room.server_acl", + { + "allow": ["*"], + "allow_ip_literals": False, + "deny": ["*.evil.com", "evil.com"], + }, + tok=self.admin_access_token, + expect_code=200, # expect success + ) + + def test_admins_can_tombstone_room(self): + # Create another room that will serve as our "upgraded room" + self.upgraded_room_id = self.helper.create_room_as( + self.admin_user_id, tok=self.admin_access_token + ) + + # have the admin try to send a tombstone event + self.helper.send_state( + self.room_id, + "m.room.tombstone", + { + "body": "This room has been replaced", + "replacement_room": self.upgraded_room_id, + }, + tok=self.admin_access_token, + expect_code=200, # expect success + ) diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index 873d5ef99c..371637618d 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -18,6 +18,7 @@ import json import time +from typing import Any, Dict, Optional import attr @@ -142,7 +143,34 @@ class RestHelper(object): return channel.json_body - def send_state(self, room_id, event_type, body, tok, expect_code=200, state_key=""): + def _read_write_state( + self, + room_id: str, + event_type: str, + body: Optional[Dict[str, Any]], + tok: str, + expect_code: int = 200, + state_key: str = "", + method: str = "GET", + ) -> Dict: + """Read or write some state from a given room + + Args: + room_id: + event_type: The type of state event + body: Body that is sent when making the request. The content of the state event. + If None, the request to the server will have an empty body + tok: The access token to use + expect_code: The HTTP code to expect in the response + state_key: + method: "GET" or "PUT" for reading or writing state, respectively + + Returns: + The response body from the server + + Raises: + AssertionError: if expect_code doesn't match the HTTP code we received + """ path = "/_matrix/client/r0/rooms/%s/state/%s/%s" % ( room_id, event_type, @@ -151,9 +179,13 @@ class RestHelper(object): if tok: path = path + "?access_token=%s" % tok - request, channel = make_request( - self.hs.get_reactor(), "PUT", path, json.dumps(body).encode("utf8") - ) + # Set request body if provided + content = b"" + if body is not None: + content = json.dumps(body).encode("utf8") + + request, channel = make_request(self.hs.get_reactor(), method, path, content) + render(request, self.resource, self.hs.get_reactor()) assert int(channel.result["code"]) == expect_code, ( @@ -163,6 +195,62 @@ class RestHelper(object): return channel.json_body + def get_state( + self, + room_id: str, + event_type: str, + tok: str, + expect_code: int = 200, + state_key: str = "", + ): + """Gets some state from a room + + Args: + room_id: + event_type: The type of state event + tok: The access token to use + expect_code: The HTTP code to expect in the response + state_key: + + Returns: + The response body from the server + + Raises: + AssertionError: if expect_code doesn't match the HTTP code we received + """ + return self._read_write_state( + room_id, event_type, None, tok, expect_code, state_key, method="GET" + ) + + def send_state( + self, + room_id: str, + event_type: str, + body: Dict[str, Any], + tok: str, + expect_code: int = 200, + state_key: str = "", + ): + """Set some state in a room + + Args: + room_id: + event_type: The type of state event + body: Body that is sent when making the request. The content of the state event. + tok: The access token to use + expect_code: The HTTP code to expect in the response + state_key: + + Returns: + The response body from the server + + Raises: + AssertionError: if expect_code doesn't match the HTTP code we received + """ + return self._read_write_state( + room_id, event_type, body, tok, expect_code, state_key, method="PUT" + ) + def upload_media( self, resource: Resource, -- cgit 1.5.1 From f1097e772042bf9bdfc3d5f71d1c37a4e13084da Mon Sep 17 00:00:00 2001 From: Zay11Zay Date: Tue, 14 Apr 2020 15:37:28 -0400 Subject: Fix the parameters of a test fixture (#7243) --- changelog.d/7243.misc | 1 + tests/rest/client/v1/test_events.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 changelog.d/7243.misc (limited to 'tests') diff --git a/changelog.d/7243.misc b/changelog.d/7243.misc new file mode 100644 index 0000000000..a39c257a54 --- /dev/null +++ b/changelog.d/7243.misc @@ -0,0 +1 @@ +Correct the parameters of a test fixture. Contributed by Isaiah Singletary. diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py index ffb2de1505..b54b06482b 100644 --- a/tests/rest/client/v1/test_events.py +++ b/tests/rest/client/v1/test_events.py @@ -50,7 +50,7 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase): return hs - def prepare(self, hs, reactor, clock): + def prepare(self, reactor, clock, hs): # register an account self.user_id = self.register_user("sid1", "pass") -- cgit 1.5.1 From a48138784ea20dd8d8a68ce5c3563da6f3fbde43 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Wed, 15 Apr 2020 13:35:29 +0100 Subject: Allow specifying the value of Accept-Language header for URL previews (#7265) --- changelog.d/7265.feature | 1 + docs/sample_config.yaml | 25 ++++++++++++ synapse/config/repository.py | 29 ++++++++++++++ synapse/rest/media/v1/preview_url_resource.py | 8 +++- tests/rest/media/v1/test_url_preview.py | 55 +++++++++++++++++++++++++++ 5 files changed, 116 insertions(+), 2 deletions(-) create mode 100644 changelog.d/7265.feature (limited to 'tests') diff --git a/changelog.d/7265.feature b/changelog.d/7265.feature new file mode 100644 index 0000000000..345b63e0b7 --- /dev/null +++ b/changelog.d/7265.feature @@ -0,0 +1 @@ +Add a config option for specifying the value of the Accept-Language HTTP header when generating URL previews. \ No newline at end of file diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 3417813750..81dccbd997 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -859,6 +859,31 @@ media_store_path: "DATADIR/media_store" # #max_spider_size: 10M +# A list of values for the Accept-Language HTTP header used when +# downloading webpages during URL preview generation. This allows +# Synapse to specify the preferred languages that URL previews should +# be in when communicating with remote servers. +# +# Each value is a IETF language tag; a 2-3 letter identifier for a +# language, optionally followed by subtags separated by '-', specifying +# a country or region variant. +# +# Multiple values can be provided, and a weight can be added to each by +# using quality value syntax (;q=). '*' translates to any language. +# +# Defaults to "en". +# +# Example: +# +# url_preview_accept_language: +# - en-UK +# - en-US;q=0.9 +# - fr;q=0.8 +# - *;q=0.7 +# +url_preview_accept_language: +# - en + ## Captcha ## # See docs/CAPTCHA_SETUP for full details of configuring this. diff --git a/synapse/config/repository.py b/synapse/config/repository.py index 7d2dd27fd0..7193ea1114 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -192,6 +192,10 @@ class ContentRepositoryConfig(Config): self.url_preview_url_blacklist = config.get("url_preview_url_blacklist", ()) + self.url_preview_accept_language = config.get( + "url_preview_accept_language" + ) or ["en"] + def generate_config_section(self, data_dir_path, **kwargs): media_store = os.path.join(data_dir_path, "media_store") uploads_path = os.path.join(data_dir_path, "uploads") @@ -329,6 +333,31 @@ class ContentRepositoryConfig(Config): # The largest allowed URL preview spidering size in bytes # #max_spider_size: 10M + + # A list of values for the Accept-Language HTTP header used when + # downloading webpages during URL preview generation. This allows + # Synapse to specify the preferred languages that URL previews should + # be in when communicating with remote servers. + # + # Each value is a IETF language tag; a 2-3 letter identifier for a + # language, optionally followed by subtags separated by '-', specifying + # a country or region variant. + # + # Multiple values can be provided, and a weight can be added to each by + # using quality value syntax (;q=). '*' translates to any language. + # + # Defaults to "en". + # + # Example: + # + # url_preview_accept_language: + # - en-UK + # - en-US;q=0.9 + # - fr;q=0.8 + # - *;q=0.7 + # + url_preview_accept_language: + # - en """ % locals() ) diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index c46676f8fc..f68e18ea8a 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -86,6 +86,7 @@ class PreviewUrlResource(DirectServeResource): self.media_storage = media_storage self.url_preview_url_blacklist = hs.config.url_preview_url_blacklist + self.url_preview_accept_language = hs.config.url_preview_accept_language # memory cache mapping urls to an ObservableDeferred returning # JSON-encoded OG metadata @@ -315,9 +316,12 @@ class PreviewUrlResource(DirectServeResource): with self.media_storage.store_into_file(file_info) as (f, fname, finish): try: - logger.debug("Trying to get url '%s'", url) + logger.debug("Trying to get preview for url '%s'", url) length, headers, uri, code = await self.client.get_file( - url, output_stream=f, max_size=self.max_spider_size + url, + output_stream=f, + max_size=self.max_spider_size, + headers={"Accept-Language": self.url_preview_accept_language}, ) except SynapseError: # Pass SynapseErrors through directly, so that the servlet diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py index 852b8ab11c..2826211f32 100644 --- a/tests/rest/media/v1/test_url_preview.py +++ b/tests/rest/media/v1/test_url_preview.py @@ -74,6 +74,12 @@ class URLPreviewTests(unittest.HomeserverTestCase): ) config["url_preview_ip_range_whitelist"] = ("1.1.1.1",) config["url_preview_url_blacklist"] = [] + config["url_preview_accept_language"] = [ + "en-UK", + "en-US;q=0.9", + "fr;q=0.8", + "*;q=0.7", + ] self.storage_path = self.mktemp() self.media_store_path = self.mktemp() @@ -507,3 +513,52 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.pump() self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body, {}) + + def test_accept_language_config_option(self): + """ + Accept-Language header is sent to the remote server + """ + self.lookups["example.com"] = [(IPv4Address, "8.8.8.8")] + + # Build and make a request to the server + request, channel = self.make_request( + "GET", "url_preview?url=http://example.com", shorthand=False + ) + request.render(self.preview_url) + self.pump() + + # Extract Synapse's tcp client + client = self.reactor.tcpClients[0][2].buildProtocol(None) + + # Build a fake remote server to reply with + server = AccumulatingProtocol() + + # Connect the two together + server.makeConnection(FakeTransport(client, self.reactor)) + client.makeConnection(FakeTransport(server, self.reactor)) + + # Tell Synapse that it has received some data from the remote server + client.dataReceived( + b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\nContent-Type: text/html\r\n\r\n" + % (len(self.end_content),) + + self.end_content + ) + + # Move the reactor along until we get a response on our original channel + self.pump() + self.assertEqual(channel.code, 200) + self.assertEqual( + channel.json_body, {"og:title": "~matrix~", "og:description": "hi"} + ) + + # Check that the server received the Accept-Language header as part + # of the request from Synapse + self.assertIn( + ( + b"Accept-Language: en-UK\r\n" + b"Accept-Language: en-US;q=0.9\r\n" + b"Accept-Language: fr;q=0.8\r\n" + b"Accept-Language: *;q=0.7" + ), + server.data, + ) -- cgit 1.5.1 From eed7c5b89eee6951ac17861b1695817470bace36 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 15 Apr 2020 12:40:18 -0400 Subject: Convert auth handler to async/await (#7261) --- changelog.d/7261.misc | 1 + synapse/handlers/auth.py | 173 ++++++++++++++++++--------------------- synapse/handlers/device.py | 12 ++- synapse/handlers/register.py | 28 +++++-- synapse/handlers/set_password.py | 13 ++- synapse/module_api/__init__.py | 6 +- tests/api/test_auth.py | 64 +++++++++------ tests/handlers/test_auth.py | 80 +++++++++++------- tests/handlers/test_register.py | 4 +- tests/utils.py | 13 ++- 10 files changed, 224 insertions(+), 170 deletions(-) create mode 100644 changelog.d/7261.misc (limited to 'tests') diff --git a/changelog.d/7261.misc b/changelog.d/7261.misc new file mode 100644 index 0000000000..88165f0105 --- /dev/null +++ b/changelog.d/7261.misc @@ -0,0 +1 @@ +Convert auth handler to async/await. diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index fbfbd44a2e..0aae929ecc 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -18,14 +18,12 @@ import logging import time import unicodedata import urllib.parse -from typing import Any, Dict, Iterable, List, Optional +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union import attr import bcrypt # type: ignore[import] import pymacaroons -from twisted.internet import defer - import synapse.util.stringutils as stringutils from synapse.api.constants import LoginType from synapse.api.errors import ( @@ -170,15 +168,14 @@ class AuthHandler(BaseHandler): # cast to tuple for use with str.startswith self._whitelisted_sso_clients = tuple(hs.config.sso_client_whitelist) - @defer.inlineCallbacks - def validate_user_via_ui_auth( + async def validate_user_via_ui_auth( self, requester: Requester, request: SynapseRequest, request_body: Dict[str, Any], clientip: str, description: str, - ): + ) -> dict: """ Checks that the user is who they claim to be, via a UI auth. @@ -199,7 +196,7 @@ class AuthHandler(BaseHandler): describes the operation happening on their account. Returns: - defer.Deferred[dict]: the parameters for this request (which may + The parameters for this request (which may have been given only in a previous call). Raises: @@ -229,7 +226,7 @@ class AuthHandler(BaseHandler): flows = [[login_type] for login_type in self._supported_ui_auth_types] try: - result, params, _ = yield self.check_auth( + result, params, _ = await self.check_auth( flows, request, request_body, clientip, description ) except LoginError: @@ -268,15 +265,14 @@ class AuthHandler(BaseHandler): """ return self.checkers.keys() - @defer.inlineCallbacks - def check_auth( + async def check_auth( self, flows: List[List[str]], request: SynapseRequest, clientdict: Dict[str, Any], clientip: str, description: str, - ): + ) -> Tuple[dict, dict, str]: """ Takes a dictionary sent by the client in the login / registration protocol and handles the User-Interactive Auth flow. @@ -306,8 +302,7 @@ class AuthHandler(BaseHandler): describes the operation happening on their account. Returns: - defer.Deferred[dict, dict, str]: a deferred tuple of - (creds, params, session_id). + A tuple of (creds, params, session_id). 'creds' contains the authenticated credentials of each stage. @@ -380,7 +375,7 @@ class AuthHandler(BaseHandler): if "type" in authdict: login_type = authdict["type"] # type: str try: - result = yield self._check_auth_dict(authdict, clientip) + result = await self._check_auth_dict(authdict, clientip) if result: creds[login_type] = result self._save_session(session) @@ -419,8 +414,9 @@ class AuthHandler(BaseHandler): ret.update(errordict) raise InteractiveAuthIncompleteError(ret) - @defer.inlineCallbacks - def add_oob_auth(self, stagetype: str, authdict: Dict[str, Any], clientip: str): + async def add_oob_auth( + self, stagetype: str, authdict: Dict[str, Any], clientip: str + ) -> bool: """ Adds the result of out-of-band authentication into an existing auth session. Currently used for adding the result of fallback auth. @@ -435,7 +431,7 @@ class AuthHandler(BaseHandler): sess["creds"] = {} creds = sess["creds"] - result = yield self.checkers[stagetype].check_auth(authdict, clientip) + result = await self.checkers[stagetype].check_auth(authdict, clientip) if result: creds[stagetype] = result self._save_session(sess) @@ -489,8 +485,9 @@ class AuthHandler(BaseHandler): sess = self._get_session_info(session_id) return sess.setdefault("serverdict", {}).get(key, default) - @defer.inlineCallbacks - def _check_auth_dict(self, authdict: Dict[str, Any], clientip: str): + async def _check_auth_dict( + self, authdict: Dict[str, Any], clientip: str + ) -> Union[Dict[str, Any], str]: """Attempt to validate the auth dict provided by a client Args: @@ -498,7 +495,7 @@ class AuthHandler(BaseHandler): clientip: IP address of the client Returns: - Deferred: result of the stage verification. + Result of the stage verification. Raises: StoreError if there was a problem accessing the database @@ -508,7 +505,7 @@ class AuthHandler(BaseHandler): login_type = authdict["type"] checker = self.checkers.get(login_type) if checker is not None: - res = yield checker.check_auth(authdict, clientip=clientip) + res = await checker.check_auth(authdict, clientip=clientip) return res # build a v1-login-style dict out of the authdict and fall back to the @@ -518,7 +515,7 @@ class AuthHandler(BaseHandler): if user_id is None: raise SynapseError(400, "", Codes.MISSING_PARAM) - (canonical_id, callback) = yield self.validate_login(user_id, authdict) + (canonical_id, callback) = await self.validate_login(user_id, authdict) return canonical_id def _get_params_recaptcha(self) -> dict: @@ -584,8 +581,7 @@ class AuthHandler(BaseHandler): return self.sessions[session_id] - @defer.inlineCallbacks - def get_access_token_for_user_id( + async def get_access_token_for_user_id( self, user_id: str, device_id: Optional[str], valid_until_ms: Optional[int] ): """ @@ -615,10 +611,10 @@ class AuthHandler(BaseHandler): ) logger.info("Logging in user %s on device %s%s", user_id, device_id, fmt_expiry) - yield self.auth.check_auth_blocking(user_id) + await self.auth.check_auth_blocking(user_id) access_token = self.macaroon_gen.generate_access_token(user_id) - yield self.store.add_access_token_to_user( + await self.store.add_access_token_to_user( user_id, access_token, device_id, valid_until_ms ) @@ -628,15 +624,14 @@ class AuthHandler(BaseHandler): # device, so we double-check it here. if device_id is not None: try: - yield self.store.get_device(user_id, device_id) + await self.store.get_device(user_id, device_id) except StoreError: - yield self.store.delete_access_token(access_token) + await self.store.delete_access_token(access_token) raise StoreError(400, "Login raced against device deletion") return access_token - @defer.inlineCallbacks - def check_user_exists(self, user_id: str): + async def check_user_exists(self, user_id: str) -> Optional[str]: """ Checks to see if a user with the given id exists. Will check case insensitively, but return None if there are multiple inexact matches. @@ -645,25 +640,25 @@ class AuthHandler(BaseHandler): user_id: complete @user:id Returns: - defer.Deferred: (unicode) canonical_user_id, or None if zero or - multiple matches + The canonical_user_id, or None if zero or multiple matches """ - res = yield self._find_user_id_and_pwd_hash(user_id) + res = await self._find_user_id_and_pwd_hash(user_id) if res is not None: return res[0] return None - @defer.inlineCallbacks - def _find_user_id_and_pwd_hash(self, user_id: str): + async def _find_user_id_and_pwd_hash( + self, user_id: str + ) -> Optional[Tuple[str, str]]: """Checks to see if a user with the given id exists. Will check case insensitively, but will return None if there are multiple inexact matches. Returns: - tuple: A 2-tuple of `(canonical_user_id, password_hash)` - None: if there is not exactly one match + A 2-tuple of `(canonical_user_id, password_hash)` or `None` + if there is not exactly one match """ - user_infos = yield self.store.get_users_by_id_case_insensitive(user_id) + user_infos = await self.store.get_users_by_id_case_insensitive(user_id) result = None if not user_infos: @@ -696,8 +691,9 @@ class AuthHandler(BaseHandler): """ return self._supported_login_types - @defer.inlineCallbacks - def validate_login(self, username: str, login_submission: Dict[str, Any]): + async def validate_login( + self, username: str, login_submission: Dict[str, Any] + ) -> Tuple[str, Optional[Callable[[Dict[str, str]], None]]]: """Authenticates the user for the /login API Also used by the user-interactive auth flow to validate @@ -708,7 +704,7 @@ class AuthHandler(BaseHandler): login_submission: the whole of the login submission (including 'type' and other relevant fields) Returns: - Deferred[str, func]: canonical user id, and optional callback + A tuple of the canonical user id, and optional callback to be called once the access token and device id are issued Raises: StoreError if there was a problem accessing the database @@ -737,7 +733,7 @@ class AuthHandler(BaseHandler): for provider in self.password_providers: if hasattr(provider, "check_password") and login_type == LoginType.PASSWORD: known_login_type = True - is_valid = yield provider.check_password(qualified_user_id, password) + is_valid = await provider.check_password(qualified_user_id, password) if is_valid: return qualified_user_id, None @@ -769,7 +765,7 @@ class AuthHandler(BaseHandler): % (login_type, missing_fields), ) - result = yield provider.check_auth(username, login_type, login_dict) + result = await provider.check_auth(username, login_type, login_dict) if result: if isinstance(result, str): result = (result, None) @@ -778,8 +774,8 @@ class AuthHandler(BaseHandler): if login_type == LoginType.PASSWORD and self.hs.config.password_localdb_enabled: known_login_type = True - canonical_user_id = yield self._check_local_password( - qualified_user_id, password + canonical_user_id = await self._check_local_password( + qualified_user_id, password # type: ignore ) if canonical_user_id: @@ -792,8 +788,9 @@ class AuthHandler(BaseHandler): # login, it turns all LoginErrors into a 401 anyway. raise LoginError(403, "Invalid password", errcode=Codes.FORBIDDEN) - @defer.inlineCallbacks - def check_password_provider_3pid(self, medium: str, address: str, password: str): + async def check_password_provider_3pid( + self, medium: str, address: str, password: str + ) -> Tuple[Optional[str], Optional[Callable[[Dict[str, str]], None]]]: """Check if a password provider is able to validate a thirdparty login Args: @@ -802,9 +799,8 @@ class AuthHandler(BaseHandler): password: The password of the user. Returns: - Deferred[(str|None, func|None)]: A tuple of `(user_id, - callback)`. If authentication is successful, `user_id` is a `str` - containing the authenticated, canonical user ID. `callback` is + A tuple of `(user_id, callback)`. If authentication is successful, + `user_id`is the authenticated, canonical user ID. `callback` is then either a function to be later run after the server has completed login/registration, or `None`. If authentication was unsuccessful, `user_id` and `callback` are both `None`. @@ -816,7 +812,7 @@ class AuthHandler(BaseHandler): # success, to a str (which is the user_id) or a tuple of # (user_id, callback_func), where callback_func should be run # after we've finished everything else - result = yield provider.check_3pid_auth(medium, address, password) + result = await provider.check_3pid_auth(medium, address, password) if result: # Check if the return value is a str or a tuple if isinstance(result, str): @@ -826,8 +822,7 @@ class AuthHandler(BaseHandler): return None, None - @defer.inlineCallbacks - def _check_local_password(self, user_id: str, password: str): + async def _check_local_password(self, user_id: str, password: str) -> Optional[str]: """Authenticate a user against the local password database. user_id is checked case insensitively, but will return None if there are @@ -837,28 +832,26 @@ class AuthHandler(BaseHandler): user_id: complete @user:id password: the provided password Returns: - Deferred[unicode] the canonical_user_id, or Deferred[None] if - unknown user/bad password + The canonical_user_id, or None if unknown user/bad password """ - lookupres = yield self._find_user_id_and_pwd_hash(user_id) + lookupres = await self._find_user_id_and_pwd_hash(user_id) if not lookupres: return None (user_id, password_hash) = lookupres # If the password hash is None, the account has likely been deactivated if not password_hash: - deactivated = yield self.store.get_user_deactivated_status(user_id) + deactivated = await self.store.get_user_deactivated_status(user_id) if deactivated: raise UserDeactivatedError("This account has been deactivated") - result = yield self.validate_hash(password, password_hash) + result = await self.validate_hash(password, password_hash) if not result: logger.warning("Failed password login for user %s", user_id) return None return user_id - @defer.inlineCallbacks - def validate_short_term_login_token_and_get_user_id(self, login_token: str): + async def validate_short_term_login_token_and_get_user_id(self, login_token: str): auth_api = self.hs.get_auth() user_id = None try: @@ -868,26 +861,23 @@ class AuthHandler(BaseHandler): except Exception: raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN) - yield self.auth.check_auth_blocking(user_id) + await self.auth.check_auth_blocking(user_id) return user_id - @defer.inlineCallbacks - def delete_access_token(self, access_token: str): + async def delete_access_token(self, access_token: str): """Invalidate a single access token Args: access_token: access token to be deleted - Returns: - Deferred """ - user_info = yield self.auth.get_user_by_access_token(access_token) - yield self.store.delete_access_token(access_token) + user_info = await self.auth.get_user_by_access_token(access_token) + await self.store.delete_access_token(access_token) # see if any of our auth providers want to know about this for provider in self.password_providers: if hasattr(provider, "on_logged_out"): - yield provider.on_logged_out( + await provider.on_logged_out( user_id=str(user_info["user"]), device_id=user_info["device_id"], access_token=access_token, @@ -895,12 +885,11 @@ class AuthHandler(BaseHandler): # delete pushers associated with this access token if user_info["token_id"] is not None: - yield self.hs.get_pusherpool().remove_pushers_by_access_token( + await self.hs.get_pusherpool().remove_pushers_by_access_token( str(user_info["user"]), (user_info["token_id"],) ) - @defer.inlineCallbacks - def delete_access_tokens_for_user( + async def delete_access_tokens_for_user( self, user_id: str, except_token_id: Optional[str] = None, @@ -914,10 +903,8 @@ class AuthHandler(BaseHandler): device_id: ID of device the tokens are associated with. If None, tokens associated with any device (or no device) will be deleted - Returns: - Deferred """ - tokens_and_devices = yield self.store.user_delete_access_tokens( + tokens_and_devices = await self.store.user_delete_access_tokens( user_id, except_token_id=except_token_id, device_id=device_id ) @@ -925,17 +912,18 @@ class AuthHandler(BaseHandler): for provider in self.password_providers: if hasattr(provider, "on_logged_out"): for token, token_id, device_id in tokens_and_devices: - yield provider.on_logged_out( + await provider.on_logged_out( user_id=user_id, device_id=device_id, access_token=token ) # delete pushers associated with the access tokens - yield self.hs.get_pusherpool().remove_pushers_by_access_token( + await self.hs.get_pusherpool().remove_pushers_by_access_token( user_id, (token_id for _, token_id, _ in tokens_and_devices) ) - @defer.inlineCallbacks - def add_threepid(self, user_id: str, medium: str, address: str, validated_at: int): + async def add_threepid( + self, user_id: str, medium: str, address: str, validated_at: int + ): # check if medium has a valid value if medium not in ["email", "msisdn"]: raise SynapseError( @@ -956,14 +944,13 @@ class AuthHandler(BaseHandler): if medium == "email": address = address.lower() - yield self.store.user_add_threepid( + await self.store.user_add_threepid( user_id, medium, address, validated_at, self.hs.get_clock().time_msec() ) - @defer.inlineCallbacks - def delete_threepid( + async def delete_threepid( self, user_id: str, medium: str, address: str, id_server: Optional[str] = None - ): + ) -> bool: """Attempts to unbind the 3pid on the identity servers and deletes it from the local database. @@ -976,7 +963,7 @@ class AuthHandler(BaseHandler): identity server specified when binding (if known). Returns: - Deferred[bool]: Returns True if successfully unbound the 3pid on + Returns True if successfully unbound the 3pid on the identity server, False if identity server doesn't support the unbind API. """ @@ -986,11 +973,11 @@ class AuthHandler(BaseHandler): address = address.lower() identity_handler = self.hs.get_handlers().identity_handler - result = yield identity_handler.try_unbind_threepid( + result = await identity_handler.try_unbind_threepid( user_id, {"medium": medium, "address": address, "id_server": id_server} ) - yield self.store.user_delete_threepid(user_id, medium, address) + await self.store.user_delete_threepid(user_id, medium, address) return result def _save_session(self, session: Dict[str, Any]) -> None: @@ -1000,14 +987,14 @@ class AuthHandler(BaseHandler): session["last_used"] = self.hs.get_clock().time_msec() self.sessions[session["id"]] = session - def hash(self, password: str): + async def hash(self, password: str) -> str: """Computes a secure hash of password. Args: password: Password to hash. Returns: - Deferred(unicode): Hashed password. + Hashed password. """ def _do_hash(): @@ -1019,9 +1006,11 @@ class AuthHandler(BaseHandler): bcrypt.gensalt(self.bcrypt_rounds), ).decode("ascii") - return defer_to_thread(self.hs.get_reactor(), _do_hash) + return await defer_to_thread(self.hs.get_reactor(), _do_hash) - def validate_hash(self, password: str, stored_hash: bytes): + async def validate_hash( + self, password: str, stored_hash: Union[bytes, str] + ) -> bool: """Validates that self.hash(password) == stored_hash. Args: @@ -1029,7 +1018,7 @@ class AuthHandler(BaseHandler): stored_hash: Expected hash value. Returns: - Deferred(bool): Whether self.hash(password) == stored_hash. + Whether self.hash(password) == stored_hash. """ def _do_validate_hash(): @@ -1045,9 +1034,9 @@ class AuthHandler(BaseHandler): if not isinstance(stored_hash, bytes): stored_hash = stored_hash.encode("ascii") - return defer_to_thread(self.hs.get_reactor(), _do_validate_hash) + return await defer_to_thread(self.hs.get_reactor(), _do_validate_hash) else: - return defer.succeed(False) + return False def start_sso_ui_auth(self, redirect_url: str, session_id: str) -> str: """ diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 993499f446..9bd941b5a0 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -338,8 +338,10 @@ class DeviceHandler(DeviceWorkerHandler): else: raise - yield self._auth_handler.delete_access_tokens_for_user( - user_id, device_id=device_id + yield defer.ensureDeferred( + self._auth_handler.delete_access_tokens_for_user( + user_id, device_id=device_id + ) ) yield self.store.delete_e2e_keys_by_device(user_id=user_id, device_id=device_id) @@ -391,8 +393,10 @@ class DeviceHandler(DeviceWorkerHandler): # Delete access tokens and e2e keys for each device. Not optimised as it is not # considered as part of a critical path. for device_id in device_ids: - yield self._auth_handler.delete_access_tokens_for_user( - user_id, device_id=device_id + yield defer.ensureDeferred( + self._auth_handler.delete_access_tokens_for_user( + user_id, device_id=device_id + ) ) yield self.store.delete_e2e_keys_by_device( user_id=user_id, device_id=device_id diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 7ffc194f0c..3a65b46ecd 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -166,7 +166,9 @@ class RegistrationHandler(BaseHandler): yield self.auth.check_auth_blocking(threepid=threepid) password_hash = None if password: - password_hash = yield self._auth_handler.hash(password) + password_hash = yield defer.ensureDeferred( + self._auth_handler.hash(password) + ) if localpart is not None: yield self.check_username(localpart, guest_access_token=guest_access_token) @@ -540,8 +542,10 @@ class RegistrationHandler(BaseHandler): user_id, ["guest = true"] ) else: - access_token = yield self._auth_handler.get_access_token_for_user_id( - user_id, device_id=device_id, valid_until_ms=valid_until_ms + access_token = yield defer.ensureDeferred( + self._auth_handler.get_access_token_for_user_id( + user_id, device_id=device_id, valid_until_ms=valid_until_ms + ) ) return (device_id, access_token) @@ -617,8 +621,13 @@ class RegistrationHandler(BaseHandler): logger.info("Can't add incomplete 3pid") return - yield self._auth_handler.add_threepid( - user_id, threepid["medium"], threepid["address"], threepid["validated_at"] + yield defer.ensureDeferred( + self._auth_handler.add_threepid( + user_id, + threepid["medium"], + threepid["address"], + threepid["validated_at"], + ) ) # And we add an email pusher for them by default, but only @@ -670,6 +679,11 @@ class RegistrationHandler(BaseHandler): return None raise - yield self._auth_handler.add_threepid( - user_id, threepid["medium"], threepid["address"], threepid["validated_at"] + yield defer.ensureDeferred( + self._auth_handler.add_threepid( + user_id, + threepid["medium"], + threepid["address"], + threepid["validated_at"], + ) ) diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py index 7d1263caf2..63d8f9aa0d 100644 --- a/synapse/handlers/set_password.py +++ b/synapse/handlers/set_password.py @@ -15,8 +15,6 @@ import logging from typing import Optional -from twisted.internet import defer - from synapse.api.errors import Codes, StoreError, SynapseError from synapse.types import Requester @@ -34,8 +32,7 @@ class SetPasswordHandler(BaseHandler): self._device_handler = hs.get_device_handler() self._password_policy_handler = hs.get_password_policy_handler() - @defer.inlineCallbacks - def set_password( + async def set_password( self, user_id: str, new_password: str, @@ -46,10 +43,10 @@ class SetPasswordHandler(BaseHandler): raise SynapseError(403, "Password change disabled", errcode=Codes.FORBIDDEN) self._password_policy_handler.validate_password(new_password) - password_hash = yield self._auth_handler.hash(new_password) + password_hash = await self._auth_handler.hash(new_password) try: - yield self.store.user_set_password_hash(user_id, password_hash) + await self.store.user_set_password_hash(user_id, password_hash) except StoreError as e: if e.code == 404: raise SynapseError(404, "Unknown user", Codes.NOT_FOUND) @@ -61,12 +58,12 @@ class SetPasswordHandler(BaseHandler): except_access_token_id = requester.access_token_id if requester else None # First delete all of their other devices. - yield self._device_handler.delete_all_devices_for_user( + await self._device_handler.delete_all_devices_for_user( user_id, except_device_id=except_device_id ) # and now delete any access tokens which weren't associated with # devices (or were associated with this device). - yield self._auth_handler.delete_access_tokens_for_user( + await self._auth_handler.delete_access_tokens_for_user( user_id, except_token_id=except_access_token_id ) diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index afc3598e11..d678c0eb9b 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -86,7 +86,7 @@ class ModuleApi(object): Deferred[str|None]: Canonical (case-corrected) user_id, or None if the user is not registered. """ - return self._auth_handler.check_user_exists(user_id) + return defer.ensureDeferred(self._auth_handler.check_user_exists(user_id)) @defer.inlineCallbacks def register(self, localpart, displayname=None, emails=[]): @@ -196,7 +196,9 @@ class ModuleApi(object): yield self._hs.get_device_handler().delete_device(user_id, device_id) else: # no associated device. Just delete the access token. - yield self._auth_handler.delete_access_token(access_token) + yield defer.ensureDeferred( + self._auth_handler.delete_access_token(access_token) + ) def run_db_interaction(self, desc, func, *args, **kwargs): """Run a function with a database connection diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index 6121efcfa9..cc0b10e7f6 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -68,7 +68,7 @@ class AuthTestCase(unittest.TestCase): request = Mock(args={}) request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() - requester = yield self.auth.get_user_by_req(request) + requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request)) self.assertEquals(requester.user.to_string(), self.test_user) def test_get_user_by_req_user_bad_token(self): @@ -105,7 +105,7 @@ class AuthTestCase(unittest.TestCase): request.getClientIP.return_value = "127.0.0.1" request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() - requester = yield self.auth.get_user_by_req(request) + requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request)) self.assertEquals(requester.user.to_string(), self.test_user) @defer.inlineCallbacks @@ -125,7 +125,7 @@ class AuthTestCase(unittest.TestCase): request.getClientIP.return_value = "192.168.10.10" request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() - requester = yield self.auth.get_user_by_req(request) + requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request)) self.assertEquals(requester.user.to_string(), self.test_user) def test_get_user_by_req_appservice_valid_token_bad_ip(self): @@ -188,7 +188,7 @@ class AuthTestCase(unittest.TestCase): request.args[b"access_token"] = [self.test_token] request.args[b"user_id"] = [masquerading_user_id] request.requestHeaders.getRawHeaders = mock_getRawHeaders() - requester = yield self.auth.get_user_by_req(request) + requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request)) self.assertEquals( requester.user.to_string(), masquerading_user_id.decode("utf8") ) @@ -225,7 +225,9 @@ class AuthTestCase(unittest.TestCase): macaroon.add_first_party_caveat("gen = 1") macaroon.add_first_party_caveat("type = access") macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) - user_info = yield self.auth.get_user_by_access_token(macaroon.serialize()) + user_info = yield defer.ensureDeferred( + self.auth.get_user_by_access_token(macaroon.serialize()) + ) user = user_info["user"] self.assertEqual(UserID.from_string(user_id), user) @@ -250,7 +252,9 @@ class AuthTestCase(unittest.TestCase): macaroon.add_first_party_caveat("guest = true") serialized = macaroon.serialize() - user_info = yield self.auth.get_user_by_access_token(serialized) + user_info = yield defer.ensureDeferred( + self.auth.get_user_by_access_token(serialized) + ) user = user_info["user"] is_guest = user_info["is_guest"] self.assertEqual(UserID.from_string(user_id), user) @@ -260,10 +264,13 @@ class AuthTestCase(unittest.TestCase): @defer.inlineCallbacks def test_cannot_use_regular_token_as_guest(self): USER_ID = "@percy:matrix.org" - self.store.add_access_token_to_user = Mock() + self.store.add_access_token_to_user = Mock(return_value=defer.succeed(None)) + self.store.get_device = Mock(return_value=defer.succeed(None)) - token = yield self.hs.handlers.auth_handler.get_access_token_for_user_id( - USER_ID, "DEVICE", valid_until_ms=None + token = yield defer.ensureDeferred( + self.hs.handlers.auth_handler.get_access_token_for_user_id( + USER_ID, "DEVICE", valid_until_ms=None + ) ) self.store.add_access_token_to_user.assert_called_with( USER_ID, token, "DEVICE", None @@ -286,7 +293,9 @@ class AuthTestCase(unittest.TestCase): request = Mock(args={}) request.args[b"access_token"] = [token.encode("ascii")] request.requestHeaders.getRawHeaders = mock_getRawHeaders() - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + requester = yield defer.ensureDeferred( + self.auth.get_user_by_req(request, allow_guest=True) + ) self.assertEqual(UserID.from_string(USER_ID), requester.user) self.assertFalse(requester.is_guest) @@ -301,7 +310,9 @@ class AuthTestCase(unittest.TestCase): request.requestHeaders.getRawHeaders = mock_getRawHeaders() with self.assertRaises(InvalidClientCredentialsError) as cm: - yield self.auth.get_user_by_req(request, allow_guest=True) + yield defer.ensureDeferred( + self.auth.get_user_by_req(request, allow_guest=True) + ) self.assertEqual(401, cm.exception.code) self.assertEqual("Guest access token used for regular user", cm.exception.msg) @@ -316,7 +327,7 @@ class AuthTestCase(unittest.TestCase): small_number_of_users = 1 # Ensure no error thrown - yield self.auth.check_auth_blocking() + yield defer.ensureDeferred(self.auth.check_auth_blocking()) self.hs.config.limit_usage_by_mau = True @@ -325,7 +336,7 @@ class AuthTestCase(unittest.TestCase): ) with self.assertRaises(ResourceLimitError) as e: - yield self.auth.check_auth_blocking() + yield defer.ensureDeferred(self.auth.check_auth_blocking()) self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact) self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.assertEquals(e.exception.code, 403) @@ -334,7 +345,7 @@ class AuthTestCase(unittest.TestCase): self.store.get_monthly_active_count = Mock( return_value=defer.succeed(small_number_of_users) ) - yield self.auth.check_auth_blocking() + yield defer.ensureDeferred(self.auth.check_auth_blocking()) @defer.inlineCallbacks def test_blocking_mau__depending_on_user_type(self): @@ -343,15 +354,19 @@ class AuthTestCase(unittest.TestCase): self.store.get_monthly_active_count = Mock(return_value=defer.succeed(100)) # Support users allowed - yield self.auth.check_auth_blocking(user_type=UserTypes.SUPPORT) + yield defer.ensureDeferred( + self.auth.check_auth_blocking(user_type=UserTypes.SUPPORT) + ) self.store.get_monthly_active_count = Mock(return_value=defer.succeed(100)) # Bots not allowed with self.assertRaises(ResourceLimitError): - yield self.auth.check_auth_blocking(user_type=UserTypes.BOT) + yield defer.ensureDeferred( + self.auth.check_auth_blocking(user_type=UserTypes.BOT) + ) self.store.get_monthly_active_count = Mock(return_value=defer.succeed(100)) # Real users not allowed with self.assertRaises(ResourceLimitError): - yield self.auth.check_auth_blocking() + yield defer.ensureDeferred(self.auth.check_auth_blocking()) @defer.inlineCallbacks def test_reserved_threepid(self): @@ -362,21 +377,22 @@ class AuthTestCase(unittest.TestCase): unknown_threepid = {"medium": "email", "address": "unreserved@server.com"} self.hs.config.mau_limits_reserved_threepids = [threepid] - yield self.store.register_user(user_id="user1", password_hash=None) with self.assertRaises(ResourceLimitError): - yield self.auth.check_auth_blocking() + yield defer.ensureDeferred(self.auth.check_auth_blocking()) with self.assertRaises(ResourceLimitError): - yield self.auth.check_auth_blocking(threepid=unknown_threepid) + yield defer.ensureDeferred( + self.auth.check_auth_blocking(threepid=unknown_threepid) + ) - yield self.auth.check_auth_blocking(threepid=threepid) + yield defer.ensureDeferred(self.auth.check_auth_blocking(threepid=threepid)) @defer.inlineCallbacks def test_hs_disabled(self): self.hs.config.hs_disabled = True self.hs.config.hs_disabled_message = "Reason for being disabled" with self.assertRaises(ResourceLimitError) as e: - yield self.auth.check_auth_blocking() + yield defer.ensureDeferred(self.auth.check_auth_blocking()) self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact) self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.assertEquals(e.exception.code, 403) @@ -393,7 +409,7 @@ class AuthTestCase(unittest.TestCase): self.hs.config.hs_disabled = True self.hs.config.hs_disabled_message = "Reason for being disabled" with self.assertRaises(ResourceLimitError) as e: - yield self.auth.check_auth_blocking() + yield defer.ensureDeferred(self.auth.check_auth_blocking()) self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact) self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.assertEquals(e.exception.code, 403) @@ -404,4 +420,4 @@ class AuthTestCase(unittest.TestCase): user = "@user:server" self.hs.config.server_notices_mxid = user self.hs.config.hs_disabled_message = "Reason for being disabled" - yield self.auth.check_auth_blocking(user) + yield defer.ensureDeferred(self.auth.check_auth_blocking(user)) diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index b03103d96f..52c4ac8b11 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -82,16 +82,16 @@ class AuthTestCase(unittest.TestCase): self.hs.clock.now = 1000 token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000) - user_id = yield self.auth_handler.validate_short_term_login_token_and_get_user_id( - token + user_id = yield defer.ensureDeferred( + self.auth_handler.validate_short_term_login_token_and_get_user_id(token) ) self.assertEqual("a_user", user_id) # when we advance the clock, the token should be rejected self.hs.clock.now = 6000 with self.assertRaises(synapse.api.errors.AuthError): - yield self.auth_handler.validate_short_term_login_token_and_get_user_id( - token + yield defer.ensureDeferred( + self.auth_handler.validate_short_term_login_token_and_get_user_id(token) ) @defer.inlineCallbacks @@ -99,8 +99,10 @@ class AuthTestCase(unittest.TestCase): token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000) macaroon = pymacaroons.Macaroon.deserialize(token) - user_id = yield self.auth_handler.validate_short_term_login_token_and_get_user_id( - macaroon.serialize() + user_id = yield defer.ensureDeferred( + self.auth_handler.validate_short_term_login_token_and_get_user_id( + macaroon.serialize() + ) ) self.assertEqual("a_user", user_id) @@ -109,20 +111,26 @@ class AuthTestCase(unittest.TestCase): macaroon.add_first_party_caveat("user_id = b_user") with self.assertRaises(synapse.api.errors.AuthError): - yield self.auth_handler.validate_short_term_login_token_and_get_user_id( - macaroon.serialize() + yield defer.ensureDeferred( + self.auth_handler.validate_short_term_login_token_and_get_user_id( + macaroon.serialize() + ) ) @defer.inlineCallbacks def test_mau_limits_disabled(self): self.hs.config.limit_usage_by_mau = False # Ensure does not throw exception - yield self.auth_handler.get_access_token_for_user_id( - "user_a", device_id=None, valid_until_ms=None + yield defer.ensureDeferred( + self.auth_handler.get_access_token_for_user_id( + "user_a", device_id=None, valid_until_ms=None + ) ) - yield self.auth_handler.validate_short_term_login_token_and_get_user_id( - self._get_macaroon().serialize() + yield defer.ensureDeferred( + self.auth_handler.validate_short_term_login_token_and_get_user_id( + self._get_macaroon().serialize() + ) ) @defer.inlineCallbacks @@ -133,16 +141,20 @@ class AuthTestCase(unittest.TestCase): ) with self.assertRaises(ResourceLimitError): - yield self.auth_handler.get_access_token_for_user_id( - "user_a", device_id=None, valid_until_ms=None + yield defer.ensureDeferred( + self.auth_handler.get_access_token_for_user_id( + "user_a", device_id=None, valid_until_ms=None + ) ) self.hs.get_datastore().get_monthly_active_count = Mock( return_value=defer.succeed(self.large_number_of_users) ) with self.assertRaises(ResourceLimitError): - yield self.auth_handler.validate_short_term_login_token_and_get_user_id( - self._get_macaroon().serialize() + yield defer.ensureDeferred( + self.auth_handler.validate_short_term_login_token_and_get_user_id( + self._get_macaroon().serialize() + ) ) @defer.inlineCallbacks @@ -154,16 +166,20 @@ class AuthTestCase(unittest.TestCase): return_value=defer.succeed(self.hs.config.max_mau_value) ) with self.assertRaises(ResourceLimitError): - yield self.auth_handler.get_access_token_for_user_id( - "user_a", device_id=None, valid_until_ms=None + yield defer.ensureDeferred( + self.auth_handler.get_access_token_for_user_id( + "user_a", device_id=None, valid_until_ms=None + ) ) self.hs.get_datastore().get_monthly_active_count = Mock( return_value=defer.succeed(self.hs.config.max_mau_value) ) with self.assertRaises(ResourceLimitError): - yield self.auth_handler.validate_short_term_login_token_and_get_user_id( - self._get_macaroon().serialize() + yield defer.ensureDeferred( + self.auth_handler.validate_short_term_login_token_and_get_user_id( + self._get_macaroon().serialize() + ) ) # If in monthly active cohort self.hs.get_datastore().user_last_seen_monthly_active = Mock( @@ -172,8 +188,10 @@ class AuthTestCase(unittest.TestCase): self.hs.get_datastore().get_monthly_active_count = Mock( return_value=defer.succeed(self.hs.config.max_mau_value) ) - yield self.auth_handler.get_access_token_for_user_id( - "user_a", device_id=None, valid_until_ms=None + yield defer.ensureDeferred( + self.auth_handler.get_access_token_for_user_id( + "user_a", device_id=None, valid_until_ms=None + ) ) self.hs.get_datastore().user_last_seen_monthly_active = Mock( return_value=defer.succeed(self.hs.get_clock().time_msec()) @@ -181,8 +199,10 @@ class AuthTestCase(unittest.TestCase): self.hs.get_datastore().get_monthly_active_count = Mock( return_value=defer.succeed(self.hs.config.max_mau_value) ) - yield self.auth_handler.validate_short_term_login_token_and_get_user_id( - self._get_macaroon().serialize() + yield defer.ensureDeferred( + self.auth_handler.validate_short_term_login_token_and_get_user_id( + self._get_macaroon().serialize() + ) ) @defer.inlineCallbacks @@ -193,15 +213,19 @@ class AuthTestCase(unittest.TestCase): return_value=defer.succeed(self.small_number_of_users) ) # Ensure does not raise exception - yield self.auth_handler.get_access_token_for_user_id( - "user_a", device_id=None, valid_until_ms=None + yield defer.ensureDeferred( + self.auth_handler.get_access_token_for_user_id( + "user_a", device_id=None, valid_until_ms=None + ) ) self.hs.get_datastore().get_monthly_active_count = Mock( return_value=defer.succeed(self.small_number_of_users) ) - yield self.auth_handler.validate_short_term_login_token_and_get_user_id( - self._get_macaroon().serialize() + yield defer.ensureDeferred( + self.auth_handler.validate_short_term_login_token_and_get_user_id( + self._get_macaroon().serialize() + ) ) def _get_macaroon(self): diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index e7b638dbfe..f1dc51d6c9 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -294,7 +294,9 @@ class RegistrationTestCase(unittest.HomeserverTestCase): create_profile_with_displayname=user.localpart, ) else: - yield self.hs.get_auth_handler().delete_access_tokens_for_user(user_id) + yield defer.ensureDeferred( + self.hs.get_auth_handler().delete_access_tokens_for_user(user_id) + ) yield self.store.add_access_token_to_user( user_id=user_id, token=token, device_id=None, valid_until_ms=None diff --git a/tests/utils.py b/tests/utils.py index 968d109f77..2079e0143d 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -332,10 +332,15 @@ def setup_test_homeserver( # Need to let the HS build an auth handler and then mess with it # because AuthHandler's constructor requires the HS, so we can't make one # beforehand and pass it in to the HS's constructor (chicken / egg) - hs.get_auth_handler().hash = lambda p: hashlib.md5(p.encode("utf8")).hexdigest() - hs.get_auth_handler().validate_hash = ( - lambda p, h: hashlib.md5(p.encode("utf8")).hexdigest() == h - ) + async def hash(p): + return hashlib.md5(p.encode("utf8")).hexdigest() + + hs.get_auth_handler().hash = hash + + async def validate_hash(p, h): + return hashlib.md5(p.encode("utf8")).hexdigest() == h + + hs.get_auth_handler().validate_hash = validate_hash fed = kargs.get("resource_for_federation", None) if fed: -- cgit 1.5.1 From 01294e6b3a5c6ba909a0e70dad97a9ce7fdb64b9 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 16 Apr 2020 10:52:55 -0400 Subject: Do not treat display names as globs for push rules. (#7271) --- changelog.d/7271.bugfix | 1 + synapse/push/push_rule_evaluator.py | 69 +++++++++++++++++++--------------- tests/push/test_push_rule_evaluator.py | 65 ++++++++++++++++++++++++++++++++ tox.ini | 1 + 4 files changed, 106 insertions(+), 30 deletions(-) create mode 100644 changelog.d/7271.bugfix create mode 100644 tests/push/test_push_rule_evaluator.py (limited to 'tests') diff --git a/changelog.d/7271.bugfix b/changelog.d/7271.bugfix new file mode 100644 index 0000000000..e8315e4ce4 --- /dev/null +++ b/changelog.d/7271.bugfix @@ -0,0 +1 @@ +Do not treat display names as globs in push rules. diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py index b1587183a8..4cd702b5fa 100644 --- a/synapse/push/push_rule_evaluator.py +++ b/synapse/push/push_rule_evaluator.py @@ -16,9 +16,11 @@ import logging import re +from typing import Pattern from six import string_types +from synapse.events import EventBase from synapse.types import UserID from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache from synapse.util.caches.lrucache import LruCache @@ -56,18 +58,18 @@ def _test_ineq_condition(condition, number): rhs = m.group(2) if not rhs.isdigit(): return False - rhs = int(rhs) + rhs_int = int(rhs) if ineq == "" or ineq == "==": - return number == rhs + return number == rhs_int elif ineq == "<": - return number < rhs + return number < rhs_int elif ineq == ">": - return number > rhs + return number > rhs_int elif ineq == ">=": - return number >= rhs + return number >= rhs_int elif ineq == "<=": - return number <= rhs + return number <= rhs_int else: return False @@ -83,7 +85,13 @@ def tweaks_for_actions(actions): class PushRuleEvaluatorForEvent(object): - def __init__(self, event, room_member_count, sender_power_level, power_levels): + def __init__( + self, + event: EventBase, + room_member_count: int, + sender_power_level: int, + power_levels: dict, + ): self._event = event self._room_member_count = room_member_count self._sender_power_level = sender_power_level @@ -92,7 +100,7 @@ class PushRuleEvaluatorForEvent(object): # Maps strings of e.g. 'content.body' -> event["content"]["body"] self._value_cache = _flatten_dict(event) - def matches(self, condition, user_id, display_name): + def matches(self, condition: dict, user_id: str, display_name: str) -> bool: if condition["kind"] == "event_match": return self._event_match(condition, user_id) elif condition["kind"] == "contains_display_name": @@ -106,7 +114,7 @@ class PushRuleEvaluatorForEvent(object): else: return True - def _event_match(self, condition, user_id): + def _event_match(self, condition: dict, user_id: str) -> bool: pattern = condition.get("pattern", None) if not pattern: @@ -134,7 +142,7 @@ class PushRuleEvaluatorForEvent(object): return _glob_matches(pattern, haystack) - def _contains_display_name(self, display_name): + def _contains_display_name(self, display_name: str) -> bool: if not display_name: return False @@ -142,51 +150,52 @@ class PushRuleEvaluatorForEvent(object): if not body: return False - return _glob_matches(display_name, body, word_boundary=True) + # Similar to _glob_matches, but do not treat display_name as a glob. + r = regex_cache.get((display_name, False, True), None) + if not r: + r = re.escape(display_name) + r = _re_word_boundary(r) + r = re.compile(r, flags=re.IGNORECASE) + regex_cache[(display_name, False, True)] = r + + return r.search(body) - def _get_value(self, dotted_key): + def _get_value(self, dotted_key: str) -> str: return self._value_cache.get(dotted_key, None) -# Caches (glob, word_boundary) -> regex for push. See _glob_matches +# Caches (string, is_glob, word_boundary) -> regex for push. See _glob_matches regex_cache = LruCache(50000 * CACHE_SIZE_FACTOR) register_cache("cache", "regex_push_cache", regex_cache) -def _glob_matches(glob, value, word_boundary=False): +def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool: """Tests if value matches glob. Args: - glob (string) - value (string): String to test against glob. - word_boundary (bool): Whether to match against word boundaries or entire + glob + value: String to test against glob. + word_boundary: Whether to match against word boundaries or entire string. Defaults to False. - - Returns: - bool """ try: - r = regex_cache.get((glob, word_boundary), None) + r = regex_cache.get((glob, True, word_boundary), None) if not r: r = _glob_to_re(glob, word_boundary) - regex_cache[(glob, word_boundary)] = r + regex_cache[(glob, True, word_boundary)] = r return r.search(value) except re.error: logger.warning("Failed to parse glob to regex: %r", glob) return False -def _glob_to_re(glob, word_boundary): +def _glob_to_re(glob: str, word_boundary: bool) -> Pattern: """Generates regex for a given glob. Args: - glob (string) - word_boundary (bool): Whether to match against word boundaries or entire - string. Defaults to False. - - Returns: - regex object + glob + word_boundary: Whether to match against word boundaries or entire string. """ if IS_GLOB.search(glob): r = re.escape(glob) @@ -219,7 +228,7 @@ def _glob_to_re(glob, word_boundary): return re.compile(r, flags=re.IGNORECASE) -def _re_word_boundary(r): +def _re_word_boundary(r: str) -> str: """ Adds word boundary characters to the start and end of an expression to require that the match occur as a whole word, diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py new file mode 100644 index 0000000000..9ae6a87d7b --- /dev/null +++ b/tests/push/test_push_rule_evaluator.py @@ -0,0 +1,65 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.api.room_versions import RoomVersions +from synapse.events import FrozenEvent +from synapse.push.push_rule_evaluator import PushRuleEvaluatorForEvent + +from tests import unittest + + +class PushRuleEvaluatorTestCase(unittest.TestCase): + def setUp(self): + event = FrozenEvent( + { + "event_id": "$event_id", + "type": "m.room.history_visibility", + "sender": "@user:test", + "state_key": "", + "room_id": "@room:test", + "content": {"body": "foo bar baz"}, + }, + RoomVersions.V1, + ) + room_member_count = 0 + sender_power_level = 0 + power_levels = {} + self.evaluator = PushRuleEvaluatorForEvent( + event, room_member_count, sender_power_level, power_levels + ) + + def test_display_name(self): + """Check for a matching display name in the body of the event.""" + condition = { + "kind": "contains_display_name", + } + + # Blank names are skipped. + self.assertFalse(self.evaluator.matches(condition, "@user:test", "")) + + # Check a display name that doesn't match. + self.assertFalse(self.evaluator.matches(condition, "@user:test", "not found")) + + # Check a display name which matches. + self.assertTrue(self.evaluator.matches(condition, "@user:test", "foo")) + + # A display name that matches, but not a full word does not result in a match. + self.assertFalse(self.evaluator.matches(condition, "@user:test", "ba")) + + # A display name should not be interpreted as a regular expression. + self.assertFalse(self.evaluator.matches(condition, "@user:test", "ba[rz]")) + + # A display name with spaces should work fine. + self.assertTrue(self.evaluator.matches(condition, "@user:test", "foo bar")) diff --git a/tox.ini b/tox.ini index 763c8463d9..42b2d74891 100644 --- a/tox.ini +++ b/tox.ini @@ -196,6 +196,7 @@ commands = mypy \ synapse/metrics \ synapse/module_api \ synapse/push/pusherpool.py \ + synapse/push/push_rule_evaluator.py \ synapse/replication \ synapse/rest \ synapse/spam_checker_api \ -- cgit 1.5.1 From 51f7eaf908a84fcaf231899e2bf1beae14ae72c0 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 22 Apr 2020 13:07:41 +0100 Subject: Add ability to run replication protocol over redis. (#7040) This is configured via the `redis` config options. --- changelog.d/7040.feature | 1 + stubs/txredisapi.pyi | 40 +++++++ synapse/app/homeserver.py | 6 + synapse/config/homeserver.py | 2 + synapse/config/redis.py | 35 ++++++ synapse/python_dependencies.py | 1 + synapse/replication/tcp/client.py | 2 +- synapse/replication/tcp/commands.py | 18 +++ synapse/replication/tcp/handler.py | 50 +++++++-- synapse/replication/tcp/protocol.py | 38 ++----- synapse/replication/tcp/redis.py | 181 +++++++++++++++++++++++++++++++ tests/replication/slave/storage/_base.py | 4 +- 12 files changed, 342 insertions(+), 36 deletions(-) create mode 100644 changelog.d/7040.feature create mode 100644 stubs/txredisapi.pyi create mode 100644 synapse/config/redis.py create mode 100644 synapse/replication/tcp/redis.py (limited to 'tests') diff --git a/changelog.d/7040.feature b/changelog.d/7040.feature new file mode 100644 index 0000000000..ce6140fdd1 --- /dev/null +++ b/changelog.d/7040.feature @@ -0,0 +1 @@ +Add support for running replication over Redis when using workers. diff --git a/stubs/txredisapi.pyi b/stubs/txredisapi.pyi new file mode 100644 index 0000000000..763d3fb404 --- /dev/null +++ b/stubs/txredisapi.pyi @@ -0,0 +1,40 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Contains *incomplete* type hints for txredisapi. +""" + +from typing import List, Optional, Union + +class RedisProtocol: + def publish(self, channel: str, message: bytes): ... + +class SubscriberProtocol: + def subscribe(self, channels: Union[str, List[str]]): ... + +def lazyConnection( + host: str = ..., + port: int = ..., + dbid: Optional[int] = ..., + reconnect: bool = ..., + charset: str = ..., + password: Optional[str] = ..., + connectTimeout: Optional[int] = ..., + replyTimeout: Optional[int] = ..., + convertNumbers: bool = ..., +) -> RedisProtocol: ... + +class SubscriberFactory: + def buildProtocol(self, addr): ... diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 49df63acd0..cbd1ea475a 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -273,6 +273,12 @@ class SynapseHomeServer(HomeServer): def start_listening(self, listeners): config = self.get_config() + if config.redis_enabled: + # If redis is enabled we connect via the replication command handler + # in the same way as the workers (since we're effectively a client + # rather than a server). + self.get_tcp_replication().start_replication(self) + for listener in listeners: if listener["type"] == "http": self._listening_services.extend(self._listener_http(config, listener)) diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py index b4bca08b20..be6c6afa74 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py @@ -31,6 +31,7 @@ from .password import PasswordConfig from .password_auth_providers import PasswordAuthProviderConfig from .push import PushConfig from .ratelimiting import RatelimitConfig +from .redis import RedisConfig from .registration import RegistrationConfig from .repository import ContentRepositoryConfig from .room_directory import RoomDirectoryConfig @@ -82,4 +83,5 @@ class HomeServerConfig(RootConfig): RoomDirectoryConfig, ThirdPartyRulesConfig, TracerConfig, + RedisConfig, ] diff --git a/synapse/config/redis.py b/synapse/config/redis.py new file mode 100644 index 0000000000..81a27619ec --- /dev/null +++ b/synapse/config/redis.py @@ -0,0 +1,35 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.config._base import Config +from synapse.python_dependencies import check_requirements + + +class RedisConfig(Config): + section = "redis" + + def read_config(self, config, **kwargs): + redis_config = config.get("redis", {}) + self.redis_enabled = redis_config.get("enabled", False) + + if not self.redis_enabled: + return + + check_requirements("redis") + + self.redis_host = redis_config.get("host", "localhost") + self.redis_port = redis_config.get("port", 6379) + self.redis_dbid = redis_config.get("dbid") + self.redis_password = redis_config.get("password") diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index 8de8cb2c12..733c51b758 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -98,6 +98,7 @@ CONDITIONAL_REQUIREMENTS = { "sentry": ["sentry-sdk>=0.7.2"], "opentracing": ["jaeger-client>=4.0.0", "opentracing>=2.2.0"], "jwt": ["pyjwt>=1.6.4"], + "redis": ["txredisapi>=1.4.7"], } ALL_OPTIONAL_REQUIREMENTS = set() # type: Set[str] diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 700ae79158..2d07b8b2d0 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -30,7 +30,7 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class ReplicationClientFactory(ReconnectingClientFactory): +class DirectTcpReplicationClientFactory(ReconnectingClientFactory): """Factory for building connections to the master. Will reconnect if the connection is lost. diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index 5ec89d0fb8..5002efe6a0 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -454,3 +454,21 @@ VALID_CLIENT_COMMANDS = ( ErrorCommand.NAME, RemoteServerUpCommand.NAME, ) + + +def parse_command_from_line(line: str) -> Command: + """Parses a command from a received line. + + Line should already be stripped of whitespace and be checked if blank. + """ + + idx = line.index(" ") + if idx >= 0: + cmd_name = line[:idx] + rest_of_line = line[idx + 1 :] + else: + cmd_name = line + rest_of_line = "" + + cmd_cls = COMMAND_MAP[cmd_name] + return cmd_cls.from_line(rest_of_line) diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index e32e68e8c4..5b5ee2c13e 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -30,8 +30,10 @@ from typing import ( from prometheus_client import Counter +from twisted.internet.protocol import ReconnectingClientFactory + from synapse.metrics import LaterGauge -from synapse.replication.tcp.client import ReplicationClientFactory +from synapse.replication.tcp.client import DirectTcpReplicationClientFactory from synapse.replication.tcp.commands import ( ClearUserSyncsCommand, Command, @@ -92,7 +94,7 @@ class ReplicationCommandHandler: self._pending_batches = {} # type: Dict[str, List[Any]] # The factory used to create connections. - self._factory = None # type: Optional[ReplicationClientFactory] + self._factory = None # type: Optional[ReconnectingClientFactory] # The currently connected connections. self._connections = [] # type: List[AbstractConnection] @@ -119,11 +121,45 @@ class ReplicationCommandHandler: """Helper method to start a replication connection to the remote server using TCP. """ - client_name = hs.config.worker_name - self._factory = ReplicationClientFactory(hs, client_name, self) - host = hs.config.worker_replication_host - port = hs.config.worker_replication_port - hs.get_reactor().connectTCP(host, port, self._factory) + if hs.config.redis.redis_enabled: + from synapse.replication.tcp.redis import ( + RedisDirectTcpReplicationClientFactory, + ) + import txredisapi + + logger.info( + "Connecting to redis (host=%r port=%r DBID=%r)", + hs.config.redis_host, + hs.config.redis_port, + hs.config.redis_dbid, + ) + + # We need two connections to redis, one for the subscription stream and + # one to send commands to (as you can't send further redis commands to a + # connection after SUBSCRIBE is called). + + # First create the connection for sending commands. + outbound_redis_connection = txredisapi.lazyConnection( + host=hs.config.redis_host, + port=hs.config.redis_port, + dbid=hs.config.redis_dbid, + password=hs.config.redis.redis_password, + reconnect=True, + ) + + # Now create the factory/connection for the subscription stream. + self._factory = RedisDirectTcpReplicationClientFactory( + hs, outbound_redis_connection + ) + hs.get_reactor().connectTCP( + hs.config.redis.redis_host, hs.config.redis.redis_port, self._factory, + ) + else: + client_name = hs.config.worker_name + self._factory = DirectTcpReplicationClientFactory(hs, client_name, self) + host = hs.config.worker_replication_host + port = hs.config.worker_replication_port + hs.get_reactor().connectTCP(host, port, self._factory) async def on_REPLICATE(self, cmd: ReplicateCommand): # We only want to announce positions by the writer of the streams. diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index 9276ed2965..7240acb0a2 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -63,7 +63,6 @@ from twisted.python.failure import Failure from synapse.metrics import LaterGauge from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.tcp.commands import ( - COMMAND_MAP, VALID_CLIENT_COMMANDS, VALID_SERVER_COMMANDS, Command, @@ -72,6 +71,7 @@ from synapse.replication.tcp.commands import ( PingCommand, ReplicateCommand, ServerCommand, + parse_command_from_line, ) from synapse.types import Collection from synapse.util import Clock @@ -210,38 +210,24 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): linestr = line.decode("utf-8") - # split at the first " ", handling one-word commands - idx = linestr.index(" ") - if idx >= 0: - cmd_name = linestr[:idx] - rest_of_line = linestr[idx + 1 :] - else: - cmd_name = linestr - rest_of_line = "" + try: + cmd = parse_command_from_line(linestr) + except Exception as e: + logger.exception("[%s] failed to parse line: %r", self.id(), linestr) + self.send_error("failed to parse line: %r (%r):" % (e, linestr)) + return - if cmd_name not in self.VALID_INBOUND_COMMANDS: - logger.error("[%s] invalid command %s", self.id(), cmd_name) - self.send_error("invalid command: %s", cmd_name) + if cmd.NAME not in self.VALID_INBOUND_COMMANDS: + logger.error("[%s] invalid command %s", self.id(), cmd.NAME) + self.send_error("invalid command: %s", cmd.NAME) return self.last_received_command = self.clock.time_msec() - self.inbound_commands_counter[cmd_name] = ( - self.inbound_commands_counter[cmd_name] + 1 + self.inbound_commands_counter[cmd.NAME] = ( + self.inbound_commands_counter[cmd.NAME] + 1 ) - cmd_cls = COMMAND_MAP[cmd_name] - try: - cmd = cmd_cls.from_line(rest_of_line) - except Exception as e: - logger.exception( - "[%s] failed to parse line %r: %r", self.id(), cmd_name, rest_of_line - ) - self.send_error( - "failed to parse line for %r: %r (%r):" % (cmd_name, e, rest_of_line) - ) - return - # Now lets try and call on_ function run_as_background_process( "replication-" + cmd.get_logcontext_id(), self.handle_command, cmd diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py new file mode 100644 index 0000000000..4c08425735 --- /dev/null +++ b/synapse/replication/tcp/redis.py @@ -0,0 +1,181 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import TYPE_CHECKING + +import txredisapi + +from synapse.logging.context import PreserveLoggingContext +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.replication.tcp.commands import ( + Command, + ReplicateCommand, + parse_command_from_line, +) +from synapse.replication.tcp.protocol import AbstractConnection + +if TYPE_CHECKING: + from synapse.replication.tcp.handler import ReplicationCommandHandler + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): + """Connection to redis subscribed to replication stream. + + Parses incoming messages from redis into replication commands, and passes + them to `ReplicationCommandHandler` + + Due to the vagaries of `txredisapi` we don't want to have a custom + constructor, so instead we expect the defined attributes below to be set + immediately after initialisation. + + Attributes: + handler: The command handler to handle incoming commands. + stream_name: The *redis* stream name to subscribe to (not anything to + do with Synapse replication streams). + outbound_redis_connection: The connection to redis to use to send + commands. + """ + + handler = None # type: ReplicationCommandHandler + stream_name = None # type: str + outbound_redis_connection = None # type: txredisapi.RedisProtocol + + def connectionMade(self): + logger.info("Connected to redis instance") + self.subscribe(self.stream_name) + self.send_command(ReplicateCommand()) + + self.handler.new_connection(self) + + def messageReceived(self, pattern: str, channel: str, message: str): + """Received a message from redis. + """ + + if message.strip() == "": + # Ignore blank lines + return + + try: + cmd = parse_command_from_line(message) + except Exception: + logger.exception( + "[%s] failed to parse line: %r", message, + ) + return + + # Now lets try and call on_ function + run_as_background_process( + "replication-" + cmd.get_logcontext_id(), self.handle_command, cmd + ) + + async def handle_command(self, cmd: Command): + """Handle a command we have received over the replication stream. + + By default delegates to on_, which should return an awaitable. + + Args: + cmd: received command + """ + handled = False + + # First call any command handlers on this instance. These are for redis + # specific handling. + cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None) + if cmd_func: + await cmd_func(cmd) + handled = True + + # Then call out to the handler. + cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None) + if cmd_func: + await cmd_func(cmd) + handled = True + + if not handled: + logger.warning("Unhandled command: %r", cmd) + + def connectionLost(self, reason): + logger.info("Lost connection to redis instance") + self.handler.lost_connection(self) + + def send_command(self, cmd: Command): + """Send a command if connection has been established. + + Args: + cmd (Command) + """ + string = "%s %s" % (cmd.NAME, cmd.to_line()) + if "\n" in string: + raise Exception("Unexpected newline in command: %r", string) + + encoded_string = string.encode("utf-8") + + async def _send(): + with PreserveLoggingContext(): + # Note that we use the other connection as we can't send + # commands using the subscription connection. + await self.outbound_redis_connection.publish( + self.stream_name, encoded_string + ) + + run_as_background_process("send-cmd", _send) + + +class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory): + """This is a reconnecting factory that connects to redis and immediately + subscribes to a stream. + + Args: + hs + outbound_redis_connection: A connection to redis that will be used to + send outbound commands (this is seperate to the redis connection + used to subscribe). + """ + + maxDelay = 5 + continueTrying = True + protocol = RedisSubscriber + + def __init__( + self, hs: "HomeServer", outbound_redis_connection: txredisapi.RedisProtocol + ): + + super().__init__() + + # This sets the password on the RedisFactory base class (as + # SubscriberFactory constructor doesn't pass it through). + self.password = hs.config.redis.redis_password + + self.handler = hs.get_tcp_replication() + self.stream_name = hs.hostname + + self.outbound_redis_connection = outbound_redis_connection + + def buildProtocol(self, addr): + p = super().buildProtocol(addr) # type: RedisSubscriber + + # We do this here rather than add to the constructor of `RedisSubcriber` + # as to do so would involve overriding `buildProtocol` entirely, however + # the base method does some other things than just instantiating the + # protocol. + p.handler = self.handler + p.outbound_redis_connection = self.outbound_redis_connection + p.stream_name = self.stream_name + + return p diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index 8902a5ab69..395c7d0306 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -16,7 +16,7 @@ from mock import Mock, NonCallableMock from synapse.replication.tcp.client import ( - ReplicationClientFactory, + DirectTcpReplicationClientFactory, ReplicationDataHandler, ) from synapse.replication.tcp.handler import ReplicationCommandHandler @@ -61,7 +61,7 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase): self.slaved_store ) - client_factory = ReplicationClientFactory( + client_factory = DirectTcpReplicationClientFactory( self.hs, "client_name", self.replication_handler ) client_factory.handler = self.replication_handler -- cgit 1.5.1 From 6b6685db9f8cf8a55dfe0edc8f2131be8001c360 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Wed, 22 Apr 2020 14:38:41 +0200 Subject: Extend room admin api with additional attributes (#7225) --- changelog.d/7225.misc | 1 + docs/admin_api/rooms.md | 107 ++++- synapse/rest/admin/rooms.py | 15 +- synapse/storage/data_stores/main/room.py | 78 +++- tests/rest/admin/test_admin.py | 592 +-------------------------- tests/rest/admin/test_room.py | 680 ++++++++++++++++++++++++++++++- 6 files changed, 869 insertions(+), 604 deletions(-) create mode 100644 changelog.d/7225.misc (limited to 'tests') diff --git a/changelog.d/7225.misc b/changelog.d/7225.misc new file mode 100644 index 0000000000..375e2a475f --- /dev/null +++ b/changelog.d/7225.misc @@ -0,0 +1 @@ +Extend room admin api (`GET /_synapse/admin/v1/rooms`) with additional attributes. \ No newline at end of file diff --git a/docs/admin_api/rooms.md b/docs/admin_api/rooms.md index 2db457c1b6..26fe8b8679 100644 --- a/docs/admin_api/rooms.md +++ b/docs/admin_api/rooms.md @@ -11,8 +11,21 @@ The following query parameters are available: * `from` - Offset in the returned list. Defaults to `0`. * `limit` - Maximum amount of rooms to return. Defaults to `100`. * `order_by` - The method in which to sort the returned list of rooms. Valid values are: - - `alphabetical` - Rooms are ordered alphabetically by room name. This is the default. - - `size` - Rooms are ordered by the number of members. Largest to smallest. + - `alphabetical` - Same as `name`. This is deprecated. + - `size` - Same as `joined_members`. This is deprecated. + - `name` - Rooms are ordered alphabetically by room name. This is the default. + - `canonical_alias` - Rooms are ordered alphabetically by main alias address of the room. + - `joined_members` - Rooms are ordered by the number of members. Largest to smallest. + - `joined_local_members` - Rooms are ordered by the number of local members. Largest to smallest. + - `version` - Rooms are ordered by room version. Largest to smallest. + - `creator` - Rooms are ordered alphabetically by creator of the room. + - `encryption` - Rooms are ordered alphabetically by the end-to-end encryption algorithm. + - `federatable` - Rooms are ordered by whether the room is federatable. + - `public` - Rooms are ordered by visibility in room list. + - `join_rules` - Rooms are ordered alphabetically by join rules of the room. + - `guest_access` - Rooms are ordered alphabetically by guest access option of the room. + - `history_visibility` - Rooms are ordered alphabetically by visibility of history of the room. + - `state_events` - Rooms are ordered by number of state events. Largest to smallest. * `dir` - Direction of room order. Either `f` for forwards or `b` for backwards. Setting this value to `b` will reverse the above sort order. Defaults to `f`. * `search_term` - Filter rooms by their room name. Search term can be contained in any @@ -26,6 +39,16 @@ The following fields are possible in the JSON response body: - `name` - The name of the room. - `canonical_alias` - The canonical (main) alias address of the room. - `joined_members` - How many users are currently in the room. + - `joined_local_members` - How many local users are currently in the room. + - `version` - The version of the room as a string. + - `creator` - The `user_id` of the room creator. + - `encryption` - Algorithm of end-to-end encryption of messages. Is `null` if encryption is not active. + - `federatable` - Whether users on other servers can join this room. + - `public` - Whether the room is visible in room directory. + - `join_rules` - The type of rules used for users wishing to join this room. One of: ["public", "knock", "invite", "private"]. + - `guest_access` - Whether guests can join the room. One of: ["can_join", "forbidden"]. + - `history_visibility` - Who can see the room history. One of: ["invited", "joined", "shared", "world_readable"]. + - `state_events` - Total number of state_events of a room. Complexity of the room. * `offset` - The current pagination offset in rooms. This parameter should be used instead of `next_token` for room offset as `next_token` is not intended to be parsed. @@ -60,14 +83,34 @@ Response: "room_id": "!OGEhHVWSdvArJzumhm:matrix.org", "name": "Matrix HQ", "canonical_alias": "#matrix:matrix.org", - "joined_members": 8326 + "joined_members": 8326, + "joined_local_members": 2, + "version": "1", + "creator": "@foo:matrix.org", + "encryption": null, + "federatable": true, + "public": true, + "join_rules": "invite", + "guest_access": null, + "history_visibility": "shared", + "state_events": 93534 }, ... (8 hidden items) ... { "room_id": "!xYvNcQPhnkrdUmYczI:matrix.org", "name": "This Week In Matrix (TWIM)", "canonical_alias": "#twim:matrix.org", - "joined_members": 314 + "joined_members": 314, + "joined_local_members": 20, + "version": "4", + "creator": "@foo:matrix.org", + "encryption": "m.megolm.v1.aes-sha2", + "federatable": true, + "public": false, + "join_rules": "invite", + "guest_access": null, + "history_visibility": "shared", + "state_events": 8345 } ], "offset": 0, @@ -92,7 +135,17 @@ Response: "room_id": "!xYvNcQPhnkrdUmYczI:matrix.org", "name": "This Week In Matrix (TWIM)", "canonical_alias": "#twim:matrix.org", - "joined_members": 314 + "joined_members": 314, + "joined_local_members": 20, + "version": "4", + "creator": "@foo:matrix.org", + "encryption": "m.megolm.v1.aes-sha2", + "federatable": true, + "public": false, + "join_rules": "invite", + "guest_access": null, + "history_visibility": "shared", + "state_events": 8 } ], "offset": 0, @@ -117,14 +170,34 @@ Response: "room_id": "!OGEhHVWSdvArJzumhm:matrix.org", "name": "Matrix HQ", "canonical_alias": "#matrix:matrix.org", - "joined_members": 8326 + "joined_members": 8326, + "joined_local_members": 2, + "version": "1", + "creator": "@foo:matrix.org", + "encryption": null, + "federatable": true, + "public": true, + "join_rules": "invite", + "guest_access": null, + "history_visibility": "shared", + "state_events": 93534 }, ... (98 hidden items) ... { "room_id": "!xYvNcQPhnkrdUmYczI:matrix.org", "name": "This Week In Matrix (TWIM)", "canonical_alias": "#twim:matrix.org", - "joined_members": 314 + "joined_members": 314, + "joined_local_members": 20, + "version": "4", + "creator": "@foo:matrix.org", + "encryption": "m.megolm.v1.aes-sha2", + "federatable": true, + "public": false, + "join_rules": "invite", + "guest_access": null, + "history_visibility": "shared", + "state_events": 8345 } ], "offset": 0, @@ -154,6 +227,16 @@ Response: "name": "Music Theory", "canonical_alias": "#musictheory:matrix.org", "joined_members": 127 + "joined_local_members": 2, + "version": "1", + "creator": "@foo:matrix.org", + "encryption": null, + "federatable": true, + "public": true, + "join_rules": "invite", + "guest_access": null, + "history_visibility": "shared", + "state_events": 93534 }, ... (48 hidden items) ... { @@ -161,6 +244,16 @@ Response: "name": "weechat-matrix", "canonical_alias": "#weechat-matrix:termina.org.uk", "joined_members": 137 + "joined_local_members": 20, + "version": "4", + "creator": "@foo:termina.org.uk", + "encryption": null, + "federatable": true, + "public": true, + "join_rules": "invite", + "guest_access": null, + "history_visibility": "shared", + "state_events": 8345 } ], "offset": 100, diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 659b8a10ee..d1bdb64111 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -183,10 +183,23 @@ class ListRoomRestServlet(RestServlet): # Extract query parameters start = parse_integer(request, "from", default=0) limit = parse_integer(request, "limit", default=100) - order_by = parse_string(request, "order_by", default="alphabetical") + order_by = parse_string(request, "order_by", default=RoomSortOrder.NAME.value) if order_by not in ( RoomSortOrder.ALPHABETICAL.value, RoomSortOrder.SIZE.value, + RoomSortOrder.NAME.value, + RoomSortOrder.CANONICAL_ALIAS.value, + RoomSortOrder.JOINED_MEMBERS.value, + RoomSortOrder.JOINED_LOCAL_MEMBERS.value, + RoomSortOrder.VERSION.value, + RoomSortOrder.CREATOR.value, + RoomSortOrder.ENCRYPTION.value, + RoomSortOrder.FEDERATABLE.value, + RoomSortOrder.PUBLIC.value, + RoomSortOrder.JOIN_RULES.value, + RoomSortOrder.GUEST_ACCESS.value, + RoomSortOrder.HISTORY_VISIBILITY.value, + RoomSortOrder.STATE_EVENTS.value, ): raise SynapseError( 400, diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py index aaebe427d3..147eba1df7 100644 --- a/synapse/storage/data_stores/main/room.py +++ b/synapse/storage/data_stores/main/room.py @@ -52,12 +52,28 @@ class RoomSortOrder(Enum): """ Enum to define the sorting method used when returning rooms with get_rooms_paginate - ALPHABETICAL = sort rooms alphabetically by name - SIZE = sort rooms by membership size, highest to lowest + NAME = sort rooms alphabetically by name + JOINED_MEMBERS = sort rooms by membership size, highest to lowest """ + # ALPHABETICAL and SIZE are deprecated. + # ALPHABETICAL is the same as NAME. ALPHABETICAL = "alphabetical" + # SIZE is the same as JOINED_MEMBERS. SIZE = "size" + NAME = "name" + CANONICAL_ALIAS = "canonical_alias" + JOINED_MEMBERS = "joined_members" + JOINED_LOCAL_MEMBERS = "joined_local_members" + VERSION = "version" + CREATOR = "creator" + ENCRYPTION = "encryption" + FEDERATABLE = "federatable" + PUBLIC = "public" + JOIN_RULES = "join_rules" + GUEST_ACCESS = "guest_access" + HISTORY_VISIBILITY = "history_visibility" + STATE_EVENTS = "state_events" class RoomWorkerStore(SQLBaseStore): @@ -329,12 +345,52 @@ class RoomWorkerStore(SQLBaseStore): # Set ordering if RoomSortOrder(order_by) == RoomSortOrder.SIZE: + # Deprecated in favour of RoomSortOrder.JOINED_MEMBERS order_by_column = "curr.joined_members" order_by_asc = False elif RoomSortOrder(order_by) == RoomSortOrder.ALPHABETICAL: - # Sort alphabetically + # Deprecated in favour of RoomSortOrder.NAME order_by_column = "state.name" order_by_asc = True + elif RoomSortOrder(order_by) == RoomSortOrder.NAME: + order_by_column = "state.name" + order_by_asc = True + elif RoomSortOrder(order_by) == RoomSortOrder.CANONICAL_ALIAS: + order_by_column = "state.canonical_alias" + order_by_asc = True + elif RoomSortOrder(order_by) == RoomSortOrder.JOINED_MEMBERS: + order_by_column = "curr.joined_members" + order_by_asc = False + elif RoomSortOrder(order_by) == RoomSortOrder.JOINED_LOCAL_MEMBERS: + order_by_column = "curr.local_users_in_room" + order_by_asc = False + elif RoomSortOrder(order_by) == RoomSortOrder.VERSION: + order_by_column = "rooms.room_version" + order_by_asc = False + elif RoomSortOrder(order_by) == RoomSortOrder.CREATOR: + order_by_column = "rooms.creator" + order_by_asc = True + elif RoomSortOrder(order_by) == RoomSortOrder.ENCRYPTION: + order_by_column = "state.encryption" + order_by_asc = True + elif RoomSortOrder(order_by) == RoomSortOrder.FEDERATABLE: + order_by_column = "state.is_federatable" + order_by_asc = True + elif RoomSortOrder(order_by) == RoomSortOrder.PUBLIC: + order_by_column = "rooms.is_public" + order_by_asc = True + elif RoomSortOrder(order_by) == RoomSortOrder.JOIN_RULES: + order_by_column = "state.join_rules" + order_by_asc = True + elif RoomSortOrder(order_by) == RoomSortOrder.GUEST_ACCESS: + order_by_column = "state.guest_access" + order_by_asc = True + elif RoomSortOrder(order_by) == RoomSortOrder.HISTORY_VISIBILITY: + order_by_column = "state.history_visibility" + order_by_asc = True + elif RoomSortOrder(order_by) == RoomSortOrder.STATE_EVENTS: + order_by_column = "curr.current_state_events" + order_by_asc = False else: raise StoreError( 500, "Incorrect value for order_by provided: %s" % order_by @@ -349,9 +405,13 @@ class RoomWorkerStore(SQLBaseStore): # for, and another query for getting the total number of events that could be # returned. Thus allowing us to see if there are more events to paginate through info_sql = """ - SELECT state.room_id, state.name, state.canonical_alias, curr.joined_members + SELECT state.room_id, state.name, state.canonical_alias, curr.joined_members, + curr.local_users_in_room, rooms.room_version, rooms.creator, + state.encryption, state.is_federatable, rooms.is_public, state.join_rules, + state.guest_access, state.history_visibility, curr.current_state_events FROM room_stats_state state INNER JOIN room_stats_current curr USING (room_id) + INNER JOIN rooms USING (room_id) %s ORDER BY %s %s LIMIT ? @@ -389,6 +449,16 @@ class RoomWorkerStore(SQLBaseStore): "name": room[1], "canonical_alias": room[2], "joined_members": room[3], + "joined_local_members": room[4], + "version": room[5], + "creator": room[6], + "encryption": room[7], + "federatable": room[8], + "public": room[9], + "join_rules": room[10], + "guest_access": room[11], + "history_visibility": room[12], + "state_events": room[13], } ) diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index 0342aed416..977615ebef 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -17,7 +17,6 @@ import json import os import urllib.parse from binascii import unhexlify -from typing import List, Optional from mock import Mock @@ -27,7 +26,7 @@ import synapse.rest.admin from synapse.http.server import JsonResource from synapse.logging.context import make_deferred_yieldable from synapse.rest.admin import VersionServlet -from synapse.rest.client.v1 import directory, events, login, room +from synapse.rest.client.v1 import login, room from synapse.rest.client.v2_alpha import groups from tests import unittest @@ -51,129 +50,6 @@ class VersionTestCase(unittest.HomeserverTestCase): ) -class ShutdownRoomTestCase(unittest.HomeserverTestCase): - servlets = [ - synapse.rest.admin.register_servlets_for_client_rest_resource, - login.register_servlets, - events.register_servlets, - room.register_servlets, - room.register_deprecated_servlets, - ] - - def prepare(self, reactor, clock, hs): - self.event_creation_handler = hs.get_event_creation_handler() - hs.config.user_consent_version = "1" - - consent_uri_builder = Mock() - consent_uri_builder.build_user_consent_uri.return_value = "http://example.com" - self.event_creation_handler._consent_uri_builder = consent_uri_builder - - self.store = hs.get_datastore() - - self.admin_user = self.register_user("admin", "pass", admin=True) - self.admin_user_tok = self.login("admin", "pass") - - self.other_user = self.register_user("user", "pass") - self.other_user_token = self.login("user", "pass") - - # Mark the admin user as having consented - self.get_success(self.store.user_set_consent_version(self.admin_user, "1")) - - def test_shutdown_room_consent(self): - """Test that we can shutdown rooms with local users who have not - yet accepted the privacy policy. This used to fail when we tried to - force part the user from the old room. - """ - self.event_creation_handler._block_events_without_consent_error = None - - room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token) - - # Assert one user in room - users_in_room = self.get_success(self.store.get_users_in_room(room_id)) - self.assertEqual([self.other_user], users_in_room) - - # Enable require consent to send events - self.event_creation_handler._block_events_without_consent_error = "Error" - - # Assert that the user is getting consent error - self.helper.send( - room_id, body="foo", tok=self.other_user_token, expect_code=403 - ) - - # Test that the admin can still send shutdown - url = "admin/shutdown_room/" + room_id - request, channel = self.make_request( - "POST", - url.encode("ascii"), - json.dumps({"new_room_user_id": self.admin_user}), - access_token=self.admin_user_tok, - ) - self.render(request) - - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - - # Assert there is now no longer anyone in the room - users_in_room = self.get_success(self.store.get_users_in_room(room_id)) - self.assertEqual([], users_in_room) - - def test_shutdown_room_block_peek(self): - """Test that a world_readable room can no longer be peeked into after - it has been shut down. - """ - - self.event_creation_handler._block_events_without_consent_error = None - - room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token) - - # Enable world readable - url = "rooms/%s/state/m.room.history_visibility" % (room_id,) - request, channel = self.make_request( - "PUT", - url.encode("ascii"), - json.dumps({"history_visibility": "world_readable"}), - access_token=self.other_user_token, - ) - self.render(request) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - - # Test that the admin can still send shutdown - url = "admin/shutdown_room/" + room_id - request, channel = self.make_request( - "POST", - url.encode("ascii"), - json.dumps({"new_room_user_id": self.admin_user}), - access_token=self.admin_user_tok, - ) - self.render(request) - - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - - # Assert we can no longer peek into the room - self._assert_peek(room_id, expect_code=403) - - def _assert_peek(self, room_id, expect_code): - """Assert that the admin user can (or cannot) peek into the room. - """ - - url = "rooms/%s/initialSync" % (room_id,) - request, channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok - ) - self.render(request) - self.assertEqual( - expect_code, int(channel.result["code"]), msg=channel.result["body"] - ) - - url = "events?timeout=0&room_id=" + room_id - request, channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok - ) - self.render(request) - self.assertEqual( - expect_code, int(channel.result["code"]), msg=channel.result["body"] - ) - - class DeleteGroupTestCase(unittest.HomeserverTestCase): servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, @@ -273,86 +149,6 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase): return channel.json_body["groups"] -class PurgeRoomTestCase(unittest.HomeserverTestCase): - """Test /purge_room admin API. - """ - - servlets = [ - synapse.rest.admin.register_servlets, - login.register_servlets, - room.register_servlets, - ] - - def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() - - self.admin_user = self.register_user("admin", "pass", admin=True) - self.admin_user_tok = self.login("admin", "pass") - - def test_purge_room(self): - room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) - - # All users have to have left the room. - self.helper.leave(room_id, user=self.admin_user, tok=self.admin_user_tok) - - url = "/_synapse/admin/v1/purge_room" - request, channel = self.make_request( - "POST", - url.encode("ascii"), - {"room_id": room_id}, - access_token=self.admin_user_tok, - ) - self.render(request) - - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - - # Test that the following tables have been purged of all rows related to the room. - for table in ( - "current_state_events", - "event_backward_extremities", - "event_forward_extremities", - "event_json", - "event_push_actions", - "event_search", - "events", - "group_rooms", - "public_room_list_stream", - "receipts_graph", - "receipts_linearized", - "room_aliases", - "room_depth", - "room_memberships", - "room_stats_state", - "room_stats_current", - "room_stats_historical", - "room_stats_earliest_token", - "rooms", - "stream_ordering_to_exterm", - "users_in_public_rooms", - "users_who_share_private_rooms", - "appservice_room_list", - "e2e_room_keys", - "event_push_summary", - "pusher_throttle", - "group_summary_rooms", - "local_invites", - "room_account_data", - "room_tags", - # "state_groups", # Current impl leaves orphaned state groups around. - "state_groups_state", - ): - count = self.get_success( - self.store.db.simple_select_one_onecol( - table=table, - keyvalues={"room_id": room_id}, - retcol="COUNT(*)", - desc="test_purge_room", - ) - ) - - self.assertEqual(count, 0, msg="Rows not purged in {}".format(table)) - - class QuarantineMediaTestCase(unittest.HomeserverTestCase): """Test /quarantine_media admin API. """ @@ -691,389 +487,3 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): % server_and_media_id_2 ), ) - - -class RoomTestCase(unittest.HomeserverTestCase): - """Test /room admin API. - """ - - servlets = [ - synapse.rest.admin.register_servlets, - login.register_servlets, - room.register_servlets, - directory.register_servlets, - ] - - def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() - - # Create user - self.admin_user = self.register_user("admin", "pass", admin=True) - self.admin_user_tok = self.login("admin", "pass") - - def test_list_rooms(self): - """Test that we can list rooms""" - # Create 3 test rooms - total_rooms = 3 - room_ids = [] - for x in range(total_rooms): - room_id = self.helper.create_room_as( - self.admin_user, tok=self.admin_user_tok - ) - room_ids.append(room_id) - - # Request the list of rooms - url = "/_synapse/admin/v1/rooms" - request, channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok, - ) - self.render(request) - - # Check request completed successfully - self.assertEqual(200, int(channel.code), msg=channel.json_body) - - # Check that response json body contains a "rooms" key - self.assertTrue( - "rooms" in channel.json_body, - msg="Response body does not " "contain a 'rooms' key", - ) - - # Check that 3 rooms were returned - self.assertEqual(3, len(channel.json_body["rooms"]), msg=channel.json_body) - - # Check their room_ids match - returned_room_ids = [room["room_id"] for room in channel.json_body["rooms"]] - self.assertEqual(room_ids, returned_room_ids) - - # Check that all fields are available - for r in channel.json_body["rooms"]: - self.assertIn("name", r) - self.assertIn("canonical_alias", r) - self.assertIn("joined_members", r) - - # Check that the correct number of total rooms was returned - self.assertEqual(channel.json_body["total_rooms"], total_rooms) - - # Check that the offset is correct - # Should be 0 as we aren't paginating - self.assertEqual(channel.json_body["offset"], 0) - - # Check that the prev_batch parameter is not present - self.assertNotIn("prev_batch", channel.json_body) - - # We shouldn't receive a next token here as there's no further rooms to show - self.assertNotIn("next_batch", channel.json_body) - - def test_list_rooms_pagination(self): - """Test that we can get a full list of rooms through pagination""" - # Create 5 test rooms - total_rooms = 5 - room_ids = [] - for x in range(total_rooms): - room_id = self.helper.create_room_as( - self.admin_user, tok=self.admin_user_tok - ) - room_ids.append(room_id) - - # Set the name of the rooms so we get a consistent returned ordering - for idx, room_id in enumerate(room_ids): - self.helper.send_state( - room_id, "m.room.name", {"name": str(idx)}, tok=self.admin_user_tok, - ) - - # Request the list of rooms - returned_room_ids = [] - start = 0 - limit = 2 - - run_count = 0 - should_repeat = True - while should_repeat: - run_count += 1 - - url = "/_synapse/admin/v1/rooms?from=%d&limit=%d&order_by=%s" % ( - start, - limit, - "alphabetical", - ) - request, channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok, - ) - self.render(request) - self.assertEqual( - 200, int(channel.result["code"]), msg=channel.result["body"] - ) - - self.assertTrue("rooms" in channel.json_body) - for r in channel.json_body["rooms"]: - returned_room_ids.append(r["room_id"]) - - # Check that the correct number of total rooms was returned - self.assertEqual(channel.json_body["total_rooms"], total_rooms) - - # Check that the offset is correct - # We're only getting 2 rooms each page, so should be 2 * last run_count - self.assertEqual(channel.json_body["offset"], 2 * (run_count - 1)) - - if run_count > 1: - # Check the value of prev_batch is correct - self.assertEqual(channel.json_body["prev_batch"], 2 * (run_count - 2)) - - if "next_batch" not in channel.json_body: - # We have reached the end of the list - should_repeat = False - else: - # Make another query with an updated start value - start = channel.json_body["next_batch"] - - # We should've queried the endpoint 3 times - self.assertEqual( - run_count, - 3, - msg="Should've queried 3 times for 5 rooms with limit 2 per query", - ) - - # Check that we received all of the room ids - self.assertEqual(room_ids, returned_room_ids) - - url = "/_synapse/admin/v1/rooms?from=%d&limit=%d" % (start, limit) - request, channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok, - ) - self.render(request) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - - def test_correct_room_attributes(self): - """Test the correct attributes for a room are returned""" - # Create a test room - room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) - - test_alias = "#test:test" - test_room_name = "something" - - # Have another user join the room - user_2 = self.register_user("user4", "pass") - user_tok_2 = self.login("user4", "pass") - self.helper.join(room_id, user_2, tok=user_tok_2) - - # Create a new alias to this room - url = "/_matrix/client/r0/directory/room/%s" % (urllib.parse.quote(test_alias),) - request, channel = self.make_request( - "PUT", - url.encode("ascii"), - {"room_id": room_id}, - access_token=self.admin_user_tok, - ) - self.render(request) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - - # Set this new alias as the canonical alias for this room - self.helper.send_state( - room_id, - "m.room.aliases", - {"aliases": [test_alias]}, - tok=self.admin_user_tok, - state_key="test", - ) - self.helper.send_state( - room_id, - "m.room.canonical_alias", - {"alias": test_alias}, - tok=self.admin_user_tok, - ) - - # Set a name for the room - self.helper.send_state( - room_id, "m.room.name", {"name": test_room_name}, tok=self.admin_user_tok, - ) - - # Request the list of rooms - url = "/_synapse/admin/v1/rooms" - request, channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok, - ) - self.render(request) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - - # Check that rooms were returned - self.assertTrue("rooms" in channel.json_body) - rooms = channel.json_body["rooms"] - - # Check that only one room was returned - self.assertEqual(len(rooms), 1) - - # And that the value of the total_rooms key was correct - self.assertEqual(channel.json_body["total_rooms"], 1) - - # Check that the offset is correct - # We're not paginating, so should be 0 - self.assertEqual(channel.json_body["offset"], 0) - - # Check that there is no `prev_batch` - self.assertNotIn("prev_batch", channel.json_body) - - # Check that there is no `next_batch` - self.assertNotIn("next_batch", channel.json_body) - - # Check that all provided attributes are set - r = rooms[0] - self.assertEqual(room_id, r["room_id"]) - self.assertEqual(test_room_name, r["name"]) - self.assertEqual(test_alias, r["canonical_alias"]) - - def test_room_list_sort_order(self): - """Test room list sort ordering. alphabetical versus number of members, - reversing the order, etc. - """ - # Create 3 test rooms - room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) - room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) - room_id_3 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) - - # Set room names in alphabetical order. room 1 -> A, 2 -> B, 3 -> C - self.helper.send_state( - room_id_1, "m.room.name", {"name": "A"}, tok=self.admin_user_tok, - ) - self.helper.send_state( - room_id_2, "m.room.name", {"name": "B"}, tok=self.admin_user_tok, - ) - self.helper.send_state( - room_id_3, "m.room.name", {"name": "C"}, tok=self.admin_user_tok, - ) - - # Set room member size in the reverse order. room 1 -> 1 member, 2 -> 2, 3 -> 3 - user_1 = self.register_user("bob1", "pass") - user_1_tok = self.login("bob1", "pass") - self.helper.join(room_id_2, user_1, tok=user_1_tok) - - user_2 = self.register_user("bob2", "pass") - user_2_tok = self.login("bob2", "pass") - self.helper.join(room_id_3, user_2, tok=user_2_tok) - - user_3 = self.register_user("bob3", "pass") - user_3_tok = self.login("bob3", "pass") - self.helper.join(room_id_3, user_3, tok=user_3_tok) - - def _order_test( - order_type: str, expected_room_list: List[str], reverse: bool = False, - ): - """Request the list of rooms in a certain order. Assert that order is what - we expect - - Args: - order_type: The type of ordering to give the server - expected_room_list: The list of room_ids in the order we expect to get - back from the server - """ - # Request the list of rooms in the given order - url = "/_synapse/admin/v1/rooms?order_by=%s" % (order_type,) - if reverse: - url += "&dir=b" - request, channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok, - ) - self.render(request) - self.assertEqual(200, channel.code, msg=channel.json_body) - - # Check that rooms were returned - self.assertTrue("rooms" in channel.json_body) - rooms = channel.json_body["rooms"] - - # Check for the correct total_rooms value - self.assertEqual(channel.json_body["total_rooms"], 3) - - # Check that the offset is correct - # We're not paginating, so should be 0 - self.assertEqual(channel.json_body["offset"], 0) - - # Check that there is no `prev_batch` - self.assertNotIn("prev_batch", channel.json_body) - - # Check that there is no `next_batch` - self.assertNotIn("next_batch", channel.json_body) - - # Check that rooms were returned in alphabetical order - returned_order = [r["room_id"] for r in rooms] - self.assertListEqual(expected_room_list, returned_order) # order is checked - - # Test different sort orders, with forward and reverse directions - _order_test("alphabetical", [room_id_1, room_id_2, room_id_3]) - _order_test("alphabetical", [room_id_3, room_id_2, room_id_1], reverse=True) - - _order_test("size", [room_id_3, room_id_2, room_id_1]) - _order_test("size", [room_id_1, room_id_2, room_id_3], reverse=True) - - def test_search_term(self): - """Test that searching for a room works correctly""" - # Create two test rooms - room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) - room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) - - room_name_1 = "something" - room_name_2 = "else" - - # Set the name for each room - self.helper.send_state( - room_id_1, "m.room.name", {"name": room_name_1}, tok=self.admin_user_tok, - ) - self.helper.send_state( - room_id_2, "m.room.name", {"name": room_name_2}, tok=self.admin_user_tok, - ) - - def _search_test( - expected_room_id: Optional[str], - search_term: str, - expected_http_code: int = 200, - ): - """Search for a room and check that the returned room's id is a match - - Args: - expected_room_id: The room_id expected to be returned by the API. Set - to None to expect zero results for the search - search_term: The term to search for room names with - expected_http_code: The expected http code for the request - """ - url = "/_synapse/admin/v1/rooms?search_term=%s" % (search_term,) - request, channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok, - ) - self.render(request) - self.assertEqual(expected_http_code, channel.code, msg=channel.json_body) - - if expected_http_code != 200: - return - - # Check that rooms were returned - self.assertTrue("rooms" in channel.json_body) - rooms = channel.json_body["rooms"] - - # Check that the expected number of rooms were returned - expected_room_count = 1 if expected_room_id else 0 - self.assertEqual(len(rooms), expected_room_count) - self.assertEqual(channel.json_body["total_rooms"], expected_room_count) - - # Check that the offset is correct - # We're not paginating, so should be 0 - self.assertEqual(channel.json_body["offset"], 0) - - # Check that there is no `prev_batch` - self.assertNotIn("prev_batch", channel.json_body) - - # Check that there is no `next_batch` - self.assertNotIn("next_batch", channel.json_body) - - if expected_room_id: - # Check that the first returned room id is correct - r = rooms[0] - self.assertEqual(expected_room_id, r["room_id"]) - - # Perform search tests - _search_test(room_id_1, "something") - _search_test(room_id_1, "thing") - - _search_test(room_id_2, "else") - _search_test(room_id_2, "se") - - _search_test(None, "foo") - _search_test(None, "bar") - _search_test(None, "", expected_http_code=400) diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index 672cc3eac5..249c93722f 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -14,16 +14,694 @@ # limitations under the License. import json +import urllib.parse +from typing import List, Optional + +from mock import Mock import synapse.rest.admin from synapse.api.errors import Codes -from synapse.rest.client.v1 import login, room +from synapse.rest.client.v1 import directory, events, login, room from tests import unittest """Tests admin REST events for /rooms paths.""" +class ShutdownRoomTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + login.register_servlets, + events.register_servlets, + room.register_servlets, + room.register_deprecated_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.event_creation_handler = hs.get_event_creation_handler() + hs.config.user_consent_version = "1" + + consent_uri_builder = Mock() + consent_uri_builder.build_user_consent_uri.return_value = "http://example.com" + self.event_creation_handler._consent_uri_builder = consent_uri_builder + + self.store = hs.get_datastore() + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass") + self.other_user_token = self.login("user", "pass") + + # Mark the admin user as having consented + self.get_success(self.store.user_set_consent_version(self.admin_user, "1")) + + def test_shutdown_room_consent(self): + """Test that we can shutdown rooms with local users who have not + yet accepted the privacy policy. This used to fail when we tried to + force part the user from the old room. + """ + self.event_creation_handler._block_events_without_consent_error = None + + room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token) + + # Assert one user in room + users_in_room = self.get_success(self.store.get_users_in_room(room_id)) + self.assertEqual([self.other_user], users_in_room) + + # Enable require consent to send events + self.event_creation_handler._block_events_without_consent_error = "Error" + + # Assert that the user is getting consent error + self.helper.send( + room_id, body="foo", tok=self.other_user_token, expect_code=403 + ) + + # Test that the admin can still send shutdown + url = "admin/shutdown_room/" + room_id + request, channel = self.make_request( + "POST", + url.encode("ascii"), + json.dumps({"new_room_user_id": self.admin_user}), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Assert there is now no longer anyone in the room + users_in_room = self.get_success(self.store.get_users_in_room(room_id)) + self.assertEqual([], users_in_room) + + def test_shutdown_room_block_peek(self): + """Test that a world_readable room can no longer be peeked into after + it has been shut down. + """ + + self.event_creation_handler._block_events_without_consent_error = None + + room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token) + + # Enable world readable + url = "rooms/%s/state/m.room.history_visibility" % (room_id,) + request, channel = self.make_request( + "PUT", + url.encode("ascii"), + json.dumps({"history_visibility": "world_readable"}), + access_token=self.other_user_token, + ) + self.render(request) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Test that the admin can still send shutdown + url = "admin/shutdown_room/" + room_id + request, channel = self.make_request( + "POST", + url.encode("ascii"), + json.dumps({"new_room_user_id": self.admin_user}), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Assert we can no longer peek into the room + self._assert_peek(room_id, expect_code=403) + + def _assert_peek(self, room_id, expect_code): + """Assert that the admin user can (or cannot) peek into the room. + """ + + url = "rooms/%s/initialSync" % (room_id,) + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok + ) + self.render(request) + self.assertEqual( + expect_code, int(channel.result["code"]), msg=channel.result["body"] + ) + + url = "events?timeout=0&room_id=" + room_id + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok + ) + self.render(request) + self.assertEqual( + expect_code, int(channel.result["code"]), msg=channel.result["body"] + ) + + +class PurgeRoomTestCase(unittest.HomeserverTestCase): + """Test /purge_room admin API. + """ + + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + def test_purge_room(self): + room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + + # All users have to have left the room. + self.helper.leave(room_id, user=self.admin_user, tok=self.admin_user_tok) + + url = "/_synapse/admin/v1/purge_room" + request, channel = self.make_request( + "POST", + url.encode("ascii"), + {"room_id": room_id}, + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Test that the following tables have been purged of all rows related to the room. + for table in ( + "current_state_events", + "event_backward_extremities", + "event_forward_extremities", + "event_json", + "event_push_actions", + "event_search", + "events", + "group_rooms", + "public_room_list_stream", + "receipts_graph", + "receipts_linearized", + "room_aliases", + "room_depth", + "room_memberships", + "room_stats_state", + "room_stats_current", + "room_stats_historical", + "room_stats_earliest_token", + "rooms", + "stream_ordering_to_exterm", + "users_in_public_rooms", + "users_who_share_private_rooms", + "appservice_room_list", + "e2e_room_keys", + "event_push_summary", + "pusher_throttle", + "group_summary_rooms", + "local_invites", + "room_account_data", + "room_tags", + # "state_groups", # Current impl leaves orphaned state groups around. + "state_groups_state", + ): + count = self.get_success( + self.store.db.simple_select_one_onecol( + table=table, + keyvalues={"room_id": room_id}, + retcol="COUNT(*)", + desc="test_purge_room", + ) + ) + + self.assertEqual(count, 0, msg="Rows not purged in {}".format(table)) + + +class RoomTestCase(unittest.HomeserverTestCase): + """Test /room admin API. + """ + + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + room.register_servlets, + directory.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + + # Create user + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + def test_list_rooms(self): + """Test that we can list rooms""" + # Create 3 test rooms + total_rooms = 3 + room_ids = [] + for x in range(total_rooms): + room_id = self.helper.create_room_as( + self.admin_user, tok=self.admin_user_tok + ) + room_ids.append(room_id) + + # Request the list of rooms + url = "/_synapse/admin/v1/rooms" + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok, + ) + self.render(request) + + # Check request completed successfully + self.assertEqual(200, int(channel.code), msg=channel.json_body) + + # Check that response json body contains a "rooms" key + self.assertTrue( + "rooms" in channel.json_body, + msg="Response body does not " "contain a 'rooms' key", + ) + + # Check that 3 rooms were returned + self.assertEqual(3, len(channel.json_body["rooms"]), msg=channel.json_body) + + # Check their room_ids match + returned_room_ids = [room["room_id"] for room in channel.json_body["rooms"]] + self.assertEqual(room_ids, returned_room_ids) + + # Check that all fields are available + for r in channel.json_body["rooms"]: + self.assertIn("name", r) + self.assertIn("canonical_alias", r) + self.assertIn("joined_members", r) + self.assertIn("joined_local_members", r) + self.assertIn("version", r) + self.assertIn("creator", r) + self.assertIn("encryption", r) + self.assertIn("federatable", r) + self.assertIn("public", r) + self.assertIn("join_rules", r) + self.assertIn("guest_access", r) + self.assertIn("history_visibility", r) + self.assertIn("state_events", r) + + # Check that the correct number of total rooms was returned + self.assertEqual(channel.json_body["total_rooms"], total_rooms) + + # Check that the offset is correct + # Should be 0 as we aren't paginating + self.assertEqual(channel.json_body["offset"], 0) + + # Check that the prev_batch parameter is not present + self.assertNotIn("prev_batch", channel.json_body) + + # We shouldn't receive a next token here as there's no further rooms to show + self.assertNotIn("next_batch", channel.json_body) + + def test_list_rooms_pagination(self): + """Test that we can get a full list of rooms through pagination""" + # Create 5 test rooms + total_rooms = 5 + room_ids = [] + for x in range(total_rooms): + room_id = self.helper.create_room_as( + self.admin_user, tok=self.admin_user_tok + ) + room_ids.append(room_id) + + # Set the name of the rooms so we get a consistent returned ordering + for idx, room_id in enumerate(room_ids): + self.helper.send_state( + room_id, "m.room.name", {"name": str(idx)}, tok=self.admin_user_tok, + ) + + # Request the list of rooms + returned_room_ids = [] + start = 0 + limit = 2 + + run_count = 0 + should_repeat = True + while should_repeat: + run_count += 1 + + url = "/_synapse/admin/v1/rooms?from=%d&limit=%d&order_by=%s" % ( + start, + limit, + "name", + ) + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual( + 200, int(channel.result["code"]), msg=channel.result["body"] + ) + + self.assertTrue("rooms" in channel.json_body) + for r in channel.json_body["rooms"]: + returned_room_ids.append(r["room_id"]) + + # Check that the correct number of total rooms was returned + self.assertEqual(channel.json_body["total_rooms"], total_rooms) + + # Check that the offset is correct + # We're only getting 2 rooms each page, so should be 2 * last run_count + self.assertEqual(channel.json_body["offset"], 2 * (run_count - 1)) + + if run_count > 1: + # Check the value of prev_batch is correct + self.assertEqual(channel.json_body["prev_batch"], 2 * (run_count - 2)) + + if "next_batch" not in channel.json_body: + # We have reached the end of the list + should_repeat = False + else: + # Make another query with an updated start value + start = channel.json_body["next_batch"] + + # We should've queried the endpoint 3 times + self.assertEqual( + run_count, + 3, + msg="Should've queried 3 times for 5 rooms with limit 2 per query", + ) + + # Check that we received all of the room ids + self.assertEqual(room_ids, returned_room_ids) + + url = "/_synapse/admin/v1/rooms?from=%d&limit=%d" % (start, limit) + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + def test_correct_room_attributes(self): + """Test the correct attributes for a room are returned""" + # Create a test room + room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + + test_alias = "#test:test" + test_room_name = "something" + + # Have another user join the room + user_2 = self.register_user("user4", "pass") + user_tok_2 = self.login("user4", "pass") + self.helper.join(room_id, user_2, tok=user_tok_2) + + # Create a new alias to this room + url = "/_matrix/client/r0/directory/room/%s" % (urllib.parse.quote(test_alias),) + request, channel = self.make_request( + "PUT", + url.encode("ascii"), + {"room_id": room_id}, + access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Set this new alias as the canonical alias for this room + self.helper.send_state( + room_id, + "m.room.aliases", + {"aliases": [test_alias]}, + tok=self.admin_user_tok, + state_key="test", + ) + self.helper.send_state( + room_id, + "m.room.canonical_alias", + {"alias": test_alias}, + tok=self.admin_user_tok, + ) + + # Set a name for the room + self.helper.send_state( + room_id, "m.room.name", {"name": test_room_name}, tok=self.admin_user_tok, + ) + + # Request the list of rooms + url = "/_synapse/admin/v1/rooms" + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Check that rooms were returned + self.assertTrue("rooms" in channel.json_body) + rooms = channel.json_body["rooms"] + + # Check that only one room was returned + self.assertEqual(len(rooms), 1) + + # And that the value of the total_rooms key was correct + self.assertEqual(channel.json_body["total_rooms"], 1) + + # Check that the offset is correct + # We're not paginating, so should be 0 + self.assertEqual(channel.json_body["offset"], 0) + + # Check that there is no `prev_batch` + self.assertNotIn("prev_batch", channel.json_body) + + # Check that there is no `next_batch` + self.assertNotIn("next_batch", channel.json_body) + + # Check that all provided attributes are set + r = rooms[0] + self.assertEqual(room_id, r["room_id"]) + self.assertEqual(test_room_name, r["name"]) + self.assertEqual(test_alias, r["canonical_alias"]) + + def test_room_list_sort_order(self): + """Test room list sort ordering. alphabetical name versus number of members, + reversing the order, etc. + """ + + def _set_canonical_alias(room_id: str, test_alias: str, admin_user_tok: str): + # Create a new alias to this room + url = "/_matrix/client/r0/directory/room/%s" % ( + urllib.parse.quote(test_alias), + ) + request, channel = self.make_request( + "PUT", + url.encode("ascii"), + {"room_id": room_id}, + access_token=admin_user_tok, + ) + self.render(request) + self.assertEqual( + 200, int(channel.result["code"]), msg=channel.result["body"] + ) + + # Set this new alias as the canonical alias for this room + self.helper.send_state( + room_id, + "m.room.aliases", + {"aliases": [test_alias]}, + tok=admin_user_tok, + state_key="test", + ) + self.helper.send_state( + room_id, + "m.room.canonical_alias", + {"alias": test_alias}, + tok=admin_user_tok, + ) + + def _order_test( + order_type: str, expected_room_list: List[str], reverse: bool = False, + ): + """Request the list of rooms in a certain order. Assert that order is what + we expect + + Args: + order_type: The type of ordering to give the server + expected_room_list: The list of room_ids in the order we expect to get + back from the server + """ + # Request the list of rooms in the given order + url = "/_synapse/admin/v1/rooms?order_by=%s" % (order_type,) + if reverse: + url += "&dir=b" + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, channel.code, msg=channel.json_body) + + # Check that rooms were returned + self.assertTrue("rooms" in channel.json_body) + rooms = channel.json_body["rooms"] + + # Check for the correct total_rooms value + self.assertEqual(channel.json_body["total_rooms"], 3) + + # Check that the offset is correct + # We're not paginating, so should be 0 + self.assertEqual(channel.json_body["offset"], 0) + + # Check that there is no `prev_batch` + self.assertNotIn("prev_batch", channel.json_body) + + # Check that there is no `next_batch` + self.assertNotIn("next_batch", channel.json_body) + + # Check that rooms were returned in alphabetical order + returned_order = [r["room_id"] for r in rooms] + self.assertListEqual(expected_room_list, returned_order) # order is checked + + # Create 3 test rooms + room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + room_id_3 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + + # Set room names in alphabetical order. room 1 -> A, 2 -> B, 3 -> C + self.helper.send_state( + room_id_1, "m.room.name", {"name": "A"}, tok=self.admin_user_tok, + ) + self.helper.send_state( + room_id_2, "m.room.name", {"name": "B"}, tok=self.admin_user_tok, + ) + self.helper.send_state( + room_id_3, "m.room.name", {"name": "C"}, tok=self.admin_user_tok, + ) + + # Set room canonical room aliases + _set_canonical_alias(room_id_1, "#A_alias:test", self.admin_user_tok) + _set_canonical_alias(room_id_2, "#B_alias:test", self.admin_user_tok) + _set_canonical_alias(room_id_3, "#C_alias:test", self.admin_user_tok) + + # Set room member size in the reverse order. room 1 -> 1 member, 2 -> 2, 3 -> 3 + user_1 = self.register_user("bob1", "pass") + user_1_tok = self.login("bob1", "pass") + self.helper.join(room_id_2, user_1, tok=user_1_tok) + + user_2 = self.register_user("bob2", "pass") + user_2_tok = self.login("bob2", "pass") + self.helper.join(room_id_3, user_2, tok=user_2_tok) + + user_3 = self.register_user("bob3", "pass") + user_3_tok = self.login("bob3", "pass") + self.helper.join(room_id_3, user_3, tok=user_3_tok) + + # Test different sort orders, with forward and reverse directions + _order_test("name", [room_id_1, room_id_2, room_id_3]) + _order_test("name", [room_id_3, room_id_2, room_id_1], reverse=True) + + _order_test("canonical_alias", [room_id_1, room_id_2, room_id_3]) + _order_test("canonical_alias", [room_id_3, room_id_2, room_id_1], reverse=True) + + _order_test("joined_members", [room_id_3, room_id_2, room_id_1]) + _order_test("joined_members", [room_id_1, room_id_2, room_id_3], reverse=True) + + _order_test("joined_local_members", [room_id_3, room_id_2, room_id_1]) + _order_test( + "joined_local_members", [room_id_1, room_id_2, room_id_3], reverse=True + ) + + _order_test("version", [room_id_1, room_id_2, room_id_3]) + _order_test("version", [room_id_1, room_id_2, room_id_3], reverse=True) + + _order_test("creator", [room_id_1, room_id_2, room_id_3]) + _order_test("creator", [room_id_1, room_id_2, room_id_3], reverse=True) + + _order_test("encryption", [room_id_1, room_id_2, room_id_3]) + _order_test("encryption", [room_id_1, room_id_2, room_id_3], reverse=True) + + _order_test("federatable", [room_id_1, room_id_2, room_id_3]) + _order_test("federatable", [room_id_1, room_id_2, room_id_3], reverse=True) + + _order_test("public", [room_id_1, room_id_2, room_id_3]) + # Different sort order of SQlite and PostreSQL + # _order_test("public", [room_id_3, room_id_2, room_id_1], reverse=True) + + _order_test("join_rules", [room_id_1, room_id_2, room_id_3]) + _order_test("join_rules", [room_id_1, room_id_2, room_id_3], reverse=True) + + _order_test("guest_access", [room_id_1, room_id_2, room_id_3]) + _order_test("guest_access", [room_id_1, room_id_2, room_id_3], reverse=True) + + _order_test("history_visibility", [room_id_1, room_id_2, room_id_3]) + _order_test( + "history_visibility", [room_id_1, room_id_2, room_id_3], reverse=True + ) + + _order_test("state_events", [room_id_3, room_id_2, room_id_1]) + _order_test("state_events", [room_id_1, room_id_2, room_id_3], reverse=True) + + def test_search_term(self): + """Test that searching for a room works correctly""" + # Create two test rooms + room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + + room_name_1 = "something" + room_name_2 = "else" + + # Set the name for each room + self.helper.send_state( + room_id_1, "m.room.name", {"name": room_name_1}, tok=self.admin_user_tok, + ) + self.helper.send_state( + room_id_2, "m.room.name", {"name": room_name_2}, tok=self.admin_user_tok, + ) + + def _search_test( + expected_room_id: Optional[str], + search_term: str, + expected_http_code: int = 200, + ): + """Search for a room and check that the returned room's id is a match + + Args: + expected_room_id: The room_id expected to be returned by the API. Set + to None to expect zero results for the search + search_term: The term to search for room names with + expected_http_code: The expected http code for the request + """ + url = "/_synapse/admin/v1/rooms?search_term=%s" % (search_term,) + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(expected_http_code, channel.code, msg=channel.json_body) + + if expected_http_code != 200: + return + + # Check that rooms were returned + self.assertTrue("rooms" in channel.json_body) + rooms = channel.json_body["rooms"] + + # Check that the expected number of rooms were returned + expected_room_count = 1 if expected_room_id else 0 + self.assertEqual(len(rooms), expected_room_count) + self.assertEqual(channel.json_body["total_rooms"], expected_room_count) + + # Check that the offset is correct + # We're not paginating, so should be 0 + self.assertEqual(channel.json_body["offset"], 0) + + # Check that there is no `prev_batch` + self.assertNotIn("prev_batch", channel.json_body) + + # Check that there is no `next_batch` + self.assertNotIn("next_batch", channel.json_body) + + if expected_room_id: + # Check that the first returned room id is correct + r = rooms[0] + self.assertEqual(expected_room_id, r["room_id"]) + + # Perform search tests + _search_test(room_id_1, "something") + _search_test(room_id_1, "thing") + + _search_test(room_id_2, "else") + _search_test(room_id_2, "se") + + _search_test(None, "foo") + _search_test(None, "bar") + _search_test(None, "", expected_http_code=400) + + class JoinAliasRoomTestCase(unittest.HomeserverTestCase): servlets = [ -- cgit 1.5.1 From 13683a3a223a3f08b295973e7b8aa6e9299a2fa5 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 22 Apr 2020 13:45:40 +0100 Subject: Extend StreamChangeCache to support multiple entities per stream ID (#7303) First some background: StreamChangeCache is used to keep track of what "entities" have changed since a given stream ID. So for example, we might use it to keep track of when the last to-device message for a given user was received [1], and hence whether we need to pull any to-device messages from the database on a sync [2]. Now, it turns out that StreamChangeCache didn't support more than one thing being changed at a given stream_id (this was part of the problem with #7206). However, it's entirely valid to send to-device messages to more than one user at a time. As it turns out, this did in fact work, because *some* methods of StreamChangeCache coped ok with having multiple things changing on the same stream ID, and it seems we never actually use the methods which don't work on the stream change caches where we allow multiple changes at the same stream ID. But that feels horribly fragile, hence: let's update StreamChangeCache to properly support this, and add some typing and some more tests while we're at it. [1]: https://github.com/matrix-org/synapse/blob/release-v1.12.3/synapse/storage/data_stores/main/deviceinbox.py#L301 [2]: https://github.com/matrix-org/synapse/blob/release-v1.12.3/synapse/storage/data_stores/main/deviceinbox.py#L47-L51 --- changelog.d/7303.misc | 1 + stubs/sortedcontainers/__init__.pyi | 13 +++ stubs/sortedcontainers/sorteddict.pyi | 124 +++++++++++++++++++++++++++++ synapse/util/caches/stream_change_cache.py | 117 ++++++++++++++++----------- tests/util/test_stream_change_cache.py | 69 +++++++++++++--- tox.ini | 4 +- 6 files changed, 272 insertions(+), 56 deletions(-) create mode 100644 changelog.d/7303.misc create mode 100644 stubs/sortedcontainers/__init__.pyi create mode 100644 stubs/sortedcontainers/sorteddict.pyi (limited to 'tests') diff --git a/changelog.d/7303.misc b/changelog.d/7303.misc new file mode 100644 index 0000000000..aa89c2b254 --- /dev/null +++ b/changelog.d/7303.misc @@ -0,0 +1 @@ +Fix StreamChangeCache to work with multiple entities changing on the same stream id. diff --git a/stubs/sortedcontainers/__init__.pyi b/stubs/sortedcontainers/__init__.pyi new file mode 100644 index 0000000000..073b806d3c --- /dev/null +++ b/stubs/sortedcontainers/__init__.pyi @@ -0,0 +1,13 @@ +from .sorteddict import ( + SortedDict, + SortedKeysView, + SortedItemsView, + SortedValuesView, +) + +__all__ = [ + "SortedDict", + "SortedKeysView", + "SortedItemsView", + "SortedValuesView", +] diff --git a/stubs/sortedcontainers/sorteddict.pyi b/stubs/sortedcontainers/sorteddict.pyi new file mode 100644 index 0000000000..68779f968e --- /dev/null +++ b/stubs/sortedcontainers/sorteddict.pyi @@ -0,0 +1,124 @@ +# stub for SortedDict. This is a lightly edited copy of +# https://github.com/grantjenks/python-sortedcontainers/blob/eea42df1f7bad2792e8da77335ff888f04b9e5ae/sortedcontainers/sorteddict.pyi +# (from https://github.com/grantjenks/python-sortedcontainers/pull/107) + +from typing import ( + Any, + Callable, + Dict, + Hashable, + Iterator, + Iterable, + ItemsView, + KeysView, + List, + Mapping, + Optional, + Sequence, + Type, + TypeVar, + Tuple, + Union, + ValuesView, + overload, +) + +_T = TypeVar("_T") +_S = TypeVar("_S") +_T_h = TypeVar("_T_h", bound=Hashable) +_KT = TypeVar("_KT", bound=Hashable) # Key type. +_VT = TypeVar("_VT") # Value type. +_KT_co = TypeVar("_KT_co", covariant=True, bound=Hashable) +_VT_co = TypeVar("_VT_co", covariant=True) +_SD = TypeVar("_SD", bound=SortedDict) +_Key = Callable[[_T], Any] + +class SortedDict(Dict[_KT, _VT]): + @overload + def __init__(self, **kwargs: _VT) -> None: ... + @overload + def __init__(self, __map: Mapping[_KT, _VT], **kwargs: _VT) -> None: ... + @overload + def __init__( + self, __iterable: Iterable[Tuple[_KT, _VT]], **kwargs: _VT + ) -> None: ... + @overload + def __init__(self, __key: _Key[_KT], **kwargs: _VT) -> None: ... + @overload + def __init__( + self, __key: _Key[_KT], __map: Mapping[_KT, _VT], **kwargs: _VT + ) -> None: ... + @overload + def __init__( + self, __key: _Key[_KT], __iterable: Iterable[Tuple[_KT, _VT]], **kwargs: _VT + ) -> None: ... + @property + def key(self) -> Optional[_Key[_KT]]: ... + @property + def iloc(self) -> SortedKeysView[_KT]: ... + def clear(self) -> None: ... + def __delitem__(self, key: _KT) -> None: ... + def __iter__(self) -> Iterator[_KT]: ... + def __reversed__(self) -> Iterator[_KT]: ... + def __setitem__(self, key: _KT, value: _VT) -> None: ... + def _setitem(self, key: _KT, value: _VT) -> None: ... + def copy(self: _SD) -> _SD: ... + def __copy__(self: _SD) -> _SD: ... + @classmethod + @overload + def fromkeys(cls, seq: Iterable[_T_h]) -> SortedDict[_T_h, None]: ... + @classmethod + @overload + def fromkeys(cls, seq: Iterable[_T_h], value: _S) -> SortedDict[_T_h, _S]: ... + def keys(self) -> SortedKeysView[_KT]: ... + def items(self) -> SortedItemsView[_KT, _VT]: ... + def values(self) -> SortedValuesView[_VT]: ... + @overload + def pop(self, key: _KT) -> _VT: ... + @overload + def pop(self, key: _KT, default: _T = ...) -> Union[_VT, _T]: ... + def popitem(self, index: int = ...) -> Tuple[_KT, _VT]: ... + def peekitem(self, index: int = ...) -> Tuple[_KT, _VT]: ... + def setdefault(self, key: _KT, default: Optional[_VT] = ...) -> _VT: ... + @overload + def update(self, __map: Mapping[_KT, _VT], **kwargs: _VT) -> None: ... + @overload + def update(self, __iterable: Iterable[Tuple[_KT, _VT]], **kwargs: _VT) -> None: ... + @overload + def update(self, **kwargs: _VT) -> None: ... + def __reduce__( + self, + ) -> Tuple[ + Type[SortedDict[_KT, _VT]], Tuple[Callable[[_KT], Any], List[Tuple[_KT, _VT]]], + ]: ... + def __repr__(self) -> str: ... + def _check(self) -> None: ... + def islice( + self, start: Optional[int] = ..., stop: Optional[int] = ..., reverse=bool, + ) -> Iterator[_KT]: ... + def bisect_left(self, value: _KT) -> int: ... + def bisect_right(self, value: _KT) -> int: ... + +class SortedKeysView(KeysView[_KT_co], Sequence[_KT_co]): + @overload + def __getitem__(self, index: int) -> _KT_co: ... + @overload + def __getitem__(self, index: slice) -> List[_KT_co]: ... + def __delitem__(self, index: Union[int, slice]) -> None: ... + +class SortedItemsView( # type: ignore + ItemsView[_KT_co, _VT_co], Sequence[Tuple[_KT_co, _VT_co]] +): + def __iter__(self) -> Iterator[Tuple[_KT_co, _VT_co]]: ... + @overload + def __getitem__(self, index: int) -> Tuple[_KT_co, _VT_co]: ... + @overload + def __getitem__(self, index: slice) -> List[Tuple[_KT_co, _VT_co]]: ... + def __delitem__(self, index: Union[int, slice]) -> None: ... + +class SortedValuesView(ValuesView[_VT_co], Sequence[_VT_co]): + @overload + def __getitem__(self, index: int) -> _VT_co: ... + @overload + def __getitem__(self, index: slice) -> List[_VT_co]: ... + def __delitem__(self, index: Union[int, slice]) -> None: ... diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py index c61d36a82e..38dc3f501e 100644 --- a/synapse/util/caches/stream_change_cache.py +++ b/synapse/util/caches/stream_change_cache.py @@ -14,6 +14,7 @@ # limitations under the License. import logging +from typing import Dict, Iterable, List, Mapping, Optional, Set from six import integer_types @@ -23,8 +24,11 @@ from synapse.util import caches logger = logging.getLogger(__name__) +# for now, assume all entities in the cache are strings +EntityType = str -class StreamChangeCache(object): + +class StreamChangeCache: """Keeps track of the stream positions of the latest change in a set of entities. Typically the entity will be a room or user id. @@ -34,10 +38,23 @@ class StreamChangeCache(object): old then the cache will simply return all given entities. """ - def __init__(self, name, current_stream_pos, max_size=10000, prefilled_cache=None): + def __init__( + self, + name: str, + current_stream_pos: int, + max_size=10000, + prefilled_cache: Optional[Mapping[EntityType, int]] = None, + ): self._max_size = int(max_size * caches.CACHE_SIZE_FACTOR) - self._entity_to_key = {} - self._cache = SortedDict() + self._entity_to_key = {} # type: Dict[EntityType, int] + + # map from stream id to the a set of entities which changed at that stream id. + self._cache = SortedDict() # type: SortedDict[int, Set[EntityType]] + + # the earliest stream_pos for which we can reliably answer + # get_all_entities_changed. In other words, one less than the earliest + # stream_pos for which we know _cache is valid. + # self._earliest_known_stream_pos = current_stream_pos self.name = name self.metrics = caches.register_cache("cache", self.name, self._cache) @@ -46,7 +63,7 @@ class StreamChangeCache(object): for entity, stream_pos in prefilled_cache.items(): self.entity_has_changed(entity, stream_pos) - def has_entity_changed(self, entity, stream_pos): + def has_entity_changed(self, entity: EntityType, stream_pos: int) -> bool: """Returns True if the entity may have been updated since stream_pos """ assert type(stream_pos) in integer_types @@ -67,22 +84,17 @@ class StreamChangeCache(object): self.metrics.inc_hits() return False - def get_entities_changed(self, entities, stream_pos): + def get_entities_changed( + self, entities: Iterable[EntityType], stream_pos: int + ) -> Set[EntityType]: """ Returns subset of entities that have had new things since the given position. Entities unknown to the cache will be returned. If the position is too old it will just return the given list. """ - assert type(stream_pos) is int - - if stream_pos >= self._earliest_known_stream_pos: - changed_entities = { - self._cache[k] - for k in self._cache.islice(start=self._cache.bisect_right(stream_pos)) - } - - result = changed_entities.intersection(entities) - + changed_entities = self.get_all_entities_changed(stream_pos) + if changed_entities is not None: + result = set(changed_entities).intersection(entities) self.metrics.inc_hits() else: result = set(entities) @@ -90,13 +102,13 @@ class StreamChangeCache(object): return result - def has_any_entity_changed(self, stream_pos): + def has_any_entity_changed(self, stream_pos: int) -> bool: """Returns if any entity has changed """ assert type(stream_pos) is int if not self._cache: - # If we have no cache, nothing can have changed. + # If the cache is empty, nothing can have changed. return False if stream_pos >= self._earliest_known_stream_pos: @@ -106,45 +118,58 @@ class StreamChangeCache(object): self.metrics.inc_misses() return True - def get_all_entities_changed(self, stream_pos): - """Returns all entites that have had new things since the given + def get_all_entities_changed(self, stream_pos: int) -> Optional[List[EntityType]]: + """Returns all entities that have had new things since the given position. If the position is too old it will return None. + + Returns the entities in the order that they were changed. """ assert type(stream_pos) is int - if stream_pos >= self._earliest_known_stream_pos: - return [ - self._cache[k] - for k in self._cache.islice(start=self._cache.bisect_right(stream_pos)) - ] - else: + if stream_pos < self._earliest_known_stream_pos: return None - def entity_has_changed(self, entity, stream_pos): + changed_entities = [] # type: List[EntityType] + + for k in self._cache.islice(start=self._cache.bisect_right(stream_pos)): + changed_entities.extend(self._cache[k]) + return changed_entities + + def entity_has_changed(self, entity: EntityType, stream_pos: int) -> None: """Informs the cache that the entity has been changed at the given position. """ assert type(stream_pos) is int - # FIXME: add a sanity check here that we are not overwriting existing - # data in self._cache - - if stream_pos > self._earliest_known_stream_pos: - old_pos = self._entity_to_key.get(entity, None) - if old_pos is not None: - stream_pos = max(stream_pos, old_pos) - self._cache.pop(old_pos, None) - self._cache[stream_pos] = entity - self._entity_to_key[entity] = stream_pos - - while len(self._cache) > self._max_size: - k, r = self._cache.popitem(0) - self._earliest_known_stream_pos = max( - k, self._earliest_known_stream_pos - ) - self._entity_to_key.pop(r, None) - - def get_max_pos_of_last_change(self, entity): + if stream_pos <= self._earliest_known_stream_pos: + return + + old_pos = self._entity_to_key.get(entity, None) + if old_pos is not None: + if old_pos >= stream_pos: + # nothing to do + return + e = self._cache[old_pos] + e.remove(entity) + if not e: + # cache at this point is now empty + del self._cache[old_pos] + + e1 = self._cache.get(stream_pos) + if e1 is None: + e1 = self._cache[stream_pos] = set() + e1.add(entity) + self._entity_to_key[entity] = stream_pos + + # if the cache is too big, remove entries + while len(self._cache) > self._max_size: + k, r = self._cache.popitem(0) + self._earliest_known_stream_pos = max(k, self._earliest_known_stream_pos) + for entity in r: + del self._entity_to_key[entity] + + def get_max_pos_of_last_change(self, entity: EntityType) -> int: + """Returns an upper bound of the stream id of the last change to an entity. """ diff --git a/tests/util/test_stream_change_cache.py b/tests/util/test_stream_change_cache.py index 72a9de5370..6857933540 100644 --- a/tests/util/test_stream_change_cache.py +++ b/tests/util/test_stream_change_cache.py @@ -28,18 +28,26 @@ class StreamChangeCacheTests(unittest.TestCase): cache.entity_has_changed("user@foo.com", 6) cache.entity_has_changed("bar@baz.net", 7) + # also test multiple things changing on the same stream ID + cache.entity_has_changed("user2@foo.com", 8) + cache.entity_has_changed("bar2@baz.net", 8) + # If it's been changed after that stream position, return True self.assertTrue(cache.has_entity_changed("user@foo.com", 4)) self.assertTrue(cache.has_entity_changed("bar@baz.net", 4)) + self.assertTrue(cache.has_entity_changed("bar2@baz.net", 4)) + self.assertTrue(cache.has_entity_changed("user2@foo.com", 4)) # If it's been changed at that stream position, return False self.assertFalse(cache.has_entity_changed("user@foo.com", 6)) + self.assertFalse(cache.has_entity_changed("user2@foo.com", 8)) # If there's no changes after that stream position, return False self.assertFalse(cache.has_entity_changed("user@foo.com", 7)) + self.assertFalse(cache.has_entity_changed("user2@foo.com", 9)) # If the entity does not exist, return False. - self.assertFalse(cache.has_entity_changed("not@here.website", 7)) + self.assertFalse(cache.has_entity_changed("not@here.website", 9)) # If we request before the stream cache's earliest known position, # return True, whether it's a known entity or not. @@ -47,7 +55,7 @@ class StreamChangeCacheTests(unittest.TestCase): self.assertTrue(cache.has_entity_changed("not@here.website", 0)) @patch("synapse.util.caches.CACHE_SIZE_FACTOR", 1.0) - def test_has_entity_changed_pops_off_start(self): + def test_entity_has_changed_pops_off_start(self): """ StreamChangeCache.entity_has_changed will respect the max size and purge the oldest items upon reaching that max size. @@ -64,11 +72,20 @@ class StreamChangeCacheTests(unittest.TestCase): # The oldest item has been popped off self.assertTrue("user@foo.com" not in cache._entity_to_key) + self.assertEqual( + cache.get_all_entities_changed(2), ["bar@baz.net", "user@elsewhere.org"], + ) + self.assertIsNone(cache.get_all_entities_changed(1)) + # If we update an existing entity, it keeps the two existing entities cache.entity_has_changed("bar@baz.net", 5) self.assertEqual( {"bar@baz.net", "user@elsewhere.org"}, set(cache._entity_to_key) ) + self.assertEqual( + cache.get_all_entities_changed(2), ["user@elsewhere.org", "bar@baz.net"], + ) + self.assertIsNone(cache.get_all_entities_changed(1)) def test_get_all_entities_changed(self): """ @@ -80,18 +97,52 @@ class StreamChangeCacheTests(unittest.TestCase): cache.entity_has_changed("user@foo.com", 2) cache.entity_has_changed("bar@baz.net", 3) + cache.entity_has_changed("anotheruser@foo.com", 3) cache.entity_has_changed("user@elsewhere.org", 4) - self.assertEqual( - cache.get_all_entities_changed(1), - ["user@foo.com", "bar@baz.net", "user@elsewhere.org"], - ) - self.assertEqual( - cache.get_all_entities_changed(2), ["bar@baz.net", "user@elsewhere.org"] - ) + r = cache.get_all_entities_changed(1) + + # either of these are valid + ok1 = [ + "user@foo.com", + "bar@baz.net", + "anotheruser@foo.com", + "user@elsewhere.org", + ] + ok2 = [ + "user@foo.com", + "anotheruser@foo.com", + "bar@baz.net", + "user@elsewhere.org", + ] + self.assertTrue(r == ok1 or r == ok2) + + r = cache.get_all_entities_changed(2) + self.assertTrue(r == ok1[1:] or r == ok2[1:]) + self.assertEqual(cache.get_all_entities_changed(3), ["user@elsewhere.org"]) self.assertEqual(cache.get_all_entities_changed(0), None) + # ... later, things gest more updates + cache.entity_has_changed("user@foo.com", 5) + cache.entity_has_changed("bar@baz.net", 5) + cache.entity_has_changed("anotheruser@foo.com", 6) + + ok1 = [ + "user@elsewhere.org", + "user@foo.com", + "bar@baz.net", + "anotheruser@foo.com", + ] + ok2 = [ + "user@elsewhere.org", + "bar@baz.net", + "user@foo.com", + "anotheruser@foo.com", + ] + r = cache.get_all_entities_changed(3) + self.assertTrue(r == ok1 or r == ok2) + def test_has_any_entity_changed(self): """ StreamChangeCache.has_any_entity_changed will return True if any diff --git a/tox.ini b/tox.ini index 42b2d74891..31011d7436 100644 --- a/tox.ini +++ b/tox.ini @@ -202,7 +202,9 @@ commands = mypy \ synapse/spam_checker_api \ synapse/storage/engines \ synapse/storage/database.py \ - synapse/streams + synapse/streams \ + synapse/util/caches/stream_change_cache.py \ + tests/util/test_stream_change_cache.py # To find all folders that pass mypy you run: # -- cgit 1.5.1 From 82d8b1dd1f9712e6fe29da671b079c646a663a48 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 22 Apr 2020 14:34:31 +0100 Subject: Another go at fixing one-word commands (#7326) I messed this up last time I tried (#7239 / e13c6c7). --- changelog.d/7326.misc | 1 + synapse/replication/tcp/commands.py | 2 +- tests/replication/tcp/test_commands.py | 42 ++++++++++++++++++++++++++++++++++ 3 files changed, 44 insertions(+), 1 deletion(-) create mode 100644 changelog.d/7326.misc create mode 100644 tests/replication/tcp/test_commands.py (limited to 'tests') diff --git a/changelog.d/7326.misc b/changelog.d/7326.misc new file mode 100644 index 0000000000..676f285377 --- /dev/null +++ b/changelog.d/7326.misc @@ -0,0 +1 @@ +Move catchup of replication streams logic to worker. diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index 5002efe6a0..f26aee83cb 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -462,7 +462,7 @@ def parse_command_from_line(line: str) -> Command: Line should already be stripped of whitespace and be checked if blank. """ - idx = line.index(" ") + idx = line.find(" ") if idx >= 0: cmd_name = line[:idx] rest_of_line = line[idx + 1 :] diff --git a/tests/replication/tcp/test_commands.py b/tests/replication/tcp/test_commands.py new file mode 100644 index 0000000000..3cbcb513cc --- /dev/null +++ b/tests/replication/tcp/test_commands.py @@ -0,0 +1,42 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from synapse.replication.tcp.commands import ( + RdataCommand, + ReplicateCommand, + parse_command_from_line, +) + +from tests.unittest import TestCase + + +class ParseCommandTestCase(TestCase): + def test_parse_one_word_command(self): + line = "REPLICATE" + cmd = parse_command_from_line(line) + self.assertIsInstance(cmd, ReplicateCommand) + + def test_parse_rdata(self): + line = 'RDATA events 6287863 ["ev", ["$eventid", "!roomid", "type", null, null, null]]' + cmd = parse_command_from_line(line) + self.assertIsInstance(cmd, RdataCommand) + self.assertEqual(cmd.stream_name, "events") + self.assertEqual(cmd.token, 6287863) + + def test_parse_rdata_batch(self): + line = 'RDATA presence batch ["@foo:example.com", "online"]' + cmd = parse_command_from_line(line) + self.assertIsInstance(cmd, RdataCommand) + self.assertEqual(cmd.stream_name, "presence") + self.assertIsNone(cmd.token) -- cgit 1.5.1 From 2e3b9a0fcb81b539e155004ded8017ee9923eecc Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Thu, 23 Apr 2020 11:23:53 +0200 Subject: Revert "Revert "Merge pull request #7315 from matrix-org/babolivier/request_token"" This reverts commit 1adf6a55870aa08de272591ff49db9dc49738076. --- changelog.d/7315.feature | 1 + docs/sample_config.yaml | 10 ++++++ synapse/config/server.py | 21 +++++++++++++ synapse/rest/client/v2_alpha/account.py | 17 ++++++++++- synapse/rest/client/v2_alpha/register.py | 12 +++++++- tests/rest/client/v2_alpha/test_account.py | 16 ++++++++++ tests/rest/client/v2_alpha/test_register.py | 47 ++++++++++++++++++++++++++++- 7 files changed, 121 insertions(+), 3 deletions(-) create mode 100644 changelog.d/7315.feature (limited to 'tests') diff --git a/changelog.d/7315.feature b/changelog.d/7315.feature new file mode 100644 index 0000000000..ebcb4741b7 --- /dev/null +++ b/changelog.d/7315.feature @@ -0,0 +1 @@ +Allow `/requestToken` endpoints to hide the existence (or lack thereof) of 3PID associations on the homeserver. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index ca8accbc6e..6d5f4f316d 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -414,6 +414,16 @@ retention: # longest_max_lifetime: 1y # interval: 1d +# Inhibits the /requestToken endpoints from returning an error that might leak +# information about whether an e-mail address is in use or not on this +# homeserver. +# Note that for some endpoints the error situation is the e-mail already being +# used, and for others the error is entering the e-mail being unused. +# If this option is enabled, instead of returning an error, these endpoints will +# act as if no error happened and return a fake session ID ('sid') to clients. +# +#request_token_inhibit_3pid_errors: true + ## TLS ## diff --git a/synapse/config/server.py b/synapse/config/server.py index 28e2a031fb..c6d58effd4 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -507,6 +507,17 @@ class ServerConfig(Config): self.enable_ephemeral_messages = config.get("enable_ephemeral_messages", False) + # Inhibits the /requestToken endpoints from returning an error that might leak + # information about whether an e-mail address is in use or not on this + # homeserver, and instead return a 200 with a fake sid if this kind of error is + # met, without sending anything. + # This is a compromise between sending an email, which could be a spam vector, + # and letting the client know which email address is bound to an account and + # which one isn't. + self.request_token_inhibit_3pid_errors = config.get( + "request_token_inhibit_3pid_errors", False, + ) + def has_tls_listener(self) -> bool: return any(l["tls"] for l in self.listeners) @@ -972,6 +983,16 @@ class ServerConfig(Config): # - shortest_max_lifetime: 3d # longest_max_lifetime: 1y # interval: 1d + + # Inhibits the /requestToken endpoints from returning an error that might leak + # information about whether an e-mail address is in use or not on this + # homeserver. + # Note that for some endpoints the error situation is the e-mail already being + # used, and for others the error is entering the e-mail being unused. + # If this option is enabled, instead of returning an error, these endpoints will + # act as if no error happened and return a fake session ID ('sid') to clients. + # + #request_token_inhibit_3pid_errors: true """ % locals() ) diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index 31435b1e1c..1bd0234779 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -30,7 +30,7 @@ from synapse.http.servlet import ( ) from synapse.push.mailer import Mailer, load_jinja2_templates from synapse.util.msisdn import phone_number_to_msisdn -from synapse.util.stringutils import assert_valid_client_secret +from synapse.util.stringutils import assert_valid_client_secret, random_string from synapse.util.threepids import check_3pid_allowed from ._base import client_patterns, interactive_auth_handler @@ -100,6 +100,11 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): ) if existing_user_id is None: + if self.config.request_token_inhibit_3pid_errors: + # Make the client think the operation succeeded. See the rationale in the + # comments for request_token_inhibit_3pid_errors. + return 200, {"sid": random_string(16)} + raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND) if self.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE: @@ -390,6 +395,11 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): ) if existing_user_id is not None: + if self.config.request_token_inhibit_3pid_errors: + # Make the client think the operation succeeded. See the rationale in the + # comments for request_token_inhibit_3pid_errors. + return 200, {"sid": random_string(16)} + raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE) if self.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE: @@ -453,6 +463,11 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): existing_user_id = await self.store.get_user_id_by_threepid("msisdn", msisdn) if existing_user_id is not None: + if self.hs.config.request_token_inhibit_3pid_errors: + # Make the client think the operation succeeded. See the rationale in the + # comments for request_token_inhibit_3pid_errors. + return 200, {"sid": random_string(16)} + raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE) if not self.hs.config.account_threepid_delegate_msisdn: diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 431ecf4f84..d1b5c49989 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -49,7 +49,7 @@ from synapse.http.servlet import ( from synapse.push.mailer import load_jinja2_templates from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.ratelimitutils import FederationRateLimiter -from synapse.util.stringutils import assert_valid_client_secret +from synapse.util.stringutils import assert_valid_client_secret, random_string from synapse.util.threepids import check_3pid_allowed from ._base import client_patterns, interactive_auth_handler @@ -135,6 +135,11 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): ) if existing_user_id is not None: + if self.hs.config.request_token_inhibit_3pid_errors: + # Make the client think the operation succeeded. See the rationale in the + # comments for request_token_inhibit_3pid_errors. + return 200, {"sid": random_string(16)} + raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE) if self.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE: @@ -202,6 +207,11 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): ) if existing_user_id is not None: + if self.hs.config.request_token_inhibit_3pid_errors: + # Make the client think the operation succeeded. See the rationale in the + # comments for request_token_inhibit_3pid_errors. + return 200, {"sid": random_string(16)} + raise SynapseError( 400, "Phone number is already in use", Codes.THREEPID_IN_USE ) diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py index 45a9d445f8..0d6936fd36 100644 --- a/tests/rest/client/v2_alpha/test_account.py +++ b/tests/rest/client/v2_alpha/test_account.py @@ -179,6 +179,22 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): # Assert we can't log in with the new password self.attempt_wrong_password_login("kermit", new_password) + @unittest.override_config({"request_token_inhibit_3pid_errors": True}) + def test_password_reset_bad_email_inhibit_error(self): + """Test that triggering a password reset with an email address that isn't bound + to an account doesn't leak the lack of binding for that address if configured + that way. + """ + self.register_user("kermit", "monkey") + self.login("kermit", "monkey") + + email = "test@example.com" + + client_secret = "foobar" + session_id = self._request_token(email, client_secret) + + self.assertIsNotNone(session_id) + def _request_token(self, email, client_secret): request, channel = self.make_request( "POST", diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index b6ed06e02d..a68a96f618 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -33,7 +33,11 @@ from tests import unittest class RegisterRestServletTestCase(unittest.HomeserverTestCase): - servlets = [register.register_servlets] + servlets = [ + login.register_servlets, + register.register_servlets, + synapse.rest.admin.register_servlets, + ] url = b"/_matrix/client/r0/register" def default_config(self): @@ -260,6 +264,47 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): [["m.login.email.identity"]], (f["stages"] for f in flows) ) + @unittest.override_config( + { + "request_token_inhibit_3pid_errors": True, + "public_baseurl": "https://test_server", + "email": { + "smtp_host": "mail_server", + "smtp_port": 2525, + "notif_from": "sender@host", + }, + } + ) + def test_request_token_existing_email_inhibit_error(self): + """Test that requesting a token via this endpoint doesn't leak existing + associations if configured that way. + """ + user_id = self.register_user("kermit", "monkey") + self.login("kermit", "monkey") + + email = "test@example.com" + + # Add a threepid + self.get_success( + self.hs.get_datastore().user_add_threepid( + user_id=user_id, + medium="email", + address=email, + validated_at=0, + added_at=0, + ) + ) + + request, channel = self.make_request( + "POST", + b"register/email/requestToken", + {"client_secret": "foobar", "email": email, "send_attempt": 1}, + ) + self.render(request) + self.assertEquals(200, channel.code, channel.result) + + self.assertIsNotNone(channel.json_body.get("sid")) + class AccountValidityTestCase(unittest.HomeserverTestCase): -- cgit 1.5.1 From fb8ff79efd0897b0b7bf52b0c4bb4061a4ef4018 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 28 Apr 2020 14:21:48 +0100 Subject: Fix collation for postgres for unit tests (#7359) When running the UTs against a postgres deatbase, we need to set the collation correctly. --- changelog.d/7359.misc | 1 + tests/utils.py | 5 ++++- 2 files changed, 5 insertions(+), 1 deletion(-) create mode 100644 changelog.d/7359.misc (limited to 'tests') diff --git a/changelog.d/7359.misc b/changelog.d/7359.misc new file mode 100644 index 0000000000..b99f257d9a --- /dev/null +++ b/changelog.d/7359.misc @@ -0,0 +1 @@ +Fix collation for postgres for unit tests. diff --git a/tests/utils.py b/tests/utils.py index 2079e0143d..037cb134f0 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -74,7 +74,10 @@ def setupdb(): db_conn.autocommit = True cur = db_conn.cursor() cur.execute("DROP DATABASE IF EXISTS %s;" % (POSTGRES_BASE_DB,)) - cur.execute("CREATE DATABASE %s;" % (POSTGRES_BASE_DB,)) + cur.execute( + "CREATE DATABASE %s ENCODING 'UTF8' LC_COLLATE='C' LC_CTYPE='C' " + "template=template0;" % (POSTGRES_BASE_DB,) + ) cur.close() db_conn.close() -- cgit 1.5.1 From fce663889b150e55b19097dd9b7fed7aca8abccc Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 28 Apr 2020 17:42:03 +0100 Subject: Add some replication tests (#7278) Specifically some tests for the typing stream, which means we test streams that fetch missing updates via HTTP (rather than via the DB). We also shuffle things around a bit so that we create two separate `HomeServer` objects, rather than trying to insert a slaved store into places. Note: `test_typing.py` is heavily inspired by `test_receipts.py` --- changelog.d/7278.misc | 1 + tests/replication/tcp/streams/_base.py | 229 +++++++++++++++++++++++++-- tests/replication/tcp/streams/test_typing.py | 80 ++++++++++ 3 files changed, 299 insertions(+), 11 deletions(-) create mode 100644 changelog.d/7278.misc create mode 100644 tests/replication/tcp/streams/test_typing.py (limited to 'tests') diff --git a/changelog.d/7278.misc b/changelog.d/7278.misc new file mode 100644 index 0000000000..8a4c4328f4 --- /dev/null +++ b/changelog.d/7278.misc @@ -0,0 +1 @@ +Add some unit tests for replication. diff --git a/tests/replication/tcp/streams/_base.py b/tests/replication/tcp/streams/_base.py index 32238fe79a..82f15c64e0 100644 --- a/tests/replication/tcp/streams/_base.py +++ b/tests/replication/tcp/streams/_base.py @@ -12,34 +12,67 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging +from typing import Optional from mock import Mock +import attr + +from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime +from twisted.internet.task import LoopingCall +from twisted.web.http import HTTPChannel + +from synapse.app.generic_worker import GenericWorkerServer +from synapse.http.site import SynapseRequest +from synapse.replication.tcp.client import ReplicationDataHandler from synapse.replication.tcp.handler import ReplicationCommandHandler from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory +from synapse.util import Clock from tests import unittest from tests.server import FakeTransport +logger = logging.getLogger(__name__) + class BaseStreamTestCase(unittest.HomeserverTestCase): """Base class for tests of the replication streams""" - def make_homeserver(self, reactor, clock): - self.test_handler = Mock(wraps=TestReplicationDataHandler()) - return self.setup_test_homeserver(replication_data_handler=self.test_handler) - def prepare(self, reactor, clock, hs): # build a replication server server_factory = ReplicationStreamProtocolFactory(hs) self.streamer = hs.get_replication_streamer() self.server = server_factory.buildProtocol(None) - repl_handler = ReplicationCommandHandler(hs) - repl_handler.handler = self.test_handler + # Make a new HomeServer object for the worker + config = self.default_config() + config["worker_app"] = "synapse.app.generic_worker" + config["worker_replication_host"] = "testserv" + config["worker_replication_http_port"] = "8765" + + self.reactor.lookups["testserv"] = "1.2.3.4" + + self.worker_hs = self.setup_test_homeserver( + http_client=None, + homeserverToUse=GenericWorkerServer, + config=config, + reactor=self.reactor, + ) + + # Since we use sqlite in memory databases we need to make sure the + # databases objects are the same. + self.worker_hs.get_datastore().db = hs.get_datastore().db + + self.test_handler = Mock( + wraps=TestReplicationDataHandler(self.worker_hs.get_datastore()) + ) + self.worker_hs.replication_data_handler = self.test_handler + + repl_handler = ReplicationCommandHandler(self.worker_hs) self.client = ClientReplicationStreamProtocol( - hs, "client", "test", clock, repl_handler, + self.worker_hs, "client", "test", clock, repl_handler, ) self._client_transport = None @@ -74,11 +107,75 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): self.streamer.on_notifier_poke() self.pump(0.1) + def handle_http_replication_attempt(self) -> SynapseRequest: + """Asserts that a connection attempt was made to the master HS on the + HTTP replication port, then proxies it to the master HS object to be + handled. + + Returns: + The request object received by master HS. + """ + + # We should have an outbound connection attempt. + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) + self.assertEqual(host, "1.2.3.4") + self.assertEqual(port, 8765) + + # Set up client side protocol + client_protocol = client_factory.buildProtocol(None) + + request_factory = OneShotRequestFactory() -class TestReplicationDataHandler: + # Set up the server side protocol + channel = _PushHTTPChannel(self.reactor) + channel.requestFactory = request_factory + channel.site = self.site + + # Connect client to server and vice versa. + client_to_server_transport = FakeTransport( + channel, self.reactor, client_protocol + ) + client_protocol.makeConnection(client_to_server_transport) + + server_to_client_transport = FakeTransport( + client_protocol, self.reactor, channel + ) + channel.makeConnection(server_to_client_transport) + + # The request will now be processed by `self.site` and the response + # streamed back. + self.reactor.advance(0) + + # We tear down the connection so it doesn't get reused without our + # knowledge. + server_to_client_transport.loseConnection() + client_to_server_transport.loseConnection() + + return request_factory.request + + def assert_request_is_get_repl_stream_updates( + self, request: SynapseRequest, stream_name: str + ): + """Asserts that the given request is a HTTP replication request for + fetching updates for given stream. + """ + + self.assertRegex( + request.path, + br"^/_synapse/replication/get_repl_stream_updates/%s/[^/]+$" + % (stream_name.encode("ascii"),), + ) + + self.assertEqual(request.method, b"GET") + + +class TestReplicationDataHandler(ReplicationDataHandler): """Drop-in for ReplicationDataHandler which just collects RDATA rows""" - def __init__(self): + def __init__(self, hs): + super().__init__(hs) self.streams = set() self._received_rdata_rows = [] @@ -90,8 +187,118 @@ class TestReplicationDataHandler: return positions async def on_rdata(self, stream_name, token, rows): + await super().on_rdata(stream_name, token, rows) for r in rows: self._received_rdata_rows.append((stream_name, token, r)) - async def on_position(self, stream_name, token): - pass + +@attr.s() +class OneShotRequestFactory: + """A simple request factory that generates a single `SynapseRequest` and + stores it for future use. Can only be used once. + """ + + request = attr.ib(default=None) + + def __call__(self, *args, **kwargs): + assert self.request is None + + self.request = SynapseRequest(*args, **kwargs) + return self.request + + +class _PushHTTPChannel(HTTPChannel): + """A HTTPChannel that wraps pull producers to push producers. + + This is a hack to get around the fact that HTTPChannel transparently wraps a + pull producer (which is what Synapse uses to reply to requests) with + `_PullToPush` to convert it to a push producer. Unfortunately `_PullToPush` + uses the standard reactor rather than letting us use our test reactor, which + makes it very hard to test. + """ + + def __init__(self, reactor: IReactorTime): + super().__init__() + self.reactor = reactor + + self._pull_to_push_producer = None + + def registerProducer(self, producer, streaming): + # Convert pull producers to push producer. + if not streaming: + self._pull_to_push_producer = _PullToPushProducer( + self.reactor, producer, self + ) + producer = self._pull_to_push_producer + + super().registerProducer(producer, True) + + def unregisterProducer(self): + if self._pull_to_push_producer: + # We need to manually stop the _PullToPushProducer. + self._pull_to_push_producer.stop() + + +class _PullToPushProducer: + """A push producer that wraps a pull producer. + """ + + def __init__( + self, reactor: IReactorTime, producer: IPullProducer, consumer: IConsumer + ): + self._clock = Clock(reactor) + self._producer = producer + self._consumer = consumer + + # While running we use a looping call with a zero delay to call + # resumeProducing on given producer. + self._looping_call = None # type: Optional[LoopingCall] + + # We start writing next reactor tick. + self._start_loop() + + def _start_loop(self): + """Start the looping call to + """ + + if not self._looping_call: + # Start a looping call which runs every tick. + self._looping_call = self._clock.looping_call(self._run_once, 0) + + def stop(self): + """Stops calling resumeProducing. + """ + if self._looping_call: + self._looping_call.stop() + self._looping_call = None + + def pauseProducing(self): + """Implements IPushProducer + """ + self.stop() + + def resumeProducing(self): + """Implements IPushProducer + """ + self._start_loop() + + def stopProducing(self): + """Implements IPushProducer + """ + self.stop() + self._producer.stopProducing() + + def _run_once(self): + """Calls resumeProducing on producer once. + """ + + try: + self._producer.resumeProducing() + except Exception: + logger.exception("Failed to call resumeProducing") + try: + self._consumer.unregisterProducer() + except Exception: + pass + + self.stopProducing() diff --git a/tests/replication/tcp/streams/test_typing.py b/tests/replication/tcp/streams/test_typing.py new file mode 100644 index 0000000000..f0ad6402ae --- /dev/null +++ b/tests/replication/tcp/streams/test_typing.py @@ -0,0 +1,80 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from synapse.handlers.typing import RoomMember +from synapse.replication.http import streams +from synapse.replication.tcp.streams import TypingStream + +from tests.replication.tcp.streams._base import BaseStreamTestCase + +USER_ID = "@feeling:blue" + + +class TypingStreamTestCase(BaseStreamTestCase): + servlets = [ + streams.register_servlets, + ] + + def test_typing(self): + typing = self.hs.get_typing_handler() + + room_id = "!bar:blue" + + self.reconnect() + + # make the client subscribe to the receipts stream + self.test_handler.streams.add("typing") + + typing._push_update(member=RoomMember(room_id, USER_ID), typing=True) + + self.reactor.advance(0) + + # We should now see an attempt to connect to the master + request = self.handle_http_replication_attempt() + self.assert_request_is_get_repl_stream_updates(request, "typing") + + self.test_handler.on_rdata.assert_called_once() + stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0] + self.assertEqual(stream_name, "typing") + self.assertEqual(1, len(rdata_rows)) + row = rdata_rows[0] # type: TypingStream.TypingStreamRow + self.assertEqual(room_id, row.room_id) + self.assertEqual([USER_ID], row.user_ids) + + # Now let's disconnect and insert some data. + self.disconnect() + + self.test_handler.on_rdata.reset_mock() + + typing._push_update(member=RoomMember(room_id, USER_ID), typing=False) + + self.test_handler.on_rdata.assert_not_called() + + self.reconnect() + self.pump(0.1) + + # We should now see an attempt to connect to the master + request = self.handle_http_replication_attempt() + self.assert_request_is_get_repl_stream_updates(request, "typing") + + # The from token should be the token from the last RDATA we got. + self.assertEqual(int(request.args[b"from_token"][0]), token) + + self.test_handler.on_rdata.assert_called_once() + stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0] + self.assertEqual(stream_name, "typing") + self.assertEqual(1, len(rdata_rows)) + row = rdata_rows[0] # type: TypingStream.TypingStreamRow + self.assertEqual(room_id, row.room_id) + self.assertEqual([], row.user_ids) -- cgit 1.5.1 From 04dd7d182d0601289e0e047243b50803f526ef69 Mon Sep 17 00:00:00 2001 From: Manuel Stahl <37705355+awesome-manuel@users.noreply.github.com> Date: Tue, 28 Apr 2020 19:19:36 +0200 Subject: Return total number of users and profile attributes in admin users endpoint (#6881) Signed-off-by: Manuel Stahl --- changelog.d/6881.misc | 1 + docs/admin_api/user_admin_api.rst | 11 +++-- synapse/rest/admin/users.py | 8 ++-- synapse/storage/data_stores/main/__init__.py | 68 ++++++++++++++++------------ tests/rest/admin/test_user.py | 2 + tests/storage/test_main.py | 46 +++++++++++++++++++ 6 files changed, 100 insertions(+), 36 deletions(-) create mode 100644 changelog.d/6881.misc create mode 100644 tests/storage/test_main.py (limited to 'tests') diff --git a/changelog.d/6881.misc b/changelog.d/6881.misc new file mode 100644 index 0000000000..03b89ccd3d --- /dev/null +++ b/changelog.d/6881.misc @@ -0,0 +1 @@ +Return total number of users and profile attributes in admin users endpoint. Contributed by Awesome Technologies Innovationslabor GmbH. diff --git a/docs/admin_api/user_admin_api.rst b/docs/admin_api/user_admin_api.rst index 9ce10119ff..927ed65f77 100644 --- a/docs/admin_api/user_admin_api.rst +++ b/docs/admin_api/user_admin_api.rst @@ -72,17 +72,22 @@ It returns a JSON body like the following: "is_guest": 0, "admin": 0, "user_type": null, - "deactivated": 0 + "deactivated": 0, + "displayname": , + "avatar_url": null }, { "name": "", "password_hash": "", "is_guest": 0, "admin": 1, "user_type": null, - "deactivated": 0 + "deactivated": 0, + "displayname": , + "avatar_url": "" } ], - "next_token": "100" + "next_token": "100", + "total": 200 } diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 8551ac19b8..593ce011e8 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -94,10 +94,10 @@ class UsersRestServletV2(RestServlet): guests = parse_boolean(request, "guests", default=True) deactivated = parse_boolean(request, "deactivated", default=False) - users = await self.store.get_users_paginate( + users, total = await self.store.get_users_paginate( start, limit, user_id, guests, deactivated ) - ret = {"users": users} + ret = {"users": users, "total": total} if len(users) >= limit: ret["next_token"] = str(start + len(users)) @@ -199,7 +199,7 @@ class UserRestServletV2(RestServlet): user_id, threepid["medium"], threepid["address"], current_time ) - if "avatar_url" in body: + if "avatar_url" in body and type(body["avatar_url"]) == str: await self.profile_handler.set_avatar_url( target_user, requester, body["avatar_url"], True ) @@ -276,7 +276,7 @@ class UserRestServletV2(RestServlet): user_id, threepid["medium"], threepid["address"], current_time ) - if "avatar_url" in body: + if "avatar_url" in body and type(body["avatar_url"]) == str: await self.profile_handler.set_avatar_url( user_id, requester, body["avatar_url"], True ) diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/data_stores/main/__init__.py index 649e835303..bd7c3a00ea 100644 --- a/synapse/storage/data_stores/main/__init__.py +++ b/synapse/storage/data_stores/main/__init__.py @@ -503,7 +503,8 @@ class DataStore( self, start, limit, name=None, guests=True, deactivated=False ): """Function to retrieve a paginated list of users from - users list. This will return a json list of users. + users list. This will return a json list of users and the + total number of users matching the filter criteria. Args: start (int): start number to begin the query from @@ -512,35 +513,44 @@ class DataStore( guests (bool): whether to in include guest users deactivated (bool): whether to include deactivated users Returns: - defer.Deferred: resolves to list[dict[str, Any]] + defer.Deferred: resolves to list[dict[str, Any]], int """ - name_filter = {} - if name: - name_filter["name"] = "%" + name + "%" - - attr_filter = {} - if not guests: - attr_filter["is_guest"] = 0 - if not deactivated: - attr_filter["deactivated"] = 0 - - return self.db.simple_select_list_paginate( - desc="get_users_paginate", - table="users", - orderby="name", - start=start, - limit=limit, - filters=name_filter, - keyvalues=attr_filter, - retcols=[ - "name", - "password_hash", - "is_guest", - "admin", - "user_type", - "deactivated", - ], - ) + + def get_users_paginate_txn(txn): + filters = [] + args = [] + + if name: + filters.append("name LIKE ?") + args.append("%" + name + "%") + + if not guests: + filters.append("is_guest = 0") + + if not deactivated: + filters.append("deactivated = 0") + + where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else "" + + sql = "SELECT COUNT(*) as total_users FROM users %s" % (where_clause) + txn.execute(sql, args) + count = txn.fetchone()[0] + + args = [self.hs.config.server_name] + args + [limit, start] + sql = """ + SELECT name, user_type, is_guest, admin, deactivated, displayname, avatar_url + FROM users as u + LEFT JOIN profiles AS p ON u.name = '@' || p.user_id || ':' || ? + {} + ORDER BY u.name LIMIT ? OFFSET ? + """.format( + where_clause + ) + txn.execute(sql, args) + users = self.db.cursor_to_dict(txn) + return users, count + + return self.db.runInteraction("get_users_paginate_txn", get_users_paginate_txn) def search_users(self, term): """Function to search users list for one or more users with diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 6416fb5d2a..6c88ab06e2 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -360,6 +360,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(3, len(channel.json_body["users"])) + self.assertEqual(3, channel.json_body["total"]) class UserRestTestCase(unittest.HomeserverTestCase): @@ -434,6 +435,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): "admin": True, "displayname": "Bob's name", "threepids": [{"medium": "email", "address": "bob@bob.bob"}], + "avatar_url": None, } ) diff --git a/tests/storage/test_main.py b/tests/storage/test_main.py new file mode 100644 index 0000000000..ab0df5ea93 --- /dev/null +++ b/tests/storage/test_main.py @@ -0,0 +1,46 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Awesome Technologies Innovationslabor GmbH +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from twisted.internet import defer + +from synapse.types import UserID + +from tests import unittest +from tests.utils import setup_test_homeserver + + +class DataStoreTestCase(unittest.TestCase): + @defer.inlineCallbacks + def setUp(self): + hs = yield setup_test_homeserver(self.addCleanup) + + self.store = hs.get_datastore() + + self.user = UserID.from_string("@abcde:test") + self.displayname = "Frank" + + @defer.inlineCallbacks + def test_get_users_paginate(self): + yield self.store.register_user(self.user.to_string(), "pass") + yield self.store.create_profile(self.user.localpart) + yield self.store.set_profile_displayname(self.user.localpart, self.displayname) + + users, total = yield self.store.get_users_paginate( + 0, 10, name="bc", guests=False + ) + + self.assertEquals(1, total) + self.assertEquals(self.displayname, users.pop()["displayname"]) -- cgit 1.5.1 From c2e1a2110fbe9ead26b4ecbb1afd504ed035a04d Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 29 Apr 2020 12:30:36 +0100 Subject: Fix limit logic for EventsStream (#7358) * Factor out functions for injecting events into database I want to add some more flexibility to the tools for injecting events into the database, and I don't want to clutter up HomeserverTestCase with them, so let's factor them out to a new file. * Rework TestReplicationDataHandler This wasn't very easy to work with: the mock wrapping was largely superfluous, and it's useful to be able to inspect the received rows, and clear out the received list. * Fix AssertionErrors being thrown by EventsStream Part of the problem was that there was an off-by-one error in the assertion, but also the limit logic was too simple. Fix it all up and add some tests. --- changelog.d/7358.bugfix | 1 + synapse/replication/tcp/handler.py | 4 +- synapse/replication/tcp/streams/events.py | 22 +- synapse/server.pyi | 5 + synapse/storage/data_stores/main/events_worker.py | 64 +++- tests/replication/tcp/streams/_base.py | 41 ++- tests/replication/tcp/streams/test_events.py | 417 ++++++++++++++++++++++ tests/replication/tcp/streams/test_receipts.py | 10 +- tests/replication/tcp/streams/test_typing.py | 11 +- tests/rest/client/v1/utils.py | 2 +- tests/test_utils/__init__.py | 20 ++ tests/test_utils/event_injection.py | 96 +++++ tests/unittest.py | 30 +- tox.ini | 2 + 14 files changed, 658 insertions(+), 67 deletions(-) create mode 100644 changelog.d/7358.bugfix create mode 100644 tests/replication/tcp/streams/test_events.py create mode 100644 tests/test_utils/event_injection.py (limited to 'tests') diff --git a/changelog.d/7358.bugfix b/changelog.d/7358.bugfix new file mode 100644 index 0000000000..f49c600173 --- /dev/null +++ b/changelog.d/7358.bugfix @@ -0,0 +1 @@ +Fix a bug where event updates might not be sent over replication to worker processes after the stream falls behind. diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 0db5a3a24d..3a8c7c7e2d 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -87,7 +87,9 @@ class ReplicationCommandHandler: stream.NAME: stream(hs) for stream in STREAMS_MAP.values() } # type: Dict[str, Stream] - self._position_linearizer = Linearizer("replication_position") + self._position_linearizer = Linearizer( + "replication_position", clock=self._clock + ) # Map of stream to batched updates. See RdataCommand for info on how # batching works. diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index aa50492569..52df81b1bd 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -170,22 +170,16 @@ class EventsStream(Stream): limited = False upper_limit = current_token - # next up is the state delta table - - state_rows = await self._store.get_all_updated_current_state_deltas( + # next up is the state delta table. + ( + state_rows, + upper_limit, + state_rows_limited, + ) = await self._store.get_all_updated_current_state_deltas( from_token, upper_limit, target_row_count - ) # type: List[Tuple] - - # again, if we've hit the limit there, we'll need to limit the other sources - assert len(state_rows) < target_row_count - if len(state_rows) == target_row_count: - assert state_rows[-1][0] <= upper_limit - upper_limit = state_rows[-1][0] - limited = True + ) - # FIXME: is it a given that there is only one row per stream_id in the - # state_deltas table (so that we can be sure that we have got all of the - # rows for upper_limit)? + limited = limited or state_rows_limited # finally, fetch the ex-outliers rows. We assume there are few enough of these # not to bother with the limit. diff --git a/synapse/server.pyi b/synapse/server.pyi index f1a5717028..fc5886f762 100644 --- a/synapse/server.pyi +++ b/synapse/server.pyi @@ -25,6 +25,7 @@ import synapse.server_notices.server_notices_manager import synapse.server_notices.server_notices_sender import synapse.state import synapse.storage +from synapse.events.builder import EventBuilderFactory class HomeServer(object): @property @@ -121,3 +122,7 @@ class HomeServer(object): pass def get_instance_id(self) -> str: pass + def get_event_builder_factory(self) -> EventBuilderFactory: + pass + def get_storage(self) -> synapse.storage.Storage: + pass diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py index ce8be72bfe..73df6b33ba 100644 --- a/synapse/storage/data_stores/main/events_worker.py +++ b/synapse/storage/data_stores/main/events_worker.py @@ -19,7 +19,7 @@ import itertools import logging import threading from collections import namedtuple -from typing import List, Optional +from typing import List, Optional, Tuple from canonicaljson import json from constantly import NamedConstant, Names @@ -1084,7 +1084,28 @@ class EventsWorkerStore(SQLBaseStore): "get_all_new_backfill_event_rows", get_all_new_backfill_event_rows ) - def get_all_updated_current_state_deltas(self, from_token, to_token, limit): + async def get_all_updated_current_state_deltas( + self, from_token: int, to_token: int, target_row_count: int + ) -> Tuple[List[Tuple], int, bool]: + """Fetch updates from current_state_delta_stream + + Args: + from_token: The previous stream token. Updates from this stream id will + be excluded. + + to_token: The current stream token (ie the upper limit). Updates up to this + stream id will be included (modulo the 'limit' param) + + target_row_count: The number of rows to try to return. If more rows are + available, we will set 'limited' in the result. In the event of a large + batch, we may return more rows than this. + Returns: + A triplet `(updates, new_last_token, limited)`, where: + * `updates` is a list of database tuples. + * `new_last_token` is the new position in stream. + * `limited` is whether there are more updates to fetch. + """ + def get_all_updated_current_state_deltas_txn(txn): sql = """ SELECT stream_id, room_id, type, state_key, event_id @@ -1092,10 +1113,45 @@ class EventsWorkerStore(SQLBaseStore): WHERE ? < stream_id AND stream_id <= ? ORDER BY stream_id ASC LIMIT ? """ - txn.execute(sql, (from_token, to_token, limit)) + txn.execute(sql, (from_token, to_token, target_row_count)) return txn.fetchall() - return self.db.runInteraction( + def get_deltas_for_stream_id_txn(txn, stream_id): + sql = """ + SELECT stream_id, room_id, type, state_key, event_id + FROM current_state_delta_stream + WHERE stream_id = ? + """ + txn.execute(sql, [stream_id]) + return txn.fetchall() + + # we need to make sure that, for every stream id in the results, we get *all* + # the rows with that stream id. + + rows = await self.db.runInteraction( "get_all_updated_current_state_deltas", get_all_updated_current_state_deltas_txn, + ) # type: List[Tuple] + + # if we've got fewer rows than the limit, we're good + if len(rows) < target_row_count: + return rows, to_token, False + + # we hit the limit, so reduce the upper limit so that we exclude the stream id + # of the last row in the result. + assert rows[-1][0] <= to_token + to_token = rows[-1][0] - 1 + + # search backwards through the list for the point to truncate + for idx in range(len(rows) - 1, 0, -1): + if rows[idx - 1][0] <= to_token: + return rows[:idx], to_token, True + + # bother. We didn't get a full set of changes for even a single + # stream id. let's run the query again, without a row limit, but for + # just one stream id. + to_token += 1 + rows = await self.db.runInteraction( + "get_deltas_for_stream_id", get_deltas_for_stream_id_txn, to_token ) + return rows, to_token, True diff --git a/tests/replication/tcp/streams/_base.py b/tests/replication/tcp/streams/_base.py index 82f15c64e0..83e16cfe3d 100644 --- a/tests/replication/tcp/streams/_base.py +++ b/tests/replication/tcp/streams/_base.py @@ -12,10 +12,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import logging -from typing import Optional -from mock import Mock +import logging +from typing import Any, Dict, List, Optional, Tuple import attr @@ -25,6 +24,7 @@ from twisted.web.http import HTTPChannel from synapse.app.generic_worker import GenericWorkerServer from synapse.http.site import SynapseRequest +from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.tcp.client import ReplicationDataHandler from synapse.replication.tcp.handler import ReplicationCommandHandler from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol @@ -65,9 +65,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): # databases objects are the same. self.worker_hs.get_datastore().db = hs.get_datastore().db - self.test_handler = Mock( - wraps=TestReplicationDataHandler(self.worker_hs.get_datastore()) - ) + self.test_handler = self._build_replication_data_handler() self.worker_hs.replication_data_handler = self.test_handler repl_handler = ReplicationCommandHandler(self.worker_hs) @@ -78,6 +76,9 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): self._client_transport = None self._server_transport = None + def _build_replication_data_handler(self): + return TestReplicationDataHandler(self.worker_hs.get_datastore()) + def reconnect(self): if self._client_transport: self.client.close() @@ -174,22 +175,28 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): class TestReplicationDataHandler(ReplicationDataHandler): """Drop-in for ReplicationDataHandler which just collects RDATA rows""" - def __init__(self, hs): - super().__init__(hs) - self.streams = set() - self._received_rdata_rows = [] + def __init__(self, store: BaseSlavedStore): + super().__init__(store) + + # streams to subscribe to: map from stream id to position + self.stream_positions = {} # type: Dict[str, int] + + # list of received (stream_name, token, row) tuples + self.received_rdata_rows = [] # type: List[Tuple[str, int, Any]] def get_streams_to_replicate(self): - positions = {s: 0 for s in self.streams} - for stream, token, _ in self._received_rdata_rows: - if stream in self.streams: - positions[stream] = max(token, positions.get(stream, 0)) - return positions + return self.stream_positions async def on_rdata(self, stream_name, token, rows): await super().on_rdata(stream_name, token, rows) for r in rows: - self._received_rdata_rows.append((stream_name, token, r)) + self.received_rdata_rows.append((stream_name, token, r)) + + if ( + stream_name in self.stream_positions + and token > self.stream_positions[stream_name] + ): + self.stream_positions[stream_name] = token @attr.s() @@ -221,7 +228,7 @@ class _PushHTTPChannel(HTTPChannel): super().__init__() self.reactor = reactor - self._pull_to_push_producer = None + self._pull_to_push_producer = None # type: Optional[_PullToPushProducer] def registerProducer(self, producer, streaming): # Convert pull producers to push producer. diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py new file mode 100644 index 0000000000..1fa28084f9 --- /dev/null +++ b/tests/replication/tcp/streams/test_events.py @@ -0,0 +1,417 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional + +from synapse.api.constants import EventTypes, Membership +from synapse.events import EventBase +from synapse.replication.tcp.streams._base import _STREAM_UPDATE_TARGET_ROW_COUNT +from synapse.replication.tcp.streams.events import ( + EventsStreamCurrentStateRow, + EventsStreamEventRow, + EventsStreamRow, +) +from synapse.rest import admin +from synapse.rest.client.v1 import login, room + +from tests.replication.tcp.streams._base import BaseStreamTestCase +from tests.test_utils.event_injection import inject_event, inject_member_event + + +class EventsStreamTestCase(BaseStreamTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + super().prepare(reactor, clock, hs) + self.user_id = self.register_user("u1", "pass") + self.user_tok = self.login("u1", "pass") + + self.reconnect() + self.test_handler.stream_positions["events"] = 0 + + self.room_id = self.helper.create_room_as(tok=self.user_tok) + self.test_handler.received_rdata_rows.clear() + + def test_update_function_event_row_limit(self): + """Test replication with many non-state events + + Checks that all events are correctly replicated when there are lots of + event rows to be replicated. + """ + # disconnect, so that we can stack up some changes + self.disconnect() + + # generate lots of non-state events. We inject them using inject_event + # so that they are not send out over replication until we call self.replicate(). + events = [ + self._inject_test_event() + for _ in range(_STREAM_UPDATE_TARGET_ROW_COUNT + 1) + ] + + # also one state event + state_event = self._inject_state_event() + + # tell the notifier to catch up to avoid duplicate rows. + # workaround for https://github.com/matrix-org/synapse/issues/7360 + # FIXME remove this when the above is fixed + self.replicate() + + # check we're testing what we think we are: no rows should yet have been + # received + self.assertEqual([], self.test_handler.received_rdata_rows) + + # now reconnect to pull the updates + self.reconnect() + self.replicate() + + # we should have received all the expected rows in the right order + received_rows = self.test_handler.received_rdata_rows + for event in events: + stream_name, token, row = received_rows.pop(0) + self.assertEqual("events", stream_name) + self.assertIsInstance(row, EventsStreamRow) + self.assertEqual(row.type, "ev") + self.assertIsInstance(row.data, EventsStreamEventRow) + self.assertEqual(row.data.event_id, event.event_id) + + stream_name, token, row = received_rows.pop(0) + self.assertIsInstance(row, EventsStreamRow) + self.assertIsInstance(row.data, EventsStreamEventRow) + self.assertEqual(row.data.event_id, state_event.event_id) + + stream_name, token, row = received_rows.pop(0) + self.assertEqual("events", stream_name) + self.assertIsInstance(row, EventsStreamRow) + self.assertEqual(row.type, "state") + self.assertIsInstance(row.data, EventsStreamCurrentStateRow) + self.assertEqual(row.data.event_id, state_event.event_id) + + self.assertEqual([], received_rows) + + def test_update_function_huge_state_change(self): + """Test replication with many state events + + Ensures that all events are correctly replicated when there are lots of + state change rows to be replicated. + """ + + # we want to generate lots of state changes at a single stream ID. + # + # We do this by having two branches in the DAG. On one, we have a moderator + # which that generates lots of state; on the other, we de-op the moderator, + # thus invalidating all the state. + + OTHER_USER = "@other_user:localhost" + + # have the user join + inject_member_event(self.hs, self.room_id, OTHER_USER, Membership.JOIN) + + # Update existing power levels with mod at PL50 + pls = self.helper.get_state( + self.room_id, EventTypes.PowerLevels, tok=self.user_tok + ) + pls["users"][OTHER_USER] = 50 + self.helper.send_state( + self.room_id, EventTypes.PowerLevels, pls, tok=self.user_tok, + ) + + # this is the point in the DAG where we make a fork + fork_point = self.get_success( + self.hs.get_datastore().get_latest_event_ids_in_room(self.room_id) + ) # type: List[str] + + events = [ + self._inject_state_event(sender=OTHER_USER) + for _ in range(_STREAM_UPDATE_TARGET_ROW_COUNT) + ] + + self.replicate() + # all those events and state changes should have landed + self.assertGreaterEqual( + len(self.test_handler.received_rdata_rows), 2 * len(events) + ) + + # disconnect, so that we can stack up the changes + self.disconnect() + self.test_handler.received_rdata_rows.clear() + + # a state event which doesn't get rolled back, to check that the state + # before the huge update comes through ok + state1 = self._inject_state_event() + + # roll back all the state by de-modding the user + prev_events = fork_point + pls["users"][OTHER_USER] = 0 + pl_event = inject_event( + self.hs, + prev_event_ids=prev_events, + type=EventTypes.PowerLevels, + state_key="", + sender=self.user_id, + room_id=self.room_id, + content=pls, + ) + + # one more bit of state that doesn't get rolled back + state2 = self._inject_state_event() + + # tell the notifier to catch up to avoid duplicate rows. + # workaround for https://github.com/matrix-org/synapse/issues/7360 + # FIXME remove this when the above is fixed + self.replicate() + + # check we're testing what we think we are: no rows should yet have been + # received + self.assertEqual([], self.test_handler.received_rdata_rows) + + # now reconnect to pull the updates + self.reconnect() + self.replicate() + + # now we should have received all the expected rows in the right order. + # + # we expect: + # + # - two rows for state1 + # - the PL event row, plus state rows for the PL event and each + # of the states that got reverted. + # - two rows for state2 + + received_rows = self.test_handler.received_rdata_rows + + # first check the first two rows, which should be state1 + + stream_name, token, row = received_rows.pop(0) + self.assertEqual("events", stream_name) + self.assertIsInstance(row, EventsStreamRow) + self.assertEqual(row.type, "ev") + self.assertIsInstance(row.data, EventsStreamEventRow) + self.assertEqual(row.data.event_id, state1.event_id) + + stream_name, token, row = received_rows.pop(0) + self.assertIsInstance(row, EventsStreamRow) + self.assertEqual(row.type, "state") + self.assertIsInstance(row.data, EventsStreamCurrentStateRow) + self.assertEqual(row.data.event_id, state1.event_id) + + # now the last two rows, which should be state2 + stream_name, token, row = received_rows.pop(-2) + self.assertEqual("events", stream_name) + self.assertIsInstance(row, EventsStreamRow) + self.assertEqual(row.type, "ev") + self.assertIsInstance(row.data, EventsStreamEventRow) + self.assertEqual(row.data.event_id, state2.event_id) + + stream_name, token, row = received_rows.pop(-1) + self.assertIsInstance(row, EventsStreamRow) + self.assertEqual(row.type, "state") + self.assertIsInstance(row.data, EventsStreamCurrentStateRow) + self.assertEqual(row.data.event_id, state2.event_id) + + # that should leave us with the rows for the PL event + self.assertEqual(len(received_rows), len(events) + 2) + + stream_name, token, row = received_rows.pop(0) + self.assertEqual("events", stream_name) + self.assertIsInstance(row, EventsStreamRow) + self.assertEqual(row.type, "ev") + self.assertIsInstance(row.data, EventsStreamEventRow) + self.assertEqual(row.data.event_id, pl_event.event_id) + + # the state rows are unsorted + state_rows = [] # type: List[EventsStreamCurrentStateRow] + for stream_name, token, row in received_rows: + self.assertEqual("events", stream_name) + self.assertIsInstance(row, EventsStreamRow) + self.assertEqual(row.type, "state") + self.assertIsInstance(row.data, EventsStreamCurrentStateRow) + state_rows.append(row.data) + + state_rows.sort(key=lambda r: r.state_key) + + sr = state_rows.pop(0) + self.assertEqual(sr.type, EventTypes.PowerLevels) + self.assertEqual(sr.event_id, pl_event.event_id) + for sr in state_rows: + self.assertEqual(sr.type, "test_state_event") + # "None" indicates the state has been deleted + self.assertIsNone(sr.event_id) + + def test_update_function_state_row_limit(self): + """Test replication with many state events over several stream ids. + """ + + # we want to generate lots of state changes, but for this test, we want to + # spread out the state changes over a few stream IDs. + # + # We do this by having two branches in the DAG. On one, we have four moderators, + # each of which that generates lots of state; on the other, we de-op the users, + # thus invalidating all the state. + + NUM_USERS = 4 + STATES_PER_USER = _STREAM_UPDATE_TARGET_ROW_COUNT // 4 + 1 + + user_ids = ["@user%i:localhost" % (i,) for i in range(NUM_USERS)] + + # have the users join + for u in user_ids: + inject_member_event(self.hs, self.room_id, u, Membership.JOIN) + + # Update existing power levels with mod at PL50 + pls = self.helper.get_state( + self.room_id, EventTypes.PowerLevels, tok=self.user_tok + ) + pls["users"].update({u: 50 for u in user_ids}) + self.helper.send_state( + self.room_id, EventTypes.PowerLevels, pls, tok=self.user_tok, + ) + + # this is the point in the DAG where we make a fork + fork_point = self.get_success( + self.hs.get_datastore().get_latest_event_ids_in_room(self.room_id) + ) # type: List[str] + + events = [] # type: List[EventBase] + for user in user_ids: + events.extend( + self._inject_state_event(sender=user) for _ in range(STATES_PER_USER) + ) + + self.replicate() + + # all those events and state changes should have landed + self.assertGreaterEqual( + len(self.test_handler.received_rdata_rows), 2 * len(events) + ) + + # disconnect, so that we can stack up the changes + self.disconnect() + self.test_handler.received_rdata_rows.clear() + + # now roll back all that state by de-modding the users + prev_events = fork_point + pl_events = [] + for u in user_ids: + pls["users"][u] = 0 + e = inject_event( + self.hs, + prev_event_ids=prev_events, + type=EventTypes.PowerLevels, + state_key="", + sender=self.user_id, + room_id=self.room_id, + content=pls, + ) + prev_events = [e.event_id] + pl_events.append(e) + + # tell the notifier to catch up to avoid duplicate rows. + # workaround for https://github.com/matrix-org/synapse/issues/7360 + # FIXME remove this when the above is fixed + self.replicate() + + # check we're testing what we think we are: no rows should yet have been + # received + self.assertEqual([], self.test_handler.received_rdata_rows) + + # now reconnect to pull the updates + self.reconnect() + self.replicate() + + # we should have received all the expected rows in the right order + + received_rows = self.test_handler.received_rdata_rows + self.assertGreaterEqual(len(received_rows), len(events)) + for i in range(NUM_USERS): + # for each user, we expect the PL event row, followed by state rows for + # the PL event and each of the states that got reverted. + stream_name, token, row = received_rows.pop(0) + self.assertEqual("events", stream_name) + self.assertIsInstance(row, EventsStreamRow) + self.assertEqual(row.type, "ev") + self.assertIsInstance(row.data, EventsStreamEventRow) + self.assertEqual(row.data.event_id, pl_events[i].event_id) + + # the state rows are unsorted + state_rows = [] # type: List[EventsStreamCurrentStateRow] + for j in range(STATES_PER_USER + 1): + stream_name, token, row = received_rows.pop(0) + self.assertEqual("events", stream_name) + self.assertIsInstance(row, EventsStreamRow) + self.assertEqual(row.type, "state") + self.assertIsInstance(row.data, EventsStreamCurrentStateRow) + state_rows.append(row.data) + + state_rows.sort(key=lambda r: r.state_key) + + sr = state_rows.pop(0) + self.assertEqual(sr.type, EventTypes.PowerLevels) + self.assertEqual(sr.event_id, pl_events[i].event_id) + for sr in state_rows: + self.assertEqual(sr.type, "test_state_event") + # "None" indicates the state has been deleted + self.assertIsNone(sr.event_id) + + self.assertEqual([], received_rows) + + event_count = 0 + + def _inject_test_event( + self, body: Optional[str] = None, sender: Optional[str] = None, **kwargs + ) -> EventBase: + if sender is None: + sender = self.user_id + + if body is None: + body = "event %i" % (self.event_count,) + self.event_count += 1 + + return inject_event( + self.hs, + room_id=self.room_id, + sender=sender, + type="test_event", + content={"body": body}, + **kwargs + ) + + def _inject_state_event( + self, + body: Optional[str] = None, + state_key: Optional[str] = None, + sender: Optional[str] = None, + ) -> EventBase: + if sender is None: + sender = self.user_id + + if state_key is None: + state_key = "state_%i" % (self.event_count,) + self.event_count += 1 + + if body is None: + body = "state event %s" % (state_key,) + + return inject_event( + self.hs, + room_id=self.room_id, + sender=sender, + type="test_state_event", + state_key=state_key, + content={"body": body}, + ) diff --git a/tests/replication/tcp/streams/test_receipts.py b/tests/replication/tcp/streams/test_receipts.py index a0206f7363..c122b8589c 100644 --- a/tests/replication/tcp/streams/test_receipts.py +++ b/tests/replication/tcp/streams/test_receipts.py @@ -12,6 +12,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +# type: ignore + +from mock import Mock + from synapse.replication.tcp.streams._base import ReceiptsStream from tests.replication.tcp.streams._base import BaseStreamTestCase @@ -20,11 +25,14 @@ USER_ID = "@feeling:blue" class ReceiptsStreamTestCase(BaseStreamTestCase): + def _build_replication_data_handler(self): + return Mock(wraps=super()._build_replication_data_handler()) + def test_receipt(self): self.reconnect() # make the client subscribe to the receipts stream - self.test_handler.streams.add("receipts") + self.test_handler.stream_positions.update({"receipts": 0}) # tell the master to send a new receipt self.get_success( diff --git a/tests/replication/tcp/streams/test_typing.py b/tests/replication/tcp/streams/test_typing.py index f0ad6402ae..4d354a9db8 100644 --- a/tests/replication/tcp/streams/test_typing.py +++ b/tests/replication/tcp/streams/test_typing.py @@ -12,6 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from mock import Mock + from synapse.handlers.typing import RoomMember from synapse.replication.http import streams from synapse.replication.tcp.streams import TypingStream @@ -26,6 +28,9 @@ class TypingStreamTestCase(BaseStreamTestCase): streams.register_servlets, ] + def _build_replication_data_handler(self): + return Mock(wraps=super()._build_replication_data_handler()) + def test_typing(self): typing = self.hs.get_typing_handler() @@ -33,8 +38,8 @@ class TypingStreamTestCase(BaseStreamTestCase): self.reconnect() - # make the client subscribe to the receipts stream - self.test_handler.streams.add("typing") + # make the client subscribe to the typing stream + self.test_handler.stream_positions.update({"typing": 0}) typing._push_update(member=RoomMember(room_id, USER_ID), typing=True) @@ -75,6 +80,6 @@ class TypingStreamTestCase(BaseStreamTestCase): stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0] self.assertEqual(stream_name, "typing") self.assertEqual(1, len(rdata_rows)) - row = rdata_rows[0] # type: TypingStream.TypingStreamRow + row = rdata_rows[0] self.assertEqual(room_id, row.room_id) self.assertEqual([], row.user_ids) diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index 371637618d..22d734e763 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -39,7 +39,7 @@ class RestHelper(object): resource = attr.ib() auth_user_id = attr.ib() - def create_room_as(self, room_creator, is_public=True, tok=None): + def create_room_as(self, room_creator=None, is_public=True, tok=None): temp_id = self.auth_user_id self.auth_user_id = room_creator path = "/_matrix/client/r0/createRoom" diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py index a7310cf12a..7b345b03bb 100644 --- a/tests/test_utils/__init__.py +++ b/tests/test_utils/__init__.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2019 New Vector Ltd +# Copyright 2020 The Matrix.org Foundation C.I.C # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,3 +17,22 @@ """ Utilities for running the unit tests """ +from typing import Awaitable, TypeVar + +TV = TypeVar("TV") + + +def get_awaitable_result(awaitable: Awaitable[TV]) -> TV: + """Get the result from an Awaitable which should have completed + + Asserts that the given awaitable has a result ready, and returns its value + """ + i = awaitable.__await__() + try: + next(i) + except StopIteration as e: + # awaitable returned a result + return e.value + + # if next didn't raise, the awaitable hasn't completed. + raise Exception("awaitable has not yet completed") diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py new file mode 100644 index 0000000000..8f6872761a --- /dev/null +++ b/tests/test_utils/event_injection.py @@ -0,0 +1,96 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# Copyright 2020 The Matrix.org Foundation C.I.C +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import synapse.server +from synapse.api.constants import EventTypes +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS +from synapse.events import EventBase +from synapse.types import Collection + +from tests.test_utils import get_awaitable_result + + +""" +Utility functions for poking events into the storage of the server under test. +""" + + +def inject_member_event( + hs: synapse.server.HomeServer, + room_id: str, + sender: str, + membership: str, + target: Optional[str] = None, + extra_content: Optional[dict] = None, + **kwargs +) -> EventBase: + """Inject a membership event into a room.""" + if target is None: + target = sender + + content = {"membership": membership} + if extra_content: + content.update(extra_content) + + return inject_event( + hs, + room_id=room_id, + type=EventTypes.Member, + sender=sender, + state_key=target, + content=content, + **kwargs + ) + + +def inject_event( + hs: synapse.server.HomeServer, + room_version: Optional[str] = None, + prev_event_ids: Optional[Collection[str]] = None, + **kwargs +) -> EventBase: + """Inject a generic event into a room + + Args: + hs: the homeserver under test + room_version: the version of the room we're inserting into. + if not specified, will be looked up + prev_event_ids: prev_events for the event. If not specified, will be looked up + kwargs: fields for the event to be created + """ + test_reactor = hs.get_reactor() + + if room_version is None: + d = hs.get_datastore().get_room_version_id(kwargs["room_id"]) + test_reactor.advance(0) + room_version = get_awaitable_result(d) + + builder = hs.get_event_builder_factory().for_room_version( + KNOWN_ROOM_VERSIONS[room_version], kwargs + ) + d = hs.get_event_creation_handler().create_new_client_event( + builder, prev_event_ids=prev_event_ids + ) + test_reactor.advance(0) + event, context = get_awaitable_result(d) + + d = hs.get_storage().persistence.persist_event(event, context) + test_reactor.advance(0) + get_awaitable_result(d) + + return event diff --git a/tests/unittest.py b/tests/unittest.py index 27af5228fe..6b6f224e9c 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -32,7 +32,6 @@ from twisted.python.threadpool import ThreadPool from twisted.trial import unittest from synapse.api.constants import EventTypes, Membership -from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.config.homeserver import HomeServerConfig from synapse.config.ratelimiting import FederationRateLimitConfig from synapse.federation.transport import server as federation_server @@ -55,6 +54,7 @@ from tests.server import ( render, setup_test_homeserver, ) +from tests.test_utils import event_injection from tests.test_utils.logging_setup import setup_logging from tests.utils import default_config, setupdb @@ -596,36 +596,14 @@ class HomeserverTestCase(TestCase): """ Inject a membership event into a room. + Deprecated: use event_injection.inject_room_member directly + Args: room: Room ID to inject the event into. user: MXID of the user to inject the membership for. membership: The membership type. """ - event_builder_factory = self.hs.get_event_builder_factory() - event_creation_handler = self.hs.get_event_creation_handler() - - room_version = self.get_success( - self.hs.get_datastore().get_room_version_id(room) - ) - - builder = event_builder_factory.for_room_version( - KNOWN_ROOM_VERSIONS[room_version], - { - "type": EventTypes.Member, - "sender": user, - "state_key": user, - "room_id": room, - "content": {"membership": membership}, - }, - ) - - event, context = self.get_success( - event_creation_handler.create_new_client_event(builder) - ) - - self.get_success( - self.hs.get_storage().persistence.persist_event(event, context) - ) + event_injection.inject_member_event(self.hs, room, user, membership) class FederatingHomeserverTestCase(HomeserverTestCase): diff --git a/tox.ini b/tox.ini index 31011d7436..2630857436 100644 --- a/tox.ini +++ b/tox.ini @@ -204,6 +204,8 @@ commands = mypy \ synapse/storage/database.py \ synapse/streams \ synapse/util/caches/stream_change_cache.py \ + tests/replication/tcp/streams \ + tests/test_utils \ tests/util/test_stream_change_cache.py # To find all folders that pass mypy you run: -- cgit 1.5.1 From 3eab76ad43e1de4074550c468d11318d497ff632 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 29 Apr 2020 14:10:59 +0100 Subject: Don't relay REMOTE_SERVER_UP cmds to same conn. (#7352) For direct TCP connections we need the master to relay REMOTE_SERVER_UP commands to the other connections so that all instances get notified about it. The old implementation just relayed to all connections, assuming that sending back to the original sender of the command was safe. This is not true for redis, where commands sent get echoed back to the sender, which was causing master to effectively infinite loop sending and then re-receiving REMOTE_SERVER_UP commands that it sent. The fix is to ensure that we only relay to *other* connections and not to the connection we received the notification from. Fixes #7334. --- changelog.d/7352.feature | 1 + synapse/notifier.py | 9 ---- synapse/replication/tcp/handler.py | 63 ++++++++++++++++++++------ synapse/replication/tcp/protocol.py | 2 +- synapse/replication/tcp/redis.py | 2 +- tests/replication/tcp/test_remote_server_up.py | 62 +++++++++++++++++++++++++ 6 files changed, 114 insertions(+), 25 deletions(-) create mode 100644 changelog.d/7352.feature create mode 100644 tests/replication/tcp/test_remote_server_up.py (limited to 'tests') diff --git a/changelog.d/7352.feature b/changelog.d/7352.feature new file mode 100644 index 0000000000..ce6140fdd1 --- /dev/null +++ b/changelog.d/7352.feature @@ -0,0 +1 @@ +Add support for running replication over Redis when using workers. diff --git a/synapse/notifier.py b/synapse/notifier.py index 6132727cbd..88a5a97caf 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -220,12 +220,6 @@ class Notifier(object): """ self.replication_callbacks.append(cb) - def add_remote_server_up_callback(self, cb: Callable[[str], None]): - """Add a callback that will be called when synapse detects a server - has been - """ - self.remote_server_up_callbacks.append(cb) - def on_new_room_event( self, event, room_stream_id, max_room_stream_id, extra_users=[] ): @@ -544,6 +538,3 @@ class Notifier(object): # circular dependencies. if self.federation_sender: self.federation_sender.wake_destination(server) - - for cb in self.remote_server_up_callbacks: - cb(server) diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 3a8c7c7e2d..b8f49a8d0f 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -117,7 +117,6 @@ class ReplicationCommandHandler: self._server_notices_sender = None if self._is_master: self._server_notices_sender = hs.get_server_notices_sender() - self._notifier.add_remote_server_up_callback(self.send_remote_server_up) def start_replication(self, hs): """Helper method to start a replication connection to the remote server @@ -163,7 +162,7 @@ class ReplicationCommandHandler: port = hs.config.worker_replication_port hs.get_reactor().connectTCP(host, port, self._factory) - async def on_REPLICATE(self, cmd: ReplicateCommand): + async def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand): # We only want to announce positions by the writer of the streams. # Currently this is just the master process. if not self._is_master: @@ -173,7 +172,7 @@ class ReplicationCommandHandler: current_token = stream.current_token() self.send_command(PositionCommand(stream_name, current_token)) - async def on_USER_SYNC(self, cmd: UserSyncCommand): + async def on_USER_SYNC(self, conn: AbstractConnection, cmd: UserSyncCommand): user_sync_counter.inc() if self._is_master: @@ -181,17 +180,23 @@ class ReplicationCommandHandler: cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms ) - async def on_CLEAR_USER_SYNC(self, cmd: ClearUserSyncsCommand): + async def on_CLEAR_USER_SYNC( + self, conn: AbstractConnection, cmd: ClearUserSyncsCommand + ): if self._is_master: await self._presence_handler.update_external_syncs_clear(cmd.instance_id) - async def on_FEDERATION_ACK(self, cmd: FederationAckCommand): + async def on_FEDERATION_ACK( + self, conn: AbstractConnection, cmd: FederationAckCommand + ): federation_ack_counter.inc() if self._federation_sender: self._federation_sender.federation_ack(cmd.token) - async def on_REMOVE_PUSHER(self, cmd: RemovePusherCommand): + async def on_REMOVE_PUSHER( + self, conn: AbstractConnection, cmd: RemovePusherCommand + ): remove_pusher_counter.inc() if self._is_master: @@ -201,7 +206,9 @@ class ReplicationCommandHandler: self._notifier.on_new_replication_data() - async def on_INVALIDATE_CACHE(self, cmd: InvalidateCacheCommand): + async def on_INVALIDATE_CACHE( + self, conn: AbstractConnection, cmd: InvalidateCacheCommand + ): invalidate_cache_counter.inc() if self._is_master: @@ -211,7 +218,7 @@ class ReplicationCommandHandler: cmd.cache_func, tuple(cmd.keys) ) - async def on_USER_IP(self, cmd: UserIpCommand): + async def on_USER_IP(self, conn: AbstractConnection, cmd: UserIpCommand): user_ip_cache_counter.inc() if self._is_master: @@ -227,7 +234,7 @@ class ReplicationCommandHandler: if self._server_notices_sender: await self._server_notices_sender.on_user_ip(cmd.user_id) - async def on_RDATA(self, cmd: RdataCommand): + async def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand): stream_name = cmd.stream_name inbound_rdata_count.labels(stream_name).inc() @@ -278,7 +285,7 @@ class ReplicationCommandHandler: logger.debug("Received rdata %s -> %s", stream_name, token) await self._replication_data_handler.on_rdata(stream_name, token, rows) - async def on_POSITION(self, cmd: PositionCommand): + async def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand): stream = self._streams.get(cmd.stream_name) if not stream: logger.error("Got POSITION for unknown stream: %s", cmd.stream_name) @@ -332,12 +339,30 @@ class ReplicationCommandHandler: self._streams_connected.add(cmd.stream_name) - async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand): + async def on_REMOTE_SERVER_UP( + self, conn: AbstractConnection, cmd: RemoteServerUpCommand + ): """"Called when get a new REMOTE_SERVER_UP command.""" self._replication_data_handler.on_remote_server_up(cmd.data) - if self._is_master: - self._notifier.notify_remote_server_up(cmd.data) + self._notifier.notify_remote_server_up(cmd.data) + + # We relay to all other connections to ensure every instance gets the + # notification. + # + # When configured to use redis we'll always only have one connection and + # so this is a no-op (all instances will have already received the same + # REMOTE_SERVER_UP command). + # + # For direct TCP connections this will relay to all other connections + # connected to us. When on master this will correctly fan out to all + # other direct TCP clients and on workers there'll only be the one + # connection to master. + # + # (The logic here should also be sound if we have a mix of Redis and + # direct TCP connections so long as there is only one traffic route + # between two instances, but that is not currently supported). + self.send_command(cmd, ignore_conn=conn) def new_connection(self, connection: AbstractConnection): """Called when we have a new connection. @@ -382,11 +407,21 @@ class ReplicationCommandHandler: """ return bool(self._connections) - def send_command(self, cmd: Command): + def send_command( + self, cmd: Command, ignore_conn: Optional[AbstractConnection] = None + ): """Send a command to all connected connections. + + Args: + cmd + ignore_conn: If set don't send command to the given connection. + Used when relaying commands from one connection to all others. """ if self._connections: for connection in self._connections: + if connection == ignore_conn: + continue + try: connection.send_command(cmd) except Exception: diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index e3f64eba8f..4198eece71 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -260,7 +260,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): # Then call out to the handler. cmd_func = getattr(self.command_handler, "on_%s" % (cmd.NAME,), None) if cmd_func: - await cmd_func(cmd) + await cmd_func(self, cmd) handled = True if not handled: diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py index 49b3ed0c5e..617e860f95 100644 --- a/synapse/replication/tcp/redis.py +++ b/synapse/replication/tcp/redis.py @@ -112,7 +112,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): # Then call out to the handler. cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None) if cmd_func: - await cmd_func(cmd) + await cmd_func(self, cmd) handled = True if not handled: diff --git a/tests/replication/tcp/test_remote_server_up.py b/tests/replication/tcp/test_remote_server_up.py new file mode 100644 index 0000000000..d1c15caeb0 --- /dev/null +++ b/tests/replication/tcp/test_remote_server_up.py @@ -0,0 +1,62 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple + +from twisted.internet.interfaces import IProtocol +from twisted.test.proto_helpers import StringTransport + +from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory + +from tests.unittest import HomeserverTestCase + + +class RemoteServerUpTestCase(HomeserverTestCase): + def prepare(self, reactor, clock, hs): + self.factory = ReplicationStreamProtocolFactory(hs) + + def _make_client(self) -> Tuple[IProtocol, StringTransport]: + """Create a new direct TCP replication connection + """ + + proto = self.factory.buildProtocol(("127.0.0.1", 0)) + transport = StringTransport() + proto.makeConnection(transport) + + # We can safely ignore the commands received during connection. + self.pump() + transport.clear() + + return proto, transport + + def test_relay(self): + """Test that Synapse will relay REMOTE_SERVER_UP commands to all + other connections, but not the one that sent it. + """ + + proto1, transport1 = self._make_client() + + # We shouldn't receive an echo. + proto1.dataReceived(b"REMOTE_SERVER_UP example.com\n") + self.pump() + self.assertEqual(transport1.value(), b"") + + # But we should see an echo if we connect another client + proto2, transport2 = self._make_client() + proto1.dataReceived(b"REMOTE_SERVER_UP example.com\n") + + self.pump() + self.assertEqual(transport1.value(), b"") + self.assertEqual(transport2.value(), b"REMOTE_SERVER_UP example.com\n") -- cgit 1.5.1 From 37f6823f5b91f27b9dd8de8fc0e52d5ea889647c Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 29 Apr 2020 16:23:08 +0100 Subject: Add instance name to RDATA/POSITION commands (#7364) This is primarily for allowing us to send those commands from workers, but for now simply allows us to ignore echoed RDATA/POSITION commands that we sent (we get echoes of sent commands when using redis). Currently we log a WARNING on the master process every time we receive an echoed RDATA. --- changelog.d/7364.misc | 1 + docs/tcp_replication.md | 41 +++++++++++++++++++------------- synapse/app/_base.py | 4 ++-- synapse/logging/opentracing.py | 23 ++++++++---------- synapse/replication/tcp/commands.py | 37 +++++++++++++++++++--------- synapse/replication/tcp/handler.py | 17 ++++++++++--- synapse/server.py | 13 ++++++++-- synapse/server.pyi | 2 ++ tests/replication/slave/storage/_base.py | 1 + tests/replication/tcp/test_commands.py | 6 +++-- 10 files changed, 95 insertions(+), 50 deletions(-) create mode 100644 changelog.d/7364.misc (limited to 'tests') diff --git a/changelog.d/7364.misc b/changelog.d/7364.misc new file mode 100644 index 0000000000..bb5d727cf4 --- /dev/null +++ b/changelog.d/7364.misc @@ -0,0 +1 @@ +Add an `instance_name` to `RDATA` and `POSITION` replication commands. diff --git a/docs/tcp_replication.md b/docs/tcp_replication.md index b922d9cf7e..ab2fffbfe4 100644 --- a/docs/tcp_replication.md +++ b/docs/tcp_replication.md @@ -15,15 +15,17 @@ example flow would be (where '>' indicates master to worker and > SERVER example.com < REPLICATE - > POSITION events 53 - > RDATA events 54 ["$foo1:bar.com", ...] - > RDATA events 55 ["$foo4:bar.com", ...] + > POSITION events master 53 + > RDATA events master 54 ["$foo1:bar.com", ...] + > RDATA events master 55 ["$foo4:bar.com", ...] The example shows the server accepting a new connection and sending its identity with the `SERVER` command, followed by the client server to respond with the position of all streams. The server then periodically sends `RDATA` commands -which have the format `RDATA `, where the format of -`` is defined by the individual streams. +which have the format `RDATA `, where +the format of `` is defined by the individual streams. The +`` is the name of the Synapse process that generated the data +(usually "master"). Error reporting happens by either the client or server sending an ERROR command, and usually the connection will be closed. @@ -52,7 +54,7 @@ The basic structure of the protocol is line based, where the initial word of each line specifies the command. The rest of the line is parsed based on the command. For example, the RDATA command is defined as: - RDATA + RDATA (Note that may contains spaces, but cannot contain newlines.) @@ -136,11 +138,11 @@ the wire: < NAME synapse.app.appservice < PING 1490197665618 < REPLICATE - > POSITION events 1 - > POSITION backfill 1 - > POSITION caches 1 - > RDATA caches 2 ["get_user_by_id",["@01register-user:localhost:8823"],1490197670513] - > RDATA events 14 ["$149019767112vOHxz:localhost:8823", + > POSITION events master 1 + > POSITION backfill master 1 + > POSITION caches master 1 + > RDATA caches master 2 ["get_user_by_id",["@01register-user:localhost:8823"],1490197670513] + > RDATA events master 14 ["$149019767112vOHxz:localhost:8823", "!AFDCvgApUmpdfVjIXm:localhost:8823","m.room.guest_access","",null] < PING 1490197675618 > ERROR server stopping @@ -151,10 +153,10 @@ position without needing to send data with the `RDATA` command. An example of a batched set of `RDATA` is: - > RDATA caches batch ["get_user_by_id",["@test:localhost:8823"],1490197670513] - > RDATA caches batch ["get_user_by_id",["@test2:localhost:8823"],1490197670513] - > RDATA caches batch ["get_user_by_id",["@test3:localhost:8823"],1490197670513] - > RDATA caches 54 ["get_user_by_id",["@test4:localhost:8823"],1490197670513] + > RDATA caches master batch ["get_user_by_id",["@test:localhost:8823"],1490197670513] + > RDATA caches master batch ["get_user_by_id",["@test2:localhost:8823"],1490197670513] + > RDATA caches master batch ["get_user_by_id",["@test3:localhost:8823"],1490197670513] + > RDATA caches master 54 ["get_user_by_id",["@test4:localhost:8823"],1490197670513] In this case the client shouldn't advance their caches token until it sees the the last `RDATA`. @@ -178,6 +180,11 @@ client (C): updates, and if so then fetch them out of band. Sent in response to a REPLICATE command (but can happen at any time). + The POSITION command includes the source of the stream. Currently all streams + are written by a single process (usually "master"). If fetching missing + updates via HTTP API, rather than via the DB, then processes should make the + request to the appropriate process. + #### ERROR (S, C) There was an error @@ -234,12 +241,12 @@ Each individual cache invalidation results in a row being sent down replication, which includes the cache name (the name of the function) and they key to invalidate. For example: - > RDATA caches 550953771 ["get_user_by_id", ["@bob:example.com"], 1550574873251] + > RDATA caches master 550953771 ["get_user_by_id", ["@bob:example.com"], 1550574873251] Alternatively, an entire cache can be invalidated by sending down a `null` instead of the key. For example: - > RDATA caches 550953772 ["get_user_by_id", null, 1550574873252] + > RDATA caches master 550953772 ["get_user_by_id", null, 1550574873252] However, there are times when a number of caches need to be invalidated at the same time with the same key. To reduce traffic we batch those diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 4d84f4595a..628292b890 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -270,7 +270,7 @@ def start(hs, listeners=None): # Start the tracer synapse.logging.opentracing.init_tracer( # type: ignore[attr-defined] # noqa - hs.config + hs ) # It is now safe to start your Synapse. @@ -316,7 +316,7 @@ def setup_sentry(hs): scope.set_tag("matrix_server_name", hs.config.server_name) app = hs.config.worker_app if hs.config.worker_app else "synapse.app.homeserver" - name = hs.config.worker_name if hs.config.worker_name else "master" + name = hs.get_instance_name() scope.set_tag("worker_app", app) scope.set_tag("worker_name", name) diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index 0638cec429..5dddf57008 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -171,7 +171,7 @@ import logging import re import types from functools import wraps -from typing import Dict +from typing import TYPE_CHECKING, Dict from canonicaljson import json @@ -179,6 +179,9 @@ from twisted.internet import defer from synapse.config import ConfigError +if TYPE_CHECKING: + from synapse.server import HomeServer + # Helper class @@ -297,14 +300,11 @@ def _noop_context_manager(*args, **kwargs): # Setup -def init_tracer(config): +def init_tracer(hs: "HomeServer"): """Set the whitelists and initialise the JaegerClient tracer - - Args: - config (HomeserverConfig): The config used by the homeserver """ global opentracing - if not config.opentracer_enabled: + if not hs.config.opentracer_enabled: # We don't have a tracer opentracing = None return @@ -315,18 +315,15 @@ def init_tracer(config): "installed." ) - # Include the worker name - name = config.worker_name if config.worker_name else "master" - # Pull out the jaeger config if it was given. Otherwise set it to something sensible. # See https://github.com/jaegertracing/jaeger-client-python/blob/master/jaeger_client/config.py - set_homeserver_whitelist(config.opentracer_whitelist) + set_homeserver_whitelist(hs.config.opentracer_whitelist) JaegerConfig( - config=config.jaeger_config, - service_name="{} {}".format(config.server_name, name), - scope_manager=LogContextScopeManager(config), + config=hs.config.jaeger_config, + service_name="{} {}".format(hs.config.server_name, hs.get_instance_name()), + scope_manager=LogContextScopeManager(hs.config), ).initialize_tracer() diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index c7880d4b63..f58e384d17 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -95,7 +95,7 @@ class RdataCommand(Command): Format:: - RDATA + RDATA The `` may either be a numeric stream id OR "batch". The latter case is used to support sending multiple updates with the same stream ID. This @@ -105,33 +105,40 @@ class RdataCommand(Command): The client should batch all incoming RDATA with a token of "batch" (per stream_name) until it sees an RDATA with a numeric stream ID. + The `` is the source of the new data (usually "master"). + `` of "batch" maps to the instance variable `token` being None. An example of a batched series of RDATA:: - RDATA presence batch ["@foo:example.com", "online", ...] - RDATA presence batch ["@bar:example.com", "online", ...] - RDATA presence 59 ["@baz:example.com", "online", ...] + RDATA presence master batch ["@foo:example.com", "online", ...] + RDATA presence master batch ["@bar:example.com", "online", ...] + RDATA presence master 59 ["@baz:example.com", "online", ...] """ NAME = "RDATA" - def __init__(self, stream_name, token, row): + def __init__(self, stream_name, instance_name, token, row): self.stream_name = stream_name + self.instance_name = instance_name self.token = token self.row = row @classmethod def from_line(cls, line): - stream_name, token, row_json = line.split(" ", 2) + stream_name, instance_name, token, row_json = line.split(" ", 3) return cls( - stream_name, None if token == "batch" else int(token), json.loads(row_json) + stream_name, + instance_name, + None if token == "batch" else int(token), + json.loads(row_json), ) def to_line(self): return " ".join( ( self.stream_name, + self.instance_name, str(self.token) if self.token is not None else "batch", _json_encoder.encode(self.row), ) @@ -145,23 +152,31 @@ class PositionCommand(Command): """Sent by the server to tell the client the stream postition without needing to send an RDATA. + Format:: + + POSITION + On receipt of a POSITION command clients should check if they have missed any updates, and if so then fetch them out of band. + + The `` is the process that sent the command and is the source + of the stream. """ NAME = "POSITION" - def __init__(self, stream_name, token): + def __init__(self, stream_name, instance_name, token): self.stream_name = stream_name + self.instance_name = instance_name self.token = token @classmethod def from_line(cls, line): - stream_name, token = line.split(" ", 1) - return cls(stream_name, int(token)) + stream_name, instance_name, token = line.split(" ", 2) + return cls(stream_name, instance_name, int(token)) def to_line(self): - return " ".join((self.stream_name, str(self.token))) + return " ".join((self.stream_name, self.instance_name, str(self.token))) class ErrorCommand(_SimpleCommand): diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index b8f49a8d0f..6f7054d5af 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -79,6 +79,7 @@ class ReplicationCommandHandler: self._notifier = hs.get_notifier() self._clock = hs.get_clock() self._instance_id = hs.get_instance_id() + self._instance_name = hs.get_instance_name() # Set of streams that we've caught up with. self._streams_connected = set() # type: Set[str] @@ -156,7 +157,7 @@ class ReplicationCommandHandler: hs.config.redis.redis_host, hs.config.redis.redis_port, self._factory, ) else: - client_name = hs.config.worker_name + client_name = hs.get_instance_name() self._factory = DirectTcpReplicationClientFactory(hs, client_name, self) host = hs.config.worker_replication_host port = hs.config.worker_replication_port @@ -170,7 +171,9 @@ class ReplicationCommandHandler: for stream_name, stream in self._streams.items(): current_token = stream.current_token() - self.send_command(PositionCommand(stream_name, current_token)) + self.send_command( + PositionCommand(stream_name, self._instance_name, current_token) + ) async def on_USER_SYNC(self, conn: AbstractConnection, cmd: UserSyncCommand): user_sync_counter.inc() @@ -235,6 +238,10 @@ class ReplicationCommandHandler: await self._server_notices_sender.on_user_ip(cmd.user_id) async def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand): + if cmd.instance_name == self._instance_name: + # Ignore RDATA that are just our own echoes + return + stream_name = cmd.stream_name inbound_rdata_count.labels(stream_name).inc() @@ -286,6 +293,10 @@ class ReplicationCommandHandler: await self._replication_data_handler.on_rdata(stream_name, token, rows) async def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand): + if cmd.instance_name == self._instance_name: + # Ignore POSITION that are just our own echoes + return + stream = self._streams.get(cmd.stream_name) if not stream: logger.error("Got POSITION for unknown stream: %s", cmd.stream_name) @@ -485,7 +496,7 @@ class ReplicationCommandHandler: We need to check if the client is interested in the stream or not """ - self.send_command(RdataCommand(stream_name, token, data)) + self.send_command(RdataCommand(stream_name, self._instance_name, token, data)) UpdateToken = TypeVar("UpdateToken") diff --git a/synapse/server.py b/synapse/server.py index 9d273c980c..bf97a16c09 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -234,7 +234,8 @@ class HomeServer(object): self._listening_services = [] self.start_time = None - self.instance_id = random_string(5) + self._instance_id = random_string(5) + self._instance_name = config.worker_name or "master" self.clock = Clock(reactor) self.distributor = Distributor() @@ -254,7 +255,15 @@ class HomeServer(object): This is used to distinguish running instances in worker-based deployments. """ - return self.instance_id + return self._instance_id + + def get_instance_name(self) -> str: + """A unique name for this synapse process. + + Used to identify the process over replication and in config. Does not + change over restarts. + """ + return self._instance_name def setup(self): logger.info("Setting up.") diff --git a/synapse/server.pyi b/synapse/server.pyi index fc5886f762..18043a2593 100644 --- a/synapse/server.pyi +++ b/synapse/server.pyi @@ -122,6 +122,8 @@ class HomeServer(object): pass def get_instance_id(self) -> str: pass + def get_instance_name(self) -> str: + pass def get_event_builder_factory(self) -> EventBuilderFactory: pass def get_storage(self) -> synapse.storage.Storage: diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index 395c7d0306..1615dfab5e 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -57,6 +57,7 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase): # We now do some gut wrenching so that we have a client that is based # off of the slave store rather than the main store. self.replication_handler = ReplicationCommandHandler(self.hs) + self.replication_handler._instance_name = "worker" self.replication_handler._replication_data_handler = ReplicationDataHandler( self.slaved_store ) diff --git a/tests/replication/tcp/test_commands.py b/tests/replication/tcp/test_commands.py index 3cbcb513cc..7ddfd0a733 100644 --- a/tests/replication/tcp/test_commands.py +++ b/tests/replication/tcp/test_commands.py @@ -28,15 +28,17 @@ class ParseCommandTestCase(TestCase): self.assertIsInstance(cmd, ReplicateCommand) def test_parse_rdata(self): - line = 'RDATA events 6287863 ["ev", ["$eventid", "!roomid", "type", null, null, null]]' + line = 'RDATA events master 6287863 ["ev", ["$eventid", "!roomid", "type", null, null, null]]' cmd = parse_command_from_line(line) self.assertIsInstance(cmd, RdataCommand) self.assertEqual(cmd.stream_name, "events") + self.assertEqual(cmd.instance_name, "master") self.assertEqual(cmd.token, 6287863) def test_parse_rdata_batch(self): - line = 'RDATA presence batch ["@foo:example.com", "online"]' + line = 'RDATA presence master batch ["@foo:example.com", "online"]' cmd = parse_command_from_line(line) self.assertIsInstance(cmd, RdataCommand) self.assertEqual(cmd.stream_name, "presence") + self.assertEqual(cmd.instance_name, "master") self.assertIsNone(cmd.token) -- cgit 1.5.1 From 627b0f5f2753e6910adb7a877541d50f5936b8a5 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 30 Apr 2020 13:47:49 -0400 Subject: Persist user interactive authentication sessions (#7302) By persisting the user interactive authentication sessions to the database, this fixes situations where a user hits different works throughout their auth session and also allows sessions to persist through restarts of Synapse. --- changelog.d/7302.bugfix | 1 + synapse/app/generic_worker.py | 2 + synapse/handlers/auth.py | 175 +++++-------- synapse/handlers/cas_handler.py | 2 +- synapse/handlers/saml_handler.py | 2 +- synapse/rest/client/v2_alpha/auth.py | 4 +- synapse/rest/client/v2_alpha/register.py | 4 +- synapse/storage/data_stores/main/__init__.py | 2 + .../main/schema/delta/58/03persist_ui_auth.sql | 36 +++ synapse/storage/data_stores/main/ui_auth.py | 279 +++++++++++++++++++++ synapse/storage/engines/sqlite.py | 1 + tests/rest/client/v2_alpha/test_auth.py | 40 +++ tests/utils.py | 8 +- tox.ini | 3 +- 14 files changed, 434 insertions(+), 125 deletions(-) create mode 100644 changelog.d/7302.bugfix create mode 100644 synapse/storage/data_stores/main/schema/delta/58/03persist_ui_auth.sql create mode 100644 synapse/storage/data_stores/main/ui_auth.py (limited to 'tests') diff --git a/changelog.d/7302.bugfix b/changelog.d/7302.bugfix new file mode 100644 index 0000000000..820646d1f9 --- /dev/null +++ b/changelog.d/7302.bugfix @@ -0,0 +1 @@ +Persist user interactive authentication sessions across workers and Synapse restarts. diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index d125327f08..0ace7b787d 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -127,6 +127,7 @@ from synapse.storage.data_stores.main.monthly_active_users import ( MonthlyActiveUsersWorkerStore, ) from synapse.storage.data_stores.main.presence import UserPresenceState +from synapse.storage.data_stores.main.ui_auth import UIAuthWorkerStore from synapse.storage.data_stores.main.user_directory import UserDirectoryStore from synapse.types import ReadReceipt from synapse.util.async_helpers import Linearizer @@ -439,6 +440,7 @@ class GenericWorkerSlavedStore( # FIXME(#3714): We need to add UserDirectoryStore as we write directly # rather than going via the correct worker. UserDirectoryStore, + UIAuthWorkerStore, SlavedDeviceInboxStore, SlavedDeviceStore, SlavedReceiptsStore, diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index dbe165ce1e..7613e5b6ab 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -41,10 +41,10 @@ from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker from synapse.http.server import finish_request from synapse.http.site import SynapseRequest from synapse.logging.context import defer_to_thread +from synapse.metrics.background_process_metrics import run_as_background_process from synapse.module_api import ModuleApi from synapse.push.mailer import load_jinja2_templates from synapse.types import Requester, UserID -from synapse.util.caches.expiringcache import ExpiringCache from ._base import BaseHandler @@ -69,15 +69,6 @@ class AuthHandler(BaseHandler): self.bcrypt_rounds = hs.config.bcrypt_rounds - # This is not a cache per se, but a store of all current sessions that - # expire after N hours - self.sessions = ExpiringCache( - cache_name="register_sessions", - clock=hs.get_clock(), - expiry_ms=self.SESSION_EXPIRE_MS, - reset_expiry_on_get=True, - ) - account_handler = ModuleApi(hs, self) self.password_providers = [ module(config=config, account_handler=account_handler) @@ -119,6 +110,15 @@ class AuthHandler(BaseHandler): self._clock = self.hs.get_clock() + # Expire old UI auth sessions after a period of time. + if hs.config.worker_app is None: + self._clock.looping_call( + run_as_background_process, + 5 * 60 * 1000, + "expire_old_sessions", + self._expire_old_sessions, + ) + # Load the SSO HTML templates. # The following template is shown to the user during a client login via SSO, @@ -301,16 +301,21 @@ class AuthHandler(BaseHandler): if "session" in authdict: sid = authdict["session"] + # Convert the URI and method to strings. + uri = request.uri.decode("utf-8") + method = request.uri.decode("utf-8") + # If there's no session ID, create a new session. if not sid: - session = self._create_session( - clientdict, (request.uri, request.method, clientdict), description + session = await self.store.create_ui_auth_session( + clientdict, uri, method, description ) - session_id = session["id"] else: - session = self._get_session_info(sid) - session_id = sid + try: + session = await self.store.get_ui_auth_session(sid) + except StoreError: + raise SynapseError(400, "Unknown session ID: %s" % (sid,)) if not clientdict: # This was designed to allow the client to omit the parameters @@ -322,15 +327,15 @@ class AuthHandler(BaseHandler): # on a homeserver. # Revisit: Assuming the REST APIs do sensible validation, the data # isn't arbitrary. - clientdict = session["clientdict"] + clientdict = session.clientdict # Ensure that the queried operation does not vary between stages of # the UI authentication session. This is done by generating a stable # comparator based on the URI, method, and body (minus the auth dict) # and storing it during the initial query. Subsequent queries ensure # that this comparator has not changed. - comparator = (request.uri, request.method, clientdict) - if session["ui_auth"] != comparator: + comparator = (uri, method, clientdict) + if (session.uri, session.method, session.clientdict) != comparator: raise SynapseError( 403, "Requested operation has changed during the UI authentication session.", @@ -338,11 +343,9 @@ class AuthHandler(BaseHandler): if not authdict: raise InteractiveAuthIncompleteError( - self._auth_dict_for_flows(flows, session_id) + self._auth_dict_for_flows(flows, session.session_id) ) - creds = session["creds"] - # check auth type currently being presented errordict = {} # type: Dict[str, Any] if "type" in authdict: @@ -350,8 +353,9 @@ class AuthHandler(BaseHandler): try: result = await self._check_auth_dict(authdict, clientip) if result: - creds[login_type] = result - self._save_session(session) + await self.store.mark_ui_auth_stage_complete( + session.session_id, login_type, result + ) except LoginError as e: if login_type == LoginType.EMAIL_IDENTITY: # riot used to have a bug where it would request a new @@ -367,6 +371,7 @@ class AuthHandler(BaseHandler): # so that the client can have another go. errordict = e.error_dict() + creds = await self.store.get_completed_ui_auth_stages(session.session_id) for f in flows: if len(set(f) - set(creds)) == 0: # it's very useful to know what args are stored, but this can @@ -380,9 +385,9 @@ class AuthHandler(BaseHandler): list(clientdict), ) - return creds, clientdict, session_id + return creds, clientdict, session.session_id - ret = self._auth_dict_for_flows(flows, session_id) + ret = self._auth_dict_for_flows(flows, session.session_id) ret["completed"] = list(creds) ret.update(errordict) raise InteractiveAuthIncompleteError(ret) @@ -399,13 +404,11 @@ class AuthHandler(BaseHandler): if "session" not in authdict: raise LoginError(400, "", Codes.MISSING_PARAM) - sess = self._get_session_info(authdict["session"]) - creds = sess["creds"] - result = await self.checkers[stagetype].check_auth(authdict, clientip) if result: - creds[stagetype] = result - self._save_session(sess) + await self.store.mark_ui_auth_stage_complete( + authdict["session"], stagetype, result + ) return True return False @@ -427,7 +430,7 @@ class AuthHandler(BaseHandler): sid = authdict["session"] return sid - def set_session_data(self, session_id: str, key: str, value: Any) -> None: + async def set_session_data(self, session_id: str, key: str, value: Any) -> None: """ Store a key-value pair into the sessions data associated with this request. This data is stored server-side and cannot be modified by @@ -438,11 +441,12 @@ class AuthHandler(BaseHandler): key: The key to store the data under value: The data to store """ - sess = self._get_session_info(session_id) - sess["serverdict"][key] = value - self._save_session(sess) + try: + await self.store.set_ui_auth_session_data(session_id, key, value) + except StoreError: + raise SynapseError(400, "Unknown session ID: %s" % (session_id,)) - def get_session_data( + async def get_session_data( self, session_id: str, key: str, default: Optional[Any] = None ) -> Any: """ @@ -453,8 +457,18 @@ class AuthHandler(BaseHandler): key: The key to store the data under default: Value to return if the key has not been set """ - sess = self._get_session_info(session_id) - return sess["serverdict"].get(key, default) + try: + return await self.store.get_ui_auth_session_data(session_id, key, default) + except StoreError: + raise SynapseError(400, "Unknown session ID: %s" % (session_id,)) + + async def _expire_old_sessions(self): + """ + Invalidate any user interactive authentication sessions that have expired. + """ + now = self._clock.time_msec() + expiration_time = now - self.SESSION_EXPIRE_MS + await self.store.delete_old_ui_auth_sessions(expiration_time) async def _check_auth_dict( self, authdict: Dict[str, Any], clientip: str @@ -534,67 +548,6 @@ class AuthHandler(BaseHandler): "params": params, } - def _create_session( - self, - clientdict: Dict[str, Any], - ui_auth: Tuple[bytes, bytes, Dict[str, Any]], - description: str, - ) -> dict: - """ - Creates a new user interactive authentication session. - - The session can be used to track data across multiple requests, e.g. for - interactive authentication. - - Each session has the following keys: - - id: - A unique identifier for this session. Passed back to the client - and returned for each stage. - clientdict: - The dictionary from the client root level, not the 'auth' key. - ui_auth: - A tuple which is checked at each stage of the authentication to - ensure that the asked for operation has not changed. - creds: - A map, which maps each auth-type (str) to the relevant identity - authenticated by that auth-type (mostly str, but for captcha, bool). - serverdict: - A map of data that is stored server-side and cannot be modified - by the client. - description: - A string description of the operation that the current - authentication is authorising. - Returns: - The newly created session. - """ - session_id = None - while session_id is None or session_id in self.sessions: - session_id = stringutils.random_string(24) - - self.sessions[session_id] = { - "id": session_id, - "clientdict": clientdict, - "ui_auth": ui_auth, - "creds": {}, - "serverdict": {}, - "description": description, - } - - return self.sessions[session_id] - - def _get_session_info(self, session_id: str) -> dict: - """ - Gets a session given a session ID. - - The session can be used to track data across multiple requests, e.g. for - interactive authentication. - """ - try: - return self.sessions[session_id] - except KeyError: - raise SynapseError(400, "Unknown session ID: %s" % (session_id,)) - async def get_access_token_for_user_id( self, user_id: str, device_id: Optional[str], valid_until_ms: Optional[int] ): @@ -994,13 +947,6 @@ class AuthHandler(BaseHandler): await self.store.user_delete_threepid(user_id, medium, address) return result - def _save_session(self, session: Dict[str, Any]) -> None: - """Update the last used time on the session to now and add it back to the session store.""" - # TODO: Persistent storage - logger.debug("Saving session %s", session) - session["last_used"] = self.hs.get_clock().time_msec() - self.sessions[session["id"]] = session - async def hash(self, password: str) -> str: """Computes a secure hash of password. @@ -1052,7 +998,7 @@ class AuthHandler(BaseHandler): else: return False - def start_sso_ui_auth(self, redirect_url: str, session_id: str) -> str: + async def start_sso_ui_auth(self, redirect_url: str, session_id: str) -> str: """ Get the HTML for the SSO redirect confirmation page. @@ -1063,12 +1009,15 @@ class AuthHandler(BaseHandler): Returns: The HTML to render. """ - session = self._get_session_info(session_id) + try: + session = await self.store.get_ui_auth_session(session_id) + except StoreError: + raise SynapseError(400, "Unknown session ID: %s" % (session_id,)) return self._sso_auth_confirm_template.render( - description=session["description"], redirect_url=redirect_url, + description=session.description, redirect_url=redirect_url, ) - def complete_sso_ui_auth( + async def complete_sso_ui_auth( self, registered_user_id: str, session_id: str, request: SynapseRequest, ): """Having figured out a mxid for this user, complete the HTTP request @@ -1080,13 +1029,11 @@ class AuthHandler(BaseHandler): process. """ # Mark the stage of the authentication as successful. - sess = self._get_session_info(session_id) - creds = sess["creds"] - # Save the user who authenticated with SSO, this will be used to ensure # that the account be modified is also the person who logged in. - creds[LoginType.SSO] = registered_user_id - self._save_session(sess) + await self.store.mark_ui_auth_stage_complete( + session_id, LoginType.SSO, registered_user_id + ) # Render the HTML and return. html_bytes = self._sso_auth_success_template.encode("utf-8") diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py index 5cb3f9d133..64aaa1335c 100644 --- a/synapse/handlers/cas_handler.py +++ b/synapse/handlers/cas_handler.py @@ -206,7 +206,7 @@ class CasHandler: registered_user_id = await self._auth_handler.check_user_exists(user_id) if session: - self._auth_handler.complete_sso_ui_auth( + await self._auth_handler.complete_sso_ui_auth( registered_user_id, session, request, ) diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py index 7c9454b504..96f2dd36ad 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml_handler.py @@ -149,7 +149,7 @@ class SamlHandler: # Complete the interactive auth session or the login. if current_session and current_session.ui_auth_session_id: - self._auth_handler.complete_sso_ui_auth( + await self._auth_handler.complete_sso_ui_auth( user_id, current_session.ui_auth_session_id, request ) diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py index 11599f5005..24dd3d3e96 100644 --- a/synapse/rest/client/v2_alpha/auth.py +++ b/synapse/rest/client/v2_alpha/auth.py @@ -140,7 +140,7 @@ class AuthRestServlet(RestServlet): self._cas_server_url = hs.config.cas_server_url self._cas_service_url = hs.config.cas_service_url - def on_GET(self, request, stagetype): + async def on_GET(self, request, stagetype): session = parse_string(request, "session") if not session: raise SynapseError(400, "No session supplied") @@ -180,7 +180,7 @@ class AuthRestServlet(RestServlet): else: raise SynapseError(400, "Homeserver not configured for SSO.") - html = self.auth_handler.start_sso_ui_auth(sso_redirect_url, session) + html = await self.auth_handler.start_sso_ui_auth(sso_redirect_url, session) else: raise SynapseError(404, "Unknown auth stage type") diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index d1b5c49989..af08cc6cce 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -499,7 +499,7 @@ class RegisterRestServlet(RestServlet): # registered a user for this session, so we could just return the # user here. We carry on and go through the auth checks though, # for paranoia. - registered_user_id = self.auth_handler.get_session_data( + registered_user_id = await self.auth_handler.get_session_data( session_id, "registered_user_id", None ) @@ -598,7 +598,7 @@ class RegisterRestServlet(RestServlet): # remember that we've now registered that user account, and with # what user ID (since the user may not have specified) - self.auth_handler.set_session_data( + await self.auth_handler.set_session_data( session_id, "registered_user_id", registered_user_id ) diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/data_stores/main/__init__.py index bd7c3a00ea..ceba10882c 100644 --- a/synapse/storage/data_stores/main/__init__.py +++ b/synapse/storage/data_stores/main/__init__.py @@ -66,6 +66,7 @@ from .stats import StatsStore from .stream import StreamStore from .tags import TagsStore from .transactions import TransactionStore +from .ui_auth import UIAuthStore from .user_directory import UserDirectoryStore from .user_erasure_store import UserErasureStore @@ -112,6 +113,7 @@ class DataStore( StatsStore, RelationsStore, CacheInvalidationStore, + UIAuthStore, ): def __init__(self, database: Database, db_conn, hs): self.hs = hs diff --git a/synapse/storage/data_stores/main/schema/delta/58/03persist_ui_auth.sql b/synapse/storage/data_stores/main/schema/delta/58/03persist_ui_auth.sql new file mode 100644 index 0000000000..dcb593fc2d --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/58/03persist_ui_auth.sql @@ -0,0 +1,36 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS ui_auth_sessions( + session_id TEXT NOT NULL, -- The session ID passed to the client. + creation_time BIGINT NOT NULL, -- The time this session was created (epoch time in milliseconds). + serverdict TEXT NOT NULL, -- A JSON dictionary of arbitrary data added by Synapse. + clientdict TEXT NOT NULL, -- A JSON dictionary of arbitrary data from the client. + uri TEXT NOT NULL, -- The URI the UI authentication session is using. + method TEXT NOT NULL, -- The HTTP method the UI authentication session is using. + -- The clientdict, uri, and method make up an tuple that must be immutable + -- throughout the lifetime of the UI Auth session. + description TEXT NOT NULL, -- A human readable description of the operation which caused the UI Auth flow to occur. + UNIQUE (session_id) +); + +CREATE TABLE IF NOT EXISTS ui_auth_sessions_credentials( + session_id TEXT NOT NULL, -- The corresponding UI Auth session. + stage_type TEXT NOT NULL, -- The stage type. + result TEXT NOT NULL, -- The result of the stage verification, stored as JSON. + UNIQUE (session_id, stage_type), + FOREIGN KEY (session_id) + REFERENCES ui_auth_sessions (session_id) +); diff --git a/synapse/storage/data_stores/main/ui_auth.py b/synapse/storage/data_stores/main/ui_auth.py new file mode 100644 index 0000000000..c8eebc9378 --- /dev/null +++ b/synapse/storage/data_stores/main/ui_auth.py @@ -0,0 +1,279 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +from typing import Any, Dict, Optional, Union + +import attr + +import synapse.util.stringutils as stringutils +from synapse.api.errors import StoreError +from synapse.storage._base import SQLBaseStore +from synapse.types import JsonDict + + +@attr.s +class UIAuthSessionData: + session_id = attr.ib(type=str) + # The dictionary from the client root level, not the 'auth' key. + clientdict = attr.ib(type=JsonDict) + # The URI and method the session was intiatied with. These are checked at + # each stage of the authentication to ensure that the asked for operation + # has not changed. + uri = attr.ib(type=str) + method = attr.ib(type=str) + # A string description of the operation that the current authentication is + # authorising. + description = attr.ib(type=str) + + +class UIAuthWorkerStore(SQLBaseStore): + """ + Manage user interactive authentication sessions. + """ + + async def create_ui_auth_session( + self, clientdict: JsonDict, uri: str, method: str, description: str, + ) -> UIAuthSessionData: + """ + Creates a new user interactive authentication session. + + The session can be used to track the stages necessary to authenticate a + user across multiple HTTP requests. + + Args: + clientdict: + The dictionary from the client root level, not the 'auth' key. + uri: + The URI this session was initiated with, this is checked at each + stage of the authentication to ensure that the asked for + operation has not changed. + method: + The method this session was initiated with, this is checked at each + stage of the authentication to ensure that the asked for + operation has not changed. + description: + A string description of the operation that the current + authentication is authorising. + Returns: + The newly created session. + Raises: + StoreError if a unique session ID cannot be generated. + """ + # The clientdict gets stored as JSON. + clientdict_json = json.dumps(clientdict) + + # autogen a session ID and try to create it. We may clash, so just + # try a few times till one goes through, giving up eventually. + attempts = 0 + while attempts < 5: + session_id = stringutils.random_string(24) + + try: + await self.db.simple_insert( + table="ui_auth_sessions", + values={ + "session_id": session_id, + "clientdict": clientdict_json, + "uri": uri, + "method": method, + "description": description, + "serverdict": "{}", + "creation_time": self.hs.get_clock().time_msec(), + }, + desc="create_ui_auth_session", + ) + return UIAuthSessionData( + session_id, clientdict, uri, method, description + ) + except self.db.engine.module.IntegrityError: + attempts += 1 + raise StoreError(500, "Couldn't generate a session ID.") + + async def get_ui_auth_session(self, session_id: str) -> UIAuthSessionData: + """Retrieve a UI auth session. + + Args: + session_id: The ID of the session. + Returns: + A dict containing the device information. + Raises: + StoreError if the session is not found. + """ + result = await self.db.simple_select_one( + table="ui_auth_sessions", + keyvalues={"session_id": session_id}, + retcols=("clientdict", "uri", "method", "description"), + desc="get_ui_auth_session", + ) + + result["clientdict"] = json.loads(result["clientdict"]) + + return UIAuthSessionData(session_id, **result) + + async def mark_ui_auth_stage_complete( + self, session_id: str, stage_type: str, result: Union[str, bool, JsonDict], + ): + """ + Mark a session stage as completed. + + Args: + session_id: The ID of the corresponding session. + stage_type: The completed stage type. + result: The result of the stage verification. + Raises: + StoreError if the session cannot be found. + """ + # Add (or update) the results of the current stage to the database. + # + # Note that we need to allow for the same stage to complete multiple + # times here so that registration is idempotent. + try: + await self.db.simple_upsert( + table="ui_auth_sessions_credentials", + keyvalues={"session_id": session_id, "stage_type": stage_type}, + values={"result": json.dumps(result)}, + desc="mark_ui_auth_stage_complete", + ) + except self.db.engine.module.IntegrityError: + raise StoreError(400, "Unknown session ID: %s" % (session_id,)) + + async def get_completed_ui_auth_stages( + self, session_id: str + ) -> Dict[str, Union[str, bool, JsonDict]]: + """ + Retrieve the completed stages of a UI authentication session. + + Args: + session_id: The ID of the session. + Returns: + The completed stages mapped to the result of the verification of + that auth-type. + """ + results = {} + for row in await self.db.simple_select_list( + table="ui_auth_sessions_credentials", + keyvalues={"session_id": session_id}, + retcols=("stage_type", "result"), + desc="get_completed_ui_auth_stages", + ): + results[row["stage_type"]] = json.loads(row["result"]) + + return results + + async def set_ui_auth_session_data(self, session_id: str, key: str, value: Any): + """ + Store a key-value pair into the sessions data associated with this + request. This data is stored server-side and cannot be modified by + the client. + + Args: + session_id: The ID of this session as returned from check_auth + key: The key to store the data under + value: The data to store + Raises: + StoreError if the session cannot be found. + """ + await self.db.runInteraction( + "set_ui_auth_session_data", + self._set_ui_auth_session_data_txn, + session_id, + key, + value, + ) + + def _set_ui_auth_session_data_txn(self, txn, session_id: str, key: str, value: Any): + # Get the current value. + result = self.db.simple_select_one_txn( + txn, + table="ui_auth_sessions", + keyvalues={"session_id": session_id}, + retcols=("serverdict",), + ) + + # Update it and add it back to the database. + serverdict = json.loads(result["serverdict"]) + serverdict[key] = value + + self.db.simple_update_one_txn( + txn, + table="ui_auth_sessions", + keyvalues={"session_id": session_id}, + updatevalues={"serverdict": json.dumps(serverdict)}, + ) + + async def get_ui_auth_session_data( + self, session_id: str, key: str, default: Optional[Any] = None + ) -> Any: + """ + Retrieve data stored with set_session_data + + Args: + session_id: The ID of this session as returned from check_auth + key: The key to store the data under + default: Value to return if the key has not been set + Raises: + StoreError if the session cannot be found. + """ + result = await self.db.simple_select_one( + table="ui_auth_sessions", + keyvalues={"session_id": session_id}, + retcols=("serverdict",), + desc="get_ui_auth_session_data", + ) + + serverdict = json.loads(result["serverdict"]) + + return serverdict.get(key, default) + + +class UIAuthStore(UIAuthWorkerStore): + def delete_old_ui_auth_sessions(self, expiration_time: int): + """ + Remove sessions which were last used earlier than the expiration time. + + Args: + expiration_time: The latest time that is still considered valid. + This is an epoch time in milliseconds. + + """ + return self.db.runInteraction( + "delete_old_ui_auth_sessions", + self._delete_old_ui_auth_sessions_txn, + expiration_time, + ) + + def _delete_old_ui_auth_sessions_txn(self, txn, expiration_time: int): + # Get the expired sessions. + sql = "SELECT session_id FROM ui_auth_sessions WHERE creation_time <= ?" + txn.execute(sql, [expiration_time]) + session_ids = [r[0] for r in txn.fetchall()] + + # Delete the corresponding completed credentials. + self.db.simple_delete_many_txn( + txn, + table="ui_auth_sessions_credentials", + column="session_id", + iterable=session_ids, + keyvalues={}, + ) + + # Finally, delete the sessions. + self.db.simple_delete_many_txn( + txn, + table="ui_auth_sessions", + column="session_id", + iterable=session_ids, + keyvalues={}, + ) diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py index 3bc2e8b986..215a949442 100644 --- a/synapse/storage/engines/sqlite.py +++ b/synapse/storage/engines/sqlite.py @@ -85,6 +85,7 @@ class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]): prepare_database(db_conn, self, config=None) db_conn.create_function("rank", 1, _rank) + db_conn.execute("PRAGMA foreign_keys = ON;") def is_deadlock(self, error): return False diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py index 624bf5ada2..587be7b2e7 100644 --- a/tests/rest/client/v2_alpha/test_auth.py +++ b/tests/rest/client/v2_alpha/test_auth.py @@ -181,3 +181,43 @@ class FallbackAuthTests(unittest.HomeserverTestCase): ) self.render(request) self.assertEqual(channel.code, 403) + + def test_complete_operation_unknown_session(self): + """ + Attempting to mark an invalid session as complete should error. + """ + + # Make the initial request to register. (Later on a different password + # will be used.) + request, channel = self.make_request( + "POST", + "register", + {"username": "user", "type": "m.login.password", "password": "bar"}, + ) + self.render(request) + + # Returns a 401 as per the spec + self.assertEqual(request.code, 401) + # Grab the session + session = channel.json_body["session"] + # Assert our configured public key is being given + self.assertEqual( + channel.json_body["params"]["m.login.recaptcha"]["public_key"], "brokencake" + ) + + request, channel = self.make_request( + "GET", "auth/m.login.recaptcha/fallback/web?session=" + session + ) + self.render(request) + self.assertEqual(request.code, 200) + + # Attempt to complete an unknown session, which should return an error. + unknown_session = session + "unknown" + request, channel = self.make_request( + "POST", + "auth/m.login.recaptcha/fallback/web?session=" + + unknown_session + + "&g-recaptcha-response=a", + ) + self.render(request) + self.assertEqual(request.code, 400) diff --git a/tests/utils.py b/tests/utils.py index 037cb134f0..f9be62b499 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -512,8 +512,8 @@ class MockClock(object): return t - def looping_call(self, function, interval): - self.loopers.append([function, interval / 1000.0, self.now]) + def looping_call(self, function, interval, *args, **kwargs): + self.loopers.append([function, interval / 1000.0, self.now, args, kwargs]) def cancel_call_later(self, timer, ignore_errs=False): if timer[2]: @@ -543,9 +543,9 @@ class MockClock(object): self.timers.append(t) for looped in self.loopers: - func, interval, last = looped + func, interval, last, args, kwargs = looped if last + interval < self.now: - func() + func(*args, **kwargs) looped[2] = self.now def advance_time_msec(self, ms): diff --git a/tox.ini b/tox.ini index 2630857436..eccc44e436 100644 --- a/tox.ini +++ b/tox.ini @@ -200,8 +200,9 @@ commands = mypy \ synapse/replication \ synapse/rest \ synapse/spam_checker_api \ - synapse/storage/engines \ + synapse/storage/data_stores/main/ui_auth.py \ synapse/storage/database.py \ + synapse/storage/engines \ synapse/streams \ synapse/util/caches/stream_change_cache.py \ tests/replication/tcp/streams \ -- cgit 1.5.1 From 6b22921b195c24762cd7c02a8b8fad75791fce70 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Fri, 1 May 2020 15:15:36 +0100 Subject: async/await is_server_admin (#7363) --- changelog.d/7363.misc | 1 + synapse/api/auth.py | 9 +- synapse/federation/federation_client.py | 5 +- synapse/groups/groups_server.py | 64 +++++------ synapse/handlers/_base.py | 16 ++- synapse/handlers/directory.py | 51 ++++----- synapse/handlers/federation.py | 21 ++-- synapse/handlers/groups_local.py | 24 ++-- synapse/handlers/message.py | 83 +++++++------- synapse/handlers/profile.py | 39 ++++--- synapse/handlers/register.py | 49 ++++---- synapse/handlers/room.py | 121 ++++++++++---------- synapse/handlers/room_member.py | 127 ++++++++++----------- synapse/server_notices/consent_server_notices.py | 11 +- .../resource_limits_server_notices.py | 35 +++--- synapse/server_notices/server_notices_manager.py | 32 +++--- synapse/server_notices/server_notices_sender.py | 12 +- synapse/storage/data_stores/main/registration.py | 5 +- tests/handlers/test_profile.py | 60 ++++++---- tests/handlers/test_register.py | 29 +++-- .../test_resource_limits_server_notices.py | 48 ++++---- tests/test_federation.py | 6 +- 22 files changed, 410 insertions(+), 438 deletions(-) create mode 100644 changelog.d/7363.misc (limited to 'tests') diff --git a/changelog.d/7363.misc b/changelog.d/7363.misc new file mode 100644 index 0000000000..1e3cddde79 --- /dev/null +++ b/changelog.d/7363.misc @@ -0,0 +1 @@ +Convert RegistrationWorkerStore.is_server_admin and dependent code to async/await. \ No newline at end of file diff --git a/synapse/api/auth.py b/synapse/api/auth.py index c1ade1333b..c5d1eb952b 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -537,8 +537,7 @@ class Auth(object): return defer.succeed(auth_ids) - @defer.inlineCallbacks - def check_can_change_room_list(self, room_id: str, user: UserID): + async def check_can_change_room_list(self, room_id: str, user: UserID): """Determine whether the user is allowed to edit the room's entry in the published room list. @@ -547,17 +546,17 @@ class Auth(object): user """ - is_admin = yield self.is_server_admin(user) + is_admin = await self.is_server_admin(user) if is_admin: return True user_id = user.to_string() - yield self.check_user_in_room(room_id, user_id) + await self.check_user_in_room(room_id, user_id) # We currently require the user is a "moderator" in the room. We do this # by checking if they would (theoretically) be able to change the # m.room.canonical_alias events - power_level_event = yield self.state.get_current_state( + power_level_event = await self.state.get_current_state( room_id, EventTypes.PowerLevels, "" ) diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 58b13da616..687cd841ac 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -976,14 +976,13 @@ class FederationClient(FederationBase): return signed_events - @defer.inlineCallbacks - def forward_third_party_invite(self, destinations, room_id, event_dict): + async def forward_third_party_invite(self, destinations, room_id, event_dict): for destination in destinations: if destination == self.server_name: continue try: - yield self.transport_layer.exchange_third_party_invite( + await self.transport_layer.exchange_third_party_invite( destination=destination, room_id=room_id, event_dict=event_dict ) return None diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py index 4f0dc0a209..4acb4fa489 100644 --- a/synapse/groups/groups_server.py +++ b/synapse/groups/groups_server.py @@ -748,17 +748,18 @@ class GroupsServerHandler(GroupsServerWorkerHandler): raise NotImplementedError() - @defer.inlineCallbacks - def remove_user_from_group(self, group_id, user_id, requester_user_id, content): + async def remove_user_from_group( + self, group_id, user_id, requester_user_id, content + ): """Remove a user from the group; either a user is leaving or an admin kicked them. """ - yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) is_kick = False if requester_user_id != user_id: - is_admin = yield self.store.is_user_admin_in_group( + is_admin = await self.store.is_user_admin_in_group( group_id, requester_user_id ) if not is_admin: @@ -766,30 +767,29 @@ class GroupsServerHandler(GroupsServerWorkerHandler): is_kick = True - yield self.store.remove_user_from_group(group_id, user_id) + await self.store.remove_user_from_group(group_id, user_id) if is_kick: if self.hs.is_mine_id(user_id): groups_local = self.hs.get_groups_local_handler() - yield groups_local.user_removed_from_group(group_id, user_id, {}) + await groups_local.user_removed_from_group(group_id, user_id, {}) else: - yield self.transport_client.remove_user_from_group_notification( + await self.transport_client.remove_user_from_group_notification( get_domain_from_id(user_id), group_id, user_id, {} ) if not self.hs.is_mine_id(user_id): - yield self.store.maybe_delete_remote_profile_cache(user_id) + await self.store.maybe_delete_remote_profile_cache(user_id) # Delete group if the last user has left - users = yield self.store.get_users_in_group(group_id, include_private=True) + users = await self.store.get_users_in_group(group_id, include_private=True) if not users: - yield self.store.delete_group(group_id) + await self.store.delete_group(group_id) return {} - @defer.inlineCallbacks - def create_group(self, group_id, requester_user_id, content): - group = yield self.check_group_is_ours(group_id, requester_user_id) + async def create_group(self, group_id, requester_user_id, content): + group = await self.check_group_is_ours(group_id, requester_user_id) logger.info("Attempting to create group with ID: %r", group_id) @@ -799,7 +799,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler): if group: raise SynapseError(400, "Group already exists") - is_admin = yield self.auth.is_server_admin( + is_admin = await self.auth.is_server_admin( UserID.from_string(requester_user_id) ) if not is_admin: @@ -822,7 +822,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler): long_description = profile.get("long_description") user_profile = content.get("user_profile", {}) - yield self.store.create_group( + await self.store.create_group( group_id, requester_user_id, name=name, @@ -834,7 +834,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler): if not self.hs.is_mine_id(requester_user_id): remote_attestation = content["attestation"] - yield self.attestations.verify_attestation( + await self.attestations.verify_attestation( remote_attestation, user_id=requester_user_id, group_id=group_id ) @@ -845,7 +845,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler): local_attestation = None remote_attestation = None - yield self.store.add_user_to_group( + await self.store.add_user_to_group( group_id, requester_user_id, is_admin=True, @@ -855,7 +855,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler): ) if not self.hs.is_mine_id(requester_user_id): - yield self.store.add_remote_profile_cache( + await self.store.add_remote_profile_cache( requester_user_id, displayname=user_profile.get("displayname"), avatar_url=user_profile.get("avatar_url"), @@ -863,8 +863,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler): return {"group_id": group_id} - @defer.inlineCallbacks - def delete_group(self, group_id, requester_user_id): + async def delete_group(self, group_id, requester_user_id): """Deletes a group, kicking out all current members. Only group admins or server admins can call this request @@ -877,14 +876,14 @@ class GroupsServerHandler(GroupsServerWorkerHandler): Deferred """ - yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) # Only server admins or group admins can delete groups. - is_admin = yield self.store.is_user_admin_in_group(group_id, requester_user_id) + is_admin = await self.store.is_user_admin_in_group(group_id, requester_user_id) if not is_admin: - is_admin = yield self.auth.is_server_admin( + is_admin = await self.auth.is_server_admin( UserID.from_string(requester_user_id) ) @@ -892,18 +891,17 @@ class GroupsServerHandler(GroupsServerWorkerHandler): raise SynapseError(403, "User is not an admin") # Before deleting the group lets kick everyone out of it - users = yield self.store.get_users_in_group(group_id, include_private=True) + users = await self.store.get_users_in_group(group_id, include_private=True) - @defer.inlineCallbacks - def _kick_user_from_group(user_id): + async def _kick_user_from_group(user_id): if self.hs.is_mine_id(user_id): groups_local = self.hs.get_groups_local_handler() - yield groups_local.user_removed_from_group(group_id, user_id, {}) + await groups_local.user_removed_from_group(group_id, user_id, {}) else: - yield self.transport_client.remove_user_from_group_notification( + await self.transport_client.remove_user_from_group_notification( get_domain_from_id(user_id), group_id, user_id, {} ) - yield self.store.maybe_delete_remote_profile_cache(user_id) + await self.store.maybe_delete_remote_profile_cache(user_id) # We kick users out in the order of: # 1. Non-admins @@ -922,11 +920,11 @@ class GroupsServerHandler(GroupsServerWorkerHandler): else: non_admins.append(u["user_id"]) - yield concurrently_execute(_kick_user_from_group, non_admins, 10) - yield concurrently_execute(_kick_user_from_group, admins, 10) - yield _kick_user_from_group(requester_user_id) + await concurrently_execute(_kick_user_from_group, non_admins, 10) + await concurrently_execute(_kick_user_from_group, admins, 10) + await _kick_user_from_group(requester_user_id) - yield self.store.delete_group(group_id) + await self.store.delete_group(group_id) def _parse_join_policy_from_contents(content): diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index 51413d910e..3b781d9836 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -126,30 +126,28 @@ class BaseHandler(object): retry_after_ms=int(1000 * (time_allowed - time_now)) ) - @defer.inlineCallbacks - def maybe_kick_guest_users(self, event, context=None): + async def maybe_kick_guest_users(self, event, context=None): # Technically this function invalidates current_state by changing it. # Hopefully this isn't that important to the caller. if event.type == EventTypes.GuestAccess: guest_access = event.content.get("guest_access", "forbidden") if guest_access != "can_join": if context: - current_state_ids = yield context.get_current_state_ids() - current_state = yield self.store.get_events( + current_state_ids = await context.get_current_state_ids() + current_state = await self.store.get_events( list(current_state_ids.values()) ) else: - current_state = yield self.state_handler.get_current_state( + current_state = await self.state_handler.get_current_state( event.room_id ) current_state = list(current_state.values()) logger.info("maybe_kick_guest_users %r", current_state) - yield self.kick_guest_users(current_state) + await self.kick_guest_users(current_state) - @defer.inlineCallbacks - def kick_guest_users(self, current_state): + async def kick_guest_users(self, current_state): for member_event in current_state: try: if member_event.type != EventTypes.Member: @@ -180,7 +178,7 @@ class BaseHandler(object): # homeserver. requester = synapse.types.create_requester(target_user, is_guest=True) handler = self.hs.get_room_member_handler() - yield handler.update_membership( + await handler.update_membership( requester, target_user, member_event.room_id, diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 53e5f585d9..f2f16b1e43 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -86,8 +86,7 @@ class DirectoryHandler(BaseHandler): room_alias, room_id, servers, creator=creator ) - @defer.inlineCallbacks - def create_association( + async def create_association( self, requester: Requester, room_alias: RoomAlias, @@ -129,10 +128,10 @@ class DirectoryHandler(BaseHandler): else: # Server admins are not subject to the same constraints as normal # users when creating an alias (e.g. being in the room). - is_admin = yield self.auth.is_server_admin(requester.user) + is_admin = await self.auth.is_server_admin(requester.user) if (self.require_membership and check_membership) and not is_admin: - rooms_for_user = yield self.store.get_rooms_for_user(user_id) + rooms_for_user = await self.store.get_rooms_for_user(user_id) if room_id not in rooms_for_user: raise AuthError( 403, "You must be in the room to create an alias for it" @@ -149,7 +148,7 @@ class DirectoryHandler(BaseHandler): # per alias creation rule? raise SynapseError(403, "Not allowed to create alias") - can_create = yield self.can_modify_alias(room_alias, user_id=user_id) + can_create = await self.can_modify_alias(room_alias, user_id=user_id) if not can_create: raise AuthError( 400, @@ -157,10 +156,9 @@ class DirectoryHandler(BaseHandler): errcode=Codes.EXCLUSIVE, ) - yield self._create_association(room_alias, room_id, servers, creator=user_id) + await self._create_association(room_alias, room_id, servers, creator=user_id) - @defer.inlineCallbacks - def delete_association(self, requester: Requester, room_alias: RoomAlias): + async def delete_association(self, requester: Requester, room_alias: RoomAlias): """Remove an alias from the directory (this is only meant for human users; AS users should call @@ -184,7 +182,7 @@ class DirectoryHandler(BaseHandler): user_id = requester.user.to_string() try: - can_delete = yield self._user_can_delete_alias(room_alias, user_id) + can_delete = await self._user_can_delete_alias(room_alias, user_id) except StoreError as e: if e.code == 404: raise NotFoundError("Unknown room alias") @@ -193,7 +191,7 @@ class DirectoryHandler(BaseHandler): if not can_delete: raise AuthError(403, "You don't have permission to delete the alias.") - can_delete = yield self.can_modify_alias(room_alias, user_id=user_id) + can_delete = await self.can_modify_alias(room_alias, user_id=user_id) if not can_delete: raise SynapseError( 400, @@ -201,10 +199,10 @@ class DirectoryHandler(BaseHandler): errcode=Codes.EXCLUSIVE, ) - room_id = yield self._delete_association(room_alias) + room_id = await self._delete_association(room_alias) try: - yield self._update_canonical_alias(requester, user_id, room_id, room_alias) + await self._update_canonical_alias(requester, user_id, room_id, room_alias) except AuthError as e: logger.info("Failed to update alias events: %s", e) @@ -296,15 +294,14 @@ class DirectoryHandler(BaseHandler): Codes.NOT_FOUND, ) - @defer.inlineCallbacks - def _update_canonical_alias( + async def _update_canonical_alias( self, requester: Requester, user_id: str, room_id: str, room_alias: RoomAlias ): """ Send an updated canonical alias event if the removed alias was set as the canonical alias or listed in the alt_aliases field. """ - alias_event = yield self.state.get_current_state( + alias_event = await self.state.get_current_state( room_id, EventTypes.CanonicalAlias, "" ) @@ -335,7 +332,7 @@ class DirectoryHandler(BaseHandler): del content["alt_aliases"] if send_update: - yield self.event_creation_handler.create_and_send_nonmember_event( + await self.event_creation_handler.create_and_send_nonmember_event( requester, { "type": EventTypes.CanonicalAlias, @@ -376,8 +373,7 @@ class DirectoryHandler(BaseHandler): # either no interested services, or no service with an exclusive lock return defer.succeed(True) - @defer.inlineCallbacks - def _user_can_delete_alias(self, alias: RoomAlias, user_id: str): + async def _user_can_delete_alias(self, alias: RoomAlias, user_id: str): """Determine whether a user can delete an alias. One of the following must be true: @@ -388,24 +384,23 @@ class DirectoryHandler(BaseHandler): for the current room. """ - creator = yield self.store.get_room_alias_creator(alias.to_string()) + creator = await self.store.get_room_alias_creator(alias.to_string()) if creator is not None and creator == user_id: return True # Resolve the alias to the corresponding room. - room_mapping = yield self.get_association(alias) + room_mapping = await self.get_association(alias) room_id = room_mapping["room_id"] if not room_id: return False - res = yield self.auth.check_can_change_room_list( + res = await self.auth.check_can_change_room_list( room_id, UserID.from_string(user_id) ) return res - @defer.inlineCallbacks - def edit_published_room_list( + async def edit_published_room_list( self, requester: Requester, room_id: str, visibility: str ): """Edit the entry of the room in the published room list. @@ -433,11 +428,11 @@ class DirectoryHandler(BaseHandler): 403, "This user is not permitted to publish rooms to the room list" ) - room = yield self.store.get_room(room_id) + room = await self.store.get_room(room_id) if room is None: raise SynapseError(400, "Unknown room") - can_change_room_list = yield self.auth.check_can_change_room_list( + can_change_room_list = await self.auth.check_can_change_room_list( room_id, requester.user ) if not can_change_room_list: @@ -449,8 +444,8 @@ class DirectoryHandler(BaseHandler): making_public = visibility == "public" if making_public: - room_aliases = yield self.store.get_aliases_for_room(room_id) - canonical_alias = yield self.store.get_canonical_alias_for_room(room_id) + room_aliases = await self.store.get_aliases_for_room(room_id) + canonical_alias = await self.store.get_canonical_alias_for_room(room_id) if canonical_alias: room_aliases.append(canonical_alias) @@ -462,7 +457,7 @@ class DirectoryHandler(BaseHandler): # per alias creation rule? raise SynapseError(403, "Not allowed to publish room") - yield self.store.set_room_is_public(room_id, making_public) + await self.store.set_room_is_public(room_id, making_public) @defer.inlineCallbacks def edit_published_appservice_room_list( diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 41b96c0a73..4e5c645525 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -2562,9 +2562,8 @@ class FederationHandler(BaseHandler): "missing": [e.event_id for e in missing_locals], } - @defer.inlineCallbacks @log_function - def exchange_third_party_invite( + async def exchange_third_party_invite( self, sender_user_id, target_user_id, room_id, signed ): third_party_invite = {"signed": signed} @@ -2580,16 +2579,16 @@ class FederationHandler(BaseHandler): "state_key": target_user_id, } - if (yield self.auth.check_host_in_room(room_id, self.hs.hostname)): - room_version = yield self.store.get_room_version_id(room_id) + if await self.auth.check_host_in_room(room_id, self.hs.hostname): + room_version = await self.store.get_room_version_id(room_id) builder = self.event_builder_factory.new(room_version, event_dict) EventValidator().validate_builder(builder) - event, context = yield self.event_creation_handler.create_new_client_event( + event, context = await self.event_creation_handler.create_new_client_event( builder=builder ) - event_allowed = yield self.third_party_event_rules.check_event_allowed( + event_allowed = await self.third_party_event_rules.check_event_allowed( event, context ) if not event_allowed: @@ -2601,7 +2600,7 @@ class FederationHandler(BaseHandler): 403, "This event is not allowed in this context", Codes.FORBIDDEN ) - event, context = yield self.add_display_name_to_third_party_invite( + event, context = await self.add_display_name_to_third_party_invite( room_version, event_dict, event, context ) @@ -2612,19 +2611,19 @@ class FederationHandler(BaseHandler): event.internal_metadata.send_on_behalf_of = self.hs.hostname try: - yield self.auth.check_from_context(room_version, event, context) + await self.auth.check_from_context(room_version, event, context) except AuthError as e: logger.warning("Denying new third party invite %r because %s", event, e) raise e - yield self._check_signature(event, context) + await self._check_signature(event, context) # We retrieve the room member handler here as to not cause a cyclic dependency member_handler = self.hs.get_room_member_handler() - yield member_handler.send_membership_event(None, event, context) + await member_handler.send_membership_event(None, event, context) else: destinations = {x.split(":", 1)[-1] for x in (sender_user_id, room_id)} - yield self.federation_client.forward_third_party_invite( + await self.federation_client.forward_third_party_invite( destinations, room_id, event_dict ) diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py index ad22415782..ca5c83811a 100644 --- a/synapse/handlers/groups_local.py +++ b/synapse/handlers/groups_local.py @@ -284,15 +284,14 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): set_group_join_policy = _create_rerouter("set_group_join_policy") - @defer.inlineCallbacks - def create_group(self, group_id, user_id, content): + async def create_group(self, group_id, user_id, content): """Create a group """ logger.info("Asking to create group with ID: %r", group_id) if self.is_mine_id(group_id): - res = yield self.groups_server_handler.create_group( + res = await self.groups_server_handler.create_group( group_id, user_id, content ) local_attestation = None @@ -301,10 +300,10 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): local_attestation = self.attestations.create_attestation(group_id, user_id) content["attestation"] = local_attestation - content["user_profile"] = yield self.profile_handler.get_profile(user_id) + content["user_profile"] = await self.profile_handler.get_profile(user_id) try: - res = yield self.transport_client.create_group( + res = await self.transport_client.create_group( get_domain_from_id(group_id), group_id, user_id, content ) except HttpResponseException as e: @@ -313,7 +312,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): raise SynapseError(502, "Failed to contact group server") remote_attestation = res["attestation"] - yield self.attestations.verify_attestation( + await self.attestations.verify_attestation( remote_attestation, group_id=group_id, user_id=user_id, @@ -321,7 +320,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): ) is_publicised = content.get("publicise", False) - token = yield self.store.register_user_group_membership( + token = await self.store.register_user_group_membership( group_id, user_id, membership="join", @@ -482,12 +481,13 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): return {"state": "invite", "user_profile": user_profile} - @defer.inlineCallbacks - def remove_user_from_group(self, group_id, user_id, requester_user_id, content): + async def remove_user_from_group( + self, group_id, user_id, requester_user_id, content + ): """Remove a user from a group """ if user_id == requester_user_id: - token = yield self.store.register_user_group_membership( + token = await self.store.register_user_group_membership( group_id, user_id, membership="leave" ) self.notifier.on_new_event("groups_key", token, users=[user_id]) @@ -496,13 +496,13 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): # retry if the group server is currently down. if self.is_mine_id(group_id): - res = yield self.groups_server_handler.remove_user_from_group( + res = await self.groups_server_handler.remove_user_from_group( group_id, user_id, requester_user_id, content ) else: content["requester_user_id"] = requester_user_id try: - res = yield self.transport_client.remove_user_from_group( + res = await self.transport_client.remove_user_from_group( get_domain_from_id(group_id), group_id, requester_user_id, diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 522271eed1..a324f09340 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -626,8 +626,7 @@ class EventCreationHandler(object): msg = self._block_events_without_consent_error % {"consent_uri": consent_uri} raise ConsentNotGivenError(msg=msg, consent_uri=consent_uri) - @defer.inlineCallbacks - def send_nonmember_event(self, requester, event, context, ratelimit=True): + async def send_nonmember_event(self, requester, event, context, ratelimit=True): """ Persists and notifies local clients and federation of an event. @@ -647,7 +646,7 @@ class EventCreationHandler(object): assert self.hs.is_mine(user), "User must be our own: %s" % (user,) if event.is_state(): - prev_state = yield self.deduplicate_state_event(event, context) + prev_state = await self.deduplicate_state_event(event, context) if prev_state is not None: logger.info( "Not bothering to persist state event %s duplicated by %s", @@ -656,7 +655,7 @@ class EventCreationHandler(object): ) return prev_state - yield self.handle_new_client_event( + await self.handle_new_client_event( requester=requester, event=event, context=context, ratelimit=ratelimit ) @@ -683,8 +682,7 @@ class EventCreationHandler(object): return prev_event return - @defer.inlineCallbacks - def create_and_send_nonmember_event( + async def create_and_send_nonmember_event( self, requester, event_dict, ratelimit=True, txn_id=None ): """ @@ -698,8 +696,8 @@ class EventCreationHandler(object): # a situation where event persistence can't keep up, causing # extremities to pile up, which in turn leads to state resolution # taking longer. - with (yield self.limiter.queue(event_dict["room_id"])): - event, context = yield self.create_event( + with (await self.limiter.queue(event_dict["room_id"])): + event, context = await self.create_event( requester, event_dict, token_id=requester.access_token_id, txn_id=txn_id ) @@ -709,7 +707,7 @@ class EventCreationHandler(object): spam_error = "Spam is not permitted here" raise SynapseError(403, spam_error, Codes.FORBIDDEN) - yield self.send_nonmember_event( + await self.send_nonmember_event( requester, event, context, ratelimit=ratelimit ) return event @@ -770,8 +768,7 @@ class EventCreationHandler(object): return (event, context) @measure_func("handle_new_client_event") - @defer.inlineCallbacks - def handle_new_client_event( + async def handle_new_client_event( self, requester, event, context, ratelimit=True, extra_users=[] ): """Processes a new event. This includes checking auth, persisting it, @@ -794,9 +791,9 @@ class EventCreationHandler(object): ): room_version = event.content.get("room_version", RoomVersions.V1.identifier) else: - room_version = yield self.store.get_room_version_id(event.room_id) + room_version = await self.store.get_room_version_id(event.room_id) - event_allowed = yield self.third_party_event_rules.check_event_allowed( + event_allowed = await self.third_party_event_rules.check_event_allowed( event, context ) if not event_allowed: @@ -805,7 +802,7 @@ class EventCreationHandler(object): ) try: - yield self.auth.check_from_context(room_version, event, context) + await self.auth.check_from_context(room_version, event, context) except AuthError as err: logger.warning("Denying new event %r because %s", event, err) raise err @@ -818,7 +815,7 @@ class EventCreationHandler(object): logger.exception("Failed to encode content: %r", event.content) raise - yield self.action_generator.handle_push_actions_for_event(event, context) + await self.action_generator.handle_push_actions_for_event(event, context) # reraise does not allow inlineCallbacks to preserve the stacktrace, so we # hack around with a try/finally instead. @@ -826,7 +823,7 @@ class EventCreationHandler(object): try: # If we're a worker we need to hit out to the master. if self.config.worker_app: - yield self.send_event_to_master( + await self.send_event_to_master( event_id=event.event_id, store=self.store, requester=requester, @@ -838,7 +835,7 @@ class EventCreationHandler(object): success = True return - yield self.persist_and_notify_client_event( + await self.persist_and_notify_client_event( requester, event, context, ratelimit=ratelimit, extra_users=extra_users ) @@ -883,8 +880,7 @@ class EventCreationHandler(object): Codes.BAD_ALIAS, ) - @defer.inlineCallbacks - def persist_and_notify_client_event( + async def persist_and_notify_client_event( self, requester, event, context, ratelimit=True, extra_users=[] ): """Called when we have fully built the event, have already @@ -901,7 +897,7 @@ class EventCreationHandler(object): # user is actually admin or not). is_admin_redaction = False if event.type == EventTypes.Redaction: - original_event = yield self.store.get_event( + original_event = await self.store.get_event( event.redacts, redact_behaviour=EventRedactBehaviour.AS_IS, get_prev_content=False, @@ -913,11 +909,11 @@ class EventCreationHandler(object): original_event and event.sender != original_event.sender ) - yield self.base_handler.ratelimit( + await self.base_handler.ratelimit( requester, is_admin_redaction=is_admin_redaction ) - yield self.base_handler.maybe_kick_guest_users(event, context) + await self.base_handler.maybe_kick_guest_users(event, context) if event.type == EventTypes.CanonicalAlias: # Validate a newly added alias or newly added alt_aliases. @@ -927,7 +923,7 @@ class EventCreationHandler(object): original_event_id = event.unsigned.get("replaces_state") if original_event_id: - original_event = yield self.store.get_event(original_event_id) + original_event = await self.store.get_event(original_event_id) if original_event: original_alias = original_event.content.get("alias", None) @@ -937,7 +933,7 @@ class EventCreationHandler(object): room_alias_str = event.content.get("alias", None) directory_handler = self.hs.get_handlers().directory_handler if room_alias_str and room_alias_str != original_alias: - yield self._validate_canonical_alias( + await self._validate_canonical_alias( directory_handler, room_alias_str, event.room_id ) @@ -957,7 +953,7 @@ class EventCreationHandler(object): new_alt_aliases = set(alt_aliases) - set(original_alt_aliases) if new_alt_aliases: for alias_str in new_alt_aliases: - yield self._validate_canonical_alias( + await self._validate_canonical_alias( directory_handler, alias_str, event.room_id ) @@ -969,7 +965,7 @@ class EventCreationHandler(object): def is_inviter_member_event(e): return e.type == EventTypes.Member and e.sender == event.sender - current_state_ids = yield context.get_current_state_ids() + current_state_ids = await context.get_current_state_ids() state_to_include_ids = [ e_id @@ -978,7 +974,7 @@ class EventCreationHandler(object): or k == (EventTypes.Member, event.sender) ] - state_to_include = yield self.store.get_events(state_to_include_ids) + state_to_include = await self.store.get_events(state_to_include_ids) event.unsigned["invite_room_state"] = [ { @@ -996,8 +992,8 @@ class EventCreationHandler(object): # way? If we have been invited by a remote server, we need # to get them to sign the event. - returned_invite = yield defer.ensureDeferred( - federation_handler.send_invite(invitee.domain, event) + returned_invite = await federation_handler.send_invite( + invitee.domain, event ) event.unsigned.pop("room_state", None) @@ -1005,7 +1001,7 @@ class EventCreationHandler(object): event.signatures.update(returned_invite.signatures) if event.type == EventTypes.Redaction: - original_event = yield self.store.get_event( + original_event = await self.store.get_event( event.redacts, redact_behaviour=EventRedactBehaviour.AS_IS, get_prev_content=False, @@ -1021,14 +1017,14 @@ class EventCreationHandler(object): if original_event.room_id != event.room_id: raise SynapseError(400, "Cannot redact event from a different room") - prev_state_ids = yield context.get_prev_state_ids() - auth_events_ids = yield self.auth.compute_auth_events( + prev_state_ids = await context.get_prev_state_ids() + auth_events_ids = await self.auth.compute_auth_events( event, prev_state_ids, for_verification=True ) - auth_events = yield self.store.get_events(auth_events_ids) + auth_events = await self.store.get_events(auth_events_ids) auth_events = {(e.type, e.state_key): e for e in auth_events.values()} - room_version = yield self.store.get_room_version_id(event.room_id) + room_version = await self.store.get_room_version_id(event.room_id) room_version_obj = KNOWN_ROOM_VERSIONS[room_version] if event_auth.check_redaction( @@ -1047,11 +1043,11 @@ class EventCreationHandler(object): event.internal_metadata.recheck_redaction = False if event.type == EventTypes.Create: - prev_state_ids = yield context.get_prev_state_ids() + prev_state_ids = await context.get_prev_state_ids() if prev_state_ids: raise AuthError(403, "Changing the room create event is forbidden") - event_stream_id, max_stream_id = yield self.storage.persistence.persist_event( + event_stream_id, max_stream_id = await self.storage.persistence.persist_event( event, context=context ) @@ -1059,7 +1055,7 @@ class EventCreationHandler(object): # If there's an expiry timestamp on the event, schedule its expiry. self._message_handler.maybe_schedule_expiry(event) - yield self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id) + await self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id) def _notify(): try: @@ -1083,13 +1079,12 @@ class EventCreationHandler(object): except Exception: logger.exception("Error bumping presence active time") - @defer.inlineCallbacks - def _send_dummy_events_to_fill_extremities(self): + async def _send_dummy_events_to_fill_extremities(self): """Background task to send dummy events into rooms that have a large number of extremities """ self._expire_rooms_to_exclude_from_dummy_event_insertion() - room_ids = yield self.store.get_rooms_with_many_extremities( + room_ids = await self.store.get_rooms_with_many_extremities( min_count=10, limit=5, room_id_filter=self._rooms_to_exclude_from_dummy_event_insertion.keys(), @@ -1099,9 +1094,9 @@ class EventCreationHandler(object): # For each room we need to find a joined member we can use to send # the dummy event with. - latest_event_ids = yield self.store.get_prev_events_for_room(room_id) + latest_event_ids = await self.store.get_prev_events_for_room(room_id) - members = yield self.state.get_current_users_in_room( + members = await self.state.get_current_users_in_room( room_id, latest_event_ids=latest_event_ids ) dummy_event_sent = False @@ -1110,7 +1105,7 @@ class EventCreationHandler(object): continue requester = create_requester(user_id) try: - event, context = yield self.create_event( + event, context = await self.create_event( requester, { "type": "org.matrix.dummy_event", @@ -1123,7 +1118,7 @@ class EventCreationHandler(object): event.internal_metadata.proactively_send = False - yield self.send_nonmember_event( + await self.send_nonmember_event( requester, event, context, ratelimit=False ) dummy_event_sent = True diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 6aa1c0f5e0..302efc1b9a 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -141,8 +141,9 @@ class BaseProfileHandler(BaseHandler): return result["displayname"] - @defer.inlineCallbacks - def set_displayname(self, target_user, requester, new_displayname, by_admin=False): + async def set_displayname( + self, target_user, requester, new_displayname, by_admin=False + ): """Set the displayname of a user Args: @@ -158,7 +159,7 @@ class BaseProfileHandler(BaseHandler): raise AuthError(400, "Cannot set another user's displayname") if not by_admin and not self.hs.config.enable_set_displayname: - profile = yield self.store.get_profileinfo(target_user.localpart) + profile = await self.store.get_profileinfo(target_user.localpart) if profile.display_name: raise SynapseError( 400, @@ -180,15 +181,15 @@ class BaseProfileHandler(BaseHandler): if by_admin: requester = create_requester(target_user) - yield self.store.set_profile_displayname(target_user.localpart, new_displayname) + await self.store.set_profile_displayname(target_user.localpart, new_displayname) if self.hs.config.user_directory_search_all_users: - profile = yield self.store.get_profileinfo(target_user.localpart) - yield self.user_directory_handler.handle_local_profile_change( + profile = await self.store.get_profileinfo(target_user.localpart) + await self.user_directory_handler.handle_local_profile_change( target_user.to_string(), profile ) - yield self._update_join_states(requester, target_user) + await self._update_join_states(requester, target_user) @defer.inlineCallbacks def get_avatar_url(self, target_user): @@ -217,8 +218,9 @@ class BaseProfileHandler(BaseHandler): return result["avatar_url"] - @defer.inlineCallbacks - def set_avatar_url(self, target_user, requester, new_avatar_url, by_admin=False): + async def set_avatar_url( + self, target_user, requester, new_avatar_url, by_admin=False + ): """target_user is the user whose avatar_url is to be changed; auth_user is the user attempting to make this change.""" if not self.hs.is_mine(target_user): @@ -228,7 +230,7 @@ class BaseProfileHandler(BaseHandler): raise AuthError(400, "Cannot set another user's avatar_url") if not by_admin and not self.hs.config.enable_set_avatar_url: - profile = yield self.store.get_profileinfo(target_user.localpart) + profile = await self.store.get_profileinfo(target_user.localpart) if profile.avatar_url: raise SynapseError( 400, "Changing avatar is disabled on this server", Codes.FORBIDDEN @@ -243,15 +245,15 @@ class BaseProfileHandler(BaseHandler): if by_admin: requester = create_requester(target_user) - yield self.store.set_profile_avatar_url(target_user.localpart, new_avatar_url) + await self.store.set_profile_avatar_url(target_user.localpart, new_avatar_url) if self.hs.config.user_directory_search_all_users: - profile = yield self.store.get_profileinfo(target_user.localpart) - yield self.user_directory_handler.handle_local_profile_change( + profile = await self.store.get_profileinfo(target_user.localpart) + await self.user_directory_handler.handle_local_profile_change( target_user.to_string(), profile ) - yield self._update_join_states(requester, target_user) + await self._update_join_states(requester, target_user) @defer.inlineCallbacks def on_profile_query(self, args): @@ -279,21 +281,20 @@ class BaseProfileHandler(BaseHandler): return response - @defer.inlineCallbacks - def _update_join_states(self, requester, target_user): + async def _update_join_states(self, requester, target_user): if not self.hs.is_mine(target_user): return - yield self.ratelimit(requester) + await self.ratelimit(requester) - room_ids = yield self.store.get_rooms_for_user(target_user.to_string()) + room_ids = await self.store.get_rooms_for_user(target_user.to_string()) for room_id in room_ids: handler = self.hs.get_room_member_handler() try: # Assume the target_user isn't a guest, # because we don't let guests set profile or avatar data. - yield handler.update_membership( + await handler.update_membership( requester, target_user, room_id, diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 3a65b46ecd..1e6bdac0ad 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -145,9 +145,9 @@ class RegistrationHandler(BaseHandler): """Registers a new client on the server. Args: - localpart : The local part of the user ID to register. If None, + localpart: The local part of the user ID to register. If None, one will be generated. - password (unicode) : The password to assign to this user so they can + password (unicode): The password to assign to this user so they can login again. This can be None which means they cannot login again via a password (e.g. the user is an application service user). user_type (str|None): type of user. One of the values from @@ -244,7 +244,7 @@ class RegistrationHandler(BaseHandler): fail_count += 1 if not self.hs.config.user_consent_at_registration: - yield self._auto_join_rooms(user_id) + yield defer.ensureDeferred(self._auto_join_rooms(user_id)) else: logger.info( "Skipping auto-join for %s because consent is required at registration", @@ -266,8 +266,7 @@ class RegistrationHandler(BaseHandler): return user_id - @defer.inlineCallbacks - def _auto_join_rooms(self, user_id): + async def _auto_join_rooms(self, user_id): """Automatically joins users to auto join rooms - creating the room in the first place if the user is the first to be created. @@ -281,9 +280,9 @@ class RegistrationHandler(BaseHandler): # that an auto-generated support or bot user is not a real user and will never be # the user to create the room should_auto_create_rooms = False - is_real_user = yield self.store.is_real_user(user_id) + is_real_user = await self.store.is_real_user(user_id) if self.hs.config.autocreate_auto_join_rooms and is_real_user: - count = yield self.store.count_real_users() + count = await self.store.count_real_users() should_auto_create_rooms = count == 1 for r in self.hs.config.auto_join_rooms: logger.info("Auto-joining %s to %s", user_id, r) @@ -302,7 +301,7 @@ class RegistrationHandler(BaseHandler): # getting the RoomCreationHandler during init gives a dependency # loop - yield self.hs.get_room_creation_handler().create_room( + await self.hs.get_room_creation_handler().create_room( fake_requester, config={ "preset": "public_chat", @@ -311,7 +310,7 @@ class RegistrationHandler(BaseHandler): ratelimit=False, ) else: - yield self._join_user_to_room(fake_requester, r) + await self._join_user_to_room(fake_requester, r) except ConsentNotGivenError as e: # Technically not necessary to pull out this error though # moving away from bare excepts is a good thing to do. @@ -319,15 +318,14 @@ class RegistrationHandler(BaseHandler): except Exception as e: logger.error("Failed to join new user to %r: %r", r, e) - @defer.inlineCallbacks - def post_consent_actions(self, user_id): + async def post_consent_actions(self, user_id): """A series of registration actions that can only be carried out once consent has been granted Args: user_id (str): The user to join """ - yield self._auto_join_rooms(user_id) + await self._auto_join_rooms(user_id) @defer.inlineCallbacks def appservice_register(self, user_localpart, as_token): @@ -394,14 +392,13 @@ class RegistrationHandler(BaseHandler): self._next_generated_user_id += 1 return str(id) - @defer.inlineCallbacks - def _join_user_to_room(self, requester, room_identifier): + async def _join_user_to_room(self, requester, room_identifier): room_member_handler = self.hs.get_room_member_handler() if RoomID.is_valid(room_identifier): room_id = room_identifier elif RoomAlias.is_valid(room_identifier): room_alias = RoomAlias.from_string(room_identifier) - room_id, remote_room_hosts = yield room_member_handler.lookup_room_alias( + room_id, remote_room_hosts = await room_member_handler.lookup_room_alias( room_alias ) room_id = room_id.to_string() @@ -410,7 +407,7 @@ class RegistrationHandler(BaseHandler): 400, "%s was not legal room ID or room alias" % (room_identifier,) ) - yield room_member_handler.update_membership( + await room_member_handler.update_membership( requester=requester, target=requester.user, room_id=room_id, @@ -550,8 +547,7 @@ class RegistrationHandler(BaseHandler): return (device_id, access_token) - @defer.inlineCallbacks - def post_registration_actions(self, user_id, auth_result, access_token): + async def post_registration_actions(self, user_id, auth_result, access_token): """A user has completed registration Args: @@ -562,7 +558,7 @@ class RegistrationHandler(BaseHandler): device, or None if `inhibit_login` enabled. """ if self.hs.config.worker_app: - yield self._post_registration_client( + await self._post_registration_client( user_id=user_id, auth_result=auth_result, access_token=access_token ) return @@ -574,19 +570,18 @@ class RegistrationHandler(BaseHandler): if is_threepid_reserved( self.hs.config.mau_limits_reserved_threepids, threepid ): - yield self.store.upsert_monthly_active_user(user_id) + await self.store.upsert_monthly_active_user(user_id) - yield self._register_email_threepid(user_id, threepid, access_token) + await self._register_email_threepid(user_id, threepid, access_token) if auth_result and LoginType.MSISDN in auth_result: threepid = auth_result[LoginType.MSISDN] - yield self._register_msisdn_threepid(user_id, threepid) + await self._register_msisdn_threepid(user_id, threepid) if auth_result and LoginType.TERMS in auth_result: - yield self._on_user_consented(user_id, self.hs.config.user_consent_version) + await self._on_user_consented(user_id, self.hs.config.user_consent_version) - @defer.inlineCallbacks - def _on_user_consented(self, user_id, consent_version): + async def _on_user_consented(self, user_id, consent_version): """A user consented to the terms on registration Args: @@ -595,8 +590,8 @@ class RegistrationHandler(BaseHandler): consented to. """ logger.info("%s has consented to the privacy policy", user_id) - yield self.store.user_set_consent_version(user_id, consent_version) - yield self.post_consent_actions(user_id) + await self.store.user_set_consent_version(user_id, consent_version) + await self.post_consent_actions(user_id) @defer.inlineCallbacks def _register_email_threepid(self, user_id, threepid, token): diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 3d10e4b2d9..da12df7f53 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -148,17 +148,16 @@ class RoomCreationHandler(BaseHandler): return ret - @defer.inlineCallbacks - def _upgrade_room( + async def _upgrade_room( self, requester: Requester, old_room_id: str, new_version: RoomVersion ): user_id = requester.user.to_string() # start by allocating a new room id - r = yield self.store.get_room(old_room_id) + r = await self.store.get_room(old_room_id) if r is None: raise NotFoundError("Unknown room id %s" % (old_room_id,)) - new_room_id = yield self._generate_room_id( + new_room_id = await self._generate_room_id( creator_id=user_id, is_public=r["is_public"], room_version=new_version, ) @@ -169,7 +168,7 @@ class RoomCreationHandler(BaseHandler): ( tombstone_event, tombstone_context, - ) = yield self.event_creation_handler.create_event( + ) = await self.event_creation_handler.create_event( requester, { "type": EventTypes.Tombstone, @@ -183,12 +182,12 @@ class RoomCreationHandler(BaseHandler): }, token_id=requester.access_token_id, ) - old_room_version = yield self.store.get_room_version_id(old_room_id) - yield self.auth.check_from_context( + old_room_version = await self.store.get_room_version_id(old_room_id) + await self.auth.check_from_context( old_room_version, tombstone_event, tombstone_context ) - yield self.clone_existing_room( + await self.clone_existing_room( requester, old_room_id=old_room_id, new_room_id=new_room_id, @@ -197,32 +196,31 @@ class RoomCreationHandler(BaseHandler): ) # now send the tombstone - yield self.event_creation_handler.send_nonmember_event( + await self.event_creation_handler.send_nonmember_event( requester, tombstone_event, tombstone_context ) - old_room_state = yield tombstone_context.get_current_state_ids() + old_room_state = await tombstone_context.get_current_state_ids() # update any aliases - yield self._move_aliases_to_new_room( + await self._move_aliases_to_new_room( requester, old_room_id, new_room_id, old_room_state ) # Copy over user push rules, tags and migrate room directory state - yield self.room_member_handler.transfer_room_state_on_room_upgrade( + await self.room_member_handler.transfer_room_state_on_room_upgrade( old_room_id, new_room_id ) # finally, shut down the PLs in the old room, and update them in the new # room. - yield self._update_upgraded_room_pls( + await self._update_upgraded_room_pls( requester, old_room_id, new_room_id, old_room_state, ) return new_room_id - @defer.inlineCallbacks - def _update_upgraded_room_pls( + async def _update_upgraded_room_pls( self, requester: Requester, old_room_id: str, @@ -249,7 +247,7 @@ class RoomCreationHandler(BaseHandler): ) return - old_room_pl_state = yield self.store.get_event(old_room_pl_event_id) + old_room_pl_state = await self.store.get_event(old_room_pl_event_id) # we try to stop regular users from speaking by setting the PL required # to send regular events and invites to 'Moderator' level. That's normally @@ -278,7 +276,7 @@ class RoomCreationHandler(BaseHandler): if updated: try: - yield self.event_creation_handler.create_and_send_nonmember_event( + await self.event_creation_handler.create_and_send_nonmember_event( requester, { "type": EventTypes.PowerLevels, @@ -292,7 +290,7 @@ class RoomCreationHandler(BaseHandler): except AuthError as e: logger.warning("Unable to update PLs in old room: %s", e) - yield self.event_creation_handler.create_and_send_nonmember_event( + await self.event_creation_handler.create_and_send_nonmember_event( requester, { "type": EventTypes.PowerLevels, @@ -304,8 +302,7 @@ class RoomCreationHandler(BaseHandler): ratelimit=False, ) - @defer.inlineCallbacks - def clone_existing_room( + async def clone_existing_room( self, requester: Requester, old_room_id: str, @@ -338,7 +335,7 @@ class RoomCreationHandler(BaseHandler): # Check if old room was non-federatable # Get old room's create event - old_room_create_event = yield self.store.get_create_event_for_room(old_room_id) + old_room_create_event = await self.store.get_create_event_for_room(old_room_id) # Check if the create event specified a non-federatable room if not old_room_create_event.content.get("m.federate", True): @@ -361,11 +358,11 @@ class RoomCreationHandler(BaseHandler): (EventTypes.PowerLevels, ""), ) - old_room_state_ids = yield self.store.get_filtered_current_state_ids( + old_room_state_ids = await self.store.get_filtered_current_state_ids( old_room_id, StateFilter.from_types(types_to_copy) ) # map from event_id to BaseEvent - old_room_state_events = yield self.store.get_events(old_room_state_ids.values()) + old_room_state_events = await self.store.get_events(old_room_state_ids.values()) for k, old_event_id in iteritems(old_room_state_ids): old_event = old_room_state_events.get(old_event_id) @@ -400,7 +397,7 @@ class RoomCreationHandler(BaseHandler): if current_power_level < needed_power_level: power_levels["users"][user_id] = needed_power_level - yield self._send_events_for_new_room( + await self._send_events_for_new_room( requester, new_room_id, # we expect to override all the presets with initial_state, so this is @@ -412,12 +409,12 @@ class RoomCreationHandler(BaseHandler): ) # Transfer membership events - old_room_member_state_ids = yield self.store.get_filtered_current_state_ids( + old_room_member_state_ids = await self.store.get_filtered_current_state_ids( old_room_id, StateFilter.from_types([(EventTypes.Member, None)]) ) # map from event_id to BaseEvent - old_room_member_state_events = yield self.store.get_events( + old_room_member_state_events = await self.store.get_events( old_room_member_state_ids.values() ) for k, old_event in iteritems(old_room_member_state_events): @@ -426,7 +423,7 @@ class RoomCreationHandler(BaseHandler): "membership" in old_event.content and old_event.content["membership"] == "ban" ): - yield self.room_member_handler.update_membership( + await self.room_member_handler.update_membership( requester, UserID.from_string(old_event["state_key"]), new_room_id, @@ -438,8 +435,7 @@ class RoomCreationHandler(BaseHandler): # XXX invites/joins # XXX 3pid invites - @defer.inlineCallbacks - def _move_aliases_to_new_room( + async def _move_aliases_to_new_room( self, requester: Requester, old_room_id: str, @@ -448,13 +444,13 @@ class RoomCreationHandler(BaseHandler): ): directory_handler = self.hs.get_handlers().directory_handler - aliases = yield self.store.get_aliases_for_room(old_room_id) + aliases = await self.store.get_aliases_for_room(old_room_id) # check to see if we have a canonical alias. canonical_alias_event = None canonical_alias_event_id = old_room_state.get((EventTypes.CanonicalAlias, "")) if canonical_alias_event_id: - canonical_alias_event = yield self.store.get_event(canonical_alias_event_id) + canonical_alias_event = await self.store.get_event(canonical_alias_event_id) # first we try to remove the aliases from the old room (we suppress sending # the room_aliases event until the end). @@ -472,7 +468,7 @@ class RoomCreationHandler(BaseHandler): for alias_str in aliases: alias = RoomAlias.from_string(alias_str) try: - yield directory_handler.delete_association(requester, alias) + await directory_handler.delete_association(requester, alias) removed_aliases.append(alias_str) except SynapseError as e: logger.warning("Unable to remove alias %s from old room: %s", alias, e) @@ -485,7 +481,7 @@ class RoomCreationHandler(BaseHandler): # we can now add any aliases we successfully removed to the new room. for alias in removed_aliases: try: - yield directory_handler.create_association( + await directory_handler.create_association( requester, RoomAlias.from_string(alias), new_room_id, @@ -502,7 +498,7 @@ class RoomCreationHandler(BaseHandler): # alias event for the new room with a copy of the information. try: if canonical_alias_event: - yield self.event_creation_handler.create_and_send_nonmember_event( + await self.event_creation_handler.create_and_send_nonmember_event( requester, { "type": EventTypes.CanonicalAlias, @@ -518,8 +514,9 @@ class RoomCreationHandler(BaseHandler): # we returned the new room to the client at this point. logger.error("Unable to send updated alias events in new room: %s", e) - @defer.inlineCallbacks - def create_room(self, requester, config, ratelimit=True, creator_join_profile=None): + async def create_room( + self, requester, config, ratelimit=True, creator_join_profile=None + ): """ Creates a new room. Args: @@ -547,7 +544,7 @@ class RoomCreationHandler(BaseHandler): """ user_id = requester.user.to_string() - yield self.auth.check_auth_blocking(user_id) + await self.auth.check_auth_blocking(user_id) if ( self._server_notices_mxid is not None @@ -556,11 +553,11 @@ class RoomCreationHandler(BaseHandler): # allow the server notices mxid to create rooms is_requester_admin = True else: - is_requester_admin = yield self.auth.is_server_admin(requester.user) + is_requester_admin = await self.auth.is_server_admin(requester.user) # Check whether the third party rules allows/changes the room create # request. - event_allowed = yield self.third_party_event_rules.on_create_room( + event_allowed = await self.third_party_event_rules.on_create_room( requester, config, is_requester_admin=is_requester_admin ) if not event_allowed: @@ -574,7 +571,7 @@ class RoomCreationHandler(BaseHandler): raise SynapseError(403, "You are not permitted to create rooms") if ratelimit: - yield self.ratelimit(requester) + await self.ratelimit(requester) room_version_id = config.get( "room_version", self.config.default_room_version.identifier @@ -597,7 +594,7 @@ class RoomCreationHandler(BaseHandler): raise SynapseError(400, "Invalid characters in room alias") room_alias = RoomAlias(config["room_alias_name"], self.hs.hostname) - mapping = yield self.store.get_association_from_room_alias(room_alias) + mapping = await self.store.get_association_from_room_alias(room_alias) if mapping: raise SynapseError(400, "Room alias already taken", Codes.ROOM_IN_USE) @@ -612,7 +609,7 @@ class RoomCreationHandler(BaseHandler): except Exception: raise SynapseError(400, "Invalid user_id: %s" % (i,)) - yield self.event_creation_handler.assert_accepted_privacy_policy(requester) + await self.event_creation_handler.assert_accepted_privacy_policy(requester) power_level_content_override = config.get("power_level_content_override") if ( @@ -631,13 +628,13 @@ class RoomCreationHandler(BaseHandler): visibility = config.get("visibility", None) is_public = visibility == "public" - room_id = yield self._generate_room_id( + room_id = await self._generate_room_id( creator_id=user_id, is_public=is_public, room_version=room_version, ) directory_handler = self.hs.get_handlers().directory_handler if room_alias: - yield directory_handler.create_association( + await directory_handler.create_association( requester=requester, room_id=room_id, room_alias=room_alias, @@ -670,7 +667,7 @@ class RoomCreationHandler(BaseHandler): # override any attempt to set room versions via the creation_content creation_content["room_version"] = room_version.identifier - yield self._send_events_for_new_room( + await self._send_events_for_new_room( requester, room_id, preset_config=preset_config, @@ -684,7 +681,7 @@ class RoomCreationHandler(BaseHandler): if "name" in config: name = config["name"] - yield self.event_creation_handler.create_and_send_nonmember_event( + await self.event_creation_handler.create_and_send_nonmember_event( requester, { "type": EventTypes.Name, @@ -698,7 +695,7 @@ class RoomCreationHandler(BaseHandler): if "topic" in config: topic = config["topic"] - yield self.event_creation_handler.create_and_send_nonmember_event( + await self.event_creation_handler.create_and_send_nonmember_event( requester, { "type": EventTypes.Topic, @@ -716,7 +713,7 @@ class RoomCreationHandler(BaseHandler): if is_direct: content["is_direct"] = is_direct - yield self.room_member_handler.update_membership( + await self.room_member_handler.update_membership( requester, UserID.from_string(invitee), room_id, @@ -730,7 +727,7 @@ class RoomCreationHandler(BaseHandler): id_access_token = invite_3pid.get("id_access_token") # optional address = invite_3pid["address"] medium = invite_3pid["medium"] - yield self.hs.get_room_member_handler().do_3pid_invite( + await self.hs.get_room_member_handler().do_3pid_invite( room_id, requester.user, medium, @@ -748,8 +745,7 @@ class RoomCreationHandler(BaseHandler): return result - @defer.inlineCallbacks - def _send_events_for_new_room( + async def _send_events_for_new_room( self, creator, # A Requester object. room_id, @@ -769,11 +765,10 @@ class RoomCreationHandler(BaseHandler): return e - @defer.inlineCallbacks - def send(etype, content, **kwargs): + async def send(etype, content, **kwargs): event = create(etype, content, **kwargs) logger.debug("Sending %s in new room", etype) - yield self.event_creation_handler.create_and_send_nonmember_event( + await self.event_creation_handler.create_and_send_nonmember_event( creator, event, ratelimit=False ) @@ -784,10 +779,10 @@ class RoomCreationHandler(BaseHandler): event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""} creation_content.update({"creator": creator_id}) - yield send(etype=EventTypes.Create, content=creation_content) + await send(etype=EventTypes.Create, content=creation_content) logger.debug("Sending %s in new room", EventTypes.Member) - yield self.room_member_handler.update_membership( + await self.room_member_handler.update_membership( creator, creator.user, room_id, @@ -800,7 +795,7 @@ class RoomCreationHandler(BaseHandler): # of the first events that get sent into a room. pl_content = initial_state.pop((EventTypes.PowerLevels, ""), None) if pl_content is not None: - yield send(etype=EventTypes.PowerLevels, content=pl_content) + await send(etype=EventTypes.PowerLevels, content=pl_content) else: power_level_content = { "users": {creator_id: 100}, @@ -833,33 +828,33 @@ class RoomCreationHandler(BaseHandler): if power_level_content_override: power_level_content.update(power_level_content_override) - yield send(etype=EventTypes.PowerLevels, content=power_level_content) + await send(etype=EventTypes.PowerLevels, content=power_level_content) if room_alias and (EventTypes.CanonicalAlias, "") not in initial_state: - yield send( + await send( etype=EventTypes.CanonicalAlias, content={"alias": room_alias.to_string()}, ) if (EventTypes.JoinRules, "") not in initial_state: - yield send( + await send( etype=EventTypes.JoinRules, content={"join_rule": config["join_rules"]} ) if (EventTypes.RoomHistoryVisibility, "") not in initial_state: - yield send( + await send( etype=EventTypes.RoomHistoryVisibility, content={"history_visibility": config["history_visibility"]}, ) if config["guest_can_join"]: if (EventTypes.GuestAccess, "") not in initial_state: - yield send( + await send( etype=EventTypes.GuestAccess, content={"guest_access": "can_join"} ) for (etype, state_key), content in initial_state.items(): - yield send(etype=etype, state_key=state_key, content=content) + await send(etype=etype, state_key=state_key, content=content) @defer.inlineCallbacks def _generate_room_id( diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index c3ee8db4f0..53b49bc15f 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -142,8 +142,7 @@ class RoomMemberHandler(object): """ raise NotImplementedError() - @defer.inlineCallbacks - def _local_membership_update( + async def _local_membership_update( self, requester, target, @@ -164,7 +163,7 @@ class RoomMemberHandler(object): if requester.is_guest: content["kind"] = "guest" - event, context = yield self.event_creation_handler.create_event( + event, context = await self.event_creation_handler.create_event( requester, { "type": EventTypes.Member, @@ -182,18 +181,18 @@ class RoomMemberHandler(object): ) # Check if this event matches the previous membership event for the user. - duplicate = yield self.event_creation_handler.deduplicate_state_event( + duplicate = await self.event_creation_handler.deduplicate_state_event( event, context ) if duplicate is not None: # Discard the new event since this membership change is a no-op. return duplicate - yield self.event_creation_handler.handle_new_client_event( + await self.event_creation_handler.handle_new_client_event( requester, event, context, extra_users=[target], ratelimit=ratelimit ) - prev_state_ids = yield context.get_prev_state_ids() + prev_state_ids = await context.get_prev_state_ids() prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None) @@ -203,15 +202,15 @@ class RoomMemberHandler(object): # info. newly_joined = True if prev_member_event_id: - prev_member_event = yield self.store.get_event(prev_member_event_id) + prev_member_event = await self.store.get_event(prev_member_event_id) newly_joined = prev_member_event.membership != Membership.JOIN if newly_joined: - yield self._user_joined_room(target, room_id) + await self._user_joined_room(target, room_id) elif event.membership == Membership.LEAVE: if prev_member_event_id: - prev_member_event = yield self.store.get_event(prev_member_event_id) + prev_member_event = await self.store.get_event(prev_member_event_id) if prev_member_event.membership == Membership.JOIN: - yield self._user_left_room(target, room_id) + await self._user_left_room(target, room_id) return event @@ -253,8 +252,7 @@ class RoomMemberHandler(object): for tag, tag_content in room_tags.items(): yield self.store.add_tag_to_room(user_id, new_room_id, tag, tag_content) - @defer.inlineCallbacks - def update_membership( + async def update_membership( self, requester, target, @@ -269,8 +267,8 @@ class RoomMemberHandler(object): ): key = (room_id,) - with (yield self.member_linearizer.queue(key)): - result = yield self._update_membership( + with (await self.member_linearizer.queue(key)): + result = await self._update_membership( requester, target, room_id, @@ -285,8 +283,7 @@ class RoomMemberHandler(object): return result - @defer.inlineCallbacks - def _update_membership( + async def _update_membership( self, requester, target, @@ -321,7 +318,7 @@ class RoomMemberHandler(object): # if this is a join with a 3pid signature, we may need to turn a 3pid # invite into a normal invite before we can handle the join. if third_party_signed is not None: - yield self.federation_handler.exchange_third_party_invite( + await self.federation_handler.exchange_third_party_invite( third_party_signed["sender"], target.to_string(), room_id, @@ -332,7 +329,7 @@ class RoomMemberHandler(object): remote_room_hosts = [] if effective_membership_state not in ("leave", "ban"): - is_blocked = yield self.store.is_room_blocked(room_id) + is_blocked = await self.store.is_room_blocked(room_id) if is_blocked: raise SynapseError(403, "This room has been blocked on this server") @@ -351,7 +348,7 @@ class RoomMemberHandler(object): is_requester_admin = True else: - is_requester_admin = yield self.auth.is_server_admin(requester.user) + is_requester_admin = await self.auth.is_server_admin(requester.user) if not is_requester_admin: if self.config.block_non_admin_invites: @@ -370,9 +367,9 @@ class RoomMemberHandler(object): if block_invite: raise SynapseError(403, "Invites have been disabled on this server") - latest_event_ids = yield self.store.get_prev_events_for_room(room_id) + latest_event_ids = await self.store.get_prev_events_for_room(room_id) - current_state_ids = yield self.state_handler.get_current_state_ids( + current_state_ids = await self.state_handler.get_current_state_ids( room_id, latest_event_ids=latest_event_ids ) @@ -381,7 +378,7 @@ class RoomMemberHandler(object): # transitions and generic otherwise old_state_id = current_state_ids.get((EventTypes.Member, target.to_string())) if old_state_id: - old_state = yield self.store.get_event(old_state_id, allow_none=True) + old_state = await self.store.get_event(old_state_id, allow_none=True) old_membership = old_state.content.get("membership") if old_state else None if action == "unban" and old_membership != "ban": raise SynapseError( @@ -413,7 +410,7 @@ class RoomMemberHandler(object): old_membership == Membership.INVITE and effective_membership_state == Membership.LEAVE ): - is_blocked = yield self._is_server_notice_room(room_id) + is_blocked = await self._is_server_notice_room(room_id) if is_blocked: raise SynapseError( http_client.FORBIDDEN, @@ -424,18 +421,18 @@ class RoomMemberHandler(object): if action == "kick": raise AuthError(403, "The target user is not in the room") - is_host_in_room = yield self._is_host_in_room(current_state_ids) + is_host_in_room = await self._is_host_in_room(current_state_ids) if effective_membership_state == Membership.JOIN: if requester.is_guest: - guest_can_join = yield self._can_guest_join(current_state_ids) + guest_can_join = await self._can_guest_join(current_state_ids) if not guest_can_join: # This should be an auth check, but guests are a local concept, # so don't really fit into the general auth process. raise AuthError(403, "Guest access not allowed") if not is_host_in_room: - inviter = yield self._get_inviter(target.to_string(), room_id) + inviter = await self._get_inviter(target.to_string(), room_id) if inviter and not self.hs.is_mine(inviter): remote_room_hosts.append(inviter.domain) @@ -443,13 +440,13 @@ class RoomMemberHandler(object): profile = self.profile_handler if not content_specified: - content["displayname"] = yield profile.get_displayname(target) - content["avatar_url"] = yield profile.get_avatar_url(target) + content["displayname"] = await profile.get_displayname(target) + content["avatar_url"] = await profile.get_avatar_url(target) if requester.is_guest: content["kind"] = "guest" - remote_join_response = yield self._remote_join( + remote_join_response = await self._remote_join( requester, remote_room_hosts, room_id, target, content ) @@ -458,7 +455,7 @@ class RoomMemberHandler(object): elif effective_membership_state == Membership.LEAVE: if not is_host_in_room: # perhaps we've been invited - inviter = yield self._get_inviter(target.to_string(), room_id) + inviter = await self._get_inviter(target.to_string(), room_id) if not inviter: raise SynapseError(404, "Not a known room") @@ -472,12 +469,12 @@ class RoomMemberHandler(object): else: # send the rejection to the inviter's HS. remote_room_hosts = remote_room_hosts + [inviter.domain] - res = yield self._remote_reject_invite( + res = await self._remote_reject_invite( requester, remote_room_hosts, room_id, target, content, ) return res - res = yield self._local_membership_update( + res = await self._local_membership_update( requester=requester, target=target, room_id=room_id, @@ -572,8 +569,7 @@ class RoomMemberHandler(object): ) continue - @defer.inlineCallbacks - def send_membership_event(self, requester, event, context, ratelimit=True): + async def send_membership_event(self, requester, event, context, ratelimit=True): """ Change the membership status of a user in a room. @@ -599,27 +595,27 @@ class RoomMemberHandler(object): else: requester = types.create_requester(target_user) - prev_event = yield self.event_creation_handler.deduplicate_state_event( + prev_event = await self.event_creation_handler.deduplicate_state_event( event, context ) if prev_event is not None: return - prev_state_ids = yield context.get_prev_state_ids() + prev_state_ids = await context.get_prev_state_ids() if event.membership == Membership.JOIN: if requester.is_guest: - guest_can_join = yield self._can_guest_join(prev_state_ids) + guest_can_join = await self._can_guest_join(prev_state_ids) if not guest_can_join: # This should be an auth check, but guests are a local concept, # so don't really fit into the general auth process. raise AuthError(403, "Guest access not allowed") if event.membership not in (Membership.LEAVE, Membership.BAN): - is_blocked = yield self.store.is_room_blocked(room_id) + is_blocked = await self.store.is_room_blocked(room_id) if is_blocked: raise SynapseError(403, "This room has been blocked on this server") - yield self.event_creation_handler.handle_new_client_event( + await self.event_creation_handler.handle_new_client_event( requester, event, context, extra_users=[target_user], ratelimit=ratelimit ) @@ -633,15 +629,15 @@ class RoomMemberHandler(object): # info. newly_joined = True if prev_member_event_id: - prev_member_event = yield self.store.get_event(prev_member_event_id) + prev_member_event = await self.store.get_event(prev_member_event_id) newly_joined = prev_member_event.membership != Membership.JOIN if newly_joined: - yield self._user_joined_room(target_user, room_id) + await self._user_joined_room(target_user, room_id) elif event.membership == Membership.LEAVE: if prev_member_event_id: - prev_member_event = yield self.store.get_event(prev_member_event_id) + prev_member_event = await self.store.get_event(prev_member_event_id) if prev_member_event.membership == Membership.JOIN: - yield self._user_left_room(target_user, room_id) + await self._user_left_room(target_user, room_id) @defer.inlineCallbacks def _can_guest_join(self, current_state_ids): @@ -699,8 +695,7 @@ class RoomMemberHandler(object): if invite: return UserID.from_string(invite.sender) - @defer.inlineCallbacks - def do_3pid_invite( + async def do_3pid_invite( self, room_id, inviter, @@ -712,7 +707,7 @@ class RoomMemberHandler(object): id_access_token=None, ): if self.config.block_non_admin_invites: - is_requester_admin = yield self.auth.is_server_admin(requester.user) + is_requester_admin = await self.auth.is_server_admin(requester.user) if not is_requester_admin: raise SynapseError( 403, "Invites have been disabled on this server", Codes.FORBIDDEN @@ -720,9 +715,9 @@ class RoomMemberHandler(object): # We need to rate limit *before* we send out any 3PID invites, so we # can't just rely on the standard ratelimiting of events. - yield self.base_handler.ratelimit(requester) + await self.base_handler.ratelimit(requester) - can_invite = yield self.third_party_event_rules.check_threepid_can_be_invited( + can_invite = await self.third_party_event_rules.check_threepid_can_be_invited( medium, address, room_id ) if not can_invite: @@ -737,16 +732,16 @@ class RoomMemberHandler(object): 403, "Looking up third-party identifiers is denied from this server" ) - invitee = yield self.identity_handler.lookup_3pid( + invitee = await self.identity_handler.lookup_3pid( id_server, medium, address, id_access_token ) if invitee: - yield self.update_membership( + await self.update_membership( requester, UserID.from_string(invitee), room_id, "invite", txn_id=txn_id ) else: - yield self._make_and_store_3pid_invite( + await self._make_and_store_3pid_invite( requester, id_server, medium, @@ -757,8 +752,7 @@ class RoomMemberHandler(object): id_access_token=id_access_token, ) - @defer.inlineCallbacks - def _make_and_store_3pid_invite( + async def _make_and_store_3pid_invite( self, requester, id_server, @@ -769,7 +763,7 @@ class RoomMemberHandler(object): txn_id, id_access_token=None, ): - room_state = yield self.state_handler.get_current_state(room_id) + room_state = await self.state_handler.get_current_state(room_id) inviter_display_name = "" inviter_avatar_url = "" @@ -807,7 +801,7 @@ class RoomMemberHandler(object): public_keys, fallback_public_key, display_name, - ) = yield self.identity_handler.ask_id_server_for_third_party_invite( + ) = await self.identity_handler.ask_id_server_for_third_party_invite( requester=requester, id_server=id_server, medium=medium, @@ -823,7 +817,7 @@ class RoomMemberHandler(object): id_access_token=id_access_token, ) - yield self.event_creation_handler.create_and_send_nonmember_event( + await self.event_creation_handler.create_and_send_nonmember_event( requester, { "type": EventTypes.ThirdPartyInvite, @@ -917,8 +911,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): return complexity["v1"] > max_complexity - @defer.inlineCallbacks - def _remote_join(self, requester, remote_room_hosts, room_id, user, content): + async def _remote_join(self, requester, remote_room_hosts, room_id, user, content): """Implements RoomMemberHandler._remote_join """ # filter ourselves out of remote_room_hosts: do_invite_join ignores it @@ -933,7 +926,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): if self.hs.config.limit_remote_rooms.enabled: # Fetch the room complexity - too_complex = yield self._is_remote_room_too_complex( + too_complex = await self._is_remote_room_too_complex( room_id, remote_room_hosts ) if too_complex is True: @@ -947,12 +940,10 @@ class RoomMemberMasterHandler(RoomMemberHandler): # join dance for now, since we're kinda implicitly checking # that we are allowed to join when we decide whether or not we # need to do the invite/join dance. - yield defer.ensureDeferred( - self.federation_handler.do_invite_join( - remote_room_hosts, room_id, user.to_string(), content - ) + await self.federation_handler.do_invite_join( + remote_room_hosts, room_id, user.to_string(), content ) - yield self._user_joined_room(user, room_id) + await self._user_joined_room(user, room_id) # Check the room we just joined wasn't too large, if we didn't fetch the # complexity of it before. @@ -962,7 +953,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): return # Check again, but with the local state events - too_complex = yield self._is_local_room_too_complex(room_id) + too_complex = await self._is_local_room_too_complex(room_id) if too_complex is False: # We're under the limit. @@ -970,7 +961,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): # The room is too large. Leave. requester = types.create_requester(user, None, False, None) - yield self.update_membership( + await self.update_membership( requester=requester, target=user, room_id=room_id, action="leave" ) raise SynapseError( @@ -1008,12 +999,12 @@ class RoomMemberMasterHandler(RoomMemberHandler): def _user_joined_room(self, target, room_id): """Implements RoomMemberHandler._user_joined_room """ - return user_joined_room(self.distributor, target, room_id) + return defer.succeed(user_joined_room(self.distributor, target, room_id)) def _user_left_room(self, target, room_id): """Implements RoomMemberHandler._user_left_room """ - return user_left_room(self.distributor, target, room_id) + return defer.succeed(user_left_room(self.distributor, target, room_id)) @defer.inlineCallbacks def forget(self, user, room_id): diff --git a/synapse/server_notices/consent_server_notices.py b/synapse/server_notices/consent_server_notices.py index 5736c56032..3bf330da49 100644 --- a/synapse/server_notices/consent_server_notices.py +++ b/synapse/server_notices/consent_server_notices.py @@ -16,8 +16,6 @@ import logging from six import iteritems, string_types -from twisted.internet import defer - from synapse.api.errors import SynapseError from synapse.api.urls import ConsentURIBuilder from synapse.config import ConfigError @@ -59,8 +57,7 @@ class ConsentServerNotices(object): self._consent_uri_builder = ConsentURIBuilder(hs.config) - @defer.inlineCallbacks - def maybe_send_server_notice_to_user(self, user_id): + async def maybe_send_server_notice_to_user(self, user_id): """Check if we need to send a notice to this user, and does so if so Args: @@ -78,7 +75,7 @@ class ConsentServerNotices(object): return self._users_in_progress.add(user_id) try: - u = yield self._store.get_user_by_id(user_id) + u = await self._store.get_user_by_id(user_id) if u["is_guest"] and not self._send_to_guests: # don't send to guests @@ -100,8 +97,8 @@ class ConsentServerNotices(object): content = copy_with_str_subst( self._server_notice_content, {"consent_uri": consent_uri} ) - yield self._server_notices_manager.send_notice(user_id, content) - yield self._store.user_set_consent_server_notice_sent( + await self._server_notices_manager.send_notice(user_id, content) + await self._store.user_set_consent_server_notice_sent( user_id, self._current_consent_version ) except SynapseError as e: diff --git a/synapse/server_notices/resource_limits_server_notices.py b/synapse/server_notices/resource_limits_server_notices.py index ce4a828894..971771b8b2 100644 --- a/synapse/server_notices/resource_limits_server_notices.py +++ b/synapse/server_notices/resource_limits_server_notices.py @@ -50,8 +50,7 @@ class ResourceLimitsServerNotices(object): self._notifier = hs.get_notifier() - @defer.inlineCallbacks - def maybe_send_server_notice_to_user(self, user_id): + async def maybe_send_server_notice_to_user(self, user_id): """Check if we need to send a notice to this user, this will be true in two cases. 1. The server has reached its limit does not reflect this @@ -74,13 +73,13 @@ class ResourceLimitsServerNotices(object): # Don't try and send server notices unless they've been enabled return - timestamp = yield self._store.user_last_seen_monthly_active(user_id) + timestamp = await self._store.user_last_seen_monthly_active(user_id) if timestamp is None: # This user will be blocked from receiving the notice anyway. # In practice, not sure we can ever get here return - room_id = yield self._server_notices_manager.get_or_create_notice_room_for_user( + room_id = await self._server_notices_manager.get_or_create_notice_room_for_user( user_id ) @@ -88,10 +87,10 @@ class ResourceLimitsServerNotices(object): logger.warning("Failed to get server notices room") return - yield self._check_and_set_tags(user_id, room_id) + await self._check_and_set_tags(user_id, room_id) # Determine current state of room - currently_blocked, ref_events = yield self._is_room_currently_blocked(room_id) + currently_blocked, ref_events = await self._is_room_currently_blocked(room_id) limit_msg = None limit_type = None @@ -99,7 +98,7 @@ class ResourceLimitsServerNotices(object): # Normally should always pass in user_id to check_auth_blocking # if you have it, but in this case are checking what would happen # to other users if they were to arrive. - yield self._auth.check_auth_blocking() + await self._auth.check_auth_blocking() except ResourceLimitError as e: limit_msg = e.msg limit_type = e.limit_type @@ -112,22 +111,21 @@ class ResourceLimitsServerNotices(object): # We have hit the MAU limit, but MAU alerting is disabled: # reset room if necessary and return if currently_blocked: - self._remove_limit_block_notification(user_id, ref_events) + await self._remove_limit_block_notification(user_id, ref_events) return if currently_blocked and not limit_msg: # Room is notifying of a block, when it ought not to be. - yield self._remove_limit_block_notification(user_id, ref_events) + await self._remove_limit_block_notification(user_id, ref_events) elif not currently_blocked and limit_msg: # Room is not notifying of a block, when it ought to be. - yield self._apply_limit_block_notification( + await self._apply_limit_block_notification( user_id, limit_msg, limit_type ) except SynapseError as e: logger.error("Error sending resource limits server notice: %s", e) - @defer.inlineCallbacks - def _remove_limit_block_notification(self, user_id, ref_events): + async def _remove_limit_block_notification(self, user_id, ref_events): """Utility method to remove limit block notifications from the server notices room. @@ -137,12 +135,13 @@ class ResourceLimitsServerNotices(object): limit blocking and need to be preserved. """ content = {"pinned": ref_events} - yield self._server_notices_manager.send_notice( + await self._server_notices_manager.send_notice( user_id, content, EventTypes.Pinned, "" ) - @defer.inlineCallbacks - def _apply_limit_block_notification(self, user_id, event_body, event_limit_type): + async def _apply_limit_block_notification( + self, user_id, event_body, event_limit_type + ): """Utility method to apply limit block notifications in the server notices room. @@ -159,12 +158,12 @@ class ResourceLimitsServerNotices(object): "admin_contact": self._config.admin_contact, "limit_type": event_limit_type, } - event = yield self._server_notices_manager.send_notice( + event = await self._server_notices_manager.send_notice( user_id, content, EventTypes.Message ) content = {"pinned": [event.event_id]} - yield self._server_notices_manager.send_notice( + await self._server_notices_manager.send_notice( user_id, content, EventTypes.Pinned, "" ) @@ -198,7 +197,7 @@ class ResourceLimitsServerNotices(object): room_id(str): The room id of the server notices room Returns: - + Deferred[Tuple[bool, List]]: bool: Is the room currently blocked list: The list of pinned events that are unrelated to limit blocking This list can be used as a convenience in the case where the block diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py index bf0943f265..999c621b92 100644 --- a/synapse/server_notices/server_notices_manager.py +++ b/synapse/server_notices/server_notices_manager.py @@ -14,11 +14,9 @@ # limitations under the License. import logging -from twisted.internet import defer - from synapse.api.constants import EventTypes, Membership, RoomCreationPreset from synapse.types import UserID, create_requester -from synapse.util.caches.descriptors import cachedInlineCallbacks +from synapse.util.caches.descriptors import cached logger = logging.getLogger(__name__) @@ -51,8 +49,7 @@ class ServerNoticesManager(object): """ return self._config.server_notices_mxid is not None - @defer.inlineCallbacks - def send_notice( + async def send_notice( self, user_id, event_content, type=EventTypes.Message, state_key=None ): """Send a notice to the given user @@ -68,8 +65,8 @@ class ServerNoticesManager(object): Returns: Deferred[FrozenEvent] """ - room_id = yield self.get_or_create_notice_room_for_user(user_id) - yield self.maybe_invite_user_to_room(user_id, room_id) + room_id = await self.get_or_create_notice_room_for_user(user_id) + await self.maybe_invite_user_to_room(user_id, room_id) system_mxid = self._config.server_notices_mxid requester = create_requester(system_mxid) @@ -86,13 +83,13 @@ class ServerNoticesManager(object): if state_key is not None: event_dict["state_key"] = state_key - res = yield self._event_creation_handler.create_and_send_nonmember_event( + res = await self._event_creation_handler.create_and_send_nonmember_event( requester, event_dict, ratelimit=False ) return res - @cachedInlineCallbacks() - def get_or_create_notice_room_for_user(self, user_id): + @cached() + async def get_or_create_notice_room_for_user(self, user_id): """Get the room for notices for a given user If we have not yet created a notice room for this user, create it, but don't @@ -109,7 +106,7 @@ class ServerNoticesManager(object): assert self._is_mine_id(user_id), "Cannot send server notices to remote users" - rooms = yield self._store.get_rooms_for_local_user_where_membership_is( + rooms = await self._store.get_rooms_for_local_user_where_membership_is( user_id, [Membership.INVITE, Membership.JOIN] ) for room in rooms: @@ -118,7 +115,7 @@ class ServerNoticesManager(object): # be joined. This is kinda deliberate, in that if somebody somehow # manages to invite the system user to a room, that doesn't make it # the server notices room. - user_ids = yield self._store.get_users_in_room(room.room_id) + user_ids = await self._store.get_users_in_room(room.room_id) if self.server_notices_mxid in user_ids: # we found a room which our user shares with the system notice # user @@ -146,7 +143,7 @@ class ServerNoticesManager(object): } requester = create_requester(self.server_notices_mxid) - info = yield self._room_creation_handler.create_room( + info = await self._room_creation_handler.create_room( requester, config={ "preset": RoomCreationPreset.PRIVATE_CHAT, @@ -158,7 +155,7 @@ class ServerNoticesManager(object): ) room_id = info["room_id"] - max_id = yield self._store.add_tag_to_room( + max_id = await self._store.add_tag_to_room( user_id, room_id, SERVER_NOTICE_ROOM_TAG, {} ) self._notifier.on_new_event("account_data_key", max_id, users=[user_id]) @@ -166,8 +163,7 @@ class ServerNoticesManager(object): logger.info("Created server notices room %s for %s", room_id, user_id) return room_id - @defer.inlineCallbacks - def maybe_invite_user_to_room(self, user_id: str, room_id: str): + async def maybe_invite_user_to_room(self, user_id: str, room_id: str): """Invite the given user to the given server room, unless the user has already joined or been invited to it. @@ -179,14 +175,14 @@ class ServerNoticesManager(object): # Check whether the user has already joined or been invited to this room. If # that's the case, there is no need to re-invite them. - joined_rooms = yield self._store.get_rooms_for_local_user_where_membership_is( + joined_rooms = await self._store.get_rooms_for_local_user_where_membership_is( user_id, [Membership.INVITE, Membership.JOIN] ) for room in joined_rooms: if room.room_id == room_id: return - yield self._room_member_handler.update_membership( + await self._room_member_handler.update_membership( requester=requester, target=UserID.from_string(user_id), room_id=room_id, diff --git a/synapse/server_notices/server_notices_sender.py b/synapse/server_notices/server_notices_sender.py index 652bab58e3..be74e86641 100644 --- a/synapse/server_notices/server_notices_sender.py +++ b/synapse/server_notices/server_notices_sender.py @@ -12,8 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer - from synapse.server_notices.consent_server_notices import ConsentServerNotices from synapse.server_notices.resource_limits_server_notices import ( ResourceLimitsServerNotices, @@ -36,18 +34,16 @@ class ServerNoticesSender(object): ResourceLimitsServerNotices(hs), ) - @defer.inlineCallbacks - def on_user_syncing(self, user_id): + async def on_user_syncing(self, user_id): """Called when the user performs a sync operation. Args: user_id (str): mxid of user who synced """ for sn in self._server_notices: - yield sn.maybe_send_server_notice_to_user(user_id) + await sn.maybe_send_server_notice_to_user(user_id) - @defer.inlineCallbacks - def on_user_ip(self, user_id): + async def on_user_ip(self, user_id): """Called on the master when a worker process saw a client request. Args: @@ -57,4 +53,4 @@ class ServerNoticesSender(object): # we check for notices to send to the user in on_user_ip as well as # in on_user_syncing for sn in self._server_notices: - yield sn.maybe_send_server_notice_to_user(user_id) + await sn.maybe_send_server_notice_to_user(user_id) diff --git a/synapse/storage/data_stores/main/registration.py b/synapse/storage/data_stores/main/registration.py index 3e53c8568a..efcdd2100b 100644 --- a/synapse/storage/data_stores/main/registration.py +++ b/synapse/storage/data_stores/main/registration.py @@ -273,8 +273,7 @@ class RegistrationWorkerStore(SQLBaseStore): desc="delete_account_validity_for_user", ) - @defer.inlineCallbacks - def is_server_admin(self, user): + async def is_server_admin(self, user): """Determines if a user is an admin of this homeserver. Args: @@ -283,7 +282,7 @@ class RegistrationWorkerStore(SQLBaseStore): Returns (bool): true iff the user is a server admin, false otherwise. """ - res = yield self.db.simple_select_one_onecol( + res = await self.db.simple_select_one_onecol( table="users", keyvalues={"name": user.to_string()}, retcol="admin", diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index be665262c6..8aa56f1496 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -82,18 +82,26 @@ class ProfileTestCase(unittest.TestCase): @defer.inlineCallbacks def test_set_my_name(self): - yield self.handler.set_displayname( - self.frank, synapse.types.create_requester(self.frank), "Frank Jr." + yield defer.ensureDeferred( + self.handler.set_displayname( + self.frank, synapse.types.create_requester(self.frank), "Frank Jr." + ) ) self.assertEquals( - (yield self.store.get_profile_displayname(self.frank.localpart)), + ( + yield defer.ensureDeferred( + self.store.get_profile_displayname(self.frank.localpart) + ) + ), "Frank Jr.", ) # Set displayname again - yield self.handler.set_displayname( - self.frank, synapse.types.create_requester(self.frank), "Frank" + yield defer.ensureDeferred( + self.handler.set_displayname( + self.frank, synapse.types.create_requester(self.frank), "Frank" + ) ) self.assertEquals( @@ -112,16 +120,20 @@ class ProfileTestCase(unittest.TestCase): ) # Setting displayname a second time is forbidden - d = self.handler.set_displayname( - self.frank, synapse.types.create_requester(self.frank), "Frank Jr." + d = defer.ensureDeferred( + self.handler.set_displayname( + self.frank, synapse.types.create_requester(self.frank), "Frank Jr." + ) ) yield self.assertFailure(d, SynapseError) @defer.inlineCallbacks def test_set_my_name_noauth(self): - d = self.handler.set_displayname( - self.frank, synapse.types.create_requester(self.bob), "Frank Jr." + d = defer.ensureDeferred( + self.handler.set_displayname( + self.frank, synapse.types.create_requester(self.bob), "Frank Jr." + ) ) yield self.assertFailure(d, AuthError) @@ -165,10 +177,12 @@ class ProfileTestCase(unittest.TestCase): @defer.inlineCallbacks def test_set_my_avatar(self): - yield self.handler.set_avatar_url( - self.frank, - synapse.types.create_requester(self.frank), - "http://my.server/pic.gif", + yield defer.ensureDeferred( + self.handler.set_avatar_url( + self.frank, + synapse.types.create_requester(self.frank), + "http://my.server/pic.gif", + ) ) self.assertEquals( @@ -177,10 +191,12 @@ class ProfileTestCase(unittest.TestCase): ) # Set avatar again - yield self.handler.set_avatar_url( - self.frank, - synapse.types.create_requester(self.frank), - "http://my.server/me.png", + yield defer.ensureDeferred( + self.handler.set_avatar_url( + self.frank, + synapse.types.create_requester(self.frank), + "http://my.server/me.png", + ) ) self.assertEquals( @@ -203,10 +219,12 @@ class ProfileTestCase(unittest.TestCase): ) # Set avatar a second time is forbidden - d = self.handler.set_avatar_url( - self.frank, - synapse.types.create_requester(self.frank), - "http://my.server/pic.gif", + d = defer.ensureDeferred( + self.handler.set_avatar_url( + self.frank, + synapse.types.create_requester(self.frank), + "http://my.server/pic.gif", + ) ) yield self.assertFailure(d, SynapseError) diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index f1dc51d6c9..1b7935cef2 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -175,7 +175,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): room_alias_str = "#room:test" self.hs.config.auto_join_rooms = [room_alias_str] - self.store.is_real_user = Mock(return_value=False) + self.store.is_real_user = Mock(return_value=defer.succeed(False)) user_id = self.get_success(self.handler.register_user(localpart="support")) rooms = self.get_success(self.store.get_rooms_for_user(user_id)) self.assertEqual(len(rooms), 0) @@ -187,8 +187,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase): room_alias_str = "#room:test" self.hs.config.auto_join_rooms = [room_alias_str] - self.store.count_real_users = Mock(return_value=1) - self.store.is_real_user = Mock(return_value=True) + self.store.count_real_users = Mock(return_value=defer.succeed(1)) + self.store.is_real_user = Mock(return_value=defer.succeed(True)) user_id = self.get_success(self.handler.register_user(localpart="real")) rooms = self.get_success(self.store.get_rooms_for_user(user_id)) directory_handler = self.hs.get_handlers().directory_handler @@ -202,8 +202,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase): room_alias_str = "#room:test" self.hs.config.auto_join_rooms = [room_alias_str] - self.store.count_real_users = Mock(return_value=2) - self.store.is_real_user = Mock(return_value=True) + self.store.count_real_users = Mock(return_value=defer.succeed(2)) + self.store.is_real_user = Mock(return_value=defer.succeed(True)) user_id = self.get_success(self.handler.register_user(localpart="real")) rooms = self.get_success(self.store.get_rooms_for_user(user_id)) self.assertEqual(len(rooms), 0) @@ -256,8 +256,9 @@ class RegistrationTestCase(unittest.HomeserverTestCase): self.handler.register_user(localpart=invalid_user_id), SynapseError ) - @defer.inlineCallbacks - def get_or_create_user(self, requester, localpart, displayname, password_hash=None): + async def get_or_create_user( + self, requester, localpart, displayname, password_hash=None + ): """Creates a new user if the user does not exist, else revokes all previous access tokens and generates a new one. @@ -272,11 +273,11 @@ class RegistrationTestCase(unittest.HomeserverTestCase): """ if localpart is None: raise SynapseError(400, "Request must include user id") - yield self.hs.get_auth().check_auth_blocking() + await self.hs.get_auth().check_auth_blocking() need_register = True try: - yield self.handler.check_username(localpart) + await self.handler.check_username(localpart) except SynapseError as e: if e.errcode == Codes.USER_IN_USE: need_register = False @@ -288,23 +289,21 @@ class RegistrationTestCase(unittest.HomeserverTestCase): token = self.macaroon_generator.generate_access_token(user_id) if need_register: - yield self.handler.register_with_store( + await self.handler.register_with_store( user_id=user_id, password_hash=password_hash, create_profile_with_displayname=user.localpart, ) else: - yield defer.ensureDeferred( - self.hs.get_auth_handler().delete_access_tokens_for_user(user_id) - ) + await self.hs.get_auth_handler().delete_access_tokens_for_user(user_id) - yield self.store.add_access_token_to_user( + await self.store.add_access_token_to_user( user_id=user_id, token=token, device_id=None, valid_until_ms=None ) if displayname is not None: # logger.info("setting user display name: %s -> %s", user_id, displayname) - yield self.hs.get_profile_handler().set_displayname( + await self.hs.get_profile_handler().set_displayname( user, requester, displayname, by_admin=True ) diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index 93eb053b8c..987addad9b 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -55,25 +55,18 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): self._rlsn._store.user_last_seen_monthly_active = Mock( return_value=defer.succeed(1000) ) - self._send_notice = self._rlsn._server_notices_manager.send_notice - self._rlsn._server_notices_manager.send_notice = Mock() - self._rlsn._state.get_current_state = Mock(return_value=defer.succeed(None)) - self._rlsn._store.get_events = Mock(return_value=defer.succeed({})) - + self._rlsn._server_notices_manager.send_notice = Mock( + return_value=defer.succeed(Mock()) + ) self._send_notice = self._rlsn._server_notices_manager.send_notice self.hs.config.limit_usage_by_mau = True self.user_id = "@user_id:test" - # self.server_notices_mxid = "@server:test" - # self.server_notices_mxid_display_name = None - # self.server_notices_mxid_avatar_url = None - # self.server_notices_room_name = "Server Notices" - self._rlsn._server_notices_manager.get_or_create_notice_room_for_user = Mock( - returnValue="" + return_value=defer.succeed("!something:localhost") ) - self._rlsn._store.add_tag_to_room = Mock() + self._rlsn._store.add_tag_to_room = Mock(return_value=defer.succeed(None)) self._rlsn._store.get_tags_for_room = Mock(return_value={}) self.hs.config.admin_contact = "mailto:user@test.com" @@ -95,14 +88,13 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): def test_maybe_send_server_notice_to_user_remove_blocked_notice(self): """Test when user has blocked notice, but should have it removed""" - self._rlsn._auth.check_auth_blocking = Mock() + self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None)) mock_event = Mock( type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} ) self._rlsn._store.get_events = Mock( return_value=defer.succeed({"123": mock_event}) ) - self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) # Would be better to check the content, but once == remove blocking event self._send_notice.assert_called_once() @@ -112,7 +104,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): Test when user has blocked notice, but notice ought to be there (NOOP) """ self._rlsn._auth.check_auth_blocking = Mock( - side_effect=ResourceLimitError(403, "foo") + return_value=defer.succeed(None), side_effect=ResourceLimitError(403, "foo") ) mock_event = Mock( @@ -121,6 +113,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): self._rlsn._store.get_events = Mock( return_value=defer.succeed({"123": mock_event}) ) + self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) self._send_notice.assert_not_called() @@ -129,9 +122,8 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): """ Test when user does not have blocked notice, but should have one """ - self._rlsn._auth.check_auth_blocking = Mock( - side_effect=ResourceLimitError(403, "foo") + return_value=defer.succeed(None), side_effect=ResourceLimitError(403, "foo") ) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) @@ -142,7 +134,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): """ Test when user does not have blocked notice, nor should they (NOOP) """ - self._rlsn._auth.check_auth_blocking = Mock() + self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None)) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) @@ -153,7 +145,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): Test when user is not part of the MAU cohort - this should not ever happen - but ... """ - self._rlsn._auth.check_auth_blocking = Mock() + self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None)) self._rlsn._store.user_last_seen_monthly_active = Mock( return_value=defer.succeed(None) ) @@ -167,24 +159,28 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): an alert message is not sent into the room """ self.hs.config.mau_limit_alerting = False + self._rlsn._auth.check_auth_blocking = Mock( + return_value=defer.succeed(None), side_effect=ResourceLimitError( 403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER - ) + ), ) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) - self.assertTrue(self._send_notice.call_count == 0) + self.assertEqual(self._send_notice.call_count, 0) def test_check_hs_disabled_unaffected_by_mau_alert_suppression(self): """ Test that when a server is disabled, that MAU limit alerting is ignored. """ self.hs.config.mau_limit_alerting = False + self._rlsn._auth.check_auth_blocking = Mock( + return_value=defer.succeed(None), side_effect=ResourceLimitError( 403, "foo", limit_type=LimitBlockingTypes.HS_DISABLED - ) + ), ) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) @@ -198,10 +194,12 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): """ self.hs.config.mau_limit_alerting = False self._rlsn._auth.check_auth_blocking = Mock( + return_value=defer.succeed(None), side_effect=ResourceLimitError( 403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER - ) + ), ) + self._rlsn._server_notices_manager.__is_room_currently_blocked = Mock( return_value=defer.succeed((True, [])) ) @@ -256,7 +254,9 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): def test_server_notice_only_sent_once(self): self.store.get_monthly_active_count = Mock(return_value=1000) - self.store.user_last_seen_monthly_active = Mock(return_value=1000) + self.store.user_last_seen_monthly_active = Mock( + return_value=defer.succeed(1000) + ) # Call the function multiple times to ensure we only send the notice once self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) diff --git a/tests/test_federation.py b/tests/test_federation.py index 9b5cf562f3..f297de95f1 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -27,8 +27,10 @@ class MessageAcceptTests(unittest.TestCase): user_id = UserID("us", "test") our_user = Requester(user_id, None, False, None, None) room_creator = self.homeserver.get_room_creation_handler() - room = room_creator.create_room( - our_user, room_creator.PRESETS_DICT["public_chat"], ratelimit=False + room = ensureDeferred( + room_creator.create_room( + our_user, room_creator.PRESETS_DICT["public_chat"], ratelimit=False + ) ) self.reactor.advance(0.1) self.room_id = self.successResultOf(room)["room_id"] -- cgit 1.5.1 From 3085cde577216519d789c8160262831cb2029972 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 1 May 2020 15:21:35 +0100 Subject: Use `stream.current_token()` and remove `stream_positions()` (#7172) We move the processing of typing and federation replication traffic into their handlers so that `Stream.current_token()` points to a valid token. This allows us to remove `get_streams_to_replicate()` and `stream_positions()`. --- changelog.d/7172.misc | 1 + synapse/app/generic_worker.py | 16 ------------ synapse/replication/slave/storage/_base.py | 15 +----------- synapse/replication/slave/storage/account_data.py | 8 ------ synapse/replication/slave/storage/deviceinbox.py | 5 ---- synapse/replication/slave/storage/devices.py | 10 -------- synapse/replication/slave/storage/events.py | 6 ----- synapse/replication/slave/storage/groups.py | 5 ---- synapse/replication/slave/storage/presence.py | 9 ------- synapse/replication/slave/storage/push_rule.py | 5 ---- synapse/replication/slave/storage/pushers.py | 5 ---- synapse/replication/slave/storage/receipts.py | 5 ---- synapse/replication/slave/storage/room.py | 5 ---- synapse/replication/tcp/client.py | 19 +------------- synapse/replication/tcp/handler.py | 10 +------- tests/replication/tcp/streams/_base.py | 30 ++++++++--------------- tests/replication/tcp/streams/test_events.py | 24 ++++++++++++------ tests/replication/tcp/streams/test_receipts.py | 3 --- tests/replication/tcp/streams/test_typing.py | 3 --- 19 files changed, 30 insertions(+), 154 deletions(-) create mode 100644 changelog.d/7172.misc (limited to 'tests') diff --git a/changelog.d/7172.misc b/changelog.d/7172.misc new file mode 100644 index 0000000000..ffecdf97fe --- /dev/null +++ b/changelog.d/7172.misc @@ -0,0 +1 @@ +Use `stream.current_token()` and remove `stream_positions()`. diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 0ace7b787d..97b9b81237 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -413,12 +413,6 @@ class GenericWorkerTyping(object): # map room IDs to sets of users currently typing self._room_typing = {} - def stream_positions(self): - # We must update this typing token from the response of the previous - # sync. In particular, the stream id may "reset" back to zero/a low - # value which we *must* use for the next replication request. - return {"typing": self._latest_room_serial} - def process_replication_rows(self, token, rows): if self._latest_room_serial > token: # The master has gone backwards. To prevent inconsistent data, just @@ -658,13 +652,6 @@ class GenericWorkerReplicationHandler(ReplicationDataHandler): ) await self.process_and_notify(stream_name, token, rows) - def get_streams_to_replicate(self): - args = super(GenericWorkerReplicationHandler, self).get_streams_to_replicate() - args.update(self.typing_handler.stream_positions()) - if self.send_handler: - args.update(self.send_handler.stream_positions()) - return args - async def process_and_notify(self, stream_name, token, rows): try: if self.send_handler: @@ -799,9 +786,6 @@ class FederationSenderHandler(object): def wake_destination(self, server: str): self.federation_sender.wake_destination(server) - def stream_positions(self): - return {"federation": self.federation_position} - async def process_replication_rows(self, stream_name, token, rows): # The federation stream contains things that we want to send out, e.g. # presence, typing, etc. diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py index 751c799d94..5d7c8871a4 100644 --- a/synapse/replication/slave/storage/_base.py +++ b/synapse/replication/slave/storage/_base.py @@ -14,7 +14,7 @@ # limitations under the License. import logging -from typing import Dict, Optional +from typing import Optional import six @@ -49,19 +49,6 @@ class BaseSlavedStore(CacheInvalidationWorkerStore): self.hs = hs - def stream_positions(self) -> Dict[str, int]: - """ - Get the current positions of all the streams this store wants to subscribe to - - Returns: - map from stream name to the most recent update we have for - that stream (ie, the point we want to start replicating from) - """ - pos = {} - if self._cache_id_gen: - pos["caches"] = self._cache_id_gen.get_current_token() - return pos - def get_cache_stream_token(self): if self._cache_id_gen: return self._cache_id_gen.get_current_token() diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py index ebe94909cb..65e54b1c71 100644 --- a/synapse/replication/slave/storage/account_data.py +++ b/synapse/replication/slave/storage/account_data.py @@ -32,14 +32,6 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved def get_max_account_data_stream_id(self): return self._account_data_id_gen.get_current_token() - def stream_positions(self): - result = super(SlavedAccountDataStore, self).stream_positions() - position = self._account_data_id_gen.get_current_token() - result["user_account_data"] = position - result["room_account_data"] = position - result["tag_account_data"] = position - return result - def process_replication_rows(self, stream_name, token, rows): if stream_name == "tag_account_data": self._account_data_id_gen.advance(token) diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py index 0c237c6e0f..c923751e50 100644 --- a/synapse/replication/slave/storage/deviceinbox.py +++ b/synapse/replication/slave/storage/deviceinbox.py @@ -43,11 +43,6 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore): expiry_ms=30 * 60 * 1000, ) - def stream_positions(self): - result = super(SlavedDeviceInboxStore, self).stream_positions() - result["to_device"] = self._device_inbox_id_gen.get_current_token() - return result - def process_replication_rows(self, stream_name, token, rows): if stream_name == "to_device": self._device_inbox_id_gen.advance(token) diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py index 23b1650e41..58fb0eaae3 100644 --- a/synapse/replication/slave/storage/devices.py +++ b/synapse/replication/slave/storage/devices.py @@ -48,16 +48,6 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto "DeviceListFederationStreamChangeCache", device_list_max ) - def stream_positions(self): - result = super(SlavedDeviceStore, self).stream_positions() - # The user signature stream uses the same stream ID generator as the - # device list stream, so set them both to the device list ID - # generator's current token. - current_token = self._device_list_id_gen.get_current_token() - result[DeviceListsStream.NAME] = current_token - result[UserSignatureStream.NAME] = current_token - return result - def process_replication_rows(self, stream_name, token, rows): if stream_name == DeviceListsStream.NAME: self._device_list_id_gen.advance(token) diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index e73342c657..15011259df 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -93,12 +93,6 @@ class SlavedEventStore( def get_room_min_stream_ordering(self): return self._backfill_id_gen.get_current_token() - def stream_positions(self): - result = super(SlavedEventStore, self).stream_positions() - result["events"] = self._stream_id_gen.get_current_token() - result["backfill"] = -self._backfill_id_gen.get_current_token() - return result - def process_replication_rows(self, stream_name, token, rows): if stream_name == "events": self._stream_id_gen.advance(token) diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py index 2d4fd08cf5..01bcf0e882 100644 --- a/synapse/replication/slave/storage/groups.py +++ b/synapse/replication/slave/storage/groups.py @@ -37,11 +37,6 @@ class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore): def get_group_stream_token(self): return self._group_updates_id_gen.get_current_token() - def stream_positions(self): - result = super(SlavedGroupServerStore, self).stream_positions() - result["groups"] = self._group_updates_id_gen.get_current_token() - return result - def process_replication_rows(self, stream_name, token, rows): if stream_name == "groups": self._group_updates_id_gen.advance(token) diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py index ad8f0c15a9..fae3125072 100644 --- a/synapse/replication/slave/storage/presence.py +++ b/synapse/replication/slave/storage/presence.py @@ -41,15 +41,6 @@ class SlavedPresenceStore(BaseSlavedStore): def get_current_presence_token(self): return self._presence_id_gen.get_current_token() - def stream_positions(self): - result = super(SlavedPresenceStore, self).stream_positions() - - if self.hs.config.use_presence: - position = self._presence_id_gen.get_current_token() - result["presence"] = position - - return result - def process_replication_rows(self, stream_name, token, rows): if stream_name == "presence": self._presence_id_gen.advance(token) diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py index eebd5a1fb6..6138796da4 100644 --- a/synapse/replication/slave/storage/push_rule.py +++ b/synapse/replication/slave/storage/push_rule.py @@ -37,11 +37,6 @@ class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore): def get_max_push_rules_stream_id(self): return self._push_rules_stream_id_gen.get_current_token() - def stream_positions(self): - result = super(SlavedPushRuleStore, self).stream_positions() - result["push_rules"] = self._push_rules_stream_id_gen.get_current_token() - return result - def process_replication_rows(self, stream_name, token, rows): if stream_name == "push_rules": self._push_rules_stream_id_gen.advance(token) diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py index bce8a3d115..67be337945 100644 --- a/synapse/replication/slave/storage/pushers.py +++ b/synapse/replication/slave/storage/pushers.py @@ -28,11 +28,6 @@ class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore): db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")] ) - def stream_positions(self): - result = super(SlavedPusherStore, self).stream_positions() - result["pushers"] = self._pushers_id_gen.get_current_token() - return result - def get_pushers_stream_token(self): return self._pushers_id_gen.get_current_token() diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py index d40dc6e1f5..993432edcb 100644 --- a/synapse/replication/slave/storage/receipts.py +++ b/synapse/replication/slave/storage/receipts.py @@ -42,11 +42,6 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore): def get_max_receipt_stream_id(self): return self._receipts_id_gen.get_current_token() - def stream_positions(self): - result = super(SlavedReceiptsStore, self).stream_positions() - result["receipts"] = self._receipts_id_gen.get_current_token() - return result - def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id): self.get_receipts_for_user.invalidate((user_id, receipt_type)) self._get_linearized_receipts_for_room.invalidate_many((room_id,)) diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py index 3a20f45316..10dda8708f 100644 --- a/synapse/replication/slave/storage/room.py +++ b/synapse/replication/slave/storage/room.py @@ -30,11 +30,6 @@ class RoomStore(RoomWorkerStore, BaseSlavedStore): def get_current_public_room_stream_id(self): return self._public_room_id_gen.get_current_token() - def stream_positions(self): - result = super(RoomStore, self).stream_positions() - result["public_rooms"] = self._public_room_id_gen.get_current_token() - return result - def process_replication_rows(self, stream_name, token, rows): if stream_name == "public_rooms": self._public_room_id_gen.advance(token) diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 2d07b8b2d0..5c28fd4ac3 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -16,7 +16,7 @@ """ import logging -from typing import TYPE_CHECKING, Dict +from typing import TYPE_CHECKING from twisted.internet.protocol import ReconnectingClientFactory @@ -100,23 +100,6 @@ class ReplicationDataHandler: """ self.store.process_replication_rows(stream_name, token, rows) - def get_streams_to_replicate(self) -> Dict[str, int]: - """Called when a new connection has been established and we need to - subscribe to streams. - - Returns: - map from stream name to the most recent update we have for - that stream (ie, the point we want to start replicating from) - """ - args = self.store.stream_positions() - user_account_data = args.pop("user_account_data", None) - room_account_data = args.pop("room_account_data", None) - if user_account_data: - args["account_data"] = user_account_data - elif room_account_data: - args["account_data"] = room_account_data - return args - async def on_position(self, stream_name: str, token: int): self.store.process_replication_rows(stream_name, token, []) diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 6f7054d5af..d72f3d0cf9 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -314,15 +314,7 @@ class ReplicationCommandHandler: self._pending_batches.pop(cmd.stream_name, []) # Find where we previously streamed up to. - current_token = self._replication_data_handler.get_streams_to_replicate().get( - cmd.stream_name - ) - if current_token is None: - logger.warning( - "Got POSITION for stream we're not subscribed to: %s", - cmd.stream_name, - ) - return + current_token = stream.current_token() # If the position token matches our current token then we're up to # date and there's nothing to do. Otherwise, fetch all updates diff --git a/tests/replication/tcp/streams/_base.py b/tests/replication/tcp/streams/_base.py index 83e16cfe3d..8c104f8d1d 100644 --- a/tests/replication/tcp/streams/_base.py +++ b/tests/replication/tcp/streams/_base.py @@ -14,7 +14,7 @@ # limitations under the License. import logging -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, List, Optional, Tuple import attr @@ -22,13 +22,15 @@ from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime from twisted.internet.task import LoopingCall from twisted.web.http import HTTPChannel -from synapse.app.generic_worker import GenericWorkerServer +from synapse.app.generic_worker import ( + GenericWorkerReplicationHandler, + GenericWorkerServer, +) from synapse.http.site import SynapseRequest -from synapse.replication.slave.storage._base import BaseSlavedStore -from synapse.replication.tcp.client import ReplicationDataHandler from synapse.replication.tcp.handler import ReplicationCommandHandler from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory +from synapse.server import HomeServer from synapse.util import Clock from tests import unittest @@ -77,7 +79,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): self._server_transport = None def _build_replication_data_handler(self): - return TestReplicationDataHandler(self.worker_hs.get_datastore()) + return TestReplicationDataHandler(self.worker_hs) def reconnect(self): if self._client_transport: @@ -172,32 +174,20 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): self.assertEqual(request.method, b"GET") -class TestReplicationDataHandler(ReplicationDataHandler): +class TestReplicationDataHandler(GenericWorkerReplicationHandler): """Drop-in for ReplicationDataHandler which just collects RDATA rows""" - def __init__(self, store: BaseSlavedStore): - super().__init__(store) - - # streams to subscribe to: map from stream id to position - self.stream_positions = {} # type: Dict[str, int] + def __init__(self, hs: HomeServer): + super().__init__(hs) # list of received (stream_name, token, row) tuples self.received_rdata_rows = [] # type: List[Tuple[str, int, Any]] - def get_streams_to_replicate(self): - return self.stream_positions - async def on_rdata(self, stream_name, token, rows): await super().on_rdata(stream_name, token, rows) for r in rows: self.received_rdata_rows.append((stream_name, token, r)) - if ( - stream_name in self.stream_positions - and token > self.stream_positions[stream_name] - ): - self.stream_positions[stream_name] = token - @attr.s() class OneShotRequestFactory: diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py index 1fa28084f9..8bd67bb9f1 100644 --- a/tests/replication/tcp/streams/test_events.py +++ b/tests/replication/tcp/streams/test_events.py @@ -43,7 +43,6 @@ class EventsStreamTestCase(BaseStreamTestCase): self.user_tok = self.login("u1", "pass") self.reconnect() - self.test_handler.stream_positions["events"] = 0 self.room_id = self.helper.create_room_as(tok=self.user_tok) self.test_handler.received_rdata_rows.clear() @@ -80,8 +79,12 @@ class EventsStreamTestCase(BaseStreamTestCase): self.reconnect() self.replicate() - # we should have received all the expected rows in the right order - received_rows = self.test_handler.received_rdata_rows + # we should have received all the expected rows in the right order (as + # well as various cache invalidation updates which we ignore) + received_rows = [ + row for row in self.test_handler.received_rdata_rows if row[0] == "events" + ] + for event in events: stream_name, token, row = received_rows.pop(0) self.assertEqual("events", stream_name) @@ -184,7 +187,8 @@ class EventsStreamTestCase(BaseStreamTestCase): self.reconnect() self.replicate() - # now we should have received all the expected rows in the right order. + # we should have received all the expected rows in the right order (as + # well as various cache invalidation updates which we ignore) # # we expect: # @@ -193,7 +197,9 @@ class EventsStreamTestCase(BaseStreamTestCase): # of the states that got reverted. # - two rows for state2 - received_rows = self.test_handler.received_rdata_rows + received_rows = [ + row for row in self.test_handler.received_rdata_rows if row[0] == "events" + ] # first check the first two rows, which should be state1 @@ -334,9 +340,11 @@ class EventsStreamTestCase(BaseStreamTestCase): self.reconnect() self.replicate() - # we should have received all the expected rows in the right order - - received_rows = self.test_handler.received_rdata_rows + # we should have received all the expected rows in the right order (as + # well as various cache invalidation updates which we ignore) + received_rows = [ + row for row in self.test_handler.received_rdata_rows if row[0] == "events" + ] self.assertGreaterEqual(len(received_rows), len(events)) for i in range(NUM_USERS): # for each user, we expect the PL event row, followed by state rows for diff --git a/tests/replication/tcp/streams/test_receipts.py b/tests/replication/tcp/streams/test_receipts.py index c122b8589c..df332ee679 100644 --- a/tests/replication/tcp/streams/test_receipts.py +++ b/tests/replication/tcp/streams/test_receipts.py @@ -31,9 +31,6 @@ class ReceiptsStreamTestCase(BaseStreamTestCase): def test_receipt(self): self.reconnect() - # make the client subscribe to the receipts stream - self.test_handler.stream_positions.update({"receipts": 0}) - # tell the master to send a new receipt self.get_success( self.hs.get_datastore().insert_receipt( diff --git a/tests/replication/tcp/streams/test_typing.py b/tests/replication/tcp/streams/test_typing.py index 4d354a9db8..e8d17ca68a 100644 --- a/tests/replication/tcp/streams/test_typing.py +++ b/tests/replication/tcp/streams/test_typing.py @@ -38,9 +38,6 @@ class TypingStreamTestCase(BaseStreamTestCase): self.reconnect() - # make the client subscribe to the typing stream - self.test_handler.stream_positions.update({"typing": 0}) - typing._push_update(member=RoomMember(room_id, USER_ID), typing=True) self.reactor.advance(0) -- cgit 1.5.1 From 0e719f23981b8294df66ba7f38b8c7cc99fad228 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 1 May 2020 17:19:56 +0100 Subject: Thread through instance name to replication client. (#7369) For in memory streams when fetching updates on workers we need to query the source of the stream, which currently is hard coded to be master. This PR threads through the source instance we received via `POSITION` through to the update function in each stream, which can then be passed to the replication client for in memory streams. --- changelog.d/7369.misc | 1 + synapse/app/generic_worker.py | 10 +++--- synapse/replication/http/_base.py | 19 +++++++++- synapse/replication/http/streams.py | 4 ++- synapse/replication/tcp/client.py | 12 ++++--- synapse/replication/tcp/handler.py | 20 ++++++++--- synapse/replication/tcp/streams/_base.py | 50 +++++++++++++++++++------- synapse/replication/tcp/streams/events.py | 10 ++++-- synapse/replication/tcp/streams/federation.py | 4 +-- tests/replication/tcp/streams/_base.py | 4 +-- tests/replication/tcp/streams/test_receipts.py | 4 +-- tests/replication/tcp/streams/test_typing.py | 4 +-- 12 files changed, 101 insertions(+), 41 deletions(-) create mode 100644 changelog.d/7369.misc (limited to 'tests') diff --git a/changelog.d/7369.misc b/changelog.d/7369.misc new file mode 100644 index 0000000000..060b09c888 --- /dev/null +++ b/changelog.d/7369.misc @@ -0,0 +1 @@ +Thread through instance name to replication client. diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 97b9b81237..667ad20428 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -646,13 +646,11 @@ class GenericWorkerReplicationHandler(ReplicationDataHandler): else: self.send_handler = None - async def on_rdata(self, stream_name, token, rows): - await super(GenericWorkerReplicationHandler, self).on_rdata( - stream_name, token, rows - ) - await self.process_and_notify(stream_name, token, rows) + async def on_rdata(self, stream_name, instance_name, token, rows): + await super().on_rdata(stream_name, instance_name, token, rows) + await self._process_and_notify(stream_name, instance_name, token, rows) - async def process_and_notify(self, stream_name, token, rows): + async def _process_and_notify(self, stream_name, instance_name, token, rows): try: if self.send_handler: await self.send_handler.process_replication_rows( diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index 1be1ccbdf3..f88c80ae84 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -16,6 +16,7 @@ import abc import logging import re +from inspect import signature from typing import Dict, List, Tuple from six import raise_from @@ -60,6 +61,8 @@ class ReplicationEndpoint(object): must call `register` to register the path with the HTTP server. Requests can be sent by calling the client returned by `make_client`. + Requests are sent to master process by default, but can be sent to other + named processes by specifying an `instance_name` keyword argument. Attributes: NAME (str): A name for the endpoint, added to the path as well as used @@ -91,6 +94,16 @@ class ReplicationEndpoint(object): hs, "repl." + self.NAME, timeout_ms=30 * 60 * 1000 ) + # We reserve `instance_name` as a parameter to sending requests, so we + # assert here that sub classes don't try and use the name. + assert ( + "instance_name" not in self.PATH_ARGS + ), "`instance_name` is a reserved paramater name" + assert ( + "instance_name" + not in signature(self.__class__._serialize_payload).parameters + ), "`instance_name` is a reserved paramater name" + assert self.METHOD in ("PUT", "POST", "GET") @abc.abstractmethod @@ -135,7 +148,11 @@ class ReplicationEndpoint(object): @trace(opname="outgoing_replication_request") @defer.inlineCallbacks - def send_request(**kwargs): + def send_request(instance_name="master", **kwargs): + # Currently we only support sending requests to master process. + if instance_name != "master": + raise Exception("Unknown instance") + data = yield cls._serialize_payload(**kwargs) url_args = [ diff --git a/synapse/replication/http/streams.py b/synapse/replication/http/streams.py index f35cebc710..0459f582bf 100644 --- a/synapse/replication/http/streams.py +++ b/synapse/replication/http/streams.py @@ -50,6 +50,8 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint): def __init__(self, hs): super().__init__(hs) + self._instance_name = hs.get_instance_name() + # We pull the streams from the replication steamer (if we try and make # them ourselves we end up in an import loop). self.streams = hs.get_replication_streamer().get_streams() @@ -67,7 +69,7 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint): upto_token = parse_integer(request, "upto_token", required=True) updates, upto_token, limited = await stream.get_updates_since( - from_token, upto_token + self._instance_name, from_token, upto_token ) return ( diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 5c28fd4ac3..3bbf3c3569 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -86,17 +86,19 @@ class ReplicationDataHandler: def __init__(self, store: BaseSlavedStore): self.store = store - async def on_rdata(self, stream_name: str, token: int, rows: list): + async def on_rdata( + self, stream_name: str, instance_name: str, token: int, rows: list + ): """Called to handle a batch of replication data with a given stream token. By default this just pokes the slave store. Can be overridden in subclasses to handle more. Args: - stream_name (str): name of the replication stream for this batch of rows - token (int): stream token for this batch of rows - rows (list): a list of Stream.ROW_TYPE objects as returned by - Stream.parse_row. + stream_name: name of the replication stream for this batch of rows + instance_name: the instance that wrote the rows. + token: stream token for this batch of rows + rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row. """ self.store.process_replication_rows(stream_name, token, rows) diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index d72f3d0cf9..2d1d119c7c 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -278,19 +278,24 @@ class ReplicationCommandHandler: # Check if this is the last of a batch of updates rows = self._pending_batches.pop(stream_name, []) rows.append(row) - await self.on_rdata(stream_name, cmd.token, rows) + await self.on_rdata(stream_name, cmd.instance_name, cmd.token, rows) - async def on_rdata(self, stream_name: str, token: int, rows: list): + async def on_rdata( + self, stream_name: str, instance_name: str, token: int, rows: list + ): """Called to handle a batch of replication data with a given stream token. Args: stream_name: name of the replication stream for this batch of rows + instance_name: the instance that wrote the rows. token: stream token for this batch of rows rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row. """ logger.debug("Received rdata %s -> %s", stream_name, token) - await self._replication_data_handler.on_rdata(stream_name, token, rows) + await self._replication_data_handler.on_rdata( + stream_name, instance_name, token, rows + ) async def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand): if cmd.instance_name == self._instance_name: @@ -325,7 +330,9 @@ class ReplicationCommandHandler: updates, current_token, missing_updates, - ) = await stream.get_updates_since(current_token, cmd.token) + ) = await stream.get_updates_since( + cmd.instance_name, current_token, cmd.token + ) # TODO: add some tests for this @@ -334,7 +341,10 @@ class ReplicationCommandHandler: for token, rows in _batch_updates(updates): await self.on_rdata( - cmd.stream_name, token, [stream.parse_row(row) for row in rows], + cmd.stream_name, + cmd.instance_name, + token, + [stream.parse_row(row) for row in rows], ) # We've now caught up to position sent to us, notify handler. diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 4af1afd119..b0f87c365b 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -16,7 +16,7 @@ import logging from collections import namedtuple -from typing import Any, Awaitable, Callable, Iterable, List, Optional, Tuple +from typing import Any, Awaitable, Callable, List, Optional, Tuple import attr @@ -53,6 +53,7 @@ StreamUpdateResult = Tuple[List[Tuple[Token, StreamRow]], Token, bool] # # The arguments are: # +# * instance_name: the writer of the stream # * from_token: the previous stream token: the starting point for fetching the # updates # * to_token: the new stream token: the point to get updates up to @@ -62,7 +63,7 @@ StreamUpdateResult = Tuple[List[Tuple[Token, StreamRow]], Token, bool] # If there are more updates available, it should set `limited` in the result, and # it will be called again to get the next batch. # -UpdateFunction = Callable[[Token, Token, int], Awaitable[StreamUpdateResult]] +UpdateFunction = Callable[[str, Token, Token, int], Awaitable[StreamUpdateResult]] class Stream(object): @@ -93,6 +94,7 @@ class Stream(object): def __init__( self, + local_instance_name: str, current_token_function: Callable[[], Token], update_function: UpdateFunction, ): @@ -108,9 +110,11 @@ class Stream(object): stream tokens. See the UpdateFunction type definition for more info. Args: + local_instance_name: The instance name of the current process current_token_function: callback to get the current token, as above update_function: callback go get stream updates, as above """ + self.local_instance_name = local_instance_name self.current_token = current_token_function self.update_function = update_function @@ -135,14 +139,14 @@ class Stream(object): """ current_token = self.current_token() updates, current_token, limited = await self.get_updates_since( - self.last_token, current_token + self.local_instance_name, self.last_token, current_token ) self.last_token = current_token return updates, current_token, limited async def get_updates_since( - self, from_token: Token, upto_token: Token + self, instance_name: str, from_token: Token, upto_token: Token ) -> StreamUpdateResult: """Like get_updates except allows specifying from when we should stream updates @@ -160,19 +164,19 @@ class Stream(object): return [], upto_token, False updates, upto_token, limited = await self.update_function( - from_token, upto_token, _STREAM_UPDATE_TARGET_ROW_COUNT, + instance_name, from_token, upto_token, _STREAM_UPDATE_TARGET_ROW_COUNT, ) return updates, upto_token, limited def db_query_to_update_function( - query_function: Callable[[Token, Token, int], Awaitable[Iterable[tuple]]] + query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]] ) -> UpdateFunction: """Wraps a db query function which returns a list of rows to make it suitable for use as an `update_function` for the Stream class """ - async def update_function(from_token, upto_token, limit): + async def update_function(instance_name, from_token, upto_token, limit): rows = await query_function(from_token, upto_token, limit) updates = [(row[0], row[1:]) for row in rows] limited = False @@ -193,10 +197,13 @@ def make_http_update_function(hs, stream_name: str) -> UpdateFunction: client = ReplicationGetStreamUpdates.make_client(hs) async def update_function( - from_token: int, upto_token: int, limit: int + instance_name: str, from_token: int, upto_token: int, limit: int ) -> StreamUpdateResult: result = await client( - stream_name=stream_name, from_token=from_token, upto_token=upto_token, + instance_name=instance_name, + stream_name=stream_name, + from_token=from_token, + upto_token=upto_token, ) return result["updates"], result["upto_token"], result["limited"] @@ -226,6 +233,7 @@ class BackfillStream(Stream): def __init__(self, hs): store = hs.get_datastore() super().__init__( + hs.get_instance_name(), store.get_current_backfill_token, db_query_to_update_function(store.get_all_new_backfill_event_rows), ) @@ -261,7 +269,9 @@ class PresenceStream(Stream): # Query master process update_function = make_http_update_function(hs, self.NAME) - super().__init__(store.get_current_presence_token, update_function) + super().__init__( + hs.get_instance_name(), store.get_current_presence_token, update_function + ) class TypingStream(Stream): @@ -284,7 +294,9 @@ class TypingStream(Stream): # Query master process update_function = make_http_update_function(hs, self.NAME) - super().__init__(typing_handler.get_current_token, update_function) + super().__init__( + hs.get_instance_name(), typing_handler.get_current_token, update_function + ) class ReceiptsStream(Stream): @@ -305,6 +317,7 @@ class ReceiptsStream(Stream): def __init__(self, hs): store = hs.get_datastore() super().__init__( + hs.get_instance_name(), store.get_max_receipt_stream_id, db_query_to_update_function(store.get_all_updated_receipts), ) @@ -322,14 +335,16 @@ class PushRulesStream(Stream): def __init__(self, hs): self.store = hs.get_datastore() super(PushRulesStream, self).__init__( - self._current_token, self._update_function + hs.get_instance_name(), self._current_token, self._update_function ) def _current_token(self) -> int: push_rules_token, _ = self.store.get_push_rules_stream_token() return push_rules_token - async def _update_function(self, from_token: Token, to_token: Token, limit: int): + async def _update_function( + self, instance_name: str, from_token: Token, to_token: Token, limit: int + ): rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit) limited = False @@ -356,6 +371,7 @@ class PushersStream(Stream): store = hs.get_datastore() super().__init__( + hs.get_instance_name(), store.get_pushers_stream_token, db_query_to_update_function(store.get_all_updated_pushers_rows), ) @@ -387,6 +403,7 @@ class CachesStream(Stream): def __init__(self, hs): store = hs.get_datastore() super().__init__( + hs.get_instance_name(), store.get_cache_stream_token, db_query_to_update_function(store.get_all_updated_caches), ) @@ -412,6 +429,7 @@ class PublicRoomsStream(Stream): def __init__(self, hs): store = hs.get_datastore() super().__init__( + hs.get_instance_name(), store.get_current_public_room_stream_id, db_query_to_update_function(store.get_all_new_public_rooms), ) @@ -432,6 +450,7 @@ class DeviceListsStream(Stream): def __init__(self, hs): store = hs.get_datastore() super().__init__( + hs.get_instance_name(), store.get_device_stream_token, db_query_to_update_function(store.get_all_device_list_changes_for_remotes), ) @@ -449,6 +468,7 @@ class ToDeviceStream(Stream): def __init__(self, hs): store = hs.get_datastore() super().__init__( + hs.get_instance_name(), store.get_to_device_stream_token, db_query_to_update_function(store.get_all_new_device_messages), ) @@ -468,6 +488,7 @@ class TagAccountDataStream(Stream): def __init__(self, hs): store = hs.get_datastore() super().__init__( + hs.get_instance_name(), store.get_max_account_data_stream_id, db_query_to_update_function(store.get_all_updated_tags), ) @@ -487,6 +508,7 @@ class AccountDataStream(Stream): def __init__(self, hs): self.store = hs.get_datastore() super().__init__( + hs.get_instance_name(), self.store.get_max_account_data_stream_id, db_query_to_update_function(self._update_function), ) @@ -517,6 +539,7 @@ class GroupServerStream(Stream): def __init__(self, hs): store = hs.get_datastore() super().__init__( + hs.get_instance_name(), store.get_group_stream_token, db_query_to_update_function(store.get_all_groups_changes), ) @@ -534,6 +557,7 @@ class UserSignatureStream(Stream): def __init__(self, hs): store = hs.get_datastore() super().__init__( + hs.get_instance_name(), store.get_device_stream_token, db_query_to_update_function( store.get_all_user_signature_changes_for_remotes diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index 52df81b1bd..890e75d827 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -118,11 +118,17 @@ class EventsStream(Stream): def __init__(self, hs): self._store = hs.get_datastore() super().__init__( - self._store.get_current_events_token, self._update_function, + hs.get_instance_name(), + self._store.get_current_events_token, + self._update_function, ) async def _update_function( - self, from_token: Token, current_token: Token, target_row_count: int + self, + instance_name: str, + from_token: Token, + current_token: Token, + target_row_count: int, ) -> StreamUpdateResult: # the events stream merges together three separate sources: diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py index 75133d7e40..e8bd52e389 100644 --- a/synapse/replication/tcp/streams/federation.py +++ b/synapse/replication/tcp/streams/federation.py @@ -48,8 +48,8 @@ class FederationStream(Stream): current_token = lambda: 0 update_function = self._stub_update_function - super().__init__(current_token, update_function) + super().__init__(hs.get_instance_name(), current_token, update_function) @staticmethod - async def _stub_update_function(from_token, upto_token, limit): + async def _stub_update_function(instance_name, from_token, upto_token, limit): return [], upto_token, False diff --git a/tests/replication/tcp/streams/_base.py b/tests/replication/tcp/streams/_base.py index 8c104f8d1d..7b56d2028d 100644 --- a/tests/replication/tcp/streams/_base.py +++ b/tests/replication/tcp/streams/_base.py @@ -183,8 +183,8 @@ class TestReplicationDataHandler(GenericWorkerReplicationHandler): # list of received (stream_name, token, row) tuples self.received_rdata_rows = [] # type: List[Tuple[str, int, Any]] - async def on_rdata(self, stream_name, token, rows): - await super().on_rdata(stream_name, token, rows) + async def on_rdata(self, stream_name, instance_name, token, rows): + await super().on_rdata(stream_name, instance_name, token, rows) for r in rows: self.received_rdata_rows.append((stream_name, token, r)) diff --git a/tests/replication/tcp/streams/test_receipts.py b/tests/replication/tcp/streams/test_receipts.py index df332ee679..5853314fd4 100644 --- a/tests/replication/tcp/streams/test_receipts.py +++ b/tests/replication/tcp/streams/test_receipts.py @@ -41,7 +41,7 @@ class ReceiptsStreamTestCase(BaseStreamTestCase): # there should be one RDATA command self.test_handler.on_rdata.assert_called_once() - stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0] + stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0] self.assertEqual(stream_name, "receipts") self.assertEqual(1, len(rdata_rows)) row = rdata_rows[0] # type: ReceiptsStream.ReceiptsStreamRow @@ -71,7 +71,7 @@ class ReceiptsStreamTestCase(BaseStreamTestCase): # We should now have caught up and get the missing data self.test_handler.on_rdata.assert_called_once() - stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0] + stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0] self.assertEqual(stream_name, "receipts") self.assertEqual(token, 3) self.assertEqual(1, len(rdata_rows)) diff --git a/tests/replication/tcp/streams/test_typing.py b/tests/replication/tcp/streams/test_typing.py index e8d17ca68a..d25a7b194e 100644 --- a/tests/replication/tcp/streams/test_typing.py +++ b/tests/replication/tcp/streams/test_typing.py @@ -47,7 +47,7 @@ class TypingStreamTestCase(BaseStreamTestCase): self.assert_request_is_get_repl_stream_updates(request, "typing") self.test_handler.on_rdata.assert_called_once() - stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0] + stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0] self.assertEqual(stream_name, "typing") self.assertEqual(1, len(rdata_rows)) row = rdata_rows[0] # type: TypingStream.TypingStreamRow @@ -74,7 +74,7 @@ class TypingStreamTestCase(BaseStreamTestCase): self.assertEqual(int(request.args[b"from_token"][0]), token) self.test_handler.on_rdata.assert_called_once() - stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0] + stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0] self.assertEqual(stream_name, "typing") self.assertEqual(1, len(rdata_rows)) row = rdata_rows[0] -- cgit 1.5.1 From 032e5a2acaa26511a00801eedd71e35d21c791fa Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 1 May 2020 15:28:59 -0400 Subject: Convert synapse.server_notices to async/await. (#7394) --- changelog.d/7394.misc | 1 + synapse/server_notices/resource_limits_server_notices.py | 16 ++++++---------- .../test_resource_limits_server_notices.py | 2 +- 3 files changed, 8 insertions(+), 11 deletions(-) create mode 100644 changelog.d/7394.misc (limited to 'tests') diff --git a/changelog.d/7394.misc b/changelog.d/7394.misc new file mode 100644 index 0000000000..f1390308b3 --- /dev/null +++ b/changelog.d/7394.misc @@ -0,0 +1 @@ +Convert synapse.server_notices to async/await. diff --git a/synapse/server_notices/resource_limits_server_notices.py b/synapse/server_notices/resource_limits_server_notices.py index 971771b8b2..d97166351e 100644 --- a/synapse/server_notices/resource_limits_server_notices.py +++ b/synapse/server_notices/resource_limits_server_notices.py @@ -16,8 +16,6 @@ import logging from six import iteritems -from twisted.internet import defer - from synapse.api.constants import ( EventTypes, LimitBlockingTypes, @@ -167,8 +165,7 @@ class ResourceLimitsServerNotices(object): user_id, content, EventTypes.Pinned, "" ) - @defer.inlineCallbacks - def _check_and_set_tags(self, user_id, room_id): + async def _check_and_set_tags(self, user_id, room_id): """ Since server notices rooms were originally not with tags, important to check that tags have been set correctly @@ -176,20 +173,19 @@ class ResourceLimitsServerNotices(object): user_id(str): the user in question room_id(str): the server notices room for that user """ - tags = yield self._store.get_tags_for_room(user_id, room_id) + tags = await self._store.get_tags_for_room(user_id, room_id) need_to_set_tag = True if tags: if SERVER_NOTICE_ROOM_TAG in tags: # tag already present, nothing to do here need_to_set_tag = False if need_to_set_tag: - max_id = yield self._store.add_tag_to_room( + max_id = await self._store.add_tag_to_room( user_id, room_id, SERVER_NOTICE_ROOM_TAG, {} ) self._notifier.on_new_event("account_data_key", max_id, users=[user_id]) - @defer.inlineCallbacks - def _is_room_currently_blocked(self, room_id): + async def _is_room_currently_blocked(self, room_id): """ Determines if the room is currently blocked @@ -207,7 +203,7 @@ class ResourceLimitsServerNotices(object): currently_blocked = False pinned_state_event = None try: - pinned_state_event = yield self._state.get_current_state( + pinned_state_event = await self._state.get_current_state( room_id, event_type=EventTypes.Pinned ) except AuthError: @@ -218,7 +214,7 @@ class ResourceLimitsServerNotices(object): if pinned_state_event is not None: referenced_events = list(pinned_state_event.content.get("pinned", [])) - events = yield self._store.get_events(referenced_events) + events = await self._store.get_events(referenced_events) for event_id, event in iteritems(events): if event.type != EventTypes.Message: continue diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index 987addad9b..406f29a7c0 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -67,7 +67,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): return_value=defer.succeed("!something:localhost") ) self._rlsn._store.add_tag_to_room = Mock(return_value=defer.succeed(None)) - self._rlsn._store.get_tags_for_room = Mock(return_value={}) + self._rlsn._store.get_tags_for_room = Mock(return_value=defer.succeed({})) self.hs.config.admin_contact = "mailto:user@test.com" def test_maybe_send_server_notice_to_user_flag_off(self): -- cgit 1.5.1 From 7941a70fa8b297b0dec320a9b7dda01df3efe1e4 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 5 May 2020 14:17:27 +0100 Subject: Fix bug in EventContext.deserialize. (#7393) This caused `prev_state_ids` to be incorrect if the state event was not replacing an existing state entry. --- changelog.d/7393.bugfix | 1 + synapse/events/snapshot.py | 7 ++- tests/events/test_snapshot.py | 100 ++++++++++++++++++++++++++++++++++++ tests/test_utils/event_injection.py | 26 +++++++--- 4 files changed, 126 insertions(+), 8 deletions(-) create mode 100644 changelog.d/7393.bugfix create mode 100644 tests/events/test_snapshot.py (limited to 'tests') diff --git a/changelog.d/7393.bugfix b/changelog.d/7393.bugfix new file mode 100644 index 0000000000..74419af858 --- /dev/null +++ b/changelog.d/7393.bugfix @@ -0,0 +1 @@ +Fix bug in `EventContext.deserialize`. diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index 9ea85e93e6..7c5f620d09 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -322,11 +322,14 @@ class _AsyncEventContextImpl(EventContext): self._current_state_ids = yield self._storage.state.get_state_ids_for_group( self.state_group ) - if self._prev_state_id and self._event_state_key is not None: + if self._event_state_key is not None: self._prev_state_ids = dict(self._current_state_ids) key = (self._event_type, self._event_state_key) - self._prev_state_ids[key] = self._prev_state_id + if self._prev_state_id: + self._prev_state_ids[key] = self._prev_state_id + else: + self._prev_state_ids.pop(key, None) else: self._prev_state_ids = self._current_state_ids diff --git a/tests/events/test_snapshot.py b/tests/events/test_snapshot.py new file mode 100644 index 0000000000..640f5f3bce --- /dev/null +++ b/tests/events/test_snapshot.py @@ -0,0 +1,100 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.events.snapshot import EventContext +from synapse.rest import admin +from synapse.rest.client.v1 import login, room + +from tests import unittest +from tests.test_utils.event_injection import create_event + + +class TestEventContext(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + self.storage = hs.get_storage() + + self.user_id = self.register_user("u1", "pass") + self.user_tok = self.login("u1", "pass") + self.room_id = self.helper.create_room_as(tok=self.user_tok) + + def test_serialize_deserialize_msg(self): + """Test that an EventContext for a message event is the same after + serialize/deserialize. + """ + + event, context = create_event( + self.hs, room_id=self.room_id, type="m.test", sender=self.user_id, + ) + + self._check_serialize_deserialize(event, context) + + def test_serialize_deserialize_state_no_prev(self): + """Test that an EventContext for a state event (with not previous entry) + is the same after serialize/deserialize. + """ + event, context = create_event( + self.hs, + room_id=self.room_id, + type="m.test", + sender=self.user_id, + state_key="", + ) + + self._check_serialize_deserialize(event, context) + + def test_serialize_deserialize_state_prev(self): + """Test that an EventContext for a state event (which replaces a + previous entry) is the same after serialize/deserialize. + """ + event, context = create_event( + self.hs, + room_id=self.room_id, + type="m.room.member", + sender=self.user_id, + state_key=self.user_id, + content={"membership": "leave"}, + ) + + self._check_serialize_deserialize(event, context) + + def _check_serialize_deserialize(self, event, context): + serialized = self.get_success(context.serialize(event, self.store)) + + d_context = EventContext.deserialize(self.storage, serialized) + + self.assertEqual(context.state_group, d_context.state_group) + self.assertEqual(context.rejected, d_context.rejected) + self.assertEqual( + context.state_group_before_event, d_context.state_group_before_event + ) + self.assertEqual(context.prev_group, d_context.prev_group) + self.assertEqual(context.delta_ids, d_context.delta_ids) + self.assertEqual(context.app_service, d_context.app_service) + + self.assertEqual( + self.get_success(context.get_current_state_ids()), + self.get_success(d_context.get_current_state_ids()), + ) + self.assertEqual( + self.get_success(context.get_prev_state_ids()), + self.get_success(d_context.get_prev_state_ids()), + ) diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py index 8f6872761a..431e9f8e5e 100644 --- a/tests/test_utils/event_injection.py +++ b/tests/test_utils/event_injection.py @@ -14,12 +14,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import Optional, Tuple import synapse.server from synapse.api.constants import EventTypes from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import EventBase +from synapse.events.snapshot import EventContext from synapse.types import Collection from tests.test_utils import get_awaitable_result @@ -75,6 +76,23 @@ def inject_event( """ test_reactor = hs.get_reactor() + event, context = create_event(hs, room_version, prev_event_ids, **kwargs) + + d = hs.get_storage().persistence.persist_event(event, context) + test_reactor.advance(0) + get_awaitable_result(d) + + return event + + +def create_event( + hs: synapse.server.HomeServer, + room_version: Optional[str] = None, + prev_event_ids: Optional[Collection[str]] = None, + **kwargs +) -> Tuple[EventBase, EventContext]: + test_reactor = hs.get_reactor() + if room_version is None: d = hs.get_datastore().get_room_version_id(kwargs["room_id"]) test_reactor.advance(0) @@ -89,8 +107,4 @@ def inject_event( test_reactor.advance(0) event, context = get_awaitable_result(d) - d = hs.get_storage().persistence.persist_event(event, context) - test_reactor.advance(0) - get_awaitable_result(d) - - return event + return event, context -- cgit 1.5.1 From aee9130a83eec7a295071a7e3ea87ce380c4521c Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Wed, 6 May 2020 15:54:58 +0100 Subject: Stop Auth methods from polling the config on every req. (#7420) --- changelog.d/7420.misc | 1 + synapse/api/auth.py | 83 +++++----------------------------- synapse/api/auth_blocking.py | 104 +++++++++++++++++++++++++++++++++++++++++++ tests/api/test_auth.py | 36 ++++++++------- tests/handlers/test_auth.py | 23 ++++++---- tests/handlers/test_sync.py | 13 +++--- tests/test_mau.py | 14 ++++-- 7 files changed, 168 insertions(+), 106 deletions(-) create mode 100644 changelog.d/7420.misc create mode 100644 synapse/api/auth_blocking.py (limited to 'tests') diff --git a/changelog.d/7420.misc b/changelog.d/7420.misc new file mode 100644 index 0000000000..e834a9163e --- /dev/null +++ b/changelog.d/7420.misc @@ -0,0 +1 @@ +Prevent methods in `synapse.handlers.auth` from polling the homeserver config every request. \ No newline at end of file diff --git a/synapse/api/auth.py b/synapse/api/auth.py index c5d1eb952b..1ad5ff9410 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -26,16 +26,15 @@ from twisted.internet import defer import synapse.logging.opentracing as opentracing import synapse.types from synapse import event_auth -from synapse.api.constants import EventTypes, LimitBlockingTypes, Membership, UserTypes +from synapse.api.auth_blocking import AuthBlocking +from synapse.api.constants import EventTypes, Membership from synapse.api.errors import ( AuthError, Codes, InvalidClientTokenError, MissingClientTokenError, - ResourceLimitError, ) from synapse.api.room_versions import KNOWN_ROOM_VERSIONS -from synapse.config.server import is_threepid_reserved from synapse.events import EventBase from synapse.types import StateMap, UserID from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache @@ -77,7 +76,11 @@ class Auth(object): self.token_cache = LruCache(CACHE_SIZE_FACTOR * 10000) register_cache("cache", "token_cache", self.token_cache) + self._auth_blocking = AuthBlocking(self.hs) + self._account_validity = hs.config.account_validity + self._track_appservice_user_ips = hs.config.track_appservice_user_ips + self._macaroon_secret_key = hs.config.macaroon_secret_key @defer.inlineCallbacks def check_from_context(self, room_version: str, event, context, do_sig_check=True): @@ -191,7 +194,7 @@ class Auth(object): opentracing.set_tag("authenticated_entity", user_id) opentracing.set_tag("appservice_id", app_service.id) - if ip_addr and self.hs.config.track_appservice_user_ips: + if ip_addr and self._track_appservice_user_ips: yield self.store.insert_client_ip( user_id=user_id, access_token=access_token, @@ -454,7 +457,7 @@ class Auth(object): # access_tokens include a nonce for uniqueness: any value is acceptable v.satisfy_general(lambda c: c.startswith("nonce = ")) - v.verify(macaroon, self.hs.config.macaroon_secret_key) + v.verify(macaroon, self._macaroon_secret_key) def _verify_expiry(self, caveat): prefix = "time < " @@ -663,71 +666,5 @@ class Auth(object): % (user_id, room_id), ) - @defer.inlineCallbacks - def check_auth_blocking(self, user_id=None, threepid=None, user_type=None): - """Checks if the user should be rejected for some external reason, - such as monthly active user limiting or global disable flag - - Args: - user_id(str|None): If present, checks for presence against existing - MAU cohort - - threepid(dict|None): If present, checks for presence against configured - reserved threepid. Used in cases where the user is trying register - with a MAU blocked server, normally they would be rejected but their - threepid is on the reserved list. user_id and - threepid should never be set at the same time. - - user_type(str|None): If present, is used to decide whether to check against - certain blocking reasons like MAU. - """ - - # Never fail an auth check for the server notices users or support user - # This can be a problem where event creation is prohibited due to blocking - if user_id is not None: - if user_id == self.hs.config.server_notices_mxid: - return - if (yield self.store.is_support_user(user_id)): - return - - if self.hs.config.hs_disabled: - raise ResourceLimitError( - 403, - self.hs.config.hs_disabled_message, - errcode=Codes.RESOURCE_LIMIT_EXCEEDED, - admin_contact=self.hs.config.admin_contact, - limit_type=LimitBlockingTypes.HS_DISABLED, - ) - if self.hs.config.limit_usage_by_mau is True: - assert not (user_id and threepid) - - # If the user is already part of the MAU cohort or a trial user - if user_id: - timestamp = yield self.store.user_last_seen_monthly_active(user_id) - if timestamp: - return - - is_trial = yield self.store.is_trial_user(user_id) - if is_trial: - return - elif threepid: - # If the user does not exist yet, but is signing up with a - # reserved threepid then pass auth check - if is_threepid_reserved( - self.hs.config.mau_limits_reserved_threepids, threepid - ): - return - elif user_type == UserTypes.SUPPORT: - # If the user does not exist yet and is of type "support", - # allow registration. Support users are excluded from MAU checks. - return - # Else if there is no room in the MAU bucket, bail - current_mau = yield self.store.get_monthly_active_count() - if current_mau >= self.hs.config.max_mau_value: - raise ResourceLimitError( - 403, - "Monthly Active User Limit Exceeded", - admin_contact=self.hs.config.admin_contact, - errcode=Codes.RESOURCE_LIMIT_EXCEEDED, - limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER, - ) + def check_auth_blocking(self, *args, **kwargs): + return self._auth_blocking.check_auth_blocking(*args, **kwargs) diff --git a/synapse/api/auth_blocking.py b/synapse/api/auth_blocking.py new file mode 100644 index 0000000000..5c499b6b4e --- /dev/null +++ b/synapse/api/auth_blocking.py @@ -0,0 +1,104 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from twisted.internet import defer + +from synapse.api.constants import LimitBlockingTypes, UserTypes +from synapse.api.errors import Codes, ResourceLimitError +from synapse.config.server import is_threepid_reserved + +logger = logging.getLogger(__name__) + + +class AuthBlocking(object): + def __init__(self, hs): + self.store = hs.get_datastore() + + self._server_notices_mxid = hs.config.server_notices_mxid + self._hs_disabled = hs.config.hs_disabled + self._hs_disabled_message = hs.config.hs_disabled_message + self._admin_contact = hs.config.admin_contact + self._max_mau_value = hs.config.max_mau_value + self._limit_usage_by_mau = hs.config.limit_usage_by_mau + self._mau_limits_reserved_threepids = hs.config.mau_limits_reserved_threepids + + @defer.inlineCallbacks + def check_auth_blocking(self, user_id=None, threepid=None, user_type=None): + """Checks if the user should be rejected for some external reason, + such as monthly active user limiting or global disable flag + + Args: + user_id(str|None): If present, checks for presence against existing + MAU cohort + + threepid(dict|None): If present, checks for presence against configured + reserved threepid. Used in cases where the user is trying register + with a MAU blocked server, normally they would be rejected but their + threepid is on the reserved list. user_id and + threepid should never be set at the same time. + + user_type(str|None): If present, is used to decide whether to check against + certain blocking reasons like MAU. + """ + + # Never fail an auth check for the server notices users or support user + # This can be a problem where event creation is prohibited due to blocking + if user_id is not None: + if user_id == self._server_notices_mxid: + return + if (yield self.store.is_support_user(user_id)): + return + + if self._hs_disabled: + raise ResourceLimitError( + 403, + self._hs_disabled_message, + errcode=Codes.RESOURCE_LIMIT_EXCEEDED, + admin_contact=self._admin_contact, + limit_type=LimitBlockingTypes.HS_DISABLED, + ) + if self._limit_usage_by_mau is True: + assert not (user_id and threepid) + + # If the user is already part of the MAU cohort or a trial user + if user_id: + timestamp = yield self.store.user_last_seen_monthly_active(user_id) + if timestamp: + return + + is_trial = yield self.store.is_trial_user(user_id) + if is_trial: + return + elif threepid: + # If the user does not exist yet, but is signing up with a + # reserved threepid then pass auth check + if is_threepid_reserved(self._mau_limits_reserved_threepids, threepid): + return + elif user_type == UserTypes.SUPPORT: + # If the user does not exist yet and is of type "support", + # allow registration. Support users are excluded from MAU checks. + return + # Else if there is no room in the MAU bucket, bail + current_mau = yield self.store.get_monthly_active_count() + if current_mau >= self._max_mau_value: + raise ResourceLimitError( + 403, + "Monthly Active User Limit Exceeded", + admin_contact=self._admin_contact, + errcode=Codes.RESOURCE_LIMIT_EXCEEDED, + limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER, + ) diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index cc0b10e7f6..0bfb86bf1f 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -52,6 +52,10 @@ class AuthTestCase(unittest.TestCase): self.hs.handlers = TestHandlers(self.hs) self.auth = Auth(self.hs) + # AuthBlocking reads from the hs' config on initialization. We need to + # modify its config instead of the hs' + self.auth_blocking = self.auth._auth_blocking + self.test_user = "@foo:bar" self.test_token = b"_test_token_" @@ -321,15 +325,15 @@ class AuthTestCase(unittest.TestCase): @defer.inlineCallbacks def test_blocking_mau(self): - self.hs.config.limit_usage_by_mau = False - self.hs.config.max_mau_value = 50 + self.auth_blocking._limit_usage_by_mau = False + self.auth_blocking._max_mau_value = 50 lots_of_users = 100 small_number_of_users = 1 # Ensure no error thrown yield defer.ensureDeferred(self.auth.check_auth_blocking()) - self.hs.config.limit_usage_by_mau = True + self.auth_blocking._limit_usage_by_mau = True self.store.get_monthly_active_count = Mock( return_value=defer.succeed(lots_of_users) @@ -349,8 +353,8 @@ class AuthTestCase(unittest.TestCase): @defer.inlineCallbacks def test_blocking_mau__depending_on_user_type(self): - self.hs.config.max_mau_value = 50 - self.hs.config.limit_usage_by_mau = True + self.auth_blocking._max_mau_value = 50 + self.auth_blocking._limit_usage_by_mau = True self.store.get_monthly_active_count = Mock(return_value=defer.succeed(100)) # Support users allowed @@ -370,12 +374,12 @@ class AuthTestCase(unittest.TestCase): @defer.inlineCallbacks def test_reserved_threepid(self): - self.hs.config.limit_usage_by_mau = True - self.hs.config.max_mau_value = 1 + self.auth_blocking._limit_usage_by_mau = True + self.auth_blocking._max_mau_value = 1 self.store.get_monthly_active_count = lambda: defer.succeed(2) threepid = {"medium": "email", "address": "reserved@server.com"} unknown_threepid = {"medium": "email", "address": "unreserved@server.com"} - self.hs.config.mau_limits_reserved_threepids = [threepid] + self.auth_blocking._mau_limits_reserved_threepids = [threepid] with self.assertRaises(ResourceLimitError): yield defer.ensureDeferred(self.auth.check_auth_blocking()) @@ -389,8 +393,8 @@ class AuthTestCase(unittest.TestCase): @defer.inlineCallbacks def test_hs_disabled(self): - self.hs.config.hs_disabled = True - self.hs.config.hs_disabled_message = "Reason for being disabled" + self.auth_blocking._hs_disabled = True + self.auth_blocking._hs_disabled_message = "Reason for being disabled" with self.assertRaises(ResourceLimitError) as e: yield defer.ensureDeferred(self.auth.check_auth_blocking()) self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact) @@ -404,10 +408,10 @@ class AuthTestCase(unittest.TestCase): """ # this should be the default, but we had a bug where the test was doing the wrong # thing, so let's make it explicit - self.hs.config.server_notices_mxid = None + self.auth_blocking._server_notices_mxid = None - self.hs.config.hs_disabled = True - self.hs.config.hs_disabled_message = "Reason for being disabled" + self.auth_blocking._hs_disabled = True + self.auth_blocking._hs_disabled_message = "Reason for being disabled" with self.assertRaises(ResourceLimitError) as e: yield defer.ensureDeferred(self.auth.check_auth_blocking()) self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact) @@ -416,8 +420,8 @@ class AuthTestCase(unittest.TestCase): @defer.inlineCallbacks def test_server_notices_mxid_special_cased(self): - self.hs.config.hs_disabled = True + self.auth_blocking._hs_disabled = True user = "@user:server" - self.hs.config.server_notices_mxid = user - self.hs.config.hs_disabled_message = "Reason for being disabled" + self.auth_blocking._server_notices_mxid = user + self.auth_blocking._hs_disabled_message = "Reason for being disabled" yield defer.ensureDeferred(self.auth.check_auth_blocking(user)) diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index 52c4ac8b11..c01b04e1dc 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -39,8 +39,13 @@ class AuthTestCase(unittest.TestCase): self.hs.handlers = AuthHandlers(self.hs) self.auth_handler = self.hs.handlers.auth_handler self.macaroon_generator = self.hs.get_macaroon_generator() + # MAU tests - self.hs.config.max_mau_value = 50 + # AuthBlocking reads from the hs' config on initialization. We need to + # modify its config instead of the hs' + self.auth_blocking = self.hs.get_auth()._auth_blocking + self.auth_blocking._max_mau_value = 50 + self.small_number_of_users = 1 self.large_number_of_users = 100 @@ -119,7 +124,7 @@ class AuthTestCase(unittest.TestCase): @defer.inlineCallbacks def test_mau_limits_disabled(self): - self.hs.config.limit_usage_by_mau = False + self.auth_blocking._limit_usage_by_mau = False # Ensure does not throw exception yield defer.ensureDeferred( self.auth_handler.get_access_token_for_user_id( @@ -135,7 +140,7 @@ class AuthTestCase(unittest.TestCase): @defer.inlineCallbacks def test_mau_limits_exceeded_large(self): - self.hs.config.limit_usage_by_mau = True + self.auth_blocking._limit_usage_by_mau = True self.hs.get_datastore().get_monthly_active_count = Mock( return_value=defer.succeed(self.large_number_of_users) ) @@ -159,11 +164,11 @@ class AuthTestCase(unittest.TestCase): @defer.inlineCallbacks def test_mau_limits_parity(self): - self.hs.config.limit_usage_by_mau = True + self.auth_blocking._limit_usage_by_mau = True # If not in monthly active cohort self.hs.get_datastore().get_monthly_active_count = Mock( - return_value=defer.succeed(self.hs.config.max_mau_value) + return_value=defer.succeed(self.auth_blocking._max_mau_value) ) with self.assertRaises(ResourceLimitError): yield defer.ensureDeferred( @@ -173,7 +178,7 @@ class AuthTestCase(unittest.TestCase): ) self.hs.get_datastore().get_monthly_active_count = Mock( - return_value=defer.succeed(self.hs.config.max_mau_value) + return_value=defer.succeed(self.auth_blocking._max_mau_value) ) with self.assertRaises(ResourceLimitError): yield defer.ensureDeferred( @@ -186,7 +191,7 @@ class AuthTestCase(unittest.TestCase): return_value=defer.succeed(self.hs.get_clock().time_msec()) ) self.hs.get_datastore().get_monthly_active_count = Mock( - return_value=defer.succeed(self.hs.config.max_mau_value) + return_value=defer.succeed(self.auth_blocking._max_mau_value) ) yield defer.ensureDeferred( self.auth_handler.get_access_token_for_user_id( @@ -197,7 +202,7 @@ class AuthTestCase(unittest.TestCase): return_value=defer.succeed(self.hs.get_clock().time_msec()) ) self.hs.get_datastore().get_monthly_active_count = Mock( - return_value=defer.succeed(self.hs.config.max_mau_value) + return_value=defer.succeed(self.auth_blocking._max_mau_value) ) yield defer.ensureDeferred( self.auth_handler.validate_short_term_login_token_and_get_user_id( @@ -207,7 +212,7 @@ class AuthTestCase(unittest.TestCase): @defer.inlineCallbacks def test_mau_limits_not_exceeded(self): - self.hs.config.limit_usage_by_mau = True + self.auth_blocking._limit_usage_by_mau = True self.hs.get_datastore().get_monthly_active_count = Mock( return_value=defer.succeed(self.small_number_of_users) diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py index 4cbe9784ed..e178d7765b 100644 --- a/tests/handlers/test_sync.py +++ b/tests/handlers/test_sync.py @@ -30,28 +30,31 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): self.sync_handler = self.hs.get_sync_handler() self.store = self.hs.get_datastore() - def test_wait_for_sync_for_user_auth_blocking(self): + # AuthBlocking reads from the hs' config on initialization. We need to + # modify its config instead of the hs' + self.auth_blocking = self.hs.get_auth()._auth_blocking + def test_wait_for_sync_for_user_auth_blocking(self): user_id1 = "@user1:test" user_id2 = "@user2:test" sync_config = self._generate_sync_config(user_id1) self.reactor.advance(100) # So we get not 0 time - self.hs.config.limit_usage_by_mau = True - self.hs.config.max_mau_value = 1 + self.auth_blocking._limit_usage_by_mau = True + self.auth_blocking._max_mau_value = 1 # Check that the happy case does not throw errors self.get_success(self.store.upsert_monthly_active_user(user_id1)) self.get_success(self.sync_handler.wait_for_sync_for_user(sync_config)) # Test that global lock works - self.hs.config.hs_disabled = True + self.auth_blocking._hs_disabled = True e = self.get_failure( self.sync_handler.wait_for_sync_for_user(sync_config), ResourceLimitError ) self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) - self.hs.config.hs_disabled = False + self.auth_blocking._hs_disabled = False sync_config = self._generate_sync_config(user_id2) diff --git a/tests/test_mau.py b/tests/test_mau.py index 1fbe0d51ff..eb159e3ba5 100644 --- a/tests/test_mau.py +++ b/tests/test_mau.py @@ -19,6 +19,7 @@ import json from mock import Mock +from synapse.api.auth_blocking import AuthBlocking from synapse.api.constants import LoginType from synapse.api.errors import Codes, HttpResponseException, SynapseError from synapse.rest.client.v2_alpha import register, sync @@ -45,11 +46,17 @@ class TestMauLimit(unittest.HomeserverTestCase): self.hs.config.limit_usage_by_mau = True self.hs.config.hs_disabled = False self.hs.config.max_mau_value = 2 - self.hs.config.mau_trial_days = 0 self.hs.config.server_notices_mxid = "@server:red" self.hs.config.server_notices_mxid_display_name = None self.hs.config.server_notices_mxid_avatar_url = None self.hs.config.server_notices_room_name = "Test Server Notice Room" + self.hs.config.mau_trial_days = 0 + + # AuthBlocking reads config options during hs creation. Recreate the + # hs' copy of AuthBlocking after we've updated config values above + self.auth_blocking = AuthBlocking(self.hs) + self.hs.get_auth()._auth_blocking = self.auth_blocking + return self.hs def test_simple_deny_mau(self): @@ -121,6 +128,7 @@ class TestMauLimit(unittest.HomeserverTestCase): self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) def test_trial_users_cant_come_back(self): + self.auth_blocking._mau_trial_days = 1 self.hs.config.mau_trial_days = 1 # We should be able to register more than the limit initially @@ -169,8 +177,8 @@ class TestMauLimit(unittest.HomeserverTestCase): self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) def test_tracked_but_not_limited(self): - self.hs.config.max_mau_value = 1 # should not matter - self.hs.config.limit_usage_by_mau = False + self.auth_blocking._max_mau_value = 1 # should not matter + self.auth_blocking._limit_usage_by_mau = False self.hs.config.mau_stats_only = True # Simply being able to create 2 users indicates that the -- cgit 1.5.1 From 0ad6d28b0dec06d5e7478984280b4e81ef0f0256 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 8 May 2020 16:08:58 -0400 Subject: Rework UI Auth session validation for registration (#7455) Be less strict about validation of UI authentication sessions during registration to match client expecations. --- changelog.d/7455.bugfix | 1 + synapse/handlers/auth.py | 54 +++-- synapse/rest/client/v2_alpha/register.py | 1 + synapse/storage/data_stores/main/ui_auth.py | 21 ++ tests/rest/client/v2_alpha/test_auth.py | 304 ++++++++++++++++++++-------- tox.ini | 1 + 6 files changed, 280 insertions(+), 102 deletions(-) create mode 100644 changelog.d/7455.bugfix (limited to 'tests') diff --git a/changelog.d/7455.bugfix b/changelog.d/7455.bugfix new file mode 100644 index 0000000000..d1693a7f22 --- /dev/null +++ b/changelog.d/7455.bugfix @@ -0,0 +1 @@ +Ensure that a user inteactive authentication session is tied to a single request. diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 7613e5b6ab..9c71702371 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -252,6 +252,7 @@ class AuthHandler(BaseHandler): clientdict: Dict[str, Any], clientip: str, description: str, + validate_clientdict: bool = True, ) -> Tuple[dict, dict, str]: """ Takes a dictionary sent by the client in the login / registration @@ -277,6 +278,10 @@ class AuthHandler(BaseHandler): description: A human readable string to be displayed to the user that describes the operation happening on their account. + validate_clientdict: Whether to validate that the operation happening + on the account has not changed. If this is false, + the client dict is persisted instead of validated. + Returns: A tuple of (creds, params, session_id). @@ -317,30 +322,51 @@ class AuthHandler(BaseHandler): except StoreError: raise SynapseError(400, "Unknown session ID: %s" % (sid,)) + # If the client provides parameters, update what is persisted, + # otherwise use whatever was last provided. + # + # This was designed to allow the client to omit the parameters + # and just supply the session in subsequent calls so it split + # auth between devices by just sharing the session, (eg. so you + # could continue registration from your phone having clicked the + # email auth link on there). It's probably too open to abuse + # because it lets unauthenticated clients store arbitrary objects + # on a homeserver. + # + # Revisit: Assuming the REST APIs do sensible validation, the data + # isn't arbitrary. + # + # Note that the registration endpoint explicitly removes the + # "initial_device_display_name" parameter if it is provided + # without a "password" parameter. See the changes to + # synapse.rest.client.v2_alpha.register.RegisterRestServlet.on_POST + # in commit 544722bad23fc31056b9240189c3cbbbf0ffd3f9. if not clientdict: - # This was designed to allow the client to omit the parameters - # and just supply the session in subsequent calls so it split - # auth between devices by just sharing the session, (eg. so you - # could continue registration from your phone having clicked the - # email auth link on there). It's probably too open to abuse - # because it lets unauthenticated clients store arbitrary objects - # on a homeserver. - # Revisit: Assuming the REST APIs do sensible validation, the data - # isn't arbitrary. clientdict = session.clientdict # Ensure that the queried operation does not vary between stages of # the UI authentication session. This is done by generating a stable - # comparator based on the URI, method, and body (minus the auth dict) - # and storing it during the initial query. Subsequent queries ensure - # that this comparator has not changed. - comparator = (uri, method, clientdict) - if (session.uri, session.method, session.clientdict) != comparator: + # comparator based on the URI, method, and client dict (minus the + # auth dict) and storing it during the initial query. Subsequent + # queries ensure that this comparator has not changed. + if validate_clientdict: + session_comparator = (session.uri, session.method, session.clientdict) + comparator = (uri, method, clientdict) + else: + session_comparator = (session.uri, session.method) # type: ignore + comparator = (uri, method) # type: ignore + + if session_comparator != comparator: raise SynapseError( 403, "Requested operation has changed during the UI authentication session.", ) + # For backwards compatibility the registration endpoint persists + # changes to the client dict instead of validating them. + if not validate_clientdict: + await self.store.set_ui_auth_clientdict(sid, clientdict) + if not authdict: raise InteractiveAuthIncompleteError( self._auth_dict_for_flows(flows, session.session_id) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index af08cc6cce..e77dd6bf92 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -516,6 +516,7 @@ class RegisterRestServlet(RestServlet): body, self.hs.get_ip_from_request(request), "register a new account", + validate_clientdict=False, ) # Check that we're not trying to register a denied 3pid. diff --git a/synapse/storage/data_stores/main/ui_auth.py b/synapse/storage/data_stores/main/ui_auth.py index c8eebc9378..1d8ee22fb1 100644 --- a/synapse/storage/data_stores/main/ui_auth.py +++ b/synapse/storage/data_stores/main/ui_auth.py @@ -172,6 +172,27 @@ class UIAuthWorkerStore(SQLBaseStore): return results + async def set_ui_auth_clientdict( + self, session_id: str, clientdict: JsonDict + ) -> None: + """ + Store an updated clientdict for a given session ID. + + Args: + session_id: The ID of this session as returned from check_auth + clientdict: + The dictionary from the client root level, not the 'auth' key. + """ + # The clientdict gets stored as JSON. + clientdict_json = json.dumps(clientdict) + + self.db.simple_update_one( + table="ui_auth_sessions", + keyvalues={"session_id": session_id}, + updatevalues={"clientdict": clientdict_json}, + desc="set_ui_auth_client_dict", + ) + async def set_ui_auth_session_data(self, session_id: str, key: str, value: Any): """ Store a key-value pair into the sessions data associated with this diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py index 587be7b2e7..a56c50a5b7 100644 --- a/tests/rest/client/v2_alpha/test_auth.py +++ b/tests/rest/client/v2_alpha/test_auth.py @@ -12,16 +12,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +from typing import List, Union from twisted.internet.defer import succeed import synapse.rest.admin from synapse.api.constants import LoginType from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker -from synapse.rest.client.v2_alpha import auth, register +from synapse.http.site import SynapseRequest +from synapse.rest.client.v1 import login +from synapse.rest.client.v2_alpha import auth, devices, register +from synapse.types import JsonDict from tests import unittest +from tests.server import FakeChannel class DummyRecaptchaChecker(UserInteractiveAuthChecker): @@ -34,11 +38,15 @@ class DummyRecaptchaChecker(UserInteractiveAuthChecker): return succeed(True) +class DummyPasswordChecker(UserInteractiveAuthChecker): + def check_auth(self, authdict, clientip): + return succeed(authdict["identifier"]["user"]) + + class FallbackAuthTests(unittest.HomeserverTestCase): servlets = [ auth.register_servlets, - synapse.rest.admin.register_servlets_for_client_rest_resource, register.register_servlets, ] hijack_auth = False @@ -59,79 +67,84 @@ class FallbackAuthTests(unittest.HomeserverTestCase): auth_handler = hs.get_auth_handler() auth_handler.checkers[LoginType.RECAPTCHA] = self.recaptcha_checker - @unittest.INFO - def test_fallback_captcha(self): - + def register(self, expected_response: int, body: JsonDict) -> FakeChannel: + """Make a register request.""" request, channel = self.make_request( - "POST", - "register", - {"username": "user", "type": "m.login.password", "password": "bar"}, - ) + "POST", "register", body + ) # type: SynapseRequest, FakeChannel self.render(request) - # Returns a 401 as per the spec - self.assertEqual(request.code, 401) - # Grab the session - session = channel.json_body["session"] - # Assert our configured public key is being given - self.assertEqual( - channel.json_body["params"]["m.login.recaptcha"]["public_key"], "brokencake" - ) + self.assertEqual(request.code, expected_response) + return channel + + def recaptcha( + self, session: str, expected_post_response: int, post_session: str = None + ) -> None: + """Get and respond to a fallback recaptcha. Returns the second request.""" + if post_session is None: + post_session = session request, channel = self.make_request( "GET", "auth/m.login.recaptcha/fallback/web?session=" + session - ) + ) # type: SynapseRequest, FakeChannel self.render(request) self.assertEqual(request.code, 200) request, channel = self.make_request( "POST", "auth/m.login.recaptcha/fallback/web?session=" - + session + + post_session + "&g-recaptcha-response=a", ) self.render(request) - self.assertEqual(request.code, 200) + self.assertEqual(request.code, expected_post_response) # The recaptcha handler is called with the response given attempts = self.recaptcha_checker.recaptcha_attempts self.assertEqual(len(attempts), 1) self.assertEqual(attempts[0][0]["response"], "a") - # also complete the dummy auth - request, channel = self.make_request( - "POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}} + @unittest.INFO + def test_fallback_captcha(self): + """Ensure that fallback auth via a captcha works.""" + # Returns a 401 as per the spec + channel = self.register( + 401, {"username": "user", "type": "m.login.password", "password": "bar"}, ) - self.render(request) + + # Grab the session + session = channel.json_body["session"] + # Assert our configured public key is being given + self.assertEqual( + channel.json_body["params"]["m.login.recaptcha"]["public_key"], "brokencake" + ) + + # Complete the recaptcha step. + self.recaptcha(session, 200) + + # also complete the dummy auth + self.register(200, {"auth": {"session": session, "type": "m.login.dummy"}}) # Now we should have fulfilled a complete auth flow, including # the recaptcha fallback step, we can then send a # request to the register API with the session in the authdict. - request, channel = self.make_request( - "POST", "register", {"auth": {"session": session}} - ) - self.render(request) - self.assertEqual(channel.code, 200) + channel = self.register(200, {"auth": {"session": session}}) # We're given a registered user. self.assertEqual(channel.json_body["user_id"], "@user:test") - def test_cannot_change_operation(self): + def test_legacy_registration(self): """ - The initial requested operation cannot be modified during the user interactive authentication session. + Registration allows the parameters to vary through the process. """ # Make the initial request to register. (Later on a different password # will be used.) - request, channel = self.make_request( - "POST", - "register", - {"username": "user", "type": "m.login.password", "password": "bar"}, + # Returns a 401 as per the spec + channel = self.register( + 401, {"username": "user", "type": "m.login.password", "password": "bar"}, ) - self.render(request) - # Returns a 401 as per the spec - self.assertEqual(request.code, 401) # Grab the session session = channel.json_body["session"] # Assert our configured public key is being given @@ -139,65 +152,39 @@ class FallbackAuthTests(unittest.HomeserverTestCase): channel.json_body["params"]["m.login.recaptcha"]["public_key"], "brokencake" ) - request, channel = self.make_request( - "GET", "auth/m.login.recaptcha/fallback/web?session=" + session - ) - self.render(request) - self.assertEqual(request.code, 200) - - request, channel = self.make_request( - "POST", - "auth/m.login.recaptcha/fallback/web?session=" - + session - + "&g-recaptcha-response=a", - ) - self.render(request) - self.assertEqual(request.code, 200) - - # The recaptcha handler is called with the response given - attempts = self.recaptcha_checker.recaptcha_attempts - self.assertEqual(len(attempts), 1) - self.assertEqual(attempts[0][0]["response"], "a") + # Complete the recaptcha step. + self.recaptcha(session, 200) # also complete the dummy auth - request, channel = self.make_request( - "POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}} - ) - self.render(request) + self.register(200, {"auth": {"session": session, "type": "m.login.dummy"}}) # Now we should have fulfilled a complete auth flow, including # the recaptcha fallback step. Make the initial request again, but - # with a different password. This causes the request to fail since the - # operaiton was modified during the ui auth session. - request, channel = self.make_request( - "POST", - "register", + # with a changed password. This still completes. + channel = self.register( + 200, { "username": "user", "type": "m.login.password", - "password": "foo", # Note this doesn't match the original request. + "password": "foo", # Note that this is different. "auth": {"session": session}, }, ) - self.render(request) - self.assertEqual(channel.code, 403) + + # We're given a registered user. + self.assertEqual(channel.json_body["user_id"], "@user:test") def test_complete_operation_unknown_session(self): """ Attempting to mark an invalid session as complete should error. """ - # Make the initial request to register. (Later on a different password # will be used.) - request, channel = self.make_request( - "POST", - "register", - {"username": "user", "type": "m.login.password", "password": "bar"}, + # Returns a 401 as per the spec + channel = self.register( + 401, {"username": "user", "type": "m.login.password", "password": "bar"} ) - self.render(request) - # Returns a 401 as per the spec - self.assertEqual(request.code, 401) # Grab the session session = channel.json_body["session"] # Assert our configured public key is being given @@ -205,19 +192,160 @@ class FallbackAuthTests(unittest.HomeserverTestCase): channel.json_body["params"]["m.login.recaptcha"]["public_key"], "brokencake" ) + # Attempt to complete the recaptcha step with an unknown session. + # This results in an error. + self.recaptcha(session, 400, session + "unknown") + + +class UIAuthTests(unittest.HomeserverTestCase): + servlets = [ + auth.register_servlets, + devices.register_servlets, + login.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, + register.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + auth_handler = hs.get_auth_handler() + auth_handler.checkers[LoginType.PASSWORD] = DummyPasswordChecker(hs) + + self.user_pass = "pass" + self.user = self.register_user("test", self.user_pass) + self.user_tok = self.login("test", self.user_pass) + + def get_device_ids(self) -> List[str]: + # Get the list of devices so one can be deleted. request, channel = self.make_request( - "GET", "auth/m.login.recaptcha/fallback/web?session=" + session - ) + "GET", "devices", access_token=self.user_tok, + ) # type: SynapseRequest, FakeChannel self.render(request) + + # Get the ID of the device. self.assertEqual(request.code, 200) + return [d["device_id"] for d in channel.json_body["devices"]] - # Attempt to complete an unknown session, which should return an error. - unknown_session = session + "unknown" + def delete_device( + self, device: str, expected_response: int, body: Union[bytes, JsonDict] = b"" + ) -> FakeChannel: + """Delete an individual device.""" request, channel = self.make_request( - "POST", - "auth/m.login.recaptcha/fallback/web?session=" - + unknown_session - + "&g-recaptcha-response=a", - ) + "DELETE", "devices/" + device, body, access_token=self.user_tok + ) # type: SynapseRequest, FakeChannel self.render(request) - self.assertEqual(request.code, 400) + + # Ensure the response is sane. + self.assertEqual(request.code, expected_response) + + return channel + + def delete_devices(self, expected_response: int, body: JsonDict) -> FakeChannel: + """Delete 1 or more devices.""" + # Note that this uses the delete_devices endpoint so that we can modify + # the payload half-way through some tests. + request, channel = self.make_request( + "POST", "delete_devices", body, access_token=self.user_tok, + ) # type: SynapseRequest, FakeChannel + self.render(request) + + # Ensure the response is sane. + self.assertEqual(request.code, expected_response) + + return channel + + def test_ui_auth(self): + """ + Test user interactive authentication outside of registration. + """ + device_id = self.get_device_ids()[0] + + # Attempt to delete this device. + # Returns a 401 as per the spec + channel = self.delete_device(device_id, 401) + + # Grab the session + session = channel.json_body["session"] + # Ensure that flows are what is expected. + self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"]) + + # Make another request providing the UI auth flow. + self.delete_device( + device_id, + 200, + { + "auth": { + "type": "m.login.password", + "identifier": {"type": "m.id.user", "user": self.user}, + "password": self.user_pass, + "session": session, + }, + }, + ) + + def test_cannot_change_body(self): + """ + The initial requested client dict cannot be modified during the user interactive authentication session. + """ + # Create a second login. + self.login("test", self.user_pass) + + device_ids = self.get_device_ids() + self.assertEqual(len(device_ids), 2) + + # Attempt to delete the first device. + # Returns a 401 as per the spec + channel = self.delete_devices(401, {"devices": [device_ids[0]]}) + + # Grab the session + session = channel.json_body["session"] + # Ensure that flows are what is expected. + self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"]) + + # Make another request providing the UI auth flow, but try to delete the + # second device. This results in an error. + self.delete_devices( + 403, + { + "devices": [device_ids[1]], + "auth": { + "type": "m.login.password", + "identifier": {"type": "m.id.user", "user": self.user}, + "password": self.user_pass, + "session": session, + }, + }, + ) + + def test_cannot_change_uri(self): + """ + The initial requested URI cannot be modified during the user interactive authentication session. + """ + # Create a second login. + self.login("test", self.user_pass) + + device_ids = self.get_device_ids() + self.assertEqual(len(device_ids), 2) + + # Attempt to delete the first device. + # Returns a 401 as per the spec + channel = self.delete_device(device_ids[0], 401) + + # Grab the session + session = channel.json_body["session"] + # Ensure that flows are what is expected. + self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"]) + + # Make another request providing the UI auth flow, but try to delete the + # second device. This results in an error. + self.delete_device( + device_ids[1], + 403, + { + "auth": { + "type": "m.login.password", + "identifier": {"type": "m.id.user", "user": self.user}, + "password": self.user_pass, + "session": session, + }, + }, + ) diff --git a/tox.ini b/tox.ini index eccc44e436..8aef52021d 100644 --- a/tox.ini +++ b/tox.ini @@ -207,6 +207,7 @@ commands = mypy \ synapse/util/caches/stream_change_cache.py \ tests/replication/tcp/streams \ tests/test_utils \ + tests/rest/client/v2_alpha/test_auth.py \ tests/util/test_stream_change_cache.py # To find all folders that pass mypy you run: -- cgit 1.5.1 From edd3b0747cc651d224fc1bf81ae7fbd6a25a2ea5 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 13 May 2020 08:24:50 -0400 Subject: Fix new flake8 errors (#7489) This is a cherry-pick of 1a1da60ad2c9172fe487cd38a164b39df60f4cb5 (#7470) to the release-v1.13.0 branch. --- changelog.d/7470.misc | 1 + synapse/app/_base.py | 5 +++-- synapse/config/server.py | 2 +- synapse/notifier.py | 10 ++++++---- synapse/push/mailer.py | 7 +++++-- synapse/storage/database.py | 4 ++-- tests/config/test_load.py | 2 +- 7 files changed, 19 insertions(+), 12 deletions(-) create mode 100644 changelog.d/7470.misc (limited to 'tests') diff --git a/changelog.d/7470.misc b/changelog.d/7470.misc new file mode 100644 index 0000000000..45e66ecf48 --- /dev/null +++ b/changelog.d/7470.misc @@ -0,0 +1 @@ +Fix linting errors in new version of Flake8. diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 628292b890..dedff81af3 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -22,6 +22,7 @@ import sys import traceback from daemonize import Daemonize +from typing_extensions import NoReturn from twisted.internet import defer, error, reactor from twisted.protocols.tls import TLSMemoryBIOFactory @@ -139,9 +140,9 @@ def start_reactor( run() -def quit_with_error(error_string): +def quit_with_error(error_string: str) -> NoReturn: message_lines = error_string.split("\n") - line_length = max(len(l) for l in message_lines if len(l) < 80) + 2 + line_length = max(len(line) for line in message_lines if len(line) < 80) + 2 sys.stderr.write("*" * line_length + "\n") for line in message_lines: sys.stderr.write(" %s\n" % (line.rstrip(),)) diff --git a/synapse/config/server.py b/synapse/config/server.py index 6d88231843..ed28da3deb 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -522,7 +522,7 @@ class ServerConfig(Config): ) def has_tls_listener(self) -> bool: - return any(l["tls"] for l in self.listeners) + return any(listener["tls"] for listener in self.listeners) def generate_config_section( self, server_name, data_dir_path, open_private_ports, listeners, **kwargs diff --git a/synapse/notifier.py b/synapse/notifier.py index 71d9ed62b0..87c120a59c 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -15,7 +15,7 @@ import logging from collections import namedtuple -from typing import Callable, List +from typing import Callable, Iterable, List, TypeVar from prometheus_client import Counter @@ -42,12 +42,14 @@ users_woken_by_stream_counter = Counter( "synapse_notifier_users_woken_by_stream", "", ["stream"] ) +T = TypeVar("T") + # TODO(paul): Should be shared somewhere -def count(func, l): - """Return the number of items in l for which func returns true.""" +def count(func: Callable[[T], bool], it: Iterable[T]) -> int: + """Return the number of items in it for which func returns true.""" n = 0 - for x in l: + for x in it: if func(x): n += 1 return n diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index 73580c1c6c..ab33abbeed 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -19,6 +19,7 @@ import logging import time from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText +from typing import Iterable, List, TypeVar from six.moves import urllib @@ -41,6 +42,8 @@ from synapse.visibility import filter_events_for_client logger = logging.getLogger(__name__) +T = TypeVar("T") + MESSAGE_FROM_PERSON_IN_ROOM = ( "You have a message on %(app)s from %(person)s in the %(room)s room..." @@ -638,10 +641,10 @@ def safe_text(raw_text): ) -def deduped_ordered_list(l): +def deduped_ordered_list(it: Iterable[T]) -> List[T]: seen = set() ret = [] - for item in l: + for item in it: if item not in seen: seen.add(item) ret.append(item) diff --git a/synapse/storage/database.py b/synapse/storage/database.py index a7cd97b0b0..50f475bfd3 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -212,9 +212,9 @@ class LoggingTransaction: def executemany(self, sql: str, *args: Any): self._do_execute(self.txn.executemany, sql, *args) - def _make_sql_one_line(self, sql): + def _make_sql_one_line(self, sql: str) -> str: "Strip newlines out of SQL so that the loggers in the DB are on one line" - return " ".join(l.strip() for l in sql.splitlines() if l.strip()) + return " ".join(line.strip() for line in sql.splitlines() if line.strip()) def _do_execute(self, func, sql, *args): sql = self._make_sql_one_line(sql) diff --git a/tests/config/test_load.py b/tests/config/test_load.py index b3e557bd6a..734a9983e8 100644 --- a/tests/config/test_load.py +++ b/tests/config/test_load.py @@ -122,7 +122,7 @@ class ConfigLoadingTestCase(unittest.TestCase): with open(self.file, "r") as f: contents = f.readlines() - contents = [l for l in contents if needle not in l] + contents = [line for line in contents if needle not in line] with open(self.file, "w") as f: f.write("".join(contents)) -- cgit 1.5.1 From 5d64fefd6c7790dac0209c6c32cdb97cd6cd8820 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 13 May 2020 14:26:44 -0400 Subject: Do not validate that the client dict is stable during UI Auth. (#7483) This backs out some of the validation for the client dictionary and logs if this changes during a user interactive authentication session instead. --- changelog.d/7483.bugfix | 1 + synapse/handlers/auth.py | 37 +++++++++++---------- synapse/rest/client/v2_alpha/register.py | 1 - tests/rest/client/v2_alpha/test_auth.py | 55 ++++++-------------------------- 4 files changed, 29 insertions(+), 65 deletions(-) create mode 100644 changelog.d/7483.bugfix (limited to 'tests') diff --git a/changelog.d/7483.bugfix b/changelog.d/7483.bugfix new file mode 100644 index 0000000000..e1bc324617 --- /dev/null +++ b/changelog.d/7483.bugfix @@ -0,0 +1 @@ +Restore compatibility with non-compliant clients during the user interactive authentication process. diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 9c71702371..5c20e29171 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -252,7 +252,6 @@ class AuthHandler(BaseHandler): clientdict: Dict[str, Any], clientip: str, description: str, - validate_clientdict: bool = True, ) -> Tuple[dict, dict, str]: """ Takes a dictionary sent by the client in the login / registration @@ -278,10 +277,6 @@ class AuthHandler(BaseHandler): description: A human readable string to be displayed to the user that describes the operation happening on their account. - validate_clientdict: Whether to validate that the operation happening - on the account has not changed. If this is false, - the client dict is persisted instead of validated. - Returns: A tuple of (creds, params, session_id). @@ -346,26 +341,30 @@ class AuthHandler(BaseHandler): # Ensure that the queried operation does not vary between stages of # the UI authentication session. This is done by generating a stable - # comparator based on the URI, method, and client dict (minus the - # auth dict) and storing it during the initial query. Subsequent + # comparator and storing it during the initial query. Subsequent # queries ensure that this comparator has not changed. - if validate_clientdict: - session_comparator = (session.uri, session.method, session.clientdict) - comparator = (uri, method, clientdict) - else: - session_comparator = (session.uri, session.method) # type: ignore - comparator = (uri, method) # type: ignore - - if session_comparator != comparator: + # + # The comparator is based on the requested URI and HTTP method. The + # client dict (minus the auth dict) should also be checked, but some + # clients are not spec compliant, just warn for now if the client + # dict changes. + if (session.uri, session.method) != (uri, method): raise SynapseError( 403, "Requested operation has changed during the UI authentication session.", ) - # For backwards compatibility the registration endpoint persists - # changes to the client dict instead of validating them. - if not validate_clientdict: - await self.store.set_ui_auth_clientdict(sid, clientdict) + if session.clientdict != clientdict: + logger.warning( + "Requested operation has changed during the UI " + "authentication session. A future version of Synapse " + "will remove this capability." + ) + + # For backwards compatibility, changes to the client dict are + # persisted as clients modify them throughout their user interactive + # authentication flow. + await self.store.set_ui_auth_clientdict(sid, clientdict) if not authdict: raise InteractiveAuthIncompleteError( diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index e77dd6bf92..af08cc6cce 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -516,7 +516,6 @@ class RegisterRestServlet(RestServlet): body, self.hs.get_ip_from_request(request), "register a new account", - validate_clientdict=False, ) # Check that we're not trying to register a denied 3pid. diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py index a56c50a5b7..293ccfba2b 100644 --- a/tests/rest/client/v2_alpha/test_auth.py +++ b/tests/rest/client/v2_alpha/test_auth.py @@ -133,47 +133,6 @@ class FallbackAuthTests(unittest.HomeserverTestCase): # We're given a registered user. self.assertEqual(channel.json_body["user_id"], "@user:test") - def test_legacy_registration(self): - """ - Registration allows the parameters to vary through the process. - """ - - # Make the initial request to register. (Later on a different password - # will be used.) - # Returns a 401 as per the spec - channel = self.register( - 401, {"username": "user", "type": "m.login.password", "password": "bar"}, - ) - - # Grab the session - session = channel.json_body["session"] - # Assert our configured public key is being given - self.assertEqual( - channel.json_body["params"]["m.login.recaptcha"]["public_key"], "brokencake" - ) - - # Complete the recaptcha step. - self.recaptcha(session, 200) - - # also complete the dummy auth - self.register(200, {"auth": {"session": session, "type": "m.login.dummy"}}) - - # Now we should have fulfilled a complete auth flow, including - # the recaptcha fallback step. Make the initial request again, but - # with a changed password. This still completes. - channel = self.register( - 200, - { - "username": "user", - "type": "m.login.password", - "password": "foo", # Note that this is different. - "auth": {"session": session}, - }, - ) - - # We're given a registered user. - self.assertEqual(channel.json_body["user_id"], "@user:test") - def test_complete_operation_unknown_session(self): """ Attempting to mark an invalid session as complete should error. @@ -282,9 +241,15 @@ class UIAuthTests(unittest.HomeserverTestCase): }, ) - def test_cannot_change_body(self): + def test_can_change_body(self): """ - The initial requested client dict cannot be modified during the user interactive authentication session. + The client dict can be modified during the user interactive authentication session. + + Note that it is not spec compliant to modify the client dict during a + user interactive authentication session, but many clients currently do. + + When Synapse is updated to be spec compliant, the call to re-use the + session ID should be rejected. """ # Create a second login. self.login("test", self.user_pass) @@ -302,9 +267,9 @@ class UIAuthTests(unittest.HomeserverTestCase): self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"]) # Make another request providing the UI auth flow, but try to delete the - # second device. This results in an error. + # second device. self.delete_devices( - 403, + 200, { "devices": [device_ids[1]], "auth": { -- cgit 1.5.1 From a0e063387d5c8390b19cf69359a90924ecd4fbda Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Thu, 14 May 2020 10:07:54 +0100 Subject: Stop `get_joined_users` corruption from custom statuses (#7376) Fix a bug where the `get_joined_users` cache could be corrupted by custom status events (or other state events with a state_key matching the user ID). The bug was introduced by #2229, but has largely gone unnoticed since then. Fixes #7099, #7373. --- changelog.d/7376.bugfix | 1 + synapse/storage/data_stores/main/roommember.py | 3 +- tests/storage/test_roommember.py | 50 +++++++++++++++++++++++++- 3 files changed, 52 insertions(+), 2 deletions(-) create mode 100644 changelog.d/7376.bugfix (limited to 'tests') diff --git a/changelog.d/7376.bugfix b/changelog.d/7376.bugfix new file mode 100644 index 0000000000..b7b67e328e --- /dev/null +++ b/changelog.d/7376.bugfix @@ -0,0 +1 @@ +Fix a bug which could cause messages not to be sent over federation, when state events with state keys matching user IDs (such as custom user statuses) were received. diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py index d5bd0cb5cf..e626b7f6f7 100644 --- a/synapse/storage/data_stores/main/roommember.py +++ b/synapse/storage/data_stores/main/roommember.py @@ -576,7 +576,8 @@ class RoomMemberWorkerStore(EventsWorkerStore): if key[0] == EventTypes.Member ] for etype, state_key in context.delta_ids: - users_in_room.pop(state_key, None) + if etype == EventTypes.Member: + users_in_room.pop(state_key, None) # We check if we have any of the member event ids in the event cache # before we ask the DB diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index 00df0ea68e..5dd46005e6 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -22,6 +22,8 @@ from synapse.rest.client.v1 import login, room from synapse.types import Requester, UserID from tests import unittest +from tests.test_utils import event_injection +from tests.utils import TestHomeServer class RoomMemberStoreTestCase(unittest.HomeserverTestCase): @@ -38,7 +40,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): ) return hs - def prepare(self, reactor, clock, hs): + def prepare(self, reactor, clock, hs: TestHomeServer): # We can't test the RoomMemberStore on its own without the other event # storage logic @@ -114,6 +116,52 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): # It now knows about Charlie's server. self.assertEqual(self.store._known_servers_count, 2) + def test_get_joined_users_from_context(self): + room = self.helper.create_room_as(self.u_alice, tok=self.t_alice) + bob_event = event_injection.inject_member_event( + self.hs, room, self.u_bob, Membership.JOIN + ) + + # first, create a regular event + event, context = event_injection.create_event( + self.hs, + room_id=room, + sender=self.u_alice, + prev_event_ids=[bob_event.event_id], + type="m.test.1", + content={}, + ) + + users = self.get_success( + self.store.get_joined_users_from_context(event, context) + ) + self.assertEqual(users.keys(), {self.u_alice, self.u_bob}) + + # Regression test for #7376: create a state event whose key matches bob's + # user_id, but which is *not* a membership event, and persist that; then check + # that `get_joined_users_from_context` returns the correct users for the next event. + non_member_event = event_injection.inject_event( + self.hs, + room_id=room, + sender=self.u_bob, + prev_event_ids=[bob_event.event_id], + type="m.test.2", + state_key=self.u_bob, + content={}, + ) + event, context = event_injection.create_event( + self.hs, + room_id=room, + sender=self.u_alice, + prev_event_ids=[non_member_event.event_id], + type="m.test.3", + content={}, + ) + users = self.get_success( + self.store.get_joined_users_from_context(event, context) + ) + self.assertEqual(users.keys(), {self.u_alice, self.u_bob}) + class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, homeserver): -- cgit 1.5.1