From c3c6c0e6222cc1bc8ae35a66389dc428d0ddbc92 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 28 Feb 2020 11:15:11 +0000 Subject: Add 'device_lists_outbound_pokes' as extra table. This makes sure we check all the relevant tables to get the current max stream ID. Currently not doing so isn't problematic as the max stream ID in `device_lists_outbound_pokes` is the same as in `device_lists_stream`, however that will change. --- synapse/replication/slave/storage/devices.py | 8 +++++++- synapse/storage/data_stores/main/__init__.py | 5 ++++- 2 files changed, 11 insertions(+), 2 deletions(-) (limited to 'synapse') diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py index 1c77687eea..bf46cc4f8a 100644 --- a/synapse/replication/slave/storage/devices.py +++ b/synapse/replication/slave/storage/devices.py @@ -29,7 +29,13 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto self.hs = hs self._device_list_id_gen = SlavedIdTracker( - db_conn, "device_lists_stream", "stream_id" + db_conn, + "device_lists_stream", + "stream_id", + extra_tables=[ + ("user_signature_stream", "stream_id"), + ("device_lists_outbound_pokes", "stream_id"), + ], ) device_list_max = self._device_list_id_gen.get_current_token() self._device_list_stream_cache = StreamChangeCache( diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/data_stores/main/__init__.py index acca079f23..649e835303 100644 --- a/synapse/storage/data_stores/main/__init__.py +++ b/synapse/storage/data_stores/main/__init__.py @@ -144,7 +144,10 @@ class DataStore( db_conn, "device_lists_stream", "stream_id", - extra_tables=[("user_signature_stream", "stream_id")], + extra_tables=[ + ("user_signature_stream", "stream_id"), + ("device_lists_outbound_pokes", "stream_id"), + ], ) self._cross_signing_id_gen = StreamIdGenerator( db_conn, "e2e_cross_signing_keys", "stream_id" -- cgit 1.4.1 From f5caa1864e3d3c24c691b3a3bff723f77def129e Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 28 Feb 2020 11:21:25 +0000 Subject: Change device lists stream to have one row per id. This will make it possible to process the streams more incrementally, avoiding having to process large chunks at once. --- synapse/storage/data_stores/main/devices.py | 59 ++++++++++++++++++++--------- 1 file changed, 41 insertions(+), 18 deletions(-) (limited to 'synapse') diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py index d55733a4cd..3299607910 100644 --- a/synapse/storage/data_stores/main/devices.py +++ b/synapse/storage/data_stores/main/devices.py @@ -1017,29 +1017,41 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): """Persist that a user's devices have been updated, and which hosts (if any) should be poked. """ - with self._device_list_id_gen.get_next() as stream_id: + if not device_ids: + return + + with self._device_list_id_gen.get_next_mult(len(device_ids)) as stream_ids: yield self.db.runInteraction( - "add_device_change_to_streams", - self._add_device_change_txn, + "add_device_change_to_stream", + self._add_device_change_to_stream_txn, + user_id, + device_ids, + stream_ids, + ) + + if not hosts: + return stream_ids[-1] + + context = get_active_span_text_map() + with self._device_list_id_gen.get_next_mult( + len(hosts) * len(device_ids) + ) as stream_ids: + yield self.db.runInteraction( + "add_device_outbound_poke_to_stream", + self._add_device_outbound_poke_to_stream_txn, user_id, device_ids, hosts, - stream_id, + stream_ids, + context, ) - return stream_id - def _add_device_change_txn(self, txn, user_id, device_ids, hosts, stream_id): - now = self._clock.time_msec() + return stream_ids[-1] + def _add_device_change_to_stream_txn(self, txn, user_id, device_ids, stream_ids): txn.call_after( - self._device_list_stream_cache.entity_has_changed, user_id, stream_id + self._device_list_stream_cache.entity_has_changed, user_id, stream_ids[-1], ) - for host in hosts: - txn.call_after( - self._device_list_federation_stream_cache.entity_has_changed, - host, - stream_id, - ) # Delete older entries in the table, as we really only care about # when the latest change happened. @@ -1048,7 +1060,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): DELETE FROM device_lists_stream WHERE user_id = ? AND device_id = ? AND stream_id < ? """, - [(user_id, device_id, stream_id) for device_id in device_ids], + [(user_id, device_id, stream_ids[0]) for device_id in device_ids], ) self.db.simple_insert_many_txn( @@ -1056,11 +1068,22 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): table="device_lists_stream", values=[ {"stream_id": stream_id, "user_id": user_id, "device_id": device_id} - for device_id in device_ids + for stream_id, device_id in zip(stream_ids, device_ids) ], ) - context = get_active_span_text_map() + def _add_device_outbound_poke_to_stream_txn( + self, txn, user_id, device_ids, hosts, stream_ids, context, + ): + for host in hosts: + txn.call_after( + self._device_list_federation_stream_cache.entity_has_changed, + host, + stream_ids[-1], + ) + + now = self._clock.time_msec() + next_stream_id = iter(stream_ids) self.db.simple_insert_many_txn( txn, @@ -1068,7 +1091,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): values=[ { "destination": destination, - "stream_id": stream_id, + "stream_id": next(next_stream_id), "user_id": user_id, "device_id": device_id, "sent": False, -- cgit 1.4.1 From 9ce4e344a808e15a36a2d9ea03b77ebfc6ac7fe2 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 28 Feb 2020 11:24:05 +0000 Subject: Change device list replication to match new semantics. Instead of sending down batches of user ID/host tuples, send down a row per entity (user ID or host). --- synapse/app/generic_worker.py | 2 +- synapse/replication/slave/storage/devices.py | 25 +++++++++++++------------ synapse/replication/tcp/streams/_base.py | 13 +++++++++---- synapse/storage/data_stores/main/devices.py | 15 +++++++++------ 4 files changed, 32 insertions(+), 23 deletions(-) (limited to 'synapse') diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index b2c764bfe8..561a6f4b22 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -774,7 +774,7 @@ class FederationSenderHandler(object): # ... as well as device updates and messages elif stream_name == DeviceListsStream.NAME: - hosts = {row.destination for row in rows} + hosts = {row.entity for row in rows if not row.entity.startswith("@")} for host in hosts: self.federation_sender.send_device_messages(host) diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py index bf46cc4f8a..01a4f85884 100644 --- a/synapse/replication/slave/storage/devices.py +++ b/synapse/replication/slave/storage/devices.py @@ -61,23 +61,24 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto def process_replication_rows(self, stream_name, token, rows): if stream_name == DeviceListsStream.NAME: self._device_list_id_gen.advance(token) - for row in rows: - self._invalidate_caches_for_devices(token, row.user_id, row.destination) + self._invalidate_caches_for_devices(token, rows) elif stream_name == UserSignatureStream.NAME: + self._device_list_id_gen.advance(token) for row in rows: self._user_signature_stream_cache.entity_has_changed(row.user_id, token) return super(SlavedDeviceStore, self).process_replication_rows( stream_name, token, rows ) - def _invalidate_caches_for_devices(self, token, user_id, destination): - self._device_list_stream_cache.entity_has_changed(user_id, token) - - if destination: - self._device_list_federation_stream_cache.entity_has_changed( - destination, token - ) + def _invalidate_caches_for_devices(self, token, rows): + for row in rows: + if row.entity.startswith("@"): + self._device_list_stream_cache.entity_has_changed(row.entity, token) + self.get_cached_devices_for_user.invalidate((row.entity,)) + self._get_cached_user_device.invalidate_many((row.entity,)) + self.get_device_list_last_stream_id_for_remote.invalidate((row.entity,)) - self.get_cached_devices_for_user.invalidate((user_id,)) - self._get_cached_user_device.invalidate_many((user_id,)) - self.get_device_list_last_stream_id_for_remote.invalidate((user_id,)) + else: + self._device_list_federation_stream_cache.entity_has_changed( + row.entity, token + ) diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 208e8a667b..7a8b6e9df1 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -94,9 +94,13 @@ PublicRoomsStreamRow = namedtuple( "network_id", # str, optional ), ) -DeviceListsStreamRow = namedtuple( - "DeviceListsStreamRow", ("user_id", "destination") # str # str -) + + +@attr.s +class DeviceListsStreamRow: + entity = attr.ib(type=str) + + ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ("entity",)) # str TagAccountDataStreamRow = namedtuple( "TagAccountDataStreamRow", ("user_id", "room_id", "data") # str # str # dict @@ -363,7 +367,8 @@ class PublicRoomsStream(Stream): class DeviceListsStream(Stream): - """Someone added/changed/removed a device + """Either a user has updated their devices or a remote server needs to be + told about a device update. """ NAME = "device_lists" diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py index 3299607910..768afe7a6c 100644 --- a/synapse/storage/data_stores/main/devices.py +++ b/synapse/storage/data_stores/main/devices.py @@ -612,15 +612,18 @@ class DeviceWorkerStore(SQLBaseStore): combined list of changes to devices, and which destinations need to be poked. `destination` may be None if no destinations need to be poked. """ - # We do a group by here as there can be a large number of duplicate - # entries, since we throw away device IDs. + + # This query Does The Right Thing where it'll correctly apply the + # bounds to the inner queries. sql = """ - SELECT MAX(stream_id) AS stream_id, user_id, destination - FROM device_lists_stream - LEFT JOIN device_lists_outbound_pokes USING (stream_id, user_id, device_id) + SELECT stream_id, entity FROM ( + SELECT stream_id, user_id AS entity FROM device_lists_stream + UNION ALL + SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes + ) AS e WHERE ? < stream_id AND stream_id <= ? - GROUP BY user_id, destination """ + return self.db.execute( "get_all_device_list_changes_for_remotes", None, sql, from_key, to_key ) -- cgit 1.4.1 From f70f44abc73689a66d0a05dc703ca38241092174 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 28 Feb 2020 11:45:35 +0000 Subject: Remove handling of multiple rows per ID --- synapse/storage/data_stores/main/devices.py | 35 +--------------------- tests/storage/test_devices.py | 45 ----------------------------- 2 files changed, 1 insertion(+), 79 deletions(-) (limited to 'synapse') diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py index 768afe7a6c..06e1d9f033 100644 --- a/synapse/storage/data_stores/main/devices.py +++ b/synapse/storage/data_stores/main/devices.py @@ -112,23 +112,13 @@ class DeviceWorkerStore(SQLBaseStore): if not has_changed: return now_stream_id, [] - # We retrieve n+1 devices from the list of outbound pokes where n is - # our outbound device update limit. We then check if the very last - # device has the same stream_id as the second-to-last device. If so, - # then we ignore all devices with that stream_id and only send the - # devices with a lower stream_id. - # - # If when culling the list we end up with no devices afterwards, we - # consider the device update to be too large, and simply skip the - # stream_id; the rationale being that such a large device list update - # is likely an error. updates = yield self.db.runInteraction( "get_device_updates_by_remote", self._get_device_updates_by_remote_txn, destination, from_stream_id, now_stream_id, - limit + 1, + limit, ) # Return an empty list if there are no updates @@ -166,14 +156,6 @@ class DeviceWorkerStore(SQLBaseStore): "device_id": verify_key.version, } - # if we have exceeded the limit, we need to exclude any results with the - # same stream_id as the last row. - if len(updates) > limit: - stream_id_cutoff = updates[-1][2] - now_stream_id = stream_id_cutoff - 1 - else: - stream_id_cutoff = None - # Perform the equivalent of a GROUP BY # # Iterate through the updates list and copy non-duplicate @@ -192,10 +174,6 @@ class DeviceWorkerStore(SQLBaseStore): query_map = {} cross_signing_keys_by_user = {} for user_id, device_id, update_stream_id, update_context in updates: - if stream_id_cutoff is not None and update_stream_id >= stream_id_cutoff: - # Stop processing updates - break - if ( user_id in master_key_by_user and device_id == master_key_by_user[user_id]["device_id"] @@ -218,17 +196,6 @@ class DeviceWorkerStore(SQLBaseStore): if update_stream_id > previous_update_stream_id: query_map[key] = (update_stream_id, update_context) - # If we didn't find any updates with a stream_id lower than the cutoff, it - # means that there are more than limit updates all of which have the same - # steam_id. - - # That should only happen if a client is spamming the server with new - # devices, in which case E2E isn't going to work well anyway. We'll just - # skip that stream_id and return an empty list, and continue with the next - # stream_id next time. - if not query_map and not cross_signing_keys_by_user: - return stream_id_cutoff, [] - results = yield self._get_device_update_edus_by_remote( destination, from_stream_id, query_map ) diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py index 6f8d990959..c2539b353a 100644 --- a/tests/storage/test_devices.py +++ b/tests/storage/test_devices.py @@ -88,51 +88,6 @@ class DeviceStoreTestCase(tests.unittest.TestCase): # Check original device_ids are contained within these updates self._check_devices_in_updates(device_ids, device_updates) - @defer.inlineCallbacks - def test_get_device_updates_by_remote_limited(self): - # Test breaking the update limit in 1, 101, and 1 device_id segments - - # first add one device - device_ids1 = ["device_id0"] - yield self.store.add_device_change_to_streams( - "user_id", device_ids1, ["someotherhost"] - ) - - # then add 101 - device_ids2 = ["device_id" + str(i + 1) for i in range(101)] - yield self.store.add_device_change_to_streams( - "user_id", device_ids2, ["someotherhost"] - ) - - # then one more - device_ids3 = ["newdevice"] - yield self.store.add_device_change_to_streams( - "user_id", device_ids3, ["someotherhost"] - ) - - # - # now read them back. - # - - # first we should get a single update - now_stream_id, device_updates = yield self.store.get_device_updates_by_remote( - "someotherhost", -1, limit=100 - ) - self._check_devices_in_updates(device_ids1, device_updates) - - # Then we should get an empty list back as the 101 devices broke the limit - now_stream_id, device_updates = yield self.store.get_device_updates_by_remote( - "someotherhost", now_stream_id, limit=100 - ) - self.assertEqual(len(device_updates), 0) - - # The 101 devices should've been cleared, so we should now just get one device - # update - now_stream_id, device_updates = yield self.store.get_device_updates_by_remote( - "someotherhost", now_stream_id, limit=100 - ) - self._check_devices_in_updates(device_ids3, device_updates) - def _check_devices_in_updates(self, expected_device_ids, device_updates): """Check that an specific device ids exist in a list of device update EDUs""" self.assertEqual(len(device_updates), len(expected_device_ids)) -- cgit 1.4.1 From e53744c737527ebb2af94b677b359743473b0434 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 2 Mar 2020 12:52:28 +0000 Subject: Fix worker handling --- synapse/app/generic_worker.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) (limited to 'synapse') diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 561a6f4b22..d596852419 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -676,8 +676,9 @@ class GenericWorkerReplicationHandler(ReplicationClientHandler): elif stream_name == "device_lists": all_room_ids = set() for row in rows: - room_ids = await self.store.get_rooms_for_user(row.user_id) - all_room_ids.update(room_ids) + if row.entity.startswith("@"): + room_ids = await self.store.get_rooms_for_user(row.entity) + all_room_ids.update(room_ids) self.notifier.on_new_event("device_list_key", token, rooms=all_room_ids) elif stream_name == "presence": await self.presence_handler.process_replication_rows(token, rows) -- cgit 1.4.1 From 6e6476ef07c2d72fbea85603f2eb2a61a6866732 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 18 Mar 2020 10:13:55 +0000 Subject: Comments from review --- synapse/app/generic_worker.py | 3 +++ synapse/replication/slave/storage/devices.py | 3 +++ synapse/storage/data_stores/main/devices.py | 27 +++++++++++++++++++-------- 3 files changed, 25 insertions(+), 8 deletions(-) (limited to 'synapse') diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index d596852419..cdc078cf11 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -775,6 +775,9 @@ class FederationSenderHandler(object): # ... as well as device updates and messages elif stream_name == DeviceListsStream.NAME: + # The entities are either user IDs (starting with '@') whose devices + # have changed, or remote servers that we need to tell about + # changes. hosts = {row.entity for row in rows if not row.entity.startswith("@")} for host in hosts: self.federation_sender.send_device_messages(host) diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py index 01a4f85884..23b1650e41 100644 --- a/synapse/replication/slave/storage/devices.py +++ b/synapse/replication/slave/storage/devices.py @@ -72,6 +72,9 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto def _invalidate_caches_for_devices(self, token, rows): for row in rows: + # The entities are either user IDs (starting with '@') whose devices + # have changed, or remote servers that we need to tell about + # changes. if row.entity.startswith("@"): self._device_list_stream_cache.entity_has_changed(row.entity, token) self.get_cached_devices_for_user.invalidate((row.entity,)) diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py index 06e1d9f033..4c19c02bbc 100644 --- a/synapse/storage/data_stores/main/devices.py +++ b/synapse/storage/data_stores/main/devices.py @@ -15,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from typing import List, Tuple from six import iteritems @@ -31,7 +32,7 @@ from synapse.logging.opentracing import ( ) from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause -from synapse.storage.database import Database +from synapse.storage.database import Database, LoggingTransaction from synapse.types import Collection, get_verify_key_from_cross_signing_key from synapse.util.caches.descriptors import ( Cache, @@ -574,10 +575,12 @@ class DeviceWorkerStore(SQLBaseStore): else: return set() - def get_all_device_list_changes_for_remotes(self, from_key, to_key): - """Return a list of `(stream_id, user_id, destination)` which is the - combined list of changes to devices, and which destinations need to be - poked. `destination` may be None if no destinations need to be poked. + async def get_all_device_list_changes_for_remotes( + self, from_key: int, to_key: int + ) -> List[Tuple[int, str]]: + """Return a list of `(stream_id, entity)` which is the combined list of + changes to devices and which destinations need to be poked. Entity is + either a user ID (starting with '@') or a remote destination. """ # This query Does The Right Thing where it'll correctly apply the @@ -591,7 +594,7 @@ class DeviceWorkerStore(SQLBaseStore): WHERE ? < stream_id AND stream_id <= ? """ - return self.db.execute( + return await self.db.execute( "get_all_device_list_changes_for_remotes", None, sql, from_key, to_key ) @@ -1018,11 +1021,19 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): return stream_ids[-1] - def _add_device_change_to_stream_txn(self, txn, user_id, device_ids, stream_ids): + def _add_device_change_to_stream_txn( + self, + txn: LoggingTransaction, + user_id: str, + device_ids: Collection[str], + stream_ids: List[str], + ): txn.call_after( self._device_list_stream_cache.entity_has_changed, user_id, stream_ids[-1], ) + min_stream_id = stream_ids[0] + # Delete older entries in the table, as we really only care about # when the latest change happened. txn.executemany( @@ -1030,7 +1041,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): DELETE FROM device_lists_stream WHERE user_id = ? AND device_id = ? AND stream_id < ? """, - [(user_id, device_id, stream_ids[0]) for device_id in device_ids], + [(user_id, device_id, min_stream_id) for device_id in device_ids], ) self.db.simple_insert_many_txn( -- cgit 1.4.1 From c2db6599c820d97e3c8a02d782e90af80121c903 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 19 Mar 2020 08:22:56 -0400 Subject: Fix a bug in the federation API which could cause occasional "Failed to get PDU" errors (#7089). --- changelog.d/7089.bugfix | 1 + synapse/federation/federation_base.py | 24 +++++++++--------------- synapse/federation/federation_client.py | 19 ++++++++----------- synapse/federation/federation_server.py | 8 ++++---- 4 files changed, 22 insertions(+), 30 deletions(-) create mode 100644 changelog.d/7089.bugfix (limited to 'synapse') diff --git a/changelog.d/7089.bugfix b/changelog.d/7089.bugfix new file mode 100644 index 0000000000..f1f440f23a --- /dev/null +++ b/changelog.d/7089.bugfix @@ -0,0 +1 @@ +Fix a bug in the federation API which could cause occasional "Failed to get PDU" errors. diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index 5c991e5412..b0b0eba41e 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -25,11 +25,7 @@ from twisted.python.failure import Failure from synapse.api.constants import MAX_DEPTH, EventTypes, Membership from synapse.api.errors import Codes, SynapseError -from synapse.api.room_versions import ( - KNOWN_ROOM_VERSIONS, - EventFormatVersions, - RoomVersion, -) +from synapse.api.room_versions import EventFormatVersions, RoomVersion from synapse.crypto.event_signing import check_event_content_hash from synapse.crypto.keyring import Keyring from synapse.events import EventBase, make_event_from_dict @@ -55,13 +51,15 @@ class FederationBase(object): self.store = hs.get_datastore() self._clock = hs.get_clock() - def _check_sigs_and_hash(self, room_version: str, pdu: EventBase) -> Deferred: + def _check_sigs_and_hash( + self, room_version: RoomVersion, pdu: EventBase + ) -> Deferred: return make_deferred_yieldable( self._check_sigs_and_hashes(room_version, [pdu])[0] ) def _check_sigs_and_hashes( - self, room_version: str, pdus: List[EventBase] + self, room_version: RoomVersion, pdus: List[EventBase] ) -> List[Deferred]: """Checks that each of the received events is correctly signed by the sending server. @@ -146,7 +144,7 @@ class PduToCheckSig( def _check_sigs_on_pdus( - keyring: Keyring, room_version: str, pdus: Iterable[EventBase] + keyring: Keyring, room_version: RoomVersion, pdus: Iterable[EventBase] ) -> List[Deferred]: """Check that the given events are correctly signed @@ -191,10 +189,6 @@ def _check_sigs_on_pdus( for p in pdus ] - v = KNOWN_ROOM_VERSIONS.get(room_version) - if not v: - raise RuntimeError("Unrecognized room version %s" % (room_version,)) - # First we check that the sender event is signed by the sender's domain # (except if its a 3pid invite, in which case it may be sent by any server) pdus_to_check_sender = [p for p in pdus_to_check if not _is_invite_via_3pid(p.pdu)] @@ -204,7 +198,7 @@ def _check_sigs_on_pdus( ( p.sender_domain, p.redacted_pdu_json, - p.pdu.origin_server_ts if v.enforce_key_validity else 0, + p.pdu.origin_server_ts if room_version.enforce_key_validity else 0, p.pdu.event_id, ) for p in pdus_to_check_sender @@ -227,7 +221,7 @@ def _check_sigs_on_pdus( # event id's domain (normally only the case for joins/leaves), and add additional # checks. Only do this if the room version has a concept of event ID domain # (ie, the room version uses old-style non-hash event IDs). - if v.event_format == EventFormatVersions.V1: + if room_version.event_format == EventFormatVersions.V1: pdus_to_check_event_id = [ p for p in pdus_to_check @@ -239,7 +233,7 @@ def _check_sigs_on_pdus( ( get_domain_from_id(p.pdu.event_id), p.redacted_pdu_json, - p.pdu.origin_server_ts if v.enforce_key_validity else 0, + p.pdu.origin_server_ts if room_version.enforce_key_validity else 0, p.pdu.event_id, ) for p in pdus_to_check_event_id diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 8c6b839478..a0071fec94 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -220,8 +220,7 @@ class FederationClient(FederationBase): # FIXME: We should handle signature failures more gracefully. pdus[:] = await make_deferred_yieldable( defer.gatherResults( - self._check_sigs_and_hashes(room_version.identifier, pdus), - consumeErrors=True, + self._check_sigs_and_hashes(room_version, pdus), consumeErrors=True, ).addErrback(unwrapFirstError) ) @@ -291,9 +290,7 @@ class FederationClient(FederationBase): pdu = pdu_list[0] # Check signatures are correct. - signed_pdu = await self._check_sigs_and_hash( - room_version.identifier, pdu - ) + signed_pdu = await self._check_sigs_and_hash(room_version, pdu) break @@ -350,7 +347,7 @@ class FederationClient(FederationBase): self, origin: str, pdus: List[EventBase], - room_version: str, + room_version: RoomVersion, outlier: bool = False, include_none: bool = False, ) -> List[EventBase]: @@ -396,7 +393,7 @@ class FederationClient(FederationBase): self.get_pdu( destinations=[pdu.origin], event_id=pdu.event_id, - room_version=room_version, # type: ignore + room_version=room_version, outlier=outlier, timeout=10000, ) @@ -434,7 +431,7 @@ class FederationClient(FederationBase): ] signed_auth = await self._check_sigs_and_hash_and_fetch( - destination, auth_chain, outlier=True, room_version=room_version.identifier + destination, auth_chain, outlier=True, room_version=room_version ) signed_auth.sort(key=lambda e: e.depth) @@ -661,7 +658,7 @@ class FederationClient(FederationBase): destination, list(pdus.values()), outlier=True, - room_version=room_version.identifier, + room_version=room_version, ) valid_pdus_map = {p.event_id: p for p in valid_pdus} @@ -756,7 +753,7 @@ class FederationClient(FederationBase): pdu = event_from_pdu_json(pdu_dict, room_version) # Check signatures are correct. - pdu = await self._check_sigs_and_hash(room_version.identifier, pdu) + pdu = await self._check_sigs_and_hash(room_version, pdu) # FIXME: We should handle signature failures more gracefully. @@ -948,7 +945,7 @@ class FederationClient(FederationBase): ] signed_events = await self._check_sigs_and_hash_and_fetch( - destination, events, outlier=False, room_version=room_version.identifier + destination, events, outlier=False, room_version=room_version ) except HttpResponseException as e: if not e.code == 400: diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 275b9c99d7..89d521bc31 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -409,7 +409,7 @@ class FederationServer(FederationBase): pdu = event_from_pdu_json(content, room_version) origin_host, _ = parse_server_name(origin) await self.check_server_matches_acl(origin_host, pdu.room_id) - pdu = await self._check_sigs_and_hash(room_version.identifier, pdu) + pdu = await self._check_sigs_and_hash(room_version, pdu) ret_pdu = await self.handler.on_invite_request(origin, pdu, room_version) time_now = self._clock.time_msec() return {"event": ret_pdu.get_pdu_json(time_now)} @@ -425,7 +425,7 @@ class FederationServer(FederationBase): logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures) - pdu = await self._check_sigs_and_hash(room_version.identifier, pdu) + pdu = await self._check_sigs_and_hash(room_version, pdu) res_pdus = await self.handler.on_send_join_request(origin, pdu) time_now = self._clock.time_msec() @@ -455,7 +455,7 @@ class FederationServer(FederationBase): logger.debug("on_send_leave_request: pdu sigs: %s", pdu.signatures) - pdu = await self._check_sigs_and_hash(room_version.identifier, pdu) + pdu = await self._check_sigs_and_hash(room_version, pdu) await self.handler.on_send_leave_request(origin, pdu) return {} @@ -611,7 +611,7 @@ class FederationServer(FederationBase): logger.info("Accepting join PDU %s from %s", pdu.event_id, origin) # We've already checked that we know the room version by this point - room_version = await self.store.get_room_version_id(pdu.room_id) + room_version = await self.store.get_room_version(pdu.room_id) # Check signature. try: -- cgit 1.4.1 From caec7d4fa0041697b7714e638477772f0a827ff6 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 20 Mar 2020 07:20:02 -0400 Subject: Convert some of the media REST code to async/await (#7110) --- changelog.d/7110.misc | 1 + synapse/rest/media/v1/media_repository.py | 110 ++++++++++++-------------- synapse/rest/media/v1/preview_url_resource.py | 37 ++++----- synapse/rest/media/v1/thumbnail_resource.py | 54 ++++++------- 4 files changed, 91 insertions(+), 111 deletions(-) create mode 100644 changelog.d/7110.misc (limited to 'synapse') diff --git a/changelog.d/7110.misc b/changelog.d/7110.misc new file mode 100644 index 0000000000..fac5bc0403 --- /dev/null +++ b/changelog.d/7110.misc @@ -0,0 +1 @@ +Convert some of synapse.rest.media to async/await. diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 490b1b45a8..fd10d42f2f 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -24,7 +24,6 @@ from six import iteritems import twisted.internet.error import twisted.web.http -from twisted.internet import defer from twisted.web.resource import Resource from synapse.api.errors import ( @@ -114,15 +113,14 @@ class MediaRepository(object): "update_recently_accessed_media", self._update_recently_accessed ) - @defer.inlineCallbacks - def _update_recently_accessed(self): + async def _update_recently_accessed(self): remote_media = self.recently_accessed_remotes self.recently_accessed_remotes = set() local_media = self.recently_accessed_locals self.recently_accessed_locals = set() - yield self.store.update_cached_last_access_time( + await self.store.update_cached_last_access_time( local_media, remote_media, self.clock.time_msec() ) @@ -138,8 +136,7 @@ class MediaRepository(object): else: self.recently_accessed_locals.add(media_id) - @defer.inlineCallbacks - def create_content( + async def create_content( self, media_type, upload_name, content, content_length, auth_user ): """Store uploaded content for a local user and return the mxc URL @@ -158,11 +155,11 @@ class MediaRepository(object): file_info = FileInfo(server_name=None, file_id=media_id) - fname = yield self.media_storage.store_file(content, file_info) + fname = await self.media_storage.store_file(content, file_info) logger.info("Stored local media in file %r", fname) - yield self.store.store_local_media( + await self.store.store_local_media( media_id=media_id, media_type=media_type, time_now_ms=self.clock.time_msec(), @@ -171,12 +168,11 @@ class MediaRepository(object): user_id=auth_user, ) - yield self._generate_thumbnails(None, media_id, media_id, media_type) + await self._generate_thumbnails(None, media_id, media_id, media_type) return "mxc://%s/%s" % (self.server_name, media_id) - @defer.inlineCallbacks - def get_local_media(self, request, media_id, name): + async def get_local_media(self, request, media_id, name): """Responds to reqests for local media, if exists, or returns 404. Args: @@ -190,7 +186,7 @@ class MediaRepository(object): Deferred: Resolves once a response has successfully been written to request """ - media_info = yield self.store.get_local_media(media_id) + media_info = await self.store.get_local_media(media_id) if not media_info or media_info["quarantined_by"]: respond_404(request) return @@ -204,13 +200,12 @@ class MediaRepository(object): file_info = FileInfo(None, media_id, url_cache=url_cache) - responder = yield self.media_storage.fetch_media(file_info) - yield respond_with_responder( + responder = await self.media_storage.fetch_media(file_info) + await respond_with_responder( request, responder, media_type, media_length, upload_name ) - @defer.inlineCallbacks - def get_remote_media(self, request, server_name, media_id, name): + async def get_remote_media(self, request, server_name, media_id, name): """Respond to requests for remote media. Args: @@ -236,8 +231,8 @@ class MediaRepository(object): # We linearize here to ensure that we don't try and download remote # media multiple times concurrently key = (server_name, media_id) - with (yield self.remote_media_linearizer.queue(key)): - responder, media_info = yield self._get_remote_media_impl( + with (await self.remote_media_linearizer.queue(key)): + responder, media_info = await self._get_remote_media_impl( server_name, media_id ) @@ -246,14 +241,13 @@ class MediaRepository(object): media_type = media_info["media_type"] media_length = media_info["media_length"] upload_name = name if name else media_info["upload_name"] - yield respond_with_responder( + await respond_with_responder( request, responder, media_type, media_length, upload_name ) else: respond_404(request) - @defer.inlineCallbacks - def get_remote_media_info(self, server_name, media_id): + async def get_remote_media_info(self, server_name, media_id): """Gets the media info associated with the remote file, downloading if necessary. @@ -274,8 +268,8 @@ class MediaRepository(object): # We linearize here to ensure that we don't try and download remote # media multiple times concurrently key = (server_name, media_id) - with (yield self.remote_media_linearizer.queue(key)): - responder, media_info = yield self._get_remote_media_impl( + with (await self.remote_media_linearizer.queue(key)): + responder, media_info = await self._get_remote_media_impl( server_name, media_id ) @@ -286,8 +280,7 @@ class MediaRepository(object): return media_info - @defer.inlineCallbacks - def _get_remote_media_impl(self, server_name, media_id): + async def _get_remote_media_impl(self, server_name, media_id): """Looks for media in local cache, if not there then attempt to download from remote server. @@ -299,7 +292,7 @@ class MediaRepository(object): Returns: Deferred[(Responder, media_info)] """ - media_info = yield self.store.get_cached_remote_media(server_name, media_id) + media_info = await self.store.get_cached_remote_media(server_name, media_id) # file_id is the ID we use to track the file locally. If we've already # seen the file then reuse the existing ID, otherwise genereate a new @@ -317,19 +310,18 @@ class MediaRepository(object): logger.info("Media is quarantined") raise NotFoundError() - responder = yield self.media_storage.fetch_media(file_info) + responder = await self.media_storage.fetch_media(file_info) if responder: return responder, media_info # Failed to find the file anywhere, lets download it. - media_info = yield self._download_remote_file(server_name, media_id, file_id) + media_info = await self._download_remote_file(server_name, media_id, file_id) - responder = yield self.media_storage.fetch_media(file_info) + responder = await self.media_storage.fetch_media(file_info) return responder, media_info - @defer.inlineCallbacks - def _download_remote_file(self, server_name, media_id, file_id): + async def _download_remote_file(self, server_name, media_id, file_id): """Attempt to download the remote file from the given server name, using the given file_id as the local id. @@ -351,7 +343,7 @@ class MediaRepository(object): ("/_matrix/media/v1/download", server_name, media_id) ) try: - length, headers = yield self.client.get_file( + length, headers = await self.client.get_file( server_name, request_path, output_stream=f, @@ -397,7 +389,7 @@ class MediaRepository(object): ) raise SynapseError(502, "Failed to fetch remote media") - yield finish() + await finish() media_type = headers[b"Content-Type"][0].decode("ascii") upload_name = get_filename_from_headers(headers) @@ -405,7 +397,7 @@ class MediaRepository(object): logger.info("Stored remote media in file %r", fname) - yield self.store.store_cached_remote_media( + await self.store.store_cached_remote_media( origin=server_name, media_id=media_id, media_type=media_type, @@ -423,7 +415,7 @@ class MediaRepository(object): "filesystem_id": file_id, } - yield self._generate_thumbnails(server_name, media_id, file_id, media_type) + await self._generate_thumbnails(server_name, media_id, file_id, media_type) return media_info @@ -458,16 +450,15 @@ class MediaRepository(object): return t_byte_source - @defer.inlineCallbacks - def generate_local_exact_thumbnail( + async def generate_local_exact_thumbnail( self, media_id, t_width, t_height, t_method, t_type, url_cache ): - input_path = yield self.media_storage.ensure_media_is_in_local_cache( + input_path = await self.media_storage.ensure_media_is_in_local_cache( FileInfo(None, media_id, url_cache=url_cache) ) thumbnailer = Thumbnailer(input_path) - t_byte_source = yield defer_to_thread( + t_byte_source = await defer_to_thread( self.hs.get_reactor(), self._generate_thumbnail, thumbnailer, @@ -490,7 +481,7 @@ class MediaRepository(object): thumbnail_type=t_type, ) - output_path = yield self.media_storage.store_file( + output_path = await self.media_storage.store_file( t_byte_source, file_info ) finally: @@ -500,22 +491,21 @@ class MediaRepository(object): t_len = os.path.getsize(output_path) - yield self.store.store_local_thumbnail( + await self.store.store_local_thumbnail( media_id, t_width, t_height, t_type, t_method, t_len ) return output_path - @defer.inlineCallbacks - def generate_remote_exact_thumbnail( + async def generate_remote_exact_thumbnail( self, server_name, file_id, media_id, t_width, t_height, t_method, t_type ): - input_path = yield self.media_storage.ensure_media_is_in_local_cache( + input_path = await self.media_storage.ensure_media_is_in_local_cache( FileInfo(server_name, file_id, url_cache=False) ) thumbnailer = Thumbnailer(input_path) - t_byte_source = yield defer_to_thread( + t_byte_source = await defer_to_thread( self.hs.get_reactor(), self._generate_thumbnail, thumbnailer, @@ -537,7 +527,7 @@ class MediaRepository(object): thumbnail_type=t_type, ) - output_path = yield self.media_storage.store_file( + output_path = await self.media_storage.store_file( t_byte_source, file_info ) finally: @@ -547,7 +537,7 @@ class MediaRepository(object): t_len = os.path.getsize(output_path) - yield self.store.store_remote_media_thumbnail( + await self.store.store_remote_media_thumbnail( server_name, media_id, file_id, @@ -560,8 +550,7 @@ class MediaRepository(object): return output_path - @defer.inlineCallbacks - def _generate_thumbnails( + async def _generate_thumbnails( self, server_name, media_id, file_id, media_type, url_cache=False ): """Generate and store thumbnails for an image. @@ -582,7 +571,7 @@ class MediaRepository(object): if not requirements: return - input_path = yield self.media_storage.ensure_media_is_in_local_cache( + input_path = await self.media_storage.ensure_media_is_in_local_cache( FileInfo(server_name, file_id, url_cache=url_cache) ) @@ -600,7 +589,7 @@ class MediaRepository(object): return if thumbnailer.transpose_method is not None: - m_width, m_height = yield defer_to_thread( + m_width, m_height = await defer_to_thread( self.hs.get_reactor(), thumbnailer.transpose ) @@ -620,11 +609,11 @@ class MediaRepository(object): for (t_width, t_height, t_type), t_method in iteritems(thumbnails): # Generate the thumbnail if t_method == "crop": - t_byte_source = yield defer_to_thread( + t_byte_source = await defer_to_thread( self.hs.get_reactor(), thumbnailer.crop, t_width, t_height, t_type ) elif t_method == "scale": - t_byte_source = yield defer_to_thread( + t_byte_source = await defer_to_thread( self.hs.get_reactor(), thumbnailer.scale, t_width, t_height, t_type ) else: @@ -646,7 +635,7 @@ class MediaRepository(object): url_cache=url_cache, ) - output_path = yield self.media_storage.store_file( + output_path = await self.media_storage.store_file( t_byte_source, file_info ) finally: @@ -656,7 +645,7 @@ class MediaRepository(object): # Write to database if server_name: - yield self.store.store_remote_media_thumbnail( + await self.store.store_remote_media_thumbnail( server_name, media_id, file_id, @@ -667,15 +656,14 @@ class MediaRepository(object): t_len, ) else: - yield self.store.store_local_thumbnail( + await self.store.store_local_thumbnail( media_id, t_width, t_height, t_type, t_method, t_len ) return {"width": m_width, "height": m_height} - @defer.inlineCallbacks - def delete_old_remote_media(self, before_ts): - old_media = yield self.store.get_remote_media_before(before_ts) + async def delete_old_remote_media(self, before_ts): + old_media = await self.store.get_remote_media_before(before_ts) deleted = 0 @@ -689,7 +677,7 @@ class MediaRepository(object): # TODO: Should we delete from the backup store - with (yield self.remote_media_linearizer.queue(key)): + with (await self.remote_media_linearizer.queue(key)): full_path = self.filepaths.remote_media_filepath(origin, file_id) try: os.remove(full_path) @@ -705,7 +693,7 @@ class MediaRepository(object): ) shutil.rmtree(thumbnail_dir, ignore_errors=True) - yield self.store.delete_remote_media(origin, media_id) + await self.store.delete_remote_media(origin, media_id) deleted += 1 return {"deleted": deleted} diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 07e395cfd1..c46676f8fc 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -165,8 +165,7 @@ class PreviewUrlResource(DirectServeResource): og = await make_deferred_yieldable(defer.maybeDeferred(observable.observe)) respond_with_json_bytes(request, 200, og, send_cors=True) - @defer.inlineCallbacks - def _do_preview(self, url, user, ts): + async def _do_preview(self, url, user, ts): """Check the db, and download the URL and build a preview Args: @@ -179,7 +178,7 @@ class PreviewUrlResource(DirectServeResource): """ # check the URL cache in the DB (which will also provide us with # historical previews, if we have any) - cache_result = yield self.store.get_url_cache(url, ts) + cache_result = await self.store.get_url_cache(url, ts) if ( cache_result and cache_result["expires_ts"] > ts @@ -192,13 +191,13 @@ class PreviewUrlResource(DirectServeResource): og = og.encode("utf8") return og - media_info = yield self._download_url(url, user) + media_info = await self._download_url(url, user) logger.debug("got media_info of '%s'", media_info) if _is_media(media_info["media_type"]): file_id = media_info["filesystem_id"] - dims = yield self.media_repo._generate_thumbnails( + dims = await self.media_repo._generate_thumbnails( None, file_id, file_id, media_info["media_type"], url_cache=True ) @@ -248,14 +247,14 @@ class PreviewUrlResource(DirectServeResource): # request itself and benefit from the same caching etc. But for now we # just rely on the caching on the master request to speed things up. if "og:image" in og and og["og:image"]: - image_info = yield self._download_url( + image_info = await self._download_url( _rebase_url(og["og:image"], media_info["uri"]), user ) if _is_media(image_info["media_type"]): # TODO: make sure we don't choke on white-on-transparent images file_id = image_info["filesystem_id"] - dims = yield self.media_repo._generate_thumbnails( + dims = await self.media_repo._generate_thumbnails( None, file_id, file_id, image_info["media_type"], url_cache=True ) if dims: @@ -293,7 +292,7 @@ class PreviewUrlResource(DirectServeResource): jsonog = json.dumps(og) # store OG in history-aware DB cache - yield self.store.store_url_cache( + await self.store.store_url_cache( url, media_info["response_code"], media_info["etag"], @@ -305,8 +304,7 @@ class PreviewUrlResource(DirectServeResource): return jsonog.encode("utf8") - @defer.inlineCallbacks - def _download_url(self, url, user): + async def _download_url(self, url, user): # TODO: we should probably honour robots.txt... except in practice # we're most likely being explicitly triggered by a human rather than a # bot, so are we really a robot? @@ -318,7 +316,7 @@ class PreviewUrlResource(DirectServeResource): with self.media_storage.store_into_file(file_info) as (f, fname, finish): try: logger.debug("Trying to get url '%s'", url) - length, headers, uri, code = yield self.client.get_file( + length, headers, uri, code = await self.client.get_file( url, output_stream=f, max_size=self.max_spider_size ) except SynapseError: @@ -345,7 +343,7 @@ class PreviewUrlResource(DirectServeResource): % (traceback.format_exception_only(sys.exc_info()[0], e),), Codes.UNKNOWN, ) - yield finish() + await finish() try: if b"Content-Type" in headers: @@ -356,7 +354,7 @@ class PreviewUrlResource(DirectServeResource): download_name = get_filename_from_headers(headers) - yield self.store.store_local_media( + await self.store.store_local_media( media_id=file_id, media_type=media_type, time_now_ms=self.clock.time_msec(), @@ -393,8 +391,7 @@ class PreviewUrlResource(DirectServeResource): "expire_url_cache_data", self._expire_url_cache_data ) - @defer.inlineCallbacks - def _expire_url_cache_data(self): + async def _expire_url_cache_data(self): """Clean up expired url cache content, media and thumbnails. """ # TODO: Delete from backup media store @@ -403,12 +400,12 @@ class PreviewUrlResource(DirectServeResource): logger.info("Running url preview cache expiry") - if not (yield self.store.db.updates.has_completed_background_updates()): + if not (await self.store.db.updates.has_completed_background_updates()): logger.info("Still running DB updates; skipping expiry") return # First we delete expired url cache entries - media_ids = yield self.store.get_expired_url_cache(now) + media_ids = await self.store.get_expired_url_cache(now) removed_media = [] for media_id in media_ids: @@ -430,7 +427,7 @@ class PreviewUrlResource(DirectServeResource): except Exception: pass - yield self.store.delete_url_cache(removed_media) + await self.store.delete_url_cache(removed_media) if removed_media: logger.info("Deleted %d entries from url cache", len(removed_media)) @@ -440,7 +437,7 @@ class PreviewUrlResource(DirectServeResource): # may have a room open with a preview url thing open). # So we wait a couple of days before deleting, just in case. expire_before = now - 2 * 24 * 60 * 60 * 1000 - media_ids = yield self.store.get_url_cache_media_before(expire_before) + media_ids = await self.store.get_url_cache_media_before(expire_before) removed_media = [] for media_id in media_ids: @@ -478,7 +475,7 @@ class PreviewUrlResource(DirectServeResource): except Exception: pass - yield self.store.delete_url_cache_media(removed_media) + await self.store.delete_url_cache_media(removed_media) logger.info("Deleted %d media from url cache", len(removed_media)) diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py index d57480f761..0b87220234 100644 --- a/synapse/rest/media/v1/thumbnail_resource.py +++ b/synapse/rest/media/v1/thumbnail_resource.py @@ -16,8 +16,6 @@ import logging -from twisted.internet import defer - from synapse.http.server import ( DirectServeResource, set_cors_headers, @@ -79,11 +77,10 @@ class ThumbnailResource(DirectServeResource): ) self.media_repo.mark_recently_accessed(server_name, media_id) - @defer.inlineCallbacks - def _respond_local_thumbnail( + async def _respond_local_thumbnail( self, request, media_id, width, height, method, m_type ): - media_info = yield self.store.get_local_media(media_id) + media_info = await self.store.get_local_media(media_id) if not media_info: respond_404(request) @@ -93,7 +90,7 @@ class ThumbnailResource(DirectServeResource): respond_404(request) return - thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id) + thumbnail_infos = await self.store.get_local_media_thumbnails(media_id) if thumbnail_infos: thumbnail_info = self._select_thumbnail( @@ -114,14 +111,13 @@ class ThumbnailResource(DirectServeResource): t_type = file_info.thumbnail_type t_length = thumbnail_info["thumbnail_length"] - responder = yield self.media_storage.fetch_media(file_info) - yield respond_with_responder(request, responder, t_type, t_length) + responder = await self.media_storage.fetch_media(file_info) + await respond_with_responder(request, responder, t_type, t_length) else: logger.info("Couldn't find any generated thumbnails") respond_404(request) - @defer.inlineCallbacks - def _select_or_generate_local_thumbnail( + async def _select_or_generate_local_thumbnail( self, request, media_id, @@ -130,7 +126,7 @@ class ThumbnailResource(DirectServeResource): desired_method, desired_type, ): - media_info = yield self.store.get_local_media(media_id) + media_info = await self.store.get_local_media(media_id) if not media_info: respond_404(request) @@ -140,7 +136,7 @@ class ThumbnailResource(DirectServeResource): respond_404(request) return - thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id) + thumbnail_infos = await self.store.get_local_media_thumbnails(media_id) for info in thumbnail_infos: t_w = info["thumbnail_width"] == desired_width t_h = info["thumbnail_height"] == desired_height @@ -162,15 +158,15 @@ class ThumbnailResource(DirectServeResource): t_type = file_info.thumbnail_type t_length = info["thumbnail_length"] - responder = yield self.media_storage.fetch_media(file_info) + responder = await self.media_storage.fetch_media(file_info) if responder: - yield respond_with_responder(request, responder, t_type, t_length) + await respond_with_responder(request, responder, t_type, t_length) return logger.debug("We don't have a thumbnail of that size. Generating") # Okay, so we generate one. - file_path = yield self.media_repo.generate_local_exact_thumbnail( + file_path = await self.media_repo.generate_local_exact_thumbnail( media_id, desired_width, desired_height, @@ -180,13 +176,12 @@ class ThumbnailResource(DirectServeResource): ) if file_path: - yield respond_with_file(request, desired_type, file_path) + await respond_with_file(request, desired_type, file_path) else: logger.warning("Failed to generate thumbnail") respond_404(request) - @defer.inlineCallbacks - def _select_or_generate_remote_thumbnail( + async def _select_or_generate_remote_thumbnail( self, request, server_name, @@ -196,9 +191,9 @@ class ThumbnailResource(DirectServeResource): desired_method, desired_type, ): - media_info = yield self.media_repo.get_remote_media_info(server_name, media_id) + media_info = await self.media_repo.get_remote_media_info(server_name, media_id) - thumbnail_infos = yield self.store.get_remote_media_thumbnails( + thumbnail_infos = await self.store.get_remote_media_thumbnails( server_name, media_id ) @@ -224,15 +219,15 @@ class ThumbnailResource(DirectServeResource): t_type = file_info.thumbnail_type t_length = info["thumbnail_length"] - responder = yield self.media_storage.fetch_media(file_info) + responder = await self.media_storage.fetch_media(file_info) if responder: - yield respond_with_responder(request, responder, t_type, t_length) + await respond_with_responder(request, responder, t_type, t_length) return logger.debug("We don't have a thumbnail of that size. Generating") # Okay, so we generate one. - file_path = yield self.media_repo.generate_remote_exact_thumbnail( + file_path = await self.media_repo.generate_remote_exact_thumbnail( server_name, file_id, media_id, @@ -243,21 +238,20 @@ class ThumbnailResource(DirectServeResource): ) if file_path: - yield respond_with_file(request, desired_type, file_path) + await respond_with_file(request, desired_type, file_path) else: logger.warning("Failed to generate thumbnail") respond_404(request) - @defer.inlineCallbacks - def _respond_remote_thumbnail( + async def _respond_remote_thumbnail( self, request, server_name, media_id, width, height, method, m_type ): # TODO: Don't download the whole remote file # We should proxy the thumbnail from the remote server instead of # downloading the remote file and generating our own thumbnails. - media_info = yield self.media_repo.get_remote_media_info(server_name, media_id) + media_info = await self.media_repo.get_remote_media_info(server_name, media_id) - thumbnail_infos = yield self.store.get_remote_media_thumbnails( + thumbnail_infos = await self.store.get_remote_media_thumbnails( server_name, media_id ) @@ -278,8 +272,8 @@ class ThumbnailResource(DirectServeResource): t_type = file_info.thumbnail_type t_length = thumbnail_info["thumbnail_length"] - responder = yield self.media_storage.fetch_media(file_info) - yield respond_with_responder(request, responder, t_type, t_length) + responder = await self.media_storage.fetch_media(file_info) + await respond_with_responder(request, responder, t_type, t_length) else: logger.info("Failed to find any generated thumbnails") respond_404(request) -- cgit 1.4.1 From fdb13447167da0670dd6ad95fdf4a99cde450eb9 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 20 Mar 2020 14:40:47 +0000 Subject: Remove concept of a non-limited stream. (#7011) --- changelog.d/7011.misc | 1 + synapse/handlers/presence.py | 4 +- synapse/handlers/typing.py | 11 +++- synapse/replication/tcp/resource.py | 9 +-- synapse/replication/tcp/streams/_base.py | 66 +++++++++------------- synapse/storage/data_stores/main/devices.py | 10 +++- .../storage/data_stores/main/end_to_end_keys.py | 14 +++-- synapse/storage/data_stores/main/presence.py | 23 ++++---- 8 files changed, 71 insertions(+), 67 deletions(-) create mode 100644 changelog.d/7011.misc (limited to 'synapse') diff --git a/changelog.d/7011.misc b/changelog.d/7011.misc new file mode 100644 index 0000000000..41c3b37574 --- /dev/null +++ b/changelog.d/7011.misc @@ -0,0 +1 @@ +Remove concept of a non-limited stream. diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 5526015ddb..6912165622 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -747,7 +747,7 @@ class PresenceHandler(object): return False - async def get_all_presence_updates(self, last_id, current_id): + async def get_all_presence_updates(self, last_id, current_id, limit): """ Gets a list of presence update rows from between the given stream ids. Each row has: @@ -762,7 +762,7 @@ class PresenceHandler(object): """ # TODO(markjh): replicate the unpersisted changes. # This could use the in-memory stores for recent changes. - rows = await self.store.get_all_presence_updates(last_id, current_id) + rows = await self.store.get_all_presence_updates(last_id, current_id, limit) return rows def notify_new_event(self): diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 391bceb0c4..c7bc14c623 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -15,6 +15,7 @@ import logging from collections import namedtuple +from typing import List from twisted.internet import defer @@ -257,7 +258,13 @@ class TypingHandler(object): "typing_key", self._latest_room_serial, rooms=[member.room_id] ) - async def get_all_typing_updates(self, last_id, current_id): + async def get_all_typing_updates( + self, last_id: int, current_id: int, limit: int + ) -> List[dict]: + """Get up to `limit` typing updates between the given tokens, earliest + updates first. + """ + if last_id == current_id: return [] @@ -275,7 +282,7 @@ class TypingHandler(object): typing = self._room_typing[room_id] rows.append((serial, room_id, list(typing))) rows.sort() - return rows + return rows[:limit] def get_current_token(self): return self._latest_room_serial diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index ce9d1fae12..6e2ebaf614 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -166,11 +166,6 @@ class ReplicationStreamer(object): self.pending_updates = False with Measure(self.clock, "repl.stream.get_updates"): - # First we tell the streams that they should update their - # current tokens. - for stream in self.streams: - stream.advance_current_token() - all_streams = self.streams if self._replication_torture_level is not None: @@ -180,7 +175,7 @@ class ReplicationStreamer(object): random.shuffle(all_streams) for stream in all_streams: - if stream.last_token == stream.upto_token: + if stream.last_token == stream.current_token(): continue if self._replication_torture_level: @@ -192,7 +187,7 @@ class ReplicationStreamer(object): "Getting stream: %s: %s -> %s", stream.NAME, stream.last_token, - stream.upto_token, + stream.current_token(), ) try: updates, current_token = await stream.get_updates() diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 7a8b6e9df1..abf5c6c6a8 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -17,10 +17,12 @@ import itertools import logging from collections import namedtuple -from typing import Any, List, Optional +from typing import Any, List, Optional, Tuple import attr +from synapse.types import JsonDict + logger = logging.getLogger(__name__) @@ -119,13 +121,12 @@ class Stream(object): """Base class for the streams. Provides a `get_updates()` function that returns new updates since the last - time it was called up until the point `advance_current_token` was called. + time it was called. """ NAME = None # type: str # The name of the stream # The type of the row. Used by the default impl of parse_row. ROW_TYPE = None # type: Any - _LIMITED = True # Whether the update function takes a limit @classmethod def parse_row(cls, row): @@ -146,26 +147,15 @@ class Stream(object): # The token from which we last asked for updates self.last_token = self.current_token() - # The token that we will get updates up to - self.upto_token = self.current_token() - - def advance_current_token(self): - """Updates `upto_token` to "now", which updates up until which point - get_updates[_since] will fetch rows till. - """ - self.upto_token = self.current_token() - def discard_updates_and_advance(self): """Called when the stream should advance but the updates would be discarded, e.g. when there are no currently connected workers. """ - self.upto_token = self.current_token() - self.last_token = self.upto_token + self.last_token = self.current_token() async def get_updates(self): """Gets all updates since the last time this function was called (or - since the stream was constructed if it hadn't been called before), - until the `upto_token` + since the stream was constructed if it hadn't been called before). Returns: Deferred[Tuple[List[Tuple[int, Any]], int]: @@ -178,44 +168,45 @@ class Stream(object): return updates, current_token - async def get_updates_since(self, from_token): + async def get_updates_since( + self, from_token: int + ) -> Tuple[List[Tuple[int, JsonDict]], int]: """Like get_updates except allows specifying from when we should stream updates Returns: - Deferred[Tuple[List[Tuple[int, Any]], int]: - Resolves to a pair ``(updates, current_token)``, where ``updates`` is a - list of ``(token, row)`` entries. ``row`` will be json-serialised and - sent over the replication steam. + Resolves to a pair `(updates, new_last_token)`, where `updates` is + a list of `(token, row)` entries and `new_last_token` is the new + position in stream. """ + if from_token in ("NOW", "now"): - return [], self.upto_token + return [], self.current_token() - current_token = self.upto_token + current_token = self.current_token() from_token = int(from_token) if from_token == current_token: return [], current_token - logger.info("get_updates_since: %s", self.__class__) - if self._LIMITED: - rows = await self.update_function( - from_token, current_token, limit=MAX_EVENTS_BEHIND + 1 - ) + rows = await self.update_function( + from_token, current_token, limit=MAX_EVENTS_BEHIND + 1 + ) - # never turn more than MAX_EVENTS_BEHIND + 1 into updates. - rows = itertools.islice(rows, MAX_EVENTS_BEHIND + 1) - else: - rows = await self.update_function(from_token, current_token) + # never turn more than MAX_EVENTS_BEHIND + 1 into updates. + rows = itertools.islice(rows, MAX_EVENTS_BEHIND + 1) updates = [(row[0], row[1:]) for row in rows] # check we didn't get more rows than the limit. # doing it like this allows the update_function to be a generator. - if self._LIMITED and len(updates) >= MAX_EVENTS_BEHIND: + if len(updates) >= MAX_EVENTS_BEHIND: raise Exception("stream %s has fallen behind" % (self.NAME)) + # The update function didn't hit the limit, so we must have got all + # the updates to `current_token`, and can return that as our new + # stream position. return updates, current_token def current_token(self): @@ -227,9 +218,8 @@ class Stream(object): """ raise NotImplementedError() - def update_function(self, from_token, current_token, limit=None): - """Get updates between from_token and to_token. If Stream._LIMITED is - True then limit is provided, otherwise it's not. + def update_function(self, from_token, current_token, limit): + """Get updates between from_token and to_token. Returns: Deferred(list(tuple)): the first entry in the tuple is the token for @@ -257,7 +247,6 @@ class BackfillStream(Stream): class PresenceStream(Stream): NAME = "presence" - _LIMITED = False ROW_TYPE = PresenceStreamRow def __init__(self, hs): @@ -272,7 +261,6 @@ class PresenceStream(Stream): class TypingStream(Stream): NAME = "typing" - _LIMITED = False ROW_TYPE = TypingStreamRow def __init__(self, hs): @@ -372,7 +360,6 @@ class DeviceListsStream(Stream): """ NAME = "device_lists" - _LIMITED = False ROW_TYPE = DeviceListsStreamRow def __init__(self, hs): @@ -462,7 +449,6 @@ class UserSignatureStream(Stream): """ NAME = "user_signature" - _LIMITED = False ROW_TYPE = UserSignatureStreamRow def __init__(self, hs): diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py index 4c19c02bbc..2d47cfd131 100644 --- a/synapse/storage/data_stores/main/devices.py +++ b/synapse/storage/data_stores/main/devices.py @@ -576,7 +576,7 @@ class DeviceWorkerStore(SQLBaseStore): return set() async def get_all_device_list_changes_for_remotes( - self, from_key: int, to_key: int + self, from_key: int, to_key: int, limit: int, ) -> List[Tuple[int, str]]: """Return a list of `(stream_id, entity)` which is the combined list of changes to devices and which destinations need to be poked. Entity is @@ -592,10 +592,16 @@ class DeviceWorkerStore(SQLBaseStore): SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes ) AS e WHERE ? < stream_id AND stream_id <= ? + LIMIT ? """ return await self.db.execute( - "get_all_device_list_changes_for_remotes", None, sql, from_key, to_key + "get_all_device_list_changes_for_remotes", + None, + sql, + from_key, + to_key, + limit, ) @cached(max_entries=10000) diff --git a/synapse/storage/data_stores/main/end_to_end_keys.py b/synapse/storage/data_stores/main/end_to_end_keys.py index 001a53f9b4..bcf746b7ef 100644 --- a/synapse/storage/data_stores/main/end_to_end_keys.py +++ b/synapse/storage/data_stores/main/end_to_end_keys.py @@ -537,7 +537,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): return result - def get_all_user_signature_changes_for_remotes(self, from_key, to_key): + def get_all_user_signature_changes_for_remotes(self, from_key, to_key, limit): """Return a list of changes from the user signature stream to notify remotes. Note that the user signature stream represents when a user signs their device with their user-signing key, which is not published to other @@ -552,13 +552,19 @@ class EndToEndKeyWorkerStore(SQLBaseStore): Deferred[list[(int,str)]] a list of `(stream_id, user_id)` """ sql = """ - SELECT MAX(stream_id) AS stream_id, from_user_id AS user_id + SELECT stream_id, from_user_id AS user_id FROM user_signature_stream WHERE ? < stream_id AND stream_id <= ? - GROUP BY user_id + ORDER BY stream_id ASC + LIMIT ? """ return self.db.execute( - "get_all_user_signature_changes_for_remotes", None, sql, from_key, to_key + "get_all_user_signature_changes_for_remotes", + None, + sql, + from_key, + to_key, + limit, ) diff --git a/synapse/storage/data_stores/main/presence.py b/synapse/storage/data_stores/main/presence.py index 604c8b7ddd..dab31e0c2d 100644 --- a/synapse/storage/data_stores/main/presence.py +++ b/synapse/storage/data_stores/main/presence.py @@ -60,7 +60,7 @@ class PresenceStore(SQLBaseStore): "status_msg": state.status_msg, "currently_active": state.currently_active, } - for state in presence_states + for stream_id, state in zip(stream_orderings, presence_states) ], ) @@ -73,19 +73,22 @@ class PresenceStore(SQLBaseStore): ) txn.execute(sql + clause, [stream_id] + list(args)) - def get_all_presence_updates(self, last_id, current_id): + def get_all_presence_updates(self, last_id, current_id, limit): if last_id == current_id: return defer.succeed([]) def get_all_presence_updates_txn(txn): - sql = ( - "SELECT stream_id, user_id, state, last_active_ts," - " last_federation_update_ts, last_user_sync_ts, status_msg," - " currently_active" - " FROM presence_stream" - " WHERE ? < stream_id AND stream_id <= ?" - ) - txn.execute(sql, (last_id, current_id)) + sql = """ + SELECT stream_id, user_id, state, last_active_ts, + last_federation_update_ts, last_user_sync_ts, + status_msg, + currently_active + FROM presence_stream + WHERE ? < stream_id AND stream_id <= ? + ORDER BY stream_id ASC + LIMIT ? + """ + txn.execute(sql, (last_id, current_id, limit)) return txn.fetchall() return self.db.runInteraction( -- cgit 1.4.1 From c165c1233b8ef244fadca97c7d465fdcf473d077 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 20 Mar 2020 16:24:22 +0100 Subject: Improve database configuration docs (#6988) Attempts to clarify the sample config for databases, and add some stuff about tcp keepalives to `postgres.md`. --- changelog.d/6988.doc | 1 + docs/postgres.md | 42 ++++++++++++++----- docs/sample_config.yaml | 43 +++++++++++++++++--- synapse/config/_base.py | 2 - synapse/config/database.py | 93 +++++++++++++++++++++++++++---------------- tests/config/test_database.py | 22 +--------- 6 files changed, 132 insertions(+), 71 deletions(-) create mode 100644 changelog.d/6988.doc (limited to 'synapse') diff --git a/changelog.d/6988.doc b/changelog.d/6988.doc new file mode 100644 index 0000000000..b6f71bb966 --- /dev/null +++ b/changelog.d/6988.doc @@ -0,0 +1 @@ +Improve the documentation for database configuration. diff --git a/docs/postgres.md b/docs/postgres.md index e0793ecee8..16a630c3d1 100644 --- a/docs/postgres.md +++ b/docs/postgres.md @@ -105,19 +105,41 @@ of free memory the database host has available. When you are ready to start using PostgreSQL, edit the `database` section in your config file to match the following lines: - database: - name: psycopg2 - args: - user: - password: - database: - host: - cp_min: 5 - cp_max: 10 +```yaml +database: + name: psycopg2 + args: + user: + password: + database: + host: + cp_min: 5 + cp_max: 10 +``` All key, values in `args` are passed to the `psycopg2.connect(..)` function, except keys beginning with `cp_`, which are consumed by the -twisted adbapi connection pool. +twisted adbapi connection pool. See the [libpq +documentation](https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS) +for a list of options which can be passed. + +You should consider tuning the `args.keepalives_*` options if there is any danger of +the connection between your homeserver and database dropping, otherwise Synapse +may block for an extended period while it waits for a response from the +database server. Example values might be: + +```yaml +# seconds of inactivity after which TCP should send a keepalive message to the server +keepalives_idle: 10 + +# the number of seconds after which a TCP keepalive message that is not +# acknowledged by the server should be retransmitted +keepalives_interval: 10 + +# the number of TCP keepalives that can be lost before the client's connection +# to the server is considered dead +keepalives_count: 3 +``` ## Porting from SQLite diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 2ff0dd05a2..276e43b732 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -578,13 +578,46 @@ acme: ## Database ## +# The 'database' setting defines the database that synapse uses to store all of +# its data. +# +# 'name' gives the database engine to use: either 'sqlite3' (for SQLite) or +# 'psycopg2' (for PostgreSQL). +# +# 'args' gives options which are passed through to the database engine, +# except for options starting 'cp_', which are used to configure the Twisted +# connection pool. For a reference to valid arguments, see: +# * for sqlite: https://docs.python.org/3/library/sqlite3.html#sqlite3.connect +# * for postgres: https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS +# * for the connection pool: https://twistedmatrix.com/documents/current/api/twisted.enterprise.adbapi.ConnectionPool.html#__init__ +# +# +# Example SQLite configuration: +# +#database: +# name: sqlite3 +# args: +# database: /path/to/homeserver.db +# +# +# Example Postgres configuration: +# +#database: +# name: psycopg2 +# args: +# user: synapse +# password: secretpassword +# database: synapse +# host: localhost +# cp_min: 5 +# cp_max: 10 +# +# For more information on using Synapse with Postgres, see `docs/postgres.md`. +# database: - # The database engine name - name: "sqlite3" - # Arguments to pass to the engine + name: sqlite3 args: - # Path to the database - database: "DATADIR/homeserver.db" + database: DATADIR/homeserver.db # Number of events to cache in memory. # diff --git a/synapse/config/_base.py b/synapse/config/_base.py index ba846042c4..efe2af5504 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -294,7 +294,6 @@ class RootConfig(object): report_stats=None, open_private_ports=False, listeners=None, - database_conf=None, tls_certificate_path=None, tls_private_key_path=None, acme_domain=None, @@ -367,7 +366,6 @@ class RootConfig(object): report_stats=report_stats, open_private_ports=open_private_ports, listeners=listeners, - database_conf=database_conf, tls_certificate_path=tls_certificate_path, tls_private_key_path=tls_private_key_path, acme_domain=acme_domain, diff --git a/synapse/config/database.py b/synapse/config/database.py index 219b32f670..b8ab2f86ac 100644 --- a/synapse/config/database.py +++ b/synapse/config/database.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,14 +15,60 @@ # limitations under the License. import logging import os -from textwrap import indent - -import yaml from synapse.config._base import Config, ConfigError logger = logging.getLogger(__name__) +DEFAULT_CONFIG = """\ +## Database ## + +# The 'database' setting defines the database that synapse uses to store all of +# its data. +# +# 'name' gives the database engine to use: either 'sqlite3' (for SQLite) or +# 'psycopg2' (for PostgreSQL). +# +# 'args' gives options which are passed through to the database engine, +# except for options starting 'cp_', which are used to configure the Twisted +# connection pool. For a reference to valid arguments, see: +# * for sqlite: https://docs.python.org/3/library/sqlite3.html#sqlite3.connect +# * for postgres: https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS +# * for the connection pool: https://twistedmatrix.com/documents/current/api/twisted.enterprise.adbapi.ConnectionPool.html#__init__ +# +# +# Example SQLite configuration: +# +#database: +# name: sqlite3 +# args: +# database: /path/to/homeserver.db +# +# +# Example Postgres configuration: +# +#database: +# name: psycopg2 +# args: +# user: synapse +# password: secretpassword +# database: synapse +# host: localhost +# cp_min: 5 +# cp_max: 10 +# +# For more information on using Synapse with Postgres, see `docs/postgres.md`. +# +database: + name: sqlite3 + args: + database: %(database_path)s + +# Number of events to cache in memory. +# +#event_cache_size: 10K +""" + class DatabaseConnectionConfig: """Contains the connection config for a particular database. @@ -36,10 +83,12 @@ class DatabaseConnectionConfig: """ def __init__(self, name: str, db_config: dict): - if db_config["name"] not in ("sqlite3", "psycopg2"): - raise ConfigError("Unsupported database type %r" % (db_config["name"],)) + db_engine = db_config.get("name", "sqlite3") - if db_config["name"] == "sqlite3": + if db_engine not in ("sqlite3", "psycopg2"): + raise ConfigError("Unsupported database type %r" % (db_engine,)) + + if db_engine == "sqlite3": db_config.setdefault("args", {}).update( {"cp_min": 1, "cp_max": 1, "check_same_thread": False} ) @@ -97,34 +146,10 @@ class DatabaseConfig(Config): self.set_databasepath(config.get("database_path")) - def generate_config_section(self, data_dir_path, database_conf, **kwargs): - if not database_conf: - database_path = os.path.join(data_dir_path, "homeserver.db") - database_conf = ( - """# The database engine name - name: "sqlite3" - # Arguments to pass to the engine - args: - # Path to the database - database: "%(database_path)s" - """ - % locals() - ) - else: - database_conf = indent(yaml.dump(database_conf), " " * 10).lstrip() - - return ( - """\ - ## Database ## - - database: - %(database_conf)s - # Number of events to cache in memory. - # - #event_cache_size: 10K - """ - % locals() - ) + def generate_config_section(self, data_dir_path, **kwargs): + return DEFAULT_CONFIG % { + "database_path": os.path.join(data_dir_path, "homeserver.db") + } def read_arguments(self, args): self.set_databasepath(args.database_path) diff --git a/tests/config/test_database.py b/tests/config/test_database.py index 151d3006ac..f675bde68e 100644 --- a/tests/config/test_database.py +++ b/tests/config/test_database.py @@ -21,9 +21,9 @@ from tests import unittest class DatabaseConfigTestCase(unittest.TestCase): - def test_database_configured_correctly_no_database_conf_param(self): + def test_database_configured_correctly(self): conf = yaml.safe_load( - DatabaseConfig().generate_config_section("/data_dir_path", None) + DatabaseConfig().generate_config_section(data_dir_path="/data_dir_path") ) expected_database_conf = { @@ -32,21 +32,3 @@ class DatabaseConfigTestCase(unittest.TestCase): } self.assertEqual(conf["database"], expected_database_conf) - - def test_database_configured_correctly_database_conf_param(self): - - database_conf = { - "name": "my super fast datastore", - "args": { - "user": "matrix", - "password": "synapse_database_password", - "host": "synapse_database_host", - "database": "matrix", - }, - } - - conf = yaml.safe_load( - DatabaseConfig().generate_config_section("/data_dir_path", database_conf) - ) - - self.assertEqual(conf["database"], database_conf) -- cgit 1.4.1 From 477c4f5b1c2c7733d4b2cf578dc9aa8e048011b0 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 20 Mar 2020 16:22:47 -0400 Subject: Clean-up some auth/login REST code (#7115) --- changelog.d/7115.misc | 1 + synapse/rest/client/v1/login.py | 8 ------ synapse/rest/client/v2_alpha/auth.py | 53 ++++++++++++++---------------------- 3 files changed, 21 insertions(+), 41 deletions(-) create mode 100644 changelog.d/7115.misc (limited to 'synapse') diff --git a/changelog.d/7115.misc b/changelog.d/7115.misc new file mode 100644 index 0000000000..7d4a011e3e --- /dev/null +++ b/changelog.d/7115.misc @@ -0,0 +1 @@ +De-duplicate / remove unused REST code for login and auth. diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index d0d4999795..31551524f8 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -28,7 +28,6 @@ from synapse.http.servlet import ( parse_json_object_from_request, parse_string, ) -from synapse.push.mailer import load_jinja2_templates from synapse.rest.client.v2_alpha._base import client_patterns from synapse.rest.well_known import WellKnownBuilder from synapse.types import UserID, map_username_to_mxid_localpart @@ -548,13 +547,6 @@ class SSOAuthHandler(object): self._registration_handler = hs.get_registration_handler() self._macaroon_gen = hs.get_macaroon_generator() - # Load the redirect page HTML template - self._template = load_jinja2_templates( - hs.config.sso_redirect_confirm_template_dir, ["sso_redirect_confirm.html"], - )[0] - - self._server_name = hs.config.server_name - # cast to tuple for use with str.startswith self._whitelisted_sso_clients = tuple(hs.config.sso_client_whitelist) diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py index 50e080673b..85cf5a14c6 100644 --- a/synapse/rest/client/v2_alpha/auth.py +++ b/synapse/rest/client/v2_alpha/auth.py @@ -142,14 +142,6 @@ class AuthRestServlet(RestServlet): % (CLIENT_API_PREFIX, LoginType.RECAPTCHA), "sitekey": self.hs.config.recaptcha_public_key, } - html_bytes = html.encode("utf8") - request.setResponseCode(200) - request.setHeader(b"Content-Type", b"text/html; charset=utf-8") - request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),)) - - request.write(html_bytes) - finish_request(request) - return None elif stagetype == LoginType.TERMS: html = TERMS_TEMPLATE % { "session": session, @@ -158,17 +150,19 @@ class AuthRestServlet(RestServlet): "myurl": "%s/r0/auth/%s/fallback/web" % (CLIENT_API_PREFIX, LoginType.TERMS), } - html_bytes = html.encode("utf8") - request.setResponseCode(200) - request.setHeader(b"Content-Type", b"text/html; charset=utf-8") - request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),)) - - request.write(html_bytes) - finish_request(request) - return None else: raise SynapseError(404, "Unknown auth stage type") + # Render the HTML and return. + html_bytes = html.encode("utf8") + request.setResponseCode(200) + request.setHeader(b"Content-Type", b"text/html; charset=utf-8") + request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),)) + + request.write(html_bytes) + finish_request(request) + return None + async def on_POST(self, request, stagetype): session = parse_string(request, "session") @@ -196,15 +190,6 @@ class AuthRestServlet(RestServlet): % (CLIENT_API_PREFIX, LoginType.RECAPTCHA), "sitekey": self.hs.config.recaptcha_public_key, } - html_bytes = html.encode("utf8") - request.setResponseCode(200) - request.setHeader(b"Content-Type", b"text/html; charset=utf-8") - request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),)) - - request.write(html_bytes) - finish_request(request) - - return None elif stagetype == LoginType.TERMS: authdict = {"session": session} @@ -225,17 +210,19 @@ class AuthRestServlet(RestServlet): "myurl": "%s/r0/auth/%s/fallback/web" % (CLIENT_API_PREFIX, LoginType.TERMS), } - html_bytes = html.encode("utf8") - request.setResponseCode(200) - request.setHeader(b"Content-Type", b"text/html; charset=utf-8") - request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),)) - - request.write(html_bytes) - finish_request(request) - return None else: raise SynapseError(404, "Unknown auth stage type") + # Render the HTML and return. + html_bytes = html.encode("utf8") + request.setResponseCode(200) + request.setHeader(b"Content-Type", b"text/html; charset=utf-8") + request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),)) + + request.write(html_bytes) + finish_request(request) + return None + def on_OPTIONS(self, _): return 200, {} -- cgit 1.4.1 From 96071eea8f5e18282c07da3a61e4b3431f694cc5 Mon Sep 17 00:00:00 2001 From: Dionysis Grigoropoulos Date: Mon, 23 Mar 2020 11:48:28 +0200 Subject: Set Referrer-Policy to no-referrer for media (#7009) --- changelog.d/7009.feature | 1 + synapse/rest/media/v1/download_resource.py | 3 +++ 2 files changed, 4 insertions(+) create mode 100644 changelog.d/7009.feature (limited to 'synapse') diff --git a/changelog.d/7009.feature b/changelog.d/7009.feature new file mode 100644 index 0000000000..cd2705d5ba --- /dev/null +++ b/changelog.d/7009.feature @@ -0,0 +1 @@ +Set `Referrer-Policy` header to `no-referrer` on media downloads. diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/v1/download_resource.py index 66a01559e1..24d3ae5bbc 100644 --- a/synapse/rest/media/v1/download_resource.py +++ b/synapse/rest/media/v1/download_resource.py @@ -50,6 +50,9 @@ class DownloadResource(DirectServeResource): b" media-src 'self';" b" object-src 'self';", ) + request.setHeader( + b"Referrer-Policy", b"no-referrer", + ) server_name, media_id, name = parse_media_id(request) if server_name == self.server_name: await self.media_repo.get_local_media(request, media_id, name) -- cgit 1.4.1 From b3cee0ce670ada582b2a4b36c377f160c7ee1d09 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 23 Mar 2020 11:39:36 +0000 Subject: Fix processing of `groups` stream, and use symbolic names for streams (#7117) `groups` != `receipts` Introduced in #6964 --- changelog.d/7117.bugfix | 1 + synapse/app/generic_worker.py | 35 ++++++++++----- synapse/replication/tcp/streams/__init__.py | 70 +++++++++++++++++++++-------- 3 files changed, 76 insertions(+), 30 deletions(-) create mode 100644 changelog.d/7117.bugfix (limited to 'synapse') diff --git a/changelog.d/7117.bugfix b/changelog.d/7117.bugfix new file mode 100644 index 0000000000..1896d7ad49 --- /dev/null +++ b/changelog.d/7117.bugfix @@ -0,0 +1 @@ +Fix a bug which meant that groups updates were not correctly replicated between workers. diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index cdc078cf11..136babe6ce 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -65,12 +65,23 @@ from synapse.replication.slave.storage.registration import SlavedRegistrationSto from synapse.replication.slave.storage.room import RoomStore from synapse.replication.slave.storage.transactions import SlavedTransactionStore from synapse.replication.tcp.client import ReplicationClientHandler -from synapse.replication.tcp.streams._base import ( +from synapse.replication.tcp.streams import ( + AccountDataStream, DeviceListsStream, + GroupServerStream, + PresenceStream, + PushersStream, + PushRulesStream, ReceiptsStream, + TagAccountDataStream, ToDeviceStream, + TypingStream, +) +from synapse.replication.tcp.streams.events import ( + EventsStream, + EventsStreamEventRow, + EventsStreamRow, ) -from synapse.replication.tcp.streams.events import EventsStreamEventRow, EventsStreamRow from synapse.rest.admin import register_servlets_for_media_repo from synapse.rest.client.v1 import events from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet @@ -626,7 +637,7 @@ class GenericWorkerReplicationHandler(ReplicationClientHandler): if self.send_handler: self.send_handler.process_replication_rows(stream_name, token, rows) - if stream_name == "events": + if stream_name == EventsStream.NAME: # We shouldn't get multiple rows per token for events stream, so # we don't need to optimise this for multiple rows. for row in rows: @@ -649,44 +660,44 @@ class GenericWorkerReplicationHandler(ReplicationClientHandler): ) await self.pusher_pool.on_new_notifications(token, token) - elif stream_name == "push_rules": + elif stream_name == PushRulesStream.NAME: self.notifier.on_new_event( "push_rules_key", token, users=[row.user_id for row in rows] ) - elif stream_name in ("account_data", "tag_account_data"): + elif stream_name in (AccountDataStream.NAME, TagAccountDataStream.NAME): self.notifier.on_new_event( "account_data_key", token, users=[row.user_id for row in rows] ) - elif stream_name == "receipts": + elif stream_name == ReceiptsStream.NAME: self.notifier.on_new_event( "receipt_key", token, rooms=[row.room_id for row in rows] ) await self.pusher_pool.on_new_receipts( token, token, {row.room_id for row in rows} ) - elif stream_name == "typing": + elif stream_name == TypingStream.NAME: self.typing_handler.process_replication_rows(token, rows) self.notifier.on_new_event( "typing_key", token, rooms=[row.room_id for row in rows] ) - elif stream_name == "to_device": + elif stream_name == ToDeviceStream.NAME: entities = [row.entity for row in rows if row.entity.startswith("@")] if entities: self.notifier.on_new_event("to_device_key", token, users=entities) - elif stream_name == "device_lists": + elif stream_name == DeviceListsStream.NAME: all_room_ids = set() for row in rows: if row.entity.startswith("@"): room_ids = await self.store.get_rooms_for_user(row.entity) all_room_ids.update(room_ids) self.notifier.on_new_event("device_list_key", token, rooms=all_room_ids) - elif stream_name == "presence": + elif stream_name == PresenceStream.NAME: await self.presence_handler.process_replication_rows(token, rows) - elif stream_name == "receipts": + elif stream_name == GroupServerStream.NAME: self.notifier.on_new_event( "groups_key", token, users=[row.user_id for row in rows] ) - elif stream_name == "pushers": + elif stream_name == PushersStream.NAME: for row in rows: if row.deleted: self.stop_pusher(row.user_id, row.app_id, row.pushkey) diff --git a/synapse/replication/tcp/streams/__init__.py b/synapse/replication/tcp/streams/__init__.py index 5f52264e84..29199f5b46 100644 --- a/synapse/replication/tcp/streams/__init__.py +++ b/synapse/replication/tcp/streams/__init__.py @@ -24,27 +24,61 @@ Each stream is defined by the following information: current_token: The function that returns the current token for the stream update_function: The function that returns a list of updates between two tokens """ - -from . import _base, events, federation +from synapse.replication.tcp.streams._base import ( + AccountDataStream, + BackfillStream, + CachesStream, + DeviceListsStream, + GroupServerStream, + PresenceStream, + PublicRoomsStream, + PushersStream, + PushRulesStream, + ReceiptsStream, + TagAccountDataStream, + ToDeviceStream, + TypingStream, + UserSignatureStream, +) +from synapse.replication.tcp.streams.events import EventsStream +from synapse.replication.tcp.streams.federation import FederationStream STREAMS_MAP = { stream.NAME: stream for stream in ( - events.EventsStream, - _base.BackfillStream, - _base.PresenceStream, - _base.TypingStream, - _base.ReceiptsStream, - _base.PushRulesStream, - _base.PushersStream, - _base.CachesStream, - _base.PublicRoomsStream, - _base.DeviceListsStream, - _base.ToDeviceStream, - federation.FederationStream, - _base.TagAccountDataStream, - _base.AccountDataStream, - _base.GroupServerStream, - _base.UserSignatureStream, + EventsStream, + BackfillStream, + PresenceStream, + TypingStream, + ReceiptsStream, + PushRulesStream, + PushersStream, + CachesStream, + PublicRoomsStream, + DeviceListsStream, + ToDeviceStream, + FederationStream, + TagAccountDataStream, + AccountDataStream, + GroupServerStream, + UserSignatureStream, ) } + +__all__ = [ + "STREAMS_MAP", + "BackfillStream", + "PresenceStream", + "TypingStream", + "ReceiptsStream", + "PushRulesStream", + "PushersStream", + "CachesStream", + "PublicRoomsStream", + "DeviceListsStream", + "ToDeviceStream", + "TagAccountDataStream", + "AccountDataStream", + "GroupServerStream", + "UserSignatureStream", +] -- cgit 1.4.1 From a564b92d37625855940fe599c730a9958c33f973 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 23 Mar 2020 13:59:11 +0000 Subject: Convert `*StreamRow` classes to inner classes (#7116) This just helps keep the rows closer to their streams, so that it's easier to see what the format of each stream is. --- changelog.d/7116.misc | 1 + synapse/app/generic_worker.py | 2 +- synapse/federation/send_queue.py | 2 +- synapse/replication/tcp/streams/_base.py | 181 +++++++++++++------------ synapse/replication/tcp/streams/federation.py | 16 +-- tests/replication/tcp/streams/test_receipts.py | 4 +- 6 files changed, 106 insertions(+), 100 deletions(-) create mode 100644 changelog.d/7116.misc (limited to 'synapse') diff --git a/changelog.d/7116.misc b/changelog.d/7116.misc new file mode 100644 index 0000000000..89d90bd49e --- /dev/null +++ b/changelog.d/7116.misc @@ -0,0 +1 @@ +Convert `*StreamRow` classes to inner classes. diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 136babe6ce..c8fd8909a4 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -804,7 +804,7 @@ class FederationSenderHandler(object): async def _on_new_receipts(self, rows): """ Args: - rows (iterable[synapse.replication.tcp.streams.ReceiptsStreamRow]): + rows (Iterable[synapse.replication.tcp.streams.ReceiptsStream.ReceiptsStreamRow]): new receipts to be processed """ for receipt in rows: diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py index 876fb0e245..e1700ca8aa 100644 --- a/synapse/federation/send_queue.py +++ b/synapse/federation/send_queue.py @@ -477,7 +477,7 @@ def process_rows_for_federation(transaction_queue, rows): Args: transaction_queue (FederationSender) - rows (list(synapse.replication.tcp.streams.FederationStreamRow)) + rows (list(synapse.replication.tcp.streams.federation.FederationStream.FederationStreamRow)) """ # The federation stream contains a bunch of different types of diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index abf5c6c6a8..32d9514883 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -28,94 +28,6 @@ logger = logging.getLogger(__name__) MAX_EVENTS_BEHIND = 500000 -BackfillStreamRow = namedtuple( - "BackfillStreamRow", - ( - "event_id", # str - "room_id", # str - "type", # str - "state_key", # str, optional - "redacts", # str, optional - "relates_to", # str, optional - ), -) -PresenceStreamRow = namedtuple( - "PresenceStreamRow", - ( - "user_id", # str - "state", # str - "last_active_ts", # int - "last_federation_update_ts", # int - "last_user_sync_ts", # int - "status_msg", # str - "currently_active", # bool - ), -) -TypingStreamRow = namedtuple( - "TypingStreamRow", ("room_id", "user_ids") # str # list(str) -) -ReceiptsStreamRow = namedtuple( - "ReceiptsStreamRow", - ( - "room_id", # str - "receipt_type", # str - "user_id", # str - "event_id", # str - "data", # dict - ), -) -PushRulesStreamRow = namedtuple("PushRulesStreamRow", ("user_id",)) # str -PushersStreamRow = namedtuple( - "PushersStreamRow", - ("user_id", "app_id", "pushkey", "deleted"), # str # str # str # bool -) - - -@attr.s -class CachesStreamRow: - """Stream to inform workers they should invalidate their cache. - - Attributes: - cache_func: Name of the cached function. - keys: The entry in the cache to invalidate. If None then will - invalidate all. - invalidation_ts: Timestamp of when the invalidation took place. - """ - - cache_func = attr.ib(type=str) - keys = attr.ib(type=Optional[List[Any]]) - invalidation_ts = attr.ib(type=int) - - -PublicRoomsStreamRow = namedtuple( - "PublicRoomsStreamRow", - ( - "room_id", # str - "visibility", # str - "appservice_id", # str, optional - "network_id", # str, optional - ), -) - - -@attr.s -class DeviceListsStreamRow: - entity = attr.ib(type=str) - - -ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ("entity",)) # str -TagAccountDataStreamRow = namedtuple( - "TagAccountDataStreamRow", ("user_id", "room_id", "data") # str # str # dict -) -AccountDataStreamRow = namedtuple( - "AccountDataStream", ("user_id", "room_id", "data_type") # str # str # str -) -GroupsStreamRow = namedtuple( - "GroupsStreamRow", - ("group_id", "user_id", "type", "content"), # str # str # str # dict -) -UserSignatureStreamRow = namedtuple("UserSignatureStreamRow", ("user_id")) # str - class Stream(object): """Base class for the streams. @@ -234,6 +146,18 @@ class BackfillStream(Stream): or it went from being an outlier to not. """ + BackfillStreamRow = namedtuple( + "BackfillStreamRow", + ( + "event_id", # str + "room_id", # str + "type", # str + "state_key", # str, optional + "redacts", # str, optional + "relates_to", # str, optional + ), + ) + NAME = "backfill" ROW_TYPE = BackfillStreamRow @@ -246,6 +170,19 @@ class BackfillStream(Stream): class PresenceStream(Stream): + PresenceStreamRow = namedtuple( + "PresenceStreamRow", + ( + "user_id", # str + "state", # str + "last_active_ts", # int + "last_federation_update_ts", # int + "last_user_sync_ts", # int + "status_msg", # str + "currently_active", # bool + ), + ) + NAME = "presence" ROW_TYPE = PresenceStreamRow @@ -260,6 +197,10 @@ class PresenceStream(Stream): class TypingStream(Stream): + TypingStreamRow = namedtuple( + "TypingStreamRow", ("room_id", "user_ids") # str # list(str) + ) + NAME = "typing" ROW_TYPE = TypingStreamRow @@ -273,6 +214,17 @@ class TypingStream(Stream): class ReceiptsStream(Stream): + ReceiptsStreamRow = namedtuple( + "ReceiptsStreamRow", + ( + "room_id", # str + "receipt_type", # str + "user_id", # str + "event_id", # str + "data", # dict + ), + ) + NAME = "receipts" ROW_TYPE = ReceiptsStreamRow @@ -289,6 +241,8 @@ class PushRulesStream(Stream): """A user has changed their push rules """ + PushRulesStreamRow = namedtuple("PushRulesStreamRow", ("user_id",)) # str + NAME = "push_rules" ROW_TYPE = PushRulesStreamRow @@ -309,6 +263,11 @@ class PushersStream(Stream): """A user has added/changed/removed a pusher """ + PushersStreamRow = namedtuple( + "PushersStreamRow", + ("user_id", "app_id", "pushkey", "deleted"), # str # str # str # bool + ) + NAME = "pushers" ROW_TYPE = PushersStreamRow @@ -326,6 +285,21 @@ class CachesStream(Stream): the cache on the workers """ + @attr.s + class CachesStreamRow: + """Stream to inform workers they should invalidate their cache. + + Attributes: + cache_func: Name of the cached function. + keys: The entry in the cache to invalidate. If None then will + invalidate all. + invalidation_ts: Timestamp of when the invalidation took place. + """ + + cache_func = attr.ib(type=str) + keys = attr.ib(type=Optional[List[Any]]) + invalidation_ts = attr.ib(type=int) + NAME = "caches" ROW_TYPE = CachesStreamRow @@ -342,6 +316,16 @@ class PublicRoomsStream(Stream): """The public rooms list changed """ + PublicRoomsStreamRow = namedtuple( + "PublicRoomsStreamRow", + ( + "room_id", # str + "visibility", # str + "appservice_id", # str, optional + "network_id", # str, optional + ), + ) + NAME = "public_rooms" ROW_TYPE = PublicRoomsStreamRow @@ -359,6 +343,10 @@ class DeviceListsStream(Stream): told about a device update. """ + @attr.s + class DeviceListsStreamRow: + entity = attr.ib(type=str) + NAME = "device_lists" ROW_TYPE = DeviceListsStreamRow @@ -375,6 +363,8 @@ class ToDeviceStream(Stream): """New to_device messages for a client """ + ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ("entity",)) # str + NAME = "to_device" ROW_TYPE = ToDeviceStreamRow @@ -391,6 +381,10 @@ class TagAccountDataStream(Stream): """Someone added/removed a tag for a room """ + TagAccountDataStreamRow = namedtuple( + "TagAccountDataStreamRow", ("user_id", "room_id", "data") # str # str # dict + ) + NAME = "tag_account_data" ROW_TYPE = TagAccountDataStreamRow @@ -407,6 +401,10 @@ class AccountDataStream(Stream): """Global or per room account data was changed """ + AccountDataStreamRow = namedtuple( + "AccountDataStream", ("user_id", "room_id", "data_type") # str # str # str + ) + NAME = "account_data" ROW_TYPE = AccountDataStreamRow @@ -432,6 +430,11 @@ class AccountDataStream(Stream): class GroupServerStream(Stream): + GroupsStreamRow = namedtuple( + "GroupsStreamRow", + ("group_id", "user_id", "type", "content"), # str # str # str # dict + ) + NAME = "groups" ROW_TYPE = GroupsStreamRow @@ -448,6 +451,8 @@ class UserSignatureStream(Stream): """A user has signed their own device with their user-signing key """ + UserSignatureStreamRow = namedtuple("UserSignatureStreamRow", ("user_id")) # str + NAME = "user_signature" ROW_TYPE = UserSignatureStreamRow diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py index 615f3dc9ac..f5f9336430 100644 --- a/synapse/replication/tcp/streams/federation.py +++ b/synapse/replication/tcp/streams/federation.py @@ -17,20 +17,20 @@ from collections import namedtuple from ._base import Stream -FederationStreamRow = namedtuple( - "FederationStreamRow", - ( - "type", # str, the type of data as defined in the BaseFederationRows - "data", # dict, serialization of a federation.send_queue.BaseFederationRow - ), -) - class FederationStream(Stream): """Data to be sent over federation. Only available when master has federation sending disabled. """ + FederationStreamRow = namedtuple( + "FederationStreamRow", + ( + "type", # str, the type of data as defined in the BaseFederationRows + "data", # dict, serialization of a federation.send_queue.BaseFederationRow + ), + ) + NAME = "federation" ROW_TYPE = FederationStreamRow diff --git a/tests/replication/tcp/streams/test_receipts.py b/tests/replication/tcp/streams/test_receipts.py index d5a99f6caa..fa2493cad6 100644 --- a/tests/replication/tcp/streams/test_receipts.py +++ b/tests/replication/tcp/streams/test_receipts.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from synapse.replication.tcp.streams._base import ReceiptsStreamRow +from synapse.replication.tcp.streams._base import ReceiptsStream from tests.replication.tcp.streams._base import BaseStreamTestCase @@ -38,7 +38,7 @@ class ReceiptsStreamTestCase(BaseStreamTestCase): rdata_rows = self.test_handler.received_rdata_rows self.assertEqual(1, len(rdata_rows)) self.assertEqual(rdata_rows[0][0], "receipts") - row = rdata_rows[0][2] # type: ReceiptsStreamRow + row = rdata_rows[0][2] # type: ReceiptsStream.ReceiptsStreamRow self.assertEqual(ROOM_ID, row.room_id) self.assertEqual("m.read", row.receipt_type) self.assertEqual(USER_ID, row.user_id) -- cgit 1.4.1 From 190ab593b7a2c0d79569758c0faa4d2442bc2c5f Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 23 Mar 2020 15:21:54 -0400 Subject: Use the proper error code when a canonical alias that does not exist is used. (#7109) --- changelog.d/7109.bugfix | 1 + synapse/handlers/message.py | 57 ++++++++++++++++++++++++++++++--------------- 2 files changed, 39 insertions(+), 19 deletions(-) create mode 100644 changelog.d/7109.bugfix (limited to 'synapse') diff --git a/changelog.d/7109.bugfix b/changelog.d/7109.bugfix new file mode 100644 index 0000000000..268de9978e --- /dev/null +++ b/changelog.d/7109.bugfix @@ -0,0 +1 @@ +Return the proper error (M_BAD_ALIAS) when a non-existant canonical alias is provided. diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index b743fc2dcc..522271eed1 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -851,6 +851,38 @@ class EventCreationHandler(object): self.store.remove_push_actions_from_staging, event.event_id ) + @defer.inlineCallbacks + def _validate_canonical_alias( + self, directory_handler, room_alias_str, expected_room_id + ): + """ + Ensure that the given room alias points to the expected room ID. + + Args: + directory_handler: The directory handler object. + room_alias_str: The room alias to check. + expected_room_id: The room ID that the alias should point to. + """ + room_alias = RoomAlias.from_string(room_alias_str) + try: + mapping = yield directory_handler.get_association(room_alias) + except SynapseError as e: + # Turn M_NOT_FOUND errors into M_BAD_ALIAS errors. + if e.errcode == Codes.NOT_FOUND: + raise SynapseError( + 400, + "Room alias %s does not point to the room" % (room_alias_str,), + Codes.BAD_ALIAS, + ) + raise + + if mapping["room_id"] != expected_room_id: + raise SynapseError( + 400, + "Room alias %s does not point to the room" % (room_alias_str,), + Codes.BAD_ALIAS, + ) + @defer.inlineCallbacks def persist_and_notify_client_event( self, requester, event, context, ratelimit=True, extra_users=[] @@ -905,15 +937,9 @@ class EventCreationHandler(object): room_alias_str = event.content.get("alias", None) directory_handler = self.hs.get_handlers().directory_handler if room_alias_str and room_alias_str != original_alias: - room_alias = RoomAlias.from_string(room_alias_str) - mapping = yield directory_handler.get_association(room_alias) - - if mapping["room_id"] != event.room_id: - raise SynapseError( - 400, - "Room alias %s does not point to the room" % (room_alias_str,), - Codes.BAD_ALIAS, - ) + yield self._validate_canonical_alias( + directory_handler, room_alias_str, event.room_id + ) # Check that alt_aliases is the proper form. alt_aliases = event.content.get("alt_aliases", []) @@ -931,16 +957,9 @@ class EventCreationHandler(object): new_alt_aliases = set(alt_aliases) - set(original_alt_aliases) if new_alt_aliases: for alias_str in new_alt_aliases: - room_alias = RoomAlias.from_string(alias_str) - mapping = yield directory_handler.get_association(room_alias) - - if mapping["room_id"] != event.room_id: - raise SynapseError( - 400, - "Room alias %s does not point to the room" - % (room_alias_str,), - Codes.BAD_ALIAS, - ) + yield self._validate_canonical_alias( + directory_handler, alias_str, event.room_id + ) federation_handler = self.hs.get_handlers().federation_handler -- cgit 1.4.1 From c816072d47c16d2840116418698d95a855f0f24c Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 24 Mar 2020 10:35:00 +0000 Subject: Fix starting workers when federation sending not split out. --- synapse/app/generic_worker.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) (limited to 'synapse') diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index b2c764bfe8..5363642d64 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -860,6 +860,9 @@ def start(config_options): # Force the appservice to start since they will be disabled in the main config config.notify_appservices = True + else: + # For other worker types we force this to off. + config.notify_appservices = False if config.worker_app == "synapse.app.pusher": if config.start_pushers: @@ -873,6 +876,9 @@ def start(config_options): # Force the pushers to start since they will be disabled in the main config config.start_pushers = True + else: + # For other worker types we force this to off. + config.start_pushers = False if config.worker_app == "synapse.app.user_dir": if config.update_user_directory: @@ -886,6 +892,9 @@ def start(config_options): # Force the pushers to start since they will be disabled in the main config config.update_user_directory = True + else: + # For other worker types we force this to off. + config.update_user_directory = False if config.worker_app == "synapse.app.federation_sender": if config.send_federation: @@ -899,6 +908,9 @@ def start(config_options): # Force the pushers to start since they will be disabled in the main config config.send_federation = True + else: + # For other worker types we force this to off. + config.send_federation = False synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts -- cgit 1.4.1 From 1fcf9c6f95fcfcacab95bb78849d79b8c7fa22e9 Mon Sep 17 00:00:00 2001 From: Naugrimm Date: Tue, 24 Mar 2020 12:59:04 +0100 Subject: Fix CAS redirect url (#6634) Build the same service URL when requesting the CAS ticket and when calling the proxyValidate URL. --- changelog.d/6634.bugfix | 1 + synapse/rest/client/v1/login.py | 27 ++++++++++++++++----------- 2 files changed, 17 insertions(+), 11 deletions(-) create mode 100644 changelog.d/6634.bugfix (limited to 'synapse') diff --git a/changelog.d/6634.bugfix b/changelog.d/6634.bugfix new file mode 100644 index 0000000000..ec48fdc0a0 --- /dev/null +++ b/changelog.d/6634.bugfix @@ -0,0 +1 @@ +Fix single-sign on with CAS systems: pass the same service URL when requesting the CAS ticket and when calling the `proxyValidate` URL. Contributed by @Naugrimm. diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 31551524f8..56d713462a 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -72,6 +72,14 @@ def login_id_thirdparty_from_phone(identifier): return {"type": "m.id.thirdparty", "medium": "msisdn", "address": msisdn} +def build_service_param(cas_service_url, client_redirect_url): + return "%s%s?redirectUrl=%s" % ( + cas_service_url, + "/_matrix/client/r0/login/cas/ticket", + urllib.parse.quote(client_redirect_url, safe=""), + ) + + class LoginRestServlet(RestServlet): PATTERNS = client_patterns("/login$", v1=True) CAS_TYPE = "m.login.cas" @@ -427,18 +435,15 @@ class BaseSSORedirectServlet(RestServlet): class CasRedirectServlet(BaseSSORedirectServlet): def __init__(self, hs): super(CasRedirectServlet, self).__init__() - self.cas_server_url = hs.config.cas_server_url.encode("ascii") - self.cas_service_url = hs.config.cas_service_url.encode("ascii") + self.cas_server_url = hs.config.cas_server_url + self.cas_service_url = hs.config.cas_service_url def get_sso_url(self, client_redirect_url): - client_redirect_url_param = urllib.parse.urlencode( - {b"redirectUrl": client_redirect_url} - ).encode("ascii") - hs_redirect_url = self.cas_service_url + b"/_matrix/client/r0/login/cas/ticket" - service_param = urllib.parse.urlencode( - {b"service": b"%s?%s" % (hs_redirect_url, client_redirect_url_param)} - ).encode("ascii") - return b"%s/login?%s" % (self.cas_server_url, service_param) + args = urllib.parse.urlencode( + {"service": build_service_param(self.cas_service_url, client_redirect_url)} + ) + + return "%s/login?%s" % (self.cas_server_url, args) class CasTicketServlet(RestServlet): @@ -458,7 +463,7 @@ class CasTicketServlet(RestServlet): uri = self.cas_server_url + "/proxyValidate" args = { "ticket": parse_string(request, "ticket", required=True), - "service": self.cas_service_url, + "service": build_service_param(self.cas_service_url, client_redirect_url), } try: body = await self._http_client.get_raw(uri, args) -- cgit 1.4.1 From 39230d217104f3cd7aba9065dc478f935ce1e614 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 24 Mar 2020 14:45:33 +0000 Subject: Clean up some LoggingContext stuff (#7120) * Pull Sentinel out of LoggingContext ... and drop a few unnecessary references to it * Factor out LoggingContext.current_context move `current_context` and `set_context` out to top-level functions. Mostly this means that I can more easily trace what's actually referring to LoggingContext, but I think it's generally neater. * move copy-to-parent into `stop` this really just makes `start` and `stop` more symetric. It also means that it behaves correctly if you manually `set_log_context` rather than using the context manager. * Replace `LoggingContext.alive` with `finished` Turn `alive` into `finished` and make it a bit better defined. --- changelog.d/7120.misc | 1 + docs/log_contexts.md | 5 +- synapse/crypto/keyring.py | 4 +- synapse/federation/federation_base.py | 4 +- synapse/handlers/sync.py | 4 +- synapse/http/request_metrics.py | 6 +- synapse/logging/_structured.py | 4 +- synapse/logging/context.py | 234 +++++++++++---------- synapse/logging/scopecontextmanager.py | 13 +- synapse/storage/data_stores/main/events_worker.py | 4 +- synapse/storage/database.py | 11 +- synapse/util/metrics.py | 4 +- synapse/util/patch_inline_callbacks.py | 36 ++-- tests/crypto/test_keyring.py | 7 +- .../federation/test_matrix_federation_agent.py | 6 +- tests/http/federation/test_srv_resolver.py | 6 +- tests/http/test_fedclient.py | 6 +- tests/rest/client/test_transactions.py | 16 +- tests/unittest.py | 12 +- tests/util/caches/test_descriptors.py | 22 +- tests/util/test_async_utils.py | 15 +- tests/util/test_linearizer.py | 6 +- tests/util/test_logcontext.py | 22 +- tests/utils.py | 6 +- 24 files changed, 232 insertions(+), 222 deletions(-) create mode 100644 changelog.d/7120.misc (limited to 'synapse') diff --git a/changelog.d/7120.misc b/changelog.d/7120.misc new file mode 100644 index 0000000000..731f4dcb52 --- /dev/null +++ b/changelog.d/7120.misc @@ -0,0 +1 @@ +Clean up some LoggingContext code. diff --git a/docs/log_contexts.md b/docs/log_contexts.md index 5331e8c88b..fe30ca2791 100644 --- a/docs/log_contexts.md +++ b/docs/log_contexts.md @@ -29,14 +29,13 @@ from synapse.logging import context # omitted from future snippets def handle_request(request_id): request_context = context.LoggingContext() - calling_context = context.LoggingContext.current_context() - context.LoggingContext.set_current_context(request_context) + calling_context = context.set_current_context(request_context) try: request_context.request = request_id do_request_handling() logger.debug("finished") finally: - context.LoggingContext.set_current_context(calling_context) + context.set_current_context(calling_context) def do_request_handling(): logger.debug("phew") # this will be logged against request_id diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 983f0ead8c..a9f4025bfe 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -43,8 +43,8 @@ from synapse.api.errors import ( SynapseError, ) from synapse.logging.context import ( - LoggingContext, PreserveLoggingContext, + current_context, make_deferred_yieldable, preserve_fn, run_in_background, @@ -236,7 +236,7 @@ class Keyring(object): """ try: - ctx = LoggingContext.current_context() + ctx = current_context() # map from server name to a set of outstanding request ids server_to_request_ids = {} diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index b0b0eba41e..4b115aac04 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -32,8 +32,8 @@ from synapse.events import EventBase, make_event_from_dict from synapse.events.utils import prune_event from synapse.http.servlet import assert_params_in_dict from synapse.logging.context import ( - LoggingContext, PreserveLoggingContext, + current_context, make_deferred_yieldable, ) from synapse.types import JsonDict, get_domain_from_id @@ -78,7 +78,7 @@ class FederationBase(object): """ deferreds = _check_sigs_on_pdus(self.keyring, room_version, pdus) - ctx = LoggingContext.current_context() + ctx = current_context() def callback(_, pdu: EventBase): with PreserveLoggingContext(ctx): diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 669dbc8a48..5746fdea14 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -26,7 +26,7 @@ from prometheus_client import Counter from synapse.api.constants import EventTypes, Membership from synapse.api.filtering import FilterCollection from synapse.events import EventBase -from synapse.logging.context import LoggingContext +from synapse.logging.context import current_context from synapse.push.clientformat import format_push_rules_for_user from synapse.storage.roommember import MemberSummary from synapse.storage.state import StateFilter @@ -301,7 +301,7 @@ class SyncHandler(object): else: sync_type = "incremental_sync" - context = LoggingContext.current_context() + context = current_context() if context: context.tag = sync_type diff --git a/synapse/http/request_metrics.py b/synapse/http/request_metrics.py index 58f9cc61c8..b58ae3d9db 100644 --- a/synapse/http/request_metrics.py +++ b/synapse/http/request_metrics.py @@ -19,7 +19,7 @@ import threading from prometheus_client.core import Counter, Histogram -from synapse.logging.context import LoggingContext +from synapse.logging.context import current_context from synapse.metrics import LaterGauge logger = logging.getLogger(__name__) @@ -148,7 +148,7 @@ LaterGauge( class RequestMetrics(object): def start(self, time_sec, name, method): self.start = time_sec - self.start_context = LoggingContext.current_context() + self.start_context = current_context() self.name = name self.method = method @@ -163,7 +163,7 @@ class RequestMetrics(object): with _in_flight_requests_lock: _in_flight_requests.discard(self) - context = LoggingContext.current_context() + context = current_context() tag = "" if context: diff --git a/synapse/logging/_structured.py b/synapse/logging/_structured.py index ffa7b20ca8..7372450b45 100644 --- a/synapse/logging/_structured.py +++ b/synapse/logging/_structured.py @@ -42,7 +42,7 @@ from synapse.logging._terse_json import ( TerseJSONToConsoleLogObserver, TerseJSONToTCPLogObserver, ) -from synapse.logging.context import LoggingContext +from synapse.logging.context import current_context def stdlib_log_level_to_twisted(level: str) -> LogLevel: @@ -86,7 +86,7 @@ class LogContextObserver(object): ].startswith("Timing out client"): return - context = LoggingContext.current_context() + context = current_context() # Copy the context information to the log event. if context is not None: diff --git a/synapse/logging/context.py b/synapse/logging/context.py index 860b99a4c6..a8eafb1c7c 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py @@ -175,7 +175,54 @@ class ContextResourceUsage(object): return res -LoggingContextOrSentinel = Union["LoggingContext", "LoggingContext.Sentinel"] +LoggingContextOrSentinel = Union["LoggingContext", "_Sentinel"] + + +class _Sentinel(object): + """Sentinel to represent the root context""" + + __slots__ = ["previous_context", "finished", "request", "scope", "tag"] + + def __init__(self) -> None: + # Minimal set for compatibility with LoggingContext + self.previous_context = None + self.finished = False + self.request = None + self.scope = None + self.tag = None + + def __str__(self): + return "sentinel" + + def copy_to(self, record): + pass + + def copy_to_twisted_log_entry(self, record): + record["request"] = None + record["scope"] = None + + def start(self): + pass + + def stop(self): + pass + + def add_database_transaction(self, duration_sec): + pass + + def add_database_scheduled(self, sched_sec): + pass + + def record_event_fetch(self, event_count): + pass + + def __nonzero__(self): + return False + + __bool__ = __nonzero__ # python3 + + +SENTINEL_CONTEXT = _Sentinel() class LoggingContext(object): @@ -199,76 +246,33 @@ class LoggingContext(object): "_resource_usage", "usage_start", "main_thread", - "alive", + "finished", "request", "tag", "scope", ] - thread_local = threading.local() - - class Sentinel(object): - """Sentinel to represent the root context""" - - __slots__ = ["previous_context", "alive", "request", "scope", "tag"] - - def __init__(self) -> None: - # Minimal set for compatibility with LoggingContext - self.previous_context = None - self.alive = None - self.request = None - self.scope = None - self.tag = None - - def __str__(self): - return "sentinel" - - def copy_to(self, record): - pass - - def copy_to_twisted_log_entry(self, record): - record["request"] = None - record["scope"] = None - - def start(self): - pass - - def stop(self): - pass - - def add_database_transaction(self, duration_sec): - pass - - def add_database_scheduled(self, sched_sec): - pass - - def record_event_fetch(self, event_count): - pass - - def __nonzero__(self): - return False - - __bool__ = __nonzero__ # python3 - - sentinel = Sentinel() - def __init__(self, name=None, parent_context=None, request=None) -> None: - self.previous_context = LoggingContext.current_context() + self.previous_context = current_context() self.name = name # track the resources used by this context so far self._resource_usage = ContextResourceUsage() - # If alive has the thread resource usage when the logcontext last - # became active. + # The thread resource usage when the logcontext became active. None + # if the context is not currently active. self.usage_start = None self.main_thread = get_thread_id() self.request = None self.tag = "" - self.alive = True self.scope = None # type: Optional[_LogContextScope] + # keep track of whether we have hit the __exit__ block for this context + # (suggesting that the the thing that created the context thinks it should + # be finished, and that re-activating it would suggest an error). + self.finished = False + self.parent_context = parent_context if self.parent_context is not None: @@ -283,44 +287,15 @@ class LoggingContext(object): return str(self.request) return "%s@%x" % (self.name, id(self)) - @classmethod - def current_context(cls) -> LoggingContextOrSentinel: - """Get the current logging context from thread local storage - - Returns: - LoggingContext: the current logging context - """ - return getattr(cls.thread_local, "current_context", cls.sentinel) - - @classmethod - def set_current_context( - cls, context: LoggingContextOrSentinel - ) -> LoggingContextOrSentinel: - """Set the current logging context in thread local storage - Args: - context(LoggingContext): The context to activate. - Returns: - The context that was previously active - """ - current = cls.current_context() - - if current is not context: - current.stop() - cls.thread_local.current_context = context - context.start() - return current - def __enter__(self) -> "LoggingContext": """Enters this logging context into thread local storage""" - old_context = self.set_current_context(self) + old_context = set_current_context(self) if self.previous_context != old_context: logger.warning( "Expected previous context %r, found %r", self.previous_context, old_context, ) - self.alive = True - return self def __exit__(self, type, value, traceback) -> None: @@ -329,24 +304,19 @@ class LoggingContext(object): Returns: None to avoid suppressing any exceptions that were thrown. """ - current = self.set_current_context(self.previous_context) + current = set_current_context(self.previous_context) if current is not self: - if current is self.sentinel: + if current is SENTINEL_CONTEXT: logger.warning("Expected logging context %s was lost", self) else: logger.warning( "Expected logging context %s but found %s", self, current ) - self.alive = False - - # if we have a parent, pass our CPU usage stats on - if self.parent_context is not None and hasattr( - self.parent_context, "_resource_usage" - ): - self.parent_context._resource_usage += self._resource_usage - # reset them in case we get entered again - self._resource_usage.reset() + # the fact that we are here suggests that the caller thinks that everything + # is done and dusted for this logcontext, and further activity will not get + # recorded against the correct metrics. + self.finished = True def copy_to(self, record) -> None: """Copy logging fields from this context to a log record or @@ -371,9 +341,14 @@ class LoggingContext(object): logger.warning("Started logcontext %s on different thread", self) return + if self.finished: + logger.warning("Re-starting finished log context %s", self) + # If we haven't already started record the thread resource usage so # far - if not self.usage_start: + if self.usage_start: + logger.warning("Re-starting already-active log context %s", self) + else: self.usage_start = get_thread_resource_usage() def stop(self) -> None: @@ -396,6 +371,15 @@ class LoggingContext(object): self.usage_start = None + # if we have a parent, pass our CPU usage stats on + if self.parent_context is not None and hasattr( + self.parent_context, "_resource_usage" + ): + self.parent_context._resource_usage += self._resource_usage + + # reset them in case we get entered again + self._resource_usage.reset() + def get_resource_usage(self) -> ContextResourceUsage: """Get resources used by this logcontext so far. @@ -409,7 +393,7 @@ class LoggingContext(object): # If we are on the correct thread and we're currently running then we # can include resource usage so far. is_main_thread = get_thread_id() == self.main_thread - if self.alive and self.usage_start and is_main_thread: + if self.usage_start and is_main_thread: utime_delta, stime_delta = self._get_cputime() res.ru_utime += utime_delta res.ru_stime += stime_delta @@ -492,7 +476,7 @@ class LoggingContextFilter(logging.Filter): Returns: True to include the record in the log output. """ - context = LoggingContext.current_context() + context = current_context() for key, value in self.defaults.items(): setattr(record, key, value) @@ -512,27 +496,24 @@ class PreserveLoggingContext(object): __slots__ = ["current_context", "new_context", "has_parent"] - def __init__(self, new_context: Optional[LoggingContextOrSentinel] = None) -> None: - if new_context is None: - self.new_context = LoggingContext.sentinel # type: LoggingContextOrSentinel - else: - self.new_context = new_context + def __init__( + self, new_context: LoggingContextOrSentinel = SENTINEL_CONTEXT + ) -> None: + self.new_context = new_context def __enter__(self) -> None: """Captures the current logging context""" - self.current_context = LoggingContext.set_current_context(self.new_context) + self.current_context = set_current_context(self.new_context) if self.current_context: self.has_parent = self.current_context.previous_context is not None - if not self.current_context.alive: - logger.debug("Entering dead context: %s", self.current_context) def __exit__(self, type, value, traceback) -> None: """Restores the current logging context""" - context = LoggingContext.set_current_context(self.current_context) + context = set_current_context(self.current_context) if context != self.new_context: - if context is LoggingContext.sentinel: + if not context: logger.warning("Expected logging context %s was lost", self.new_context) else: logger.warning( @@ -541,9 +522,30 @@ class PreserveLoggingContext(object): context, ) - if self.current_context is not LoggingContext.sentinel: - if not self.current_context.alive: - logger.debug("Restoring dead context: %s", self.current_context) + +_thread_local = threading.local() +_thread_local.current_context = SENTINEL_CONTEXT + + +def current_context() -> LoggingContextOrSentinel: + """Get the current logging context from thread local storage""" + return getattr(_thread_local, "current_context", SENTINEL_CONTEXT) + + +def set_current_context(context: LoggingContextOrSentinel) -> LoggingContextOrSentinel: + """Set the current logging context in thread local storage + Args: + context(LoggingContext): The context to activate. + Returns: + The context that was previously active + """ + current = current_context() + + if current is not context: + current.stop() + _thread_local.current_context = context + context.start() + return current def nested_logging_context( @@ -572,7 +574,7 @@ def nested_logging_context( if parent_context is not None: context = parent_context # type: LoggingContextOrSentinel else: - context = LoggingContext.current_context() + context = current_context() return LoggingContext( parent_context=context, request=str(context.request) + "-" + suffix ) @@ -604,7 +606,7 @@ def run_in_background(f, *args, **kwargs): CRITICAL error about an unhandled error will be logged without much indication about where it came from. """ - current = LoggingContext.current_context() + current = current_context() try: res = f(*args, **kwargs) except: # noqa: E722 @@ -625,7 +627,7 @@ def run_in_background(f, *args, **kwargs): # The function may have reset the context before returning, so # we need to restore it now. - ctx = LoggingContext.set_current_context(current) + ctx = set_current_context(current) # The original context will be restored when the deferred # completes, but there is nothing waiting for it, so it will @@ -674,7 +676,7 @@ def make_deferred_yieldable(deferred): # ok, we can't be sure that a yield won't block, so let's reset the # logcontext, and add a callback to the deferred to restore it. - prev_context = LoggingContext.set_current_context(LoggingContext.sentinel) + prev_context = set_current_context(SENTINEL_CONTEXT) deferred.addBoth(_set_context_cb, prev_context) return deferred @@ -684,7 +686,7 @@ ResultT = TypeVar("ResultT") def _set_context_cb(result: ResultT, context: LoggingContext) -> ResultT: """A callback function which just sets the logging context""" - LoggingContext.set_current_context(context) + set_current_context(context) return result @@ -752,7 +754,7 @@ def defer_to_threadpool(reactor, threadpool, f, *args, **kwargs): Deferred: A Deferred which fires a callback with the result of `f`, or an errback if `f` throws an exception. """ - logcontext = LoggingContext.current_context() + logcontext = current_context() def g(): with LoggingContext(parent_context=logcontext): diff --git a/synapse/logging/scopecontextmanager.py b/synapse/logging/scopecontextmanager.py index 4eed4f2338..dc3ab00cbb 100644 --- a/synapse/logging/scopecontextmanager.py +++ b/synapse/logging/scopecontextmanager.py @@ -19,7 +19,7 @@ from opentracing import Scope, ScopeManager import twisted -from synapse.logging.context import LoggingContext, nested_logging_context +from synapse.logging.context import current_context, nested_logging_context logger = logging.getLogger(__name__) @@ -49,11 +49,8 @@ class LogContextScopeManager(ScopeManager): (Scope) : the Scope that is active, or None if not available. """ - ctx = LoggingContext.current_context() - if ctx is LoggingContext.sentinel: - return None - else: - return ctx.scope + ctx = current_context() + return ctx.scope def activate(self, span, finish_on_close): """ @@ -70,9 +67,9 @@ class LogContextScopeManager(ScopeManager): """ enter_logcontext = False - ctx = LoggingContext.current_context() + ctx = current_context() - if ctx is LoggingContext.sentinel: + if not ctx: # We don't want this scope to affect. logger.error("Tried to activate scope outside of loggingcontext") return Scope(None, span) diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py index ca237c6f12..3013f49d32 100644 --- a/synapse/storage/data_stores/main/events_worker.py +++ b/synapse/storage/data_stores/main/events_worker.py @@ -35,7 +35,7 @@ from synapse.api.room_versions import ( ) from synapse.events import make_event_from_dict from synapse.events.utils import prune_event -from synapse.logging.context import LoggingContext, PreserveLoggingContext +from synapse.logging.context import PreserveLoggingContext, current_context from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause from synapse.storage.database import Database @@ -409,7 +409,7 @@ class EventsWorkerStore(SQLBaseStore): missing_events_ids = [e for e in event_ids if e not in event_entry_map] if missing_events_ids: - log_ctx = LoggingContext.current_context() + log_ctx = current_context() log_ctx.record_event_fetch(len(missing_events_ids)) # Note that _get_events_from_db is also responsible for turning db rows diff --git a/synapse/storage/database.py b/synapse/storage/database.py index e61595336c..715c0346dd 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -32,6 +32,7 @@ from synapse.config.database import DatabaseConnectionConfig from synapse.logging.context import ( LoggingContext, LoggingContextOrSentinel, + current_context, make_deferred_yieldable, ) from synapse.metrics.background_process_metrics import run_as_background_process @@ -483,7 +484,7 @@ class Database(object): end = monotonic_time() duration = end - start - LoggingContext.current_context().add_database_transaction(duration) + current_context().add_database_transaction(duration) transaction_logger.debug("[TXN END] {%s} %f sec", name, duration) @@ -510,7 +511,7 @@ class Database(object): after_callbacks = [] # type: List[_CallbackListEntry] exception_callbacks = [] # type: List[_CallbackListEntry] - if LoggingContext.current_context() == LoggingContext.sentinel: + if not current_context(): logger.warning("Starting db txn '%s' from sentinel context", desc) try: @@ -547,10 +548,8 @@ class Database(object): Returns: Deferred: The result of func """ - parent_context = ( - LoggingContext.current_context() - ) # type: Optional[LoggingContextOrSentinel] - if parent_context == LoggingContext.sentinel: + parent_context = current_context() # type: Optional[LoggingContextOrSentinel] + if not parent_context: logger.warning( "Starting db connection from sentinel context: metrics will be lost" ) diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py index 7b18455469..ec61e14423 100644 --- a/synapse/util/metrics.py +++ b/synapse/util/metrics.py @@ -21,7 +21,7 @@ from prometheus_client import Counter from twisted.internet import defer -from synapse.logging.context import LoggingContext +from synapse.logging.context import LoggingContext, current_context from synapse.metrics import InFlightGauge logger = logging.getLogger(__name__) @@ -106,7 +106,7 @@ class Measure(object): raise RuntimeError("Measure() objects cannot be re-used") self.start = self.clock.time() - parent_context = LoggingContext.current_context() + parent_context = current_context() self._logging_context = LoggingContext( "Measure[%s]" % (self.name,), parent_context ) diff --git a/synapse/util/patch_inline_callbacks.py b/synapse/util/patch_inline_callbacks.py index 3925927f9f..fdff195771 100644 --- a/synapse/util/patch_inline_callbacks.py +++ b/synapse/util/patch_inline_callbacks.py @@ -32,7 +32,7 @@ def do_patch(): Patch defer.inlineCallbacks so that it checks the state of the logcontext on exit """ - from synapse.logging.context import LoggingContext + from synapse.logging.context import current_context global _already_patched @@ -43,35 +43,35 @@ def do_patch(): def new_inline_callbacks(f): @functools.wraps(f) def wrapped(*args, **kwargs): - start_context = LoggingContext.current_context() + start_context = current_context() changes = [] # type: List[str] orig = orig_inline_callbacks(_check_yield_points(f, changes)) try: res = orig(*args, **kwargs) except Exception: - if LoggingContext.current_context() != start_context: + if current_context() != start_context: for err in changes: print(err, file=sys.stderr) err = "%s changed context from %s to %s on exception" % ( f, start_context, - LoggingContext.current_context(), + current_context(), ) print(err, file=sys.stderr) raise Exception(err) raise if not isinstance(res, Deferred) or res.called: - if LoggingContext.current_context() != start_context: + if current_context() != start_context: for err in changes: print(err, file=sys.stderr) err = "Completed %s changed context from %s to %s" % ( f, start_context, - LoggingContext.current_context(), + current_context(), ) # print the error to stderr because otherwise all we # see in travis-ci is the 500 error @@ -79,23 +79,23 @@ def do_patch(): raise Exception(err) return res - if LoggingContext.current_context() != LoggingContext.sentinel: + if current_context(): err = ( "%s returned incomplete deferred in non-sentinel context " "%s (start was %s)" - ) % (f, LoggingContext.current_context(), start_context) + ) % (f, current_context(), start_context) print(err, file=sys.stderr) raise Exception(err) def check_ctx(r): - if LoggingContext.current_context() != start_context: + if current_context() != start_context: for err in changes: print(err, file=sys.stderr) err = "%s completion of %s changed context from %s to %s" % ( "Failure" if isinstance(r, Failure) else "Success", f, start_context, - LoggingContext.current_context(), + current_context(), ) print(err, file=sys.stderr) raise Exception(err) @@ -127,7 +127,7 @@ def _check_yield_points(f: Callable, changes: List[str]): function """ - from synapse.logging.context import LoggingContext + from synapse.logging.context import current_context @functools.wraps(f) def check_yield_points_inner(*args, **kwargs): @@ -136,7 +136,7 @@ def _check_yield_points(f: Callable, changes: List[str]): last_yield_line_no = gen.gi_frame.f_lineno result = None # type: Any while True: - expected_context = LoggingContext.current_context() + expected_context = current_context() try: isFailure = isinstance(result, Failure) @@ -145,7 +145,7 @@ def _check_yield_points(f: Callable, changes: List[str]): else: d = gen.send(result) except (StopIteration, defer._DefGen_Return) as e: - if LoggingContext.current_context() != expected_context: + if current_context() != expected_context: # This happens when the context is lost sometime *after* the # final yield and returning. E.g. we forgot to yield on a # function that returns a deferred. @@ -159,7 +159,7 @@ def _check_yield_points(f: Callable, changes: List[str]): % ( f.__qualname__, expected_context, - LoggingContext.current_context(), + current_context(), f.__code__.co_filename, last_yield_line_no, ) @@ -173,13 +173,13 @@ def _check_yield_points(f: Callable, changes: List[str]): # This happens if we yield on a deferred that doesn't follow # the log context rules without wrapping in a `make_deferred_yieldable`. # We raise here as this should never happen. - if LoggingContext.current_context() is not LoggingContext.sentinel: + if current_context(): err = ( "%s yielded with context %s rather than sentinel," " yielded on line %d in %s" % ( frame.f_code.co_name, - LoggingContext.current_context(), + current_context(), frame.f_lineno, frame.f_code.co_filename, ) @@ -191,7 +191,7 @@ def _check_yield_points(f: Callable, changes: List[str]): except Exception as e: result = Failure(e) - if LoggingContext.current_context() != expected_context: + if current_context() != expected_context: # This happens because the context is lost sometime *after* the # previous yield and *after* the current yield. E.g. the @@ -206,7 +206,7 @@ def _check_yield_points(f: Callable, changes: List[str]): % ( frame.f_code.co_name, expected_context, - LoggingContext.current_context(), + current_context(), last_yield_line_no, frame.f_lineno, frame.f_code.co_filename, diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index 34d5895f18..70c8e72303 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -34,6 +34,7 @@ from synapse.crypto.keyring import ( from synapse.logging.context import ( LoggingContext, PreserveLoggingContext, + current_context, make_deferred_yieldable, ) from synapse.storage.keys import FetchKeyResult @@ -83,9 +84,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): ) def check_context(self, _, expected): - self.assertEquals( - getattr(LoggingContext.current_context(), "request", None), expected - ) + self.assertEquals(getattr(current_context(), "request", None), expected) def test_verify_json_objects_for_server_awaits_previous_requests(self): key1 = signedjson.key.generate_signing_key(1) @@ -105,7 +104,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): @defer.inlineCallbacks def get_perspectives(**kwargs): - self.assertEquals(LoggingContext.current_context().request, "11") + self.assertEquals(current_context().request, "11") with PreserveLoggingContext(): yield persp_deferred return persp_resp diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index fdc1d918ff..562397cdda 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -38,7 +38,7 @@ from synapse.http.federation.well_known_resolver import ( WellKnownResolver, _cache_period_from_headers, ) -from synapse.logging.context import LoggingContext +from synapse.logging.context import SENTINEL_CONTEXT, LoggingContext, current_context from synapse.util.caches.ttlcache import TTLCache from tests import unittest @@ -155,7 +155,7 @@ class MatrixFederationAgentTests(unittest.TestCase): self.assertNoResult(fetch_d) # should have reset logcontext to the sentinel - _check_logcontext(LoggingContext.sentinel) + _check_logcontext(SENTINEL_CONTEXT) try: fetch_res = yield fetch_d @@ -1197,7 +1197,7 @@ class TestCachePeriodFromHeaders(unittest.TestCase): def _check_logcontext(context): - current = LoggingContext.current_context() + current = current_context() if current is not context: raise AssertionError("Expected logcontext %s but was %s" % (context, current)) diff --git a/tests/http/federation/test_srv_resolver.py b/tests/http/federation/test_srv_resolver.py index df034ab237..babc201643 100644 --- a/tests/http/federation/test_srv_resolver.py +++ b/tests/http/federation/test_srv_resolver.py @@ -22,7 +22,7 @@ from twisted.internet.error import ConnectError from twisted.names import dns, error from synapse.http.federation.srv_resolver import SrvResolver -from synapse.logging.context import LoggingContext +from synapse.logging.context import SENTINEL_CONTEXT, LoggingContext, current_context from tests import unittest from tests.utils import MockClock @@ -54,12 +54,12 @@ class SrvResolverTestCase(unittest.TestCase): self.assertNoResult(resolve_d) # should have reset to the sentinel context - self.assertIs(LoggingContext.current_context(), LoggingContext.sentinel) + self.assertIs(current_context(), SENTINEL_CONTEXT) result = yield resolve_d # should have restored our context - self.assertIs(LoggingContext.current_context(), ctx) + self.assertIs(current_context(), ctx) return result diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py index 2b01f40a42..fff4f0cbf4 100644 --- a/tests/http/test_fedclient.py +++ b/tests/http/test_fedclient.py @@ -29,14 +29,14 @@ from synapse.http.matrixfederationclient import ( MatrixFederationHttpClient, MatrixFederationRequest, ) -from synapse.logging.context import LoggingContext +from synapse.logging.context import SENTINEL_CONTEXT, LoggingContext, current_context from tests.server import FakeTransport from tests.unittest import HomeserverTestCase def check_logcontext(context): - current = LoggingContext.current_context() + current = current_context() if current is not context: raise AssertionError("Expected logcontext %s but was %s" % (context, current)) @@ -64,7 +64,7 @@ class FederationClientTests(HomeserverTestCase): self.assertNoResult(fetch_d) # should have reset logcontext to the sentinel - check_logcontext(LoggingContext.sentinel) + check_logcontext(SENTINEL_CONTEXT) try: fetch_res = yield fetch_d diff --git a/tests/rest/client/test_transactions.py b/tests/rest/client/test_transactions.py index a3d7e3c046..171632e195 100644 --- a/tests/rest/client/test_transactions.py +++ b/tests/rest/client/test_transactions.py @@ -2,7 +2,7 @@ from mock import Mock, call from twisted.internet import defer, reactor -from synapse.logging.context import LoggingContext +from synapse.logging.context import SENTINEL_CONTEXT, LoggingContext, current_context from synapse.rest.client.transactions import CLEANUP_PERIOD_MS, HttpTransactionCache from synapse.util import Clock @@ -52,14 +52,14 @@ class HttpTransactionCacheTestCase(unittest.TestCase): def test(): with LoggingContext("c") as c1: res = yield self.cache.fetch_or_execute(self.mock_key, cb) - self.assertIs(LoggingContext.current_context(), c1) + self.assertIs(current_context(), c1) self.assertEqual(res, "yay") # run the test twice in parallel d = defer.gatherResults([test(), test()]) - self.assertIs(LoggingContext.current_context(), LoggingContext.sentinel) + self.assertIs(current_context(), SENTINEL_CONTEXT) yield d - self.assertIs(LoggingContext.current_context(), LoggingContext.sentinel) + self.assertIs(current_context(), SENTINEL_CONTEXT) @defer.inlineCallbacks def test_does_not_cache_exceptions(self): @@ -81,11 +81,11 @@ class HttpTransactionCacheTestCase(unittest.TestCase): yield self.cache.fetch_or_execute(self.mock_key, cb) except Exception as e: self.assertEqual(e.args[0], "boo") - self.assertIs(LoggingContext.current_context(), test_context) + self.assertIs(current_context(), test_context) res = yield self.cache.fetch_or_execute(self.mock_key, cb) self.assertEqual(res, self.mock_http_response) - self.assertIs(LoggingContext.current_context(), test_context) + self.assertIs(current_context(), test_context) @defer.inlineCallbacks def test_does_not_cache_failures(self): @@ -107,11 +107,11 @@ class HttpTransactionCacheTestCase(unittest.TestCase): yield self.cache.fetch_or_execute(self.mock_key, cb) except Exception as e: self.assertEqual(e.args[0], "boo") - self.assertIs(LoggingContext.current_context(), test_context) + self.assertIs(current_context(), test_context) res = yield self.cache.fetch_or_execute(self.mock_key, cb) self.assertEqual(res, self.mock_http_response) - self.assertIs(LoggingContext.current_context(), test_context) + self.assertIs(current_context(), test_context) @defer.inlineCallbacks def test_cleans_up(self): diff --git a/tests/unittest.py b/tests/unittest.py index 8816a4d152..439174dbfc 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -38,7 +38,11 @@ from synapse.config.ratelimiting import FederationRateLimitConfig from synapse.federation.transport import server as federation_server from synapse.http.server import JsonResource from synapse.http.site import SynapseRequest, SynapseSite -from synapse.logging.context import LoggingContext +from synapse.logging.context import ( + SENTINEL_CONTEXT, + current_context, + set_current_context, +) from synapse.server import HomeServer from synapse.types import Requester, UserID, create_requester from synapse.util.ratelimitutils import FederationRateLimiter @@ -97,10 +101,10 @@ class TestCase(unittest.TestCase): def setUp(orig): # if we're not starting in the sentinel logcontext, then to be honest # all future bets are off. - if LoggingContext.current_context() is not LoggingContext.sentinel: + if current_context(): self.fail( "Test starting with non-sentinel logging context %s" - % (LoggingContext.current_context(),) + % (current_context(),) ) old_level = logging.getLogger().level @@ -122,7 +126,7 @@ class TestCase(unittest.TestCase): # force a GC to workaround problems with deferreds leaking logcontexts when # they are GCed (see the logcontext docs) gc.collect() - LoggingContext.set_current_context(LoggingContext.sentinel) + set_current_context(SENTINEL_CONTEXT) return ret diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index 39e360fe24..4d2b9e0d64 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -22,8 +22,10 @@ from twisted.internet import defer, reactor from synapse.api.errors import SynapseError from synapse.logging.context import ( + SENTINEL_CONTEXT, LoggingContext, PreserveLoggingContext, + current_context, make_deferred_yieldable, ) from synapse.util.caches import descriptors @@ -194,7 +196,7 @@ class DescriptorTestCase(unittest.TestCase): with LoggingContext() as c1: c1.name = "c1" r = yield obj.fn(1) - self.assertEqual(LoggingContext.current_context(), c1) + self.assertEqual(current_context(), c1) return r def check_result(r): @@ -204,12 +206,12 @@ class DescriptorTestCase(unittest.TestCase): # set off a deferred which will do a cache lookup d1 = do_lookup() - self.assertEqual(LoggingContext.current_context(), LoggingContext.sentinel) + self.assertEqual(current_context(), SENTINEL_CONTEXT) d1.addCallback(check_result) # and another d2 = do_lookup() - self.assertEqual(LoggingContext.current_context(), LoggingContext.sentinel) + self.assertEqual(current_context(), SENTINEL_CONTEXT) d2.addCallback(check_result) # let the lookup complete @@ -239,14 +241,14 @@ class DescriptorTestCase(unittest.TestCase): try: d = obj.fn(1) self.assertEqual( - LoggingContext.current_context(), LoggingContext.sentinel + current_context(), SENTINEL_CONTEXT, ) yield d self.fail("No exception thrown") except SynapseError: pass - self.assertEqual(LoggingContext.current_context(), c1) + self.assertEqual(current_context(), c1) # the cache should now be empty self.assertEqual(len(obj.fn.cache.cache), 0) @@ -255,7 +257,7 @@ class DescriptorTestCase(unittest.TestCase): # set off a deferred which will do a cache lookup d1 = do_lookup() - self.assertEqual(LoggingContext.current_context(), LoggingContext.sentinel) + self.assertEqual(current_context(), SENTINEL_CONTEXT) return d1 @@ -366,10 +368,10 @@ class CachedListDescriptorTestCase(unittest.TestCase): @descriptors.cachedList("fn", "args1", inlineCallbacks=True) def list_fn(self, args1, arg2): - assert LoggingContext.current_context().request == "c1" + assert current_context().request == "c1" # we want this to behave like an asynchronous function yield run_on_reactor() - assert LoggingContext.current_context().request == "c1" + assert current_context().request == "c1" return self.mock(args1, arg2) with LoggingContext() as c1: @@ -377,9 +379,9 @@ class CachedListDescriptorTestCase(unittest.TestCase): obj = Cls() obj.mock.return_value = {10: "fish", 20: "chips"} d1 = obj.list_fn([10, 20], 2) - self.assertEqual(LoggingContext.current_context(), LoggingContext.sentinel) + self.assertEqual(current_context(), SENTINEL_CONTEXT) r = yield d1 - self.assertEqual(LoggingContext.current_context(), c1) + self.assertEqual(current_context(), c1) obj.mock.assert_called_once_with([10, 20], 2) self.assertEqual(r, {10: "fish", 20: "chips"}) obj.mock.reset_mock() diff --git a/tests/util/test_async_utils.py b/tests/util/test_async_utils.py index f60918069a..17fd86d02d 100644 --- a/tests/util/test_async_utils.py +++ b/tests/util/test_async_utils.py @@ -16,7 +16,12 @@ from twisted.internet import defer from twisted.internet.defer import CancelledError, Deferred from twisted.internet.task import Clock -from synapse.logging.context import LoggingContext, PreserveLoggingContext +from synapse.logging.context import ( + SENTINEL_CONTEXT, + LoggingContext, + PreserveLoggingContext, + current_context, +) from synapse.util.async_helpers import timeout_deferred from tests.unittest import TestCase @@ -79,10 +84,10 @@ class TimeoutDeferredTest(TestCase): # the errbacks should be run in the test logcontext def errback(res, deferred_name): self.assertIs( - LoggingContext.current_context(), + current_context(), context_one, "errback %s run in unexpected logcontext %s" - % (deferred_name, LoggingContext.current_context()), + % (deferred_name, current_context()), ) return res @@ -90,7 +95,7 @@ class TimeoutDeferredTest(TestCase): original_deferred.addErrback(errback, "orig") timing_out_d = timeout_deferred(original_deferred, 1.0, self.clock) self.assertNoResult(timing_out_d) - self.assertIs(LoggingContext.current_context(), LoggingContext.sentinel) + self.assertIs(current_context(), SENTINEL_CONTEXT) timing_out_d.addErrback(errback, "timingout") self.clock.pump((1.0,)) @@ -99,4 +104,4 @@ class TimeoutDeferredTest(TestCase): blocking_was_cancelled[0], "non-completing deferred was not cancelled" ) self.failureResultOf(timing_out_d, defer.TimeoutError) - self.assertIs(LoggingContext.current_context(), context_one) + self.assertIs(current_context(), context_one) diff --git a/tests/util/test_linearizer.py b/tests/util/test_linearizer.py index 0ec8ef90ce..852ef23185 100644 --- a/tests/util/test_linearizer.py +++ b/tests/util/test_linearizer.py @@ -19,7 +19,7 @@ from six.moves import range from twisted.internet import defer, reactor from twisted.internet.defer import CancelledError -from synapse.logging.context import LoggingContext +from synapse.logging.context import LoggingContext, current_context from synapse.util import Clock from synapse.util.async_helpers import Linearizer @@ -54,11 +54,11 @@ class LinearizerTestCase(unittest.TestCase): def func(i, sleep=False): with LoggingContext("func(%s)" % i) as lc: with (yield linearizer.queue("")): - self.assertEqual(LoggingContext.current_context(), lc) + self.assertEqual(current_context(), lc) if sleep: yield Clock(reactor).sleep(0) - self.assertEqual(LoggingContext.current_context(), lc) + self.assertEqual(current_context(), lc) func(0, sleep=True) for i in range(1, 100): diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py index 281b32c4b8..95301c013c 100644 --- a/tests/util/test_logcontext.py +++ b/tests/util/test_logcontext.py @@ -2,8 +2,10 @@ import twisted.python.failure from twisted.internet import defer, reactor from synapse.logging.context import ( + SENTINEL_CONTEXT, LoggingContext, PreserveLoggingContext, + current_context, make_deferred_yieldable, nested_logging_context, run_in_background, @@ -15,7 +17,7 @@ from .. import unittest class LoggingContextTestCase(unittest.TestCase): def _check_test_key(self, value): - self.assertEquals(LoggingContext.current_context().request, value) + self.assertEquals(current_context().request, value) def test_with_context(self): with LoggingContext() as context_one: @@ -41,7 +43,7 @@ class LoggingContextTestCase(unittest.TestCase): self._check_test_key("one") def _test_run_in_background(self, function): - sentinel_context = LoggingContext.current_context() + sentinel_context = current_context() callback_completed = [False] @@ -71,7 +73,7 @@ class LoggingContextTestCase(unittest.TestCase): # make sure that the context was reset before it got thrown back # into the reactor try: - self.assertIs(LoggingContext.current_context(), sentinel_context) + self.assertIs(current_context(), sentinel_context) d2.callback(None) except BaseException: d2.errback(twisted.python.failure.Failure()) @@ -108,7 +110,7 @@ class LoggingContextTestCase(unittest.TestCase): async def testfunc(): self._check_test_key("one") d = Clock(reactor).sleep(0) - self.assertIs(LoggingContext.current_context(), LoggingContext.sentinel) + self.assertIs(current_context(), SENTINEL_CONTEXT) await d self._check_test_key("one") @@ -129,14 +131,14 @@ class LoggingContextTestCase(unittest.TestCase): reactor.callLater(0, d.callback, None) return d - sentinel_context = LoggingContext.current_context() + sentinel_context = current_context() with LoggingContext() as context_one: context_one.request = "one" d1 = make_deferred_yieldable(blocking_function()) # make sure that the context was reset by make_deferred_yieldable - self.assertIs(LoggingContext.current_context(), sentinel_context) + self.assertIs(current_context(), sentinel_context) yield d1 @@ -145,14 +147,14 @@ class LoggingContextTestCase(unittest.TestCase): @defer.inlineCallbacks def test_make_deferred_yieldable_with_chained_deferreds(self): - sentinel_context = LoggingContext.current_context() + sentinel_context = current_context() with LoggingContext() as context_one: context_one.request = "one" d1 = make_deferred_yieldable(_chained_deferred_function()) # make sure that the context was reset by make_deferred_yieldable - self.assertIs(LoggingContext.current_context(), sentinel_context) + self.assertIs(current_context(), sentinel_context) yield d1 @@ -189,14 +191,14 @@ class LoggingContextTestCase(unittest.TestCase): reactor.callLater(0, d.callback, None) await d - sentinel_context = LoggingContext.current_context() + sentinel_context = current_context() with LoggingContext() as context_one: context_one.request = "one" d1 = make_deferred_yieldable(blocking_function()) # make sure that the context was reset by make_deferred_yieldable - self.assertIs(LoggingContext.current_context(), sentinel_context) + self.assertIs(current_context(), sentinel_context) yield d1 diff --git a/tests/utils.py b/tests/utils.py index 513f358f4f..968d109f77 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -35,7 +35,7 @@ from synapse.config.homeserver import HomeServerConfig from synapse.config.server import DEFAULT_ROOM_VERSION from synapse.federation.transport import server as federation_server from synapse.http.server import HttpServer -from synapse.logging.context import LoggingContext +from synapse.logging.context import current_context, set_current_context from synapse.server import HomeServer from synapse.storage import DataStore from synapse.storage.engines import PostgresEngine, create_engine @@ -493,10 +493,10 @@ class MockClock(object): return self.time() * 1000 def call_later(self, delay, callback, *args, **kwargs): - current_context = LoggingContext.current_context() + ctx = current_context() def wrapped_callback(): - LoggingContext.thread_local.current_context = current_context + set_current_context(ctx) callback(*args, **kwargs) t = [self.now + delay, wrapped_callback, False] -- cgit 1.4.1 From 4cff617df1ba6f241fee6957cc44859f57edcc0e Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 25 Mar 2020 14:54:01 +0000 Subject: Move catchup of replication streams to worker. (#7024) This changes the replication protocol so that the server does not send down `RDATA` for rows that happened before the client connected. Instead, the server will send a `POSITION` and clients then query the database (or master out of band) to get up to date. --- changelog.d/7024.misc | 1 + docs/tcp_replication.md | 46 ++--- synapse/app/generic_worker.py | 3 + synapse/federation/sender/__init__.py | 9 + synapse/replication/http/__init__.py | 2 + synapse/replication/http/streams.py | 78 ++++++++ synapse/replication/slave/storage/_base.py | 14 +- synapse/replication/slave/storage/pushers.py | 3 + synapse/replication/tcp/client.py | 3 +- synapse/replication/tcp/commands.py | 34 +--- synapse/replication/tcp/protocol.py | 206 ++++++++-------------- synapse/replication/tcp/resource.py | 19 +- synapse/replication/tcp/streams/__init__.py | 8 +- synapse/replication/tcp/streams/_base.py | 160 +++++++++++------ synapse/replication/tcp/streams/events.py | 5 +- synapse/replication/tcp/streams/federation.py | 19 +- synapse/server.py | 5 + synapse/storage/data_stores/main/cache.py | 44 ++--- synapse/storage/data_stores/main/deviceinbox.py | 88 ++++----- synapse/storage/data_stores/main/events.py | 114 ------------ synapse/storage/data_stores/main/events_worker.py | 114 ++++++++++++ synapse/storage/data_stores/main/room.py | 40 ++--- tests/replication/tcp/streams/_base.py | 55 ++++-- tests/replication/tcp/streams/test_receipts.py | 52 +++++- 24 files changed, 635 insertions(+), 487 deletions(-) create mode 100644 changelog.d/7024.misc create mode 100644 synapse/replication/http/streams.py (limited to 'synapse') diff --git a/changelog.d/7024.misc b/changelog.d/7024.misc new file mode 100644 index 0000000000..676f285377 --- /dev/null +++ b/changelog.d/7024.misc @@ -0,0 +1 @@ +Move catchup of replication streams logic to worker. diff --git a/docs/tcp_replication.md b/docs/tcp_replication.md index e3a4634b14..d4f7d9ec18 100644 --- a/docs/tcp_replication.md +++ b/docs/tcp_replication.md @@ -14,16 +14,16 @@ example flow would be (where '>' indicates master to worker and '<' worker to master flows): > SERVER example.com - < REPLICATE events 53 + < REPLICATE + > POSITION events 53 > RDATA events 54 ["$foo1:bar.com", ...] > RDATA events 55 ["$foo4:bar.com", ...] -The example shows the server accepting a new connection and sending its -identity with the `SERVER` command, followed by the client asking to -subscribe to the `events` stream from the token `53`. The server then -periodically sends `RDATA` commands which have the format -`RDATA `, where the format of `` is -defined by the individual streams. +The example shows the server accepting a new connection and sending its identity +with the `SERVER` command, followed by the client server to respond with the +position of all streams. The server then periodically sends `RDATA` commands +which have the format `RDATA `, where the format of +`` is defined by the individual streams. Error reporting happens by either the client or server sending an ERROR command, and usually the connection will be closed. @@ -32,9 +32,6 @@ Since the protocol is a simple line based, its possible to manually connect to the server using a tool like netcat. A few things should be noted when manually using the protocol: -- When subscribing to a stream using `REPLICATE`, the special token - `NOW` can be used to get all future updates. The special stream name - `ALL` can be used with `NOW` to subscribe to all available streams. - The federation stream is only available if federation sending has been disabled on the main process. - The server will only time connections out that have sent a `PING` @@ -91,9 +88,7 @@ The client: - Sends a `NAME` command, allowing the server to associate a human friendly name with the connection. This is optional. - Sends a `PING` as above -- For each stream the client wishes to subscribe to it sends a - `REPLICATE` with the `stream_name` and token it wants to subscribe - from. +- Sends a `REPLICATE` to get the current position of all streams. - On receipt of a `SERVER` command, checks that the server name matches the expected server name. @@ -140,9 +135,7 @@ the wire: > PING 1490197665618 < NAME synapse.app.appservice < PING 1490197665618 - < REPLICATE events 1 - < REPLICATE backfill 1 - < REPLICATE caches 1 + < REPLICATE > POSITION events 1 > POSITION backfill 1 > POSITION caches 1 @@ -181,9 +174,9 @@ client (C): #### POSITION (S) - The position of the stream has been updated. Sent to the client - after all missing updates for a stream have been sent to the client - and they're now up to date. + On receipt of a POSITION command clients should check if they have missed any + updates, and if so then fetch them out of band. Sent in response to a + REPLICATE command (but can happen at any time). #### ERROR (S, C) @@ -199,20 +192,7 @@ client (C): #### REPLICATE (C) -Asks the server to replicate a given stream. The syntax is: - -``` - REPLICATE -``` - -Where `` may be either: - * a numeric stream_id to stream updates since (exclusive) - * `NOW` to stream all subsequent updates. - -The `` is the name of a replication stream to subscribe -to (see [here](../synapse/replication/tcp/streams/_base.py) for a list -of streams). It can also be `ALL` to subscribe to all known streams, -in which case the `` must be set to `NOW`. +Asks the server for the current position of all streams. #### USER_SYNC (C) diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index bd1733573b..fba7ad9551 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -401,6 +401,9 @@ class GenericWorkerTyping(object): self._room_serials[row.room_id] = token self._room_typing[row.room_id] = row.user_ids + def get_current_token(self) -> int: + return self._latest_room_serial + class GenericWorkerSlavedStore( # FIXME(#3714): We need to add UserDirectoryStore as we write directly diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index 233cb33daf..a477578e44 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -499,4 +499,13 @@ class FederationSender(object): self._get_per_destination_queue(destination).attempt_new_transaction() def get_current_token(self) -> int: + # Dummy implementation for case where federation sender isn't offloaded + # to a worker. return 0 + + async def get_replication_rows( + self, from_token, to_token, limit, federation_ack=None + ): + # Dummy implementation for case where federation sender isn't offloaded + # to a worker. + return [] diff --git a/synapse/replication/http/__init__.py b/synapse/replication/http/__init__.py index 28dbc6fcba..4613b2538c 100644 --- a/synapse/replication/http/__init__.py +++ b/synapse/replication/http/__init__.py @@ -21,6 +21,7 @@ from synapse.replication.http import ( membership, register, send_event, + streams, ) REPLICATION_PREFIX = "/_synapse/replication" @@ -38,3 +39,4 @@ class ReplicationRestResource(JsonResource): login.register_servlets(hs, self) register.register_servlets(hs, self) devices.register_servlets(hs, self) + streams.register_servlets(hs, self) diff --git a/synapse/replication/http/streams.py b/synapse/replication/http/streams.py new file mode 100644 index 0000000000..ffd4c61993 --- /dev/null +++ b/synapse/replication/http/streams.py @@ -0,0 +1,78 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.api.errors import SynapseError +from synapse.http.servlet import parse_integer +from synapse.replication.http._base import ReplicationEndpoint + +logger = logging.getLogger(__name__) + + +class ReplicationGetStreamUpdates(ReplicationEndpoint): + """Fetches stream updates from a server. Used for streams not persisted to + the database, e.g. typing notifications. + + The API looks like: + + GET /_synapse/replication/get_repl_stream_updates/events?from_token=0&to_token=10&limit=100 + + 200 OK + + { + updates: [ ... ], + upto_token: 10, + limited: False, + } + + """ + + NAME = "get_repl_stream_updates" + PATH_ARGS = ("stream_name",) + METHOD = "GET" + + def __init__(self, hs): + super().__init__(hs) + + # We pull the streams from the replication steamer (if we try and make + # them ourselves we end up in an import loop). + self.streams = hs.get_replication_streamer().get_streams() + + @staticmethod + def _serialize_payload(stream_name, from_token, upto_token, limit): + return {"from_token": from_token, "upto_token": upto_token, "limit": limit} + + async def _handle_request(self, request, stream_name): + stream = self.streams.get(stream_name) + if stream is None: + raise SynapseError(400, "Unknown stream") + + from_token = parse_integer(request, "from_token", required=True) + upto_token = parse_integer(request, "upto_token", required=True) + limit = parse_integer(request, "limit", required=True) + + updates, upto_token, limited = await stream.get_updates_since( + from_token, upto_token, limit + ) + + return ( + 200, + {"updates": updates, "upto_token": upto_token, "limited": limited}, + ) + + +def register_servlets(hs, http_server): + ReplicationGetStreamUpdates(hs).register(http_server) diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py index f45cbd37a0..751c799d94 100644 --- a/synapse/replication/slave/storage/_base.py +++ b/synapse/replication/slave/storage/_base.py @@ -18,8 +18,10 @@ from typing import Dict, Optional import six -from synapse.storage._base import SQLBaseStore -from synapse.storage.data_stores.main.cache import CURRENT_STATE_CACHE_NAME +from synapse.storage.data_stores.main.cache import ( + CURRENT_STATE_CACHE_NAME, + CacheInvalidationWorkerStore, +) from synapse.storage.database import Database from synapse.storage.engines import PostgresEngine @@ -35,7 +37,7 @@ def __func__(inp): return inp.__func__ -class BaseSlavedStore(SQLBaseStore): +class BaseSlavedStore(CacheInvalidationWorkerStore): def __init__(self, database: Database, db_conn, hs): super(BaseSlavedStore, self).__init__(database, db_conn, hs) if isinstance(self.database_engine, PostgresEngine): @@ -60,6 +62,12 @@ class BaseSlavedStore(SQLBaseStore): pos["caches"] = self._cache_id_gen.get_current_token() return pos + def get_cache_stream_token(self): + if self._cache_id_gen: + return self._cache_id_gen.get_current_token() + else: + return 0 + def process_replication_rows(self, stream_name, token, rows): if stream_name == "caches": if self._cache_id_gen: diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py index f22c2d44a3..bce8a3d115 100644 --- a/synapse/replication/slave/storage/pushers.py +++ b/synapse/replication/slave/storage/pushers.py @@ -33,6 +33,9 @@ class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore): result["pushers"] = self._pushers_id_gen.get_current_token() return result + def get_pushers_stream_token(self): + return self._pushers_id_gen.get_current_token() + def process_replication_rows(self, stream_name, token, rows): if stream_name == "pushers": self._pushers_id_gen.advance(token) diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 02ab5b66ea..7e7ad0f798 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -55,6 +55,7 @@ class ReplicationClientFactory(ReconnectingClientFactory): self.client_name = client_name self.handler = handler self.server_name = hs.config.server_name + self.hs = hs self._clock = hs.get_clock() # As self.clock is defined in super class hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.stopTrying) @@ -65,7 +66,7 @@ class ReplicationClientFactory(ReconnectingClientFactory): def buildProtocol(self, addr): logger.info("Connected to replication: %r", addr) return ClientReplicationStreamProtocol( - self.client_name, self.server_name, self._clock, self.handler + self.hs, self.client_name, self.server_name, self._clock, self.handler, ) def clientConnectionLost(self, connector, reason): diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index 451671412d..5a6b734094 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -136,8 +136,8 @@ class PositionCommand(Command): """Sent by the server to tell the client the stream postition without needing to send an RDATA. - Sent to the client after all missing updates for a stream have been sent - to the client and they're now up to date. + On receipt of a POSITION command clients should check if they have missed + any updates, and if so then fetch them out of band. """ NAME = "POSITION" @@ -179,42 +179,24 @@ class NameCommand(Command): class ReplicateCommand(Command): - """Sent by the client to subscribe to the stream. + """Sent by the client to subscribe to streams. Format:: - REPLICATE - - Where may be either: - * a numeric stream_id to stream updates from - * "NOW" to stream all subsequent updates. - - The can be "ALL" to subscribe to all known streams, in which - case the must be set to "NOW", i.e.:: - - REPLICATE ALL NOW + REPLICATE """ NAME = "REPLICATE" - def __init__(self, stream_name, token): - self.stream_name = stream_name - self.token = token + def __init__(self): + pass @classmethod def from_line(cls, line): - stream_name, token = line.split(" ", 1) - if token in ("NOW", "now"): - token = "NOW" - else: - token = int(token) - return cls(stream_name, token) + return cls() def to_line(self): - return " ".join((self.stream_name, str(self.token))) - - def get_logcontext_id(self): - return "REPLICATE-" + self.stream_name + return "" class UserSyncCommand(Command): diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index bc1482a9bb..f81d2e2442 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -35,9 +35,7 @@ indicate which side is sending, these are *not* included on the wire:: > PING 1490197665618 < NAME synapse.app.appservice < PING 1490197665618 - < REPLICATE events 1 - < REPLICATE backfill 1 - < REPLICATE caches 1 + < REPLICATE > POSITION events 1 > POSITION backfill 1 > POSITION caches 1 @@ -53,17 +51,15 @@ import fcntl import logging import struct from collections import defaultdict -from typing import Any, DefaultDict, Dict, List, Set, Tuple +from typing import Any, DefaultDict, Dict, List, Set -from six import iteritems, iterkeys +from six import iteritems from prometheus_client import Counter -from twisted.internet import defer from twisted.protocols.basic import LineOnlyReceiver from twisted.python.failure import Failure -from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.metrics import LaterGauge from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.tcp.commands import ( @@ -82,11 +78,16 @@ from synapse.replication.tcp.commands import ( SyncCommand, UserSyncCommand, ) -from synapse.replication.tcp.streams import STREAMS_MAP +from synapse.replication.tcp.streams import STREAMS_MAP, Stream from synapse.types import Collection from synapse.util import Clock from synapse.util.stringutils import random_string +MYPY = False +if MYPY: + from synapse.server import HomeServer + + connection_close_counter = Counter( "synapse_replication_tcp_protocol_close_reason", "", ["reason_type"] ) @@ -411,16 +412,6 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): self.server_name = server_name self.streamer = streamer - # The streams the client has subscribed to and is up to date with - self.replication_streams = set() # type: Set[str] - - # The streams the client is currently subscribing to. - self.connecting_streams = set() # type: Set[str] - - # Map from stream name to list of updates to send once we've finished - # subscribing the client to the stream. - self.pending_rdata = {} # type: Dict[str, List[Tuple[int, Any]]] - def connectionMade(self): self.send_command(ServerCommand(self.server_name)) BaseReplicationStreamProtocol.connectionMade(self) @@ -436,21 +427,10 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): ) async def on_REPLICATE(self, cmd): - stream_name = cmd.stream_name - token = cmd.token - - if stream_name == "ALL": - # Subscribe to all streams we're publishing to. - deferreds = [ - run_in_background(self.subscribe_to_stream, stream, token) - for stream in iterkeys(self.streamer.streams_by_name) - ] - - await make_deferred_yieldable( - defer.gatherResults(deferreds, consumeErrors=True) - ) - else: - await self.subscribe_to_stream(stream_name, token) + # Subscribe to all streams we're publishing to. + for stream_name in self.streamer.streams_by_name: + current_token = self.streamer.get_stream_token(stream_name) + self.send_command(PositionCommand(stream_name, current_token)) async def on_FEDERATION_ACK(self, cmd): self.streamer.federation_ack(cmd.token) @@ -474,87 +454,12 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): cmd.last_seen, ) - async def subscribe_to_stream(self, stream_name, token): - """Subscribe the remote to a stream. - - This invloves checking if they've missed anything and sending those - updates down if they have. During that time new updates for the stream - are queued and sent once we've sent down any missed updates. - """ - self.replication_streams.discard(stream_name) - self.connecting_streams.add(stream_name) - - try: - # Get missing updates - updates, current_token = await self.streamer.get_stream_updates( - stream_name, token - ) - - # Send all the missing updates - for update in updates: - token, row = update[0], update[1] - self.send_command(RdataCommand(stream_name, token, row)) - - # We send a POSITION command to ensure that they have an up to - # date token (especially useful if we didn't send any updates - # above) - self.send_command(PositionCommand(stream_name, current_token)) - - # Now we can send any updates that came in while we were subscribing - pending_rdata = self.pending_rdata.pop(stream_name, []) - updates = [] - for token, update in pending_rdata: - # If the token is null, it is part of a batch update. Batches - # are multiple updates that share a single token. To denote - # this, the token is set to None for all tokens in the batch - # except for the last. If we find a None token, we keep looking - # through tokens until we find one that is not None and then - # process all previous updates in the batch as if they had the - # final token. - if token is None: - # Store this update as part of a batch - updates.append(update) - continue - - if token <= current_token: - # This update or batch of updates is older than - # current_token, dismiss it - updates = [] - continue - - updates.append(update) - - # Send all updates that are part of this batch with the - # found token - for update in updates: - self.send_command(RdataCommand(stream_name, token, update)) - - # Clear stored updates - updates = [] - - # They're now fully subscribed - self.replication_streams.add(stream_name) - except Exception as e: - logger.exception("[%s] Failed to handle REPLICATE command", self.id()) - self.send_error("failed to handle replicate: %r", e) - finally: - self.connecting_streams.discard(stream_name) - def stream_update(self, stream_name, token, data): """Called when a new update is available to stream to clients. We need to check if the client is interested in the stream or not """ - if stream_name in self.replication_streams: - # The client is subscribed to the stream - self.send_command(RdataCommand(stream_name, token, data)) - elif stream_name in self.connecting_streams: - # The client is being subscribed to the stream - logger.debug("[%s] Queuing RDATA %r %r", self.id(), stream_name, token) - self.pending_rdata.setdefault(stream_name, []).append((token, data)) - else: - # The client isn't subscribed - logger.debug("[%s] Dropping RDATA %r %r", self.id(), stream_name, token) + self.send_command(RdataCommand(stream_name, token, data)) def send_sync(self, data): self.send_command(SyncCommand(data)) @@ -638,6 +543,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): def __init__( self, + hs: "HomeServer", client_name: str, server_name: str, clock: Clock, @@ -649,22 +555,25 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): self.server_name = server_name self.handler = handler + self.streams = { + stream.NAME: stream(hs) for stream in STREAMS_MAP.values() + } # type: Dict[str, Stream] + # Set of stream names that have been subscribe to, but haven't yet # caught up with. This is used to track when the client has been fully # connected to the remote. - self.streams_connecting = set() # type: Set[str] + self.streams_connecting = set(STREAMS_MAP) # type: Set[str] # Map of stream to batched updates. See RdataCommand for info on how # batching works. - self.pending_batches = {} # type: Dict[str, Any] + self.pending_batches = {} # type: Dict[str, List[Any]] def connectionMade(self): self.send_command(NameCommand(self.client_name)) BaseReplicationStreamProtocol.connectionMade(self) # Once we've connected subscribe to the necessary streams - for stream_name, token in iteritems(self.handler.get_streams_to_replicate()): - self.replicate(stream_name, token) + self.replicate() # Tell the server if we have any users currently syncing (should only # happen on synchrotrons) @@ -676,10 +585,6 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): # We've now finished connecting to so inform the client handler self.handler.update_connection(self) - # This will happen if we don't actually subscribe to any streams - if not self.streams_connecting: - self.handler.finished_connecting() - async def on_SERVER(self, cmd): if cmd.data != self.server_name: logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data) @@ -697,7 +602,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): ) raise - if cmd.token is None: + if cmd.token is None or stream_name in self.streams_connecting: # I.e. this is part of a batch of updates for this stream. Batch # until we get an update for the stream with a non None token self.pending_batches.setdefault(stream_name, []).append(row) @@ -707,14 +612,55 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): rows.append(row) await self.handler.on_rdata(stream_name, cmd.token, rows) - async def on_POSITION(self, cmd): - # When we get a `POSITION` command it means we've finished getting - # missing updates for the given stream, and are now up to date. + async def on_POSITION(self, cmd: PositionCommand): + stream = self.streams.get(cmd.stream_name) + if not stream: + logger.error("Got POSITION for unknown stream: %s", cmd.stream_name) + return + + # Find where we previously streamed up to. + current_token = self.handler.get_streams_to_replicate().get(cmd.stream_name) + if current_token is None: + logger.warning( + "Got POSITION for stream we're not subscribed to: %s", cmd.stream_name + ) + return + + # Fetch all updates between then and now. + limited = True + while limited: + updates, current_token, limited = await stream.get_updates_since( + current_token, cmd.token + ) + + # Check if the connection was closed underneath us, if so we bail + # rather than risk having concurrent catch ups going on. + if self.state == ConnectionStates.CLOSED: + return + + if updates: + await self.handler.on_rdata( + cmd.stream_name, + current_token, + [stream.parse_row(update[1]) for update in updates], + ) + + # We've now caught up to position sent to us, notify handler. + await self.handler.on_position(cmd.stream_name, cmd.token) + self.streams_connecting.discard(cmd.stream_name) if not self.streams_connecting: self.handler.finished_connecting() - await self.handler.on_position(cmd.stream_name, cmd.token) + # Check if the connection was closed underneath us, if so we bail + # rather than risk having concurrent catch ups going on. + if self.state == ConnectionStates.CLOSED: + return + + # Handle any RDATA that came in while we were catching up. + rows = self.pending_batches.pop(cmd.stream_name, []) + if rows: + await self.handler.on_rdata(cmd.stream_name, rows[-1].token, rows) async def on_SYNC(self, cmd): self.handler.on_sync(cmd.data) @@ -722,22 +668,12 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand): self.handler.on_remote_server_up(cmd.data) - def replicate(self, stream_name, token): + def replicate(self): """Send the subscription request to the server """ - if stream_name not in STREAMS_MAP: - raise Exception("Invalid stream name %r" % (stream_name,)) - - logger.info( - "[%s] Subscribing to replication stream: %r from %r", - self.id(), - stream_name, - token, - ) - - self.streams_connecting.add(stream_name) + logger.info("[%s] Subscribing to replication streams", self.id()) - self.send_command(ReplicateCommand(stream_name, token)) + self.send_command(ReplicateCommand()) def on_connection_closed(self): BaseReplicationStreamProtocol.on_connection_closed(self) diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index 6e2ebaf614..4374e99e32 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -17,7 +17,7 @@ import logging import random -from typing import Any, List +from typing import Any, Dict, List from six import itervalues @@ -30,7 +30,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.util.metrics import Measure, measure_func from .protocol import ServerReplicationStreamProtocol -from .streams import STREAMS_MAP +from .streams import STREAMS_MAP, Stream from .streams.federation import FederationStream stream_updates_counter = Counter( @@ -52,7 +52,7 @@ class ReplicationStreamProtocolFactory(Factory): """ def __init__(self, hs): - self.streamer = ReplicationStreamer(hs) + self.streamer = hs.get_replication_streamer() self.clock = hs.get_clock() self.server_name = hs.config.server_name @@ -133,6 +133,11 @@ class ReplicationStreamer(object): for conn in self.connections: conn.send_error("server shutting down") + def get_streams(self) -> Dict[str, Stream]: + """Get a mapp from stream name to stream instance. + """ + return self.streams_by_name + def on_notifier_poke(self): """Checks if there is actually any new data and sends it to the connections if there are. @@ -190,7 +195,8 @@ class ReplicationStreamer(object): stream.current_token(), ) try: - updates, current_token = await stream.get_updates() + updates, current_token, limited = await stream.get_updates() + self.pending_updates |= limited except Exception: logger.info("Failed to handle stream %s", stream.NAME) raise @@ -226,8 +232,7 @@ class ReplicationStreamer(object): self.pending_updates = False self.is_looping = False - @measure_func("repl.get_stream_updates") - async def get_stream_updates(self, stream_name, token): + def get_stream_token(self, stream_name): """For a given stream get all updates since token. This is called when a client first subscribes to a stream. """ @@ -235,7 +240,7 @@ class ReplicationStreamer(object): if not stream: raise Exception("unknown stream %s", stream_name) - return await stream.get_updates_since(token) + return stream.current_token() @measure_func("repl.federation_ack") def federation_ack(self, token): diff --git a/synapse/replication/tcp/streams/__init__.py b/synapse/replication/tcp/streams/__init__.py index 29199f5b46..37bcd3de66 100644 --- a/synapse/replication/tcp/streams/__init__.py +++ b/synapse/replication/tcp/streams/__init__.py @@ -24,6 +24,9 @@ Each stream is defined by the following information: current_token: The function that returns the current token for the stream update_function: The function that returns a list of updates between two tokens """ + +from typing import Dict, Type + from synapse.replication.tcp.streams._base import ( AccountDataStream, BackfillStream, @@ -35,6 +38,7 @@ from synapse.replication.tcp.streams._base import ( PushersStream, PushRulesStream, ReceiptsStream, + Stream, TagAccountDataStream, ToDeviceStream, TypingStream, @@ -63,10 +67,12 @@ STREAMS_MAP = { GroupServerStream, UserSignatureStream, ) -} +} # type: Dict[str, Type[Stream]] + __all__ = [ "STREAMS_MAP", + "Stream", "BackfillStream", "PresenceStream", "TypingStream", diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 32d9514883..c14dff6c64 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -14,13 +14,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import itertools import logging from collections import namedtuple -from typing import Any, List, Optional, Tuple +from typing import Any, Awaitable, Callable, List, Optional, Tuple import attr +from synapse.replication.http.streams import ReplicationGetStreamUpdates from synapse.types import JsonDict logger = logging.getLogger(__name__) @@ -29,6 +29,15 @@ logger = logging.getLogger(__name__) MAX_EVENTS_BEHIND = 500000 +# Some type aliases to make things a bit easier. + +# A stream position token +Token = int + +# A pair of position in stream and args used to create an instance of `ROW_TYPE`. +StreamRow = Tuple[Token, tuple] + + class Stream(object): """Base class for the streams. @@ -56,6 +65,7 @@ class Stream(object): return cls.ROW_TYPE(*row) def __init__(self, hs): + # The token from which we last asked for updates self.last_token = self.current_token() @@ -65,61 +75,46 @@ class Stream(object): """ self.last_token = self.current_token() - async def get_updates(self): + async def get_updates(self) -> Tuple[List[Tuple[Token, JsonDict]], Token, bool]: """Gets all updates since the last time this function was called (or since the stream was constructed if it hadn't been called before). Returns: - Deferred[Tuple[List[Tuple[int, Any]], int]: - Resolves to a pair ``(updates, current_token)``, where ``updates`` is a - list of ``(token, row)`` entries. ``row`` will be json-serialised and - sent over the replication steam. + A triplet `(updates, new_last_token, limited)`, where `updates` is + a list of `(token, row)` entries, `new_last_token` is the new + position in stream, and `limited` is whether there are more updates + to fetch. """ - updates, current_token = await self.get_updates_since(self.last_token) + current_token = self.current_token() + updates, current_token, limited = await self.get_updates_since( + self.last_token, current_token + ) self.last_token = current_token - return updates, current_token + return updates, current_token, limited async def get_updates_since( - self, from_token: int - ) -> Tuple[List[Tuple[int, JsonDict]], int]: + self, from_token: Token, upto_token: Token, limit: int = 100 + ) -> Tuple[List[Tuple[Token, JsonDict]], Token, bool]: """Like get_updates except allows specifying from when we should stream updates Returns: - Resolves to a pair `(updates, new_last_token)`, where `updates` is - a list of `(token, row)` entries and `new_last_token` is the new - position in stream. + A triplet `(updates, new_last_token, limited)`, where `updates` is + a list of `(token, row)` entries, `new_last_token` is the new + position in stream, and `limited` is whether there are more updates + to fetch. """ - if from_token in ("NOW", "now"): - return [], self.current_token() - - current_token = self.current_token() - from_token = int(from_token) - if from_token == current_token: - return [], current_token + if from_token == upto_token: + return [], upto_token, False - rows = await self.update_function( - from_token, current_token, limit=MAX_EVENTS_BEHIND + 1 + updates, upto_token, limited = await self.update_function( + from_token, upto_token, limit=limit, ) - - # never turn more than MAX_EVENTS_BEHIND + 1 into updates. - rows = itertools.islice(rows, MAX_EVENTS_BEHIND + 1) - - updates = [(row[0], row[1:]) for row in rows] - - # check we didn't get more rows than the limit. - # doing it like this allows the update_function to be a generator. - if len(updates) >= MAX_EVENTS_BEHIND: - raise Exception("stream %s has fallen behind" % (self.NAME)) - - # The update function didn't hit the limit, so we must have got all - # the updates to `current_token`, and can return that as our new - # stream position. - return updates, current_token + return updates, upto_token, limited def current_token(self): """Gets the current token of the underlying streams. Should be provided @@ -141,6 +136,48 @@ class Stream(object): raise NotImplementedError() +def db_query_to_update_function( + query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]] +) -> Callable[[Token, Token, int], Awaitable[Tuple[List[StreamRow], Token, bool]]]: + """Wraps a db query function which returns a list of rows to make it + suitable for use as an `update_function` for the Stream class + """ + + async def update_function(from_token, upto_token, limit): + rows = await query_function(from_token, upto_token, limit) + updates = [(row[0], row[1:]) for row in rows] + limited = False + if len(updates) == limit: + upto_token = rows[-1][0] + limited = True + + return updates, upto_token, limited + + return update_function + + +def make_http_update_function( + hs, stream_name: str +) -> Callable[[Token, Token, Token], Awaitable[Tuple[List[StreamRow], Token, bool]]]: + """Makes a suitable function for use as an `update_function` that queries + the master process for updates. + """ + + client = ReplicationGetStreamUpdates.make_client(hs) + + async def update_function( + from_token: int, upto_token: int, limit: int + ) -> Tuple[List[Tuple[int, tuple]], int, bool]: + return await client( + stream_name=stream_name, + from_token=from_token, + upto_token=upto_token, + limit=limit, + ) + + return update_function + + class BackfillStream(Stream): """We fetched some old events and either we had never seen that event before or it went from being an outlier to not. @@ -164,7 +201,7 @@ class BackfillStream(Stream): def __init__(self, hs): store = hs.get_datastore() self.current_token = store.get_current_backfill_token # type: ignore - self.update_function = store.get_all_new_backfill_event_rows # type: ignore + self.update_function = db_query_to_update_function(store.get_all_new_backfill_event_rows) # type: ignore super(BackfillStream, self).__init__(hs) @@ -190,8 +227,15 @@ class PresenceStream(Stream): store = hs.get_datastore() presence_handler = hs.get_presence_handler() + self._is_worker = hs.config.worker_app is not None + self.current_token = store.get_current_presence_token # type: ignore - self.update_function = presence_handler.get_all_presence_updates # type: ignore + + if hs.config.worker_app is None: + self.update_function = db_query_to_update_function(presence_handler.get_all_presence_updates) # type: ignore + else: + # Query master process + self.update_function = make_http_update_function(hs, self.NAME) # type: ignore super(PresenceStream, self).__init__(hs) @@ -208,7 +252,12 @@ class TypingStream(Stream): typing_handler = hs.get_typing_handler() self.current_token = typing_handler.get_current_token # type: ignore - self.update_function = typing_handler.get_all_typing_updates # type: ignore + + if hs.config.worker_app is None: + self.update_function = db_query_to_update_function(typing_handler.get_all_typing_updates) # type: ignore + else: + # Query master process + self.update_function = make_http_update_function(hs, self.NAME) # type: ignore super(TypingStream, self).__init__(hs) @@ -232,7 +281,7 @@ class ReceiptsStream(Stream): store = hs.get_datastore() self.current_token = store.get_max_receipt_stream_id # type: ignore - self.update_function = store.get_all_updated_receipts # type: ignore + self.update_function = db_query_to_update_function(store.get_all_updated_receipts) # type: ignore super(ReceiptsStream, self).__init__(hs) @@ -256,7 +305,13 @@ class PushRulesStream(Stream): async def update_function(self, from_token, to_token, limit): rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit) - return [(row[0], row[2]) for row in rows] + + limited = False + if len(rows) == limit: + to_token = rows[-1][0] + limited = True + + return [(row[0], (row[2],)) for row in rows], to_token, limited class PushersStream(Stream): @@ -275,7 +330,7 @@ class PushersStream(Stream): store = hs.get_datastore() self.current_token = store.get_pushers_stream_token # type: ignore - self.update_function = store.get_all_updated_pushers_rows # type: ignore + self.update_function = db_query_to_update_function(store.get_all_updated_pushers_rows) # type: ignore super(PushersStream, self).__init__(hs) @@ -307,7 +362,7 @@ class CachesStream(Stream): store = hs.get_datastore() self.current_token = store.get_cache_stream_token # type: ignore - self.update_function = store.get_all_updated_caches # type: ignore + self.update_function = db_query_to_update_function(store.get_all_updated_caches) # type: ignore super(CachesStream, self).__init__(hs) @@ -333,7 +388,7 @@ class PublicRoomsStream(Stream): store = hs.get_datastore() self.current_token = store.get_current_public_room_stream_id # type: ignore - self.update_function = store.get_all_new_public_rooms # type: ignore + self.update_function = db_query_to_update_function(store.get_all_new_public_rooms) # type: ignore super(PublicRoomsStream, self).__init__(hs) @@ -354,7 +409,7 @@ class DeviceListsStream(Stream): store = hs.get_datastore() self.current_token = store.get_device_stream_token # type: ignore - self.update_function = store.get_all_device_list_changes_for_remotes # type: ignore + self.update_function = db_query_to_update_function(store.get_all_device_list_changes_for_remotes) # type: ignore super(DeviceListsStream, self).__init__(hs) @@ -372,7 +427,7 @@ class ToDeviceStream(Stream): store = hs.get_datastore() self.current_token = store.get_to_device_stream_token # type: ignore - self.update_function = store.get_all_new_device_messages # type: ignore + self.update_function = db_query_to_update_function(store.get_all_new_device_messages) # type: ignore super(ToDeviceStream, self).__init__(hs) @@ -392,7 +447,7 @@ class TagAccountDataStream(Stream): store = hs.get_datastore() self.current_token = store.get_max_account_data_stream_id # type: ignore - self.update_function = store.get_all_updated_tags # type: ignore + self.update_function = db_query_to_update_function(store.get_all_updated_tags) # type: ignore super(TagAccountDataStream, self).__init__(hs) @@ -412,10 +467,11 @@ class AccountDataStream(Stream): self.store = hs.get_datastore() self.current_token = self.store.get_max_account_data_stream_id # type: ignore + self.update_function = db_query_to_update_function(self._update_function) # type: ignore super(AccountDataStream, self).__init__(hs) - async def update_function(self, from_token, to_token, limit): + async def _update_function(self, from_token, to_token, limit): global_results, room_results = await self.store.get_all_updated_account_data( from_token, from_token, to_token, limit ) @@ -442,7 +498,7 @@ class GroupServerStream(Stream): store = hs.get_datastore() self.current_token = store.get_group_stream_token # type: ignore - self.update_function = store.get_all_groups_changes # type: ignore + self.update_function = db_query_to_update_function(store.get_all_groups_changes) # type: ignore super(GroupServerStream, self).__init__(hs) @@ -460,6 +516,6 @@ class UserSignatureStream(Stream): store = hs.get_datastore() self.current_token = store.get_device_stream_token # type: ignore - self.update_function = store.get_all_user_signature_changes_for_remotes # type: ignore + self.update_function = db_query_to_update_function(store.get_all_user_signature_changes_for_remotes) # type: ignore super(UserSignatureStream, self).__init__(hs) diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index b3afabb8cd..c6a595629f 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -19,7 +19,7 @@ from typing import Tuple, Type import attr -from ._base import Stream +from ._base import Stream, db_query_to_update_function """Handling of the 'events' replication stream @@ -117,10 +117,11 @@ class EventsStream(Stream): def __init__(self, hs): self._store = hs.get_datastore() self.current_token = self._store.get_current_events_token # type: ignore + self.update_function = db_query_to_update_function(self._update_function) # type: ignore super(EventsStream, self).__init__(hs) - async def update_function(self, from_token, current_token, limit=None): + async def _update_function(self, from_token, current_token, limit=None): event_rows = await self._store.get_all_new_forward_event_rows( from_token, current_token, limit ) diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py index f5f9336430..48c1d45718 100644 --- a/synapse/replication/tcp/streams/federation.py +++ b/synapse/replication/tcp/streams/federation.py @@ -15,7 +15,9 @@ # limitations under the License. from collections import namedtuple -from ._base import Stream +from twisted.internet import defer + +from synapse.replication.tcp.streams._base import Stream, db_query_to_update_function class FederationStream(Stream): @@ -33,11 +35,18 @@ class FederationStream(Stream): NAME = "federation" ROW_TYPE = FederationStreamRow + _QUERY_MASTER = True def __init__(self, hs): - federation_sender = hs.get_federation_sender() - - self.current_token = federation_sender.get_current_token # type: ignore - self.update_function = federation_sender.get_replication_rows # type: ignore + # Not all synapse instances will have a federation sender instance, + # whether that's a `FederationSender` or a `FederationRemoteSendQueue`, + # so we stub the stream out when that is the case. + if hs.config.worker_app is None or hs.should_send_federation(): + federation_sender = hs.get_federation_sender() + self.current_token = federation_sender.get_current_token # type: ignore + self.update_function = db_query_to_update_function(federation_sender.get_replication_rows) # type: ignore + else: + self.current_token = lambda: 0 # type: ignore + self.update_function = lambda from_token, upto_token, limit: defer.succeed(([], upto_token, bool)) # type: ignore super(FederationStream, self).__init__(hs) diff --git a/synapse/server.py b/synapse/server.py index 1b980371de..9426eb1672 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -85,6 +85,7 @@ from synapse.http.matrixfederationclient import MatrixFederationHttpClient from synapse.notifier import Notifier from synapse.push.action_generator import ActionGenerator from synapse.push.pusherpool import PusherPool +from synapse.replication.tcp.resource import ReplicationStreamer from synapse.rest.media.v1.media_repository import ( MediaRepository, MediaRepositoryResource, @@ -199,6 +200,7 @@ class HomeServer(object): "saml_handler", "event_client_serializer", "storage", + "replication_streamer", ] REQUIRED_ON_MASTER_STARTUP = ["user_directory_handler", "stats_handler"] @@ -536,6 +538,9 @@ class HomeServer(object): def build_storage(self) -> Storage: return Storage(self, self.datastores) + def build_replication_streamer(self) -> ReplicationStreamer: + return ReplicationStreamer(self) + def remove_pusher(self, app_id, push_key, user_id): return self.get_pusherpool().remove_pusher(app_id, push_key, user_id) diff --git a/synapse/storage/data_stores/main/cache.py b/synapse/storage/data_stores/main/cache.py index d4c44dcc75..4dc5da3fe8 100644 --- a/synapse/storage/data_stores/main/cache.py +++ b/synapse/storage/data_stores/main/cache.py @@ -32,7 +32,29 @@ logger = logging.getLogger(__name__) CURRENT_STATE_CACHE_NAME = "cs_cache_fake" -class CacheInvalidationStore(SQLBaseStore): +class CacheInvalidationWorkerStore(SQLBaseStore): + def get_all_updated_caches(self, last_id, current_id, limit): + if last_id == current_id: + return defer.succeed([]) + + def get_all_updated_caches_txn(txn): + # We purposefully don't bound by the current token, as we want to + # send across cache invalidations as quickly as possible. Cache + # invalidations are idempotent, so duplicates are fine. + sql = ( + "SELECT stream_id, cache_func, keys, invalidation_ts" + " FROM cache_invalidation_stream" + " WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?" + ) + txn.execute(sql, (last_id, limit)) + return txn.fetchall() + + return self.db.runInteraction( + "get_all_updated_caches", get_all_updated_caches_txn + ) + + +class CacheInvalidationStore(CacheInvalidationWorkerStore): async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]): """Invalidates the cache and adds it to the cache stream so slaves will know to invalidate their caches. @@ -145,26 +167,6 @@ class CacheInvalidationStore(SQLBaseStore): }, ) - def get_all_updated_caches(self, last_id, current_id, limit): - if last_id == current_id: - return defer.succeed([]) - - def get_all_updated_caches_txn(txn): - # We purposefully don't bound by the current token, as we want to - # send across cache invalidations as quickly as possible. Cache - # invalidations are idempotent, so duplicates are fine. - sql = ( - "SELECT stream_id, cache_func, keys, invalidation_ts" - " FROM cache_invalidation_stream" - " WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?" - ) - txn.execute(sql, (last_id, limit)) - return txn.fetchall() - - return self.db.runInteraction( - "get_all_updated_caches", get_all_updated_caches_txn - ) - def get_cache_stream_token(self): if self._cache_id_gen: return self._cache_id_gen.get_current_token() diff --git a/synapse/storage/data_stores/main/deviceinbox.py b/synapse/storage/data_stores/main/deviceinbox.py index 0613b49f4a..9a1178fb39 100644 --- a/synapse/storage/data_stores/main/deviceinbox.py +++ b/synapse/storage/data_stores/main/deviceinbox.py @@ -207,6 +207,50 @@ class DeviceInboxWorkerStore(SQLBaseStore): "delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn ) + def get_all_new_device_messages(self, last_pos, current_pos, limit): + """ + Args: + last_pos(int): + current_pos(int): + limit(int): + Returns: + A deferred list of rows from the device inbox + """ + if last_pos == current_pos: + return defer.succeed([]) + + def get_all_new_device_messages_txn(txn): + # We limit like this as we might have multiple rows per stream_id, and + # we want to make sure we always get all entries for any stream_id + # we return. + upper_pos = min(current_pos, last_pos + limit) + sql = ( + "SELECT max(stream_id), user_id" + " FROM device_inbox" + " WHERE ? < stream_id AND stream_id <= ?" + " GROUP BY user_id" + ) + txn.execute(sql, (last_pos, upper_pos)) + rows = txn.fetchall() + + sql = ( + "SELECT max(stream_id), destination" + " FROM device_federation_outbox" + " WHERE ? < stream_id AND stream_id <= ?" + " GROUP BY destination" + ) + txn.execute(sql, (last_pos, upper_pos)) + rows.extend(txn) + + # Order by ascending stream ordering + rows.sort() + + return rows + + return self.db.runInteraction( + "get_all_new_device_messages", get_all_new_device_messages_txn + ) + class DeviceInboxBackgroundUpdateStore(SQLBaseStore): DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" @@ -411,47 +455,3 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) rows.append((user_id, device_id, stream_id, message_json)) txn.executemany(sql, rows) - - def get_all_new_device_messages(self, last_pos, current_pos, limit): - """ - Args: - last_pos(int): - current_pos(int): - limit(int): - Returns: - A deferred list of rows from the device inbox - """ - if last_pos == current_pos: - return defer.succeed([]) - - def get_all_new_device_messages_txn(txn): - # We limit like this as we might have multiple rows per stream_id, and - # we want to make sure we always get all entries for any stream_id - # we return. - upper_pos = min(current_pos, last_pos + limit) - sql = ( - "SELECT max(stream_id), user_id" - " FROM device_inbox" - " WHERE ? < stream_id AND stream_id <= ?" - " GROUP BY user_id" - ) - txn.execute(sql, (last_pos, upper_pos)) - rows = txn.fetchall() - - sql = ( - "SELECT max(stream_id), destination" - " FROM device_federation_outbox" - " WHERE ? < stream_id AND stream_id <= ?" - " GROUP BY destination" - ) - txn.execute(sql, (last_pos, upper_pos)) - rows.extend(txn) - - # Order by ascending stream ordering - rows.sort() - - return rows - - return self.db.runInteraction( - "get_all_new_device_messages", get_all_new_device_messages_txn - ) diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index d593ef47b8..e71c23541d 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -1267,104 +1267,6 @@ class EventsStore( ret = yield self.db.runInteraction("count_daily_active_rooms", _count) return ret - def get_current_backfill_token(self): - """The current minimum token that backfilled events have reached""" - return -self._backfill_id_gen.get_current_token() - - def get_current_events_token(self): - """The current maximum token that events have reached""" - return self._stream_id_gen.get_current_token() - - def get_all_new_forward_event_rows(self, last_id, current_id, limit): - if last_id == current_id: - return defer.succeed([]) - - def get_all_new_forward_event_rows(txn): - sql = ( - "SELECT e.stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts, relates_to_id" - " FROM events AS e" - " LEFT JOIN redactions USING (event_id)" - " LEFT JOIN state_events USING (event_id)" - " LEFT JOIN event_relations USING (event_id)" - " WHERE ? < stream_ordering AND stream_ordering <= ?" - " ORDER BY stream_ordering ASC" - " LIMIT ?" - ) - txn.execute(sql, (last_id, current_id, limit)) - new_event_updates = txn.fetchall() - - if len(new_event_updates) == limit: - upper_bound = new_event_updates[-1][0] - else: - upper_bound = current_id - - sql = ( - "SELECT event_stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts, relates_to_id" - " FROM events AS e" - " INNER JOIN ex_outlier_stream USING (event_id)" - " LEFT JOIN redactions USING (event_id)" - " LEFT JOIN state_events USING (event_id)" - " LEFT JOIN event_relations USING (event_id)" - " WHERE ? < event_stream_ordering" - " AND event_stream_ordering <= ?" - " ORDER BY event_stream_ordering DESC" - ) - txn.execute(sql, (last_id, upper_bound)) - new_event_updates.extend(txn) - - return new_event_updates - - return self.db.runInteraction( - "get_all_new_forward_event_rows", get_all_new_forward_event_rows - ) - - def get_all_new_backfill_event_rows(self, last_id, current_id, limit): - if last_id == current_id: - return defer.succeed([]) - - def get_all_new_backfill_event_rows(txn): - sql = ( - "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts, relates_to_id" - " FROM events AS e" - " LEFT JOIN redactions USING (event_id)" - " LEFT JOIN state_events USING (event_id)" - " LEFT JOIN event_relations USING (event_id)" - " WHERE ? > stream_ordering AND stream_ordering >= ?" - " ORDER BY stream_ordering ASC" - " LIMIT ?" - ) - txn.execute(sql, (-last_id, -current_id, limit)) - new_event_updates = txn.fetchall() - - if len(new_event_updates) == limit: - upper_bound = new_event_updates[-1][0] - else: - upper_bound = current_id - - sql = ( - "SELECT -event_stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts, relates_to_id" - " FROM events AS e" - " INNER JOIN ex_outlier_stream USING (event_id)" - " LEFT JOIN redactions USING (event_id)" - " LEFT JOIN state_events USING (event_id)" - " LEFT JOIN event_relations USING (event_id)" - " WHERE ? > event_stream_ordering" - " AND event_stream_ordering >= ?" - " ORDER BY event_stream_ordering DESC" - ) - txn.execute(sql, (-last_id, -upper_bound)) - new_event_updates.extend(txn.fetchall()) - - return new_event_updates - - return self.db.runInteraction( - "get_all_new_backfill_event_rows", get_all_new_backfill_event_rows - ) - @cached(num_args=5, max_entries=10) def get_all_new_events( self, @@ -1850,22 +1752,6 @@ class EventsStore( return (int(res["topological_ordering"]), int(res["stream_ordering"])) - def get_all_updated_current_state_deltas(self, from_token, to_token, limit): - def get_all_updated_current_state_deltas_txn(txn): - sql = """ - SELECT stream_id, room_id, type, state_key, event_id - FROM current_state_delta_stream - WHERE ? < stream_id AND stream_id <= ? - ORDER BY stream_id ASC LIMIT ? - """ - txn.execute(sql, (from_token, to_token, limit)) - return txn.fetchall() - - return self.db.runInteraction( - "get_all_updated_current_state_deltas", - get_all_updated_current_state_deltas_txn, - ) - def insert_labels_for_event_txn( self, txn, event_id, labels, room_id, topological_ordering ): diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py index 3013f49d32..16ea8948b1 100644 --- a/synapse/storage/data_stores/main/events_worker.py +++ b/synapse/storage/data_stores/main/events_worker.py @@ -963,3 +963,117 @@ class EventsWorkerStore(SQLBaseStore): complexity_v1 = round(state_events / 500, 2) return {"v1": complexity_v1} + + def get_current_backfill_token(self): + """The current minimum token that backfilled events have reached""" + return -self._backfill_id_gen.get_current_token() + + def get_current_events_token(self): + """The current maximum token that events have reached""" + return self._stream_id_gen.get_current_token() + + def get_all_new_forward_event_rows(self, last_id, current_id, limit): + if last_id == current_id: + return defer.succeed([]) + + def get_all_new_forward_event_rows(txn): + sql = ( + "SELECT e.stream_ordering, e.event_id, e.room_id, e.type," + " state_key, redacts, relates_to_id" + " FROM events AS e" + " LEFT JOIN redactions USING (event_id)" + " LEFT JOIN state_events USING (event_id)" + " LEFT JOIN event_relations USING (event_id)" + " WHERE ? < stream_ordering AND stream_ordering <= ?" + " ORDER BY stream_ordering ASC" + " LIMIT ?" + ) + txn.execute(sql, (last_id, current_id, limit)) + new_event_updates = txn.fetchall() + + if len(new_event_updates) == limit: + upper_bound = new_event_updates[-1][0] + else: + upper_bound = current_id + + sql = ( + "SELECT event_stream_ordering, e.event_id, e.room_id, e.type," + " state_key, redacts, relates_to_id" + " FROM events AS e" + " INNER JOIN ex_outlier_stream USING (event_id)" + " LEFT JOIN redactions USING (event_id)" + " LEFT JOIN state_events USING (event_id)" + " LEFT JOIN event_relations USING (event_id)" + " WHERE ? < event_stream_ordering" + " AND event_stream_ordering <= ?" + " ORDER BY event_stream_ordering DESC" + ) + txn.execute(sql, (last_id, upper_bound)) + new_event_updates.extend(txn) + + return new_event_updates + + return self.db.runInteraction( + "get_all_new_forward_event_rows", get_all_new_forward_event_rows + ) + + def get_all_new_backfill_event_rows(self, last_id, current_id, limit): + if last_id == current_id: + return defer.succeed([]) + + def get_all_new_backfill_event_rows(txn): + sql = ( + "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type," + " state_key, redacts, relates_to_id" + " FROM events AS e" + " LEFT JOIN redactions USING (event_id)" + " LEFT JOIN state_events USING (event_id)" + " LEFT JOIN event_relations USING (event_id)" + " WHERE ? > stream_ordering AND stream_ordering >= ?" + " ORDER BY stream_ordering ASC" + " LIMIT ?" + ) + txn.execute(sql, (-last_id, -current_id, limit)) + new_event_updates = txn.fetchall() + + if len(new_event_updates) == limit: + upper_bound = new_event_updates[-1][0] + else: + upper_bound = current_id + + sql = ( + "SELECT -event_stream_ordering, e.event_id, e.room_id, e.type," + " state_key, redacts, relates_to_id" + " FROM events AS e" + " INNER JOIN ex_outlier_stream USING (event_id)" + " LEFT JOIN redactions USING (event_id)" + " LEFT JOIN state_events USING (event_id)" + " LEFT JOIN event_relations USING (event_id)" + " WHERE ? > event_stream_ordering" + " AND event_stream_ordering >= ?" + " ORDER BY event_stream_ordering DESC" + ) + txn.execute(sql, (-last_id, -upper_bound)) + new_event_updates.extend(txn.fetchall()) + + return new_event_updates + + return self.db.runInteraction( + "get_all_new_backfill_event_rows", get_all_new_backfill_event_rows + ) + + def get_all_updated_current_state_deltas(self, from_token, to_token, limit): + def get_all_updated_current_state_deltas_txn(txn): + sql = """ + SELECT stream_id, room_id, type, state_key, event_id + FROM current_state_delta_stream + WHERE ? < stream_id AND stream_id <= ? + ORDER BY stream_id ASC LIMIT ? + """ + txn.execute(sql, (from_token, to_token, limit)) + return txn.fetchall() + + return self.db.runInteraction( + "get_all_updated_current_state_deltas", + get_all_updated_current_state_deltas_txn, + ) diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py index e6c10c6316..aaebe427d3 100644 --- a/synapse/storage/data_stores/main/room.py +++ b/synapse/storage/data_stores/main/room.py @@ -732,6 +732,26 @@ class RoomWorkerStore(SQLBaseStore): return total_media_quarantined + def get_all_new_public_rooms(self, prev_id, current_id, limit): + def get_all_new_public_rooms(txn): + sql = """ + SELECT stream_id, room_id, visibility, appservice_id, network_id + FROM public_room_list_stream + WHERE stream_id > ? AND stream_id <= ? + ORDER BY stream_id ASC + LIMIT ? + """ + + txn.execute(sql, (prev_id, current_id, limit)) + return txn.fetchall() + + if prev_id == current_id: + return defer.succeed([]) + + return self.db.runInteraction( + "get_all_new_public_rooms", get_all_new_public_rooms + ) + class RoomBackgroundUpdateStore(SQLBaseStore): REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory" @@ -1249,26 +1269,6 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): def get_current_public_room_stream_id(self): return self._public_room_id_gen.get_current_token() - def get_all_new_public_rooms(self, prev_id, current_id, limit): - def get_all_new_public_rooms(txn): - sql = """ - SELECT stream_id, room_id, visibility, appservice_id, network_id - FROM public_room_list_stream - WHERE stream_id > ? AND stream_id <= ? - ORDER BY stream_id ASC - LIMIT ? - """ - - txn.execute(sql, (prev_id, current_id, limit)) - return txn.fetchall() - - if prev_id == current_id: - return defer.succeed([]) - - return self.db.runInteraction( - "get_all_new_public_rooms", get_all_new_public_rooms - ) - @defer.inlineCallbacks def block_room(self, room_id, user_id): """Marks the room as blocked. Can be called multiple times. diff --git a/tests/replication/tcp/streams/_base.py b/tests/replication/tcp/streams/_base.py index e96ad4ca4e..a755fe2879 100644 --- a/tests/replication/tcp/streams/_base.py +++ b/tests/replication/tcp/streams/_base.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from mock import Mock from synapse.replication.tcp.commands import ReplicateCommand @@ -29,19 +30,37 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): # build a replication server server_factory = ReplicationStreamProtocolFactory(self.hs) self.streamer = server_factory.streamer - server = server_factory.buildProtocol(None) + self.server = server_factory.buildProtocol(None) - # build a replication client, with a dummy handler - handler_factory = Mock() - self.test_handler = TestReplicationClientHandler() - self.test_handler.factory = handler_factory + self.test_handler = Mock(wraps=TestReplicationClientHandler()) self.client = ClientReplicationStreamProtocol( - "client", "test", clock, self.test_handler + hs, "client", "test", clock, self.test_handler, ) - # wire them together - self.client.makeConnection(FakeTransport(server, reactor)) - server.makeConnection(FakeTransport(self.client, reactor)) + self._client_transport = None + self._server_transport = None + + def reconnect(self): + if self._client_transport: + self.client.close() + + if self._server_transport: + self.server.close() + + self._client_transport = FakeTransport(self.server, self.reactor) + self.client.makeConnection(self._client_transport) + + self._server_transport = FakeTransport(self.client, self.reactor) + self.server.makeConnection(self._server_transport) + + def disconnect(self): + if self._client_transport: + self._client_transport = None + self.client.close() + + if self._server_transport: + self._server_transport = None + self.server.close() def replicate(self): """Tell the master side of replication that something has happened, and then @@ -50,19 +69,24 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): self.streamer.on_notifier_poke() self.pump(0.1) - def replicate_stream(self, stream, token="NOW"): + def replicate_stream(self): """Make the client end a REPLICATE command to set up a subscription to a stream""" - self.client.send_command(ReplicateCommand(stream, token)) + self.client.send_command(ReplicateCommand()) class TestReplicationClientHandler(object): """Drop-in for ReplicationClientHandler which just collects RDATA rows""" def __init__(self): - self.received_rdata_rows = [] + self.streams = set() + self._received_rdata_rows = [] def get_streams_to_replicate(self): - return {} + positions = {s: 0 for s in self.streams} + for stream, token, _ in self._received_rdata_rows: + if stream in self.streams: + positions[stream] = max(token, positions.get(stream, 0)) + return positions def get_currently_syncing_users(self): return [] @@ -73,6 +97,9 @@ class TestReplicationClientHandler(object): def finished_connecting(self): pass + async def on_position(self, stream_name, token): + """Called when we get new position data.""" + async def on_rdata(self, stream_name, token, rows): for r in rows: - self.received_rdata_rows.append((stream_name, token, r)) + self._received_rdata_rows.append((stream_name, token, r)) diff --git a/tests/replication/tcp/streams/test_receipts.py b/tests/replication/tcp/streams/test_receipts.py index fa2493cad6..0ec0825a0e 100644 --- a/tests/replication/tcp/streams/test_receipts.py +++ b/tests/replication/tcp/streams/test_receipts.py @@ -17,30 +17,64 @@ from synapse.replication.tcp.streams._base import ReceiptsStream from tests.replication.tcp.streams._base import BaseStreamTestCase USER_ID = "@feeling:blue" -ROOM_ID = "!room:blue" -EVENT_ID = "$event:blue" class ReceiptsStreamTestCase(BaseStreamTestCase): def test_receipt(self): + self.reconnect() + # make the client subscribe to the receipts stream - self.replicate_stream("receipts", "NOW") + self.replicate_stream() + self.test_handler.streams.add("receipts") # tell the master to send a new receipt self.get_success( self.hs.get_datastore().insert_receipt( - ROOM_ID, "m.read", USER_ID, [EVENT_ID], {"a": 1} + "!room:blue", "m.read", USER_ID, ["$event:blue"], {"a": 1} ) ) self.replicate() # there should be one RDATA command - rdata_rows = self.test_handler.received_rdata_rows + self.test_handler.on_rdata.assert_called_once() + stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0] + self.assertEqual(stream_name, "receipts") self.assertEqual(1, len(rdata_rows)) - self.assertEqual(rdata_rows[0][0], "receipts") - row = rdata_rows[0][2] # type: ReceiptsStream.ReceiptsStreamRow - self.assertEqual(ROOM_ID, row.room_id) + row = rdata_rows[0] # type: ReceiptsStream.ReceiptsStreamRow + self.assertEqual("!room:blue", row.room_id) self.assertEqual("m.read", row.receipt_type) self.assertEqual(USER_ID, row.user_id) - self.assertEqual(EVENT_ID, row.event_id) + self.assertEqual("$event:blue", row.event_id) self.assertEqual({"a": 1}, row.data) + + # Now let's disconnect and insert some data. + self.disconnect() + + self.test_handler.on_rdata.reset_mock() + + self.get_success( + self.hs.get_datastore().insert_receipt( + "!room2:blue", "m.read", USER_ID, ["$event2:foo"], {"a": 2} + ) + ) + self.replicate() + + # Nothing should have happened as we are disconnected + self.test_handler.on_rdata.assert_not_called() + + self.reconnect() + self.pump(0.1) + + # We should now have caught up and get the missing data + self.test_handler.on_rdata.assert_called_once() + stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0] + self.assertEqual(stream_name, "receipts") + self.assertEqual(token, 3) + self.assertEqual(1, len(rdata_rows)) + + row = rdata_rows[0] # type: ReceiptsStream.ReceiptsStreamRow + self.assertEqual("!room2:blue", row.room_id) + self.assertEqual("m.read", row.receipt_type) + self.assertEqual(USER_ID, row.user_id) + self.assertEqual("$event2:foo", row.event_id) + self.assertEqual({"a": 2}, row.data) -- cgit 1.4.1 From 6ca5e56fd12bbccb6b3ab43ed7c0281e4822274a Mon Sep 17 00:00:00 2001 From: Aaron Raimist Date: Wed, 25 Mar 2020 12:49:34 -0500 Subject: Remove unused captcha_bypass_secret option (#7137) Signed-off-by: Aaron Raimist --- changelog.d/7137.removal | 1 + docs/sample_config.yaml | 4 ---- synapse/config/captcha.py | 5 ----- 3 files changed, 1 insertion(+), 9 deletions(-) create mode 100644 changelog.d/7137.removal (limited to 'synapse') diff --git a/changelog.d/7137.removal b/changelog.d/7137.removal new file mode 100644 index 0000000000..75266a06bb --- /dev/null +++ b/changelog.d/7137.removal @@ -0,0 +1 @@ +Remove nonfunctional `captcha_bypass_secret` option from `homeserver.yaml`. \ No newline at end of file diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 276e43b732..2ef83646b3 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -872,10 +872,6 @@ media_store_path: "DATADIR/media_store" # #enable_registration_captcha: false -# A secret key used to bypass the captcha test entirely. -# -#captcha_bypass_secret: "YOUR_SECRET_HERE" - # The API endpoint to use for verifying m.login.recaptcha responses. # #recaptcha_siteverify_api: "https://www.recaptcha.net/recaptcha/api/siteverify" diff --git a/synapse/config/captcha.py b/synapse/config/captcha.py index f0171bb5b2..56c87fa296 100644 --- a/synapse/config/captcha.py +++ b/synapse/config/captcha.py @@ -24,7 +24,6 @@ class CaptchaConfig(Config): self.enable_registration_captcha = config.get( "enable_registration_captcha", False ) - self.captcha_bypass_secret = config.get("captcha_bypass_secret") self.recaptcha_siteverify_api = config.get( "recaptcha_siteverify_api", "https://www.recaptcha.net/recaptcha/api/siteverify", @@ -49,10 +48,6 @@ class CaptchaConfig(Config): # #enable_registration_captcha: false - # A secret key used to bypass the captcha test entirely. - # - #captcha_bypass_secret: "YOUR_SECRET_HERE" - # The API endpoint to use for verifying m.login.recaptcha responses. # #recaptcha_siteverify_api: "https://www.recaptcha.net/recaptcha/api/siteverify" -- cgit 1.4.1 From 1c1242acba9694a3a4b1eb3b14ec0bac11ee4ff8 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 26 Mar 2020 07:39:34 -0400 Subject: Validate that the session is not modified during UI-Auth (#7068) --- changelog.d/7068.bugfix | 1 + synapse/handlers/auth.py | 37 +++++++++++++++-- synapse/rest/client/v2_alpha/account.py | 11 ++++-- synapse/rest/client/v2_alpha/devices.py | 4 +- synapse/rest/client/v2_alpha/keys.py | 2 +- synapse/rest/client/v2_alpha/register.py | 5 ++- tests/rest/client/v2_alpha/test_auth.py | 68 +++++++++++++++++++++++++++++++- tests/test_terms_auth.py | 3 +- 8 files changed, 117 insertions(+), 14 deletions(-) create mode 100644 changelog.d/7068.bugfix (limited to 'synapse') diff --git a/changelog.d/7068.bugfix b/changelog.d/7068.bugfix new file mode 100644 index 0000000000..d1693a7f22 --- /dev/null +++ b/changelog.d/7068.bugfix @@ -0,0 +1 @@ +Ensure that a user inteactive authentication session is tied to a single request. diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 7860f9625e..2ce1425dfa 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -125,7 +125,11 @@ class AuthHandler(BaseHandler): @defer.inlineCallbacks def validate_user_via_ui_auth( - self, requester: Requester, request_body: Dict[str, Any], clientip: str + self, + requester: Requester, + request: SynapseRequest, + request_body: Dict[str, Any], + clientip: str, ): """ Checks that the user is who they claim to be, via a UI auth. @@ -137,6 +141,8 @@ class AuthHandler(BaseHandler): Args: requester: The user, as given by the access token + request: The request sent by the client. + request_body: The body of the request sent by the client clientip: The IP address of the client. @@ -172,7 +178,9 @@ class AuthHandler(BaseHandler): flows = [[login_type] for login_type in self._supported_login_types] try: - result, params, _ = yield self.check_auth(flows, request_body, clientip) + result, params, _ = yield self.check_auth( + flows, request, request_body, clientip + ) except LoginError: # Update the ratelimite to say we failed (`can_do_action` doesn't raise). self._failed_uia_attempts_ratelimiter.can_do_action( @@ -211,7 +219,11 @@ class AuthHandler(BaseHandler): @defer.inlineCallbacks def check_auth( - self, flows: List[List[str]], clientdict: Dict[str, Any], clientip: str + self, + flows: List[List[str]], + request: SynapseRequest, + clientdict: Dict[str, Any], + clientip: str, ): """ Takes a dictionary sent by the client in the login / registration @@ -231,6 +243,8 @@ class AuthHandler(BaseHandler): strings representing auth-types. At least one full flow must be completed in order for auth to be successful. + request: The request sent by the client. + clientdict: The dictionary from the client root level, not the 'auth' key: this method prompts for auth if none is sent. @@ -270,13 +284,27 @@ class AuthHandler(BaseHandler): # email auth link on there). It's probably too open to abuse # because it lets unauthenticated clients store arbitrary objects # on a homeserver. - # Revisit: Assumimg the REST APIs do sensible validation, the data + # Revisit: Assuming the REST APIs do sensible validation, the data # isn't arbintrary. session["clientdict"] = clientdict self._save_session(session) elif "clientdict" in session: clientdict = session["clientdict"] + # Ensure that the queried operation does not vary between stages of + # the UI authentication session. This is done by generating a stable + # comparator based on the URI, method, and body (minus the auth dict) + # and storing it during the initial query. Subsequent queries ensure + # that this comparator has not changed. + comparator = (request.uri, request.method, clientdict) + if "ui_auth" not in session: + session["ui_auth"] = comparator + elif session["ui_auth"] != comparator: + raise SynapseError( + 403, + "Requested operation has changed during the UI authentication session.", + ) + if not authdict: raise InteractiveAuthIncompleteError( self._auth_dict_for_flows(flows, session) @@ -322,6 +350,7 @@ class AuthHandler(BaseHandler): creds, list(clientdict), ) + return creds, clientdict, session["id"] ret = self._auth_dict_for_flows(flows, session) diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index 631cc74cb4..b1249b664c 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -234,13 +234,16 @@ class PasswordRestServlet(RestServlet): if self.auth.has_access_token(request): requester = await self.auth.get_user_by_req(request) params = await self.auth_handler.validate_user_via_ui_auth( - requester, body, self.hs.get_ip_from_request(request) + requester, request, body, self.hs.get_ip_from_request(request), ) user_id = requester.user.to_string() else: requester = None result, params, _ = await self.auth_handler.check_auth( - [[LoginType.EMAIL_IDENTITY]], body, self.hs.get_ip_from_request(request) + [[LoginType.EMAIL_IDENTITY]], + request, + body, + self.hs.get_ip_from_request(request), ) if LoginType.EMAIL_IDENTITY in result: @@ -308,7 +311,7 @@ class DeactivateAccountRestServlet(RestServlet): return 200, {} await self.auth_handler.validate_user_via_ui_auth( - requester, body, self.hs.get_ip_from_request(request) + requester, request, body, self.hs.get_ip_from_request(request), ) result = await self._deactivate_account_handler.deactivate_account( requester.user.to_string(), erase, id_server=body.get("id_server") @@ -656,7 +659,7 @@ class ThreepidAddRestServlet(RestServlet): assert_valid_client_secret(client_secret) await self.auth_handler.validate_user_via_ui_auth( - requester, body, self.hs.get_ip_from_request(request) + requester, request, body, self.hs.get_ip_from_request(request), ) validation_session = await self.identity_handler.validate_threepid_session( diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py index 94ff73f384..119d979052 100644 --- a/synapse/rest/client/v2_alpha/devices.py +++ b/synapse/rest/client/v2_alpha/devices.py @@ -81,7 +81,7 @@ class DeleteDevicesRestServlet(RestServlet): assert_params_in_dict(body, ["devices"]) await self.auth_handler.validate_user_via_ui_auth( - requester, body, self.hs.get_ip_from_request(request) + requester, request, body, self.hs.get_ip_from_request(request), ) await self.device_handler.delete_devices( @@ -127,7 +127,7 @@ class DeviceRestServlet(RestServlet): raise await self.auth_handler.validate_user_via_ui_auth( - requester, body, self.hs.get_ip_from_request(request) + requester, request, body, self.hs.get_ip_from_request(request), ) await self.device_handler.delete_device(requester.user.to_string(), device_id) diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index f7ed4daf90..5eb7ef35a4 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -263,7 +263,7 @@ class SigningKeyUploadServlet(RestServlet): body = parse_json_object_from_request(request) await self.auth_handler.validate_user_via_ui_auth( - requester, body, self.hs.get_ip_from_request(request) + requester, request, body, self.hs.get_ip_from_request(request), ) result = await self.e2e_keys_handler.upload_signing_keys_for_user(user_id, body) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index a09189b1b4..6963d79310 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -499,7 +499,10 @@ class RegisterRestServlet(RestServlet): ) auth_result, params, session_id = await self.auth_handler.check_auth( - self._registration_flows, body, self.hs.get_ip_from_request(request) + self._registration_flows, + request, + body, + self.hs.get_ip_from_request(request), ) # Check that we're not trying to register a denied 3pid. diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py index b6df1396ad..624bf5ada2 100644 --- a/tests/rest/client/v2_alpha/test_auth.py +++ b/tests/rest/client/v2_alpha/test_auth.py @@ -104,7 +104,7 @@ class FallbackAuthTests(unittest.HomeserverTestCase): ) self.render(request) - # Now we should have fufilled a complete auth flow, including + # Now we should have fulfilled a complete auth flow, including # the recaptcha fallback step, we can then send a # request to the register API with the session in the authdict. request, channel = self.make_request( @@ -115,3 +115,69 @@ class FallbackAuthTests(unittest.HomeserverTestCase): # We're given a registered user. self.assertEqual(channel.json_body["user_id"], "@user:test") + + def test_cannot_change_operation(self): + """ + The initial requested operation cannot be modified during the user interactive authentication session. + """ + + # Make the initial request to register. (Later on a different password + # will be used.) + request, channel = self.make_request( + "POST", + "register", + {"username": "user", "type": "m.login.password", "password": "bar"}, + ) + self.render(request) + + # Returns a 401 as per the spec + self.assertEqual(request.code, 401) + # Grab the session + session = channel.json_body["session"] + # Assert our configured public key is being given + self.assertEqual( + channel.json_body["params"]["m.login.recaptcha"]["public_key"], "brokencake" + ) + + request, channel = self.make_request( + "GET", "auth/m.login.recaptcha/fallback/web?session=" + session + ) + self.render(request) + self.assertEqual(request.code, 200) + + request, channel = self.make_request( + "POST", + "auth/m.login.recaptcha/fallback/web?session=" + + session + + "&g-recaptcha-response=a", + ) + self.render(request) + self.assertEqual(request.code, 200) + + # The recaptcha handler is called with the response given + attempts = self.recaptcha_checker.recaptcha_attempts + self.assertEqual(len(attempts), 1) + self.assertEqual(attempts[0][0]["response"], "a") + + # also complete the dummy auth + request, channel = self.make_request( + "POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}} + ) + self.render(request) + + # Now we should have fulfilled a complete auth flow, including + # the recaptcha fallback step. Make the initial request again, but + # with a different password. This causes the request to fail since the + # operaiton was modified during the ui auth session. + request, channel = self.make_request( + "POST", + "register", + { + "username": "user", + "type": "m.login.password", + "password": "foo", # Note this doesn't match the original request. + "auth": {"session": session}, + }, + ) + self.render(request) + self.assertEqual(channel.code, 403) diff --git a/tests/test_terms_auth.py b/tests/test_terms_auth.py index 5ec5d2b358..a3f98a1412 100644 --- a/tests/test_terms_auth.py +++ b/tests/test_terms_auth.py @@ -53,7 +53,8 @@ class TermsTestCase(unittest.HomeserverTestCase): def test_ui_auth(self): # Do a UI auth request - request, channel = self.make_request(b"POST", self.url, b"{}") + request_data = json.dumps({"username": "kermit", "password": "monkey"}) + request, channel = self.make_request(b"POST", self.url, request_data) self.render(request) self.assertEquals(channel.result["code"], b"401", channel.result) -- cgit 1.4.1 From e8e2ddb60ae11db488f159901d918cb159695912 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Thu, 26 Mar 2020 17:51:13 +0100 Subject: Allow server admins to define and enforce a password policy (MSC2000). (#7118) --- changelog.d/7118.feature | 1 + docs/sample_config.yaml | 35 ++++ synapse/api/errors.py | 21 +++ synapse/config/password.py | 39 +++++ synapse/handlers/password_policy.py | 93 +++++++++++ synapse/handlers/set_password.py | 2 + synapse/rest/__init__.py | 2 + synapse/rest/client/v2_alpha/password_policy.py | 58 +++++++ synapse/rest/client/v2_alpha/register.py | 2 + synapse/server.py | 5 + tests/rest/client/v2_alpha/test_password_policy.py | 179 +++++++++++++++++++++ 11 files changed, 437 insertions(+) create mode 100644 changelog.d/7118.feature create mode 100644 synapse/handlers/password_policy.py create mode 100644 synapse/rest/client/v2_alpha/password_policy.py create mode 100644 tests/rest/client/v2_alpha/test_password_policy.py (limited to 'synapse') diff --git a/changelog.d/7118.feature b/changelog.d/7118.feature new file mode 100644 index 0000000000..5cbfd98160 --- /dev/null +++ b/changelog.d/7118.feature @@ -0,0 +1 @@ +Allow server admins to define and enforce a password policy (MSC2000). \ No newline at end of file diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 2ef83646b3..1a1d061759 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1482,6 +1482,41 @@ password_config: # #pepper: "EVEN_MORE_SECRET" + # Define and enforce a password policy. Each parameter is optional. + # This is an implementation of MSC2000. + # + policy: + # Whether to enforce the password policy. + # Defaults to 'false'. + # + #enabled: true + + # Minimum accepted length for a password. + # Defaults to 0. + # + #minimum_length: 15 + + # Whether a password must contain at least one digit. + # Defaults to 'false'. + # + #require_digit: true + + # Whether a password must contain at least one symbol. + # A symbol is any character that's not a number or a letter. + # Defaults to 'false'. + # + #require_symbol: true + + # Whether a password must contain at least one lowercase letter. + # Defaults to 'false'. + # + #require_lowercase: true + + # Whether a password must contain at least one lowercase letter. + # Defaults to 'false'. + # + #require_uppercase: true + # Configuration for sending emails from Synapse. # diff --git a/synapse/api/errors.py b/synapse/api/errors.py index 616942b057..11da016ac5 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -64,6 +64,13 @@ class Codes(object): INCOMPATIBLE_ROOM_VERSION = "M_INCOMPATIBLE_ROOM_VERSION" WRONG_ROOM_KEYS_VERSION = "M_WRONG_ROOM_KEYS_VERSION" EXPIRED_ACCOUNT = "ORG_MATRIX_EXPIRED_ACCOUNT" + PASSWORD_TOO_SHORT = "M_PASSWORD_TOO_SHORT" + PASSWORD_NO_DIGIT = "M_PASSWORD_NO_DIGIT" + PASSWORD_NO_UPPERCASE = "M_PASSWORD_NO_UPPERCASE" + PASSWORD_NO_LOWERCASE = "M_PASSWORD_NO_LOWERCASE" + PASSWORD_NO_SYMBOL = "M_PASSWORD_NO_SYMBOL" + PASSWORD_IN_DICTIONARY = "M_PASSWORD_IN_DICTIONARY" + WEAK_PASSWORD = "M_WEAK_PASSWORD" INVALID_SIGNATURE = "M_INVALID_SIGNATURE" USER_DEACTIVATED = "M_USER_DEACTIVATED" BAD_ALIAS = "M_BAD_ALIAS" @@ -439,6 +446,20 @@ class IncompatibleRoomVersionError(SynapseError): return cs_error(self.msg, self.errcode, room_version=self._room_version) +class PasswordRefusedError(SynapseError): + """A password has been refused, either during password reset/change or registration. + """ + + def __init__( + self, + msg="This password doesn't comply with the server's policy", + errcode=Codes.WEAK_PASSWORD, + ): + super(PasswordRefusedError, self).__init__( + code=400, msg=msg, errcode=errcode, + ) + + class RequestSendFailed(RuntimeError): """Sending a HTTP request over federation failed due to not being able to talk to the remote server for some reason. diff --git a/synapse/config/password.py b/synapse/config/password.py index 2a634ac751..9c0ea8c30a 100644 --- a/synapse/config/password.py +++ b/synapse/config/password.py @@ -31,6 +31,10 @@ class PasswordConfig(Config): self.password_localdb_enabled = password_config.get("localdb_enabled", True) self.password_pepper = password_config.get("pepper", "") + # Password policy + self.password_policy = password_config.get("policy") or {} + self.password_policy_enabled = self.password_policy.get("enabled", False) + def generate_config_section(self, config_dir_path, server_name, **kwargs): return """\ password_config: @@ -48,4 +52,39 @@ class PasswordConfig(Config): # DO NOT CHANGE THIS AFTER INITIAL SETUP! # #pepper: "EVEN_MORE_SECRET" + + # Define and enforce a password policy. Each parameter is optional. + # This is an implementation of MSC2000. + # + policy: + # Whether to enforce the password policy. + # Defaults to 'false'. + # + #enabled: true + + # Minimum accepted length for a password. + # Defaults to 0. + # + #minimum_length: 15 + + # Whether a password must contain at least one digit. + # Defaults to 'false'. + # + #require_digit: true + + # Whether a password must contain at least one symbol. + # A symbol is any character that's not a number or a letter. + # Defaults to 'false'. + # + #require_symbol: true + + # Whether a password must contain at least one lowercase letter. + # Defaults to 'false'. + # + #require_lowercase: true + + # Whether a password must contain at least one lowercase letter. + # Defaults to 'false'. + # + #require_uppercase: true """ diff --git a/synapse/handlers/password_policy.py b/synapse/handlers/password_policy.py new file mode 100644 index 0000000000..d06b110269 --- /dev/null +++ b/synapse/handlers/password_policy.py @@ -0,0 +1,93 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import re + +from synapse.api.errors import Codes, PasswordRefusedError + +logger = logging.getLogger(__name__) + + +class PasswordPolicyHandler(object): + def __init__(self, hs): + self.policy = hs.config.password_policy + self.enabled = hs.config.password_policy_enabled + + # Regexps for the spec'd policy parameters. + self.regexp_digit = re.compile("[0-9]") + self.regexp_symbol = re.compile("[^a-zA-Z0-9]") + self.regexp_uppercase = re.compile("[A-Z]") + self.regexp_lowercase = re.compile("[a-z]") + + def validate_password(self, password): + """Checks whether a given password complies with the server's policy. + + Args: + password (str): The password to check against the server's policy. + + Raises: + PasswordRefusedError: The password doesn't comply with the server's policy. + """ + + if not self.enabled: + return + + minimum_accepted_length = self.policy.get("minimum_length", 0) + if len(password) < minimum_accepted_length: + raise PasswordRefusedError( + msg=( + "The password must be at least %d characters long" + % minimum_accepted_length + ), + errcode=Codes.PASSWORD_TOO_SHORT, + ) + + if ( + self.policy.get("require_digit", False) + and self.regexp_digit.search(password) is None + ): + raise PasswordRefusedError( + msg="The password must include at least one digit", + errcode=Codes.PASSWORD_NO_DIGIT, + ) + + if ( + self.policy.get("require_symbol", False) + and self.regexp_symbol.search(password) is None + ): + raise PasswordRefusedError( + msg="The password must include at least one symbol", + errcode=Codes.PASSWORD_NO_SYMBOL, + ) + + if ( + self.policy.get("require_uppercase", False) + and self.regexp_uppercase.search(password) is None + ): + raise PasswordRefusedError( + msg="The password must include at least one uppercase letter", + errcode=Codes.PASSWORD_NO_UPPERCASE, + ) + + if ( + self.policy.get("require_lowercase", False) + and self.regexp_lowercase.search(password) is None + ): + raise PasswordRefusedError( + msg="The password must include at least one lowercase letter", + errcode=Codes.PASSWORD_NO_LOWERCASE, + ) diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py index 12657ca698..7d1263caf2 100644 --- a/synapse/handlers/set_password.py +++ b/synapse/handlers/set_password.py @@ -32,6 +32,7 @@ class SetPasswordHandler(BaseHandler): super(SetPasswordHandler, self).__init__(hs) self._auth_handler = hs.get_auth_handler() self._device_handler = hs.get_device_handler() + self._password_policy_handler = hs.get_password_policy_handler() @defer.inlineCallbacks def set_password( @@ -44,6 +45,7 @@ class SetPasswordHandler(BaseHandler): if not self.hs.config.password_localdb_enabled: raise SynapseError(403, "Password change disabled", errcode=Codes.FORBIDDEN) + self._password_policy_handler.validate_password(new_password) password_hash = yield self._auth_handler.hash(new_password) try: diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py index 4a1fc2ec2b..46e458e95b 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py @@ -41,6 +41,7 @@ from synapse.rest.client.v2_alpha import ( keys, notifications, openid, + password_policy, read_marker, receipts, register, @@ -118,6 +119,7 @@ class ClientRestResource(JsonResource): capabilities.register_servlets(hs, client_resource) account_validity.register_servlets(hs, client_resource) relations.register_servlets(hs, client_resource) + password_policy.register_servlets(hs, client_resource) # moving to /_synapse/admin synapse.rest.admin.register_servlets_for_client_rest_resource( diff --git a/synapse/rest/client/v2_alpha/password_policy.py b/synapse/rest/client/v2_alpha/password_policy.py new file mode 100644 index 0000000000..968403cca4 --- /dev/null +++ b/synapse/rest/client/v2_alpha/password_policy.py @@ -0,0 +1,58 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.http.servlet import RestServlet + +from ._base import client_patterns + +logger = logging.getLogger(__name__) + + +class PasswordPolicyServlet(RestServlet): + PATTERNS = client_patterns("/password_policy$") + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super(PasswordPolicyServlet, self).__init__() + + self.policy = hs.config.password_policy + self.enabled = hs.config.password_policy_enabled + + def on_GET(self, request): + if not self.enabled or not self.policy: + return (200, {}) + + policy = {} + + for param in [ + "minimum_length", + "require_digit", + "require_symbol", + "require_lowercase", + "require_uppercase", + ]: + if param in self.policy: + policy["m.%s" % param] = self.policy[param] + + return (200, policy) + + +def register_servlets(hs, http_server): + PasswordPolicyServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 6963d79310..66fc8ec179 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -373,6 +373,7 @@ class RegisterRestServlet(RestServlet): self.room_member_handler = hs.get_room_member_handler() self.macaroon_gen = hs.get_macaroon_generator() self.ratelimiter = hs.get_registration_ratelimiter() + self.password_policy_handler = hs.get_password_policy_handler() self.clock = hs.get_clock() self._registration_flows = _calculate_registration_flows( @@ -420,6 +421,7 @@ class RegisterRestServlet(RestServlet): or len(body["password"]) > 512 ): raise SynapseError(400, "Invalid password") + self.password_policy_handler.validate_password(body["password"]) desired_username = None if "username" in body: diff --git a/synapse/server.py b/synapse/server.py index 9426eb1672..d0d80e8ac5 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -66,6 +66,7 @@ from synapse.handlers.groups_local import GroupsLocalHandler, GroupsLocalWorkerH from synapse.handlers.initial_sync import InitialSyncHandler from synapse.handlers.message import EventCreationHandler, MessageHandler from synapse.handlers.pagination import PaginationHandler +from synapse.handlers.password_policy import PasswordPolicyHandler from synapse.handlers.presence import PresenceHandler from synapse.handlers.profile import BaseProfileHandler, MasterProfileHandler from synapse.handlers.read_marker import ReadMarkerHandler @@ -199,6 +200,7 @@ class HomeServer(object): "account_validity_handler", "saml_handler", "event_client_serializer", + "password_policy_handler", "storage", "replication_streamer", ] @@ -535,6 +537,9 @@ class HomeServer(object): def build_event_client_serializer(self): return EventClientSerializer(self) + def build_password_policy_handler(self): + return PasswordPolicyHandler(self) + def build_storage(self) -> Storage: return Storage(self, self.datastores) diff --git a/tests/rest/client/v2_alpha/test_password_policy.py b/tests/rest/client/v2_alpha/test_password_policy.py new file mode 100644 index 0000000000..c57072f50c --- /dev/null +++ b/tests/rest/client/v2_alpha/test_password_policy.py @@ -0,0 +1,179 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +from synapse.api.constants import LoginType +from synapse.api.errors import Codes +from synapse.rest import admin +from synapse.rest.client.v1 import login +from synapse.rest.client.v2_alpha import account, password_policy, register + +from tests import unittest + + +class PasswordPolicyTestCase(unittest.HomeserverTestCase): + """Tests the password policy feature and its compliance with MSC2000. + + When validating a password, Synapse does the necessary checks in this order: + + 1. Password is long enough + 2. Password contains digit(s) + 3. Password contains symbol(s) + 4. Password contains uppercase letter(s) + 5. Password contains lowercase letter(s) + + For each test below that checks whether a password triggers the right error code, + that test provides a password good enough to pass the previous tests, but not the + one it is currently testing (nor any test that comes afterward). + """ + + servlets = [ + admin.register_servlets_for_client_rest_resource, + login.register_servlets, + register.register_servlets, + password_policy.register_servlets, + account.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + self.register_url = "/_matrix/client/r0/register" + self.policy = { + "enabled": True, + "minimum_length": 10, + "require_digit": True, + "require_symbol": True, + "require_lowercase": True, + "require_uppercase": True, + } + + config = self.default_config() + config["password_config"] = { + "policy": self.policy, + } + + hs = self.setup_test_homeserver(config=config) + return hs + + def test_get_policy(self): + """Tests if the /password_policy endpoint returns the configured policy.""" + + request, channel = self.make_request( + "GET", "/_matrix/client/r0/password_policy" + ) + self.render(request) + + self.assertEqual(channel.code, 200, channel.result) + self.assertEqual( + channel.json_body, + { + "m.minimum_length": 10, + "m.require_digit": True, + "m.require_symbol": True, + "m.require_lowercase": True, + "m.require_uppercase": True, + }, + channel.result, + ) + + def test_password_too_short(self): + request_data = json.dumps({"username": "kermit", "password": "shorty"}) + request, channel = self.make_request("POST", self.register_url, request_data) + self.render(request) + + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual( + channel.json_body["errcode"], Codes.PASSWORD_TOO_SHORT, channel.result, + ) + + def test_password_no_digit(self): + request_data = json.dumps({"username": "kermit", "password": "longerpassword"}) + request, channel = self.make_request("POST", self.register_url, request_data) + self.render(request) + + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual( + channel.json_body["errcode"], Codes.PASSWORD_NO_DIGIT, channel.result, + ) + + def test_password_no_symbol(self): + request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword"}) + request, channel = self.make_request("POST", self.register_url, request_data) + self.render(request) + + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual( + channel.json_body["errcode"], Codes.PASSWORD_NO_SYMBOL, channel.result, + ) + + def test_password_no_uppercase(self): + request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword!"}) + request, channel = self.make_request("POST", self.register_url, request_data) + self.render(request) + + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual( + channel.json_body["errcode"], Codes.PASSWORD_NO_UPPERCASE, channel.result, + ) + + def test_password_no_lowercase(self): + request_data = json.dumps({"username": "kermit", "password": "L0NGERPASSWORD!"}) + request, channel = self.make_request("POST", self.register_url, request_data) + self.render(request) + + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual( + channel.json_body["errcode"], Codes.PASSWORD_NO_LOWERCASE, channel.result, + ) + + def test_password_compliant(self): + request_data = json.dumps({"username": "kermit", "password": "L0ngerpassword!"}) + request, channel = self.make_request("POST", self.register_url, request_data) + self.render(request) + + # Getting a 401 here means the password has passed validation and the server has + # responded with a list of registration flows. + self.assertEqual(channel.code, 401, channel.result) + + def test_password_change(self): + """This doesn't test every possible use case, only that hitting /account/password + triggers the password validation code. + """ + compliant_password = "C0mpl!antpassword" + not_compliant_password = "notcompliantpassword" + + user_id = self.register_user("kermit", compliant_password) + tok = self.login("kermit", compliant_password) + + request_data = json.dumps( + { + "new_password": not_compliant_password, + "auth": { + "password": compliant_password, + "type": LoginType.PASSWORD, + "user": user_id, + }, + } + ) + request, channel = self.make_request( + "POST", + "/_matrix/client/r0/account/password", + request_data, + access_token=tok, + ) + self.render(request) + + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.json_body["errcode"], Codes.PASSWORD_NO_DIGIT) -- cgit 1.4.1 From 060e7dce09ae2197f29811769b13db30ed340211 Mon Sep 17 00:00:00 2001 From: Jason Robinson Date: Thu, 26 Mar 2020 19:02:35 +0200 Subject: Allow RedirectResponse in SAML response handler Allow custom SAML handlers to redirect after processing an auth response. Fixes #7149 Signed-off-by: Jason Robinson --- changelog.d/7151.bugfix | 1 + synapse/handlers/saml_handler.py | 4 ++++ 2 files changed, 5 insertions(+) create mode 100644 changelog.d/7151.bugfix (limited to 'synapse') diff --git a/changelog.d/7151.bugfix b/changelog.d/7151.bugfix new file mode 100644 index 0000000000..69cde9351d --- /dev/null +++ b/changelog.d/7151.bugfix @@ -0,0 +1 @@ +Allow custom SAML handlers to redirect after processing an auth response. diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py index 72c109981b..dc04b53f43 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml_handler.py @@ -26,6 +26,7 @@ from synapse.config import ConfigError from synapse.http.server import finish_request from synapse.http.servlet import parse_string from synapse.module_api import ModuleApi +from synapse.module_api.errors import RedirectException from synapse.types import ( UserID, map_username_to_mxid_localpart, @@ -119,6 +120,9 @@ class SamlHandler: try: user_id = await self._map_saml_response_to_user(resp_bytes, relay_state) + except RedirectException: + # Raise the exception as per the wishes of the SAML module response + raise except Exception as e: # If decoding the response or mapping it to a user failed, then log the # error and tell the user that something went wrong. -- cgit 1.4.1 From 825fb5d0a5699fb5b5eef9a8c2170d0c76158001 Mon Sep 17 00:00:00 2001 From: Nektarios Katakis Date: Thu, 26 Mar 2020 17:13:14 +0000 Subject: Don't default to an invalid sqlite config if no database configuration is provided (#6573) --- changelog.d/6573.bugfix | 1 + synapse/config/database.py | 69 +++++++++++++++++++++++++++++++--------------- 2 files changed, 48 insertions(+), 22 deletions(-) create mode 100644 changelog.d/6573.bugfix (limited to 'synapse') diff --git a/changelog.d/6573.bugfix b/changelog.d/6573.bugfix new file mode 100644 index 0000000000..1bb8014db7 --- /dev/null +++ b/changelog.d/6573.bugfix @@ -0,0 +1 @@ +Don't attempt to use an invalid sqlite config if no database configuration is provided. Contributed by @nekatak. diff --git a/synapse/config/database.py b/synapse/config/database.py index b8ab2f86ac..c27fef157b 100644 --- a/synapse/config/database.py +++ b/synapse/config/database.py @@ -20,6 +20,11 @@ from synapse.config._base import Config, ConfigError logger = logging.getLogger(__name__) +NON_SQLITE_DATABASE_PATH_WARNING = """\ +Ignoring 'database_path' setting: not using a sqlite3 database. +-------------------------------------------------------------------------------- +""" + DEFAULT_CONFIG = """\ ## Database ## @@ -105,6 +110,11 @@ class DatabaseConnectionConfig: class DatabaseConfig(Config): section = "database" + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.databases = [] + def read_config(self, config, **kwargs): self.event_cache_size = self.parse_size(config.get("event_cache_size", "10K")) @@ -125,12 +135,13 @@ class DatabaseConfig(Config): multi_database_config = config.get("databases") database_config = config.get("database") + database_path = config.get("database_path") if multi_database_config and database_config: raise ConfigError("Can't specify both 'database' and 'datbases' in config") if multi_database_config: - if config.get("database_path"): + if database_path: raise ConfigError("Can't specify 'database_path' with 'databases'") self.databases = [ @@ -138,13 +149,17 @@ class DatabaseConfig(Config): for name, db_conf in multi_database_config.items() ] - else: - if database_config is None: - database_config = {"name": "sqlite3", "args": {}} - + if database_config: self.databases = [DatabaseConnectionConfig("master", database_config)] - self.set_databasepath(config.get("database_path")) + if database_path: + if self.databases and self.databases[0].name != "sqlite3": + logger.warning(NON_SQLITE_DATABASE_PATH_WARNING) + return + + database_config = {"name": "sqlite3", "args": {}} + self.databases = [DatabaseConnectionConfig("master", database_config)] + self.set_databasepath(database_path) def generate_config_section(self, data_dir_path, **kwargs): return DEFAULT_CONFIG % { @@ -152,27 +167,37 @@ class DatabaseConfig(Config): } def read_arguments(self, args): - self.set_databasepath(args.database_path) + """ + Cases for the cli input: + - If no databases are configured and no database_path is set, raise. + - No databases and only database_path available ==> sqlite3 db. + - If there are multiple databases and a database_path raise an error. + - If the database set in the config file is sqlite then + overwrite with the command line argument. + """ - def set_databasepath(self, database_path): - if database_path is None: + if args.database_path is None: + if not self.databases: + raise ConfigError("No database config provided") return - if database_path != ":memory:": - database_path = self.abspath(database_path) + if len(self.databases) == 0: + database_config = {"name": "sqlite3", "args": {}} + self.databases = [DatabaseConnectionConfig("master", database_config)] + self.set_databasepath(args.database_path) + return + + if self.get_single_database().name == "sqlite3": + self.set_databasepath(args.database_path) + else: + logger.warning(NON_SQLITE_DATABASE_PATH_WARNING) - # We only support setting a database path if we have a single sqlite3 - # database. - if len(self.databases) != 1: - raise ConfigError("Cannot specify 'database_path' with multiple databases") + def set_databasepath(self, database_path): - database = self.get_single_database() - if database.config["name"] != "sqlite3": - # We don't raise here as we haven't done so before for this case. - logger.warn("Ignoring 'database_path' for non-sqlite3 database") - return + if database_path != ":memory:": + database_path = self.abspath(database_path) - database.config["args"]["database"] = database_path + self.databases[0].config["args"]["database"] = database_path @staticmethod def add_arguments(parser): @@ -187,7 +212,7 @@ class DatabaseConfig(Config): def get_single_database(self) -> DatabaseConnectionConfig: """Returns the database if there is only one, useful for e.g. tests """ - if len(self.databases) != 1: + if not self.databases: raise Exception("More than one database exists") return self.databases[0] -- cgit 1.4.1 From c2ab0b30664439d09436a39e5448728c67b0414c Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Thu, 26 Mar 2020 18:58:58 +0100 Subject: Whitelist the login fallback by default for SSO --- synapse/config/sso.py | 17 ++++++++++++++++- tests/rest/client/v1/test_login.py | 13 ++++++++++++- 2 files changed, 28 insertions(+), 2 deletions(-) (limited to 'synapse') diff --git a/synapse/config/sso.py b/synapse/config/sso.py index 95762689bc..5ae9db83d0 100644 --- a/synapse/config/sso.py +++ b/synapse/config/sso.py @@ -39,6 +39,17 @@ class SSOConfig(Config): self.sso_client_whitelist = sso_config.get("client_whitelist") or [] + # Attempt to also whitelist the server's login fallback, since that fallback sets + # the redirect URL to itself (so it can process the login token then return + # gracefully to the client). This would make it pointless to ask the user for + # confirmation, since the URL the confirmation page would be showing wouldn't be + # the client's. + # public_baseurl is an optional setting, so we only add the fallback's URL to the + # list if it's provided (because we can't figure out what that URL is otherwise). + if self.public_baseurl: + login_fallback_url = self.public_baseurl + "_matrix/static/client/login" + self.sso_client_whitelist.append(login_fallback_url) + def generate_config_section(self, **kwargs): return """\ # Additional settings to use with single-sign on systems such as SAML2 and CAS. @@ -54,7 +65,11 @@ class SSOConfig(Config): # phishing attacks from evil.site. To avoid this, include a slash after the # hostname: "https://my.client/". # - # By default, this list is empty. + # If public_baseurl is set, then the login fallback page (used by clients + # that don't have full support for SSO) is always included in this list. + # + # By default, this list is empty, except if public_baseurl is set (in which + # case the login fallback page is the only element in the list). # #client_whitelist: # - https://riot.im/develop diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index da2c9bfa1e..ed02ff1dcc 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -350,7 +350,18 @@ class CASRedirectConfirmTestCase(unittest.HomeserverTestCase): def test_cas_redirect_whitelisted(self): """Tests that the SSO login flow serves a redirect to a whitelisted url """ - redirect_url = "https://legit-site.com/" + self._test_redirect("https://legit-site.com/") + + @override_config( + { + "public_baseurl": "https://example.com" + } + ) + def test_cas_redirect_login_fallback(self): + self._test_redirect("https://example.com/_matrix/static/client/login") + + def _test_redirect(self, redirect_url): + """Tests that the SSO login flow serves a redirect for the given redirect URL.""" cas_ticket_url = ( "/_matrix/client/r0/login/cas/ticket?redirectUrl=%s&ticket=ticket" % (urllib.parse.quote(redirect_url)) -- cgit 1.4.1 From fa4f12102d52b75d252d9209b45251d2b1591fdf Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 26 Mar 2020 15:05:26 -0400 Subject: Refactor the CAS code (move the logic out of the REST layer to a handler) (#7136) --- changelog.d/7136.misc | 1 + synapse/handlers/cas_handler.py | 204 ++++++++++++++++++++++++++++++++++++++++ synapse/rest/client/v1/login.py | 171 ++++----------------------------- synapse/server.py | 5 + tox.ini | 1 + 5 files changed, 227 insertions(+), 155 deletions(-) create mode 100644 changelog.d/7136.misc create mode 100644 synapse/handlers/cas_handler.py (limited to 'synapse') diff --git a/changelog.d/7136.misc b/changelog.d/7136.misc new file mode 100644 index 0000000000..3f666d25fd --- /dev/null +++ b/changelog.d/7136.misc @@ -0,0 +1 @@ +Refactored the CAS authentication logic to a separate class. diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py new file mode 100644 index 0000000000..f8dc274b78 --- /dev/null +++ b/synapse/handlers/cas_handler.py @@ -0,0 +1,204 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import xml.etree.ElementTree as ET +from typing import AnyStr, Dict, Optional, Tuple + +from six.moves import urllib + +from twisted.web.client import PartialDownloadError + +from synapse.api.errors import Codes, LoginError +from synapse.http.site import SynapseRequest +from synapse.types import UserID, map_username_to_mxid_localpart + +logger = logging.getLogger(__name__) + + +class CasHandler: + """ + Utility class for to handle the response from a CAS SSO service. + + Args: + hs (synapse.server.HomeServer) + """ + + def __init__(self, hs): + self._hostname = hs.hostname + self._auth_handler = hs.get_auth_handler() + self._registration_handler = hs.get_registration_handler() + + self._cas_server_url = hs.config.cas_server_url + self._cas_service_url = hs.config.cas_service_url + self._cas_displayname_attribute = hs.config.cas_displayname_attribute + self._cas_required_attributes = hs.config.cas_required_attributes + + self._http_client = hs.get_proxied_http_client() + + def _build_service_param(self, client_redirect_url: AnyStr) -> str: + return "%s%s?%s" % ( + self._cas_service_url, + "/_matrix/client/r0/login/cas/ticket", + urllib.parse.urlencode({"redirectUrl": client_redirect_url}), + ) + + async def _handle_cas_response( + self, request: SynapseRequest, cas_response_body: str, client_redirect_url: str + ) -> None: + """ + Retrieves the user and display name from the CAS response and continues with the authentication. + + Args: + request: The original client request. + cas_response_body: The response from the CAS server. + client_redirect_url: The URl to redirect the client to when + everything is done. + """ + user, attributes = self._parse_cas_response(cas_response_body) + displayname = attributes.pop(self._cas_displayname_attribute, None) + + for required_attribute, required_value in self._cas_required_attributes.items(): + # If required attribute was not in CAS Response - Forbidden + if required_attribute not in attributes: + raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED) + + # Also need to check value + if required_value is not None: + actual_value = attributes[required_attribute] + # If required attribute value does not match expected - Forbidden + if required_value != actual_value: + raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED) + + await self._on_successful_auth(user, request, client_redirect_url, displayname) + + def _parse_cas_response( + self, cas_response_body: str + ) -> Tuple[str, Dict[str, Optional[str]]]: + """ + Retrieve the user and other parameters from the CAS response. + + Args: + cas_response_body: The response from the CAS query. + + Returns: + A tuple of the user and a mapping of other attributes. + """ + user = None + attributes = {} + try: + root = ET.fromstring(cas_response_body) + if not root.tag.endswith("serviceResponse"): + raise Exception("root of CAS response is not serviceResponse") + success = root[0].tag.endswith("authenticationSuccess") + for child in root[0]: + if child.tag.endswith("user"): + user = child.text + if child.tag.endswith("attributes"): + for attribute in child: + # ElementTree library expands the namespace in + # attribute tags to the full URL of the namespace. + # We don't care about namespace here and it will always + # be encased in curly braces, so we remove them. + tag = attribute.tag + if "}" in tag: + tag = tag.split("}")[1] + attributes[tag] = attribute.text + if user is None: + raise Exception("CAS response does not contain user") + except Exception: + logger.exception("Error parsing CAS response") + raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED) + if not success: + raise LoginError( + 401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED + ) + return user, attributes + + async def _on_successful_auth( + self, + username: str, + request: SynapseRequest, + client_redirect_url: str, + user_display_name: Optional[str] = None, + ) -> None: + """Called once the user has successfully authenticated with the SSO. + + Registers the user if necessary, and then returns a redirect (with + a login token) to the client. + + Args: + username: the remote user id. We'll map this onto + something sane for a MXID localpath. + + request: the incoming request from the browser. We'll + respond to it with a redirect. + + client_redirect_url: the redirect_url the client gave us when + it first started the process. + + user_display_name: if set, and we have to register a new user, + we will set their displayname to this. + """ + localpart = map_username_to_mxid_localpart(username) + user_id = UserID(localpart, self._hostname).to_string() + registered_user_id = await self._auth_handler.check_user_exists(user_id) + if not registered_user_id: + registered_user_id = await self._registration_handler.register_user( + localpart=localpart, default_display_name=user_display_name + ) + + self._auth_handler.complete_sso_login( + registered_user_id, request, client_redirect_url + ) + + def handle_redirect_request(self, client_redirect_url: bytes) -> bytes: + """ + Generates a URL to the CAS server where the client should be redirected. + + Args: + client_redirect_url: The final URL the client should go to after the + user has negotiated SSO. + + Returns: + The URL to redirect to. + """ + args = urllib.parse.urlencode( + {"service": self._build_service_param(client_redirect_url)} + ) + + return ("%s/login?%s" % (self._cas_server_url, args)).encode("ascii") + + async def handle_ticket_request( + self, request: SynapseRequest, client_redirect_url: str, ticket: str + ) -> None: + """ + Validates a CAS ticket sent by the client for login/registration. + + On a successful request, writes a redirect to the request. + """ + uri = self._cas_server_url + "/proxyValidate" + args = { + "ticket": ticket, + "service": self._build_service_param(client_redirect_url), + } + try: + body = await self._http_client.get_raw(uri, args) + except PartialDownloadError as pde: + # Twisted raises this error if the connection is closed, + # even if that's being used old-http style to signal end-of-data + body = pde.response + + await self._handle_cas_response(request, body, client_redirect_url) diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 56d713462a..59593cbf6e 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -14,11 +14,6 @@ # limitations under the License. import logging -import xml.etree.ElementTree as ET - -from six.moves import urllib - -from twisted.web.client import PartialDownloadError from synapse.api.errors import Codes, LoginError, SynapseError from synapse.api.ratelimiting import Ratelimiter @@ -28,9 +23,10 @@ from synapse.http.servlet import ( parse_json_object_from_request, parse_string, ) +from synapse.http.site import SynapseRequest from synapse.rest.client.v2_alpha._base import client_patterns from synapse.rest.well_known import WellKnownBuilder -from synapse.types import UserID, map_username_to_mxid_localpart +from synapse.types import UserID from synapse.util.msisdn import phone_number_to_msisdn logger = logging.getLogger(__name__) @@ -72,14 +68,6 @@ def login_id_thirdparty_from_phone(identifier): return {"type": "m.id.thirdparty", "medium": "msisdn", "address": msisdn} -def build_service_param(cas_service_url, client_redirect_url): - return "%s%s?redirectUrl=%s" % ( - cas_service_url, - "/_matrix/client/r0/login/cas/ticket", - urllib.parse.quote(client_redirect_url, safe=""), - ) - - class LoginRestServlet(RestServlet): PATTERNS = client_patterns("/login$", v1=True) CAS_TYPE = "m.login.cas" @@ -409,7 +397,7 @@ class BaseSSORedirectServlet(RestServlet): PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True) - def on_GET(self, request): + def on_GET(self, request: SynapseRequest): args = request.args if b"redirectUrl" not in args: return 400, "Redirect URL not specified for SSO auth" @@ -418,15 +406,15 @@ class BaseSSORedirectServlet(RestServlet): request.redirect(sso_url) finish_request(request) - def get_sso_url(self, client_redirect_url): + def get_sso_url(self, client_redirect_url: bytes) -> bytes: """Get the URL to redirect to, to perform SSO auth Args: - client_redirect_url (bytes): the URL that we should redirect the + client_redirect_url: the URL that we should redirect the client to when everything is done Returns: - bytes: URL to redirect to + URL to redirect to """ # to be implemented by subclasses raise NotImplementedError() @@ -434,16 +422,10 @@ class BaseSSORedirectServlet(RestServlet): class CasRedirectServlet(BaseSSORedirectServlet): def __init__(self, hs): - super(CasRedirectServlet, self).__init__() - self.cas_server_url = hs.config.cas_server_url - self.cas_service_url = hs.config.cas_service_url + self._cas_handler = hs.get_cas_handler() - def get_sso_url(self, client_redirect_url): - args = urllib.parse.urlencode( - {"service": build_service_param(self.cas_service_url, client_redirect_url)} - ) - - return "%s/login?%s" % (self.cas_server_url, args) + def get_sso_url(self, client_redirect_url: bytes) -> bytes: + return self._cas_handler.handle_redirect_request(client_redirect_url) class CasTicketServlet(RestServlet): @@ -451,81 +433,15 @@ class CasTicketServlet(RestServlet): def __init__(self, hs): super(CasTicketServlet, self).__init__() - self.cas_server_url = hs.config.cas_server_url - self.cas_service_url = hs.config.cas_service_url - self.cas_displayname_attribute = hs.config.cas_displayname_attribute - self.cas_required_attributes = hs.config.cas_required_attributes - self._sso_auth_handler = SSOAuthHandler(hs) - self._http_client = hs.get_proxied_http_client() - - async def on_GET(self, request): - client_redirect_url = parse_string(request, "redirectUrl", required=True) - uri = self.cas_server_url + "/proxyValidate" - args = { - "ticket": parse_string(request, "ticket", required=True), - "service": build_service_param(self.cas_service_url, client_redirect_url), - } - try: - body = await self._http_client.get_raw(uri, args) - except PartialDownloadError as pde: - # Twisted raises this error if the connection is closed, - # even if that's being used old-http style to signal end-of-data - body = pde.response - result = await self.handle_cas_response(request, body, client_redirect_url) - return result + self._cas_handler = hs.get_cas_handler() - def handle_cas_response(self, request, cas_response_body, client_redirect_url): - user, attributes = self.parse_cas_response(cas_response_body) - displayname = attributes.pop(self.cas_displayname_attribute, None) - - for required_attribute, required_value in self.cas_required_attributes.items(): - # If required attribute was not in CAS Response - Forbidden - if required_attribute not in attributes: - raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED) - - # Also need to check value - if required_value is not None: - actual_value = attributes[required_attribute] - # If required attribute value does not match expected - Forbidden - if required_value != actual_value: - raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED) - - return self._sso_auth_handler.on_successful_auth( - user, request, client_redirect_url, displayname + async def on_GET(self, request: SynapseRequest) -> None: + client_redirect_url = parse_string(request, "redirectUrl", required=True) + ticket = parse_string(request, "ticket", required=True) + await self._cas_handler.handle_ticket_request( + request, client_redirect_url, ticket ) - def parse_cas_response(self, cas_response_body): - user = None - attributes = {} - try: - root = ET.fromstring(cas_response_body) - if not root.tag.endswith("serviceResponse"): - raise Exception("root of CAS response is not serviceResponse") - success = root[0].tag.endswith("authenticationSuccess") - for child in root[0]: - if child.tag.endswith("user"): - user = child.text - if child.tag.endswith("attributes"): - for attribute in child: - # ElementTree library expands the namespace in - # attribute tags to the full URL of the namespace. - # We don't care about namespace here and it will always - # be encased in curly braces, so we remove them. - tag = attribute.tag - if "}" in tag: - tag = tag.split("}")[1] - attributes[tag] = attribute.text - if user is None: - raise Exception("CAS response does not contain user") - except Exception: - logger.exception("Error parsing CAS response") - raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED) - if not success: - raise LoginError( - 401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED - ) - return user, attributes - class SAMLRedirectServlet(BaseSSORedirectServlet): PATTERNS = client_patterns("/login/sso/redirect", v1=True) @@ -533,65 +449,10 @@ class SAMLRedirectServlet(BaseSSORedirectServlet): def __init__(self, hs): self._saml_handler = hs.get_saml_handler() - def get_sso_url(self, client_redirect_url): + def get_sso_url(self, client_redirect_url: bytes) -> bytes: return self._saml_handler.handle_redirect_request(client_redirect_url) -class SSOAuthHandler(object): - """ - Utility class for Resources and Servlets which handle the response from a SSO - service - - Args: - hs (synapse.server.HomeServer) - """ - - def __init__(self, hs): - self._hostname = hs.hostname - self._auth_handler = hs.get_auth_handler() - self._registration_handler = hs.get_registration_handler() - self._macaroon_gen = hs.get_macaroon_generator() - - # cast to tuple for use with str.startswith - self._whitelisted_sso_clients = tuple(hs.config.sso_client_whitelist) - - async def on_successful_auth( - self, username, request, client_redirect_url, user_display_name=None - ): - """Called once the user has successfully authenticated with the SSO. - - Registers the user if necessary, and then returns a redirect (with - a login token) to the client. - - Args: - username (unicode|bytes): the remote user id. We'll map this onto - something sane for a MXID localpath. - - request (SynapseRequest): the incoming request from the browser. We'll - respond to it with a redirect. - - client_redirect_url (unicode): the redirect_url the client gave us when - it first started the process. - - user_display_name (unicode|None): if set, and we have to register a new user, - we will set their displayname to this. - - Returns: - Deferred[none]: Completes once we have handled the request. - """ - localpart = map_username_to_mxid_localpart(username) - user_id = UserID(localpart, self._hostname).to_string() - registered_user_id = await self._auth_handler.check_user_exists(user_id) - if not registered_user_id: - registered_user_id = await self._registration_handler.register_user( - localpart=localpart, default_display_name=user_display_name - ) - - self._auth_handler.complete_sso_login( - registered_user_id, request, client_redirect_url - ) - - def register_servlets(hs, http_server): LoginRestServlet(hs).register(http_server) if hs.config.cas_enabled: diff --git a/synapse/server.py b/synapse/server.py index d0d80e8ac5..c7ca2bda0d 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -56,6 +56,7 @@ from synapse.handlers.account_validity import AccountValidityHandler from synapse.handlers.acme import AcmeHandler from synapse.handlers.appservice import ApplicationServicesHandler from synapse.handlers.auth import AuthHandler, MacaroonGenerator +from synapse.handlers.cas_handler import CasHandler from synapse.handlers.deactivate_account import DeactivateAccountHandler from synapse.handlers.device import DeviceHandler, DeviceWorkerHandler from synapse.handlers.devicemessage import DeviceMessageHandler @@ -198,6 +199,7 @@ class HomeServer(object): "sendmail", "registration_handler", "account_validity_handler", + "cas_handler", "saml_handler", "event_client_serializer", "password_policy_handler", @@ -529,6 +531,9 @@ class HomeServer(object): def build_account_validity_handler(self): return AccountValidityHandler(self) + def build_cas_handler(self): + return CasHandler(self) + def build_saml_handler(self): from synapse.handlers.saml_handler import SamlHandler diff --git a/tox.ini b/tox.ini index 8e3f09e638..a79fc93b57 100644 --- a/tox.ini +++ b/tox.ini @@ -186,6 +186,7 @@ commands = mypy \ synapse/federation/sender \ synapse/federation/transport \ synapse/handlers/auth.py \ + synapse/handlers/cas_handler.py \ synapse/handlers/directory.py \ synapse/handlers/presence.py \ synapse/handlers/sync.py \ -- cgit 1.4.1 From 09cc058a4c3eae58ee6e08c1925bffd9cf7d52c6 Mon Sep 17 00:00:00 2001 From: David Baker Date: Fri, 27 Mar 2020 12:26:47 +0000 Subject: Always send the user updates to their own device list This will allow clients to notify users about new devices even if the user isn't in any rooms (yet). --- synapse/handlers/device.py | 19 ++++++++++++++++--- synapse/handlers/sync.py | 7 ++++--- 2 files changed, 20 insertions(+), 6 deletions(-) (limited to 'synapse') diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index a514c30714..54931c355b 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -122,11 +122,22 @@ class DeviceWorkerHandler(BaseHandler): # First we check if any devices have changed for users that we share # rooms with. - users_who_share_room = yield self.store.get_users_who_share_room_with_user( + tracked_users = yield self.store.get_users_who_share_room_with_user( user_id ) + # always tell the user about their own devices + tracked_users.add(user_id) + + logger.info( + "tracked users ids: %r", tracked_users, + ) + changed = yield self.store.get_users_whose_devices_changed( - from_token.device_list_key, users_who_share_room + from_token.device_list_key, tracked_users + ) + + logger.info( + "changed users IDs: %r", changed, ) # Then work out if any users have since joined @@ -456,7 +467,9 @@ class DeviceHandler(DeviceWorkerHandler): room_ids = yield self.store.get_rooms_for_user(user_id) - yield self.notifier.on_new_event("device_list_key", position, rooms=room_ids) + # specify the user ID too since the user should always get their own device list + # updates, even if they aren't in any rooms. + yield self.notifier.on_new_event("device_list_key", position, users=[user_id], rooms=room_ids) if hosts: logger.info( diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 5746fdea14..fd68a31b09 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -1139,13 +1139,14 @@ class SyncHandler(object): # room with by looking at all users that have left a room plus users # that were in a room we've left. - users_who_share_room = await self.store.get_users_who_share_room_with_user( + users_we_track = await self.store.get_users_who_share_room_with_user( user_id ) + users_we_track.add(user_id) # Step 1a, check for changes in devices of users we share a room with users_that_have_changed = await self.store.get_users_whose_devices_changed( - since_token.device_list_key, users_who_share_room + since_token.device_list_key, users_we_track ) # Step 1b, check for newly joined rooms @@ -1168,7 +1169,7 @@ class SyncHandler(object): newly_left_users.update(left_users) # Remove any users that we still share a room with. - newly_left_users -= users_who_share_room + newly_left_users -= users_we_track return DeviceLists(changed=users_that_have_changed, left=newly_left_users) else: -- cgit 1.4.1 From a07e03ce908f3de4a5ca80d0d7db6da2a6ed98b5 Mon Sep 17 00:00:00 2001 From: David Baker Date: Fri, 27 Mar 2020 12:35:32 +0000 Subject: black --- synapse/handlers/device.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) (limited to 'synapse') diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 54931c355b..80e7374c6a 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -122,9 +122,7 @@ class DeviceWorkerHandler(BaseHandler): # First we check if any devices have changed for users that we share # rooms with. - tracked_users = yield self.store.get_users_who_share_room_with_user( - user_id - ) + tracked_users = yield self.store.get_users_who_share_room_with_user(user_id) # always tell the user about their own devices tracked_users.add(user_id) @@ -469,7 +467,9 @@ class DeviceHandler(DeviceWorkerHandler): # specify the user ID too since the user should always get their own device list # updates, even if they aren't in any rooms. - yield self.notifier.on_new_event("device_list_key", position, users=[user_id], rooms=room_ids) + yield self.notifier.on_new_event( + "device_list_key", position, users=[user_id], rooms=room_ids + ) if hosts: logger.info( -- cgit 1.4.1 From 16ee97988abb1951e7009620f69904f9bb7c9215 Mon Sep 17 00:00:00 2001 From: David Baker Date: Fri, 27 Mar 2020 12:39:54 +0000 Subject: Fix undefined variable & remove debug logging --- synapse/handlers/device.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) (limited to 'synapse') diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 80e7374c6a..07ce553995 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -126,18 +126,10 @@ class DeviceWorkerHandler(BaseHandler): # always tell the user about their own devices tracked_users.add(user_id) - logger.info( - "tracked users ids: %r", tracked_users, - ) - changed = yield self.store.get_users_whose_devices_changed( from_token.device_list_key, tracked_users ) - logger.info( - "changed users IDs: %r", changed, - ) - # Then work out if any users have since joined rooms_changed = self.store.get_rooms_that_changed(room_ids, from_token.room_key) @@ -225,8 +217,8 @@ class DeviceWorkerHandler(BaseHandler): if possibly_changed or possibly_left: # Take the intersection of the users whose devices may have changed # and those that actually still share a room with the user - possibly_joined = possibly_changed & users_who_share_room - possibly_left = (possibly_changed | possibly_left) - users_who_share_room + possibly_joined = possibly_changed & tracked_users + possibly_left = (possibly_changed | possibly_left) - tracked_users else: possibly_joined = [] possibly_left = [] -- cgit 1.4.1 From fbf0782c63bd2aba3c504dabd04abdf10d269a22 Mon Sep 17 00:00:00 2001 From: David Vo Date: Sat, 28 Mar 2020 00:20:00 +1100 Subject: Only import sqlite3 when type checking (#7155) Fixes: #7127 Signed-off-by: David Vo --- changelog.d/7155.bugfix | 1 + synapse/storage/engines/sqlite.py | 7 +++++-- 2 files changed, 6 insertions(+), 2 deletions(-) create mode 100644 changelog.d/7155.bugfix (limited to 'synapse') diff --git a/changelog.d/7155.bugfix b/changelog.d/7155.bugfix new file mode 100644 index 0000000000..0bf51e7aba --- /dev/null +++ b/changelog.d/7155.bugfix @@ -0,0 +1 @@ +Avoid importing `sqlite3` when using the postgres backend. Contributed by David Vo. diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py index 2bfeefd54e..3bc2e8b986 100644 --- a/synapse/storage/engines/sqlite.py +++ b/synapse/storage/engines/sqlite.py @@ -12,14 +12,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import sqlite3 import struct import threading +import typing from synapse.storage.engines import BaseDatabaseEngine +if typing.TYPE_CHECKING: + import sqlite3 # noqa: F401 -class Sqlite3Engine(BaseDatabaseEngine[sqlite3.Connection]): + +class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]): def __init__(self, database_module, database_config): super().__init__(database_module, database_config) -- cgit 1.4.1 From 12aa5a7fa761a729364d324405a033cf78da26de Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Fri, 27 Mar 2020 13:30:22 +0000 Subject: Ensure is_verified on /_matrix/client/r0/room_keys/keys is a boolean (#7150) --- changelog.d/7150.bugfix | 1 + synapse/rest/client/v2_alpha/room_keys.py | 2 +- synapse/storage/data_stores/main/e2e_room_keys.py | 3 ++- 3 files changed, 4 insertions(+), 2 deletions(-) create mode 100644 changelog.d/7150.bugfix (limited to 'synapse') diff --git a/changelog.d/7150.bugfix b/changelog.d/7150.bugfix new file mode 100644 index 0000000000..1feb294799 --- /dev/null +++ b/changelog.d/7150.bugfix @@ -0,0 +1 @@ +Ensure `is_verified` is a boolean in responses to `GET /_matrix/client/r0/room_keys/keys`. Also warn the user if they forgot the `version` query param. \ No newline at end of file diff --git a/synapse/rest/client/v2_alpha/room_keys.py b/synapse/rest/client/v2_alpha/room_keys.py index 38952a1d27..59529707df 100644 --- a/synapse/rest/client/v2_alpha/room_keys.py +++ b/synapse/rest/client/v2_alpha/room_keys.py @@ -188,7 +188,7 @@ class RoomKeysServlet(RestServlet): """ requester = await self.auth.get_user_by_req(request, allow_guest=False) user_id = requester.user.to_string() - version = parse_string(request, "version") + version = parse_string(request, "version", required=True) room_keys = await self.e2e_room_keys_handler.get_room_keys( user_id, version, room_id, session_id diff --git a/synapse/storage/data_stores/main/e2e_room_keys.py b/synapse/storage/data_stores/main/e2e_room_keys.py index 84594cf0a9..23f4570c4b 100644 --- a/synapse/storage/data_stores/main/e2e_room_keys.py +++ b/synapse/storage/data_stores/main/e2e_room_keys.py @@ -146,7 +146,8 @@ class EndToEndRoomKeyStore(SQLBaseStore): room_entry["sessions"][row["session_id"]] = { "first_message_index": row["first_message_index"], "forwarded_count": row["forwarded_count"], - "is_verified": row["is_verified"], + # is_verified must be returned to the client as a boolean + "is_verified": bool(row["is_verified"]), "session_data": json.loads(row["session_data"]), } -- cgit 1.4.1 From 63aea691a761a9b6a2058b54792fc5859e12cfba Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Fri, 27 Mar 2020 15:09:12 +0100 Subject: Update the wording of the config comment --- docs/sample_config.yaml | 6 +++--- synapse/config/sso.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) (limited to 'synapse') diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 556d4419f5..07e922dc27 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1393,10 +1393,10 @@ sso: # hostname: "https://my.client/". # # If public_baseurl is set, then the login fallback page (used by clients - # that don't have full support for SSO) is always included in this list. + # that don't natively support the required login flows) is whitelisted in + # addition to any URLs in this list. # - # By default, this list is empty, except if public_baseurl is set (in which - # case the login fallback page is the only element in the list). + # By default, this list is empty. # #client_whitelist: # - https://riot.im/develop diff --git a/synapse/config/sso.py b/synapse/config/sso.py index 5ae9db83d0..ec3dca9efc 100644 --- a/synapse/config/sso.py +++ b/synapse/config/sso.py @@ -66,10 +66,10 @@ class SSOConfig(Config): # hostname: "https://my.client/". # # If public_baseurl is set, then the login fallback page (used by clients - # that don't have full support for SSO) is always included in this list. + # that don't natively support the required login flows) is whitelisted in + # addition to any URLs in this list. # - # By default, this list is empty, except if public_baseurl is set (in which - # case the login fallback page is the only element in the list). + # By default, this list is empty. # #client_whitelist: # - https://riot.im/develop -- cgit 1.4.1 From 90246344e340bce3417fb330da6be9338a701c5c Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Fri, 27 Mar 2020 15:44:13 +0100 Subject: Improve the UX of the login fallback when using SSO (#7152) * Don't show the login forms if we're currently logging in with a password or a token. * Submit directly the SSO login form, showing only a spinner to the user, in order to eliminate from the clunkiness of SSO through this fallback. --- changelog.d/7152.feature | 1 + synapse/static/client/login/index.html | 2 +- synapse/static/client/login/js/login.js | 51 +++++++++++++++++++-------------- 3 files changed, 32 insertions(+), 22 deletions(-) create mode 100644 changelog.d/7152.feature (limited to 'synapse') diff --git a/changelog.d/7152.feature b/changelog.d/7152.feature new file mode 100644 index 0000000000..fafa79c7e7 --- /dev/null +++ b/changelog.d/7152.feature @@ -0,0 +1 @@ +Improve the support for SSO authentication on the login fallback page. diff --git a/synapse/static/client/login/index.html b/synapse/static/client/login/index.html index bcb6bc6bb7..712b0e3980 100644 --- a/synapse/static/client/login/index.html +++ b/synapse/static/client/login/index.html @@ -9,7 +9,7 @@

-

Log in with one of the following methods

+

diff --git a/synapse/static/client/login/js/login.js b/synapse/static/client/login/js/login.js index 276c271bbe..debe464371 100644 --- a/synapse/static/client/login/js/login.js +++ b/synapse/static/client/login/js/login.js @@ -1,37 +1,41 @@ window.matrixLogin = { endpoint: location.origin + "/_matrix/client/r0/login", serverAcceptsPassword: false, - serverAcceptsCas: false, serverAcceptsSso: false, }; +var title_pre_auth = "Log in with one of the following methods"; +var title_post_auth = "Logging in..."; + var submitPassword = function(user, pwd) { console.log("Logging in with password..."); + set_title(title_post_auth); var data = { type: "m.login.password", user: user, password: pwd, }; $.post(matrixLogin.endpoint, JSON.stringify(data), function(response) { - show_login(); matrixLogin.onLogin(response); }).error(errorFunc); }; var submitToken = function(loginToken) { console.log("Logging in with login token..."); + set_title(title_post_auth); var data = { type: "m.login.token", token: loginToken }; $.post(matrixLogin.endpoint, JSON.stringify(data), function(response) { - show_login(); matrixLogin.onLogin(response); }).error(errorFunc); }; var errorFunc = function(err) { - show_login(); + // We want to show the error to the user rather than redirecting immediately to the + // SSO portal (if SSO is the only login option), so we inhibit the redirect. + show_login(true); if (err.responseJSON && err.responseJSON.error) { setFeedbackString(err.responseJSON.error + " (" + err.responseJSON.errcode + ")"); @@ -45,26 +49,33 @@ var setFeedbackString = function(text) { $("#feedback").text(text); }; -var show_login = function() { - $("#loading").hide(); - +var show_login = function(inhibit_redirect) { var this_page = window.location.origin + window.location.pathname; $("#sso_redirect_url").val(this_page); - if (matrixLogin.serverAcceptsPassword) { - $("#password_flow").show(); + // If inhibit_redirect is false, and SSO is the only supported login method, we can + // redirect straight to the SSO page + if (matrixLogin.serverAcceptsSso) { + if (!inhibit_redirect && !matrixLogin.serverAcceptsPassword) { + $("#sso_form").submit(); + return; + } + + // Otherwise, show the SSO form + $("#sso_form").show(); } - if (matrixLogin.serverAcceptsSso) { - $("#sso_flow").show(); - } else if (matrixLogin.serverAcceptsCas) { - $("#sso_form").attr("action", "/_matrix/client/r0/login/cas/redirect"); - $("#sso_flow").show(); + if (matrixLogin.serverAcceptsPassword) { + $("#password_flow").show(); } - if (!matrixLogin.serverAcceptsPassword && !matrixLogin.serverAcceptsCas && !matrixLogin.serverAcceptsSso) { + if (!matrixLogin.serverAcceptsPassword && !matrixLogin.serverAcceptsSso) { $("#no_login_types").show(); } + + set_title(title_pre_auth); + + $("#loading").hide(); }; var show_spinner = function() { @@ -74,17 +85,15 @@ var show_spinner = function() { $("#loading").show(); }; +var set_title = function(title) { + $("#title").text(title); +}; var fetch_info = function(cb) { $.get(matrixLogin.endpoint, function(response) { var serverAcceptsPassword = false; - var serverAcceptsCas = false; for (var i=0; i Date: Fri, 27 Mar 2020 20:15:23 +0100 Subject: Add options to prevent users from changing their profile. (#7096) --- changelog.d/7096.feature | 1 + docs/sample_config.yaml | 23 +++ synapse/config/registration.py | 27 +++ synapse/handlers/profile.py | 16 ++ synapse/rest/client/v2_alpha/account.py | 16 ++ tests/handlers/test_profile.py | 65 ++++++- tests/rest/client/v2_alpha/test_account.py | 302 +++++++++++++++++++++++++++++ 7 files changed, 449 insertions(+), 1 deletion(-) create mode 100644 changelog.d/7096.feature (limited to 'synapse') diff --git a/changelog.d/7096.feature b/changelog.d/7096.feature new file mode 100644 index 0000000000..00f47b2a14 --- /dev/null +++ b/changelog.d/7096.feature @@ -0,0 +1 @@ +Add options to prevent users from changing their profile or associated 3PIDs. \ No newline at end of file diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 1a1d061759..545226f753 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1086,6 +1086,29 @@ account_threepid_delegates: #email: https://example.com # Delegate email sending to example.com #msisdn: http://localhost:8090 # Delegate SMS sending to this local process +# Whether users are allowed to change their displayname after it has +# been initially set. Useful when provisioning users based on the +# contents of a third-party directory. +# +# Does not apply to server administrators. Defaults to 'true' +# +#enable_set_displayname: false + +# Whether users are allowed to change their avatar after it has been +# initially set. Useful when provisioning users based on the contents +# of a third-party directory. +# +# Does not apply to server administrators. Defaults to 'true' +# +#enable_set_avatar_url: false + +# Whether users can change the 3PIDs associated with their accounts +# (email address and msisdn). +# +# Defaults to 'true' +# +#enable_3pid_changes: false + # Users who register on this homeserver will automatically be joined # to these rooms # diff --git a/synapse/config/registration.py b/synapse/config/registration.py index 9bb3beedbc..e7ea3a01cb 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -129,6 +129,10 @@ class RegistrationConfig(Config): raise ConfigError("Invalid auto_join_rooms entry %s" % (room_alias,)) self.autocreate_auto_join_rooms = config.get("autocreate_auto_join_rooms", True) + self.enable_set_displayname = config.get("enable_set_displayname", True) + self.enable_set_avatar_url = config.get("enable_set_avatar_url", True) + self.enable_3pid_changes = config.get("enable_3pid_changes", True) + self.disable_msisdn_registration = config.get( "disable_msisdn_registration", False ) @@ -330,6 +334,29 @@ class RegistrationConfig(Config): #email: https://example.com # Delegate email sending to example.com #msisdn: http://localhost:8090 # Delegate SMS sending to this local process + # Whether users are allowed to change their displayname after it has + # been initially set. Useful when provisioning users based on the + # contents of a third-party directory. + # + # Does not apply to server administrators. Defaults to 'true' + # + #enable_set_displayname: false + + # Whether users are allowed to change their avatar after it has been + # initially set. Useful when provisioning users based on the contents + # of a third-party directory. + # + # Does not apply to server administrators. Defaults to 'true' + # + #enable_set_avatar_url: false + + # Whether users can change the 3PIDs associated with their accounts + # (email address and msisdn). + # + # Defaults to 'true' + # + #enable_3pid_changes: false + # Users who register on this homeserver will automatically be joined # to these rooms # diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 50ce0c585b..6aa1c0f5e0 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -157,6 +157,15 @@ class BaseProfileHandler(BaseHandler): if not by_admin and target_user != requester.user: raise AuthError(400, "Cannot set another user's displayname") + if not by_admin and not self.hs.config.enable_set_displayname: + profile = yield self.store.get_profileinfo(target_user.localpart) + if profile.display_name: + raise SynapseError( + 400, + "Changing display name is disabled on this server", + Codes.FORBIDDEN, + ) + if len(new_displayname) > MAX_DISPLAYNAME_LEN: raise SynapseError( 400, "Displayname is too long (max %i)" % (MAX_DISPLAYNAME_LEN,) @@ -218,6 +227,13 @@ class BaseProfileHandler(BaseHandler): if not by_admin and target_user != requester.user: raise AuthError(400, "Cannot set another user's avatar_url") + if not by_admin and not self.hs.config.enable_set_avatar_url: + profile = yield self.store.get_profileinfo(target_user.localpart) + if profile.avatar_url: + raise SynapseError( + 400, "Changing avatar is disabled on this server", Codes.FORBIDDEN + ) + if len(new_avatar_url) > MAX_AVATAR_URL_LEN: raise SynapseError( 400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN,) diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index b1249b664c..f80b5e40ea 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -605,6 +605,11 @@ class ThreepidRestServlet(RestServlet): return 200, {"threepids": threepids} async def on_POST(self, request): + if not self.hs.config.enable_3pid_changes: + raise SynapseError( + 400, "3PID changes are disabled on this server", Codes.FORBIDDEN + ) + requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() body = parse_json_object_from_request(request) @@ -649,6 +654,11 @@ class ThreepidAddRestServlet(RestServlet): @interactive_auth_handler async def on_POST(self, request): + if not self.hs.config.enable_3pid_changes: + raise SynapseError( + 400, "3PID changes are disabled on this server", Codes.FORBIDDEN + ) + requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() body = parse_json_object_from_request(request) @@ -744,10 +754,16 @@ class ThreepidDeleteRestServlet(RestServlet): def __init__(self, hs): super(ThreepidDeleteRestServlet, self).__init__() + self.hs = hs self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() async def on_POST(self, request): + if not self.hs.config.enable_3pid_changes: + raise SynapseError( + 400, "3PID changes are disabled on this server", Codes.FORBIDDEN + ) + body = parse_json_object_from_request(request) assert_params_in_dict(body, ["medium", "address"]) diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index d60c124eec..be665262c6 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -19,7 +19,7 @@ from mock import Mock, NonCallableMock from twisted.internet import defer import synapse.types -from synapse.api.errors import AuthError +from synapse.api.errors import AuthError, SynapseError from synapse.handlers.profile import MasterProfileHandler from synapse.types import UserID @@ -70,6 +70,7 @@ class ProfileTestCase(unittest.TestCase): yield self.store.create_profile(self.frank.localpart) self.handler = hs.get_profile_handler() + self.hs = hs @defer.inlineCallbacks def test_get_my_name(self): @@ -90,6 +91,33 @@ class ProfileTestCase(unittest.TestCase): "Frank Jr.", ) + # Set displayname again + yield self.handler.set_displayname( + self.frank, synapse.types.create_requester(self.frank), "Frank" + ) + + self.assertEquals( + (yield self.store.get_profile_displayname(self.frank.localpart)), "Frank", + ) + + @defer.inlineCallbacks + def test_set_my_name_if_disabled(self): + self.hs.config.enable_set_displayname = False + + # Setting displayname for the first time is allowed + yield self.store.set_profile_displayname(self.frank.localpart, "Frank") + + self.assertEquals( + (yield self.store.get_profile_displayname(self.frank.localpart)), "Frank", + ) + + # Setting displayname a second time is forbidden + d = self.handler.set_displayname( + self.frank, synapse.types.create_requester(self.frank), "Frank Jr." + ) + + yield self.assertFailure(d, SynapseError) + @defer.inlineCallbacks def test_set_my_name_noauth(self): d = self.handler.set_displayname( @@ -147,3 +175,38 @@ class ProfileTestCase(unittest.TestCase): (yield self.store.get_profile_avatar_url(self.frank.localpart)), "http://my.server/pic.gif", ) + + # Set avatar again + yield self.handler.set_avatar_url( + self.frank, + synapse.types.create_requester(self.frank), + "http://my.server/me.png", + ) + + self.assertEquals( + (yield self.store.get_profile_avatar_url(self.frank.localpart)), + "http://my.server/me.png", + ) + + @defer.inlineCallbacks + def test_set_my_avatar_if_disabled(self): + self.hs.config.enable_set_avatar_url = False + + # Setting displayname for the first time is allowed + yield self.store.set_profile_avatar_url( + self.frank.localpart, "http://my.server/me.png" + ) + + self.assertEquals( + (yield self.store.get_profile_avatar_url(self.frank.localpart)), + "http://my.server/me.png", + ) + + # Set avatar a second time is forbidden + d = self.handler.set_avatar_url( + self.frank, + synapse.types.create_requester(self.frank), + "http://my.server/pic.gif", + ) + + yield self.assertFailure(d, SynapseError) diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py index c3facc00eb..45a9d445f8 100644 --- a/tests/rest/client/v2_alpha/test_account.py +++ b/tests/rest/client/v2_alpha/test_account.py @@ -24,6 +24,7 @@ import pkg_resources import synapse.rest.admin from synapse.api.constants import LoginType, Membership +from synapse.api.errors import Codes from synapse.rest.client.v1 import login, room from synapse.rest.client.v2_alpha import account, register @@ -325,3 +326,304 @@ class DeactivateTestCase(unittest.HomeserverTestCase): ) self.render(request) self.assertEqual(request.code, 200) + + +class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): + + servlets = [ + account.register_servlets, + login.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, + ] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + + # Email config. + self.email_attempts = [] + + def sendmail(smtphost, from_addr, to_addrs, msg, **kwargs): + self.email_attempts.append(msg) + + config["email"] = { + "enable_notifs": False, + "template_dir": os.path.abspath( + pkg_resources.resource_filename("synapse", "res/templates") + ), + "smtp_host": "127.0.0.1", + "smtp_port": 20, + "require_transport_security": False, + "smtp_user": None, + "smtp_pass": None, + "notif_from": "test@example.com", + } + config["public_baseurl"] = "https://example.com" + + self.hs = self.setup_test_homeserver(config=config, sendmail=sendmail) + return self.hs + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + + self.user_id = self.register_user("kermit", "test") + self.user_id_tok = self.login("kermit", "test") + self.email = "test@example.com" + self.url_3pid = b"account/3pid" + + def test_add_email(self): + """Test adding an email to profile + """ + client_secret = "foobar" + session_id = self._request_token(self.email, client_secret) + + self.assertEquals(len(self.email_attempts), 1) + link = self._get_link_from_email() + + self._validate_token(link) + + request, channel = self.make_request( + "POST", + b"/_matrix/client/unstable/account/3pid/add", + { + "client_secret": client_secret, + "sid": session_id, + "auth": { + "type": "m.login.password", + "user": self.user_id, + "password": "test", + }, + }, + access_token=self.user_id_tok, + ) + + self.render(request) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Get user + request, channel = self.make_request( + "GET", self.url_3pid, access_token=self.user_id_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) + self.assertEqual(self.email, channel.json_body["threepids"][0]["address"]) + + def test_add_email_if_disabled(self): + """Test adding email to profile when doing so is disallowed + """ + self.hs.config.enable_3pid_changes = False + + client_secret = "foobar" + session_id = self._request_token(self.email, client_secret) + + self.assertEquals(len(self.email_attempts), 1) + link = self._get_link_from_email() + + self._validate_token(link) + + request, channel = self.make_request( + "POST", + b"/_matrix/client/unstable/account/3pid/add", + { + "client_secret": client_secret, + "sid": session_id, + "auth": { + "type": "m.login.password", + "user": self.user_id, + "password": "test", + }, + }, + access_token=self.user_id_tok, + ) + self.render(request) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + # Get user + request, channel = self.make_request( + "GET", self.url_3pid, access_token=self.user_id_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertFalse(channel.json_body["threepids"]) + + def test_delete_email(self): + """Test deleting an email from profile + """ + # Add a threepid + self.get_success( + self.store.user_add_threepid( + user_id=self.user_id, + medium="email", + address=self.email, + validated_at=0, + added_at=0, + ) + ) + + request, channel = self.make_request( + "POST", + b"account/3pid/delete", + {"medium": "email", "address": self.email}, + access_token=self.user_id_tok, + ) + self.render(request) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Get user + request, channel = self.make_request( + "GET", self.url_3pid, access_token=self.user_id_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertFalse(channel.json_body["threepids"]) + + def test_delete_email_if_disabled(self): + """Test deleting an email from profile when disallowed + """ + self.hs.config.enable_3pid_changes = False + + # Add a threepid + self.get_success( + self.store.user_add_threepid( + user_id=self.user_id, + medium="email", + address=self.email, + validated_at=0, + added_at=0, + ) + ) + + request, channel = self.make_request( + "POST", + b"account/3pid/delete", + {"medium": "email", "address": self.email}, + access_token=self.user_id_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + # Get user + request, channel = self.make_request( + "GET", self.url_3pid, access_token=self.user_id_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) + self.assertEqual(self.email, channel.json_body["threepids"][0]["address"]) + + def test_cant_add_email_without_clicking_link(self): + """Test that we do actually need to click the link in the email + """ + client_secret = "foobar" + session_id = self._request_token(self.email, client_secret) + + self.assertEquals(len(self.email_attempts), 1) + + # Attempt to add email without clicking the link + request, channel = self.make_request( + "POST", + b"/_matrix/client/unstable/account/3pid/add", + { + "client_secret": client_secret, + "sid": session_id, + "auth": { + "type": "m.login.password", + "user": self.user_id, + "password": "test", + }, + }, + access_token=self.user_id_tok, + ) + self.render(request) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"]) + + # Get user + request, channel = self.make_request( + "GET", self.url_3pid, access_token=self.user_id_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertFalse(channel.json_body["threepids"]) + + def test_no_valid_token(self): + """Test that we do actually need to request a token and can't just + make a session up. + """ + client_secret = "foobar" + session_id = "weasle" + + # Attempt to add email without even requesting an email + request, channel = self.make_request( + "POST", + b"/_matrix/client/unstable/account/3pid/add", + { + "client_secret": client_secret, + "sid": session_id, + "auth": { + "type": "m.login.password", + "user": self.user_id, + "password": "test", + }, + }, + access_token=self.user_id_tok, + ) + self.render(request) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"]) + + # Get user + request, channel = self.make_request( + "GET", self.url_3pid, access_token=self.user_id_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertFalse(channel.json_body["threepids"]) + + def _request_token(self, email, client_secret): + request, channel = self.make_request( + "POST", + b"account/3pid/email/requestToken", + {"client_secret": client_secret, "email": email, "send_attempt": 1}, + ) + self.render(request) + self.assertEquals(200, channel.code, channel.result) + + return channel.json_body["sid"] + + def _validate_token(self, link): + # Remove the host + path = link.replace("https://example.com", "") + + request, channel = self.make_request("GET", path, shorthand=False) + self.render(request) + self.assertEquals(200, channel.code, channel.result) + + def _get_link_from_email(self): + assert self.email_attempts, "No emails have been sent" + + raw_msg = self.email_attempts[-1].decode("UTF-8") + mail = Parser().parsestr(raw_msg) + + text = None + for part in mail.walk(): + if part.get_content_type() == "text/plain": + text = part.get_payload(decode=True).decode("UTF-8") + break + + if not text: + self.fail("Could not find text portion of email to parse") + + match = re.search(r"https://example.com\S+", text) + assert match, "Could not find link in email" + + return match.group(0) -- cgit 1.4.1 From fb69690761762092c8e44d509d4f72408c4c67e0 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Fri, 27 Mar 2020 20:16:43 +0100 Subject: Admin API to join users to a room. (#7051) --- changelog.d/7051.feature | 1 + docs/admin_api/room_membership.md | 34 +++++ synapse/rest/admin/__init__.py | 7 +- synapse/rest/admin/rooms.py | 79 ++++++++++- tests/rest/admin/test_room.py | 288 ++++++++++++++++++++++++++++++++++++++ 5 files changed, 405 insertions(+), 4 deletions(-) create mode 100644 changelog.d/7051.feature create mode 100644 docs/admin_api/room_membership.md create mode 100644 tests/rest/admin/test_room.py (limited to 'synapse') diff --git a/changelog.d/7051.feature b/changelog.d/7051.feature new file mode 100644 index 0000000000..3e36a3f65e --- /dev/null +++ b/changelog.d/7051.feature @@ -0,0 +1 @@ +Admin API `POST /_synapse/admin/v1/join/` to join users to a room like `auto_join_rooms` for creation of users. \ No newline at end of file diff --git a/docs/admin_api/room_membership.md b/docs/admin_api/room_membership.md new file mode 100644 index 0000000000..16736d3d37 --- /dev/null +++ b/docs/admin_api/room_membership.md @@ -0,0 +1,34 @@ +# Edit Room Membership API + +This API allows an administrator to join an user account with a given `user_id` +to a room with a given `room_id_or_alias`. You can only modify the membership of +local users. The server administrator must be in the room and have permission to +invite users. + +## Parameters + +The following parameters are available: + +* `user_id` - Fully qualified user: for example, `@user:server.com`. +* `room_id_or_alias` - The room identifier or alias to join: for example, + `!636q39766251:server.com`. + +## Usage + +``` +POST /_synapse/admin/v1/join/ + +{ + "user_id": "@user:server.com" +} +``` + +Including an `access_token` of a server admin. + +Response: + +``` +{ + "room_id": "!636q39766251:server.com" +} +``` diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index 42cc2b062a..ed70d448a1 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -29,7 +29,11 @@ from synapse.rest.admin._base import ( from synapse.rest.admin.groups import DeleteGroupAdminRestServlet from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_media_repo from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet -from synapse.rest.admin.rooms import ListRoomRestServlet, ShutdownRoomRestServlet +from synapse.rest.admin.rooms import ( + JoinRoomAliasServlet, + ListRoomRestServlet, + ShutdownRoomRestServlet, +) from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet from synapse.rest.admin.users import ( AccountValidityRenewServlet, @@ -189,6 +193,7 @@ def register_servlets(hs, http_server): """ register_servlets_for_client_rest_resource(hs, http_server) ListRoomRestServlet(hs).register(http_server) + JoinRoomAliasServlet(hs).register(http_server) PurgeRoomServlet(hs).register(http_server) SendServerNoticeServlet(hs).register(http_server) VersionServlet(hs).register(http_server) diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index f9b8c0a4f0..659b8a10ee 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -13,9 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from typing import List, Optional -from synapse.api.constants import Membership -from synapse.api.errors import Codes, SynapseError +from synapse.api.constants import EventTypes, JoinRules, Membership +from synapse.api.errors import Codes, NotFoundError, SynapseError from synapse.http.servlet import ( RestServlet, assert_params_in_dict, @@ -29,7 +30,7 @@ from synapse.rest.admin._base import ( historical_admin_path_patterns, ) from synapse.storage.data_stores.main.room import RoomSortOrder -from synapse.types import create_requester +from synapse.types import RoomAlias, RoomID, UserID, create_requester from synapse.util.async_helpers import maybe_awaitable logger = logging.getLogger(__name__) @@ -237,3 +238,75 @@ class ListRoomRestServlet(RestServlet): response["prev_batch"] = 0 return 200, response + + +class JoinRoomAliasServlet(RestServlet): + + PATTERNS = admin_patterns("/join/(?P[^/]*)") + + def __init__(self, hs): + self.hs = hs + self.auth = hs.get_auth() + self.room_member_handler = hs.get_room_member_handler() + self.admin_handler = hs.get_handlers().admin_handler + self.state_handler = hs.get_state_handler() + + async def on_POST(self, request, room_identifier): + requester = await self.auth.get_user_by_req(request) + await assert_user_is_admin(self.auth, requester.user) + + content = parse_json_object_from_request(request) + + assert_params_in_dict(content, ["user_id"]) + target_user = UserID.from_string(content["user_id"]) + + if not self.hs.is_mine(target_user): + raise SynapseError(400, "This endpoint can only be used with local users") + + if not await self.admin_handler.get_user(target_user): + raise NotFoundError("User not found") + + if RoomID.is_valid(room_identifier): + room_id = room_identifier + try: + remote_room_hosts = [ + x.decode("ascii") for x in request.args[b"server_name"] + ] # type: Optional[List[str]] + except Exception: + remote_room_hosts = None + elif RoomAlias.is_valid(room_identifier): + handler = self.room_member_handler + room_alias = RoomAlias.from_string(room_identifier) + room_id, remote_room_hosts = await handler.lookup_room_alias(room_alias) + room_id = room_id.to_string() + else: + raise SynapseError( + 400, "%s was not legal room ID or room alias" % (room_identifier,) + ) + + fake_requester = create_requester(target_user) + + # send invite if room has "JoinRules.INVITE" + room_state = await self.state_handler.get_current_state(room_id) + join_rules_event = room_state.get((EventTypes.JoinRules, "")) + if join_rules_event: + if not (join_rules_event.content.get("join_rule") == JoinRules.PUBLIC): + await self.room_member_handler.update_membership( + requester=requester, + target=fake_requester.user, + room_id=room_id, + action="invite", + remote_room_hosts=remote_room_hosts, + ratelimit=False, + ) + + await self.room_member_handler.update_membership( + requester=fake_requester, + target=fake_requester.user, + room_id=room_id, + action="join", + remote_room_hosts=remote_room_hosts, + ratelimit=False, + ) + + return 200, {"room_id": room_id} diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py new file mode 100644 index 0000000000..672cc3eac5 --- /dev/null +++ b/tests/rest/admin/test_room.py @@ -0,0 +1,288 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Dirk Klimpel +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +import synapse.rest.admin +from synapse.api.errors import Codes +from synapse.rest.client.v1 import login, room + +from tests import unittest + +"""Tests admin REST events for /rooms paths.""" + + +class JoinAliasRoomTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, homeserver): + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.creator = self.register_user("creator", "test") + self.creator_tok = self.login("creator", "test") + + self.second_user_id = self.register_user("second", "test") + self.second_tok = self.login("second", "test") + + self.public_room_id = self.helper.create_room_as( + self.creator, tok=self.creator_tok, is_public=True + ) + self.url = "/_synapse/admin/v1/join/{}".format(self.public_room_id) + + def test_requester_is_no_admin(self): + """ + If the user is not a server admin, an error 403 is returned. + """ + body = json.dumps({"user_id": self.second_user_id}) + + request, channel = self.make_request( + "POST", + self.url, + content=body.encode(encoding="utf_8"), + access_token=self.second_tok, + ) + self.render(request) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_invalid_parameter(self): + """ + If a parameter is missing, return an error + """ + body = json.dumps({"unknown_parameter": "@unknown:test"}) + + request, channel = self.make_request( + "POST", + self.url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) + + def test_local_user_does_not_exist(self): + """ + Tests that a lookup for a user that does not exist returns a 404 + """ + body = json.dumps({"user_id": "@unknown:test"}) + + request, channel = self.make_request( + "POST", + self.url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + + def test_remote_user(self): + """ + Check that only local user can join rooms. + """ + body = json.dumps({"user_id": "@not:exist.bla"}) + + request, channel = self.make_request( + "POST", + self.url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + "This endpoint can only be used with local users", + channel.json_body["error"], + ) + + def test_room_does_not_exist(self): + """ + Check that unknown rooms/server return error 404. + """ + body = json.dumps({"user_id": self.second_user_id}) + url = "/_synapse/admin/v1/join/!unknown:test" + + request, channel = self.make_request( + "POST", + url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("No known servers", channel.json_body["error"]) + + def test_room_is_not_valid(self): + """ + Check that invalid room names, return an error 400. + """ + body = json.dumps({"user_id": self.second_user_id}) + url = "/_synapse/admin/v1/join/invalidroom" + + request, channel = self.make_request( + "POST", + url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + "invalidroom was not legal room ID or room alias", + channel.json_body["error"], + ) + + def test_join_public_room(self): + """ + Test joining a local user to a public room with "JoinRules.PUBLIC" + """ + body = json.dumps({"user_id": self.second_user_id}) + + request, channel = self.make_request( + "POST", + self.url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(self.public_room_id, channel.json_body["room_id"]) + + # Validate if user is a member of the room + + request, channel = self.make_request( + "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, + ) + self.render(request) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(self.public_room_id, channel.json_body["joined_rooms"][0]) + + def test_join_private_room_if_not_member(self): + """ + Test joining a local user to a private room with "JoinRules.INVITE" + when server admin is not member of this room. + """ + private_room_id = self.helper.create_room_as( + self.creator, tok=self.creator_tok, is_public=False + ) + url = "/_synapse/admin/v1/join/{}".format(private_room_id) + body = json.dumps({"user_id": self.second_user_id}) + + request, channel = self.make_request( + "POST", + url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_join_private_room_if_member(self): + """ + Test joining a local user to a private room with "JoinRules.INVITE", + when server admin is member of this room. + """ + private_room_id = self.helper.create_room_as( + self.creator, tok=self.creator_tok, is_public=False + ) + self.helper.invite( + room=private_room_id, + src=self.creator, + targ=self.admin_user, + tok=self.creator_tok, + ) + self.helper.join( + room=private_room_id, user=self.admin_user, tok=self.admin_user_tok + ) + + # Validate if server admin is a member of the room + + request, channel = self.make_request( + "GET", "/_matrix/client/r0/joined_rooms", access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) + + # Join user to room. + + url = "/_synapse/admin/v1/join/{}".format(private_room_id) + body = json.dumps({"user_id": self.second_user_id}) + + request, channel = self.make_request( + "POST", + url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(private_room_id, channel.json_body["room_id"]) + + # Validate if user is a member of the room + + request, channel = self.make_request( + "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, + ) + self.render(request) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) + + def test_join_private_room_if_owner(self): + """ + Test joining a local user to a private room with "JoinRules.INVITE", + when server admin is owner of this room. + """ + private_room_id = self.helper.create_room_as( + self.admin_user, tok=self.admin_user_tok, is_public=False + ) + url = "/_synapse/admin/v1/join/{}".format(private_room_id) + body = json.dumps({"user_id": self.second_user_id}) + + request, channel = self.make_request( + "POST", + url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(private_room_id, channel.json_body["room_id"]) + + # Validate if user is a member of the room + + request, channel = self.make_request( + "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, + ) + self.render(request) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) -- cgit 1.4.1 From 84f7eaed16a8169f1d70d047c9354c8232b9fb9f Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Fri, 27 Mar 2020 15:44:13 +0100 Subject: Improve the UX of the login fallback when using SSO (#7152) * Don't show the login forms if we're currently logging in with a password or a token. * Submit directly the SSO login form, showing only a spinner to the user, in order to eliminate from the clunkiness of SSO through this fallback. --- changelog.d/7152.feature | 1 + synapse/static/client/login/index.html | 2 +- synapse/static/client/login/js/login.js | 51 +++++++++++++++++++-------------- 3 files changed, 32 insertions(+), 22 deletions(-) create mode 100644 changelog.d/7152.feature (limited to 'synapse') diff --git a/changelog.d/7152.feature b/changelog.d/7152.feature new file mode 100644 index 0000000000..fafa79c7e7 --- /dev/null +++ b/changelog.d/7152.feature @@ -0,0 +1 @@ +Improve the support for SSO authentication on the login fallback page. diff --git a/synapse/static/client/login/index.html b/synapse/static/client/login/index.html index bcb6bc6bb7..712b0e3980 100644 --- a/synapse/static/client/login/index.html +++ b/synapse/static/client/login/index.html @@ -9,7 +9,7 @@

-

Log in with one of the following methods

+

diff --git a/synapse/static/client/login/js/login.js b/synapse/static/client/login/js/login.js index 276c271bbe..debe464371 100644 --- a/synapse/static/client/login/js/login.js +++ b/synapse/static/client/login/js/login.js @@ -1,37 +1,41 @@ window.matrixLogin = { endpoint: location.origin + "/_matrix/client/r0/login", serverAcceptsPassword: false, - serverAcceptsCas: false, serverAcceptsSso: false, }; +var title_pre_auth = "Log in with one of the following methods"; +var title_post_auth = "Logging in..."; + var submitPassword = function(user, pwd) { console.log("Logging in with password..."); + set_title(title_post_auth); var data = { type: "m.login.password", user: user, password: pwd, }; $.post(matrixLogin.endpoint, JSON.stringify(data), function(response) { - show_login(); matrixLogin.onLogin(response); }).error(errorFunc); }; var submitToken = function(loginToken) { console.log("Logging in with login token..."); + set_title(title_post_auth); var data = { type: "m.login.token", token: loginToken }; $.post(matrixLogin.endpoint, JSON.stringify(data), function(response) { - show_login(); matrixLogin.onLogin(response); }).error(errorFunc); }; var errorFunc = function(err) { - show_login(); + // We want to show the error to the user rather than redirecting immediately to the + // SSO portal (if SSO is the only login option), so we inhibit the redirect. + show_login(true); if (err.responseJSON && err.responseJSON.error) { setFeedbackString(err.responseJSON.error + " (" + err.responseJSON.errcode + ")"); @@ -45,26 +49,33 @@ var setFeedbackString = function(text) { $("#feedback").text(text); }; -var show_login = function() { - $("#loading").hide(); - +var show_login = function(inhibit_redirect) { var this_page = window.location.origin + window.location.pathname; $("#sso_redirect_url").val(this_page); - if (matrixLogin.serverAcceptsPassword) { - $("#password_flow").show(); + // If inhibit_redirect is false, and SSO is the only supported login method, we can + // redirect straight to the SSO page + if (matrixLogin.serverAcceptsSso) { + if (!inhibit_redirect && !matrixLogin.serverAcceptsPassword) { + $("#sso_form").submit(); + return; + } + + // Otherwise, show the SSO form + $("#sso_form").show(); } - if (matrixLogin.serverAcceptsSso) { - $("#sso_flow").show(); - } else if (matrixLogin.serverAcceptsCas) { - $("#sso_form").attr("action", "/_matrix/client/r0/login/cas/redirect"); - $("#sso_flow").show(); + if (matrixLogin.serverAcceptsPassword) { + $("#password_flow").show(); } - if (!matrixLogin.serverAcceptsPassword && !matrixLogin.serverAcceptsCas && !matrixLogin.serverAcceptsSso) { + if (!matrixLogin.serverAcceptsPassword && !matrixLogin.serverAcceptsSso) { $("#no_login_types").show(); } + + set_title(title_pre_auth); + + $("#loading").hide(); }; var show_spinner = function() { @@ -74,17 +85,15 @@ var show_spinner = function() { $("#loading").show(); }; +var set_title = function(title) { + $("#title").text(title); +}; var fetch_info = function(cb) { $.get(matrixLogin.endpoint, function(response) { var serverAcceptsPassword = false; - var serverAcceptsCas = false; for (var i=0; i Date: Fri, 27 Mar 2020 20:24:52 +0000 Subject: Always whitelist the login fallback for SSO (#7153) That fallback sets the redirect URL to itself (so it can process the login token then return gracefully to the client). This would make it pointless to ask the user for confirmation, since the URL the confirmation page would be showing wouldn't be the client's. --- changelog.d/7153.feature | 1 + docs/sample_config.yaml | 4 ++++ synapse/config/sso.py | 15 +++++++++++++++ tests/rest/client/v1/test_login.py | 9 ++++++++- 4 files changed, 28 insertions(+), 1 deletion(-) create mode 100644 changelog.d/7153.feature (limited to 'synapse') diff --git a/changelog.d/7153.feature b/changelog.d/7153.feature new file mode 100644 index 0000000000..414ebe1f69 --- /dev/null +++ b/changelog.d/7153.feature @@ -0,0 +1 @@ +Always whitelist the login fallback in the SSO configuration if `public_baseurl` is set. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 545226f753..743949945a 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1444,6 +1444,10 @@ sso: # phishing attacks from evil.site. To avoid this, include a slash after the # hostname: "https://my.client/". # + # If public_baseurl is set, then the login fallback page (used by clients + # that don't natively support the required login flows) is whitelisted in + # addition to any URLs in this list. + # # By default, this list is empty. # #client_whitelist: diff --git a/synapse/config/sso.py b/synapse/config/sso.py index 95762689bc..ec3dca9efc 100644 --- a/synapse/config/sso.py +++ b/synapse/config/sso.py @@ -39,6 +39,17 @@ class SSOConfig(Config): self.sso_client_whitelist = sso_config.get("client_whitelist") or [] + # Attempt to also whitelist the server's login fallback, since that fallback sets + # the redirect URL to itself (so it can process the login token then return + # gracefully to the client). This would make it pointless to ask the user for + # confirmation, since the URL the confirmation page would be showing wouldn't be + # the client's. + # public_baseurl is an optional setting, so we only add the fallback's URL to the + # list if it's provided (because we can't figure out what that URL is otherwise). + if self.public_baseurl: + login_fallback_url = self.public_baseurl + "_matrix/static/client/login" + self.sso_client_whitelist.append(login_fallback_url) + def generate_config_section(self, **kwargs): return """\ # Additional settings to use with single-sign on systems such as SAML2 and CAS. @@ -54,6 +65,10 @@ class SSOConfig(Config): # phishing attacks from evil.site. To avoid this, include a slash after the # hostname: "https://my.client/". # + # If public_baseurl is set, then the login fallback page (used by clients + # that don't natively support the required login flows) is whitelisted in + # addition to any URLs in this list. + # # By default, this list is empty. # #client_whitelist: diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index da2c9bfa1e..aed8853d6e 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -350,7 +350,14 @@ class CASRedirectConfirmTestCase(unittest.HomeserverTestCase): def test_cas_redirect_whitelisted(self): """Tests that the SSO login flow serves a redirect to a whitelisted url """ - redirect_url = "https://legit-site.com/" + self._test_redirect("https://legit-site.com/") + + @override_config({"public_baseurl": "https://example.com"}) + def test_cas_redirect_login_fallback(self): + self._test_redirect("https://example.com/_matrix/static/client/login") + + def _test_redirect(self, redirect_url): + """Tests that the SSO login flow serves a redirect for the given redirect URL.""" cas_ticket_url = ( "/_matrix/client/r0/login/cas/ticket?redirectUrl=%s&ticket=ticket" % (urllib.parse.quote(redirect_url)) -- cgit 1.4.1 From 9fc588e6dc03bb64f569b4e27a786abd78c36218 Mon Sep 17 00:00:00 2001 From: David Baker Date: Mon, 30 Mar 2020 10:11:26 +0100 Subject: Just add own user ID to the list we track device changes for --- synapse/handlers/device.py | 8 +++++--- synapse/handlers/sync.py | 10 ++++++---- 2 files changed, 11 insertions(+), 7 deletions(-) (limited to 'synapse') diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 07ce553995..51bfad01c8 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -122,7 +122,9 @@ class DeviceWorkerHandler(BaseHandler): # First we check if any devices have changed for users that we share # rooms with. - tracked_users = yield self.store.get_users_who_share_room_with_user(user_id) + users_who_share_room = yield self.store.get_users_who_share_room_with_user(user_id) + + tracked_users = set(users_who_share_room) # always tell the user about their own devices tracked_users.add(user_id) @@ -217,8 +219,8 @@ class DeviceWorkerHandler(BaseHandler): if possibly_changed or possibly_left: # Take the intersection of the users whose devices may have changed # and those that actually still share a room with the user - possibly_joined = possibly_changed & tracked_users - possibly_left = (possibly_changed | possibly_left) - tracked_users + possibly_joined = possibly_changed & users_who_share_room + possibly_left = (possibly_changed | possibly_left) - users_who_share_room else: possibly_joined = [] possibly_left = [] diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index fd68a31b09..725f41c4d9 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -1139,14 +1139,16 @@ class SyncHandler(object): # room with by looking at all users that have left a room plus users # that were in a room we've left. - users_we_track = await self.store.get_users_who_share_room_with_user( + users_who_share_room = await self.store.get_users_who_share_room_with_user( user_id ) - users_we_track.add(user_id) + + tracked_users = set(users_who_share_room) + tracked_users.add(user_id) # Step 1a, check for changes in devices of users we share a room with users_that_have_changed = await self.store.get_users_whose_devices_changed( - since_token.device_list_key, users_we_track + since_token.device_list_key, tracked_users ) # Step 1b, check for newly joined rooms @@ -1169,7 +1171,7 @@ class SyncHandler(object): newly_left_users.update(left_users) # Remove any users that we still share a room with. - newly_left_users -= users_we_track + newly_left_users -= users_who_share_room return DeviceLists(changed=users_that_have_changed, left=newly_left_users) else: -- cgit 1.4.1 From 740647752576228e469918bafbb97ff556ff5ebe Mon Sep 17 00:00:00 2001 From: David Baker Date: Mon, 30 Mar 2020 10:18:33 +0100 Subject: black --- synapse/handlers/device.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'synapse') diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 51bfad01c8..d256b86ce5 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -122,7 +122,9 @@ class DeviceWorkerHandler(BaseHandler): # First we check if any devices have changed for users that we share # rooms with. - users_who_share_room = yield self.store.get_users_who_share_room_with_user(user_id) + users_who_share_room = yield self.store.get_users_who_share_room_with_user( + user_id + ) tracked_users = set(users_who_share_room) # always tell the user about their own devices -- cgit 1.4.1 From 104844c1e1a99df9c4a7e022e715517578533db7 Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Mon, 30 Mar 2020 14:00:11 +0100 Subject: Add explanatory comment --- synapse/handlers/device.py | 3 ++- synapse/handlers/sync.py | 2 ++ 2 files changed, 4 insertions(+), 1 deletion(-) (limited to 'synapse') diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index d256b86ce5..993499f446 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -127,7 +127,8 @@ class DeviceWorkerHandler(BaseHandler): ) tracked_users = set(users_who_share_room) - # always tell the user about their own devices + + # Always tell the user about their own devices tracked_users.add(user_id) changed = yield self.store.get_users_whose_devices_changed( diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 725f41c4d9..1f1cde2feb 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -1144,6 +1144,8 @@ class SyncHandler(object): ) tracked_users = set(users_who_share_room) + + # Always tell the user about their own devices tracked_users.add(user_id) # Step 1a, check for changes in devices of users we share a room with -- cgit 1.4.1 From 4f21c33be301b8ea6369039c3ad8baa51878e4d5 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 30 Mar 2020 16:37:24 +0100 Subject: Remove usage of "conn_id" for presence. (#7128) * Remove `conn_id` usage for UserSyncCommand. Each tcp replication connection is assigned a "conn_id", which is used to give an ID to a remotely connected worker. In a redis world, there will no longer be a one to one mapping between connection and instance, so instead we need to replace such usages with an ID generated by the remote instances and included in the replicaiton commands. This really only effects UserSyncCommand. * Add CLEAR_USER_SYNCS command that is sent on shutdown. This should help with the case where a synchrotron gets restarted gracefully, rather than rely on 5 minute timeout. --- changelog.d/7128.misc | 1 + docs/tcp_replication.md | 6 ++++++ synapse/app/generic_worker.py | 20 ++++++++++++++++---- synapse/replication/tcp/client.py | 6 ++++-- synapse/replication/tcp/commands.py | 36 ++++++++++++++++++++++++++++++++---- synapse/replication/tcp/protocol.py | 9 +++++++-- synapse/replication/tcp/resource.py | 17 +++++++---------- synapse/server.py | 11 +++++++++++ synapse/server.pyi | 2 ++ 9 files changed, 86 insertions(+), 22 deletions(-) create mode 100644 changelog.d/7128.misc (limited to 'synapse') diff --git a/changelog.d/7128.misc b/changelog.d/7128.misc new file mode 100644 index 0000000000..5703f6d2ec --- /dev/null +++ b/changelog.d/7128.misc @@ -0,0 +1 @@ +Add explicit `instance_id` for USER_SYNC commands and remove implicit `conn_id` usage. diff --git a/docs/tcp_replication.md b/docs/tcp_replication.md index d4f7d9ec18..3be8e50c4c 100644 --- a/docs/tcp_replication.md +++ b/docs/tcp_replication.md @@ -198,6 +198,12 @@ Asks the server for the current position of all streams. A user has started or stopped syncing +#### CLEAR_USER_SYNC (C) + + The server should clear all associated user sync data from the worker. + + This is used when a worker is shutting down. + #### FEDERATION_ACK (C) Acknowledge receipt of some federation data diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index fba7ad9551..1ee266f7c5 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -65,6 +65,7 @@ from synapse.replication.slave.storage.registration import SlavedRegistrationSto from synapse.replication.slave.storage.room import RoomStore from synapse.replication.slave.storage.transactions import SlavedTransactionStore from synapse.replication.tcp.client import ReplicationClientHandler +from synapse.replication.tcp.commands import ClearUserSyncsCommand from synapse.replication.tcp.streams import ( AccountDataStream, DeviceListsStream, @@ -124,7 +125,6 @@ from synapse.types import ReadReceipt from synapse.util.async_helpers import Linearizer from synapse.util.httpresourcetree import create_resource_tree from synapse.util.manhole import manhole -from synapse.util.stringutils import random_string from synapse.util.versionstring import get_version_string logger = logging.getLogger("synapse.app.generic_worker") @@ -233,6 +233,7 @@ class GenericWorkerPresence(object): self.user_to_num_current_syncs = {} self.clock = hs.get_clock() self.notifier = hs.get_notifier() + self.instance_id = hs.get_instance_id() active_presence = self.store.take_presence_startup_info() self.user_to_current_state = {state.user_id: state for state in active_presence} @@ -245,13 +246,24 @@ class GenericWorkerPresence(object): self.send_stop_syncing, UPDATE_SYNCING_USERS_MS ) - self.process_id = random_string(16) - logger.info("Presence process_id is %r", self.process_id) + hs.get_reactor().addSystemEventTrigger( + "before", + "shutdown", + run_as_background_process, + "generic_presence.on_shutdown", + self._on_shutdown, + ) + + def _on_shutdown(self): + if self.hs.config.use_presence: + self.hs.get_tcp_replication().send_command( + ClearUserSyncsCommand(self.instance_id) + ) def send_user_sync(self, user_id, is_syncing, last_sync_ms): if self.hs.config.use_presence: self.hs.get_tcp_replication().send_user_sync( - user_id, is_syncing, last_sync_ms + self.instance_id, user_id, is_syncing, last_sync_ms ) def mark_as_coming_online(self, user_id): diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 7e7ad0f798..e86d9805f1 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -189,10 +189,12 @@ class ReplicationClientHandler(AbstractReplicationClientHandler): """ self.send_command(FederationAckCommand(token)) - def send_user_sync(self, user_id, is_syncing, last_sync_ms): + def send_user_sync(self, instance_id, user_id, is_syncing, last_sync_ms): """Poke the master that a user has started/stopped syncing. """ - self.send_command(UserSyncCommand(user_id, is_syncing, last_sync_ms)) + self.send_command( + UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms) + ) def send_remove_pusher(self, app_id, push_key, user_id): """Poke the master to remove a pusher for a user diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index 5a6b734094..e4eec643f7 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -207,30 +207,32 @@ class UserSyncCommand(Command): Format:: - USER_SYNC + USER_SYNC Where is either "start" or "stop" """ NAME = "USER_SYNC" - def __init__(self, user_id, is_syncing, last_sync_ms): + def __init__(self, instance_id, user_id, is_syncing, last_sync_ms): + self.instance_id = instance_id self.user_id = user_id self.is_syncing = is_syncing self.last_sync_ms = last_sync_ms @classmethod def from_line(cls, line): - user_id, state, last_sync_ms = line.split(" ", 2) + instance_id, user_id, state, last_sync_ms = line.split(" ", 3) if state not in ("start", "end"): raise Exception("Invalid USER_SYNC state %r" % (state,)) - return cls(user_id, state == "start", int(last_sync_ms)) + return cls(instance_id, user_id, state == "start", int(last_sync_ms)) def to_line(self): return " ".join( ( + self.instance_id, self.user_id, "start" if self.is_syncing else "end", str(self.last_sync_ms), @@ -238,6 +240,30 @@ class UserSyncCommand(Command): ) +class ClearUserSyncsCommand(Command): + """Sent by the client to inform the server that it should drop all + information about syncing users sent by the client. + + Mainly used when client is about to shut down. + + Format:: + + CLEAR_USER_SYNC + """ + + NAME = "CLEAR_USER_SYNC" + + def __init__(self, instance_id): + self.instance_id = instance_id + + @classmethod + def from_line(cls, line): + return cls(line) + + def to_line(self): + return self.instance_id + + class FederationAckCommand(Command): """Sent by the client when it has processed up to a given point in the federation stream. This allows the master to drop in-memory caches of the @@ -398,6 +424,7 @@ _COMMANDS = ( InvalidateCacheCommand, UserIpCommand, RemoteServerUpCommand, + ClearUserSyncsCommand, ) # type: Tuple[Type[Command], ...] # Map of command name to command type. @@ -420,6 +447,7 @@ VALID_CLIENT_COMMANDS = ( ReplicateCommand.NAME, PingCommand.NAME, UserSyncCommand.NAME, + ClearUserSyncsCommand.NAME, FederationAckCommand.NAME, RemovePusherCommand.NAME, InvalidateCacheCommand.NAME, diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index f81d2e2442..dae246825f 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -423,9 +423,12 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): async def on_USER_SYNC(self, cmd): await self.streamer.on_user_sync( - self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms + cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms ) + async def on_CLEAR_USER_SYNC(self, cmd): + await self.streamer.on_clear_user_syncs(cmd.instance_id) + async def on_REPLICATE(self, cmd): # Subscribe to all streams we're publishing to. for stream_name in self.streamer.streams_by_name: @@ -551,6 +554,8 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): ): BaseReplicationStreamProtocol.__init__(self, clock) + self.instance_id = hs.get_instance_id() + self.client_name = client_name self.server_name = server_name self.handler = handler @@ -580,7 +585,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): currently_syncing = self.handler.get_currently_syncing_users() now = self.clock.time_msec() for user_id in currently_syncing: - self.send_command(UserSyncCommand(user_id, True, now)) + self.send_command(UserSyncCommand(self.instance_id, user_id, True, now)) # We've now finished connecting to so inform the client handler self.handler.update_connection(self) diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index 4374e99e32..8b6067e20d 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -251,14 +251,19 @@ class ReplicationStreamer(object): self.federation_sender.federation_ack(token) @measure_func("repl.on_user_sync") - async def on_user_sync(self, conn_id, user_id, is_syncing, last_sync_ms): + async def on_user_sync(self, instance_id, user_id, is_syncing, last_sync_ms): """A client has started/stopped syncing on a worker. """ user_sync_counter.inc() await self.presence_handler.update_external_syncs_row( - conn_id, user_id, is_syncing, last_sync_ms + instance_id, user_id, is_syncing, last_sync_ms ) + async def on_clear_user_syncs(self, instance_id): + """A replication client wants us to drop all their UserSync data. + """ + await self.presence_handler.update_external_syncs_clear(instance_id) + @measure_func("repl.on_remove_pusher") async def on_remove_pusher(self, app_id, push_key, user_id): """A client has asked us to remove a pusher @@ -321,14 +326,6 @@ class ReplicationStreamer(object): except ValueError: pass - # We need to tell the presence handler that the connection has been - # lost so that it can handle any ongoing syncs on that connection. - run_as_background_process( - "update_external_syncs_clear", - self.presence_handler.update_external_syncs_clear, - connection.conn_id, - ) - def _batch_updates(updates): """Takes a list of updates of form [(token, row)] and sets the token to diff --git a/synapse/server.py b/synapse/server.py index c7ca2bda0d..cd86475d6b 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -103,6 +103,7 @@ from synapse.storage import DataStores, Storage from synapse.streams.events import EventSources from synapse.util import Clock from synapse.util.distributor import Distributor +from synapse.util.stringutils import random_string logger = logging.getLogger(__name__) @@ -230,6 +231,8 @@ class HomeServer(object): self._listening_services = [] self.start_time = None + self.instance_id = random_string(5) + self.clock = Clock(reactor) self.distributor = Distributor() self.ratelimiter = Ratelimiter() @@ -242,6 +245,14 @@ class HomeServer(object): for depname in kwargs: setattr(self, depname, kwargs[depname]) + def get_instance_id(self): + """A unique ID for this synapse process instance. + + This is used to distinguish running instances in worker-based + deployments. + """ + return self.instance_id + def setup(self): logger.info("Setting up.") self.start_time = int(self.get_clock().time()) diff --git a/synapse/server.pyi b/synapse/server.pyi index 3844f0e12f..9d1dfa71e7 100644 --- a/synapse/server.pyi +++ b/synapse/server.pyi @@ -114,3 +114,5 @@ class HomeServer(object): pass def is_mine_id(self, domain_id: str) -> bool: pass + def get_instance_id(self) -> str: + pass -- cgit 1.4.1 From d9f29f8daef2f49464382b0e80ee93ff38681e99 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Mon, 30 Mar 2020 17:38:21 +0100 Subject: Fix a small typo in the `metrics_flags` config option. (#7171) --- changelog.d/7171.doc | 1 + docs/sample_config.yaml | 2 +- synapse/config/metrics.py | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) create mode 100644 changelog.d/7171.doc (limited to 'synapse') diff --git a/changelog.d/7171.doc b/changelog.d/7171.doc new file mode 100644 index 0000000000..25a3bd8ac6 --- /dev/null +++ b/changelog.d/7171.doc @@ -0,0 +1 @@ +Fix a small typo in the `metrics_flags` config option. \ No newline at end of file diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 743949945a..6a770508f9 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1144,7 +1144,7 @@ account_threepid_delegates: # enabled by default, either for performance reasons or limited use. # metrics_flags: - # Publish synapse_federation_known_servers, a g auge of the number of + # Publish synapse_federation_known_servers, a gauge of the number of # servers this homeserver knows about, including itself. May cause # performance problems on large homeservers. # diff --git a/synapse/config/metrics.py b/synapse/config/metrics.py index 22538153e1..6f517a71d0 100644 --- a/synapse/config/metrics.py +++ b/synapse/config/metrics.py @@ -86,7 +86,7 @@ class MetricsConfig(Config): # enabled by default, either for performance reasons or limited use. # metrics_flags: - # Publish synapse_federation_known_servers, a g auge of the number of + # Publish synapse_federation_known_servers, a gauge of the number of # servers this homeserver knows about, including itself. May cause # performance problems on large homeservers. # -- cgit 1.4.1 From 7042840b3201644ee71ea3e446576aa347b6d2a3 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Mon, 30 Mar 2020 17:53:25 +0100 Subject: Transfer alias mappings when joining an upgraded room (#6946) --- changelog.d/6946.bugfix | 1 + synapse/handlers/room_member.py | 3 +++ synapse/storage/data_stores/main/directory.py | 26 +++++++++++++++++++++++--- 3 files changed, 27 insertions(+), 3 deletions(-) create mode 100644 changelog.d/6946.bugfix (limited to 'synapse') diff --git a/changelog.d/6946.bugfix b/changelog.d/6946.bugfix new file mode 100644 index 0000000000..a238c83a18 --- /dev/null +++ b/changelog.d/6946.bugfix @@ -0,0 +1 @@ +Transfer alias mappings on room upgrade. \ No newline at end of file diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 4260426369..c3ee8db4f0 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -519,6 +519,9 @@ class RoomMemberHandler(object): yield self.store.set_room_is_public(old_room_id, False) yield self.store.set_room_is_public(room_id, True) + # Transfer alias mappings in the room directory + yield self.store.update_aliases_for_room(old_room_id, room_id) + # Check if any groups we own contain the predecessor room local_group_ids = yield self.store.get_local_groups_for_room(old_room_id) for group_id in local_group_ids: diff --git a/synapse/storage/data_stores/main/directory.py b/synapse/storage/data_stores/main/directory.py index c9e7de7d12..e1d1bc3e05 100644 --- a/synapse/storage/data_stores/main/directory.py +++ b/synapse/storage/data_stores/main/directory.py @@ -14,6 +14,7 @@ # limitations under the License. from collections import namedtuple +from typing import Optional from twisted.internet import defer @@ -159,10 +160,29 @@ class DirectoryStore(DirectoryWorkerStore): return room_id - def update_aliases_for_room(self, old_room_id, new_room_id, creator): + def update_aliases_for_room( + self, old_room_id: str, new_room_id: str, creator: Optional[str] = None, + ): + """Repoint all of the aliases for a given room, to a different room. + + Args: + old_room_id: + new_room_id: + creator: The user to record as the creator of the new mapping. + If None, the creator will be left unchanged. + """ + def _update_aliases_for_room_txn(txn): - sql = "UPDATE room_aliases SET room_id = ?, creator = ? WHERE room_id = ?" - txn.execute(sql, (new_room_id, creator, old_room_id)) + update_creator_sql = "" + sql_params = (new_room_id, old_room_id) + if creator: + update_creator_sql = ", creator = ?" + sql_params = (new_room_id, creator, old_room_id) + + sql = "UPDATE room_aliases SET room_id = ? %s WHERE room_id = ?" % ( + update_creator_sql, + ) + txn.execute(sql, sql_params) self._invalidate_cache_and_stream( txn, self.get_aliases_for_room, (old_room_id,) ) -- cgit 1.4.1 From 7966a1cde9d4b598faa06620424844f2b35c94af Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 30 Mar 2020 19:06:52 +0100 Subject: Rewrite prune_old_outbound_device_pokes for efficiency (#7159) make sure we clear out all but one update for the user --- changelog.d/7159.bugfix | 1 + synapse/handlers/federation.py | 25 +------- synapse/storage/data_stores/main/devices.py | 71 ++++++++++++++++++---- synapse/util/stringutils.py | 21 ++++++- tests/federation/test_federation_sender.py | 92 +++++++++++++++++++++++++++++ 5 files changed, 173 insertions(+), 37 deletions(-) create mode 100644 changelog.d/7159.bugfix (limited to 'synapse') diff --git a/changelog.d/7159.bugfix b/changelog.d/7159.bugfix new file mode 100644 index 0000000000..1b341b127b --- /dev/null +++ b/changelog.d/7159.bugfix @@ -0,0 +1 @@ +Fix excessive CPU usage by `prune_old_outbound_device_pokes` job. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 38ab6a8fc3..c7aa7acf3b 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -49,6 +49,7 @@ from synapse.event_auth import auth_types_for_event from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.events.validator import EventValidator +from synapse.handlers._base import BaseHandler from synapse.logging.context import ( make_deferred_yieldable, nested_logging_context, @@ -69,10 +70,9 @@ from synapse.types import JsonDict, StateMap, UserID, get_domain_from_id from synapse.util.async_helpers import Linearizer, concurrently_execute from synapse.util.distributor import user_joined_room from synapse.util.retryutils import NotRetryingDestination +from synapse.util.stringutils import shortstr from synapse.visibility import filter_events_for_server -from ._base import BaseHandler - logger = logging.getLogger(__name__) @@ -93,27 +93,6 @@ class _NewEventInfo: auth_events = attr.ib(type=Optional[StateMap[EventBase]], default=None) -def shortstr(iterable, maxitems=5): - """If iterable has maxitems or fewer, return the stringification of a list - containing those items. - - Otherwise, return the stringification of a a list with the first maxitems items, - followed by "...". - - Args: - iterable (Iterable): iterable to truncate - maxitems (int): number of items to return before truncating - - Returns: - unicode - """ - - items = list(itertools.islice(iterable, maxitems + 1)) - if len(items) <= maxitems: - return str(items) - return "[" + ", ".join(repr(r) for r in items[:maxitems]) + ", ...]" - - class FederationHandler(BaseHandler): """Handles events that originated from federation. Responsible for: diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py index 2d47cfd131..3140e1b722 100644 --- a/synapse/storage/data_stores/main/devices.py +++ b/synapse/storage/data_stores/main/devices.py @@ -41,6 +41,7 @@ from synapse.util.caches.descriptors import ( cachedList, ) from synapse.util.iterutils import batch_iter +from synapse.util.stringutils import shortstr logger = logging.getLogger(__name__) @@ -1092,18 +1093,47 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ], ) - def _prune_old_outbound_device_pokes(self): + def _prune_old_outbound_device_pokes(self, prune_age=24 * 60 * 60 * 1000): """Delete old entries out of the device_lists_outbound_pokes to ensure - that we don't fill up due to dead servers. We keep one entry per - (destination, user_id) tuple to ensure that the prev_ids remain correct - if the server does come back. + that we don't fill up due to dead servers. + + Normally, we try to send device updates as a delta since a previous known point: + this is done by setting the prev_id in the m.device_list_update EDU. However, + for that to work, we have to have a complete record of each change to + each device, which can add up to quite a lot of data. + + An alternative mechanism is that, if the remote server sees that it has missed + an entry in the stream_id sequence for a given user, it will request a full + list of that user's devices. Hence, we can reduce the amount of data we have to + store (and transmit in some future transaction), by clearing almost everything + for a given destination out of the database, and having the remote server + resync. + + All we need to do is make sure we keep at least one row for each + (user, destination) pair, to remind us to send a m.device_list_update EDU for + that user when the destination comes back. It doesn't matter which device + we keep. """ - yesterday = self._clock.time_msec() - 24 * 60 * 60 * 1000 + yesterday = self._clock.time_msec() - prune_age def _prune_txn(txn): + # look for (user, destination) pairs which have an update older than + # the cutoff. + # + # For each pair, we also need to know the most recent stream_id, and + # an arbitrary device_id at that stream_id. select_sql = """ - SELECT destination, user_id, max(stream_id) as stream_id - FROM device_lists_outbound_pokes + SELECT + dlop1.destination, + dlop1.user_id, + MAX(dlop1.stream_id) AS stream_id, + (SELECT MIN(dlop2.device_id) AS device_id FROM + device_lists_outbound_pokes dlop2 + WHERE dlop2.destination = dlop1.destination AND + dlop2.user_id=dlop1.user_id AND + dlop2.stream_id=MAX(dlop1.stream_id) + ) + FROM device_lists_outbound_pokes dlop1 GROUP BY destination, user_id HAVING min(ts) < ? AND count(*) > 1 """ @@ -1114,14 +1144,29 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): if not rows: return + logger.info( + "Pruning old outbound device list updates for %i users/destinations: %s", + len(rows), + shortstr((row[0], row[1]) for row in rows), + ) + + # we want to keep the update with the highest stream_id for each user. + # + # there might be more than one update (with different device_ids) with the + # same stream_id, so we also delete all but one rows with the max stream id. delete_sql = """ DELETE FROM device_lists_outbound_pokes - WHERE ts < ? AND destination = ? AND user_id = ? AND stream_id < ? + WHERE destination = ? AND user_id = ? AND ( + stream_id < ? OR + (stream_id = ? AND device_id != ?) + ) """ - - txn.executemany( - delete_sql, ((yesterday, row[0], row[1], row[2]) for row in rows) - ) + count = 0 + for (destination, user_id, stream_id, device_id) in rows: + txn.execute( + delete_sql, (destination, user_id, stream_id, stream_id, device_id) + ) + count += txn.rowcount # Since we've deleted unsent deltas, we need to remove the entry # of last successful sent so that the prev_ids are correctly set. @@ -1131,7 +1176,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): """ txn.executemany(sql, ((row[0], row[1]) for row in rows)) - logger.info("Pruned %d device list outbound pokes", txn.rowcount) + logger.info("Pruned %d device list outbound pokes", count) return run_as_background_process( "prune_old_outbound_device_pokes", diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py index 2c0dcb5208..6899bcb788 100644 --- a/synapse/util/stringutils.py +++ b/synapse/util/stringutils.py @@ -13,10 +13,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import itertools import random import re import string +from collections import Iterable import six from six import PY2, PY3 @@ -126,3 +127,21 @@ def assert_valid_client_secret(client_secret): raise SynapseError( 400, "Invalid client_secret parameter", errcode=Codes.INVALID_PARAM ) + + +def shortstr(iterable: Iterable, maxitems: int = 5) -> str: + """If iterable has maxitems or fewer, return the stringification of a list + containing those items. + + Otherwise, return the stringification of a a list with the first maxitems items, + followed by "...". + + Args: + iterable: iterable to truncate + maxitems: number of items to return before truncating + """ + + items = list(itertools.islice(iterable, maxitems + 1)) + if len(items) <= maxitems: + return str(items) + return "[" + ", ".join(repr(r) for r in items[:maxitems]) + ", ...]" diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index 7763b12159..a5fe5c6880 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -370,6 +370,98 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): devices = {edu["content"]["device_id"] for edu in self.edus} self.assertEqual({"D1", "D2", "D3"}, devices) + def test_prune_outbound_device_pokes1(self): + """If a destination is unreachable, and the updates are pruned, we should get + a single update. + + This case tests the behaviour when the server has never been reachable. + """ + mock_send_txn = self.hs.get_federation_transport_client().send_transaction + mock_send_txn.side_effect = lambda t, cb: defer.fail("fail") + + # create devices + u1 = self.register_user("user", "pass") + self.login("user", "pass", device_id="D1") + self.login("user", "pass", device_id="D2") + self.login("user", "pass", device_id="D3") + + # delete them again + self.get_success( + self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"]) + ) + + self.assertGreaterEqual(mock_send_txn.call_count, 4) + + # run the prune job + self.reactor.advance(10) + self.get_success( + self.hs.get_datastore()._prune_old_outbound_device_pokes(prune_age=1) + ) + + # recover the server + mock_send_txn.side_effect = self.record_transaction + self.hs.get_federation_sender().send_device_messages("host2") + self.pump() + + # there should be a single update for this user. + self.assertEqual(len(self.edus), 1) + edu = self.edus.pop(0) + self.assertEqual(edu["edu_type"], "m.device_list_update") + c = edu["content"] + + # synapse uses an empty prev_id list to indicate "needs a full resync". + self.assertEqual(c["prev_id"], []) + + def test_prune_outbound_device_pokes2(self): + """If a destination is unreachable, and the updates are pruned, we should get + a single update. + + This case tests the behaviour when the server was reachable, but then goes + offline. + """ + + # create first device + u1 = self.register_user("user", "pass") + self.login("user", "pass", device_id="D1") + + # expect the update EDU + self.assertEqual(len(self.edus), 1) + self.check_device_update_edu(self.edus.pop(0), u1, "D1", None) + + # now the server goes offline + mock_send_txn = self.hs.get_federation_transport_client().send_transaction + mock_send_txn.side_effect = lambda t, cb: defer.fail("fail") + + self.login("user", "pass", device_id="D2") + self.login("user", "pass", device_id="D3") + + # delete them again + self.get_success( + self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"]) + ) + + self.assertGreaterEqual(mock_send_txn.call_count, 3) + + # run the prune job + self.reactor.advance(10) + self.get_success( + self.hs.get_datastore()._prune_old_outbound_device_pokes(prune_age=1) + ) + + # recover the server + mock_send_txn.side_effect = self.record_transaction + self.hs.get_federation_sender().send_device_messages("host2") + self.pump() + + # ... and we should get a single update for this user. + self.assertEqual(len(self.edus), 1) + edu = self.edus.pop(0) + self.assertEqual(edu["edu_type"], "m.device_list_update") + c = edu["content"] + + # synapse uses an empty prev_id list to indicate "needs a full resync". + self.assertEqual(c["prev_id"], []) + def check_device_update_edu( self, edu: JsonDict, -- cgit 1.4.1 From 62a7289133840b4f4a55844b4f24ec664c3d917b Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 31 Mar 2020 13:09:16 +0100 Subject: Fix a bug which could cause incorrect 'cyclic dependency' error. (#7178) If there was an exception setting up one of the attributes of the Homeserver god object, then future attempts to fetch that attribute would raise a confusing "Cyclic dependency" error. Let's make sure that we clear the `building` flag so that we just get the original exception. Ref: #7169 --- changelog.d/7178.bugfix | 1 + synapse/server.py | 22 ++++++++++------------ 2 files changed, 11 insertions(+), 12 deletions(-) create mode 100644 changelog.d/7178.bugfix (limited to 'synapse') diff --git a/changelog.d/7178.bugfix b/changelog.d/7178.bugfix new file mode 100644 index 0000000000..35ea645d75 --- /dev/null +++ b/changelog.d/7178.bugfix @@ -0,0 +1 @@ +Fix a bug which could cause incorrect 'cyclic dependency' error. diff --git a/synapse/server.py b/synapse/server.py index cd86475d6b..9228e1c892 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -583,24 +583,22 @@ def _make_dependency_method(depname): try: builder = getattr(hs, "build_%s" % (depname)) except AttributeError: - builder = None + raise NotImplementedError( + "%s has no %s nor a builder for it" % (type(hs).__name__, depname) + ) - if builder: - # Prevent cyclic dependencies from deadlocking - if depname in hs._building: - raise ValueError("Cyclic dependency while building %s" % (depname,)) - hs._building[depname] = 1 + # Prevent cyclic dependencies from deadlocking + if depname in hs._building: + raise ValueError("Cyclic dependency while building %s" % (depname,)) + hs._building[depname] = 1 + try: dep = builder() setattr(hs, depname, dep) - + finally: del hs._building[depname] - return dep - - raise NotImplementedError( - "%s has no %s nor a builder for it" % (type(hs).__name__, depname) - ) + return dep setattr(HomeServer, "get_%s" % (depname), _get) -- cgit 1.4.1 From 0a7b0882c1d1f52bde46d6f367f265bc330e8bd0 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 31 Mar 2020 09:33:02 -0400 Subject: Fix use of async/await in media code (#7184) --- changelog.d/7184.misc | 1 + synapse/storage/data_stores/main/media_repository.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) create mode 100644 changelog.d/7184.misc (limited to 'synapse') diff --git a/changelog.d/7184.misc b/changelog.d/7184.misc new file mode 100644 index 0000000000..fac5bc0403 --- /dev/null +++ b/changelog.d/7184.misc @@ -0,0 +1 @@ +Convert some of synapse.rest.media to async/await. diff --git a/synapse/storage/data_stores/main/media_repository.py b/synapse/storage/data_stores/main/media_repository.py index 80ca36dedf..cf195f8aa6 100644 --- a/synapse/storage/data_stores/main/media_repository.py +++ b/synapse/storage/data_stores/main/media_repository.py @@ -340,7 +340,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): "get_expired_url_cache", _get_expired_url_cache_txn ) - def delete_url_cache(self, media_ids): + async def delete_url_cache(self, media_ids): if len(media_ids) == 0: return @@ -349,7 +349,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): def _delete_url_cache_txn(txn): txn.executemany(sql, [(media_id,) for media_id in media_ids]) - return self.db.runInteraction("delete_url_cache", _delete_url_cache_txn) + return await self.db.runInteraction("delete_url_cache", _delete_url_cache_txn) def get_url_cache_media_before(self, before_ts): sql = ( -- cgit 1.4.1 From b994e86e359fd095f82feabbf38fb18a5d10e0ae Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Tue, 31 Mar 2020 14:51:22 +0100 Subject: Only setdefault for signatures if device has key_json (#7177) --- changelog.d/7177.bugfix | 1 + synapse/storage/data_stores/main/devices.py | 24 ++++++++++++++---------- 2 files changed, 15 insertions(+), 10 deletions(-) create mode 100644 changelog.d/7177.bugfix (limited to 'synapse') diff --git a/changelog.d/7177.bugfix b/changelog.d/7177.bugfix new file mode 100644 index 0000000000..329a96cb0b --- /dev/null +++ b/changelog.d/7177.bugfix @@ -0,0 +1 @@ +Fix a bug which could cause outbound federation traffic to stop working if a client uploaded an incorrect e2e device signature. \ No newline at end of file diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py index 3140e1b722..20995e1b78 100644 --- a/synapse/storage/data_stores/main/devices.py +++ b/synapse/storage/data_stores/main/devices.py @@ -286,14 +286,16 @@ class DeviceWorkerStore(SQLBaseStore): key_json = device.get("key_json", None) if key_json: result["keys"] = db_to_json(key_json) + + if "signatures" in device: + for sig_user_id, sigs in device["signatures"].items(): + result["keys"].setdefault("signatures", {}).setdefault( + sig_user_id, {} + ).update(sigs) + device_display_name = device.get("device_display_name", None) if device_display_name: result["device_display_name"] = device_display_name - if "signatures" in device: - for sig_user_id, sigs in device["signatures"].items(): - result["keys"].setdefault("signatures", {}).setdefault( - sig_user_id, {} - ).update(sigs) else: result["deleted"] = True @@ -494,14 +496,16 @@ class DeviceWorkerStore(SQLBaseStore): key_json = device.get("key_json", None) if key_json: result["keys"] = db_to_json(key_json) + + if "signatures" in device: + for sig_user_id, sigs in device["signatures"].items(): + result["keys"].setdefault("signatures", {}).setdefault( + sig_user_id, {} + ).update(sigs) + device_display_name = device.get("device_display_name", None) if device_display_name: result["device_display_name"] = device_display_name - if "signatures" in device: - for sig_user_id, sigs in device["signatures"].items(): - result["keys"].setdefault("signatures", {}).setdefault( - sig_user_id, {} - ).update(sigs) results.append(result) -- cgit 1.4.1 From fe1580bfd91151c2c375d3c403ed911828f3899e Mon Sep 17 00:00:00 2001 From: Karlinde Date: Tue, 31 Mar 2020 16:08:56 +0200 Subject: Fill in the 'default' field for user-defined push rules (#6639) Signed-off-by: Karl Linderhed --- changelog.d/6639.bugfix | 1 + synapse/storage/data_stores/main/push_rule.py | 1 + 2 files changed, 2 insertions(+) create mode 100644 changelog.d/6639.bugfix (limited to 'synapse') diff --git a/changelog.d/6639.bugfix b/changelog.d/6639.bugfix new file mode 100644 index 0000000000..c7593a6e84 --- /dev/null +++ b/changelog.d/6639.bugfix @@ -0,0 +1 @@ +Fix missing field `default` when fetching user-defined push rules. diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py index 62ac88d9f2..46f9bda773 100644 --- a/synapse/storage/data_stores/main/push_rule.py +++ b/synapse/storage/data_stores/main/push_rule.py @@ -41,6 +41,7 @@ def _load_rules(rawrules, enabled_map): rule = dict(rawrule) rule["conditions"] = json.loads(rawrule["conditions"]) rule["actions"] = json.loads(rawrule["actions"]) + rule["default"] = False ruleslist.append(rule) # We're going to be mutating this a lot, so do a deep copy -- cgit 1.4.1 From 60adcbed919afd5c85442775eca822fec43d816d Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 31 Mar 2020 15:18:41 +0100 Subject: Fix "'NoneType' has no attribute start|stop" logcontext errors (#7181) Fixes #7179. --- changelog.d/7181.misc | 1 + synapse/http/site.py | 13 ++++++------- synapse/logging/context.py | 5 +++++ 3 files changed, 12 insertions(+), 7 deletions(-) create mode 100644 changelog.d/7181.misc (limited to 'synapse') diff --git a/changelog.d/7181.misc b/changelog.d/7181.misc new file mode 100644 index 0000000000..731f4dcb52 --- /dev/null +++ b/changelog.d/7181.misc @@ -0,0 +1 @@ +Clean up some LoggingContext code. diff --git a/synapse/http/site.py b/synapse/http/site.py index e092193c9c..32feb0d968 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -193,6 +193,12 @@ class SynapseRequest(Request): self.finish_time = time.time() Request.connectionLost(self, reason) + if self.logcontext is None: + logger.info( + "Connection from %s lost before request headers were read", self.client + ) + return + # we only get here if the connection to the client drops before we send # the response. # @@ -236,13 +242,6 @@ class SynapseRequest(Request): def _finished_processing(self): """Log the completion of this request and update the metrics """ - - if self.logcontext is None: - # this can happen if the connection closed before we read the - # headers (so render was never called). In that case we'll already - # have logged a warning, so just bail out. - return - usage = self.logcontext.get_resource_usage() if self._processing_finished_time is None: diff --git a/synapse/logging/context.py b/synapse/logging/context.py index a8eafb1c7c..3254d6a8df 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py @@ -539,6 +539,11 @@ def set_current_context(context: LoggingContextOrSentinel) -> LoggingContextOrSe Returns: The context that was previously active """ + # everything blows up if we allow current_context to be set to None, so sanity-check + # that now. + if context is None: + raise TypeError("'context' argument may not be None") + current = current_context() if current is not context: -- cgit 1.4.1 From cfe8c8ab8e412b6320e5963ced0670fbc7b00d1b Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 31 Mar 2020 17:24:06 +0100 Subject: Remove unused `start_background_update` This was only used in a unit test, so let's just inline it in the test. --- synapse/storage/background_updates.py | 21 --------------------- tests/storage/test_background_update.py | 14 +++++++++----- 2 files changed, 9 insertions(+), 26 deletions(-) (limited to 'synapse') diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index eb1a7e5002..d4e26eab6c 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -400,27 +400,6 @@ class BackgroundUpdater(object): self.register_background_update_handler(update_name, updater) - def start_background_update(self, update_name, progress): - """Starts a background update running. - - Args: - update_name: The update to set running. - progress: The initial state of the progress of the update. - - Returns: - A deferred that completes once the task has been added to the - queue. - """ - # Clear the background update queue so that we will pick up the new - # task on the next iteration of do_background_update. - self._background_update_queue = [] - progress_json = json.dumps(progress) - - return self.db.simple_insert( - "background_updates", - {"update_name": update_name, "progress_json": progress_json}, - ) - def _end_background_update(self, update_name): """Removes a completed background update task from the queue. diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py index ae14fb407d..aca41eb215 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py @@ -25,12 +25,20 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase): # the target runtime for each bg update target_background_update_duration_ms = 50000 + store = self.hs.get_datastore() + self.get_success( + store.db.simple_insert( + "background_updates", + values={"update_name": "test_update", "progress_json": '{"my_key": 1}'}, + ) + ) + # first step: make a bit of progress @defer.inlineCallbacks def update(progress, count): yield self.clock.sleep((count * duration_ms) / 1000) progress = {"my_key": progress["my_key"] + 1} - yield self.hs.get_datastore().db.runInteraction( + yield store.db.runInteraction( "update_progress", self.updates._background_update_progress_txn, "test_update", @@ -39,10 +47,6 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase): return count self.update_handler.side_effect = update - - self.get_success( - self.updates.start_background_update("test_update", {"my_key": 1}) - ) self.update_handler.reset_mock() res = self.get_success( self.updates.do_next_background_update( -- cgit 1.4.1 From 26d17b9bdc0de51d5f1a7526e8ab70e7f7796e4d Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 31 Mar 2020 17:25:10 +0100 Subject: Make `has_completed_background_updates` async (Almost) everywhere that uses it is happy with an awaitable. --- synapse/storage/background_updates.py | 7 +++---- tests/storage/test_background_update.py | 4 +++- 2 files changed, 6 insertions(+), 5 deletions(-) (limited to 'synapse') diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index d4e26eab6c..79494bcdf5 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -119,12 +119,11 @@ class BackgroundUpdater(object): self._all_done = True return None - @defer.inlineCallbacks - def has_completed_background_updates(self): + async def has_completed_background_updates(self) -> bool: """Check if all the background updates have completed Returns: - Deferred[bool]: True if all background updates have completed + True if all background updates have completed """ # if we've previously determined that there is nothing left to do, that # is easy @@ -138,7 +137,7 @@ class BackgroundUpdater(object): # otherwise, check if there are updates to be run. This is important, # as we may be running on a worker which doesn't perform the bg updates # itself, but still wants to wait for them to happen. - updates = yield self.db.simple_select_onecol( + updates = await self.db.simple_select_onecol( "background_updates", keyvalues=None, retcol="1", diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py index aca41eb215..e578de8acf 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py @@ -11,7 +11,9 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, homeserver): self.updates = self.hs.get_datastore().db.updates # type: BackgroundUpdater # the base test class should have run the real bg updates for us - self.assertTrue(self.updates.has_completed_background_updates()) + self.assertTrue( + self.get_success(self.updates.has_completed_background_updates()) + ) self.update_handler = Mock() self.updates.register_background_update_handler( -- cgit 1.4.1 From b4c22342320d7de86c02dfb36415a38c62bec88d Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 31 Mar 2020 17:31:32 +0100 Subject: Make do_next_background_update return a bool returning a None or an int that we don't use is confusing. --- synapse/storage/background_updates.py | 12 +++++------- tests/storage/test_background_update.py | 6 +++--- 2 files changed, 8 insertions(+), 10 deletions(-) (limited to 'synapse') diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index 79494bcdf5..4a59132bf3 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -111,7 +111,7 @@ class BackgroundUpdater(object): except Exception: logger.exception("Error doing update") else: - if result is None: + if result: logger.info( "No more background updates to do." " Unscheduling background update task." @@ -169,9 +169,7 @@ class BackgroundUpdater(object): return not update_exists - async def do_next_background_update( - self, desired_duration_ms: float - ) -> Optional[int]: + async def do_next_background_update(self, desired_duration_ms: float) -> bool: """Does some amount of work on the next queued background update Returns once some amount of work is done. @@ -180,7 +178,7 @@ class BackgroundUpdater(object): desired_duration_ms(float): How long we want to spend updating. Returns: - None if there is no more work to do, otherwise an int + True if there is no more work to do, otherwise False """ if not self._background_update_queue: updates = await self.db.simple_select_list( @@ -195,14 +193,14 @@ class BackgroundUpdater(object): if not self._background_update_queue: # no work left to do - return None + return True # pop from the front, and add back to the back update_name = self._background_update_queue.pop(0) self._background_update_queue.append(update_name) res = await self._do_background_update(update_name, desired_duration_ms) - return res + return False async def _do_background_update( self, update_name: str, desired_duration_ms: float diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py index e578de8acf..940b166129 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py @@ -56,7 +56,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase): ), by=0.1, ) - self.assertIsNotNone(res) + self.assertFalse(res) # on the first call, we should get run with the default background update size self.update_handler.assert_called_once_with( @@ -79,7 +79,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase): result = self.get_success( self.updates.do_next_background_update(target_background_update_duration_ms) ) - self.assertIsNotNone(result) + self.assertFalse(result) self.update_handler.assert_called_once() # third step: we don't expect to be called any more @@ -87,5 +87,5 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase): result = self.get_success( self.updates.do_next_background_update(target_background_update_duration_ms) ) - self.assertIsNone(result) + self.assertTrue(result) self.assertFalse(self.update_handler.called) -- cgit 1.4.1 From 7b608cf4683c0df2dbb55aacd472c407a0f6b1fa Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 31 Mar 2020 17:43:19 +0100 Subject: Only run one background update at a time --- synapse/storage/background_updates.py | 74 ++++++++++++++-------- synapse/storage/prepare_database.py | 2 +- .../delta/58/00background_update_ordering.sql | 19 ++++++ 3 files changed, 68 insertions(+), 27 deletions(-) create mode 100644 synapse/storage/schema/delta/58/00background_update_ordering.sql (limited to 'synapse') diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index 4a59132bf3..0e430356cd 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -90,8 +90,10 @@ class BackgroundUpdater(object): self._clock = hs.get_clock() self.db = database + # if a background update is currently running, its name. + self._current_background_update = None # type: Optional[str] + self._background_update_performance = {} - self._background_update_queue = [] self._background_update_handlers = {} self._all_done = False @@ -131,7 +133,7 @@ class BackgroundUpdater(object): return True # obviously, if we have things in our queue, we're not done. - if self._background_update_queue: + if self._current_background_update: return False # otherwise, check if there are updates to be run. This is important, @@ -152,11 +154,10 @@ class BackgroundUpdater(object): async def has_completed_background_update(self, update_name) -> bool: """Check if the given background update has finished running. """ - if self._all_done: return True - if update_name in self._background_update_queue: + if update_name == self._current_background_update: return False update_exists = await self.db.simple_select_one_onecol( @@ -180,31 +181,49 @@ class BackgroundUpdater(object): Returns: True if there is no more work to do, otherwise False """ - if not self._background_update_queue: - updates = await self.db.simple_select_list( - "background_updates", - keyvalues=None, - retcols=("update_name", "depends_on"), + + def get_background_updates_txn(txn): + txn.execute( + """ + SELECT update_name, depends_on FROM background_updates + ORDER BY ordering, update_name + """ ) - in_flight = {update["update_name"] for update in updates} - for update in updates: - if update["depends_on"] not in in_flight: - self._background_update_queue.append(update["update_name"]) + return self.db.cursor_to_dict(txn) - if not self._background_update_queue: - # no work left to do - return True + if not self._current_background_update: + all_pending_updates = await self.db.runInteraction( + "background_updates", get_background_updates_txn, + ) + if not all_pending_updates: + # no work left to do + return True + + # find the first update which isn't dependent on another one in the queue. + pending = {update["update_name"] for update in all_pending_updates} + for upd in all_pending_updates: + depends_on = upd["depends_on"] + if not depends_on or depends_on not in pending: + break + logger.info( + "Not starting on bg update %s until %s is done", + upd["update_name"], + depends_on, + ) + else: + # if we get to the end of that for loop, there is a problem + raise Exception( + "Unable to find a background update which doesn't depend on " + "another: dependency cycle?" + ) - # pop from the front, and add back to the back - update_name = self._background_update_queue.pop(0) - self._background_update_queue.append(update_name) + self._current_background_update = upd["update_name"] - res = await self._do_background_update(update_name, desired_duration_ms) + await self._do_background_update(desired_duration_ms) return False - async def _do_background_update( - self, update_name: str, desired_duration_ms: float - ) -> int: + async def _do_background_update(self, desired_duration_ms: float) -> int: + update_name = self._current_background_update logger.info("Starting update batch on background update '%s'", update_name) update_handler = self._background_update_handlers[update_name] @@ -405,9 +424,12 @@ class BackgroundUpdater(object): Returns: A deferred that completes once the task is removed. """ - self._background_update_queue = [ - name for name in self._background_update_queue if name != update_name - ] + if update_name != self._current_background_update: + raise Exception( + "Cannot end background update %s which isn't currently running" + % update_name + ) + self._current_background_update = None return self.db.simple_delete_one( "background_updates", keyvalues={"update_name": update_name} ) diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 6cb7d4b922..1712932f31 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -29,7 +29,7 @@ logger = logging.getLogger(__name__) # Remember to update this number every time a change is made to database # schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 57 +SCHEMA_VERSION = 58 dir_path = os.path.abspath(os.path.dirname(__file__)) diff --git a/synapse/storage/schema/delta/58/00background_update_ordering.sql b/synapse/storage/schema/delta/58/00background_update_ordering.sql new file mode 100644 index 0000000000..02dae587cc --- /dev/null +++ b/synapse/storage/schema/delta/58/00background_update_ordering.sql @@ -0,0 +1,19 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* add an "ordering" column to background_updates, which can be used to sort them + to achieve some level of consistency. */ + +ALTER TABLE background_updates ADD COLUMN ordering INT NOT NULL DEFAULT 0; -- cgit 1.4.1 From dfa07822542da96b93ef9d871d43bf1a36dc4664 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 1 Apr 2020 10:40:46 +0100 Subject: Remove connections per replication stream metric. (#7195) This broke in a recent PR (#7024) and is no longer useful due to all replication clients implicitly subscribing to all streams, so let's just remove it. --- changelog.d/7195.misc | 1 + synapse/replication/tcp/resource.py | 16 ---------------- 2 files changed, 1 insertion(+), 16 deletions(-) create mode 100644 changelog.d/7195.misc (limited to 'synapse') diff --git a/changelog.d/7195.misc b/changelog.d/7195.misc new file mode 100644 index 0000000000..676f285377 --- /dev/null +++ b/changelog.d/7195.misc @@ -0,0 +1 @@ +Move catchup of replication streams logic to worker. diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index 8b6067e20d..30021ee309 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -99,22 +99,6 @@ class ReplicationStreamer(object): self.streams_by_name = {stream.NAME: stream for stream in self.streams} - LaterGauge( - "synapse_replication_tcp_resource_connections_per_stream", - "", - ["stream_name"], - lambda: { - (stream_name,): len( - [ - conn - for conn in self.connections - if stream_name in conn.replication_streams - ] - ) - for stream_name in self.streams_by_name - }, - ) - self.federation_sender = None if not hs.config.send_federation: self.federation_sender = hs.get_federation_sender() -- cgit 1.4.1 From 468dcc767bf379ba2b4ed4b6d1c6537473175eab Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 1 Apr 2020 08:27:05 -0400 Subject: Allow admins to create aliases when they are not in the room (#7191) --- changelog.d/7191.feature | 1 + synapse/handlers/directory.py | 6 +++- tests/handlers/test_directory.py | 62 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 68 insertions(+), 1 deletion(-) create mode 100644 changelog.d/7191.feature (limited to 'synapse') diff --git a/changelog.d/7191.feature b/changelog.d/7191.feature new file mode 100644 index 0000000000..83d5685bb2 --- /dev/null +++ b/changelog.d/7191.feature @@ -0,0 +1 @@ +Admin users are no longer required to be in a room to create an alias for it. diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 1d842c369b..53e5f585d9 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -127,7 +127,11 @@ class DirectoryHandler(BaseHandler): errcode=Codes.EXCLUSIVE, ) else: - if self.require_membership and check_membership: + # Server admins are not subject to the same constraints as normal + # users when creating an alias (e.g. being in the room). + is_admin = yield self.auth.is_server_admin(requester.user) + + if (self.require_membership and check_membership) and not is_admin: rooms_for_user = yield self.store.get_rooms_for_user(user_id) if room_id not in rooms_for_user: raise AuthError( diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index 5e40adba52..00bb776271 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -102,6 +102,68 @@ class DirectoryTestCase(unittest.HomeserverTestCase): self.assertEquals({"room_id": "!8765asdf:test", "servers": ["test"]}, response) +class TestCreateAlias(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + room.register_servlets, + directory.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.handler = hs.get_handlers().directory_handler + + # Create user + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + # Create a test room + self.room_id = self.helper.create_room_as( + self.admin_user, tok=self.admin_user_tok + ) + + self.test_alias = "#test:test" + self.room_alias = RoomAlias.from_string(self.test_alias) + + # Create a test user. + self.test_user = self.register_user("user", "pass", admin=False) + self.test_user_tok = self.login("user", "pass") + self.helper.join(room=self.room_id, user=self.test_user, tok=self.test_user_tok) + + def test_create_alias_joined_room(self): + """A user can create an alias for a room they're in.""" + self.get_success( + self.handler.create_association( + create_requester(self.test_user), self.room_alias, self.room_id, + ) + ) + + def test_create_alias_other_room(self): + """A user cannot create an alias for a room they're NOT in.""" + other_room_id = self.helper.create_room_as( + self.admin_user, tok=self.admin_user_tok + ) + + self.get_failure( + self.handler.create_association( + create_requester(self.test_user), self.room_alias, other_room_id, + ), + synapse.api.errors.SynapseError, + ) + + def test_create_alias_admin(self): + """An admin can create an alias for a room they're NOT in.""" + other_room_id = self.helper.create_room_as( + self.test_user, tok=self.test_user_tok + ) + + self.get_success( + self.handler.create_association( + create_requester(self.admin_user), self.room_alias, other_room_id, + ) + ) + + class TestDeleteAlias(unittest.HomeserverTestCase): servlets = [ synapse.rest.admin.register_servlets, -- cgit 1.4.1 From b9930d24a05e47c36845d8607b12a45eea889be0 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 1 Apr 2020 08:48:00 -0400 Subject: Support SAML in the user interactive authentication workflow. (#7102) --- CHANGES.md | 8 ++ changelog.d/7102.feature | 1 + synapse/api/constants.py | 1 + synapse/handlers/auth.py | 116 +++++++++++++++++++++++++++- synapse/handlers/saml_handler.py | 51 +++++++++--- synapse/res/templates/sso_auth_confirm.html | 14 ++++ synapse/rest/client/v2_alpha/account.py | 19 ++++- synapse/rest/client/v2_alpha/auth.py | 42 +++++----- synapse/rest/client/v2_alpha/devices.py | 12 ++- synapse/rest/client/v2_alpha/keys.py | 6 +- synapse/rest/client/v2_alpha/register.py | 1 + 11 files changed, 227 insertions(+), 44 deletions(-) create mode 100644 changelog.d/7102.feature create mode 100644 synapse/res/templates/sso_auth_confirm.html (limited to 'synapse') diff --git a/CHANGES.md b/CHANGES.md index f794c585b7..b997af1630 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,3 +1,11 @@ +Next version +============ + +* A new template (`sso_auth_confirm.html`) was added to Synapse. If your Synapse + is configured to use SSO and a custom `sso_redirect_confirm_template_dir` + configuration then this template will need to be duplicated into that + directory. + Synapse 1.12.0 (2020-03-23) =========================== diff --git a/changelog.d/7102.feature b/changelog.d/7102.feature new file mode 100644 index 0000000000..01057aa396 --- /dev/null +++ b/changelog.d/7102.feature @@ -0,0 +1 @@ +Support SSO in the user interactive authentication workflow. diff --git a/synapse/api/constants.py b/synapse/api/constants.py index cc8577552b..fda2c2e5bb 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -61,6 +61,7 @@ class LoginType(object): MSISDN = "m.login.msisdn" RECAPTCHA = "m.login.recaptcha" TERMS = "m.login.terms" + SSO = "org.matrix.login.sso" DUMMY = "m.login.dummy" # Only for C/S API v1 diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 2ce1425dfa..7c09d15a72 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -53,6 +53,31 @@ from ._base import BaseHandler logger = logging.getLogger(__name__) +SUCCESS_TEMPLATE = """ + + +Success! + + + + + +
+

Thank you

+

You may now close this window and return to the application

+
+ + +""" + + class AuthHandler(BaseHandler): SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000 @@ -91,6 +116,7 @@ class AuthHandler(BaseHandler): self.hs = hs # FIXME better possibility to access registrationHandler later? self.macaroon_gen = hs.get_macaroon_generator() self._password_enabled = hs.config.password_enabled + self._saml2_enabled = hs.config.saml2_enabled # we keep this as a list despite the O(N^2) implication so that we can # keep PASSWORD first and avoid confusing clients which pick the first @@ -106,6 +132,13 @@ class AuthHandler(BaseHandler): if t not in login_types: login_types.append(t) self._supported_login_types = login_types + # Login types and UI Auth types have a heavy overlap, but are not + # necessarily identical. Login types have SSO (and other login types) + # added in the rest layer, see synapse.rest.client.v1.login.LoginRestServerlet.on_GET. + ui_auth_types = login_types.copy() + if self._saml2_enabled: + ui_auth_types.append(LoginType.SSO) + self._supported_ui_auth_types = ui_auth_types # Ratelimiter for failed auth during UIA. Uses same ratelimit config # as per `rc_login.failed_attempts`. @@ -113,10 +146,21 @@ class AuthHandler(BaseHandler): self._clock = self.hs.get_clock() - # Load the SSO redirect confirmation page HTML template + # Load the SSO HTML templates. + + # The following template is shown to the user during a client login via SSO, + # after the SSO completes and before redirecting them back to their client. + # It notifies the user they are about to give access to their matrix account + # to the client. self._sso_redirect_confirm_template = load_jinja2_templates( hs.config.sso_redirect_confirm_template_dir, ["sso_redirect_confirm.html"], )[0] + # The following template is shown during user interactive authentication + # in the fallback auth scenario. It notifies the user that they are + # authenticating for an operation to occur on their account. + self._sso_auth_confirm_template = load_jinja2_templates( + hs.config.sso_redirect_confirm_template_dir, ["sso_auth_confirm.html"], + )[0] self._server_name = hs.config.server_name @@ -130,6 +174,7 @@ class AuthHandler(BaseHandler): request: SynapseRequest, request_body: Dict[str, Any], clientip: str, + description: str, ): """ Checks that the user is who they claim to be, via a UI auth. @@ -147,6 +192,9 @@ class AuthHandler(BaseHandler): clientip: The IP address of the client. + description: A human readable string to be displayed to the user that + describes the operation happening on their account. + Returns: defer.Deferred[dict]: the parameters for this request (which may have been given only in a previous call). @@ -175,11 +223,11 @@ class AuthHandler(BaseHandler): ) # build a list of supported flows - flows = [[login_type] for login_type in self._supported_login_types] + flows = [[login_type] for login_type in self._supported_ui_auth_types] try: result, params, _ = yield self.check_auth( - flows, request, request_body, clientip + flows, request, request_body, clientip, description ) except LoginError: # Update the ratelimite to say we failed (`can_do_action` doesn't raise). @@ -193,7 +241,7 @@ class AuthHandler(BaseHandler): raise # find the completed login type - for login_type in self._supported_login_types: + for login_type in self._supported_ui_auth_types: if login_type not in result: continue @@ -224,6 +272,7 @@ class AuthHandler(BaseHandler): request: SynapseRequest, clientdict: Dict[str, Any], clientip: str, + description: str, ): """ Takes a dictionary sent by the client in the login / registration @@ -250,6 +299,9 @@ class AuthHandler(BaseHandler): clientip: The IP address of the client. + description: A human readable string to be displayed to the user that + describes the operation happening on their account. + Returns: defer.Deferred[dict, dict, str]: a deferred tuple of (creds, params, session_id). @@ -299,12 +351,18 @@ class AuthHandler(BaseHandler): comparator = (request.uri, request.method, clientdict) if "ui_auth" not in session: session["ui_auth"] = comparator + self._save_session(session) elif session["ui_auth"] != comparator: raise SynapseError( 403, "Requested operation has changed during the UI authentication session.", ) + # Add a human readable description to the session. + if "description" not in session: + session["description"] = description + self._save_session(session) + if not authdict: raise InteractiveAuthIncompleteError( self._auth_dict_for_flows(flows, session) @@ -991,6 +1049,56 @@ class AuthHandler(BaseHandler): else: return defer.succeed(False) + def start_sso_ui_auth(self, redirect_url: str, session_id: str) -> str: + """ + Get the HTML for the SSO redirect confirmation page. + + Args: + redirect_url: The URL to redirect to the SSO provider. + session_id: The user interactive authentication session ID. + + Returns: + The HTML to render. + """ + session = self._get_session_info(session_id) + # Get the human readable operation of what is occurring, falling back to + # a generic message if it isn't available for some reason. + description = session.get("description", "modify your account") + return self._sso_auth_confirm_template.render( + description=description, redirect_url=redirect_url, + ) + + def complete_sso_ui_auth( + self, registered_user_id: str, session_id: str, request: SynapseRequest, + ): + """Having figured out a mxid for this user, complete the HTTP request + + Args: + registered_user_id: The registered user ID to complete SSO login for. + request: The request to complete. + client_redirect_url: The URL to which to redirect the user at the end of the + process. + """ + # Mark the stage of the authentication as successful. + sess = self._get_session_info(session_id) + if "creds" not in sess: + sess["creds"] = {} + creds = sess["creds"] + + # Save the user who authenticated with SSO, this will be used to ensure + # that the account be modified is also the person who logged in. + creds[LoginType.SSO] = registered_user_id + self._save_session(sess) + + # Render the HTML and return. + html_bytes = SUCCESS_TEMPLATE.encode("utf8") + request.setResponseCode(200) + request.setHeader(b"Content-Type", b"text/html; charset=utf-8") + request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),)) + + request.write(html_bytes) + finish_request(request) + def complete_sso_login( self, registered_user_id: str, diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py index dc04b53f43..4741c82f61 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml_handler.py @@ -14,7 +14,7 @@ # limitations under the License. import logging import re -from typing import Tuple +from typing import Optional, Tuple import attr import saml2 @@ -44,11 +44,15 @@ class Saml2SessionData: # time the session was created, in milliseconds creation_time = attr.ib() + # The user interactive authentication session ID associated with this SAML + # session (or None if this SAML session is for an initial login). + ui_auth_session_id = attr.ib(type=Optional[str], default=None) class SamlHandler: def __init__(self, hs): self._saml_client = Saml2Client(hs.config.saml2_sp_config) + self._auth = hs.get_auth() self._auth_handler = hs.get_auth_handler() self._registration_handler = hs.get_registration_handler() @@ -77,12 +81,14 @@ class SamlHandler: self._error_html_content = hs.config.saml2_error_html_content - def handle_redirect_request(self, client_redirect_url): + def handle_redirect_request(self, client_redirect_url, ui_auth_session_id=None): """Handle an incoming request to /login/sso/redirect Args: client_redirect_url (bytes): the URL that we should redirect the client to when everything is done + ui_auth_session_id (Optional[str]): The session ID of the ongoing UI Auth (or + None if this is a login). Returns: bytes: URL to redirect to @@ -92,7 +98,9 @@ class SamlHandler: ) now = self._clock.time_msec() - self._outstanding_requests_dict[reqid] = Saml2SessionData(creation_time=now) + self._outstanding_requests_dict[reqid] = Saml2SessionData( + creation_time=now, ui_auth_session_id=ui_auth_session_id, + ) for key, value in info["headers"]: if key == "Location": @@ -119,7 +127,9 @@ class SamlHandler: self.expire_sessions() try: - user_id = await self._map_saml_response_to_user(resp_bytes, relay_state) + user_id, current_session = await self._map_saml_response_to_user( + resp_bytes, relay_state + ) except RedirectException: # Raise the exception as per the wishes of the SAML module response raise @@ -137,9 +147,28 @@ class SamlHandler: finish_request(request) return - self._auth_handler.complete_sso_login(user_id, request, relay_state) + # Complete the interactive auth session or the login. + if current_session and current_session.ui_auth_session_id: + self._auth_handler.complete_sso_ui_auth( + user_id, current_session.ui_auth_session_id, request + ) + + else: + self._auth_handler.complete_sso_login(user_id, request, relay_state) + + async def _map_saml_response_to_user( + self, resp_bytes: str, client_redirect_url: str + ) -> Tuple[str, Optional[Saml2SessionData]]: + """ + Given a sample response, retrieve the cached session and user for it. - async def _map_saml_response_to_user(self, resp_bytes, client_redirect_url): + Args: + resp_bytes: The SAML response. + client_redirect_url: The redirect URL passed in by the client. + + Returns: + Tuple of the user ID and SAML session associated with this response. + """ try: saml2_auth = self._saml_client.parse_authn_request_response( resp_bytes, @@ -167,7 +196,9 @@ class SamlHandler: logger.info("SAML2 mapped attributes: %s", saml2_auth.ava) - self._outstanding_requests_dict.pop(saml2_auth.in_response_to, None) + current_session = self._outstanding_requests_dict.pop( + saml2_auth.in_response_to, None + ) remote_user_id = self._user_mapping_provider.get_remote_user_id( saml2_auth, client_redirect_url @@ -188,7 +219,7 @@ class SamlHandler: ) if registered_user_id is not None: logger.info("Found existing mapping %s", registered_user_id) - return registered_user_id + return registered_user_id, current_session # backwards-compatibility hack: see if there is an existing user with a # suitable mapping from the uid @@ -213,7 +244,7 @@ class SamlHandler: await self._datastore.record_user_external_id( self._auth_provider_id, remote_user_id, registered_user_id ) - return registered_user_id + return registered_user_id, current_session # Map saml response to user attributes using the configured mapping provider for i in range(1000): @@ -260,7 +291,7 @@ class SamlHandler: await self._datastore.record_user_external_id( self._auth_provider_id, remote_user_id, registered_user_id ) - return registered_user_id + return registered_user_id, current_session def expire_sessions(self): expire_before = self._clock.time_msec() - self._saml2_session_lifetime diff --git a/synapse/res/templates/sso_auth_confirm.html b/synapse/res/templates/sso_auth_confirm.html new file mode 100644 index 0000000000..0d9de9d465 --- /dev/null +++ b/synapse/res/templates/sso_auth_confirm.html @@ -0,0 +1,14 @@ + + + Authentication + + +
+

+ A client is trying to {{ description | e }}. To confirm this action, + re-authenticate with single sign-on. + If you did not expect this, your account may be compromised! +

+
+ + diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index f80b5e40ea..31435b1e1c 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -234,7 +234,11 @@ class PasswordRestServlet(RestServlet): if self.auth.has_access_token(request): requester = await self.auth.get_user_by_req(request) params = await self.auth_handler.validate_user_via_ui_auth( - requester, request, body, self.hs.get_ip_from_request(request), + requester, + request, + body, + self.hs.get_ip_from_request(request), + "modify your account password", ) user_id = requester.user.to_string() else: @@ -244,6 +248,7 @@ class PasswordRestServlet(RestServlet): request, body, self.hs.get_ip_from_request(request), + "modify your account password", ) if LoginType.EMAIL_IDENTITY in result: @@ -311,7 +316,11 @@ class DeactivateAccountRestServlet(RestServlet): return 200, {} await self.auth_handler.validate_user_via_ui_auth( - requester, request, body, self.hs.get_ip_from_request(request), + requester, + request, + body, + self.hs.get_ip_from_request(request), + "deactivate your account", ) result = await self._deactivate_account_handler.deactivate_account( requester.user.to_string(), erase, id_server=body.get("id_server") @@ -669,7 +678,11 @@ class ThreepidAddRestServlet(RestServlet): assert_valid_client_secret(client_secret) await self.auth_handler.validate_user_via_ui_auth( - requester, request, body, self.hs.get_ip_from_request(request), + requester, + request, + body, + self.hs.get_ip_from_request(request), + "add a third-party identifier to your account", ) validation_session = await self.identity_handler.validate_threepid_session( diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py index 85cf5a14c6..1787562b90 100644 --- a/synapse/rest/client/v2_alpha/auth.py +++ b/synapse/rest/client/v2_alpha/auth.py @@ -18,6 +18,7 @@ import logging from synapse.api.constants import LoginType from synapse.api.errors import SynapseError from synapse.api.urls import CLIENT_API_PREFIX +from synapse.handlers.auth import SUCCESS_TEMPLATE from synapse.http.server import finish_request from synapse.http.servlet import RestServlet, parse_string @@ -89,30 +90,6 @@ TERMS_TEMPLATE = """ """ -SUCCESS_TEMPLATE = """ - - -Success! - - - - - -
-

Thank you

-

You may now close this window and return to the application

-
- - -""" - class AuthRestServlet(RestServlet): """ @@ -130,6 +107,11 @@ class AuthRestServlet(RestServlet): self.auth_handler = hs.get_auth_handler() self.registration_handler = hs.get_registration_handler() + # SSO configuration. + self._saml_enabled = hs.config.saml2_enabled + if self._saml_enabled: + self._saml_handler = hs.get_saml_handler() + def on_GET(self, request, stagetype): session = parse_string(request, "session") if not session: @@ -150,6 +132,15 @@ class AuthRestServlet(RestServlet): "myurl": "%s/r0/auth/%s/fallback/web" % (CLIENT_API_PREFIX, LoginType.TERMS), } + + elif stagetype == LoginType.SSO and self._saml_enabled: + # Display a confirmation page which prompts the user to + # re-authenticate with their SSO provider. + client_redirect_url = "" + sso_redirect_url = self._saml_handler.handle_redirect_request( + client_redirect_url, session + ) + html = self.auth_handler.start_sso_ui_auth(sso_redirect_url, session) else: raise SynapseError(404, "Unknown auth stage type") @@ -210,6 +201,9 @@ class AuthRestServlet(RestServlet): "myurl": "%s/r0/auth/%s/fallback/web" % (CLIENT_API_PREFIX, LoginType.TERMS), } + elif stagetype == LoginType.SSO: + # The SSO fallback workflow should not post here, + raise SynapseError(404, "Fallback SSO auth does not support POST requests.") else: raise SynapseError(404, "Unknown auth stage type") diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py index 119d979052..c0714fcfb1 100644 --- a/synapse/rest/client/v2_alpha/devices.py +++ b/synapse/rest/client/v2_alpha/devices.py @@ -81,7 +81,11 @@ class DeleteDevicesRestServlet(RestServlet): assert_params_in_dict(body, ["devices"]) await self.auth_handler.validate_user_via_ui_auth( - requester, request, body, self.hs.get_ip_from_request(request), + requester, + request, + body, + self.hs.get_ip_from_request(request), + "remove device(s) from your account", ) await self.device_handler.delete_devices( @@ -127,7 +131,11 @@ class DeviceRestServlet(RestServlet): raise await self.auth_handler.validate_user_via_ui_auth( - requester, request, body, self.hs.get_ip_from_request(request), + requester, + request, + body, + self.hs.get_ip_from_request(request), + "remove a device from your account", ) await self.device_handler.delete_device(requester.user.to_string(), device_id) diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index 5eb7ef35a4..8f41a3edbf 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -263,7 +263,11 @@ class SigningKeyUploadServlet(RestServlet): body = parse_json_object_from_request(request) await self.auth_handler.validate_user_via_ui_auth( - requester, request, body, self.hs.get_ip_from_request(request), + requester, + request, + body, + self.hs.get_ip_from_request(request), + "add a device signing key to your account", ) result = await self.e2e_keys_handler.upload_signing_keys_for_user(user_id, body) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 66fc8ec179..431ecf4f84 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -505,6 +505,7 @@ class RegisterRestServlet(RestServlet): request, body, self.hs.get_ip_from_request(request), + "register a new account", ) # Check that we're not trying to register a denied 3pid. -- cgit 1.4.1 From 529462b5c044f7f7491fd41ab7d0682d01a6236b Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Thu, 2 Apr 2020 11:30:49 +0100 Subject: 1.12.1 --- CHANGES.md | 6 ++++++ debian/changelog | 6 ++++++ synapse/__init__.py | 2 +- 3 files changed, 13 insertions(+), 1 deletion(-) (limited to 'synapse') diff --git a/CHANGES.md b/CHANGES.md index a77768de58..e5608baa11 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,3 +1,9 @@ +Synapse 1.12.1 (2020-04-02) +=========================== + +No significant changes since 1.12.1rc1. + + Synapse 1.12.1rc1 (2020-03-31) ============================== diff --git a/debian/changelog b/debian/changelog index 39ec9da7ab..8e14e59c0d 100644 --- a/debian/changelog +++ b/debian/changelog @@ -1,3 +1,9 @@ +matrix-synapse-py3 (1.12.1) stable; urgency=medium + + * New synapse release 1.12.1. + + -- Synapse Packaging team Mon, 02 Apr 2020 11:30:47 +0000 + matrix-synapse-py3 (1.12.0) stable; urgency=medium * New synapse release 1.12.0. diff --git a/synapse/__init__.py b/synapse/__init__.py index c3c5b20f11..5df7d51ab1 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -36,7 +36,7 @@ try: except ImportError: pass -__version__ = "1.12.1rc1" +__version__ = "1.12.1" if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): # We import here so that we don't have to install a bunch of deps when -- cgit 1.4.1 From af47264b78c33698f6a70ce1ce3d32774d65de72 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Thu, 2 Apr 2020 12:04:55 +0100 Subject: review comment --- synapse/storage/background_updates.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse') diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index 0e430356cd..510963eb7f 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -132,7 +132,7 @@ class BackgroundUpdater(object): if self._all_done: return True - # obviously, if we have things in our queue, we're not done. + # obviously, if we are currently processing an update, we're not done. if self._current_background_update: return False -- cgit 1.4.1 From daa1ac89a0be4dd3cc941da4caeb2ddcbd701eff Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 3 Apr 2020 10:40:22 +0100 Subject: Fix device list update stream ids going backward (#7158) Occasionally we could get a federation device list update transaction which looked like: ``` [ {'edu_type': 'm.device_list_update', 'content': {'user_id': '@user:test', 'device_id': 'D2', 'prev_id': [], 'stream_id': 12, 'deleted': True}}, {'edu_type': 'm.device_list_update', 'content': {'user_id': '@user:test', 'device_id': 'D1', 'prev_id': [12], 'stream_id': 11, 'deleted': True}}, {'edu_type': 'm.device_list_update', 'content': {'user_id': '@user:test', 'device_id': 'D3', 'prev_id': [11], 'stream_id': 13, 'deleted': True}} ] ``` Having `stream_ids` which are lower than `prev_ids` looks odd. It might work (I'm not actually sure), but in any case it doesn't seem like a reasonable thing to expect other implementations to support. --- changelog.d/7158.misc | 1 + synapse/storage/data_stores/main/devices.py | 10 ++++++++-- tests/federation/test_federation_sender.py | 6 ++++++ 3 files changed, 15 insertions(+), 2 deletions(-) create mode 100644 changelog.d/7158.misc (limited to 'synapse') diff --git a/changelog.d/7158.misc b/changelog.d/7158.misc new file mode 100644 index 0000000000..269b8daeb0 --- /dev/null +++ b/changelog.d/7158.misc @@ -0,0 +1 @@ +Fix device list update stream ids going backward. diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py index 20995e1b78..dd3561e9b2 100644 --- a/synapse/storage/data_stores/main/devices.py +++ b/synapse/storage/data_stores/main/devices.py @@ -165,7 +165,6 @@ class DeviceWorkerStore(SQLBaseStore): # the max stream_id across each set of duplicate entries # # maps (user_id, device_id) -> (stream_id, opentracing_context) - # as long as their stream_id does not match that of the last row # # opentracing_context contains the opentracing metadata for the request # that created the poke @@ -270,7 +269,14 @@ class DeviceWorkerStore(SQLBaseStore): prev_id = yield self._get_last_device_update_for_remote_user( destination, user_id, from_stream_id ) - for device_id, device in iteritems(user_devices): + + # make sure we go through the devices in stream order + device_ids = sorted( + user_devices.keys(), key=lambda i: query_map[(user_id, i)][0], + ) + + for device_id in device_ids: + device = user_devices[device_id] stream_id, opentracing_context = query_map[(user_id, device_id)] result = { "user_id": user_id, diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index a5fe5c6880..33105576af 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -297,6 +297,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): c = edu["content"] if stream_id is not None: self.assertEqual(c["prev_id"], [stream_id]) + self.assertGreaterEqual(c["stream_id"], stream_id) stream_id = c["stream_id"] devices = {edu["content"]["device_id"] for edu in self.edus} self.assertEqual({"D1", "D2"}, devices) @@ -330,6 +331,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): c.items(), {"user_id": u1, "prev_id": [stream_id], "deleted": True}.items(), ) + self.assertGreaterEqual(c["stream_id"], stream_id) stream_id = c["stream_id"] devices = {edu["content"]["device_id"] for edu in self.edus} self.assertEqual({"D1", "D2", "D3"}, devices) @@ -366,6 +368,8 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): self.assertEqual(edu["edu_type"], "m.device_list_update") c = edu["content"] self.assertEqual(c["prev_id"], [stream_id] if stream_id is not None else []) + if stream_id is not None: + self.assertGreaterEqual(c["stream_id"], stream_id) stream_id = c["stream_id"] devices = {edu["content"]["device_id"] for edu in self.edus} self.assertEqual({"D1", "D2", "D3"}, devices) @@ -482,6 +486,8 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): } self.assertLessEqual(expected.items(), content.items()) + if prev_stream_id is not None: + self.assertGreaterEqual(content["stream_id"], prev_stream_id) return content["stream_id"] def check_signing_key_update_txn(self, txn: JsonDict,) -> None: -- cgit 1.4.1 From fcc2de7a0c0c7a02b935c4b35394f228a5d5a304 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 3 Apr 2020 10:51:32 +0100 Subject: Update docstring per review comments --- synapse/storage/background_updates.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse') diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index 510963eb7f..59f3394b0a 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -179,7 +179,7 @@ class BackgroundUpdater(object): desired_duration_ms(float): How long we want to spend updating. Returns: - True if there is no more work to do, otherwise False + True if we have finished running all the background updates, otherwise False """ def get_background_updates_txn(txn): -- cgit 1.4.1 From 14a8e7129793be3c7d27ea810b26efe4fba09bc9 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 3 Apr 2020 11:28:43 +0100 Subject: Revert "Revert "Improve the UX of the login fallback when using SSO (#7152)"" This reverts commit 8d4cbdeaa9765ae0d87ec82b053f12ed8162f6f5. --- changelog.d/7152.feature | 1 + synapse/static/client/login/index.html | 2 +- synapse/static/client/login/js/login.js | 51 +++++++++++++++++++-------------- 3 files changed, 32 insertions(+), 22 deletions(-) create mode 100644 changelog.d/7152.feature (limited to 'synapse') diff --git a/changelog.d/7152.feature b/changelog.d/7152.feature new file mode 100644 index 0000000000..fafa79c7e7 --- /dev/null +++ b/changelog.d/7152.feature @@ -0,0 +1 @@ +Improve the support for SSO authentication on the login fallback page. diff --git a/synapse/static/client/login/index.html b/synapse/static/client/login/index.html index bcb6bc6bb7..712b0e3980 100644 --- a/synapse/static/client/login/index.html +++ b/synapse/static/client/login/index.html @@ -9,7 +9,7 @@

-

Log in with one of the following methods

+

diff --git a/synapse/static/client/login/js/login.js b/synapse/static/client/login/js/login.js index 276c271bbe..debe464371 100644 --- a/synapse/static/client/login/js/login.js +++ b/synapse/static/client/login/js/login.js @@ -1,37 +1,41 @@ window.matrixLogin = { endpoint: location.origin + "/_matrix/client/r0/login", serverAcceptsPassword: false, - serverAcceptsCas: false, serverAcceptsSso: false, }; +var title_pre_auth = "Log in with one of the following methods"; +var title_post_auth = "Logging in..."; + var submitPassword = function(user, pwd) { console.log("Logging in with password..."); + set_title(title_post_auth); var data = { type: "m.login.password", user: user, password: pwd, }; $.post(matrixLogin.endpoint, JSON.stringify(data), function(response) { - show_login(); matrixLogin.onLogin(response); }).error(errorFunc); }; var submitToken = function(loginToken) { console.log("Logging in with login token..."); + set_title(title_post_auth); var data = { type: "m.login.token", token: loginToken }; $.post(matrixLogin.endpoint, JSON.stringify(data), function(response) { - show_login(); matrixLogin.onLogin(response); }).error(errorFunc); }; var errorFunc = function(err) { - show_login(); + // We want to show the error to the user rather than redirecting immediately to the + // SSO portal (if SSO is the only login option), so we inhibit the redirect. + show_login(true); if (err.responseJSON && err.responseJSON.error) { setFeedbackString(err.responseJSON.error + " (" + err.responseJSON.errcode + ")"); @@ -45,26 +49,33 @@ var setFeedbackString = function(text) { $("#feedback").text(text); }; -var show_login = function() { - $("#loading").hide(); - +var show_login = function(inhibit_redirect) { var this_page = window.location.origin + window.location.pathname; $("#sso_redirect_url").val(this_page); - if (matrixLogin.serverAcceptsPassword) { - $("#password_flow").show(); + // If inhibit_redirect is false, and SSO is the only supported login method, we can + // redirect straight to the SSO page + if (matrixLogin.serverAcceptsSso) { + if (!inhibit_redirect && !matrixLogin.serverAcceptsPassword) { + $("#sso_form").submit(); + return; + } + + // Otherwise, show the SSO form + $("#sso_form").show(); } - if (matrixLogin.serverAcceptsSso) { - $("#sso_flow").show(); - } else if (matrixLogin.serverAcceptsCas) { - $("#sso_form").attr("action", "/_matrix/client/r0/login/cas/redirect"); - $("#sso_flow").show(); + if (matrixLogin.serverAcceptsPassword) { + $("#password_flow").show(); } - if (!matrixLogin.serverAcceptsPassword && !matrixLogin.serverAcceptsCas && !matrixLogin.serverAcceptsSso) { + if (!matrixLogin.serverAcceptsPassword && !matrixLogin.serverAcceptsSso) { $("#no_login_types").show(); } + + set_title(title_pre_auth); + + $("#loading").hide(); }; var show_spinner = function() { @@ -74,17 +85,15 @@ var show_spinner = function() { $("#loading").show(); }; +var set_title = function(title) { + $("#title").text(title); +}; var fetch_info = function(cb) { $.get(matrixLogin.endpoint, function(response) { var serverAcceptsPassword = false; - var serverAcceptsCas = false; for (var i=0; i Date: Fri, 3 Apr 2020 11:28:49 +0100 Subject: Revert "Revert "Merge pull request #7153 from matrix-org/babolivier/sso_whitelist_login_fallback"" This reverts commit 0122ef1037c8bfe826ea09d9fc7cd63fb9c59fd1. --- changelog.d/7153.feature | 1 + docs/sample_config.yaml | 4 ++++ synapse/config/sso.py | 15 +++++++++++++++ tests/rest/client/v1/test_login.py | 9 ++++++++- 4 files changed, 28 insertions(+), 1 deletion(-) create mode 100644 changelog.d/7153.feature (limited to 'synapse') diff --git a/changelog.d/7153.feature b/changelog.d/7153.feature new file mode 100644 index 0000000000..414ebe1f69 --- /dev/null +++ b/changelog.d/7153.feature @@ -0,0 +1 @@ +Always whitelist the login fallback in the SSO configuration if `public_baseurl` is set. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 2ff0dd05a2..07e922dc27 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1392,6 +1392,10 @@ sso: # phishing attacks from evil.site. To avoid this, include a slash after the # hostname: "https://my.client/". # + # If public_baseurl is set, then the login fallback page (used by clients + # that don't natively support the required login flows) is whitelisted in + # addition to any URLs in this list. + # # By default, this list is empty. # #client_whitelist: diff --git a/synapse/config/sso.py b/synapse/config/sso.py index 95762689bc..ec3dca9efc 100644 --- a/synapse/config/sso.py +++ b/synapse/config/sso.py @@ -39,6 +39,17 @@ class SSOConfig(Config): self.sso_client_whitelist = sso_config.get("client_whitelist") or [] + # Attempt to also whitelist the server's login fallback, since that fallback sets + # the redirect URL to itself (so it can process the login token then return + # gracefully to the client). This would make it pointless to ask the user for + # confirmation, since the URL the confirmation page would be showing wouldn't be + # the client's. + # public_baseurl is an optional setting, so we only add the fallback's URL to the + # list if it's provided (because we can't figure out what that URL is otherwise). + if self.public_baseurl: + login_fallback_url = self.public_baseurl + "_matrix/static/client/login" + self.sso_client_whitelist.append(login_fallback_url) + def generate_config_section(self, **kwargs): return """\ # Additional settings to use with single-sign on systems such as SAML2 and CAS. @@ -54,6 +65,10 @@ class SSOConfig(Config): # phishing attacks from evil.site. To avoid this, include a slash after the # hostname: "https://my.client/". # + # If public_baseurl is set, then the login fallback page (used by clients + # that don't natively support the required login flows) is whitelisted in + # addition to any URLs in this list. + # # By default, this list is empty. # #client_whitelist: diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index da2c9bfa1e..aed8853d6e 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -350,7 +350,14 @@ class CASRedirectConfirmTestCase(unittest.HomeserverTestCase): def test_cas_redirect_whitelisted(self): """Tests that the SSO login flow serves a redirect to a whitelisted url """ - redirect_url = "https://legit-site.com/" + self._test_redirect("https://legit-site.com/") + + @override_config({"public_baseurl": "https://example.com"}) + def test_cas_redirect_login_fallback(self): + self._test_redirect("https://example.com/_matrix/static/client/login") + + def _test_redirect(self, redirect_url): + """Tests that the SSO login flow serves a redirect for the given redirect URL.""" cas_ticket_url = ( "/_matrix/client/r0/login/cas/ticket?redirectUrl=%s&ticket=ticket" % (urllib.parse.quote(redirect_url)) -- cgit 1.4.1 From bae32740daa5551b6613cafafb5d5bc1a73141ec Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 3 Apr 2020 12:29:30 +0100 Subject: Remove some `run_in_background` calls in replication code (#7203) By running this stuff with `run_in_background`, it won't be correctly reported against the relevant CPU usage stats. Fixes #7202 --- changelog.d/7203.bugfix | 1 + synapse/app/generic_worker.py | 16 ++++++++-------- 2 files changed, 9 insertions(+), 8 deletions(-) create mode 100644 changelog.d/7203.bugfix (limited to 'synapse') diff --git a/changelog.d/7203.bugfix b/changelog.d/7203.bugfix new file mode 100644 index 0000000000..8b383952e5 --- /dev/null +++ b/changelog.d/7203.bugfix @@ -0,0 +1 @@ +Fix some worker-mode replication handling not being correctly recorded in CPU usage stats. diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 1ee266f7c5..174bef360f 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -42,7 +42,7 @@ from synapse.handlers.presence import PresenceHandler, get_interested_parties from synapse.http.server import JsonResource from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.site import SynapseSite -from synapse.logging.context import LoggingContext, run_in_background +from synapse.logging.context import LoggingContext from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.slave.storage._base import BaseSlavedStore, __func__ @@ -635,7 +635,7 @@ class GenericWorkerReplicationHandler(ReplicationClientHandler): await super(GenericWorkerReplicationHandler, self).on_rdata( stream_name, token, rows ) - run_in_background(self.process_and_notify, stream_name, token, rows) + await self.process_and_notify(stream_name, token, rows) def get_streams_to_replicate(self): args = super(GenericWorkerReplicationHandler, self).get_streams_to_replicate() @@ -650,7 +650,9 @@ class GenericWorkerReplicationHandler(ReplicationClientHandler): async def process_and_notify(self, stream_name, token, rows): try: if self.send_handler: - self.send_handler.process_replication_rows(stream_name, token, rows) + await self.send_handler.process_replication_rows( + stream_name, token, rows + ) if stream_name == EventsStream.NAME: # We shouldn't get multiple rows per token for events stream, so @@ -782,12 +784,12 @@ class FederationSenderHandler(object): def stream_positions(self): return {"federation": self.federation_position} - def process_replication_rows(self, stream_name, token, rows): + async def process_replication_rows(self, stream_name, token, rows): # The federation stream contains things that we want to send out, e.g. # presence, typing, etc. if stream_name == "federation": send_queue.process_rows_for_federation(self.federation_sender, rows) - run_in_background(self.update_token, token) + await self.update_token(token) # We also need to poke the federation sender when new events happen elif stream_name == "events": @@ -795,9 +797,7 @@ class FederationSenderHandler(object): # ... and when new receipts happen elif stream_name == ReceiptsStream.NAME: - run_as_background_process( - "process_receipts_for_federation", self._on_new_receipts, rows - ) + await self._on_new_receipts(rows) # ... as well as device updates and messages elif stream_name == DeviceListsStream.NAME: -- cgit 1.4.1 From 0f05fd15304f1931ef167351de63cc8ffa1d3a98 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 3 Apr 2020 13:21:30 +0100 Subject: Reduce the number of calls to `resource.getrusage` (#7183) Let's just call `getrusage` once on each logcontext change, rather than twice. --- changelog.d/7183.misc | 1 + synapse/logging/context.py | 102 ++++++++++++++++++++++++++++----------------- 2 files changed, 64 insertions(+), 39 deletions(-) create mode 100644 changelog.d/7183.misc (limited to 'synapse') diff --git a/changelog.d/7183.misc b/changelog.d/7183.misc new file mode 100644 index 0000000000..731f4dcb52 --- /dev/null +++ b/changelog.d/7183.misc @@ -0,0 +1 @@ +Clean up some LoggingContext code. diff --git a/synapse/logging/context.py b/synapse/logging/context.py index 3254d6a8df..a8f674d13d 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py @@ -51,7 +51,7 @@ try: is_thread_resource_usage_supported = True - def get_thread_resource_usage(): + def get_thread_resource_usage() -> "Optional[resource._RUsage]": return resource.getrusage(RUSAGE_THREAD) @@ -60,7 +60,7 @@ except Exception: # won't track resource usage. is_thread_resource_usage_supported = False - def get_thread_resource_usage(): + def get_thread_resource_usage() -> "Optional[resource._RUsage]": return None @@ -201,10 +201,10 @@ class _Sentinel(object): record["request"] = None record["scope"] = None - def start(self): + def start(self, rusage: "Optional[resource._RUsage]"): pass - def stop(self): + def stop(self, rusage: "Optional[resource._RUsage]"): pass def add_database_transaction(self, duration_sec): @@ -261,7 +261,7 @@ class LoggingContext(object): # The thread resource usage when the logcontext became active. None # if the context is not currently active. - self.usage_start = None + self.usage_start = None # type: Optional[resource._RUsage] self.main_thread = get_thread_id() self.request = None @@ -336,7 +336,17 @@ class LoggingContext(object): record["request"] = self.request record["scope"] = self.scope - def start(self) -> None: + def start(self, rusage: "Optional[resource._RUsage]") -> None: + """ + Record that this logcontext is currently running. + + This should not be called directly: use set_current_context + + Args: + rusage: the resources used by the current thread, at the point of + switching to this logcontext. May be None if this platform doesn't + support getrusuage. + """ if get_thread_id() != self.main_thread: logger.warning("Started logcontext %s on different thread", self) return @@ -349,36 +359,48 @@ class LoggingContext(object): if self.usage_start: logger.warning("Re-starting already-active log context %s", self) else: - self.usage_start = get_thread_resource_usage() + self.usage_start = rusage - def stop(self) -> None: - if get_thread_id() != self.main_thread: - logger.warning("Stopped logcontext %s on different thread", self) - return + def stop(self, rusage: "Optional[resource._RUsage]") -> None: + """ + Record that this logcontext is no longer running. + + This should not be called directly: use set_current_context + + Args: + rusage: the resources used by the current thread, at the point of + switching away from this logcontext. May be None if this platform + doesn't support getrusuage. + """ + + try: + if get_thread_id() != self.main_thread: + logger.warning("Stopped logcontext %s on different thread", self) + return + + if not rusage: + return - # When we stop, let's record the cpu used since we started - if not self.usage_start: - # Log a warning on platforms that support thread usage tracking - if is_thread_resource_usage_supported: + # Record the cpu used since we started + if not self.usage_start: logger.warning( - "Called stop on logcontext %s without calling start", self + "Called stop on logcontext %s without recording a start rusage", + self, ) - return - - utime_delta, stime_delta = self._get_cputime() - self._resource_usage.ru_utime += utime_delta - self._resource_usage.ru_stime += stime_delta + return - self.usage_start = None + utime_delta, stime_delta = self._get_cputime(rusage) + self._resource_usage.ru_utime += utime_delta + self._resource_usage.ru_stime += stime_delta - # if we have a parent, pass our CPU usage stats on - if self.parent_context is not None and hasattr( - self.parent_context, "_resource_usage" - ): - self.parent_context._resource_usage += self._resource_usage + # if we have a parent, pass our CPU usage stats on + if self.parent_context: + self.parent_context._resource_usage += self._resource_usage - # reset them in case we get entered again - self._resource_usage.reset() + # reset them in case we get entered again + self._resource_usage.reset() + finally: + self.usage_start = None def get_resource_usage(self) -> ContextResourceUsage: """Get resources used by this logcontext so far. @@ -394,24 +416,24 @@ class LoggingContext(object): # can include resource usage so far. is_main_thread = get_thread_id() == self.main_thread if self.usage_start and is_main_thread: - utime_delta, stime_delta = self._get_cputime() + rusage = get_thread_resource_usage() + assert rusage is not None + utime_delta, stime_delta = self._get_cputime(rusage) res.ru_utime += utime_delta res.ru_stime += stime_delta return res - def _get_cputime(self) -> Tuple[float, float]: - """Get the cpu usage time so far + def _get_cputime(self, current: "resource._RUsage") -> Tuple[float, float]: + """Get the cpu usage time between start() and the given rusage + + Args: + rusage: the current resource usage Returns: Tuple[float, float]: seconds in user mode, seconds in system mode """ assert self.usage_start is not None - current = get_thread_resource_usage() - - # Indicate to mypy that we know that self.usage_start is None. - assert self.usage_start is not None - utime_delta = current.ru_utime - self.usage_start.ru_utime stime_delta = current.ru_stime - self.usage_start.ru_stime @@ -547,9 +569,11 @@ def set_current_context(context: LoggingContextOrSentinel) -> LoggingContextOrSe current = current_context() if current is not context: - current.stop() + rusage = get_thread_resource_usage() + current.stop(rusage) _thread_local.current_context = context - context.start() + context.start(rusage) + return current -- cgit 1.4.1 From 07b88c546de1b24f5cbc9b4cb6da98400a8155af Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Fri, 3 Apr 2020 14:26:07 +0100 Subject: Convert http.HTTPStatus objects to their int equivalent (#7188) --- changelog.d/7188.misc | 1 + synapse/api/errors.py | 9 ++++++++- 2 files changed, 9 insertions(+), 1 deletion(-) create mode 100644 changelog.d/7188.misc (limited to 'synapse') diff --git a/changelog.d/7188.misc b/changelog.d/7188.misc new file mode 100644 index 0000000000..f72955b95b --- /dev/null +++ b/changelog.d/7188.misc @@ -0,0 +1 @@ +Fix consistency of HTTP status codes reported in log lines. diff --git a/synapse/api/errors.py b/synapse/api/errors.py index 11da016ac5..d54dfb385d 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -86,7 +86,14 @@ class CodeMessageException(RuntimeError): def __init__(self, code, msg): super(CodeMessageException, self).__init__("%d: %s" % (code, msg)) - self.code = code + + # Some calls to this method pass instances of http.HTTPStatus for `code`. + # While HTTPStatus is a subclass of int, it has magic __str__ methods + # which emit `HTTPStatus.FORBIDDEN` when converted to a str, instead of `403`. + # This causes inconsistency in our log lines. + # + # To eliminate this behaviour, we convert them to their integer equivalents here. + self.code = int(code) self.msg = msg -- cgit 1.4.1 From b0db928c633ad2e225623cffb20293629c5d5a43 Mon Sep 17 00:00:00 2001 From: Martin Milata Date: Fri, 3 Apr 2020 17:57:34 +0200 Subject: Extend web_client_location to handle absolute URLs (#7006) Log warning when filesystem path is used. Signed-off-by: Martin Milata --- changelog.d/7006.feature | 1 + docs/sample_config.yaml | 11 ++++++++--- synapse/app/homeserver.py | 16 +++++++++++++--- synapse/config/server.py | 11 ++++++++--- 4 files changed, 30 insertions(+), 9 deletions(-) create mode 100644 changelog.d/7006.feature (limited to 'synapse') diff --git a/changelog.d/7006.feature b/changelog.d/7006.feature new file mode 100644 index 0000000000..d2ce9dbaca --- /dev/null +++ b/changelog.d/7006.feature @@ -0,0 +1 @@ +Extend the `web_client_location` option to accept an absolute URL to use as a redirect. Adds a warning when running the web client on the same hostname as homeserver. Contributed by Martin Milata. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 6a770508f9..be742969cc 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -33,10 +33,15 @@ server_name: "SERVERNAME" # pid_file: DATADIR/homeserver.pid -# The path to the web client which will be served at /_matrix/client/ -# if 'webclient' is configured under the 'listeners' configuration. +# The absolute URL to the web client which /_matrix/client will redirect +# to if 'webclient' is configured under the 'listeners' configuration. # -#web_client_location: "/path/to/web/root" +# This option can be also set to the filesystem path to the web client +# which will be served at /_matrix/client/ if 'webclient' is configured +# under the 'listeners' configuration, however this is a security risk: +# https://github.com/matrix-org/synapse#security-note +# +#web_client_location: https://riot.example.com/ # The public-facing base URL that clients use to access this HS # (not including _matrix/...). This is the same URL a user would diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index f2b56a636f..49df63acd0 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -241,16 +241,26 @@ class SynapseHomeServer(HomeServer): resources[SERVER_KEY_V2_PREFIX] = KeyApiV2Resource(self) if name == "webclient": - webclient_path = self.get_config().web_client_location + webclient_loc = self.get_config().web_client_location - if webclient_path is None: + if webclient_loc is None: logger.warning( "Not enabling webclient resource, as web_client_location is unset." ) + elif webclient_loc.startswith("http://") or webclient_loc.startswith( + "https://" + ): + resources[WEB_CLIENT_PREFIX] = RootRedirect(webclient_loc) else: + logger.warning( + "Running webclient on the same domain is not recommended: " + "https://github.com/matrix-org/synapse#security-note - " + "after you move webclient to different host you can set " + "web_client_location to its full URL to enable redirection." + ) # GZip is disabled here due to # https://twistedmatrix.com/trac/ticket/7678 - resources[WEB_CLIENT_PREFIX] = File(webclient_path) + resources[WEB_CLIENT_PREFIX] = File(webclient_loc) if name == "metrics" and self.get_config().enable_metrics: resources[METRICS_PREFIX] = MetricsResource(RegistryProxy) diff --git a/synapse/config/server.py b/synapse/config/server.py index 7525765fee..28e2a031fb 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -604,10 +604,15 @@ class ServerConfig(Config): # pid_file: %(pid_file)s - # The path to the web client which will be served at /_matrix/client/ - # if 'webclient' is configured under the 'listeners' configuration. + # The absolute URL to the web client which /_matrix/client will redirect + # to if 'webclient' is configured under the 'listeners' configuration. # - #web_client_location: "/path/to/web/root" + # This option can be also set to the filesystem path to the web client + # which will be served at /_matrix/client/ if 'webclient' is configured + # under the 'listeners' configuration, however this is a security risk: + # https://github.com/matrix-org/synapse#security-note + # + #web_client_location: https://riot.example.com/ # The public-facing base URL that clients use to access this HS # (not including _matrix/...). This is the same URL a user would -- cgit 1.4.1 From 694d8bed0e56366f080a49db0f930d635ca6cdf4 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 3 Apr 2020 15:35:05 -0400 Subject: Support CAS in UI Auth flows. (#7186) --- changelog.d/7186.feature | 1 + synapse/handlers/auth.py | 4 +- synapse/handlers/cas_handler.py | 161 +++++++++++++++++++---------------- synapse/rest/client/v1/login.py | 20 ++++- synapse/rest/client/v2_alpha/auth.py | 28 ++++-- 5 files changed, 131 insertions(+), 83 deletions(-) create mode 100644 changelog.d/7186.feature (limited to 'synapse') diff --git a/changelog.d/7186.feature b/changelog.d/7186.feature new file mode 100644 index 0000000000..01057aa396 --- /dev/null +++ b/changelog.d/7186.feature @@ -0,0 +1 @@ +Support SSO in the user interactive authentication workflow. diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 7c09d15a72..892adb00b9 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -116,7 +116,7 @@ class AuthHandler(BaseHandler): self.hs = hs # FIXME better possibility to access registrationHandler later? self.macaroon_gen = hs.get_macaroon_generator() self._password_enabled = hs.config.password_enabled - self._saml2_enabled = hs.config.saml2_enabled + self._sso_enabled = hs.config.saml2_enabled or hs.config.cas_enabled # we keep this as a list despite the O(N^2) implication so that we can # keep PASSWORD first and avoid confusing clients which pick the first @@ -136,7 +136,7 @@ class AuthHandler(BaseHandler): # necessarily identical. Login types have SSO (and other login types) # added in the rest layer, see synapse.rest.client.v1.login.LoginRestServerlet.on_GET. ui_auth_types = login_types.copy() - if self._saml2_enabled: + if self._sso_enabled: ui_auth_types.append(LoginType.SSO) self._supported_ui_auth_types = ui_auth_types diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py index f8dc274b78..d977badf35 100644 --- a/synapse/handlers/cas_handler.py +++ b/synapse/handlers/cas_handler.py @@ -15,7 +15,7 @@ import logging import xml.etree.ElementTree as ET -from typing import AnyStr, Dict, Optional, Tuple +from typing import Dict, Optional, Tuple from six.moves import urllib @@ -48,26 +48,47 @@ class CasHandler: self._http_client = hs.get_proxied_http_client() - def _build_service_param(self, client_redirect_url: AnyStr) -> str: + def _build_service_param(self, args: Dict[str, str]) -> str: + """ + Generates a value to use as the "service" parameter when redirecting or + querying the CAS service. + + Args: + args: Additional arguments to include in the final redirect URL. + + Returns: + The URL to use as a "service" parameter. + """ return "%s%s?%s" % ( self._cas_service_url, "/_matrix/client/r0/login/cas/ticket", - urllib.parse.urlencode({"redirectUrl": client_redirect_url}), + urllib.parse.urlencode(args), ) - async def _handle_cas_response( - self, request: SynapseRequest, cas_response_body: str, client_redirect_url: str - ) -> None: + async def _validate_ticket( + self, ticket: str, service_args: Dict[str, str] + ) -> Tuple[str, Optional[str]]: """ - Retrieves the user and display name from the CAS response and continues with the authentication. + Validate a CAS ticket with the server, parse the response, and return the user and display name. Args: - request: The original client request. - cas_response_body: The response from the CAS server. - client_redirect_url: The URl to redirect the client to when - everything is done. + ticket: The CAS ticket from the client. + service_args: Additional arguments to include in the service URL. + Should be the same as those passed to `get_redirect_url`. """ - user, attributes = self._parse_cas_response(cas_response_body) + uri = self._cas_server_url + "/proxyValidate" + args = { + "ticket": ticket, + "service": self._build_service_param(service_args), + } + try: + body = await self._http_client.get_raw(uri, args) + except PartialDownloadError as pde: + # Twisted raises this error if the connection is closed, + # even if that's being used old-http style to signal end-of-data + body = pde.response + + user, attributes = self._parse_cas_response(body) displayname = attributes.pop(self._cas_displayname_attribute, None) for required_attribute, required_value in self._cas_required_attributes.items(): @@ -82,7 +103,7 @@ class CasHandler: if required_value != actual_value: raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED) - await self._on_successful_auth(user, request, client_redirect_url, displayname) + return user, displayname def _parse_cas_response( self, cas_response_body: str @@ -127,78 +148,74 @@ class CasHandler: ) return user, attributes - async def _on_successful_auth( - self, - username: str, - request: SynapseRequest, - client_redirect_url: str, - user_display_name: Optional[str] = None, - ) -> None: - """Called once the user has successfully authenticated with the SSO. - - Registers the user if necessary, and then returns a redirect (with - a login token) to the client. + def get_redirect_url(self, service_args: Dict[str, str]) -> str: + """ + Generates a URL for the CAS server where the client should be redirected. Args: - username: the remote user id. We'll map this onto - something sane for a MXID localpath. + service_args: Additional arguments to include in the final redirect URL. - request: the incoming request from the browser. We'll - respond to it with a redirect. + Returns: + The URL to redirect the client to. + """ + args = urllib.parse.urlencode( + {"service": self._build_service_param(service_args)} + ) - client_redirect_url: the redirect_url the client gave us when - it first started the process. + return "%s/login?%s" % (self._cas_server_url, args) - user_display_name: if set, and we have to register a new user, - we will set their displayname to this. + async def handle_ticket( + self, + request: SynapseRequest, + ticket: str, + client_redirect_url: Optional[str], + session: Optional[str], + ) -> None: """ - localpart = map_username_to_mxid_localpart(username) - user_id = UserID(localpart, self._hostname).to_string() - registered_user_id = await self._auth_handler.check_user_exists(user_id) - if not registered_user_id: - registered_user_id = await self._registration_handler.register_user( - localpart=localpart, default_display_name=user_display_name - ) + Called once the user has successfully authenticated with the SSO. + Validates a CAS ticket sent by the client and completes the auth process. - self._auth_handler.complete_sso_login( - registered_user_id, request, client_redirect_url - ) + If the user interactive authentication session is provided, marks the + UI Auth session as complete, then returns an HTML page notifying the + user they are done. - def handle_redirect_request(self, client_redirect_url: bytes) -> bytes: - """ - Generates a URL to the CAS server where the client should be redirected. + Otherwise, this registers the user if necessary, and then returns a + redirect (with a login token) to the client. Args: - client_redirect_url: The final URL the client should go to after the - user has negotiated SSO. + request: the incoming request from the browser. We'll + respond to it with a redirect or an HTML page. - Returns: - The URL to redirect to. - """ - args = urllib.parse.urlencode( - {"service": self._build_service_param(client_redirect_url)} - ) + ticket: The CAS ticket provided by the client. - return ("%s/login?%s" % (self._cas_server_url, args)).encode("ascii") + client_redirect_url: the redirectUrl parameter from the `/cas/ticket` HTTP request, if given. + This should be the same as the redirectUrl from the original `/login/sso/redirect` request. - async def handle_ticket_request( - self, request: SynapseRequest, client_redirect_url: str, ticket: str - ) -> None: + session: The session parameter from the `/cas/ticket` HTTP request, if given. + This should be the UI Auth session id. """ - Validates a CAS ticket sent by the client for login/registration. + args = {} + if client_redirect_url: + args["redirectUrl"] = client_redirect_url + if session: + args["session"] = session + username, user_display_name = await self._validate_ticket(ticket, args) - On a successful request, writes a redirect to the request. - """ - uri = self._cas_server_url + "/proxyValidate" - args = { - "ticket": ticket, - "service": self._build_service_param(client_redirect_url), - } - try: - body = await self._http_client.get_raw(uri, args) - except PartialDownloadError as pde: - # Twisted raises this error if the connection is closed, - # even if that's being used old-http style to signal end-of-data - body = pde.response + localpart = map_username_to_mxid_localpart(username) + user_id = UserID(localpart, self._hostname).to_string() + registered_user_id = await self._auth_handler.check_user_exists(user_id) - await self._handle_cas_response(request, body, client_redirect_url) + if session: + self._auth_handler.complete_sso_ui_auth( + registered_user_id, session, request, + ) + + else: + if not registered_user_id: + registered_user_id = await self._registration_handler.register_user( + localpart=localpart, default_display_name=user_display_name + ) + + self._auth_handler.complete_sso_login( + registered_user_id, request, client_redirect_url + ) diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 59593cbf6e..4de2f97d06 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -425,7 +425,9 @@ class CasRedirectServlet(BaseSSORedirectServlet): self._cas_handler = hs.get_cas_handler() def get_sso_url(self, client_redirect_url: bytes) -> bytes: - return self._cas_handler.handle_redirect_request(client_redirect_url) + return self._cas_handler.get_redirect_url( + {"redirectUrl": client_redirect_url} + ).encode("ascii") class CasTicketServlet(RestServlet): @@ -436,10 +438,20 @@ class CasTicketServlet(RestServlet): self._cas_handler = hs.get_cas_handler() async def on_GET(self, request: SynapseRequest) -> None: - client_redirect_url = parse_string(request, "redirectUrl", required=True) + client_redirect_url = parse_string(request, "redirectUrl") ticket = parse_string(request, "ticket", required=True) - await self._cas_handler.handle_ticket_request( - request, client_redirect_url, ticket + + # Maybe get a session ID (if this ticket is from user interactive + # authentication). + session = parse_string(request, "session") + + # Either client_redirect_url or session must be provided. + if not client_redirect_url and not session: + message = "Missing string query parameter redirectUrl or session" + raise SynapseError(400, message, errcode=Codes.MISSING_PARAM) + + await self._cas_handler.handle_ticket( + request, ticket, client_redirect_url, session ) diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py index 1787562b90..13f9604407 100644 --- a/synapse/rest/client/v2_alpha/auth.py +++ b/synapse/rest/client/v2_alpha/auth.py @@ -111,6 +111,11 @@ class AuthRestServlet(RestServlet): self._saml_enabled = hs.config.saml2_enabled if self._saml_enabled: self._saml_handler = hs.get_saml_handler() + self._cas_enabled = hs.config.cas_enabled + if self._cas_enabled: + self._cas_handler = hs.get_cas_handler() + self._cas_server_url = hs.config.cas_server_url + self._cas_service_url = hs.config.cas_service_url def on_GET(self, request, stagetype): session = parse_string(request, "session") @@ -133,14 +138,27 @@ class AuthRestServlet(RestServlet): % (CLIENT_API_PREFIX, LoginType.TERMS), } - elif stagetype == LoginType.SSO and self._saml_enabled: + elif stagetype == LoginType.SSO: # Display a confirmation page which prompts the user to # re-authenticate with their SSO provider. - client_redirect_url = "" - sso_redirect_url = self._saml_handler.handle_redirect_request( - client_redirect_url, session - ) + if self._cas_enabled: + # Generate a request to CAS that redirects back to an endpoint + # to verify the successful authentication. + sso_redirect_url = self._cas_handler.get_redirect_url( + {"session": session}, + ) + + elif self._saml_enabled: + client_redirect_url = "" + sso_redirect_url = self._saml_handler.handle_redirect_request( + client_redirect_url, session + ) + + else: + raise SynapseError(400, "Homeserver not configured for SSO.") + html = self.auth_handler.start_sso_ui_auth(sso_redirect_url, session) + else: raise SynapseError(404, "Unknown auth stage type") -- cgit 1.4.1 From d73bf18d13031d9f9c0375b83f2cc5ff6f415251 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Sat, 4 Apr 2020 17:27:45 +0200 Subject: Server notices: Dissociate room creation/lookup from invite (#7199) Fixes #6815 Before figuring out whether we should alert a user on MAU, we call get_notice_room_for_user to get some info on the existing server notices room for this user. This function, if the room doesn't exist, creates it and invites the user in it. This means that, if we decide later that no server notice is needed, the user gets invited in a room with no message in it. This happens at every restart of the server, since the room ID returned by get_notice_room_for_user is cached. This PR fixes that by moving the inviting bit to a dedicated function, that's only called when the server actually needs to send a notice to the user. A potential issue with this approach is that the room that's created by get_notice_room_for_user doesn't match how that same function looks for an existing room (i.e. it creates a room that doesn't have an invite or a join for the current user in it, so it could lead to a new room being created each time a user syncs), but I'm not sure this is a problem given it's cached until the server restarts, so that function won't run very often. It also renames get_notice_room_for_user into get_or_create_notice_room_for_user to make what it does clearer. --- changelog.d/7199.bugfix | 1 + .../resource_limits_server_notices.py | 4 +- synapse/server_notices/server_notices_manager.py | 51 +++++++-- .../test_resource_limits_server_notices.py | 120 ++++++++++++++++++--- 4 files changed, 154 insertions(+), 22 deletions(-) create mode 100644 changelog.d/7199.bugfix (limited to 'synapse') diff --git a/changelog.d/7199.bugfix b/changelog.d/7199.bugfix new file mode 100644 index 0000000000..b234163ea8 --- /dev/null +++ b/changelog.d/7199.bugfix @@ -0,0 +1 @@ +Fix a bug that could cause a user to be invited to a server notices (aka System Alerts) room without any notice being sent. diff --git a/synapse/server_notices/resource_limits_server_notices.py b/synapse/server_notices/resource_limits_server_notices.py index 9fae2e0afe..ce4a828894 100644 --- a/synapse/server_notices/resource_limits_server_notices.py +++ b/synapse/server_notices/resource_limits_server_notices.py @@ -80,7 +80,9 @@ class ResourceLimitsServerNotices(object): # In practice, not sure we can ever get here return - room_id = yield self._server_notices_manager.get_notice_room_for_user(user_id) + room_id = yield self._server_notices_manager.get_or_create_notice_room_for_user( + user_id + ) if not room_id: logger.warning("Failed to get server notices room") diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py index f7432c8d2f..bf0943f265 100644 --- a/synapse/server_notices/server_notices_manager.py +++ b/synapse/server_notices/server_notices_manager.py @@ -17,7 +17,7 @@ import logging from twisted.internet import defer from synapse.api.constants import EventTypes, Membership, RoomCreationPreset -from synapse.types import create_requester +from synapse.types import UserID, create_requester from synapse.util.caches.descriptors import cachedInlineCallbacks logger = logging.getLogger(__name__) @@ -36,10 +36,12 @@ class ServerNoticesManager(object): self._store = hs.get_datastore() self._config = hs.config self._room_creation_handler = hs.get_room_creation_handler() + self._room_member_handler = hs.get_room_member_handler() self._event_creation_handler = hs.get_event_creation_handler() self._is_mine_id = hs.is_mine_id self._notifier = hs.get_notifier() + self.server_notices_mxid = self._config.server_notices_mxid def is_enabled(self): """Checks if server notices are enabled on this server. @@ -66,7 +68,8 @@ class ServerNoticesManager(object): Returns: Deferred[FrozenEvent] """ - room_id = yield self.get_notice_room_for_user(user_id) + room_id = yield self.get_or_create_notice_room_for_user(user_id) + yield self.maybe_invite_user_to_room(user_id, room_id) system_mxid = self._config.server_notices_mxid requester = create_requester(system_mxid) @@ -89,10 +92,11 @@ class ServerNoticesManager(object): return res @cachedInlineCallbacks() - def get_notice_room_for_user(self, user_id): + def get_or_create_notice_room_for_user(self, user_id): """Get the room for notices for a given user - If we have not yet created a notice room for this user, create it + If we have not yet created a notice room for this user, create it, but don't + invite the user to it. Args: user_id (str): complete user id for the user we want a room for @@ -108,7 +112,6 @@ class ServerNoticesManager(object): rooms = yield self._store.get_rooms_for_local_user_where_membership_is( user_id, [Membership.INVITE, Membership.JOIN] ) - system_mxid = self._config.server_notices_mxid for room in rooms: # it's worth noting that there is an asymmetry here in that we # expect the user to be invited or joined, but the system user must @@ -116,10 +119,14 @@ class ServerNoticesManager(object): # manages to invite the system user to a room, that doesn't make it # the server notices room. user_ids = yield self._store.get_users_in_room(room.room_id) - if system_mxid in user_ids: + if self.server_notices_mxid in user_ids: # we found a room which our user shares with the system notice # user - logger.info("Using room %s", room.room_id) + logger.info( + "Using existing server notices room %s for user %s", + room.room_id, + user_id, + ) return room.room_id # apparently no existing notice room: create a new one @@ -138,14 +145,13 @@ class ServerNoticesManager(object): "avatar_url": self._config.server_notices_mxid_avatar_url, } - requester = create_requester(system_mxid) + requester = create_requester(self.server_notices_mxid) info = yield self._room_creation_handler.create_room( requester, config={ "preset": RoomCreationPreset.PRIVATE_CHAT, "name": self._config.server_notices_room_name, "power_level_content_override": {"users_default": -10}, - "invite": (user_id,), }, ratelimit=False, creator_join_profile=join_profile, @@ -159,3 +165,30 @@ class ServerNoticesManager(object): logger.info("Created server notices room %s for %s", room_id, user_id) return room_id + + @defer.inlineCallbacks + def maybe_invite_user_to_room(self, user_id: str, room_id: str): + """Invite the given user to the given server room, unless the user has already + joined or been invited to it. + + Args: + user_id: The ID of the user to invite. + room_id: The ID of the room to invite the user to. + """ + requester = create_requester(self.server_notices_mxid) + + # Check whether the user has already joined or been invited to this room. If + # that's the case, there is no need to re-invite them. + joined_rooms = yield self._store.get_rooms_for_local_user_where_membership_is( + user_id, [Membership.INVITE, Membership.JOIN] + ) + for room in joined_rooms: + if room.room_id == room_id: + return + + yield self._room_member_handler.update_membership( + requester=requester, + target=UserID.from_string(user_id), + room_id=room_id, + action="invite", + ) diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index 0d27b92a86..93eb053b8c 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -19,6 +19,9 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, LimitBlockingTypes, ServerNoticeMsgType from synapse.api.errors import ResourceLimitError +from synapse.rest import admin +from synapse.rest.client.v1 import login, room +from synapse.rest.client.v2_alpha import sync from synapse.server_notices.resource_limits_server_notices import ( ResourceLimitsServerNotices, ) @@ -67,7 +70,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): # self.server_notices_mxid_avatar_url = None # self.server_notices_room_name = "Server Notices" - self._rlsn._server_notices_manager.get_notice_room_for_user = Mock( + self._rlsn._server_notices_manager.get_or_create_notice_room_for_user = Mock( returnValue="" ) self._rlsn._store.add_tag_to_room = Mock() @@ -215,6 +218,26 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + room.register_servlets, + sync.register_servlets, + ] + + def default_config(self): + c = super().default_config() + c["server_notices"] = { + "system_mxid_localpart": "server", + "system_mxid_display_name": None, + "system_mxid_avatar_url": None, + "room_name": "Test Server Notice Room", + } + c["limit_usage_by_mau"] = True + c["max_mau_value"] = 5 + c["admin_contact"] = "mailto:user@test.com" + return c + def prepare(self, reactor, clock, hs): self.store = self.hs.get_datastore() self.server_notices_sender = self.hs.get_server_notices_sender() @@ -228,18 +251,8 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): if not isinstance(self._rlsn, ResourceLimitsServerNotices): raise Exception("Failed to find reference to ResourceLimitsServerNotices") - self.hs.config.limit_usage_by_mau = True - self.hs.config.hs_disabled = False - self.hs.config.max_mau_value = 5 - self.hs.config.server_notices_mxid = "@server:test" - self.hs.config.server_notices_mxid_display_name = None - self.hs.config.server_notices_mxid_avatar_url = None - self.hs.config.server_notices_room_name = "Test Server Notice Room" - self.user_id = "@user_id:test" - self.hs.config.admin_contact = "mailto:user@test.com" - def test_server_notice_only_sent_once(self): self.store.get_monthly_active_count = Mock(return_value=1000) @@ -253,7 +266,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): # Now lets get the last load of messages in the service notice room and # check that there is only one server notice room_id = self.get_success( - self.server_notices_manager.get_notice_room_for_user(self.user_id) + self.server_notices_manager.get_or_create_notice_room_for_user(self.user_id) ) token = self.get_success(self.event_source.get_current_token()) @@ -273,3 +286,86 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): count += 1 self.assertEqual(count, 1) + + def test_no_invite_without_notice(self): + """Tests that a user doesn't get invited to a server notices room without a + server notice being sent. + + The scenario for this test is a single user on a server where the MAU limit + hasn't been reached (since it's the only user and the limit is 5), so users + shouldn't receive a server notice. + """ + self.register_user("user", "password") + tok = self.login("user", "password") + + request, channel = self.make_request("GET", "/sync?timeout=0", access_token=tok) + self.render(request) + + invites = channel.json_body["rooms"]["invite"] + self.assertEqual(len(invites), 0, invites) + + def test_invite_with_notice(self): + """Tests that, if the MAU limit is hit, the server notices user invites each user + to a room in which it has sent a notice. + """ + user_id, tok, room_id = self._trigger_notice_and_join() + + # Sync again to retrieve the events in the room, so we can check whether this + # room has a notice in it. + request, channel = self.make_request("GET", "/sync?timeout=0", access_token=tok) + self.render(request) + + # Scan the events in the room to search for a message from the server notices + # user. + events = channel.json_body["rooms"]["join"][room_id]["timeline"]["events"] + notice_in_room = False + for event in events: + if ( + event["type"] == EventTypes.Message + and event["sender"] == self.hs.config.server_notices_mxid + ): + notice_in_room = True + + self.assertTrue(notice_in_room, "No server notice in room") + + def _trigger_notice_and_join(self): + """Creates enough active users to hit the MAU limit and trigger a system notice + about it, then joins the system notices room with one of the users created. + + Returns: + user_id (str): The ID of the user that joined the room. + tok (str): The access token of the user that joined the room. + room_id (str): The ID of the room that's been joined. + """ + user_id = None + tok = None + invites = [] + + # Register as many users as the MAU limit allows. + for i in range(self.hs.config.max_mau_value): + localpart = "user%d" % i + user_id = self.register_user(localpart, "password") + tok = self.login(localpart, "password") + + # Sync with the user's token to mark the user as active. + request, channel = self.make_request( + "GET", "/sync?timeout=0", access_token=tok, + ) + self.render(request) + + # Also retrieves the list of invites for this user. We don't care about that + # one except if we're processing the last user, which should have received an + # invite to a room with a server notice about the MAU limit being reached. + # We could also pick another user and sync with it, which would return an + # invite to a system notices room, but it doesn't matter which user we're + # using so we use the last one because it saves us an extra sync. + invites = channel.json_body["rooms"]["invite"] + + # Make sure we have an invite to process. + self.assertEqual(len(invites), 1, invites) + + # Join the room. + room_id = list(invites.keys())[0] + self.helper.join(room=room_id, user=user_id, tok=tok) + + return user_id, tok, room_id -- cgit 1.4.1 From 5016b162fcf0372fe35404c64f80aeaf21461f31 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 6 Apr 2020 09:58:42 +0100 Subject: Move client command handling out of TCP protocol (#7185) The aim here is to move the command handling out of the TCP protocol classes and to also merge the client and server command handling (so that we can reuse them for redis protocol). This PR simply moves the client paths to the new `ReplicationCommandHandler`, a future PR will move the server paths too. --- changelog.d/7185.misc | 1 + synapse/app/admin_cmd.py | 12 -- synapse/app/generic_worker.py | 9 +- synapse/replication/tcp/__init__.py | 30 ++- synapse/replication/tcp/client.py | 179 +++--------------- synapse/replication/tcp/handler.py | 252 +++++++++++++++++++++++++ synapse/replication/tcp/protocol.py | 197 +++---------------- synapse/server.py | 8 +- synapse/server.pyi | 7 +- tests/replication/slave/storage/_base.py | 15 +- tests/replication/tcp/streams/_base.py | 38 ++-- tests/replication/tcp/streams/test_receipts.py | 1 - 12 files changed, 378 insertions(+), 371 deletions(-) create mode 100644 changelog.d/7185.misc create mode 100644 synapse/replication/tcp/handler.py (limited to 'synapse') diff --git a/changelog.d/7185.misc b/changelog.d/7185.misc new file mode 100644 index 0000000000..deb9ca7021 --- /dev/null +++ b/changelog.d/7185.misc @@ -0,0 +1 @@ +Move client command handling out of TCP protocol. diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py index 1c7c6ec0c8..a37818fe9a 100644 --- a/synapse/app/admin_cmd.py +++ b/synapse/app/admin_cmd.py @@ -43,7 +43,6 @@ from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore from synapse.replication.slave.storage.receipts import SlavedReceiptsStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.replication.slave.storage.room import RoomStore -from synapse.replication.tcp.client import ReplicationClientHandler from synapse.server import HomeServer from synapse.util.logcontext import LoggingContext from synapse.util.versionstring import get_version_string @@ -79,17 +78,6 @@ class AdminCmdServer(HomeServer): def start_listening(self, listeners): pass - def build_tcp_replication(self): - return AdminCmdReplicationHandler(self) - - -class AdminCmdReplicationHandler(ReplicationClientHandler): - async def on_rdata(self, stream_name, token, rows): - pass - - def get_streams_to_replicate(self): - return {} - @defer.inlineCallbacks def export_data_command(hs, args): diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 174bef360f..dcd0709a02 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -64,7 +64,7 @@ from synapse.replication.slave.storage.receipts import SlavedReceiptsStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.replication.slave.storage.room import RoomStore from synapse.replication.slave.storage.transactions import SlavedTransactionStore -from synapse.replication.tcp.client import ReplicationClientHandler +from synapse.replication.tcp.client import ReplicationDataHandler from synapse.replication.tcp.commands import ClearUserSyncsCommand from synapse.replication.tcp.streams import ( AccountDataStream, @@ -603,7 +603,7 @@ class GenericWorkerServer(HomeServer): def remove_pusher(self, app_id, push_key, user_id): self.get_tcp_replication().send_remove_pusher(app_id, push_key, user_id) - def build_tcp_replication(self): + def build_replication_data_handler(self): return GenericWorkerReplicationHandler(self) def build_presence_handler(self): @@ -613,7 +613,7 @@ class GenericWorkerServer(HomeServer): return GenericWorkerTyping(self) -class GenericWorkerReplicationHandler(ReplicationClientHandler): +class GenericWorkerReplicationHandler(ReplicationDataHandler): def __init__(self, hs): super(GenericWorkerReplicationHandler, self).__init__(hs.get_datastore()) @@ -644,9 +644,6 @@ class GenericWorkerReplicationHandler(ReplicationClientHandler): args.update(self.send_handler.stream_positions()) return args - def get_currently_syncing_users(self): - return self.presence_handler.get_currently_syncing_users() - async def process_and_notify(self, stream_name, token, rows): try: if self.send_handler: diff --git a/synapse/replication/tcp/__init__.py b/synapse/replication/tcp/__init__.py index 81c2ea7ee9..523a1358d4 100644 --- a/synapse/replication/tcp/__init__.py +++ b/synapse/replication/tcp/__init__.py @@ -20,11 +20,31 @@ Further details can be found in docs/tcp_replication.rst Structure of the module: - * client.py - the client classes used for workers to connect to master + * handler.py - the classes used to handle sending/receiving commands to + replication * command.py - the definitions of all the valid commands - * protocol.py - contains bot the client and server protocol implementations, - these should not be used directly - * resource.py - the server classes that accepts and handle client connections - * streams.py - the definitons of all the valid streams + * protocol.py - the TCP protocol classes + * resource.py - handles streaming stream updates to replications + * streams/ - the definitons of all the valid streams + +The general interaction of the classes are: + + +---------------------+ + | ReplicationStreamer | + +---------------------+ + | + v + +---------------------------+ +----------------------+ + | ReplicationCommandHandler |---->|ReplicationDataHandler| + +---------------------------+ +----------------------+ + | ^ + v | + +-------------+ + | Protocols | + | (TCP/redis) | + +-------------+ + +Where the ReplicationDataHandler (or subclasses) handles incoming stream +updates. """ diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index e86d9805f1..700ae79158 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -16,26 +16,16 @@ """ import logging -from typing import Dict, List, Optional +from typing import TYPE_CHECKING, Dict -from twisted.internet import defer from twisted.internet.protocol import ReconnectingClientFactory from synapse.replication.slave.storage._base import BaseSlavedStore -from synapse.replication.tcp.protocol import ( - AbstractReplicationClientHandler, - ClientReplicationStreamProtocol, -) - -from .commands import ( - Command, - FederationAckCommand, - InvalidateCacheCommand, - RemoteServerUpCommand, - RemovePusherCommand, - UserIpCommand, - UserSyncCommand, -) +from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol + +if TYPE_CHECKING: + from synapse.server import HomeServer + from synapse.replication.tcp.handler import ReplicationCommandHandler logger = logging.getLogger(__name__) @@ -44,16 +34,20 @@ class ReplicationClientFactory(ReconnectingClientFactory): """Factory for building connections to the master. Will reconnect if the connection is lost. - Accepts a handler that will be called when new data is available or data - is required. + Accepts a handler that is passed to `ClientReplicationStreamProtocol`. """ initialDelay = 0.1 maxDelay = 1 # Try at least once every N seconds - def __init__(self, hs, client_name, handler: AbstractReplicationClientHandler): + def __init__( + self, + hs: "HomeServer", + client_name: str, + command_handler: "ReplicationCommandHandler", + ): self.client_name = client_name - self.handler = handler + self.command_handler = command_handler self.server_name = hs.config.server_name self.hs = hs self._clock = hs.get_clock() # As self.clock is defined in super class @@ -66,7 +60,11 @@ class ReplicationClientFactory(ReconnectingClientFactory): def buildProtocol(self, addr): logger.info("Connected to replication: %r", addr) return ClientReplicationStreamProtocol( - self.hs, self.client_name, self.server_name, self._clock, self.handler, + self.hs, + self.client_name, + self.server_name, + self._clock, + self.command_handler, ) def clientConnectionLost(self, connector, reason): @@ -78,41 +76,17 @@ class ReplicationClientFactory(ReconnectingClientFactory): ReconnectingClientFactory.clientConnectionFailed(self, connector, reason) -class ReplicationClientHandler(AbstractReplicationClientHandler): - """A base handler that can be passed to the ReplicationClientFactory. +class ReplicationDataHandler: + """Handles incoming stream updates from replication. - By default proxies incoming replication data to the SlaveStore. + This instance notifies the slave data store about updates. Can be subclassed + to handle updates in additional ways. """ def __init__(self, store: BaseSlavedStore): self.store = store - # The current connection. None if we are currently (re)connecting - self.connection = None - - # Any pending commands to be sent once a new connection has been - # established - self.pending_commands = [] # type: List[Command] - - # Map from string -> deferred, to wake up when receiveing a SYNC with - # the given string. - # Used for tests. - self.awaiting_syncs = {} # type: Dict[str, defer.Deferred] - - # The factory used to create connections. - self.factory = None # type: Optional[ReplicationClientFactory] - - def start_replication(self, hs): - """Helper method to start a replication connection to the remote server - using TCP. - """ - client_name = hs.config.worker_name - self.factory = ReplicationClientFactory(hs, client_name, self) - host = hs.config.worker_replication_host - port = hs.config.worker_replication_port - hs.get_reactor().connectTCP(host, port, self.factory) - - async def on_rdata(self, stream_name, token, rows): + async def on_rdata(self, stream_name: str, token: int, rows: list): """Called to handle a batch of replication data with a given stream token. By default this just pokes the slave store. Can be overridden in subclasses to @@ -124,30 +98,8 @@ class ReplicationClientHandler(AbstractReplicationClientHandler): rows (list): a list of Stream.ROW_TYPE objects as returned by Stream.parse_row. """ - logger.debug("Received rdata %s -> %s", stream_name, token) self.store.process_replication_rows(stream_name, token, rows) - async def on_position(self, stream_name, token): - """Called when we get new position data. By default this just pokes - the slave store. - - Can be overriden in subclasses to handle more. - """ - self.store.process_replication_rows(stream_name, token, []) - - def on_sync(self, data): - """When we received a SYNC we wake up any deferreds that were waiting - for the sync with the given data. - - Used by tests. - """ - d = self.awaiting_syncs.pop(data, None) - if d: - d.callback(data) - - def on_remote_server_up(self, server: str): - """Called when get a new REMOTE_SERVER_UP command.""" - def get_streams_to_replicate(self) -> Dict[str, int]: """Called when a new connection has been established and we need to subscribe to streams. @@ -163,85 +115,10 @@ class ReplicationClientHandler(AbstractReplicationClientHandler): args["account_data"] = user_account_data elif room_account_data: args["account_data"] = room_account_data - return args - def get_currently_syncing_users(self): - """Get the list of currently syncing users (if any). This is called - when a connection has been established and we need to send the - currently syncing users. (Overriden by the synchrotron's only) - """ - return [] - - def send_command(self, cmd): - """Send a command to master (when we get establish a connection if we - don't have one already.) - """ - if self.connection: - self.connection.send_command(cmd) - else: - logger.warning("Queuing command as not connected: %r", cmd.NAME) - self.pending_commands.append(cmd) - - def send_federation_ack(self, token): - """Ack data for the federation stream. This allows the master to drop - data stored purely in memory. - """ - self.send_command(FederationAckCommand(token)) - - def send_user_sync(self, instance_id, user_id, is_syncing, last_sync_ms): - """Poke the master that a user has started/stopped syncing. - """ - self.send_command( - UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms) - ) - - def send_remove_pusher(self, app_id, push_key, user_id): - """Poke the master to remove a pusher for a user - """ - cmd = RemovePusherCommand(app_id, push_key, user_id) - self.send_command(cmd) - - def send_invalidate_cache(self, cache_func, keys): - """Poke the master to invalidate a cache. - """ - cmd = InvalidateCacheCommand(cache_func.__name__, keys) - self.send_command(cmd) - - def send_user_ip(self, user_id, access_token, ip, user_agent, device_id, last_seen): - """Tell the master that the user made a request. - """ - cmd = UserIpCommand(user_id, access_token, ip, user_agent, device_id, last_seen) - self.send_command(cmd) - - def send_remote_server_up(self, server: str): - self.send_command(RemoteServerUpCommand(server)) - - def await_sync(self, data): - """Returns a deferred that is resolved when we receive a SYNC command - with given data. - - [Not currently] used by tests. - """ - return self.awaiting_syncs.setdefault(data, defer.Deferred()) - - def update_connection(self, connection): - """Called when a connection has been established (or lost with None). - """ - self.connection = connection - if connection: - for cmd in self.pending_commands: - connection.send_command(cmd) - self.pending_commands = [] - - def finished_connecting(self): - """Called when we have successfully subscribed and caught up to all - streams we're interested in. - """ - logger.info("Finished connecting to server") + async def on_position(self, stream_name: str, token: int): + self.store.process_replication_rows(stream_name, token, []) - # We don't reset the delay any earlier as otherwise if there is a - # problem during start up we'll end up tight looping connecting to the - # server. - if self.factory: - self.factory.resetDelay() + def on_remote_server_up(self, server: str): + """Called when get a new REMOTE_SERVER_UP command.""" diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py new file mode 100644 index 0000000000..12a1cfd6d1 --- /dev/null +++ b/synapse/replication/tcp/handler.py @@ -0,0 +1,252 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Vector Creations Ltd +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Any, Callable, Dict, List, Optional, Set + +from prometheus_client import Counter + +from synapse.replication.tcp.client import ReplicationClientFactory +from synapse.replication.tcp.commands import ( + Command, + FederationAckCommand, + InvalidateCacheCommand, + PositionCommand, + RdataCommand, + RemoteServerUpCommand, + RemovePusherCommand, + SyncCommand, + UserIpCommand, + UserSyncCommand, +) +from synapse.replication.tcp.streams import STREAMS_MAP, Stream +from synapse.util.async_helpers import Linearizer + +logger = logging.getLogger(__name__) + + +# number of updates received for each RDATA stream +inbound_rdata_count = Counter( + "synapse_replication_tcp_protocol_inbound_rdata_count", "", ["stream_name"] +) + + +class ReplicationCommandHandler: + """Handles incoming commands from replication as well as sending commands + back out to connections. + """ + + def __init__(self, hs): + self._replication_data_handler = hs.get_replication_data_handler() + self._presence_handler = hs.get_presence_handler() + + # Set of streams that we've caught up with. + self._streams_connected = set() # type: Set[str] + + self._streams = { + stream.NAME: stream(hs) for stream in STREAMS_MAP.values() + } # type: Dict[str, Stream] + + self._position_linearizer = Linearizer("replication_position") + + # Map of stream to batched updates. See RdataCommand for info on how + # batching works. + self._pending_batches = {} # type: Dict[str, List[Any]] + + # The factory used to create connections. + self._factory = None # type: Optional[ReplicationClientFactory] + + # The current connection. None if we are currently (re)connecting + self._connection = None + + def start_replication(self, hs): + """Helper method to start a replication connection to the remote server + using TCP. + """ + client_name = hs.config.worker_name + self._factory = ReplicationClientFactory(hs, client_name, self) + host = hs.config.worker_replication_host + port = hs.config.worker_replication_port + hs.get_reactor().connectTCP(host, port, self._factory) + + async def on_RDATA(self, cmd: RdataCommand): + stream_name = cmd.stream_name + inbound_rdata_count.labels(stream_name).inc() + + try: + row = STREAMS_MAP[stream_name].parse_row(cmd.row) + except Exception: + logger.exception("Failed to parse RDATA: %r %r", stream_name, cmd.row) + raise + + if cmd.token is None or stream_name not in self._streams_connected: + # I.e. either this is part of a batch of updates for this stream (in + # which case batch until we get an update for the stream with a non + # None token) or we're currently connecting so we queue up rows. + self._pending_batches.setdefault(stream_name, []).append(row) + else: + # Check if this is the last of a batch of updates + rows = self._pending_batches.pop(stream_name, []) + rows.append(row) + await self.on_rdata(stream_name, cmd.token, rows) + + async def on_rdata(self, stream_name: str, token: int, rows: list): + """Called to handle a batch of replication data with a given stream token. + + Args: + stream_name: name of the replication stream for this batch of rows + token: stream token for this batch of rows + rows: a list of Stream.ROW_TYPE objects as returned by + Stream.parse_row. + """ + logger.debug("Received rdata %s -> %s", stream_name, token) + await self._replication_data_handler.on_rdata(stream_name, token, rows) + + async def on_POSITION(self, cmd: PositionCommand): + stream = self._streams.get(cmd.stream_name) + if not stream: + logger.error("Got POSITION for unknown stream: %s", cmd.stream_name) + return + + # We protect catching up with a linearizer in case the replication + # connection reconnects under us. + with await self._position_linearizer.queue(cmd.stream_name): + # We're about to go and catch up with the stream, so mark as connecting + # to stop RDATA being handled at the same time by removing stream from + # list of connected streams. We also clear any batched up RDATA from + # before we got the POSITION. + self._streams_connected.discard(cmd.stream_name) + self._pending_batches.clear() + + # Find where we previously streamed up to. + current_token = self._replication_data_handler.get_streams_to_replicate().get( + cmd.stream_name + ) + if current_token is None: + logger.warning( + "Got POSITION for stream we're not subscribed to: %s", + cmd.stream_name, + ) + return + + # Fetch all updates between then and now. + limited = True + while limited: + updates, current_token, limited = await stream.get_updates_since( + current_token, cmd.token + ) + if updates: + await self.on_rdata( + cmd.stream_name, + current_token, + [stream.parse_row(update[1]) for update in updates], + ) + + # We've now caught up to position sent to us, notify handler. + await self._replication_data_handler.on_position(cmd.stream_name, cmd.token) + + # Handle any RDATA that came in while we were catching up. + rows = self._pending_batches.pop(cmd.stream_name, []) + if rows: + await self._replication_data_handler.on_rdata( + cmd.stream_name, rows[-1].token, rows + ) + + self._streams_connected.add(cmd.stream_name) + + async def on_SYNC(self, cmd: SyncCommand): + pass + + async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand): + """"Called when get a new REMOTE_SERVER_UP command.""" + self._replication_data_handler.on_remote_server_up(cmd.data) + + def get_currently_syncing_users(self): + """Get the list of currently syncing users (if any). This is called + when a connection has been established and we need to send the + currently syncing users. + """ + return self._presence_handler.get_currently_syncing_users() + + def update_connection(self, connection): + """Called when a connection has been established (or lost with None). + """ + self._connection = connection + + def finished_connecting(self): + """Called when we have successfully subscribed and caught up to all + streams we're interested in. + """ + logger.info("Finished connecting to server") + + # We don't reset the delay any earlier as otherwise if there is a + # problem during start up we'll end up tight looping connecting to the + # server. + if self._factory: + self._factory.resetDelay() + + def send_command(self, cmd: Command): + """Send a command to master (when we get establish a connection if we + don't have one already.) + """ + if self._connection: + self._connection.send_command(cmd) + else: + logger.warning("Dropping command as not connected: %r", cmd.NAME) + + def send_federation_ack(self, token: int): + """Ack data for the federation stream. This allows the master to drop + data stored purely in memory. + """ + self.send_command(FederationAckCommand(token)) + + def send_user_sync( + self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int + ): + """Poke the master that a user has started/stopped syncing. + """ + self.send_command( + UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms) + ) + + def send_remove_pusher(self, app_id: str, push_key: str, user_id: str): + """Poke the master to remove a pusher for a user + """ + cmd = RemovePusherCommand(app_id, push_key, user_id) + self.send_command(cmd) + + def send_invalidate_cache(self, cache_func: Callable, keys: tuple): + """Poke the master to invalidate a cache. + """ + cmd = InvalidateCacheCommand(cache_func.__name__, keys) + self.send_command(cmd) + + def send_user_ip( + self, + user_id: str, + access_token: str, + ip: str, + user_agent: str, + device_id: str, + last_seen: int, + ): + """Tell the master that the user made a request. + """ + cmd = UserIpCommand(user_id, access_token, ip, user_agent, device_id, last_seen) + self.send_command(cmd) + + def send_remote_server_up(self, server: str): + self.send_command(RemoteServerUpCommand(server)) diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index dae246825f..f2a37f568e 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -46,12 +46,11 @@ indicate which side is sending, these are *not* included on the wire:: > ERROR server stopping * connection closed by server * """ -import abc import fcntl import logging import struct from collections import defaultdict -from typing import Any, DefaultDict, Dict, List, Set +from typing import TYPE_CHECKING, DefaultDict, List from six import iteritems @@ -78,13 +77,12 @@ from synapse.replication.tcp.commands import ( SyncCommand, UserSyncCommand, ) -from synapse.replication.tcp.streams import STREAMS_MAP, Stream from synapse.types import Collection from synapse.util import Clock from synapse.util.stringutils import random_string -MYPY = False -if MYPY: +if TYPE_CHECKING: + from synapse.replication.tcp.handler import ReplicationCommandHandler from synapse.server import HomeServer @@ -475,71 +473,6 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): self.streamer.lost_connection(self) -class AbstractReplicationClientHandler(metaclass=abc.ABCMeta): - """ - The interface for the handler that should be passed to - ClientReplicationStreamProtocol - """ - - @abc.abstractmethod - async def on_rdata(self, stream_name, token, rows): - """Called to handle a batch of replication data with a given stream token. - - Args: - stream_name (str): name of the replication stream for this batch of rows - token (int): stream token for this batch of rows - rows (list): a list of Stream.ROW_TYPE objects as returned by - Stream.parse_row. - """ - raise NotImplementedError() - - @abc.abstractmethod - async def on_position(self, stream_name, token): - """Called when we get new position data.""" - raise NotImplementedError() - - @abc.abstractmethod - def on_sync(self, data): - """Called when get a new SYNC command.""" - raise NotImplementedError() - - @abc.abstractmethod - async def on_remote_server_up(self, server: str): - """Called when get a new REMOTE_SERVER_UP command.""" - raise NotImplementedError() - - @abc.abstractmethod - def get_streams_to_replicate(self): - """Called when a new connection has been established and we need to - subscribe to streams. - - Returns: - map from stream name to the most recent update we have for - that stream (ie, the point we want to start replicating from) - """ - raise NotImplementedError() - - @abc.abstractmethod - def get_currently_syncing_users(self): - """Get the list of currently syncing users (if any). This is called - when a connection has been established and we need to send the - currently syncing users.""" - raise NotImplementedError() - - @abc.abstractmethod - def update_connection(self, connection): - """Called when a connection has been established (or lost with None). - """ - raise NotImplementedError() - - @abc.abstractmethod - def finished_connecting(self): - """Called when we have successfully subscribed and caught up to all - streams we're interested in. - """ - raise NotImplementedError() - - class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): VALID_INBOUND_COMMANDS = VALID_SERVER_COMMANDS VALID_OUTBOUND_COMMANDS = VALID_CLIENT_COMMANDS @@ -550,7 +483,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): client_name: str, server_name: str, clock: Clock, - handler: AbstractReplicationClientHandler, + command_handler: "ReplicationCommandHandler", ): BaseReplicationStreamProtocol.__init__(self, clock) @@ -558,20 +491,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): self.client_name = client_name self.server_name = server_name - self.handler = handler - - self.streams = { - stream.NAME: stream(hs) for stream in STREAMS_MAP.values() - } # type: Dict[str, Stream] - - # Set of stream names that have been subscribe to, but haven't yet - # caught up with. This is used to track when the client has been fully - # connected to the remote. - self.streams_connecting = set(STREAMS_MAP) # type: Set[str] - - # Map of stream to batched updates. See RdataCommand for info on how - # batching works. - self.pending_batches = {} # type: Dict[str, List[Any]] + self.handler = command_handler def connectionMade(self): self.send_command(NameCommand(self.client_name)) @@ -589,89 +509,39 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): # We've now finished connecting to so inform the client handler self.handler.update_connection(self) + self.handler.finished_connecting() - async def on_SERVER(self, cmd): - if cmd.data != self.server_name: - logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data) - self.send_error("Wrong remote") - - async def on_RDATA(self, cmd): - stream_name = cmd.stream_name - inbound_rdata_count.labels(stream_name).inc() - - try: - row = STREAMS_MAP[stream_name].parse_row(cmd.row) - except Exception: - logger.exception( - "[%s] Failed to parse RDATA: %r %r", self.id(), stream_name, cmd.row - ) - raise - - if cmd.token is None or stream_name in self.streams_connecting: - # I.e. this is part of a batch of updates for this stream. Batch - # until we get an update for the stream with a non None token - self.pending_batches.setdefault(stream_name, []).append(row) - else: - # Check if this is the last of a batch of updates - rows = self.pending_batches.pop(stream_name, []) - rows.append(row) - await self.handler.on_rdata(stream_name, cmd.token, rows) - - async def on_POSITION(self, cmd: PositionCommand): - stream = self.streams.get(cmd.stream_name) - if not stream: - logger.error("Got POSITION for unknown stream: %s", cmd.stream_name) - return - - # Find where we previously streamed up to. - current_token = self.handler.get_streams_to_replicate().get(cmd.stream_name) - if current_token is None: - logger.warning( - "Got POSITION for stream we're not subscribed to: %s", cmd.stream_name - ) - return - - # Fetch all updates between then and now. - limited = True - while limited: - updates, current_token, limited = await stream.get_updates_since( - current_token, cmd.token - ) - - # Check if the connection was closed underneath us, if so we bail - # rather than risk having concurrent catch ups going on. - if self.state == ConnectionStates.CLOSED: - return - - if updates: - await self.handler.on_rdata( - cmd.stream_name, - current_token, - [stream.parse_row(update[1]) for update in updates], - ) + async def handle_command(self, cmd: Command): + """Handle a command we have received over the replication stream. - # We've now caught up to position sent to us, notify handler. - await self.handler.on_position(cmd.stream_name, cmd.token) + Delegates to `command_handler.on_`, which must return an + awaitable. - self.streams_connecting.discard(cmd.stream_name) - if not self.streams_connecting: - self.handler.finished_connecting() + Args: + cmd: received command + """ + handled = False - # Check if the connection was closed underneath us, if so we bail - # rather than risk having concurrent catch ups going on. - if self.state == ConnectionStates.CLOSED: - return + # First call any command handlers on this instance. These are for TCP + # specific handling. + cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None) + if cmd_func: + await cmd_func(cmd) + handled = True - # Handle any RDATA that came in while we were catching up. - rows = self.pending_batches.pop(cmd.stream_name, []) - if rows: - await self.handler.on_rdata(cmd.stream_name, rows[-1].token, rows) + # Then call out to the handler. + cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None) + if cmd_func: + await cmd_func(cmd) + handled = True - async def on_SYNC(self, cmd): - self.handler.on_sync(cmd.data) + if not handled: + logger.warning("Unhandled command: %r", cmd) - async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand): - self.handler.on_remote_server_up(cmd.data) + async def on_SERVER(self, cmd): + if cmd.data != self.server_name: + logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data) + self.send_error("Wrong remote") def replicate(self): """Send the subscription request to the server @@ -768,8 +638,3 @@ tcp_outbound_commands = LaterGauge( for k, count in iteritems(p.outbound_commands_counter) }, ) - -# number of updates received for each RDATA stream -inbound_rdata_count = Counter( - "synapse_replication_tcp_protocol_inbound_rdata_count", "", ["stream_name"] -) diff --git a/synapse/server.py b/synapse/server.py index 9228e1c892..9d273c980c 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -87,6 +87,8 @@ from synapse.http.matrixfederationclient import MatrixFederationHttpClient from synapse.notifier import Notifier from synapse.push.action_generator import ActionGenerator from synapse.push.pusherpool import PusherPool +from synapse.replication.tcp.client import ReplicationDataHandler +from synapse.replication.tcp.handler import ReplicationCommandHandler from synapse.replication.tcp.resource import ReplicationStreamer from synapse.rest.media.v1.media_repository import ( MediaRepository, @@ -206,6 +208,7 @@ class HomeServer(object): "password_policy_handler", "storage", "replication_streamer", + "replication_data_handler", ] REQUIRED_ON_MASTER_STARTUP = ["user_directory_handler", "stats_handler"] @@ -468,7 +471,7 @@ class HomeServer(object): return ReadMarkerHandler(self) def build_tcp_replication(self): - raise NotImplementedError() + return ReplicationCommandHandler(self) def build_action_generator(self): return ActionGenerator(self) @@ -562,6 +565,9 @@ class HomeServer(object): def build_replication_streamer(self) -> ReplicationStreamer: return ReplicationStreamer(self) + def build_replication_data_handler(self): + return ReplicationDataHandler(self.get_datastore()) + def remove_pusher(self, app_id, push_key, user_id): return self.get_pusherpool().remove_pusher(app_id, push_key, user_id) diff --git a/synapse/server.pyi b/synapse/server.pyi index 9d1dfa71e7..9013e9bac9 100644 --- a/synapse/server.pyi +++ b/synapse/server.pyi @@ -19,6 +19,7 @@ import synapse.handlers.set_password import synapse.http.client import synapse.notifier import synapse.replication.tcp.client +import synapse.replication.tcp.handler import synapse.rest.media.v1.media_repository import synapse.server_notices.server_notices_manager import synapse.server_notices.server_notices_sender @@ -106,7 +107,11 @@ class HomeServer(object): pass def get_tcp_replication( self, - ) -> synapse.replication.tcp.client.ReplicationClientHandler: + ) -> synapse.replication.tcp.handler.ReplicationCommandHandler: + pass + def get_replication_data_handler( + self, + ) -> synapse.replication.tcp.client.ReplicationDataHandler: pass def get_federation_registry( self, diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index 2a1e7c7166..8902a5ab69 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -17,8 +17,9 @@ from mock import Mock, NonCallableMock from synapse.replication.tcp.client import ( ReplicationClientFactory, - ReplicationClientHandler, + ReplicationDataHandler, ) +from synapse.replication.tcp.handler import ReplicationCommandHandler from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory from synapse.storage.database import make_conn @@ -51,15 +52,19 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase): self.event_id = 0 server_factory = ReplicationStreamProtocolFactory(self.hs) - self.streamer = server_factory.streamer + self.streamer = hs.get_replication_streamer() - handler_factory = Mock() - self.replication_handler = ReplicationClientHandler(self.slaved_store) - self.replication_handler.factory = handler_factory + # We now do some gut wrenching so that we have a client that is based + # off of the slave store rather than the main store. + self.replication_handler = ReplicationCommandHandler(self.hs) + self.replication_handler._replication_data_handler = ReplicationDataHandler( + self.slaved_store + ) client_factory = ReplicationClientFactory( self.hs, "client_name", self.replication_handler ) + client_factory.handler = self.replication_handler server = server_factory.buildProtocol(None) client = client_factory.buildProtocol(None) diff --git a/tests/replication/tcp/streams/_base.py b/tests/replication/tcp/streams/_base.py index a755fe2879..32238fe79a 100644 --- a/tests/replication/tcp/streams/_base.py +++ b/tests/replication/tcp/streams/_base.py @@ -15,7 +15,7 @@ from mock import Mock -from synapse.replication.tcp.commands import ReplicateCommand +from synapse.replication.tcp.handler import ReplicationCommandHandler from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory @@ -26,15 +26,20 @@ from tests.server import FakeTransport class BaseStreamTestCase(unittest.HomeserverTestCase): """Base class for tests of the replication streams""" + def make_homeserver(self, reactor, clock): + self.test_handler = Mock(wraps=TestReplicationDataHandler()) + return self.setup_test_homeserver(replication_data_handler=self.test_handler) + def prepare(self, reactor, clock, hs): # build a replication server - server_factory = ReplicationStreamProtocolFactory(self.hs) - self.streamer = server_factory.streamer + server_factory = ReplicationStreamProtocolFactory(hs) + self.streamer = hs.get_replication_streamer() self.server = server_factory.buildProtocol(None) - self.test_handler = Mock(wraps=TestReplicationClientHandler()) + repl_handler = ReplicationCommandHandler(hs) + repl_handler.handler = self.test_handler self.client = ClientReplicationStreamProtocol( - hs, "client", "test", clock, self.test_handler, + hs, "client", "test", clock, repl_handler, ) self._client_transport = None @@ -69,13 +74,9 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): self.streamer.on_notifier_poke() self.pump(0.1) - def replicate_stream(self): - """Make the client end a REPLICATE command to set up a subscription to a stream""" - self.client.send_command(ReplicateCommand()) - -class TestReplicationClientHandler(object): - """Drop-in for ReplicationClientHandler which just collects RDATA rows""" +class TestReplicationDataHandler: + """Drop-in for ReplicationDataHandler which just collects RDATA rows""" def __init__(self): self.streams = set() @@ -88,18 +89,9 @@ class TestReplicationClientHandler(object): positions[stream] = max(token, positions.get(stream, 0)) return positions - def get_currently_syncing_users(self): - return [] - - def update_connection(self, connection): - pass - - def finished_connecting(self): - pass - - async def on_position(self, stream_name, token): - """Called when we get new position data.""" - async def on_rdata(self, stream_name, token, rows): for r in rows: self._received_rdata_rows.append((stream_name, token, r)) + + async def on_position(self, stream_name, token): + pass diff --git a/tests/replication/tcp/streams/test_receipts.py b/tests/replication/tcp/streams/test_receipts.py index 0ec0825a0e..a0206f7363 100644 --- a/tests/replication/tcp/streams/test_receipts.py +++ b/tests/replication/tcp/streams/test_receipts.py @@ -24,7 +24,6 @@ class ReceiptsStreamTestCase(BaseStreamTestCase): self.reconnect() # make the client subscribe to the receipts stream - self.replicate_stream() self.test_handler.streams.add("receipts") # tell the master to send a new receipt -- cgit 1.4.1 From b21000a44fa8b6f5d28a2089033f76767dff868b Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Mon, 6 Apr 2020 12:35:30 +0100 Subject: Improve error responses when a remote server doesn't allow you to access its public rooms list (#6899) --- changelog.d/6899.bugfix | 1 + synapse/handlers/room_list.py | 23 ++++++++++++----------- synapse/rest/client/v1/room.py | 33 ++++++++++++++++++++------------- 3 files changed, 33 insertions(+), 24 deletions(-) create mode 100644 changelog.d/6899.bugfix (limited to 'synapse') diff --git a/changelog.d/6899.bugfix b/changelog.d/6899.bugfix new file mode 100644 index 0000000000..efa8a40b1f --- /dev/null +++ b/changelog.d/6899.bugfix @@ -0,0 +1 @@ +Improve error responses when accessing remote public room lists. \ No newline at end of file diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index 0b7d3da680..59c9906b31 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -15,6 +15,7 @@ import logging from collections import namedtuple +from typing import Any, Dict, Optional from six import iteritems @@ -105,22 +106,22 @@ class RoomListHandler(BaseHandler): @defer.inlineCallbacks def _get_public_room_list( self, - limit=None, - since_token=None, - search_filter=None, - network_tuple=EMPTY_THIRD_PARTY_ID, - from_federation=False, - ): + limit: Optional[int] = None, + since_token: Optional[str] = None, + search_filter: Optional[Dict] = None, + network_tuple: ThirdPartyInstanceID = EMPTY_THIRD_PARTY_ID, + from_federation: bool = False, + ) -> Dict[str, Any]: """Generate a public room list. Args: - limit (int|None): Maximum amount of rooms to return. - since_token (str|None) - search_filter (dict|None): Dictionary to filter rooms by. - network_tuple (ThirdPartyInstanceID): Which public list to use. + limit: Maximum amount of rooms to return. + since_token: + search_filter: Dictionary to filter rooms by. + network_tuple: Which public list to use. This can be (None, None) to indicate the main list, or a particular appservice and network id to use an appservice specific one. Setting to None returns all public rooms across all lists. - from_federation (bool): Whether this request originated from a + from_federation: Whether this request originated from a federating server or a client. Used for room filtering. """ diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index bffd43de5f..6b5830cc3f 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -27,6 +27,7 @@ from synapse.api.constants import EventTypes, Membership from synapse.api.errors import ( AuthError, Codes, + HttpResponseException, InvalidClientCredentialsError, SynapseError, ) @@ -364,10 +365,13 @@ class PublicRoomListRestServlet(TransactionRestServlet): limit = None handler = self.hs.get_room_list_handler() - if server: - data = await handler.get_remote_public_room_list( - server, limit=limit, since_token=since_token - ) + if server and server != self.hs.config.server_name: + try: + data = await handler.get_remote_public_room_list( + server, limit=limit, since_token=since_token + ) + except HttpResponseException as e: + raise e.to_synapse_error() else: data = await handler.get_local_public_room_list( limit=limit, since_token=since_token @@ -404,15 +408,18 @@ class PublicRoomListRestServlet(TransactionRestServlet): limit = None handler = self.hs.get_room_list_handler() - if server: - data = await handler.get_remote_public_room_list( - server, - limit=limit, - since_token=since_token, - search_filter=search_filter, - include_all_networks=include_all_networks, - third_party_instance_id=third_party_instance_id, - ) + if server and server != self.hs.config.server_name: + try: + data = await handler.get_remote_public_room_list( + server, + limit=limit, + since_token=since_token, + search_filter=search_filter, + include_all_networks=include_all_networks, + third_party_instance_id=third_party_instance_id, + ) + except HttpResponseException as e: + raise e.to_synapse_error() else: data = await handler.get_local_public_room_list( limit=limit, -- cgit 1.4.1 From 4b0f00ad0c6bbe153f82b95980a2ba16238b4449 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Mon, 6 Apr 2020 12:40:34 +0100 Subject: Remove stream before/after debug log lines (#7207) --- changelog.d/7207.misc | 1 + synapse/storage/data_stores/main/stream.py | 2 -- 2 files changed, 1 insertion(+), 2 deletions(-) create mode 100644 changelog.d/7207.misc (limited to 'synapse') diff --git a/changelog.d/7207.misc b/changelog.d/7207.misc new file mode 100644 index 0000000000..4f9b6a1089 --- /dev/null +++ b/changelog.d/7207.misc @@ -0,0 +1 @@ +Remove some extraneous debugging log lines. \ No newline at end of file diff --git a/synapse/storage/data_stores/main/stream.py b/synapse/storage/data_stores/main/stream.py index ada5cce6c2..e89f0bffb5 100644 --- a/synapse/storage/data_stores/main/stream.py +++ b/synapse/storage/data_stores/main/stream.py @@ -481,11 +481,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): room_id, limit, end_token ) - logger.debug("stream before") events = yield self.get_events_as_list( [r.event_id for r in rows], get_prev_content=True ) - logger.debug("stream after") self._set_before_and_after(events, rows) -- cgit 1.4.1 From 82498ee9019747eb86ed753e08fac0990d4ac8b9 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 7 Apr 2020 10:51:07 +0100 Subject: Move server command handling out of TCP protocol (#7187) This completes the merging of server and client command processing. --- changelog.d/7187.misc | 1 + synapse/replication/tcp/handler.py | 177 ++++++++++++++++++++++++++++++++---- synapse/replication/tcp/protocol.py | 165 +++++++++++---------------------- synapse/replication/tcp/resource.py | 163 ++++++--------------------------- 4 files changed, 237 insertions(+), 269 deletions(-) create mode 100644 changelog.d/7187.misc (limited to 'synapse') diff --git a/changelog.d/7187.misc b/changelog.d/7187.misc new file mode 100644 index 0000000000..60d68ae877 --- /dev/null +++ b/changelog.d/7187.misc @@ -0,0 +1 @@ +Move server command handling out of TCP protocol. diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 12a1cfd6d1..8ec0119697 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -19,8 +19,10 @@ from typing import Any, Callable, Dict, List, Optional, Set from prometheus_client import Counter +from synapse.metrics import LaterGauge from synapse.replication.tcp.client import ReplicationClientFactory from synapse.replication.tcp.commands import ( + ClearUserSyncsCommand, Command, FederationAckCommand, InvalidateCacheCommand, @@ -28,10 +30,12 @@ from synapse.replication.tcp.commands import ( RdataCommand, RemoteServerUpCommand, RemovePusherCommand, + ReplicateCommand, SyncCommand, UserIpCommand, UserSyncCommand, ) +from synapse.replication.tcp.protocol import AbstractConnection from synapse.replication.tcp.streams import STREAMS_MAP, Stream from synapse.util.async_helpers import Linearizer @@ -42,6 +46,13 @@ logger = logging.getLogger(__name__) inbound_rdata_count = Counter( "synapse_replication_tcp_protocol_inbound_rdata_count", "", ["stream_name"] ) +user_sync_counter = Counter("synapse_replication_tcp_resource_user_sync", "") +federation_ack_counter = Counter("synapse_replication_tcp_resource_federation_ack", "") +remove_pusher_counter = Counter("synapse_replication_tcp_resource_remove_pusher", "") +invalidate_cache_counter = Counter( + "synapse_replication_tcp_resource_invalidate_cache", "" +) +user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache", "") class ReplicationCommandHandler: @@ -52,6 +63,10 @@ class ReplicationCommandHandler: def __init__(self, hs): self._replication_data_handler = hs.get_replication_data_handler() self._presence_handler = hs.get_presence_handler() + self._store = hs.get_datastore() + self._notifier = hs.get_notifier() + self._clock = hs.get_clock() + self._instance_id = hs.get_instance_id() # Set of streams that we've caught up with. self._streams_connected = set() # type: Set[str] @@ -69,8 +84,26 @@ class ReplicationCommandHandler: # The factory used to create connections. self._factory = None # type: Optional[ReplicationClientFactory] - # The current connection. None if we are currently (re)connecting - self._connection = None + # The currently connected connections. + self._connections = [] # type: List[AbstractConnection] + + LaterGauge( + "synapse_replication_tcp_resource_total_connections", + "", + [], + lambda: len(self._connections), + ) + + self._is_master = hs.config.worker_app is None + + self._federation_sender = None + if self._is_master and not hs.config.send_federation: + self._federation_sender = hs.get_federation_sender() + + self._server_notices_sender = None + if self._is_master: + self._server_notices_sender = hs.get_server_notices_sender() + self._notifier.add_remote_server_up_callback(self.send_remote_server_up) def start_replication(self, hs): """Helper method to start a replication connection to the remote server @@ -82,6 +115,70 @@ class ReplicationCommandHandler: port = hs.config.worker_replication_port hs.get_reactor().connectTCP(host, port, self._factory) + async def on_REPLICATE(self, cmd: ReplicateCommand): + # We only want to announce positions by the writer of the streams. + # Currently this is just the master process. + if not self._is_master: + return + + for stream_name, stream in self._streams.items(): + current_token = stream.current_token() + self.send_command(PositionCommand(stream_name, current_token)) + + async def on_USER_SYNC(self, cmd: UserSyncCommand): + user_sync_counter.inc() + + if self._is_master: + await self._presence_handler.update_external_syncs_row( + cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms + ) + + async def on_CLEAR_USER_SYNC(self, cmd: ClearUserSyncsCommand): + if self._is_master: + await self._presence_handler.update_external_syncs_clear(cmd.instance_id) + + async def on_FEDERATION_ACK(self, cmd: FederationAckCommand): + federation_ack_counter.inc() + + if self._federation_sender: + self._federation_sender.federation_ack(cmd.token) + + async def on_REMOVE_PUSHER(self, cmd: RemovePusherCommand): + remove_pusher_counter.inc() + + if self._is_master: + await self._store.delete_pusher_by_app_id_pushkey_user_id( + app_id=cmd.app_id, pushkey=cmd.push_key, user_id=cmd.user_id + ) + + self._notifier.on_new_replication_data() + + async def on_INVALIDATE_CACHE(self, cmd: InvalidateCacheCommand): + invalidate_cache_counter.inc() + + if self._is_master: + # We invalidate the cache locally, but then also stream that to other + # workers. + await self._store.invalidate_cache_and_stream( + cmd.cache_func, tuple(cmd.keys) + ) + + async def on_USER_IP(self, cmd: UserIpCommand): + user_ip_cache_counter.inc() + + if self._is_master: + await self._store.insert_client_ip( + cmd.user_id, + cmd.access_token, + cmd.ip, + cmd.user_agent, + cmd.device_id, + cmd.last_seen, + ) + + if self._server_notices_sender: + await self._server_notices_sender.on_user_ip(cmd.user_id) + async def on_RDATA(self, cmd: RdataCommand): stream_name = cmd.stream_name inbound_rdata_count.labels(stream_name).inc() @@ -174,6 +271,9 @@ class ReplicationCommandHandler: """"Called when get a new REMOTE_SERVER_UP command.""" self._replication_data_handler.on_remote_server_up(cmd.data) + if self._is_master: + self._notifier.notify_remote_server_up(cmd.data) + def get_currently_syncing_users(self): """Get the list of currently syncing users (if any). This is called when a connection has been established and we need to send the @@ -181,29 +281,63 @@ class ReplicationCommandHandler: """ return self._presence_handler.get_currently_syncing_users() - def update_connection(self, connection): - """Called when a connection has been established (or lost with None). + def new_connection(self, connection: AbstractConnection): + """Called when we have a new connection. """ - self._connection = connection + self._connections.append(connection) + + # If we are connected to replication as a client (rather than a server) + # we need to reset the reconnection delay on the client factory (which + # is used to do exponential back off when the connection drops). + # + # Ideally we would reset the delay when we've "fully established" the + # connection (for some definition thereof) to stop us from tightlooping + # on reconnection if something fails after this point and we drop the + # connection. Unfortunately, we don't really have a better definition of + # "fully established" than the connection being established. + if self._factory: + self._factory.resetDelay() + + # Tell the server if we have any users currently syncing (should only + # happen on synchrotrons) + currently_syncing = self.get_currently_syncing_users() + now = self._clock.time_msec() + for user_id in currently_syncing: + connection.send_command( + UserSyncCommand(self._instance_id, user_id, True, now) + ) - def finished_connecting(self): - """Called when we have successfully subscribed and caught up to all - streams we're interested in. + def lost_connection(self, connection: AbstractConnection): + """Called when a connection is closed/lost. """ - logger.info("Finished connecting to server") + try: + self._connections.remove(connection) + except ValueError: + pass - # We don't reset the delay any earlier as otherwise if there is a - # problem during start up we'll end up tight looping connecting to the - # server. - if self._factory: - self._factory.resetDelay() + def connected(self) -> bool: + """Do we have any replication connections open? + + Is used by e.g. `ReplicationStreamer` to no-op if nothing is connected. + """ + return bool(self._connections) def send_command(self, cmd: Command): - """Send a command to master (when we get establish a connection if we - don't have one already.) + """Send a command to all connected connections. """ - if self._connection: - self._connection.send_command(cmd) + if self._connections: + for connection in self._connections: + try: + connection.send_command(cmd) + except Exception: + # We probably want to catch some types of exceptions here + # and log them as warnings (e.g. connection gone), but I + # can't find what those exception types they would be. + logger.exception( + "Failed to write command %s to connection %s", + cmd.NAME, + connection, + ) else: logger.warning("Dropping command as not connected: %r", cmd.NAME) @@ -250,3 +384,10 @@ class ReplicationCommandHandler: def send_remote_server_up(self, server: str): self.send_command(RemoteServerUpCommand(server)) + + def stream_update(self, stream_name: str, token: str, data: Any): + """Called when a new update is available to stream to clients. + + We need to check if the client is interested in the stream or not + """ + self.send_command(RdataCommand(stream_name, token, data)) diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index f2a37f568e..9aabb9c586 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -46,6 +46,7 @@ indicate which side is sending, these are *not* included on the wire:: > ERROR server stopping * connection closed by server * """ +import abc import fcntl import logging import struct @@ -69,13 +70,8 @@ from synapse.replication.tcp.commands import ( ErrorCommand, NameCommand, PingCommand, - PositionCommand, - RdataCommand, - RemoteServerUpCommand, ReplicateCommand, ServerCommand, - SyncCommand, - UserSyncCommand, ) from synapse.types import Collection from synapse.util import Clock @@ -118,7 +114,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): are only sent by the server. On receiving a new command it calls `on_` with the parsed - command. + command before delegating to `ReplicationCommandHandler.on_`. It also sends `PING` periodically, and correctly times out remote connections (if they send a `PING` command) @@ -134,8 +130,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): max_line_buffer = 10000 - def __init__(self, clock): + def __init__(self, clock: Clock, handler: "ReplicationCommandHandler"): self.clock = clock + self.command_handler = handler self.last_received_command = self.clock.time_msec() self.last_sent_command = 0 @@ -175,6 +172,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): # can time us out. self.send_command(PingCommand(self.clock.time_msec())) + self.command_handler.new_connection(self) + def send_ping(self): """Periodically sends a ping and checks if we should close the connection due to the other side timing out. @@ -243,13 +242,31 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): async def handle_command(self, cmd: Command): """Handle a command we have received over the replication stream. - By default delegates to on_, which should return an awaitable. + First calls `self.on_` if it exists, then calls + `self.command_handler.on_` if it exists. This allows for + protocol level handling of commands (e.g. PINGs), before delegating to + the handler. Args: cmd: received command """ - handler = getattr(self, "on_%s" % (cmd.NAME,)) - await handler(cmd) + handled = False + + # First call any command handlers on this instance. These are for TCP + # specific handling. + cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None) + if cmd_func: + await cmd_func(cmd) + handled = True + + # Then call out to the handler. + cmd_func = getattr(self.command_handler, "on_%s" % (cmd.NAME,), None) + if cmd_func: + await cmd_func(cmd) + handled = True + + if not handled: + logger.warning("Unhandled command: %r", cmd) def close(self): logger.warning("[%s] Closing connection", self.id()) @@ -378,6 +395,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): self.state = ConnectionStates.CLOSED self.pending_commands = [] + self.command_handler.lost_connection(self) + if self.transport: self.transport.unregisterProducer() @@ -404,74 +423,21 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): VALID_INBOUND_COMMANDS = VALID_CLIENT_COMMANDS VALID_OUTBOUND_COMMANDS = VALID_SERVER_COMMANDS - def __init__(self, server_name, clock, streamer): - BaseReplicationStreamProtocol.__init__(self, clock) # Old style class + def __init__( + self, server_name: str, clock: Clock, handler: "ReplicationCommandHandler" + ): + super().__init__(clock, handler) self.server_name = server_name - self.streamer = streamer def connectionMade(self): self.send_command(ServerCommand(self.server_name)) - BaseReplicationStreamProtocol.connectionMade(self) - self.streamer.new_connection(self) + super().connectionMade() async def on_NAME(self, cmd): logger.info("[%s] Renamed to %r", self.id(), cmd.data) self.name = cmd.data - async def on_USER_SYNC(self, cmd): - await self.streamer.on_user_sync( - cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms - ) - - async def on_CLEAR_USER_SYNC(self, cmd): - await self.streamer.on_clear_user_syncs(cmd.instance_id) - - async def on_REPLICATE(self, cmd): - # Subscribe to all streams we're publishing to. - for stream_name in self.streamer.streams_by_name: - current_token = self.streamer.get_stream_token(stream_name) - self.send_command(PositionCommand(stream_name, current_token)) - - async def on_FEDERATION_ACK(self, cmd): - self.streamer.federation_ack(cmd.token) - - async def on_REMOVE_PUSHER(self, cmd): - await self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id) - - async def on_INVALIDATE_CACHE(self, cmd): - await self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys) - - async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand): - self.streamer.on_remote_server_up(cmd.data) - - async def on_USER_IP(self, cmd): - self.streamer.on_user_ip( - cmd.user_id, - cmd.access_token, - cmd.ip, - cmd.user_agent, - cmd.device_id, - cmd.last_seen, - ) - - def stream_update(self, stream_name, token, data): - """Called when a new update is available to stream to clients. - - We need to check if the client is interested in the stream or not - """ - self.send_command(RdataCommand(stream_name, token, data)) - - def send_sync(self, data): - self.send_command(SyncCommand(data)) - - def send_remote_server_up(self, server: str): - self.send_command(RemoteServerUpCommand(server)) - - def on_connection_closed(self): - BaseReplicationStreamProtocol.on_connection_closed(self) - self.streamer.lost_connection(self) - class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): VALID_INBOUND_COMMANDS = VALID_SERVER_COMMANDS @@ -485,59 +451,18 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): clock: Clock, command_handler: "ReplicationCommandHandler", ): - BaseReplicationStreamProtocol.__init__(self, clock) - - self.instance_id = hs.get_instance_id() + super().__init__(clock, command_handler) self.client_name = client_name self.server_name = server_name - self.handler = command_handler def connectionMade(self): self.send_command(NameCommand(self.client_name)) - BaseReplicationStreamProtocol.connectionMade(self) + super().connectionMade() # Once we've connected subscribe to the necessary streams self.replicate() - # Tell the server if we have any users currently syncing (should only - # happen on synchrotrons) - currently_syncing = self.handler.get_currently_syncing_users() - now = self.clock.time_msec() - for user_id in currently_syncing: - self.send_command(UserSyncCommand(self.instance_id, user_id, True, now)) - - # We've now finished connecting to so inform the client handler - self.handler.update_connection(self) - self.handler.finished_connecting() - - async def handle_command(self, cmd: Command): - """Handle a command we have received over the replication stream. - - Delegates to `command_handler.on_`, which must return an - awaitable. - - Args: - cmd: received command - """ - handled = False - - # First call any command handlers on this instance. These are for TCP - # specific handling. - cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None) - if cmd_func: - await cmd_func(cmd) - handled = True - - # Then call out to the handler. - cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None) - if cmd_func: - await cmd_func(cmd) - handled = True - - if not handled: - logger.warning("Unhandled command: %r", cmd) - async def on_SERVER(self, cmd): if cmd.data != self.server_name: logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data) @@ -550,9 +475,21 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): self.send_command(ReplicateCommand()) - def on_connection_closed(self): - BaseReplicationStreamProtocol.on_connection_closed(self) - self.handler.update_connection(None) + +class AbstractConnection(abc.ABC): + """An interface for replication connections. + """ + + @abc.abstractmethod + def send_command(self, cmd: Command): + """Send the command down the connection + """ + pass + + +# This tells python that `BaseReplicationStreamProtocol` implements the +# interface. +AbstractConnection.register(BaseReplicationStreamProtocol) # The following simply registers metrics for the replication connections diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index 30021ee309..b2d6baa2a2 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -17,7 +17,7 @@ import logging import random -from typing import Any, Dict, List +from typing import Dict from six import itervalues @@ -25,24 +25,14 @@ from prometheus_client import Counter from twisted.internet.protocol import Factory -from synapse.metrics import LaterGauge from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.util.metrics import Measure, measure_func - -from .protocol import ServerReplicationStreamProtocol -from .streams import STREAMS_MAP, Stream -from .streams.federation import FederationStream +from synapse.replication.tcp.protocol import ServerReplicationStreamProtocol +from synapse.replication.tcp.streams import STREAMS_MAP, FederationStream, Stream +from synapse.util.metrics import Measure stream_updates_counter = Counter( "synapse_replication_tcp_resource_stream_updates", "", ["stream_name"] ) -user_sync_counter = Counter("synapse_replication_tcp_resource_user_sync", "") -federation_ack_counter = Counter("synapse_replication_tcp_resource_federation_ack", "") -remove_pusher_counter = Counter("synapse_replication_tcp_resource_remove_pusher", "") -invalidate_cache_counter = Counter( - "synapse_replication_tcp_resource_invalidate_cache", "" -) -user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache", "") logger = logging.getLogger(__name__) @@ -52,13 +42,23 @@ class ReplicationStreamProtocolFactory(Factory): """ def __init__(self, hs): - self.streamer = hs.get_replication_streamer() + self.command_handler = hs.get_tcp_replication() self.clock = hs.get_clock() self.server_name = hs.config.server_name + # If we've created a `ReplicationStreamProtocolFactory` then we're + # almost certainly registering a replication listener, so let's ensure + # that we've started a `ReplicationStreamer` instance to actually push + # data. + # + # (This is a bit of a weird place to do this, but the alternatives such + # as putting this in `HomeServer.setup()`, requires either passing the + # listener config again or always starting a `ReplicationStreamer`.) + hs.get_replication_streamer() + def buildProtocol(self, addr): return ServerReplicationStreamProtocol( - self.server_name, self.clock, self.streamer + self.server_name, self.clock, self.command_handler ) @@ -78,16 +78,6 @@ class ReplicationStreamer(object): self._replication_torture_level = hs.config.replication_torture_level - # Current connections. - self.connections = [] # type: List[ServerReplicationStreamProtocol] - - LaterGauge( - "synapse_replication_tcp_resource_total_connections", - "", - [], - lambda: len(self.connections), - ) - # List of streams that clients can subscribe to. # We only support federation stream if federation sending hase been # disabled on the master. @@ -104,18 +94,12 @@ class ReplicationStreamer(object): self.federation_sender = hs.get_federation_sender() self.notifier.add_replication_callback(self.on_notifier_poke) - self.notifier.add_remote_server_up_callback(self.send_remote_server_up) # Keeps track of whether we are currently checking for updates self.is_looping = False self.pending_updates = False - hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.on_shutdown) - - def on_shutdown(self): - # close all connections on shutdown - for conn in self.connections: - conn.send_error("server shutting down") + self.command_handler = hs.get_tcp_replication() def get_streams(self) -> Dict[str, Stream]: """Get a mapp from stream name to stream instance. @@ -129,7 +113,7 @@ class ReplicationStreamer(object): This should get called each time new data is available, even if it is currently being executed, so that nothing gets missed """ - if not self.connections: + if not self.command_handler.connected(): # Don't bother if nothing is listening. We still need to advance # the stream tokens otherwise they'll fall beihind forever for stream in self.streams: @@ -186,9 +170,7 @@ class ReplicationStreamer(object): raise logger.debug( - "Sending %d updates to %d connections", - len(updates), - len(self.connections), + "Sending %d updates", len(updates), ) if updates: @@ -204,112 +186,19 @@ class ReplicationStreamer(object): # token. See RdataCommand for more details. batched_updates = _batch_updates(updates) - for conn in self.connections: - for token, row in batched_updates: - try: - conn.stream_update(stream.NAME, token, row) - except Exception: - logger.exception("Failed to replicate") + for token, row in batched_updates: + try: + self.command_handler.stream_update( + stream.NAME, token, row + ) + except Exception: + logger.exception("Failed to replicate") logger.debug("No more pending updates, breaking poke loop") finally: self.pending_updates = False self.is_looping = False - def get_stream_token(self, stream_name): - """For a given stream get all updates since token. This is called when - a client first subscribes to a stream. - """ - stream = self.streams_by_name.get(stream_name, None) - if not stream: - raise Exception("unknown stream %s", stream_name) - - return stream.current_token() - - @measure_func("repl.federation_ack") - def federation_ack(self, token): - """We've received an ack for federation stream from a client. - """ - federation_ack_counter.inc() - if self.federation_sender: - self.federation_sender.federation_ack(token) - - @measure_func("repl.on_user_sync") - async def on_user_sync(self, instance_id, user_id, is_syncing, last_sync_ms): - """A client has started/stopped syncing on a worker. - """ - user_sync_counter.inc() - await self.presence_handler.update_external_syncs_row( - instance_id, user_id, is_syncing, last_sync_ms - ) - - async def on_clear_user_syncs(self, instance_id): - """A replication client wants us to drop all their UserSync data. - """ - await self.presence_handler.update_external_syncs_clear(instance_id) - - @measure_func("repl.on_remove_pusher") - async def on_remove_pusher(self, app_id, push_key, user_id): - """A client has asked us to remove a pusher - """ - remove_pusher_counter.inc() - await self.store.delete_pusher_by_app_id_pushkey_user_id( - app_id=app_id, pushkey=push_key, user_id=user_id - ) - - self.notifier.on_new_replication_data() - - @measure_func("repl.on_invalidate_cache") - async def on_invalidate_cache(self, cache_func: str, keys: List[Any]): - """The client has asked us to invalidate a cache - """ - invalidate_cache_counter.inc() - - # We invalidate the cache locally, but then also stream that to other - # workers. - await self.store.invalidate_cache_and_stream(cache_func, tuple(keys)) - - @measure_func("repl.on_user_ip") - async def on_user_ip( - self, user_id, access_token, ip, user_agent, device_id, last_seen - ): - """The client saw a user request - """ - user_ip_cache_counter.inc() - await self.store.insert_client_ip( - user_id, access_token, ip, user_agent, device_id, last_seen - ) - await self._server_notices_sender.on_user_ip(user_id) - - @measure_func("repl.on_remote_server_up") - def on_remote_server_up(self, server: str): - self.notifier.notify_remote_server_up(server) - - def send_remote_server_up(self, server: str): - for conn in self.connections: - conn.send_remote_server_up(server) - - def send_sync_to_all_connections(self, data): - """Sends a SYNC command to all clients. - - Used in tests. - """ - for conn in self.connections: - conn.send_sync(data) - - def new_connection(self, connection): - """A new client connection has been established - """ - self.connections.append(connection) - - def lost_connection(self, connection): - """A client connection has been lost - """ - try: - self.connections.remove(connection) - except ValueError: - pass - def _batch_updates(updates): """Takes a list of updates of form [(token, row)] and sets the token to -- cgit 1.4.1 From ce72355d7f67a986d60a7d86489b1b40f93fb152 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 7 Apr 2020 11:01:04 +0100 Subject: Fix race in replication (#7226) Fixes a race between handling `POSITION` and `RDATA` commands. We do this by simply linearizing handling of them. --- changelog.d/7226.misc | 1 + synapse/replication/tcp/handler.py | 73 +++++++++++++++++---------- synapse/replication/tcp/streams/_base.py | 3 +- synapse/storage/data_stores/main/push_rule.py | 40 +++++++-------- 4 files changed, 68 insertions(+), 49 deletions(-) create mode 100644 changelog.d/7226.misc (limited to 'synapse') diff --git a/changelog.d/7226.misc b/changelog.d/7226.misc new file mode 100644 index 0000000000..676f285377 --- /dev/null +++ b/changelog.d/7226.misc @@ -0,0 +1 @@ +Move catchup of replication streams logic to worker. diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 8ec0119697..dd71d1bc34 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -189,16 +189,34 @@ class ReplicationCommandHandler: logger.exception("Failed to parse RDATA: %r %r", stream_name, cmd.row) raise - if cmd.token is None or stream_name not in self._streams_connected: - # I.e. either this is part of a batch of updates for this stream (in - # which case batch until we get an update for the stream with a non - # None token) or we're currently connecting so we queue up rows. - self._pending_batches.setdefault(stream_name, []).append(row) - else: - # Check if this is the last of a batch of updates - rows = self._pending_batches.pop(stream_name, []) - rows.append(row) - await self.on_rdata(stream_name, cmd.token, rows) + # We linearize here for two reasons: + # 1. so we don't try and concurrently handle multiple rows for the + # same stream, and + # 2. so we don't race with getting a POSITION command and fetching + # missing RDATA. + with await self._position_linearizer.queue(cmd.stream_name): + if stream_name not in self._streams_connected: + # If the stream isn't marked as connected then we haven't seen a + # `POSITION` command yet, and so we may have missed some rows. + # Let's drop the row for now, on the assumption we'll receive a + # `POSITION` soon and we'll catch up correctly then. + logger.warning( + "Discarding RDATA for unconnected stream %s -> %s", + stream_name, + cmd.token, + ) + return + + if cmd.token is None: + # I.e. this is part of a batch of updates for this stream (in + # which case batch until we get an update for the stream with a non + # None token). + self._pending_batches.setdefault(stream_name, []).append(row) + else: + # Check if this is the last of a batch of updates + rows = self._pending_batches.pop(stream_name, []) + rows.append(row) + await self.on_rdata(stream_name, cmd.token, rows) async def on_rdata(self, stream_name: str, token: int, rows: list): """Called to handle a batch of replication data with a given stream token. @@ -221,12 +239,13 @@ class ReplicationCommandHandler: # We protect catching up with a linearizer in case the replication # connection reconnects under us. with await self._position_linearizer.queue(cmd.stream_name): - # We're about to go and catch up with the stream, so mark as connecting - # to stop RDATA being handled at the same time by removing stream from - # list of connected streams. We also clear any batched up RDATA from - # before we got the POSITION. + # We're about to go and catch up with the stream, so remove from set + # of connected streams. self._streams_connected.discard(cmd.stream_name) - self._pending_batches.clear() + + # We clear the pending batches for the stream as the fetching of the + # missing updates below will fetch all rows in the batch. + self._pending_batches.pop(cmd.stream_name, []) # Find where we previously streamed up to. current_token = self._replication_data_handler.get_streams_to_replicate().get( @@ -239,12 +258,17 @@ class ReplicationCommandHandler: ) return - # Fetch all updates between then and now. - limited = True - while limited: - updates, current_token, limited = await stream.get_updates_since( - current_token, cmd.token - ) + # If the position token matches our current token then we're up to + # date and there's nothing to do. Otherwise, fetch all updates + # between then and now. + missing_updates = cmd.token != current_token + while missing_updates: + ( + updates, + current_token, + missing_updates, + ) = await stream.get_updates_since(current_token, cmd.token) + if updates: await self.on_rdata( cmd.stream_name, @@ -255,13 +279,6 @@ class ReplicationCommandHandler: # We've now caught up to position sent to us, notify handler. await self._replication_data_handler.on_position(cmd.stream_name, cmd.token) - # Handle any RDATA that came in while we were catching up. - rows = self._pending_batches.pop(cmd.stream_name, []) - if rows: - await self._replication_data_handler.on_rdata( - cmd.stream_name, rows[-1].token, rows - ) - self._streams_connected.add(cmd.stream_name) async def on_SYNC(self, cmd: SyncCommand): diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index c14dff6c64..f56a0fd4b5 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -168,12 +168,13 @@ def make_http_update_function( async def update_function( from_token: int, upto_token: int, limit: int ) -> Tuple[List[Tuple[int, tuple]], int, bool]: - return await client( + result = await client( stream_name=stream_name, from_token=from_token, upto_token=upto_token, limit=limit, ) + return result["updates"], result["upto_token"], result["limited"] return update_function diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py index 46f9bda773..b3faafa0a4 100644 --- a/synapse/storage/data_stores/main/push_rule.py +++ b/synapse/storage/data_stores/main/push_rule.py @@ -334,6 +334,26 @@ class PushRulesWorkerStore( results.setdefault(row["user_name"], {})[row["rule_id"]] = enabled return results + def get_all_push_rule_updates(self, last_id, current_id, limit): + """Get all the push rules changes that have happend on the server""" + if last_id == current_id: + return defer.succeed([]) + + def get_all_push_rule_updates_txn(txn): + sql = ( + "SELECT stream_id, event_stream_ordering, user_id, rule_id," + " op, priority_class, priority, conditions, actions" + " FROM push_rules_stream" + " WHERE ? < stream_id AND stream_id <= ?" + " ORDER BY stream_id ASC LIMIT ?" + ) + txn.execute(sql, (last_id, current_id, limit)) + return txn.fetchall() + + return self.db.runInteraction( + "get_all_push_rule_updates", get_all_push_rule_updates_txn + ) + class PushRuleStore(PushRulesWorkerStore): @defer.inlineCallbacks @@ -685,26 +705,6 @@ class PushRuleStore(PushRulesWorkerStore): self.push_rules_stream_cache.entity_has_changed, user_id, stream_id ) - def get_all_push_rule_updates(self, last_id, current_id, limit): - """Get all the push rules changes that have happend on the server""" - if last_id == current_id: - return defer.succeed([]) - - def get_all_push_rule_updates_txn(txn): - sql = ( - "SELECT stream_id, event_stream_ordering, user_id, rule_id," - " op, priority_class, priority, conditions, actions" - " FROM push_rules_stream" - " WHERE ? < stream_id AND stream_id <= ?" - " ORDER BY stream_id ASC LIMIT ?" - ) - txn.execute(sql, (last_id, current_id, limit)) - return txn.fetchall() - - return self.db.runInteraction( - "get_all_push_rule_updates", get_all_push_rule_updates_txn - ) - def get_push_rules_stream_token(self): """Get the position of the push rules stream. Returns a pair of a stream id for the push_rules stream and the -- cgit 1.4.1 From 2e105c156be036ebd408b8fbb87b5c218574726e Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 7 Apr 2020 15:19:19 +0100 Subject: Remove sent outbound device list pokes from the database (#7192) They just get in the way. --- changelog.d/7192.misc | 1 + synapse/storage/data_stores/main/devices.py | 4 ++-- .../schema/delta/57/remove_sent_outbound_pokes.sql | 21 +++++++++++++++++++++ 3 files changed, 24 insertions(+), 2 deletions(-) create mode 100644 changelog.d/7192.misc create mode 100644 synapse/storage/data_stores/main/schema/delta/57/remove_sent_outbound_pokes.sql (limited to 'synapse') diff --git a/changelog.d/7192.misc b/changelog.d/7192.misc new file mode 100644 index 0000000000..e401e36399 --- /dev/null +++ b/changelog.d/7192.misc @@ -0,0 +1 @@ +Remove sent outbound device list pokes from the database. diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py index dd3561e9b2..4c5bea4a5c 100644 --- a/synapse/storage/data_stores/main/devices.py +++ b/synapse/storage/data_stores/main/devices.py @@ -227,11 +227,11 @@ class DeviceWorkerStore(SQLBaseStore): # get the list of device updates that need to be sent sql = """ SELECT user_id, device_id, stream_id, opentracing_context FROM device_lists_outbound_pokes - WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ? + WHERE destination = ? AND ? < stream_id AND stream_id <= ? ORDER BY stream_id LIMIT ? """ - txn.execute(sql, (destination, from_stream_id, now_stream_id, False, limit)) + txn.execute(sql, (destination, from_stream_id, now_stream_id, limit)) return list(txn) diff --git a/synapse/storage/data_stores/main/schema/delta/57/remove_sent_outbound_pokes.sql b/synapse/storage/data_stores/main/schema/delta/57/remove_sent_outbound_pokes.sql new file mode 100644 index 0000000000..133d80af35 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/57/remove_sent_outbound_pokes.sql @@ -0,0 +1,21 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- we no longer keep sent outbound device pokes in the db; clear them out +-- so that we don't have to worry about them. +-- +-- This is a sequence scan, but it doesn't take too long. + +DELETE FROM device_lists_outbound_pokes WHERE sent; -- cgit 1.4.1 From ec5ac8e2b129f645a06f441143c2dcd2fb1c7037 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Tue, 7 Apr 2020 18:31:50 +0200 Subject: Fix typo in the login fallback javascript (#7235) * Fix typo in the login fallback javascript * Changelog --- changelog.d/7235.bugfix | 1 + synapse/static/client/login/js/login.js | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 changelog.d/7235.bugfix (limited to 'synapse') diff --git a/changelog.d/7235.bugfix b/changelog.d/7235.bugfix new file mode 100644 index 0000000000..d185efe537 --- /dev/null +++ b/changelog.d/7235.bugfix @@ -0,0 +1 @@ +Fix a bug causing the login fallback to not display the SSO login form. diff --git a/synapse/static/client/login/js/login.js b/synapse/static/client/login/js/login.js index debe464371..5ca0317755 100644 --- a/synapse/static/client/login/js/login.js +++ b/synapse/static/client/login/js/login.js @@ -62,7 +62,7 @@ var show_login = function(inhibit_redirect) { } // Otherwise, show the SSO form - $("#sso_form").show(); + $("#sso_flow").show(); } if (matrixLogin.serverAcceptsPassword) { -- cgit 1.4.1 From 6a519a0ca06f42c35d5c88f3fa5f62cfd1553905 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 7 Apr 2020 16:59:40 +0100 Subject: Remove vestigal references to SYNC replication command We've ripped pretty much all of this out: let's remove the remains. --- synapse/replication/tcp/commands.py | 10 ---------- synapse/replication/tcp/handler.py | 4 ---- 2 files changed, 14 deletions(-) (limited to 'synapse') diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index e4eec643f7..a07a01278a 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -289,14 +289,6 @@ class FederationAckCommand(Command): return str(self.token) -class SyncCommand(Command): - """Used for testing. The client protocol implementation allows waiting - on a SYNC command with a specified data. - """ - - NAME = "SYNC" - - class RemovePusherCommand(Command): """Sent by the client to request the master remove the given pusher. @@ -419,7 +411,6 @@ _COMMANDS = ( ReplicateCommand, UserSyncCommand, FederationAckCommand, - SyncCommand, RemovePusherCommand, InvalidateCacheCommand, UserIpCommand, @@ -437,7 +428,6 @@ VALID_SERVER_COMMANDS = ( PositionCommand.NAME, ErrorCommand.NAME, PingCommand.NAME, - SyncCommand.NAME, RemoteServerUpCommand.NAME, ) diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index dd71d1bc34..2f5a299141 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -31,7 +31,6 @@ from synapse.replication.tcp.commands import ( RemoteServerUpCommand, RemovePusherCommand, ReplicateCommand, - SyncCommand, UserIpCommand, UserSyncCommand, ) @@ -281,9 +280,6 @@ class ReplicationCommandHandler: self._streams_connected.add(cmd.stream_name) - async def on_SYNC(self, cmd: SyncCommand): - pass - async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand): """"Called when get a new REMOTE_SERVER_UP command.""" self._replication_data_handler.on_remote_server_up(cmd.data) -- cgit 1.4.1 From c3e4b4edb270ad31321207125007ff41344510ed Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 7 Apr 2020 17:40:22 +0100 Subject: Fix warnings about not calling superclass constructor Separate `SimpleCommand` from `Command`, so that things which don't want to use the `data` property don't have to, and thus fix the warnings PyCharm was giving me about not calling `__init__` in the base class. --- synapse/replication/tcp/commands.py | 39 +++++++++++++++++++++++-------------- 1 file changed, 24 insertions(+), 15 deletions(-) (limited to 'synapse') diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index a07a01278a..5ec89d0fb8 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -17,7 +17,7 @@ The VALID_SERVER_COMMANDS and VALID_CLIENT_COMMANDS define which commands are allowed to be sent by which side. """ - +import abc import logging import platform from typing import Tuple, Type @@ -34,34 +34,29 @@ else: logger = logging.getLogger(__name__) -class Command(object): +class Command(metaclass=abc.ABCMeta): """The base command class. All subclasses must set the NAME variable which equates to the name of the command on the wire. A full command line on the wire is constructed from `NAME + " " + to_line()` - - The default implementation creates a command of form ` ` """ NAME = None # type: str - def __init__(self, data): - self.data = data - @classmethod + @abc.abstractmethod def from_line(cls, line): """Deserialises a line from the wire into this command. `line` does not include the command. """ - return cls(line) - def to_line(self): + @abc.abstractmethod + def to_line(self) -> str: """Serialises the comamnd for the wire. Does not include the command prefix. """ - return self.data def get_logcontext_id(self): """Get a suitable string for the logcontext when processing this command""" @@ -70,7 +65,21 @@ class Command(object): return self.NAME -class ServerCommand(Command): +class _SimpleCommand(Command): + """An implementation of Command whose argument is just a 'data' string.""" + + def __init__(self, data): + self.data = data + + @classmethod + def from_line(cls, line): + return cls(line) + + def to_line(self) -> str: + return self.data + + +class ServerCommand(_SimpleCommand): """Sent by the server on new connection and includes the server_name. Format:: @@ -155,7 +164,7 @@ class PositionCommand(Command): return " ".join((self.stream_name, str(self.token))) -class ErrorCommand(Command): +class ErrorCommand(_SimpleCommand): """Sent by either side if there was an ERROR. The data is a string describing the error. """ @@ -163,14 +172,14 @@ class ErrorCommand(Command): NAME = "ERROR" -class PingCommand(Command): +class PingCommand(_SimpleCommand): """Sent by either side as a keep alive. The data is arbitary (often timestamp) """ NAME = "PING" -class NameCommand(Command): +class NameCommand(_SimpleCommand): """Sent by client to inform the server of the client's identity. The data is the name """ @@ -387,7 +396,7 @@ class UserIpCommand(Command): ) -class RemoteServerUpCommand(Command): +class RemoteServerUpCommand(_SimpleCommand): """Sent when a worker has detected that a remote server is no longer "down" and retry timings should be reset. -- cgit 1.4.1 From e13c6c7a96dc71200eef9f966bee21c27ae54426 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 7 Apr 2020 17:21:13 +0100 Subject: Handle one-word replication commands correctly `REPLICATE` is now a valid command, and it's nice if you can issue it from the console without remembering to call it `REPLICATE ` with a trailing space. --- synapse/replication/tcp/protocol.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) (limited to 'synapse') diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index 9aabb9c586..9276ed2965 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -201,15 +201,23 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): ) self.send_error("ping timeout") - def lineReceived(self, line): + def lineReceived(self, line: bytes): """Called when we've received a line """ if line.strip() == "": # Ignore blank lines return - line = line.decode("utf-8") - cmd_name, rest_of_line = line.split(" ", 1) + linestr = line.decode("utf-8") + + # split at the first " ", handling one-word commands + idx = linestr.index(" ") + if idx >= 0: + cmd_name = linestr[:idx] + rest_of_line = linestr[idx + 1 :] + else: + cmd_name = linestr + rest_of_line = "" if cmd_name not in self.VALID_INBOUND_COMMANDS: logger.error("[%s] invalid command %s", self.id(), cmd_name) -- cgit 1.4.1 From d78cb31588e01468ab06a36e6120a80fb6fbf097 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 7 Apr 2020 15:03:23 -0400 Subject: Add typing information to federation_server. (#7219) --- changelog.d/7219.misc | 1 + synapse/federation/federation_server.py | 173 ++++++++++++++++++++------------ tox.ini | 1 + 3 files changed, 109 insertions(+), 66 deletions(-) create mode 100644 changelog.d/7219.misc (limited to 'synapse') diff --git a/changelog.d/7219.misc b/changelog.d/7219.misc new file mode 100644 index 0000000000..4af5da8646 --- /dev/null +++ b/changelog.d/7219.misc @@ -0,0 +1 @@ +Add typing information to federation server code. diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 89d521bc31..32a8a2ee46 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -15,7 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Dict +from typing import Any, Callable, Dict, List, Match, Optional, Tuple, Union import six from six import iteritems @@ -38,6 +38,7 @@ from synapse.api.errors import ( UnsupportedRoomVersionError, ) from synapse.api.room_versions import KNOWN_ROOM_VERSIONS +from synapse.events import EventBase from synapse.federation.federation_base import FederationBase, event_from_pdu_json from synapse.federation.persistence import TransactionActions from synapse.federation.units import Edu, Transaction @@ -94,7 +95,9 @@ class FederationServer(FederationBase): # come in waves. self._state_resp_cache = ResponseCache(hs, "state_resp", timeout_ms=30000) - async def on_backfill_request(self, origin, room_id, versions, limit): + async def on_backfill_request( + self, origin: str, room_id: str, versions: List[str], limit: int + ) -> Tuple[int, Dict[str, Any]]: with (await self._server_linearizer.queue((origin, room_id))): origin_host, _ = parse_server_name(origin) await self.check_server_matches_acl(origin_host, room_id) @@ -107,23 +110,25 @@ class FederationServer(FederationBase): return 200, res - async def on_incoming_transaction(self, origin, transaction_data): + async def on_incoming_transaction( + self, origin: str, transaction_data: JsonDict + ) -> Tuple[int, Dict[str, Any]]: # keep this as early as possible to make the calculated origin ts as # accurate as possible. request_time = self._clock.time_msec() transaction = Transaction(**transaction_data) - if not transaction.transaction_id: + if not transaction.transaction_id: # type: ignore raise Exception("Transaction missing transaction_id") - logger.debug("[%s] Got transaction", transaction.transaction_id) + logger.debug("[%s] Got transaction", transaction.transaction_id) # type: ignore # use a linearizer to ensure that we don't process the same transaction # multiple times in parallel. with ( await self._transaction_linearizer.queue( - (origin, transaction.transaction_id) + (origin, transaction.transaction_id) # type: ignore ) ): result = await self._handle_incoming_transaction( @@ -132,31 +137,33 @@ class FederationServer(FederationBase): return result - async def _handle_incoming_transaction(self, origin, transaction, request_time): + async def _handle_incoming_transaction( + self, origin: str, transaction: Transaction, request_time: int + ) -> Tuple[int, Dict[str, Any]]: """ Process an incoming transaction and return the HTTP response Args: - origin (unicode): the server making the request - transaction (Transaction): incoming transaction - request_time (int): timestamp that the HTTP request arrived at + origin: the server making the request + transaction: incoming transaction + request_time: timestamp that the HTTP request arrived at Returns: - Deferred[(int, object)]: http response code and body + HTTP response code and body """ response = await self.transaction_actions.have_responded(origin, transaction) if response: logger.debug( "[%s] We've already responded to this request", - transaction.transaction_id, + transaction.transaction_id, # type: ignore ) return response - logger.debug("[%s] Transaction is new", transaction.transaction_id) + logger.debug("[%s] Transaction is new", transaction.transaction_id) # type: ignore # Reject if PDU count > 50 or EDU count > 100 - if len(transaction.pdus) > 50 or ( - hasattr(transaction, "edus") and len(transaction.edus) > 100 + if len(transaction.pdus) > 50 or ( # type: ignore + hasattr(transaction, "edus") and len(transaction.edus) > 100 # type: ignore ): logger.info("Transaction PDU or EDU count too large. Returning 400") @@ -204,13 +211,13 @@ class FederationServer(FederationBase): report back to the sending server. """ - received_pdus_counter.inc(len(transaction.pdus)) + received_pdus_counter.inc(len(transaction.pdus)) # type: ignore origin_host, _ = parse_server_name(origin) - pdus_by_room = {} + pdus_by_room = {} # type: Dict[str, List[EventBase]] - for p in transaction.pdus: + for p in transaction.pdus: # type: ignore if "unsigned" in p: unsigned = p["unsigned"] if "age" in unsigned: @@ -254,7 +261,7 @@ class FederationServer(FederationBase): # require callouts to other servers to fetch missing events), but # impose a limit to avoid going too crazy with ram/cpu. - async def process_pdus_for_room(room_id): + async def process_pdus_for_room(room_id: str): logger.debug("Processing PDUs for %s", room_id) try: await self.check_server_matches_acl(origin_host, room_id) @@ -310,7 +317,9 @@ class FederationServer(FederationBase): TRANSACTION_CONCURRENCY_LIMIT, ) - async def on_context_state_request(self, origin, room_id, event_id): + async def on_context_state_request( + self, origin: str, room_id: str, event_id: str + ) -> Tuple[int, Dict[str, Any]]: origin_host, _ = parse_server_name(origin) await self.check_server_matches_acl(origin_host, room_id) @@ -338,7 +347,9 @@ class FederationServer(FederationBase): return 200, resp - async def on_state_ids_request(self, origin, room_id, event_id): + async def on_state_ids_request( + self, origin: str, room_id: str, event_id: str + ) -> Tuple[int, Dict[str, Any]]: if not event_id: raise NotImplementedError("Specify an event") @@ -354,7 +365,9 @@ class FederationServer(FederationBase): return 200, {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids} - async def _on_context_state_request_compute(self, room_id, event_id): + async def _on_context_state_request_compute( + self, room_id: str, event_id: str + ) -> Dict[str, list]: if event_id: pdus = await self.handler.get_state_for_pdu(room_id, event_id) else: @@ -367,7 +380,9 @@ class FederationServer(FederationBase): "auth_chain": [pdu.get_pdu_json() for pdu in auth_chain], } - async def on_pdu_request(self, origin, event_id): + async def on_pdu_request( + self, origin: str, event_id: str + ) -> Tuple[int, Union[JsonDict, str]]: pdu = await self.handler.get_persisted_pdu(origin, event_id) if pdu: @@ -375,12 +390,16 @@ class FederationServer(FederationBase): else: return 404, "" - async def on_query_request(self, query_type, args): + async def on_query_request( + self, query_type: str, args: Dict[str, str] + ) -> Tuple[int, Dict[str, Any]]: received_queries_counter.labels(query_type).inc() resp = await self.registry.on_query(query_type, args) return 200, resp - async def on_make_join_request(self, origin, room_id, user_id, supported_versions): + async def on_make_join_request( + self, origin: str, room_id: str, user_id: str, supported_versions: List[str] + ) -> Dict[str, Any]: origin_host, _ = parse_server_name(origin) await self.check_server_matches_acl(origin_host, room_id) @@ -397,7 +416,7 @@ class FederationServer(FederationBase): async def on_invite_request( self, origin: str, content: JsonDict, room_version_id: str - ): + ) -> Dict[str, Any]: room_version = KNOWN_ROOM_VERSIONS.get(room_version_id) if not room_version: raise SynapseError( @@ -414,7 +433,9 @@ class FederationServer(FederationBase): time_now = self._clock.time_msec() return {"event": ret_pdu.get_pdu_json(time_now)} - async def on_send_join_request(self, origin, content, room_id): + async def on_send_join_request( + self, origin: str, content: JsonDict, room_id: str + ) -> Dict[str, Any]: logger.debug("on_send_join_request: content: %s", content) room_version = await self.store.get_room_version(room_id) @@ -434,7 +455,9 @@ class FederationServer(FederationBase): "auth_chain": [p.get_pdu_json(time_now) for p in res_pdus["auth_chain"]], } - async def on_make_leave_request(self, origin, room_id, user_id): + async def on_make_leave_request( + self, origin: str, room_id: str, user_id: str + ) -> Dict[str, Any]: origin_host, _ = parse_server_name(origin) await self.check_server_matches_acl(origin_host, room_id) pdu = await self.handler.on_make_leave_request(origin, room_id, user_id) @@ -444,7 +467,9 @@ class FederationServer(FederationBase): time_now = self._clock.time_msec() return {"event": pdu.get_pdu_json(time_now), "room_version": room_version} - async def on_send_leave_request(self, origin, content, room_id): + async def on_send_leave_request( + self, origin: str, content: JsonDict, room_id: str + ) -> dict: logger.debug("on_send_leave_request: content: %s", content) room_version = await self.store.get_room_version(room_id) @@ -460,7 +485,9 @@ class FederationServer(FederationBase): await self.handler.on_send_leave_request(origin, pdu) return {} - async def on_event_auth(self, origin, room_id, event_id): + async def on_event_auth( + self, origin: str, room_id: str, event_id: str + ) -> Tuple[int, Dict[str, Any]]: with (await self._server_linearizer.queue((origin, room_id))): origin_host, _ = parse_server_name(origin) await self.check_server_matches_acl(origin_host, room_id) @@ -471,15 +498,21 @@ class FederationServer(FederationBase): return 200, res @log_function - def on_query_client_keys(self, origin, content): - return self.on_query_request("client_keys", content) - - async def on_query_user_devices(self, origin: str, user_id: str): + async def on_query_client_keys( + self, origin: str, content: Dict[str, str] + ) -> Tuple[int, Dict[str, Any]]: + return await self.on_query_request("client_keys", content) + + async def on_query_user_devices( + self, origin: str, user_id: str + ) -> Tuple[int, Dict[str, Any]]: keys = await self.device_handler.on_federation_query_user_devices(user_id) return 200, keys @trace - async def on_claim_client_keys(self, origin, content): + async def on_claim_client_keys( + self, origin: str, content: JsonDict + ) -> Dict[str, Any]: query = [] for user_id, device_keys in content.get("one_time_keys", {}).items(): for device_id, algorithm in device_keys.items(): @@ -488,7 +521,7 @@ class FederationServer(FederationBase): log_kv({"message": "Claiming one time keys.", "user, device pairs": query}) results = await self.store.claim_e2e_one_time_keys(query) - json_result = {} + json_result = {} # type: Dict[str, Dict[str, dict]] for user_id, device_keys in results.items(): for device_id, keys in device_keys.items(): for key_id, json_bytes in keys.items(): @@ -511,8 +544,13 @@ class FederationServer(FederationBase): return {"one_time_keys": json_result} async def on_get_missing_events( - self, origin, room_id, earliest_events, latest_events, limit - ): + self, + origin: str, + room_id: str, + earliest_events: List[str], + latest_events: List[str], + limit: int, + ) -> Dict[str, list]: with (await self._server_linearizer.queue((origin, room_id))): origin_host, _ = parse_server_name(origin) await self.check_server_matches_acl(origin_host, room_id) @@ -541,11 +579,11 @@ class FederationServer(FederationBase): return {"events": [ev.get_pdu_json(time_now) for ev in missing_events]} @log_function - def on_openid_userinfo(self, token): + async def on_openid_userinfo(self, token: str) -> Optional[str]: ts_now_ms = self._clock.time_msec() - return self.store.get_user_id_for_open_id_token(token, ts_now_ms) + return await self.store.get_user_id_for_open_id_token(token, ts_now_ms) - def _transaction_from_pdus(self, pdu_list): + def _transaction_from_pdus(self, pdu_list: List[EventBase]) -> Transaction: """Returns a new Transaction containing the given PDUs suitable for transmission. """ @@ -558,7 +596,7 @@ class FederationServer(FederationBase): destination=None, ) - async def _handle_received_pdu(self, origin, pdu): + async def _handle_received_pdu(self, origin: str, pdu: EventBase) -> None: """ Process a PDU received in a federation /send/ transaction. If the event is invalid, then this method throws a FederationError. @@ -579,10 +617,8 @@ class FederationServer(FederationBase): until we try to backfill across the discontinuity. Args: - origin (str): server which sent the pdu - pdu (FrozenEvent): received pdu - - Returns (Deferred): completes with None + origin: server which sent the pdu + pdu: received pdu Raises: FederationError if the signatures / hash do not match, or if the event was unacceptable for any other reason (eg, too large, @@ -625,25 +661,27 @@ class FederationServer(FederationBase): return "" % self.server_name async def exchange_third_party_invite( - self, sender_user_id, target_user_id, room_id, signed + self, sender_user_id: str, target_user_id: str, room_id: str, signed: Dict ): ret = await self.handler.exchange_third_party_invite( sender_user_id, target_user_id, room_id, signed ) return ret - async def on_exchange_third_party_invite_request(self, room_id, event_dict): + async def on_exchange_third_party_invite_request( + self, room_id: str, event_dict: Dict + ): ret = await self.handler.on_exchange_third_party_invite_request( room_id, event_dict ) return ret - async def check_server_matches_acl(self, server_name, room_id): + async def check_server_matches_acl(self, server_name: str, room_id: str): """Check if the given server is allowed by the server ACLs in the room Args: - server_name (str): name of server, *without any port part* - room_id (str): ID of the room to check + server_name: name of server, *without any port part* + room_id: ID of the room to check Raises: AuthError if the server does not match the ACL @@ -661,15 +699,15 @@ class FederationServer(FederationBase): raise AuthError(code=403, msg="Server is banned from room") -def server_matches_acl_event(server_name, acl_event): +def server_matches_acl_event(server_name: str, acl_event: EventBase) -> bool: """Check if the given server is allowed by the ACL event Args: - server_name (str): name of server, without any port part - acl_event (EventBase): m.room.server_acl event + server_name: name of server, without any port part + acl_event: m.room.server_acl event Returns: - bool: True if this server is allowed by the ACLs + True if this server is allowed by the ACLs """ logger.debug("Checking %s against acl %s", server_name, acl_event.content) @@ -713,7 +751,7 @@ def server_matches_acl_event(server_name, acl_event): return False -def _acl_entry_matches(server_name, acl_entry): +def _acl_entry_matches(server_name: str, acl_entry: str) -> Match: if not isinstance(acl_entry, six.string_types): logger.warning( "Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry) @@ -732,13 +770,13 @@ class FederationHandlerRegistry(object): self.edu_handlers = {} self.query_handlers = {} - def register_edu_handler(self, edu_type, handler): + def register_edu_handler(self, edu_type: str, handler: Callable[[str, dict], None]): """Sets the handler callable that will be used to handle an incoming federation EDU of the given type. Args: - edu_type (str): The type of the incoming EDU to register handler for - handler (Callable[[str, dict]]): A callable invoked on incoming EDU + edu_type: The type of the incoming EDU to register handler for + handler: A callable invoked on incoming EDU of the given type. The arguments are the origin server name and the EDU contents. """ @@ -749,14 +787,16 @@ class FederationHandlerRegistry(object): self.edu_handlers[edu_type] = handler - def register_query_handler(self, query_type, handler): + def register_query_handler( + self, query_type: str, handler: Callable[[dict], defer.Deferred] + ): """Sets the handler callable that will be used to handle an incoming federation query of the given type. Args: - query_type (str): Category name of the query, which should match + query_type: Category name of the query, which should match the string used by make_query. - handler (Callable[[dict], Deferred[dict]]): Invoked to handle + handler: Invoked to handle incoming queries of this type. The return will be yielded on and the result used as the response to the query request. """ @@ -767,10 +807,11 @@ class FederationHandlerRegistry(object): self.query_handlers[query_type] = handler - async def on_edu(self, edu_type, origin, content): + async def on_edu(self, edu_type: str, origin: str, content: dict): handler = self.edu_handlers.get(edu_type) if not handler: logger.warning("No handler registered for EDU type %s", edu_type) + return with start_active_span_from_edu(content, "handle_edu"): try: @@ -780,7 +821,7 @@ class FederationHandlerRegistry(object): except Exception: logger.exception("Failed to handle edu %r", edu_type) - def on_query(self, query_type, args): + def on_query(self, query_type: str, args: dict) -> defer.Deferred: handler = self.query_handlers.get(query_type) if not handler: logger.warning("No handler registered for query type %s", query_type) @@ -807,7 +848,7 @@ class ReplicationFederationHandlerRegistry(FederationHandlerRegistry): super(ReplicationFederationHandlerRegistry, self).__init__() - async def on_edu(self, edu_type, origin, content): + async def on_edu(self, edu_type: str, origin: str, content: dict): """Overrides FederationHandlerRegistry """ if not self.config.use_presence and edu_type == "m.presence": @@ -821,7 +862,7 @@ class ReplicationFederationHandlerRegistry(FederationHandlerRegistry): return await self._send_edu(edu_type=edu_type, origin=origin, content=content) - async def on_query(self, query_type, args): + async def on_query(self, query_type: str, args: dict): """Overrides FederationHandlerRegistry """ handler = self.query_handlers.get(query_type) diff --git a/tox.ini b/tox.ini index a79fc93b57..763c8463d9 100644 --- a/tox.ini +++ b/tox.ini @@ -183,6 +183,7 @@ commands = mypy \ synapse/events/spamcheck.py \ synapse/federation/federation_base.py \ synapse/federation/federation_client.py \ + synapse/federation/federation_server.py \ synapse/federation/sender \ synapse/federation/transport \ synapse/handlers/auth.py \ -- cgit 1.4.1 From 1722b8a527b8caa0f76706bf4acaf240e167daf4 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 7 Apr 2020 16:56:34 -0400 Subject: Convert delete_url_cache_media to async/await. (#7241) --- changelog.d/7241.misc | 1 + synapse/storage/data_stores/main/media_repository.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) create mode 100644 changelog.d/7241.misc (limited to 'synapse') diff --git a/changelog.d/7241.misc b/changelog.d/7241.misc new file mode 100644 index 0000000000..fac5bc0403 --- /dev/null +++ b/changelog.d/7241.misc @@ -0,0 +1 @@ +Convert some of synapse.rest.media to async/await. diff --git a/synapse/storage/data_stores/main/media_repository.py b/synapse/storage/data_stores/main/media_repository.py index cf195f8aa6..8aecd414c2 100644 --- a/synapse/storage/data_stores/main/media_repository.py +++ b/synapse/storage/data_stores/main/media_repository.py @@ -367,7 +367,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): "get_url_cache_media_before", _get_url_cache_media_before_txn ) - def delete_url_cache_media(self, media_ids): + async def delete_url_cache_media(self, media_ids): if len(media_ids) == 0: return @@ -380,6 +380,6 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): txn.executemany(sql, [(media_id,) for media_id in media_ids]) - return self.db.runInteraction( + return await self.db.runInteraction( "delete_url_cache_media", _delete_url_cache_media_txn ) -- cgit 1.4.1 From f31e65a749f84f8b3278c91784509d908d4fb342 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 7 Apr 2020 23:06:39 +0100 Subject: bg update to clear out duplicate outbound_device_list_pokes (#7193) We seem to have some duplicates, which could do with being cleared out. --- changelog.d/7193.misc | 1 + synapse/storage/data_stores/main/client_ips.py | 16 ++--- synapse/storage/data_stores/main/devices.py | 73 ++++++++++++++++++- .../delta/58/02remove_dup_outbound_pokes.sql | 22 ++++++ synapse/storage/database.py | 83 +++++++++++++++++++++- tests/storage/test_database.py | 52 ++++++++++++++ 6 files changed, 234 insertions(+), 13 deletions(-) create mode 100644 changelog.d/7193.misc create mode 100644 synapse/storage/data_stores/main/schema/delta/58/02remove_dup_outbound_pokes.sql create mode 100644 tests/storage/test_database.py (limited to 'synapse') diff --git a/changelog.d/7193.misc b/changelog.d/7193.misc new file mode 100644 index 0000000000..383a738e64 --- /dev/null +++ b/changelog.d/7193.misc @@ -0,0 +1 @@ +Add a background database update job to clear out duplicate `device_lists_outbound_pokes`. diff --git a/synapse/storage/data_stores/main/client_ips.py b/synapse/storage/data_stores/main/client_ips.py index e1ccb27142..92bc06919b 100644 --- a/synapse/storage/data_stores/main/client_ips.py +++ b/synapse/storage/data_stores/main/client_ips.py @@ -21,7 +21,7 @@ from twisted.internet import defer from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore -from synapse.storage.database import Database +from synapse.storage.database import Database, make_tuple_comparison_clause from synapse.util.caches import CACHE_SIZE_FACTOR from synapse.util.caches.descriptors import Cache @@ -303,16 +303,10 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): # we'll just end up updating the same device row multiple # times, which is fine. - if self.database_engine.supports_tuple_comparison: - where_clause = "(user_id, device_id) > (?, ?)" - where_args = [last_user_id, last_device_id] - else: - # We explicitly do a `user_id >= ? AND (...)` here to ensure - # that an index is used, as doing `user_id > ? OR (user_id = ? AND ...)` - # makes it hard for query optimiser to tell that it can use the - # index on user_id - where_clause = "user_id >= ? AND (user_id > ? OR device_id > ?)" - where_args = [last_user_id, last_user_id, last_device_id] + where_clause, where_args = make_tuple_comparison_clause( + self.database_engine, + [("user_id", last_user_id), ("device_id", last_device_id)], + ) sql = """ SELECT diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py index 4c5bea4a5c..ee3a2ab031 100644 --- a/synapse/storage/data_stores/main/devices.py +++ b/synapse/storage/data_stores/main/devices.py @@ -32,7 +32,11 @@ from synapse.logging.opentracing import ( ) from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause -from synapse.storage.database import Database, LoggingTransaction +from synapse.storage.database import ( + Database, + LoggingTransaction, + make_tuple_comparison_clause, +) from synapse.types import Collection, get_verify_key_from_cross_signing_key from synapse.util.caches.descriptors import ( Cache, @@ -49,6 +53,8 @@ DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = ( "drop_device_list_streams_non_unique_indexes" ) +BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes" + class DeviceWorkerStore(SQLBaseStore): def get_device(self, user_id, device_id): @@ -714,6 +720,11 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): self._drop_device_list_streams_non_unique_indexes, ) + # clear out duplicate device list outbound pokes + self.db.updates.register_background_update_handler( + BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, self._remove_duplicate_outbound_pokes, + ) + @defer.inlineCallbacks def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size): def f(conn): @@ -728,6 +739,66 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): ) return 1 + async def _remove_duplicate_outbound_pokes(self, progress, batch_size): + # for some reason, we have accumulated duplicate entries in + # device_lists_outbound_pokes, which makes prune_outbound_device_list_pokes less + # efficient. + # + # For each duplicate, we delete all the existing rows and put one back. + + KEY_COLS = ["stream_id", "destination", "user_id", "device_id"] + last_row = progress.get( + "last_row", + {"stream_id": 0, "destination": "", "user_id": "", "device_id": ""}, + ) + + def _txn(txn): + clause, args = make_tuple_comparison_clause( + self.db.engine, [(x, last_row[x]) for x in KEY_COLS] + ) + sql = """ + SELECT stream_id, destination, user_id, device_id, MAX(ts) AS ts + FROM device_lists_outbound_pokes + WHERE %s + GROUP BY %s + HAVING count(*) > 1 + ORDER BY %s + LIMIT ? + """ % ( + clause, # WHERE + ",".join(KEY_COLS), # GROUP BY + ",".join(KEY_COLS), # ORDER BY + ) + txn.execute(sql, args + [batch_size]) + rows = self.db.cursor_to_dict(txn) + + row = None + for row in rows: + self.db.simple_delete_txn( + txn, "device_lists_outbound_pokes", {x: row[x] for x in KEY_COLS}, + ) + + row["sent"] = False + self.db.simple_insert_txn( + txn, "device_lists_outbound_pokes", row, + ) + + if row: + self.db.updates._background_update_progress_txn( + txn, BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, {"last_row": row}, + ) + + return len(rows) + + rows = await self.db.runInteraction(BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, _txn) + + if not rows: + await self.db.updates._end_background_update( + BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES + ) + + return rows + class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): def __init__(self, database: Database, db_conn, hs): diff --git a/synapse/storage/data_stores/main/schema/delta/58/02remove_dup_outbound_pokes.sql b/synapse/storage/data_stores/main/schema/delta/58/02remove_dup_outbound_pokes.sql new file mode 100644 index 0000000000..fdc39e9ba5 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/58/02remove_dup_outbound_pokes.sql @@ -0,0 +1,22 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + /* for some reason, we have accumulated duplicate entries in + * device_lists_outbound_pokes, which makes prune_outbound_device_list_pokes less + * efficient. + */ + +INSERT INTO background_updates (ordering, update_name, progress_json) + VALUES (5800, 'remove_dup_outbound_pokes', '{}'); diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 715c0346dd..a7cd97b0b0 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -17,7 +17,17 @@ import logging import time from time import monotonic as monotonic_time -from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple +from typing import ( + Any, + Callable, + Dict, + Iterable, + Iterator, + List, + Optional, + Tuple, + TypeVar, +) from six import iteritems, iterkeys, itervalues from six.moves import intern, range @@ -1557,3 +1567,74 @@ def make_in_list_sql_clause( return "%s = ANY(?)" % (column,), [list(iterable)] else: return "%s IN (%s)" % (column, ",".join("?" for _ in iterable)), list(iterable) + + +KV = TypeVar("KV") + + +def make_tuple_comparison_clause( + database_engine: BaseDatabaseEngine, keys: List[Tuple[str, KV]] +) -> Tuple[str, List[KV]]: + """Returns a tuple comparison SQL clause + + Depending what the SQL engine supports, builds a SQL clause that looks like either + "(a, b) > (?, ?)", or "(a > ?) OR (a == ? AND b > ?)". + + Args: + database_engine + keys: A set of (column, value) pairs to be compared. + + Returns: + A tuple of SQL query and the args + """ + if database_engine.supports_tuple_comparison: + return ( + "(%s) > (%s)" % (",".join(k[0] for k in keys), ",".join("?" for _ in keys)), + [k[1] for k in keys], + ) + + # we want to build a clause + # (a > ?) OR + # (a == ? AND b > ?) OR + # (a == ? AND b == ? AND c > ?) + # ... + # (a == ? AND b == ? AND ... AND z > ?) + # + # or, equivalently: + # + # (a > ? OR (a == ? AND + # (b > ? OR (b == ? AND + # ... + # (y > ? OR (y == ? AND + # z > ? + # )) + # ... + # )) + # )) + # + # which itself is equivalent to (and apparently easier for the query optimiser): + # + # (a >= ? AND (a > ? OR + # (b >= ? AND (b > ? OR + # ... + # (y >= ? AND (y > ? OR + # z > ? + # )) + # ... + # )) + # )) + # + # + + clause = "" + args = [] # type: List[KV] + for k, v in keys[:-1]: + clause = clause + "(%s >= ? AND (%s > ? OR " % (k, k) + args.extend([v, v]) + + (k, v) = keys[-1] + clause += "%s > ?" % (k,) + args.append(v) + + clause += "))" * (len(keys) - 1) + return clause, args diff --git a/tests/storage/test_database.py b/tests/storage/test_database.py new file mode 100644 index 0000000000..5a77c84962 --- /dev/null +++ b/tests/storage/test_database.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.storage.database import make_tuple_comparison_clause +from synapse.storage.engines import BaseDatabaseEngine + +from tests import unittest + + +def _stub_db_engine(**kwargs) -> BaseDatabaseEngine: + # returns a DatabaseEngine, circumventing the abc mechanism + # any kwargs are set as attributes on the class before instantiating it + t = type( + "TestBaseDatabaseEngine", + (BaseDatabaseEngine,), + dict(BaseDatabaseEngine.__dict__), + ) + # defeat the abc mechanism + t.__abstractmethods__ = set() + for k, v in kwargs.items(): + setattr(t, k, v) + return t(None, None) + + +class TupleComparisonClauseTestCase(unittest.TestCase): + def test_native_tuple_comparison(self): + db_engine = _stub_db_engine(supports_tuple_comparison=True) + clause, args = make_tuple_comparison_clause(db_engine, [("a", 1), ("b", 2)]) + self.assertEqual(clause, "(a,b) > (?,?)") + self.assertEqual(args, [1, 2]) + + def test_emulated_tuple_comparison(self): + db_engine = _stub_db_engine(supports_tuple_comparison=False) + clause, args = make_tuple_comparison_clause( + db_engine, [("a", 1), ("b", 2), ("c", 3)] + ) + self.assertEqual( + clause, "(a >= ? AND (a > ? OR (b >= ? AND (b > ? OR c > ?))))" + ) + self.assertEqual(args, [1, 1, 2, 2, 3]) -- cgit 1.4.1 From 29b7e22b939c473649c8619fdfbecec0cee6b029 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Wed, 8 Apr 2020 00:46:50 +0100 Subject: Add documentation to password_providers config option (#7238) --- changelog.d/7238.doc | 1 + docs/password_auth_providers.md | 5 ++++- docs/sample_config.yaml | 14 +++++++++++++- synapse/config/password_auth_providers.py | 16 ++++++++++++++-- 4 files changed, 32 insertions(+), 4 deletions(-) create mode 100644 changelog.d/7238.doc (limited to 'synapse') diff --git a/changelog.d/7238.doc b/changelog.d/7238.doc new file mode 100644 index 0000000000..0e3b4be428 --- /dev/null +++ b/changelog.d/7238.doc @@ -0,0 +1 @@ +Add documentation to the `password_providers` config option. Add known password provider implementations to docs. \ No newline at end of file diff --git a/docs/password_auth_providers.md b/docs/password_auth_providers.md index 0db1a3804a..96f9841b7a 100644 --- a/docs/password_auth_providers.md +++ b/docs/password_auth_providers.md @@ -9,7 +9,10 @@ into Synapse, and provides a number of methods by which it can integrate with the authentication system. This document serves as a reference for those looking to implement their -own password auth providers. +own password auth providers. Additionally, here is a list of known +password auth provider module implementations: + +* [matrix-synapse-ldap3](https://github.com/matrix-org/matrix-synapse-ldap3/) ## Required methods diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index be742969cc..3417813750 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1657,7 +1657,19 @@ email: #template_dir: "res/templates" -#password_providers: +# Password providers allow homeserver administrators to integrate +# their Synapse installation with existing authentication methods +# ex. LDAP, external tokens, etc. +# +# For more information and known implementations, please see +# https://github.com/matrix-org/synapse/blob/master/docs/password_auth_providers.md +# +# Note: instances wishing to use SAML or CAS authentication should +# instead use the `saml2_config` or `cas_config` options, +# respectively. +# +password_providers: +# # Example config for an LDAP auth provider # - module: "ldap_auth_provider.LdapAuthProvider" # config: # enabled: true diff --git a/synapse/config/password_auth_providers.py b/synapse/config/password_auth_providers.py index 9746bbc681..4fda8ae987 100644 --- a/synapse/config/password_auth_providers.py +++ b/synapse/config/password_auth_providers.py @@ -35,7 +35,7 @@ class PasswordAuthProviderConfig(Config): if ldap_config.get("enabled", False): providers.append({"module": LDAP_PROVIDER, "config": ldap_config}) - providers.extend(config.get("password_providers", [])) + providers.extend(config.get("password_providers") or []) for provider in providers: mod_name = provider["module"] @@ -52,7 +52,19 @@ class PasswordAuthProviderConfig(Config): def generate_config_section(self, **kwargs): return """\ - #password_providers: + # Password providers allow homeserver administrators to integrate + # their Synapse installation with existing authentication methods + # ex. LDAP, external tokens, etc. + # + # For more information and known implementations, please see + # https://github.com/matrix-org/synapse/blob/master/docs/password_auth_providers.md + # + # Note: instances wishing to use SAML or CAS authentication should + # instead use the `saml2_config` or `cas_config` options, + # respectively. + # + password_providers: + # # Example config for an LDAP auth provider # - module: "ldap_auth_provider.LdapAuthProvider" # config: # enabled: true -- cgit 1.4.1 From 55d46da59a9d0db58888243137b69ee342921b11 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Thu, 9 Apr 2020 12:23:30 +0100 Subject: Upgrade jQuery to 3.x on fallback login/registration screens (#7236) --- changelog.d/7236.misc | 1 + synapse/static/client/login/index.html | 3 ++- synapse/static/client/login/js/jquery-2.1.3.min.js | 4 ---- synapse/static/client/login/js/jquery-3.4.1.min.js | 2 ++ synapse/static/client/login/js/login.js | 6 +++--- synapse/static/client/register/index.html | 3 ++- synapse/static/client/register/js/jquery-2.1.3.min.js | 4 ---- synapse/static/client/register/js/jquery-3.4.1.min.js | 2 ++ synapse/static/client/register/js/register.js | 6 +++--- 9 files changed, 15 insertions(+), 16 deletions(-) create mode 100644 changelog.d/7236.misc delete mode 100644 synapse/static/client/login/js/jquery-2.1.3.min.js create mode 100644 synapse/static/client/login/js/jquery-3.4.1.min.js delete mode 100644 synapse/static/client/register/js/jquery-2.1.3.min.js create mode 100644 synapse/static/client/register/js/jquery-3.4.1.min.js (limited to 'synapse') diff --git a/changelog.d/7236.misc b/changelog.d/7236.misc new file mode 100644 index 0000000000..e4a2702b54 --- /dev/null +++ b/changelog.d/7236.misc @@ -0,0 +1 @@ +Upgrade jQuery to v3.4.1 on fallback login/registration pages. \ No newline at end of file diff --git a/synapse/static/client/login/index.html b/synapse/static/client/login/index.html index 712b0e3980..6fefdaaff7 100644 --- a/synapse/static/client/login/index.html +++ b/synapse/static/client/login/index.html @@ -1,9 +1,10 @@ + Login - + diff --git a/synapse/static/client/login/js/jquery-2.1.3.min.js b/synapse/static/client/login/js/jquery-2.1.3.min.js deleted file mode 100644 index 25714ed29a..0000000000 --- a/synapse/static/client/login/js/jquery-2.1.3.min.js +++ /dev/null @@ -1,4 +0,0 @@ -/*! jQuery v2.1.3 | (c) 2005, 2014 jQuery Foundation, Inc. | jquery.org/license */ -!function(a,b){"object"==typeof module&&"object"==typeof module.exports?module.exports=a.document?b(a,!0):function(a){if(!a.document)throw new Error("jQuery requires a window with a document");return b(a)}:b(a)}("undefined"!=typeof window?window:this,function(a,b){var c=[],d=c.slice,e=c.concat,f=c.push,g=c.indexOf,h={},i=h.toString,j=h.hasOwnProperty,k={},l=a.document,m="2.1.3",n=function(a,b){return new n.fn.init(a,b)},o=/^[\s\uFEFF\xA0]+|[\s\uFEFF\xA0]+$/g,p=/^-ms-/,q=/-([\da-z])/gi,r=function(a,b){return b.toUpperCase()};n.fn=n.prototype={jquery:m,constructor:n,selector:"",length:0,toArray:function(){return d.call(this)},get:function(a){return null!=a?0>a?this[a+this.length]:this[a]:d.call(this)},pushStack:function(a){var b=n.merge(this.constructor(),a);return b.prevObject=this,b.context=this.context,b},each:function(a,b){return n.each(this,a,b)},map:function(a){return this.pushStack(n.map(this,function(b,c){return a.call(b,c,b)}))},slice:function(){return this.pushStack(d.apply(this,arguments))},first:function(){return this.eq(0)},last:function(){return this.eq(-1)},eq:function(a){var b=this.length,c=+a+(0>a?b:0);return this.pushStack(c>=0&&b>c?[this[c]]:[])},end:function(){return this.prevObject||this.constructor(null)},push:f,sort:c.sort,splice:c.splice},n.extend=n.fn.extend=function(){var a,b,c,d,e,f,g=arguments[0]||{},h=1,i=arguments.length,j=!1;for("boolean"==typeof g&&(j=g,g=arguments[h]||{},h++),"object"==typeof g||n.isFunction(g)||(g={}),h===i&&(g=this,h--);i>h;h++)if(null!=(a=arguments[h]))for(b in a)c=g[b],d=a[b],g!==d&&(j&&d&&(n.isPlainObject(d)||(e=n.isArray(d)))?(e?(e=!1,f=c&&n.isArray(c)?c:[]):f=c&&n.isPlainObject(c)?c:{},g[b]=n.extend(j,f,d)):void 0!==d&&(g[b]=d));return g},n.extend({expando:"jQuery"+(m+Math.random()).replace(/\D/g,""),isReady:!0,error:function(a){throw new Error(a)},noop:function(){},isFunction:function(a){return"function"===n.type(a)},isArray:Array.isArray,isWindow:function(a){return null!=a&&a===a.window},isNumeric:function(a){return!n.isArray(a)&&a-parseFloat(a)+1>=0},isPlainObject:function(a){return"object"!==n.type(a)||a.nodeType||n.isWindow(a)?!1:a.constructor&&!j.call(a.constructor.prototype,"isPrototypeOf")?!1:!0},isEmptyObject:function(a){var b;for(b in a)return!1;return!0},type:function(a){return null==a?a+"":"object"==typeof a||"function"==typeof a?h[i.call(a)]||"object":typeof a},globalEval:function(a){var b,c=eval;a=n.trim(a),a&&(1===a.indexOf("use strict")?(b=l.createElement("script"),b.text=a,l.head.appendChild(b).parentNode.removeChild(b)):c(a))},camelCase:function(a){return a.replace(p,"ms-").replace(q,r)},nodeName:function(a,b){return a.nodeName&&a.nodeName.toLowerCase()===b.toLowerCase()},each:function(a,b,c){var d,e=0,f=a.length,g=s(a);if(c){if(g){for(;f>e;e++)if(d=b.apply(a[e],c),d===!1)break}else for(e in a)if(d=b.apply(a[e],c),d===!1)break}else if(g){for(;f>e;e++)if(d=b.call(a[e],e,a[e]),d===!1)break}else for(e in a)if(d=b.call(a[e],e,a[e]),d===!1)break;return a},trim:function(a){return null==a?"":(a+"").replace(o,"")},makeArray:function(a,b){var c=b||[];return null!=a&&(s(Object(a))?n.merge(c,"string"==typeof a?[a]:a):f.call(c,a)),c},inArray:function(a,b,c){return null==b?-1:g.call(b,a,c)},merge:function(a,b){for(var c=+b.length,d=0,e=a.length;c>d;d++)a[e++]=b[d];return a.length=e,a},grep:function(a,b,c){for(var d,e=[],f=0,g=a.length,h=!c;g>f;f++)d=!b(a[f],f),d!==h&&e.push(a[f]);return e},map:function(a,b,c){var d,f=0,g=a.length,h=s(a),i=[];if(h)for(;g>f;f++)d=b(a[f],f,c),null!=d&&i.push(d);else for(f in a)d=b(a[f],f,c),null!=d&&i.push(d);return e.apply([],i)},guid:1,proxy:function(a,b){var c,e,f;return"string"==typeof b&&(c=a[b],b=a,a=c),n.isFunction(a)?(e=d.call(arguments,2),f=function(){return a.apply(b||this,e.concat(d.call(arguments)))},f.guid=a.guid=a.guid||n.guid++,f):void 0},now:Date.now,support:k}),n.each("Boolean Number String Function Array Date RegExp Object Error".split(" "),function(a,b){h["[object "+b+"]"]=b.toLowerCase()});function s(a){var b=a.length,c=n.type(a);return"function"===c||n.isWindow(a)?!1:1===a.nodeType&&b?!0:"array"===c||0===b||"number"==typeof b&&b>0&&b-1 in a}var t=function(a){var b,c,d,e,f,g,h,i,j,k,l,m,n,o,p,q,r,s,t,u="sizzle"+1*new Date,v=a.document,w=0,x=0,y=hb(),z=hb(),A=hb(),B=function(a,b){return a===b&&(l=!0),0},C=1<<31,D={}.hasOwnProperty,E=[],F=E.pop,G=E.push,H=E.push,I=E.slice,J=function(a,b){for(var c=0,d=a.length;d>c;c++)if(a[c]===b)return c;return-1},K="checked|selected|async|autofocus|autoplay|controls|defer|disabled|hidden|ismap|loop|multiple|open|readonly|required|scoped",L="[\\x20\\t\\r\\n\\f]",M="(?:\\\\.|[\\w-]|[^\\x00-\\xa0])+",N=M.replace("w","w#"),O="\\["+L+"*("+M+")(?:"+L+"*([*^$|!~]?=)"+L+"*(?:'((?:\\\\.|[^\\\\'])*)'|\"((?:\\\\.|[^\\\\\"])*)\"|("+N+"))|)"+L+"*\\]",P=":("+M+")(?:\\((('((?:\\\\.|[^\\\\'])*)'|\"((?:\\\\.|[^\\\\\"])*)\")|((?:\\\\.|[^\\\\()[\\]]|"+O+")*)|.*)\\)|)",Q=new RegExp(L+"+","g"),R=new RegExp("^"+L+"+|((?:^|[^\\\\])(?:\\\\.)*)"+L+"+$","g"),S=new RegExp("^"+L+"*,"+L+"*"),T=new RegExp("^"+L+"*([>+~]|"+L+")"+L+"*"),U=new RegExp("="+L+"*([^\\]'\"]*?)"+L+"*\\]","g"),V=new RegExp(P),W=new RegExp("^"+N+"$"),X={ID:new RegExp("^#("+M+")"),CLASS:new RegExp("^\\.("+M+")"),TAG:new RegExp("^("+M.replace("w","w*")+")"),ATTR:new RegExp("^"+O),PSEUDO:new RegExp("^"+P),CHILD:new RegExp("^:(only|first|last|nth|nth-last)-(child|of-type)(?:\\("+L+"*(even|odd|(([+-]|)(\\d*)n|)"+L+"*(?:([+-]|)"+L+"*(\\d+)|))"+L+"*\\)|)","i"),bool:new RegExp("^(?:"+K+")$","i"),needsContext:new RegExp("^"+L+"*[>+~]|:(even|odd|eq|gt|lt|nth|first|last)(?:\\("+L+"*((?:-\\d)?\\d*)"+L+"*\\)|)(?=[^-]|$)","i")},Y=/^(?:input|select|textarea|button)$/i,Z=/^h\d$/i,$=/^[^{]+\{\s*\[native \w/,_=/^(?:#([\w-]+)|(\w+)|\.([\w-]+))$/,ab=/[+~]/,bb=/'|\\/g,cb=new RegExp("\\\\([\\da-f]{1,6}"+L+"?|("+L+")|.)","ig"),db=function(a,b,c){var d="0x"+b-65536;return d!==d||c?b:0>d?String.fromCharCode(d+65536):String.fromCharCode(d>>10|55296,1023&d|56320)},eb=function(){m()};try{H.apply(E=I.call(v.childNodes),v.childNodes),E[v.childNodes.length].nodeType}catch(fb){H={apply:E.length?function(a,b){G.apply(a,I.call(b))}:function(a,b){var c=a.length,d=0;while(a[c++]=b[d++]);a.length=c-1}}}function gb(a,b,d,e){var f,h,j,k,l,o,r,s,w,x;if((b?b.ownerDocument||b:v)!==n&&m(b),b=b||n,d=d||[],k=b.nodeType,"string"!=typeof a||!a||1!==k&&9!==k&&11!==k)return d;if(!e&&p){if(11!==k&&(f=_.exec(a)))if(j=f[1]){if(9===k){if(h=b.getElementById(j),!h||!h.parentNode)return d;if(h.id===j)return d.push(h),d}else if(b.ownerDocument&&(h=b.ownerDocument.getElementById(j))&&t(b,h)&&h.id===j)return d.push(h),d}else{if(f[2])return H.apply(d,b.getElementsByTagName(a)),d;if((j=f[3])&&c.getElementsByClassName)return H.apply(d,b.getElementsByClassName(j)),d}if(c.qsa&&(!q||!q.test(a))){if(s=r=u,w=b,x=1!==k&&a,1===k&&"object"!==b.nodeName.toLowerCase()){o=g(a),(r=b.getAttribute("id"))?s=r.replace(bb,"\\$&"):b.setAttribute("id",s),s="[id='"+s+"'] ",l=o.length;while(l--)o[l]=s+rb(o[l]);w=ab.test(a)&&pb(b.parentNode)||b,x=o.join(",")}if(x)try{return H.apply(d,w.querySelectorAll(x)),d}catch(y){}finally{r||b.removeAttribute("id")}}}return i(a.replace(R,"$1"),b,d,e)}function hb(){var a=[];function b(c,e){return a.push(c+" ")>d.cacheLength&&delete b[a.shift()],b[c+" "]=e}return b}function ib(a){return a[u]=!0,a}function jb(a){var b=n.createElement("div");try{return!!a(b)}catch(c){return!1}finally{b.parentNode&&b.parentNode.removeChild(b),b=null}}function kb(a,b){var c=a.split("|"),e=a.length;while(e--)d.attrHandle[c[e]]=b}function lb(a,b){var c=b&&a,d=c&&1===a.nodeType&&1===b.nodeType&&(~b.sourceIndex||C)-(~a.sourceIndex||C);if(d)return d;if(c)while(c=c.nextSibling)if(c===b)return-1;return a?1:-1}function mb(a){return function(b){var c=b.nodeName.toLowerCase();return"input"===c&&b.type===a}}function nb(a){return function(b){var c=b.nodeName.toLowerCase();return("input"===c||"button"===c)&&b.type===a}}function ob(a){return ib(function(b){return b=+b,ib(function(c,d){var e,f=a([],c.length,b),g=f.length;while(g--)c[e=f[g]]&&(c[e]=!(d[e]=c[e]))})})}function pb(a){return a&&"undefined"!=typeof a.getElementsByTagName&&a}c=gb.support={},f=gb.isXML=function(a){var b=a&&(a.ownerDocument||a).documentElement;return b?"HTML"!==b.nodeName:!1},m=gb.setDocument=function(a){var b,e,g=a?a.ownerDocument||a:v;return g!==n&&9===g.nodeType&&g.documentElement?(n=g,o=g.documentElement,e=g.defaultView,e&&e!==e.top&&(e.addEventListener?e.addEventListener("unload",eb,!1):e.attachEvent&&e.attachEvent("onunload",eb)),p=!f(g),c.attributes=jb(function(a){return a.className="i",!a.getAttribute("className")}),c.getElementsByTagName=jb(function(a){return a.appendChild(g.createComment("")),!a.getElementsByTagName("*").length}),c.getElementsByClassName=$.test(g.getElementsByClassName),c.getById=jb(function(a){return o.appendChild(a).id=u,!g.getElementsByName||!g.getElementsByName(u).length}),c.getById?(d.find.ID=function(a,b){if("undefined"!=typeof b.getElementById&&p){var c=b.getElementById(a);return c&&c.parentNode?[c]:[]}},d.filter.ID=function(a){var b=a.replace(cb,db);return function(a){return a.getAttribute("id")===b}}):(delete d.find.ID,d.filter.ID=function(a){var b=a.replace(cb,db);return function(a){var c="undefined"!=typeof a.getAttributeNode&&a.getAttributeNode("id");return c&&c.value===b}}),d.find.TAG=c.getElementsByTagName?function(a,b){return"undefined"!=typeof b.getElementsByTagName?b.getElementsByTagName(a):c.qsa?b.querySelectorAll(a):void 0}:function(a,b){var c,d=[],e=0,f=b.getElementsByTagName(a);if("*"===a){while(c=f[e++])1===c.nodeType&&d.push(c);return d}return f},d.find.CLASS=c.getElementsByClassName&&function(a,b){return p?b.getElementsByClassName(a):void 0},r=[],q=[],(c.qsa=$.test(g.querySelectorAll))&&(jb(function(a){o.appendChild(a).innerHTML="",a.querySelectorAll("[msallowcapture^='']").length&&q.push("[*^$]="+L+"*(?:''|\"\")"),a.querySelectorAll("[selected]").length||q.push("\\["+L+"*(?:value|"+K+")"),a.querySelectorAll("[id~="+u+"-]").length||q.push("~="),a.querySelectorAll(":checked").length||q.push(":checked"),a.querySelectorAll("a#"+u+"+*").length||q.push(".#.+[+~]")}),jb(function(a){var b=g.createElement("input");b.setAttribute("type","hidden"),a.appendChild(b).setAttribute("name","D"),a.querySelectorAll("[name=d]").length&&q.push("name"+L+"*[*^$|!~]?="),a.querySelectorAll(":enabled").length||q.push(":enabled",":disabled"),a.querySelectorAll("*,:x"),q.push(",.*:")})),(c.matchesSelector=$.test(s=o.matches||o.webkitMatchesSelector||o.mozMatchesSelector||o.oMatchesSelector||o.msMatchesSelector))&&jb(function(a){c.disconnectedMatch=s.call(a,"div"),s.call(a,"[s!='']:x"),r.push("!=",P)}),q=q.length&&new RegExp(q.join("|")),r=r.length&&new RegExp(r.join("|")),b=$.test(o.compareDocumentPosition),t=b||$.test(o.contains)?function(a,b){var c=9===a.nodeType?a.documentElement:a,d=b&&b.parentNode;return a===d||!(!d||1!==d.nodeType||!(c.contains?c.contains(d):a.compareDocumentPosition&&16&a.compareDocumentPosition(d)))}:function(a,b){if(b)while(b=b.parentNode)if(b===a)return!0;return!1},B=b?function(a,b){if(a===b)return l=!0,0;var d=!a.compareDocumentPosition-!b.compareDocumentPosition;return d?d:(d=(a.ownerDocument||a)===(b.ownerDocument||b)?a.compareDocumentPosition(b):1,1&d||!c.sortDetached&&b.compareDocumentPosition(a)===d?a===g||a.ownerDocument===v&&t(v,a)?-1:b===g||b.ownerDocument===v&&t(v,b)?1:k?J(k,a)-J(k,b):0:4&d?-1:1)}:function(a,b){if(a===b)return l=!0,0;var c,d=0,e=a.parentNode,f=b.parentNode,h=[a],i=[b];if(!e||!f)return a===g?-1:b===g?1:e?-1:f?1:k?J(k,a)-J(k,b):0;if(e===f)return lb(a,b);c=a;while(c=c.parentNode)h.unshift(c);c=b;while(c=c.parentNode)i.unshift(c);while(h[d]===i[d])d++;return d?lb(h[d],i[d]):h[d]===v?-1:i[d]===v?1:0},g):n},gb.matches=function(a,b){return gb(a,null,null,b)},gb.matchesSelector=function(a,b){if((a.ownerDocument||a)!==n&&m(a),b=b.replace(U,"='$1']"),!(!c.matchesSelector||!p||r&&r.test(b)||q&&q.test(b)))try{var d=s.call(a,b);if(d||c.disconnectedMatch||a.document&&11!==a.document.nodeType)return d}catch(e){}return gb(b,n,null,[a]).length>0},gb.contains=function(a,b){return(a.ownerDocument||a)!==n&&m(a),t(a,b)},gb.attr=function(a,b){(a.ownerDocument||a)!==n&&m(a);var e=d.attrHandle[b.toLowerCase()],f=e&&D.call(d.attrHandle,b.toLowerCase())?e(a,b,!p):void 0;return void 0!==f?f:c.attributes||!p?a.getAttribute(b):(f=a.getAttributeNode(b))&&f.specified?f.value:null},gb.error=function(a){throw new Error("Syntax error, unrecognized expression: "+a)},gb.uniqueSort=function(a){var b,d=[],e=0,f=0;if(l=!c.detectDuplicates,k=!c.sortStable&&a.slice(0),a.sort(B),l){while(b=a[f++])b===a[f]&&(e=d.push(f));while(e--)a.splice(d[e],1)}return k=null,a},e=gb.getText=function(a){var b,c="",d=0,f=a.nodeType;if(f){if(1===f||9===f||11===f){if("string"==typeof a.textContent)return a.textContent;for(a=a.firstChild;a;a=a.nextSibling)c+=e(a)}else if(3===f||4===f)return a.nodeValue}else while(b=a[d++])c+=e(b);return c},d=gb.selectors={cacheLength:50,createPseudo:ib,match:X,attrHandle:{},find:{},relative:{">":{dir:"parentNode",first:!0}," ":{dir:"parentNode"},"+":{dir:"previousSibling",first:!0},"~":{dir:"previousSibling"}},preFilter:{ATTR:function(a){return a[1]=a[1].replace(cb,db),a[3]=(a[3]||a[4]||a[5]||"").replace(cb,db),"~="===a[2]&&(a[3]=" "+a[3]+" "),a.slice(0,4)},CHILD:function(a){return a[1]=a[1].toLowerCase(),"nth"===a[1].slice(0,3)?(a[3]||gb.error(a[0]),a[4]=+(a[4]?a[5]+(a[6]||1):2*("even"===a[3]||"odd"===a[3])),a[5]=+(a[7]+a[8]||"odd"===a[3])):a[3]&&gb.error(a[0]),a},PSEUDO:function(a){var b,c=!a[6]&&a[2];return X.CHILD.test(a[0])?null:(a[3]?a[2]=a[4]||a[5]||"":c&&V.test(c)&&(b=g(c,!0))&&(b=c.indexOf(")",c.length-b)-c.length)&&(a[0]=a[0].slice(0,b),a[2]=c.slice(0,b)),a.slice(0,3))}},filter:{TAG:function(a){var b=a.replace(cb,db).toLowerCase();return"*"===a?function(){return!0}:function(a){return a.nodeName&&a.nodeName.toLowerCase()===b}},CLASS:function(a){var b=y[a+" "];return b||(b=new RegExp("(^|"+L+")"+a+"("+L+"|$)"))&&y(a,function(a){return b.test("string"==typeof a.className&&a.className||"undefined"!=typeof a.getAttribute&&a.getAttribute("class")||"")})},ATTR:function(a,b,c){return function(d){var e=gb.attr(d,a);return null==e?"!="===b:b?(e+="","="===b?e===c:"!="===b?e!==c:"^="===b?c&&0===e.indexOf(c):"*="===b?c&&e.indexOf(c)>-1:"$="===b?c&&e.slice(-c.length)===c:"~="===b?(" "+e.replace(Q," ")+" ").indexOf(c)>-1:"|="===b?e===c||e.slice(0,c.length+1)===c+"-":!1):!0}},CHILD:function(a,b,c,d,e){var f="nth"!==a.slice(0,3),g="last"!==a.slice(-4),h="of-type"===b;return 1===d&&0===e?function(a){return!!a.parentNode}:function(b,c,i){var j,k,l,m,n,o,p=f!==g?"nextSibling":"previousSibling",q=b.parentNode,r=h&&b.nodeName.toLowerCase(),s=!i&&!h;if(q){if(f){while(p){l=b;while(l=l[p])if(h?l.nodeName.toLowerCase()===r:1===l.nodeType)return!1;o=p="only"===a&&!o&&"nextSibling"}return!0}if(o=[g?q.firstChild:q.lastChild],g&&s){k=q[u]||(q[u]={}),j=k[a]||[],n=j[0]===w&&j[1],m=j[0]===w&&j[2],l=n&&q.childNodes[n];while(l=++n&&l&&l[p]||(m=n=0)||o.pop())if(1===l.nodeType&&++m&&l===b){k[a]=[w,n,m];break}}else if(s&&(j=(b[u]||(b[u]={}))[a])&&j[0]===w)m=j[1];else while(l=++n&&l&&l[p]||(m=n=0)||o.pop())if((h?l.nodeName.toLowerCase()===r:1===l.nodeType)&&++m&&(s&&((l[u]||(l[u]={}))[a]=[w,m]),l===b))break;return m-=e,m===d||m%d===0&&m/d>=0}}},PSEUDO:function(a,b){var c,e=d.pseudos[a]||d.setFilters[a.toLowerCase()]||gb.error("unsupported pseudo: "+a);return e[u]?e(b):e.length>1?(c=[a,a,"",b],d.setFilters.hasOwnProperty(a.toLowerCase())?ib(function(a,c){var d,f=e(a,b),g=f.length;while(g--)d=J(a,f[g]),a[d]=!(c[d]=f[g])}):function(a){return e(a,0,c)}):e}},pseudos:{not:ib(function(a){var b=[],c=[],d=h(a.replace(R,"$1"));return d[u]?ib(function(a,b,c,e){var f,g=d(a,null,e,[]),h=a.length;while(h--)(f=g[h])&&(a[h]=!(b[h]=f))}):function(a,e,f){return b[0]=a,d(b,null,f,c),b[0]=null,!c.pop()}}),has:ib(function(a){return function(b){return gb(a,b).length>0}}),contains:ib(function(a){return a=a.replace(cb,db),function(b){return(b.textContent||b.innerText||e(b)).indexOf(a)>-1}}),lang:ib(function(a){return W.test(a||"")||gb.error("unsupported lang: "+a),a=a.replace(cb,db).toLowerCase(),function(b){var c;do if(c=p?b.lang:b.getAttribute("xml:lang")||b.getAttribute("lang"))return c=c.toLowerCase(),c===a||0===c.indexOf(a+"-");while((b=b.parentNode)&&1===b.nodeType);return!1}}),target:function(b){var c=a.location&&a.location.hash;return c&&c.slice(1)===b.id},root:function(a){return a===o},focus:function(a){return a===n.activeElement&&(!n.hasFocus||n.hasFocus())&&!!(a.type||a.href||~a.tabIndex)},enabled:function(a){return a.disabled===!1},disabled:function(a){return a.disabled===!0},checked:function(a){var b=a.nodeName.toLowerCase();return"input"===b&&!!a.checked||"option"===b&&!!a.selected},selected:function(a){return a.parentNode&&a.parentNode.selectedIndex,a.selected===!0},empty:function(a){for(a=a.firstChild;a;a=a.nextSibling)if(a.nodeType<6)return!1;return!0},parent:function(a){return!d.pseudos.empty(a)},header:function(a){return Z.test(a.nodeName)},input:function(a){return Y.test(a.nodeName)},button:function(a){var b=a.nodeName.toLowerCase();return"input"===b&&"button"===a.type||"button"===b},text:function(a){var b;return"input"===a.nodeName.toLowerCase()&&"text"===a.type&&(null==(b=a.getAttribute("type"))||"text"===b.toLowerCase())},first:ob(function(){return[0]}),last:ob(function(a,b){return[b-1]}),eq:ob(function(a,b,c){return[0>c?c+b:c]}),even:ob(function(a,b){for(var c=0;b>c;c+=2)a.push(c);return a}),odd:ob(function(a,b){for(var c=1;b>c;c+=2)a.push(c);return a}),lt:ob(function(a,b,c){for(var d=0>c?c+b:c;--d>=0;)a.push(d);return a}),gt:ob(function(a,b,c){for(var d=0>c?c+b:c;++db;b++)d+=a[b].value;return d}function sb(a,b,c){var d=b.dir,e=c&&"parentNode"===d,f=x++;return b.first?function(b,c,f){while(b=b[d])if(1===b.nodeType||e)return a(b,c,f)}:function(b,c,g){var h,i,j=[w,f];if(g){while(b=b[d])if((1===b.nodeType||e)&&a(b,c,g))return!0}else while(b=b[d])if(1===b.nodeType||e){if(i=b[u]||(b[u]={}),(h=i[d])&&h[0]===w&&h[1]===f)return j[2]=h[2];if(i[d]=j,j[2]=a(b,c,g))return!0}}}function tb(a){return a.length>1?function(b,c,d){var e=a.length;while(e--)if(!a[e](b,c,d))return!1;return!0}:a[0]}function ub(a,b,c){for(var d=0,e=b.length;e>d;d++)gb(a,b[d],c);return c}function vb(a,b,c,d,e){for(var f,g=[],h=0,i=a.length,j=null!=b;i>h;h++)(f=a[h])&&(!c||c(f,d,e))&&(g.push(f),j&&b.push(h));return g}function wb(a,b,c,d,e,f){return d&&!d[u]&&(d=wb(d)),e&&!e[u]&&(e=wb(e,f)),ib(function(f,g,h,i){var j,k,l,m=[],n=[],o=g.length,p=f||ub(b||"*",h.nodeType?[h]:h,[]),q=!a||!f&&b?p:vb(p,m,a,h,i),r=c?e||(f?a:o||d)?[]:g:q;if(c&&c(q,r,h,i),d){j=vb(r,n),d(j,[],h,i),k=j.length;while(k--)(l=j[k])&&(r[n[k]]=!(q[n[k]]=l))}if(f){if(e||a){if(e){j=[],k=r.length;while(k--)(l=r[k])&&j.push(q[k]=l);e(null,r=[],j,i)}k=r.length;while(k--)(l=r[k])&&(j=e?J(f,l):m[k])>-1&&(f[j]=!(g[j]=l))}}else r=vb(r===g?r.splice(o,r.length):r),e?e(null,g,r,i):H.apply(g,r)})}function xb(a){for(var b,c,e,f=a.length,g=d.relative[a[0].type],h=g||d.relative[" "],i=g?1:0,k=sb(function(a){return a===b},h,!0),l=sb(function(a){return J(b,a)>-1},h,!0),m=[function(a,c,d){var e=!g&&(d||c!==j)||((b=c).nodeType?k(a,c,d):l(a,c,d));return b=null,e}];f>i;i++)if(c=d.relative[a[i].type])m=[sb(tb(m),c)];else{if(c=d.filter[a[i].type].apply(null,a[i].matches),c[u]){for(e=++i;f>e;e++)if(d.relative[a[e].type])break;return wb(i>1&&tb(m),i>1&&rb(a.slice(0,i-1).concat({value:" "===a[i-2].type?"*":""})).replace(R,"$1"),c,e>i&&xb(a.slice(i,e)),f>e&&xb(a=a.slice(e)),f>e&&rb(a))}m.push(c)}return tb(m)}function yb(a,b){var c=b.length>0,e=a.length>0,f=function(f,g,h,i,k){var l,m,o,p=0,q="0",r=f&&[],s=[],t=j,u=f||e&&d.find.TAG("*",k),v=w+=null==t?1:Math.random()||.1,x=u.length;for(k&&(j=g!==n&&g);q!==x&&null!=(l=u[q]);q++){if(e&&l){m=0;while(o=a[m++])if(o(l,g,h)){i.push(l);break}k&&(w=v)}c&&((l=!o&&l)&&p--,f&&r.push(l))}if(p+=q,c&&q!==p){m=0;while(o=b[m++])o(r,s,g,h);if(f){if(p>0)while(q--)r[q]||s[q]||(s[q]=F.call(i));s=vb(s)}H.apply(i,s),k&&!f&&s.length>0&&p+b.length>1&&gb.uniqueSort(i)}return k&&(w=v,j=t),r};return c?ib(f):f}return h=gb.compile=function(a,b){var c,d=[],e=[],f=A[a+" "];if(!f){b||(b=g(a)),c=b.length;while(c--)f=xb(b[c]),f[u]?d.push(f):e.push(f);f=A(a,yb(e,d)),f.selector=a}return f},i=gb.select=function(a,b,e,f){var i,j,k,l,m,n="function"==typeof a&&a,o=!f&&g(a=n.selector||a);if(e=e||[],1===o.length){if(j=o[0]=o[0].slice(0),j.length>2&&"ID"===(k=j[0]).type&&c.getById&&9===b.nodeType&&p&&d.relative[j[1].type]){if(b=(d.find.ID(k.matches[0].replace(cb,db),b)||[])[0],!b)return e;n&&(b=b.parentNode),a=a.slice(j.shift().value.length)}i=X.needsContext.test(a)?0:j.length;while(i--){if(k=j[i],d.relative[l=k.type])break;if((m=d.find[l])&&(f=m(k.matches[0].replace(cb,db),ab.test(j[0].type)&&pb(b.parentNode)||b))){if(j.splice(i,1),a=f.length&&rb(j),!a)return H.apply(e,f),e;break}}}return(n||h(a,o))(f,b,!p,e,ab.test(a)&&pb(b.parentNode)||b),e},c.sortStable=u.split("").sort(B).join("")===u,c.detectDuplicates=!!l,m(),c.sortDetached=jb(function(a){return 1&a.compareDocumentPosition(n.createElement("div"))}),jb(function(a){return a.innerHTML="","#"===a.firstChild.getAttribute("href")})||kb("type|href|height|width",function(a,b,c){return c?void 0:a.getAttribute(b,"type"===b.toLowerCase()?1:2)}),c.attributes&&jb(function(a){return a.innerHTML="",a.firstChild.setAttribute("value",""),""===a.firstChild.getAttribute("value")})||kb("value",function(a,b,c){return c||"input"!==a.nodeName.toLowerCase()?void 0:a.defaultValue}),jb(function(a){return null==a.getAttribute("disabled")})||kb(K,function(a,b,c){var d;return c?void 0:a[b]===!0?b.toLowerCase():(d=a.getAttributeNode(b))&&d.specified?d.value:null}),gb}(a);n.find=t,n.expr=t.selectors,n.expr[":"]=n.expr.pseudos,n.unique=t.uniqueSort,n.text=t.getText,n.isXMLDoc=t.isXML,n.contains=t.contains;var u=n.expr.match.needsContext,v=/^<(\w+)\s*\/?>(?:<\/\1>|)$/,w=/^.[^:#\[\.,]*$/;function x(a,b,c){if(n.isFunction(b))return n.grep(a,function(a,d){return!!b.call(a,d,a)!==c});if(b.nodeType)return n.grep(a,function(a){return a===b!==c});if("string"==typeof b){if(w.test(b))return n.filter(b,a,c);b=n.filter(b,a)}return n.grep(a,function(a){return g.call(b,a)>=0!==c})}n.filter=function(a,b,c){var d=b[0];return c&&(a=":not("+a+")"),1===b.length&&1===d.nodeType?n.find.matchesSelector(d,a)?[d]:[]:n.find.matches(a,n.grep(b,function(a){return 1===a.nodeType}))},n.fn.extend({find:function(a){var b,c=this.length,d=[],e=this;if("string"!=typeof a)return this.pushStack(n(a).filter(function(){for(b=0;c>b;b++)if(n.contains(e[b],this))return!0}));for(b=0;c>b;b++)n.find(a,e[b],d);return d=this.pushStack(c>1?n.unique(d):d),d.selector=this.selector?this.selector+" "+a:a,d},filter:function(a){return this.pushStack(x(this,a||[],!1))},not:function(a){return this.pushStack(x(this,a||[],!0))},is:function(a){return!!x(this,"string"==typeof a&&u.test(a)?n(a):a||[],!1).length}});var y,z=/^(?:\s*(<[\w\W]+>)[^>]*|#([\w-]*))$/,A=n.fn.init=function(a,b){var c,d;if(!a)return this;if("string"==typeof a){if(c="<"===a[0]&&">"===a[a.length-1]&&a.length>=3?[null,a,null]:z.exec(a),!c||!c[1]&&b)return!b||b.jquery?(b||y).find(a):this.constructor(b).find(a);if(c[1]){if(b=b instanceof n?b[0]:b,n.merge(this,n.parseHTML(c[1],b&&b.nodeType?b.ownerDocument||b:l,!0)),v.test(c[1])&&n.isPlainObject(b))for(c in b)n.isFunction(this[c])?this[c](b[c]):this.attr(c,b[c]);return this}return d=l.getElementById(c[2]),d&&d.parentNode&&(this.length=1,this[0]=d),this.context=l,this.selector=a,this}return a.nodeType?(this.context=this[0]=a,this.length=1,this):n.isFunction(a)?"undefined"!=typeof y.ready?y.ready(a):a(n):(void 0!==a.selector&&(this.selector=a.selector,this.context=a.context),n.makeArray(a,this))};A.prototype=n.fn,y=n(l);var B=/^(?:parents|prev(?:Until|All))/,C={children:!0,contents:!0,next:!0,prev:!0};n.extend({dir:function(a,b,c){var d=[],e=void 0!==c;while((a=a[b])&&9!==a.nodeType)if(1===a.nodeType){if(e&&n(a).is(c))break;d.push(a)}return d},sibling:function(a,b){for(var c=[];a;a=a.nextSibling)1===a.nodeType&&a!==b&&c.push(a);return c}}),n.fn.extend({has:function(a){var b=n(a,this),c=b.length;return this.filter(function(){for(var a=0;c>a;a++)if(n.contains(this,b[a]))return!0})},closest:function(a,b){for(var c,d=0,e=this.length,f=[],g=u.test(a)||"string"!=typeof a?n(a,b||this.context):0;e>d;d++)for(c=this[d];c&&c!==b;c=c.parentNode)if(c.nodeType<11&&(g?g.index(c)>-1:1===c.nodeType&&n.find.matchesSelector(c,a))){f.push(c);break}return this.pushStack(f.length>1?n.unique(f):f)},index:function(a){return a?"string"==typeof a?g.call(n(a),this[0]):g.call(this,a.jquery?a[0]:a):this[0]&&this[0].parentNode?this.first().prevAll().length:-1},add:function(a,b){return this.pushStack(n.unique(n.merge(this.get(),n(a,b))))},addBack:function(a){return this.add(null==a?this.prevObject:this.prevObject.filter(a))}});function D(a,b){while((a=a[b])&&1!==a.nodeType);return a}n.each({parent:function(a){var b=a.parentNode;return b&&11!==b.nodeType?b:null},parents:function(a){return n.dir(a,"parentNode")},parentsUntil:function(a,b,c){return n.dir(a,"parentNode",c)},next:function(a){return D(a,"nextSibling")},prev:function(a){return D(a,"previousSibling")},nextAll:function(a){return n.dir(a,"nextSibling")},prevAll:function(a){return n.dir(a,"previousSibling")},nextUntil:function(a,b,c){return n.dir(a,"nextSibling",c)},prevUntil:function(a,b,c){return n.dir(a,"previousSibling",c)},siblings:function(a){return n.sibling((a.parentNode||{}).firstChild,a)},children:function(a){return n.sibling(a.firstChild)},contents:function(a){return a.contentDocument||n.merge([],a.childNodes)}},function(a,b){n.fn[a]=function(c,d){var e=n.map(this,b,c);return"Until"!==a.slice(-5)&&(d=c),d&&"string"==typeof d&&(e=n.filter(d,e)),this.length>1&&(C[a]||n.unique(e),B.test(a)&&e.reverse()),this.pushStack(e)}});var E=/\S+/g,F={};function G(a){var b=F[a]={};return n.each(a.match(E)||[],function(a,c){b[c]=!0}),b}n.Callbacks=function(a){a="string"==typeof a?F[a]||G(a):n.extend({},a);var b,c,d,e,f,g,h=[],i=!a.once&&[],j=function(l){for(b=a.memory&&l,c=!0,g=e||0,e=0,f=h.length,d=!0;h&&f>g;g++)if(h[g].apply(l[0],l[1])===!1&&a.stopOnFalse){b=!1;break}d=!1,h&&(i?i.length&&j(i.shift()):b?h=[]:k.disable())},k={add:function(){if(h){var c=h.length;!function g(b){n.each(b,function(b,c){var d=n.type(c);"function"===d?a.unique&&k.has(c)||h.push(c):c&&c.length&&"string"!==d&&g(c)})}(arguments),d?f=h.length:b&&(e=c,j(b))}return this},remove:function(){return h&&n.each(arguments,function(a,b){var c;while((c=n.inArray(b,h,c))>-1)h.splice(c,1),d&&(f>=c&&f--,g>=c&&g--)}),this},has:function(a){return a?n.inArray(a,h)>-1:!(!h||!h.length)},empty:function(){return h=[],f=0,this},disable:function(){return h=i=b=void 0,this},disabled:function(){return!h},lock:function(){return i=void 0,b||k.disable(),this},locked:function(){return!i},fireWith:function(a,b){return!h||c&&!i||(b=b||[],b=[a,b.slice?b.slice():b],d?i.push(b):j(b)),this},fire:function(){return k.fireWith(this,arguments),this},fired:function(){return!!c}};return k},n.extend({Deferred:function(a){var b=[["resolve","done",n.Callbacks("once memory"),"resolved"],["reject","fail",n.Callbacks("once memory"),"rejected"],["notify","progress",n.Callbacks("memory")]],c="pending",d={state:function(){return c},always:function(){return e.done(arguments).fail(arguments),this},then:function(){var a=arguments;return n.Deferred(function(c){n.each(b,function(b,f){var g=n.isFunction(a[b])&&a[b];e[f[1]](function(){var a=g&&g.apply(this,arguments);a&&n.isFunction(a.promise)?a.promise().done(c.resolve).fail(c.reject).progress(c.notify):c[f[0]+"With"](this===d?c.promise():this,g?[a]:arguments)})}),a=null}).promise()},promise:function(a){return null!=a?n.extend(a,d):d}},e={};return d.pipe=d.then,n.each(b,function(a,f){var g=f[2],h=f[3];d[f[1]]=g.add,h&&g.add(function(){c=h},b[1^a][2].disable,b[2][2].lock),e[f[0]]=function(){return e[f[0]+"With"](this===e?d:this,arguments),this},e[f[0]+"With"]=g.fireWith}),d.promise(e),a&&a.call(e,e),e},when:function(a){var b=0,c=d.call(arguments),e=c.length,f=1!==e||a&&n.isFunction(a.promise)?e:0,g=1===f?a:n.Deferred(),h=function(a,b,c){return function(e){b[a]=this,c[a]=arguments.length>1?d.call(arguments):e,c===i?g.notifyWith(b,c):--f||g.resolveWith(b,c)}},i,j,k;if(e>1)for(i=new Array(e),j=new Array(e),k=new Array(e);e>b;b++)c[b]&&n.isFunction(c[b].promise)?c[b].promise().done(h(b,k,c)).fail(g.reject).progress(h(b,j,i)):--f;return f||g.resolveWith(k,c),g.promise()}});var H;n.fn.ready=function(a){return n.ready.promise().done(a),this},n.extend({isReady:!1,readyWait:1,holdReady:function(a){a?n.readyWait++:n.ready(!0)},ready:function(a){(a===!0?--n.readyWait:n.isReady)||(n.isReady=!0,a!==!0&&--n.readyWait>0||(H.resolveWith(l,[n]),n.fn.triggerHandler&&(n(l).triggerHandler("ready"),n(l).off("ready"))))}});function I(){l.removeEventListener("DOMContentLoaded",I,!1),a.removeEventListener("load",I,!1),n.ready()}n.ready.promise=function(b){return H||(H=n.Deferred(),"complete"===l.readyState?setTimeout(n.ready):(l.addEventListener("DOMContentLoaded",I,!1),a.addEventListener("load",I,!1))),H.promise(b)},n.ready.promise();var J=n.access=function(a,b,c,d,e,f,g){var h=0,i=a.length,j=null==c;if("object"===n.type(c)){e=!0;for(h in c)n.access(a,b,h,c[h],!0,f,g)}else if(void 0!==d&&(e=!0,n.isFunction(d)||(g=!0),j&&(g?(b.call(a,d),b=null):(j=b,b=function(a,b,c){return j.call(n(a),c)})),b))for(;i>h;h++)b(a[h],c,g?d:d.call(a[h],h,b(a[h],c)));return e?a:j?b.call(a):i?b(a[0],c):f};n.acceptData=function(a){return 1===a.nodeType||9===a.nodeType||!+a.nodeType};function K(){Object.defineProperty(this.cache={},0,{get:function(){return{}}}),this.expando=n.expando+K.uid++}K.uid=1,K.accepts=n.acceptData,K.prototype={key:function(a){if(!K.accepts(a))return 0;var b={},c=a[this.expando];if(!c){c=K.uid++;try{b[this.expando]={value:c},Object.defineProperties(a,b)}catch(d){b[this.expando]=c,n.extend(a,b)}}return this.cache[c]||(this.cache[c]={}),c},set:function(a,b,c){var d,e=this.key(a),f=this.cache[e];if("string"==typeof b)f[b]=c;else if(n.isEmptyObject(f))n.extend(this.cache[e],b);else for(d in b)f[d]=b[d];return f},get:function(a,b){var c=this.cache[this.key(a)];return void 0===b?c:c[b]},access:function(a,b,c){var d;return void 0===b||b&&"string"==typeof b&&void 0===c?(d=this.get(a,b),void 0!==d?d:this.get(a,n.camelCase(b))):(this.set(a,b,c),void 0!==c?c:b)},remove:function(a,b){var c,d,e,f=this.key(a),g=this.cache[f];if(void 0===b)this.cache[f]={};else{n.isArray(b)?d=b.concat(b.map(n.camelCase)):(e=n.camelCase(b),b in g?d=[b,e]:(d=e,d=d in g?[d]:d.match(E)||[])),c=d.length;while(c--)delete g[d[c]]}},hasData:function(a){return!n.isEmptyObject(this.cache[a[this.expando]]||{})},discard:function(a){a[this.expando]&&delete this.cache[a[this.expando]]}};var L=new K,M=new K,N=/^(?:\{[\w\W]*\}|\[[\w\W]*\])$/,O=/([A-Z])/g;function P(a,b,c){var d;if(void 0===c&&1===a.nodeType)if(d="data-"+b.replace(O,"-$1").toLowerCase(),c=a.getAttribute(d),"string"==typeof c){try{c="true"===c?!0:"false"===c?!1:"null"===c?null:+c+""===c?+c:N.test(c)?n.parseJSON(c):c}catch(e){}M.set(a,b,c)}else c=void 0;return c}n.extend({hasData:function(a){return M.hasData(a)||L.hasData(a)},data:function(a,b,c){return M.access(a,b,c) -},removeData:function(a,b){M.remove(a,b)},_data:function(a,b,c){return L.access(a,b,c)},_removeData:function(a,b){L.remove(a,b)}}),n.fn.extend({data:function(a,b){var c,d,e,f=this[0],g=f&&f.attributes;if(void 0===a){if(this.length&&(e=M.get(f),1===f.nodeType&&!L.get(f,"hasDataAttrs"))){c=g.length;while(c--)g[c]&&(d=g[c].name,0===d.indexOf("data-")&&(d=n.camelCase(d.slice(5)),P(f,d,e[d])));L.set(f,"hasDataAttrs",!0)}return e}return"object"==typeof a?this.each(function(){M.set(this,a)}):J(this,function(b){var c,d=n.camelCase(a);if(f&&void 0===b){if(c=M.get(f,a),void 0!==c)return c;if(c=M.get(f,d),void 0!==c)return c;if(c=P(f,d,void 0),void 0!==c)return c}else this.each(function(){var c=M.get(this,d);M.set(this,d,b),-1!==a.indexOf("-")&&void 0!==c&&M.set(this,a,b)})},null,b,arguments.length>1,null,!0)},removeData:function(a){return this.each(function(){M.remove(this,a)})}}),n.extend({queue:function(a,b,c){var d;return a?(b=(b||"fx")+"queue",d=L.get(a,b),c&&(!d||n.isArray(c)?d=L.access(a,b,n.makeArray(c)):d.push(c)),d||[]):void 0},dequeue:function(a,b){b=b||"fx";var c=n.queue(a,b),d=c.length,e=c.shift(),f=n._queueHooks(a,b),g=function(){n.dequeue(a,b)};"inprogress"===e&&(e=c.shift(),d--),e&&("fx"===b&&c.unshift("inprogress"),delete f.stop,e.call(a,g,f)),!d&&f&&f.empty.fire()},_queueHooks:function(a,b){var c=b+"queueHooks";return L.get(a,c)||L.access(a,c,{empty:n.Callbacks("once memory").add(function(){L.remove(a,[b+"queue",c])})})}}),n.fn.extend({queue:function(a,b){var c=2;return"string"!=typeof a&&(b=a,a="fx",c--),arguments.lengthx",k.noCloneChecked=!!b.cloneNode(!0).lastChild.defaultValue}();var U="undefined";k.focusinBubbles="onfocusin"in a;var V=/^key/,W=/^(?:mouse|pointer|contextmenu)|click/,X=/^(?:focusinfocus|focusoutblur)$/,Y=/^([^.]*)(?:\.(.+)|)$/;function Z(){return!0}function $(){return!1}function _(){try{return l.activeElement}catch(a){}}n.event={global:{},add:function(a,b,c,d,e){var f,g,h,i,j,k,l,m,o,p,q,r=L.get(a);if(r){c.handler&&(f=c,c=f.handler,e=f.selector),c.guid||(c.guid=n.guid++),(i=r.events)||(i=r.events={}),(g=r.handle)||(g=r.handle=function(b){return typeof n!==U&&n.event.triggered!==b.type?n.event.dispatch.apply(a,arguments):void 0}),b=(b||"").match(E)||[""],j=b.length;while(j--)h=Y.exec(b[j])||[],o=q=h[1],p=(h[2]||"").split(".").sort(),o&&(l=n.event.special[o]||{},o=(e?l.delegateType:l.bindType)||o,l=n.event.special[o]||{},k=n.extend({type:o,origType:q,data:d,handler:c,guid:c.guid,selector:e,needsContext:e&&n.expr.match.needsContext.test(e),namespace:p.join(".")},f),(m=i[o])||(m=i[o]=[],m.delegateCount=0,l.setup&&l.setup.call(a,d,p,g)!==!1||a.addEventListener&&a.addEventListener(o,g,!1)),l.add&&(l.add.call(a,k),k.handler.guid||(k.handler.guid=c.guid)),e?m.splice(m.delegateCount++,0,k):m.push(k),n.event.global[o]=!0)}},remove:function(a,b,c,d,e){var f,g,h,i,j,k,l,m,o,p,q,r=L.hasData(a)&&L.get(a);if(r&&(i=r.events)){b=(b||"").match(E)||[""],j=b.length;while(j--)if(h=Y.exec(b[j])||[],o=q=h[1],p=(h[2]||"").split(".").sort(),o){l=n.event.special[o]||{},o=(d?l.delegateType:l.bindType)||o,m=i[o]||[],h=h[2]&&new RegExp("(^|\\.)"+p.join("\\.(?:.*\\.|)")+"(\\.|$)"),g=f=m.length;while(f--)k=m[f],!e&&q!==k.origType||c&&c.guid!==k.guid||h&&!h.test(k.namespace)||d&&d!==k.selector&&("**"!==d||!k.selector)||(m.splice(f,1),k.selector&&m.delegateCount--,l.remove&&l.remove.call(a,k));g&&!m.length&&(l.teardown&&l.teardown.call(a,p,r.handle)!==!1||n.removeEvent(a,o,r.handle),delete i[o])}else for(o in i)n.event.remove(a,o+b[j],c,d,!0);n.isEmptyObject(i)&&(delete r.handle,L.remove(a,"events"))}},trigger:function(b,c,d,e){var f,g,h,i,k,m,o,p=[d||l],q=j.call(b,"type")?b.type:b,r=j.call(b,"namespace")?b.namespace.split("."):[];if(g=h=d=d||l,3!==d.nodeType&&8!==d.nodeType&&!X.test(q+n.event.triggered)&&(q.indexOf(".")>=0&&(r=q.split("."),q=r.shift(),r.sort()),k=q.indexOf(":")<0&&"on"+q,b=b[n.expando]?b:new n.Event(q,"object"==typeof b&&b),b.isTrigger=e?2:3,b.namespace=r.join("."),b.namespace_re=b.namespace?new RegExp("(^|\\.)"+r.join("\\.(?:.*\\.|)")+"(\\.|$)"):null,b.result=void 0,b.target||(b.target=d),c=null==c?[b]:n.makeArray(c,[b]),o=n.event.special[q]||{},e||!o.trigger||o.trigger.apply(d,c)!==!1)){if(!e&&!o.noBubble&&!n.isWindow(d)){for(i=o.delegateType||q,X.test(i+q)||(g=g.parentNode);g;g=g.parentNode)p.push(g),h=g;h===(d.ownerDocument||l)&&p.push(h.defaultView||h.parentWindow||a)}f=0;while((g=p[f++])&&!b.isPropagationStopped())b.type=f>1?i:o.bindType||q,m=(L.get(g,"events")||{})[b.type]&&L.get(g,"handle"),m&&m.apply(g,c),m=k&&g[k],m&&m.apply&&n.acceptData(g)&&(b.result=m.apply(g,c),b.result===!1&&b.preventDefault());return b.type=q,e||b.isDefaultPrevented()||o._default&&o._default.apply(p.pop(),c)!==!1||!n.acceptData(d)||k&&n.isFunction(d[q])&&!n.isWindow(d)&&(h=d[k],h&&(d[k]=null),n.event.triggered=q,d[q](),n.event.triggered=void 0,h&&(d[k]=h)),b.result}},dispatch:function(a){a=n.event.fix(a);var b,c,e,f,g,h=[],i=d.call(arguments),j=(L.get(this,"events")||{})[a.type]||[],k=n.event.special[a.type]||{};if(i[0]=a,a.delegateTarget=this,!k.preDispatch||k.preDispatch.call(this,a)!==!1){h=n.event.handlers.call(this,a,j),b=0;while((f=h[b++])&&!a.isPropagationStopped()){a.currentTarget=f.elem,c=0;while((g=f.handlers[c++])&&!a.isImmediatePropagationStopped())(!a.namespace_re||a.namespace_re.test(g.namespace))&&(a.handleObj=g,a.data=g.data,e=((n.event.special[g.origType]||{}).handle||g.handler).apply(f.elem,i),void 0!==e&&(a.result=e)===!1&&(a.preventDefault(),a.stopPropagation()))}return k.postDispatch&&k.postDispatch.call(this,a),a.result}},handlers:function(a,b){var c,d,e,f,g=[],h=b.delegateCount,i=a.target;if(h&&i.nodeType&&(!a.button||"click"!==a.type))for(;i!==this;i=i.parentNode||this)if(i.disabled!==!0||"click"!==a.type){for(d=[],c=0;h>c;c++)f=b[c],e=f.selector+" ",void 0===d[e]&&(d[e]=f.needsContext?n(e,this).index(i)>=0:n.find(e,this,null,[i]).length),d[e]&&d.push(f);d.length&&g.push({elem:i,handlers:d})}return h]*)\/>/gi,bb=/<([\w:]+)/,cb=/<|&#?\w+;/,db=/<(?:script|style|link)/i,eb=/checked\s*(?:[^=]|=\s*.checked.)/i,fb=/^$|\/(?:java|ecma)script/i,gb=/^true\/(.*)/,hb=/^\s*\s*$/g,ib={option:[1,""],thead:[1,"","
"],col:[2,"","
"],tr:[2,"","
"],td:[3,"","
"],_default:[0,"",""]};ib.optgroup=ib.option,ib.tbody=ib.tfoot=ib.colgroup=ib.caption=ib.thead,ib.th=ib.td;function jb(a,b){return n.nodeName(a,"table")&&n.nodeName(11!==b.nodeType?b:b.firstChild,"tr")?a.getElementsByTagName("tbody")[0]||a.appendChild(a.ownerDocument.createElement("tbody")):a}function kb(a){return a.type=(null!==a.getAttribute("type"))+"/"+a.type,a}function lb(a){var b=gb.exec(a.type);return b?a.type=b[1]:a.removeAttribute("type"),a}function mb(a,b){for(var c=0,d=a.length;d>c;c++)L.set(a[c],"globalEval",!b||L.get(b[c],"globalEval"))}function nb(a,b){var c,d,e,f,g,h,i,j;if(1===b.nodeType){if(L.hasData(a)&&(f=L.access(a),g=L.set(b,f),j=f.events)){delete g.handle,g.events={};for(e in j)for(c=0,d=j[e].length;d>c;c++)n.event.add(b,e,j[e][c])}M.hasData(a)&&(h=M.access(a),i=n.extend({},h),M.set(b,i))}}function ob(a,b){var c=a.getElementsByTagName?a.getElementsByTagName(b||"*"):a.querySelectorAll?a.querySelectorAll(b||"*"):[];return void 0===b||b&&n.nodeName(a,b)?n.merge([a],c):c}function pb(a,b){var c=b.nodeName.toLowerCase();"input"===c&&T.test(a.type)?b.checked=a.checked:("input"===c||"textarea"===c)&&(b.defaultValue=a.defaultValue)}n.extend({clone:function(a,b,c){var d,e,f,g,h=a.cloneNode(!0),i=n.contains(a.ownerDocument,a);if(!(k.noCloneChecked||1!==a.nodeType&&11!==a.nodeType||n.isXMLDoc(a)))for(g=ob(h),f=ob(a),d=0,e=f.length;e>d;d++)pb(f[d],g[d]);if(b)if(c)for(f=f||ob(a),g=g||ob(h),d=0,e=f.length;e>d;d++)nb(f[d],g[d]);else nb(a,h);return g=ob(h,"script"),g.length>0&&mb(g,!i&&ob(a,"script")),h},buildFragment:function(a,b,c,d){for(var e,f,g,h,i,j,k=b.createDocumentFragment(),l=[],m=0,o=a.length;o>m;m++)if(e=a[m],e||0===e)if("object"===n.type(e))n.merge(l,e.nodeType?[e]:e);else if(cb.test(e)){f=f||k.appendChild(b.createElement("div")),g=(bb.exec(e)||["",""])[1].toLowerCase(),h=ib[g]||ib._default,f.innerHTML=h[1]+e.replace(ab,"<$1>")+h[2],j=h[0];while(j--)f=f.lastChild;n.merge(l,f.childNodes),f=k.firstChild,f.textContent=""}else l.push(b.createTextNode(e));k.textContent="",m=0;while(e=l[m++])if((!d||-1===n.inArray(e,d))&&(i=n.contains(e.ownerDocument,e),f=ob(k.appendChild(e),"script"),i&&mb(f),c)){j=0;while(e=f[j++])fb.test(e.type||"")&&c.push(e)}return k},cleanData:function(a){for(var b,c,d,e,f=n.event.special,g=0;void 0!==(c=a[g]);g++){if(n.acceptData(c)&&(e=c[L.expando],e&&(b=L.cache[e]))){if(b.events)for(d in b.events)f[d]?n.event.remove(c,d):n.removeEvent(c,d,b.handle);L.cache[e]&&delete L.cache[e]}delete M.cache[c[M.expando]]}}}),n.fn.extend({text:function(a){return J(this,function(a){return void 0===a?n.text(this):this.empty().each(function(){(1===this.nodeType||11===this.nodeType||9===this.nodeType)&&(this.textContent=a)})},null,a,arguments.length)},append:function(){return this.domManip(arguments,function(a){if(1===this.nodeType||11===this.nodeType||9===this.nodeType){var b=jb(this,a);b.appendChild(a)}})},prepend:function(){return this.domManip(arguments,function(a){if(1===this.nodeType||11===this.nodeType||9===this.nodeType){var b=jb(this,a);b.insertBefore(a,b.firstChild)}})},before:function(){return this.domManip(arguments,function(a){this.parentNode&&this.parentNode.insertBefore(a,this)})},after:function(){return this.domManip(arguments,function(a){this.parentNode&&this.parentNode.insertBefore(a,this.nextSibling)})},remove:function(a,b){for(var c,d=a?n.filter(a,this):this,e=0;null!=(c=d[e]);e++)b||1!==c.nodeType||n.cleanData(ob(c)),c.parentNode&&(b&&n.contains(c.ownerDocument,c)&&mb(ob(c,"script")),c.parentNode.removeChild(c));return this},empty:function(){for(var a,b=0;null!=(a=this[b]);b++)1===a.nodeType&&(n.cleanData(ob(a,!1)),a.textContent="");return this},clone:function(a,b){return a=null==a?!1:a,b=null==b?a:b,this.map(function(){return n.clone(this,a,b)})},html:function(a){return J(this,function(a){var b=this[0]||{},c=0,d=this.length;if(void 0===a&&1===b.nodeType)return b.innerHTML;if("string"==typeof a&&!db.test(a)&&!ib[(bb.exec(a)||["",""])[1].toLowerCase()]){a=a.replace(ab,"<$1>");try{for(;d>c;c++)b=this[c]||{},1===b.nodeType&&(n.cleanData(ob(b,!1)),b.innerHTML=a);b=0}catch(e){}}b&&this.empty().append(a)},null,a,arguments.length)},replaceWith:function(){var a=arguments[0];return this.domManip(arguments,function(b){a=this.parentNode,n.cleanData(ob(this)),a&&a.replaceChild(b,this)}),a&&(a.length||a.nodeType)?this:this.remove()},detach:function(a){return this.remove(a,!0)},domManip:function(a,b){a=e.apply([],a);var c,d,f,g,h,i,j=0,l=this.length,m=this,o=l-1,p=a[0],q=n.isFunction(p);if(q||l>1&&"string"==typeof p&&!k.checkClone&&eb.test(p))return this.each(function(c){var d=m.eq(c);q&&(a[0]=p.call(this,c,d.html())),d.domManip(a,b)});if(l&&(c=n.buildFragment(a,this[0].ownerDocument,!1,this),d=c.firstChild,1===c.childNodes.length&&(c=d),d)){for(f=n.map(ob(c,"script"),kb),g=f.length;l>j;j++)h=c,j!==o&&(h=n.clone(h,!0,!0),g&&n.merge(f,ob(h,"script"))),b.call(this[j],h,j);if(g)for(i=f[f.length-1].ownerDocument,n.map(f,lb),j=0;g>j;j++)h=f[j],fb.test(h.type||"")&&!L.access(h,"globalEval")&&n.contains(i,h)&&(h.src?n._evalUrl&&n._evalUrl(h.src):n.globalEval(h.textContent.replace(hb,"")))}return this}}),n.each({appendTo:"append",prependTo:"prepend",insertBefore:"before",insertAfter:"after",replaceAll:"replaceWith"},function(a,b){n.fn[a]=function(a){for(var c,d=[],e=n(a),g=e.length-1,h=0;g>=h;h++)c=h===g?this:this.clone(!0),n(e[h])[b](c),f.apply(d,c.get());return this.pushStack(d)}});var qb,rb={};function sb(b,c){var d,e=n(c.createElement(b)).appendTo(c.body),f=a.getDefaultComputedStyle&&(d=a.getDefaultComputedStyle(e[0]))?d.display:n.css(e[0],"display");return e.detach(),f}function tb(a){var b=l,c=rb[a];return c||(c=sb(a,b),"none"!==c&&c||(qb=(qb||n("