From 8d3542a64e2689a00ed87f9bd58fe3e1d3b10ed8 Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Wed, 22 May 2019 16:42:00 -0400 Subject: implement federation parts of cross-signing --- synapse/federation/sender/per_destination_queue.py | 4 +- synapse/handlers/device.py | 13 ++- synapse/handlers/e2e_keys.py | 116 ++++++++++++++++++++- synapse/storage/devices.py | 56 +++++++++- 4 files changed, 179 insertions(+), 10 deletions(-) (limited to 'synapse') diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index fad980b893..0486af2dbf 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -366,10 +366,10 @@ class PerDestinationQueue(object): Edu( origin=self._server_name, destination=self._destination, - edu_type="m.device_list_update", + edu_type=edu_type, content=content, ) - for content in results + for (edu_type, content) in results ] assert len(edus) <= limit, "get_devices_by_remote returned too many EDUs" diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 5f23ee4488..cd6eb52316 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -458,7 +458,18 @@ class DeviceHandler(DeviceWorkerHandler): @defer.inlineCallbacks def on_federation_query_user_devices(self, user_id): stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id) - return {"user_id": user_id, "stream_id": stream_id, "devices": devices} + master_key = yield self.store.get_e2e_cross_signing_key(user_id, "master") + self_signing_key = yield self.store.get_e2e_cross_signing_key( + user_id, "self_signing" + ) + + return { + "user_id": user_id, + "stream_id": stream_id, + "devices": devices, + "master_key": master_key, + "self_signing_key": self_signing_key + } @defer.inlineCallbacks def user_left_room(self, user, room_id): diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 5ea54f60be..849ee04f93 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -36,6 +36,8 @@ from synapse.types import ( get_verify_key_from_cross_signing_key, ) from synapse.util import unwrapFirstError +from synapse.util.async_helpers import Linearizer +from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.retryutils import NotRetryingDestination logger = logging.getLogger(__name__) @@ -49,10 +51,17 @@ class E2eKeysHandler(object): self.is_mine = hs.is_mine self.clock = hs.get_clock() + self._edu_updater = SigningKeyEduUpdater(hs, self) + + federation_registry = hs.get_federation_registry() + + federation_registry.register_edu_handler( + "m.signing_key_update", self._edu_updater.incoming_signing_key_update, + ) # doesn't really work as part of the generic query API, because the # query request requires an object POST, but we abuse the # "query handler" interface. - hs.get_federation_registry().register_query_handler( + federation_registry.register_query_handler( "client_keys", self.on_federation_query_client_keys ) @@ -343,7 +352,15 @@ class E2eKeysHandler(object): """ device_keys_query = query_body.get("device_keys", {}) res = yield self.query_local_devices(device_keys_query) - return {"device_keys": res} + ret = {"device_keys": res} + + # add in the cross-signing keys + cross_signing_keys = yield self.query_cross_signing_keys(device_keys_query) + + for key, value in iteritems(cross_signing_keys): + ret[key + "_keys"] = value + + return ret @trace @defer.inlineCallbacks @@ -1047,3 +1064,98 @@ class SignatureListItem: target_user_id = attr.ib() target_device_id = attr.ib() signature = attr.ib() + + +class SigningKeyEduUpdater(object): + "Handles incoming signing key updates from federation and updates the DB" + + def __init__(self, hs, e2e_keys_handler): + self.store = hs.get_datastore() + self.federation = hs.get_federation_client() + self.clock = hs.get_clock() + self.e2e_keys_handler = e2e_keys_handler + + self._remote_edu_linearizer = Linearizer(name="remote_signing_key") + + # user_id -> list of updates waiting to be handled. + self._pending_updates = {} + + # Recently seen stream ids. We don't bother keeping these in the DB, + # but they're useful to have them about to reduce the number of spurious + # resyncs. + self._seen_updates = ExpiringCache( + cache_name="signing_key_update_edu", + clock=self.clock, + max_len=10000, + expiry_ms=30 * 60 * 1000, + iterable=True, + ) + + @defer.inlineCallbacks + def incoming_signing_key_update(self, origin, edu_content): + """Called on incoming signing key update from federation. Responsible for + parsing the EDU and adding to pending updates list. + + Args: + origin (string): the server that sent the EDU + edu_content (dict): the contents of the EDU + """ + + user_id = edu_content.pop("user_id") + master_key = edu_content.pop("master_key", None) + self_signing_key = edu_content.pop("self_signing_key", None) + + if get_domain_from_id(user_id) != origin: + # TODO: Raise? + logger.warning("Got signing key update edu for %r from %r", user_id, origin) + return + + room_ids = yield self.store.get_rooms_for_user(user_id) + if not room_ids: + # We don't share any rooms with this user. Ignore update, as we + # probably won't get any further updates. + return + + self._pending_updates.setdefault(user_id, []).append( + (master_key, self_signing_key, edu_content) + ) + + yield self._handle_signing_key_updates(user_id) + + @defer.inlineCallbacks + def _handle_signing_key_updates(self, user_id): + """Actually handle pending updates. + + Args: + user_id (string): the user whose updates we are processing + """ + + device_handler = self.e2e_keys_handler.device_handler + + with (yield self._remote_edu_linearizer.queue(user_id)): + pending_updates = self._pending_updates.pop(user_id, []) + if not pending_updates: + # This can happen since we batch updates + return + + device_ids = [] + + logger.info("pending updates: %r", pending_updates) + + for master_key, self_signing_key, edu_content in pending_updates: + if master_key: + yield self.store.set_e2e_cross_signing_key( + user_id, "master", master_key + ) + device_id = \ + get_verify_key_from_cross_signing_key(master_key)[1].version + device_ids.append(device_id) + if self_signing_key: + yield self.store.set_e2e_cross_signing_key( + user_id, "self_signing", self_signing_key + ) + device_id = \ + get_verify_key_from_cross_signing_key(self_signing_key)[1].version + device_ids.append(device_id) + + yield device_handler.notify_device_update(user_id, device_ids) diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index f7a3542348..182e95fa21 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -94,9 +94,10 @@ class DeviceWorkerStore(SQLBaseStore): """Get stream of updates to send to remote servers Returns: - Deferred[tuple[int, list[dict]]]: + Deferred[tuple[int, list[tuple[string,dict]]]]: current stream id (ie, the stream id of the last update included in the - response), and the list of updates + response), and the list of updates, where each update is a pair of EDU + type and EDU contents """ now_stream_id = self._device_list_id_gen.get_current_token() @@ -129,6 +130,25 @@ class DeviceWorkerStore(SQLBaseStore): if not updates: return now_stream_id, [] + # get the cross-signing keys of the users the list + users = set(r[0] for r in updates) + master_key_by_user = {} + self_signing_key_by_user = {} + for user in users: + cross_signing_key = yield self.get_e2e_cross_signing_key(user, "master") + key_id, verify_key = get_verify_key_from_cross_signing_key(cross_signing_key) + master_key_by_user[user] = { + "key_info": cross_signing_key, + "pubkey": verify_key.version + } + + cross_signing_key = yield self.get_e2e_cross_signing_key(user, "self_signing") + key_id, verify_key = get_verify_key_from_cross_signing_key(cross_signing_key) + self_signing_key_by_user[user] = { + "key_info": cross_signing_key, + "pubkey": verify_key.version + } + # if we have exceeded the limit, we need to exclude any results with the # same stream_id as the last row. if len(updates) > limit: @@ -158,6 +178,10 @@ class DeviceWorkerStore(SQLBaseStore): # Stop processing updates break + if update[1] == master_key_by_user[update[0]]["pubkey"] or \ + update[1] == self_signing_key_by_user[update[0]]["pubkey"]: + continue + key = (update[0], update[1]) update_context = update[3] @@ -172,16 +196,37 @@ class DeviceWorkerStore(SQLBaseStore): # means that there are more than limit updates all of which have the same # steam_id. + # figure out which cross-signing keys were changed by intersecting the + # update list with the master/self-signing key by user maps + cross_signing_keys_by_user = {} + for user_id, device_id, stream in updates: + if device_id == master_key_by_user[user_id]["pubkey"]: + result = cross_signing_keys_by_user.setdefault(user_id, {}) + result["master_key"] = \ + master_key_by_user[user_id]["key_info"] + elif device_id == self_signing_key_by_user[user_id]["pubkey"]: + result = cross_signing_keys_by_user.setdefault(user_id, {}) + result["self_signing_key"] = \ + self_signing_key_by_user[user_id]["key_info"] + + cross_signing_results = [] + + # add the updated cross-signing keys to the results list + for user_id, result in iteritems(cross_signing_keys_by_user): + result["user_id"] = user_id + cross_signing_results.append(("m.signing_key_update", result)) + # That should only happen if a client is spamming the server with new # devices, in which case E2E isn't going to work well anyway. We'll just # skip that stream_id and return an empty list, and continue with the next # stream_id next time. - if not query_map: + if not query_map and not cross_signing_results: return stream_id_cutoff, [] results = yield self._get_device_update_edus_by_remote( destination, from_stream_id, query_map ) + results.extend(cross_signing_results) return now_stream_id, results @@ -200,6 +245,7 @@ class DeviceWorkerStore(SQLBaseStore): Returns: List: List of device updates """ + # get the list of device updates that need to be sent sql = """ SELECT user_id, device_id, stream_id, opentracing_context FROM device_lists_outbound_pokes WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ? @@ -231,7 +277,7 @@ class DeviceWorkerStore(SQLBaseStore): query_map.keys(), include_all_devices=True, include_deleted_devices=True, - ) + ) if query_map else {} results = [] for user_id, user_devices in iteritems(devices): @@ -262,7 +308,7 @@ class DeviceWorkerStore(SQLBaseStore): else: result["deleted"] = True - results.append(result) + results.append(("m.device_list_update", result)) return results -- cgit 1.4.1 From a1aaf3eea6fabb3d61a28e3f75afdb33b41304ee Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Wed, 22 May 2019 21:24:21 -0400 Subject: don't crash if the user doesn't have cross-signing keys --- synapse/storage/devices.py | 39 +++++++++++++++++++++++++-------------- 1 file changed, 25 insertions(+), 14 deletions(-) (limited to 'synapse') diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index 182e95fa21..f46978c9c3 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -136,18 +136,24 @@ class DeviceWorkerStore(SQLBaseStore): self_signing_key_by_user = {} for user in users: cross_signing_key = yield self.get_e2e_cross_signing_key(user, "master") - key_id, verify_key = get_verify_key_from_cross_signing_key(cross_signing_key) - master_key_by_user[user] = { - "key_info": cross_signing_key, - "pubkey": verify_key.version - } + if cross_signing_key: + key_id, verify_key = get_verify_key_from_cross_signing_key( + cross_signing_key + ) + master_key_by_user[user] = { + "key_info": cross_signing_key, + "pubkey": verify_key.version + } cross_signing_key = yield self.get_e2e_cross_signing_key(user, "self_signing") - key_id, verify_key = get_verify_key_from_cross_signing_key(cross_signing_key) - self_signing_key_by_user[user] = { - "key_info": cross_signing_key, - "pubkey": verify_key.version - } + if cross_signing_key: + key_id, verify_key = get_verify_key_from_cross_signing_key( + cross_signing_key + ) + self_signing_key_by_user[user] = { + "key_info": cross_signing_key, + "pubkey": verify_key.version + } # if we have exceeded the limit, we need to exclude any results with the # same stream_id as the last row. @@ -178,8 +184,11 @@ class DeviceWorkerStore(SQLBaseStore): # Stop processing updates break - if update[1] == master_key_by_user[update[0]]["pubkey"] or \ - update[1] == self_signing_key_by_user[update[0]]["pubkey"]: + # skip over cross-signing keys + if (update[0] in master_key_by_user + and update[1] == master_key_by_user[update[0]]["pubkey"]) \ + or (update[0] in master_key_by_user + and update[1] == self_signing_key_by_user[update[0]]["pubkey"]): continue key = (update[0], update[1]) @@ -200,11 +209,13 @@ class DeviceWorkerStore(SQLBaseStore): # update list with the master/self-signing key by user maps cross_signing_keys_by_user = {} for user_id, device_id, stream in updates: - if device_id == master_key_by_user[user_id]["pubkey"]: + if device_id == master_key_by_user.get(user_id, {}) \ + .get("pubkey", None): result = cross_signing_keys_by_user.setdefault(user_id, {}) result["master_key"] = \ master_key_by_user[user_id]["key_info"] - elif device_id == self_signing_key_by_user[user_id]["pubkey"]: + elif device_id == self_signing_key_by_user.get(user_id, {}) \ + .get("pubkey", None): result = cross_signing_keys_by_user.setdefault(user_id, {}) result["self_signing_key"] = \ self_signing_key_by_user[user_id]["key_info"] -- cgit 1.4.1 From cfdb84422dba2ca28eacb65aca960aecc5598658 Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Mon, 22 Jul 2019 13:04:55 -0400 Subject: make black happy --- synapse/handlers/e2e_keys.py | 12 ++++++---- synapse/storage/devices.py | 54 ++++++++++++++++++++++++++------------------ 2 files changed, 39 insertions(+), 27 deletions(-) (limited to 'synapse') diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 849ee04f93..f3cfba0bd3 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -56,7 +56,7 @@ class E2eKeysHandler(object): federation_registry = hs.get_federation_registry() federation_registry.register_edu_handler( - "m.signing_key_update", self._edu_updater.incoming_signing_key_update, + "m.signing_key_update", self._edu_updater.incoming_signing_key_update ) # doesn't really work as part of the generic query API, because the # query request requires an object POST, but we abuse the @@ -1147,15 +1147,17 @@ class SigningKeyEduUpdater(object): yield self.store.set_e2e_cross_signing_key( user_id, "master", master_key ) - device_id = \ - get_verify_key_from_cross_signing_key(master_key)[1].version + device_id = get_verify_key_from_cross_signing_key(master_key)[ + 1 + ].version device_ids.append(device_id) if self_signing_key: yield self.store.set_e2e_cross_signing_key( user_id, "self_signing", self_signing_key ) - device_id = \ - get_verify_key_from_cross_signing_key(self_signing_key)[1].version + device_id = get_verify_key_from_cross_signing_key(self_signing_key)[ + 1 + ].version device_ids.append(device_id) yield device_handler.notify_device_update(user_id, device_ids) diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index f46978c9c3..60bf6d68ec 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -37,6 +37,7 @@ from synapse.storage._base import ( make_in_list_sql_clause, ) from synapse.storage.background_updates import BackgroundUpdateStore +from synapse.types import get_verify_key_from_cross_signing_key from synapse.util import batch_iter from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList @@ -142,17 +143,19 @@ class DeviceWorkerStore(SQLBaseStore): ) master_key_by_user[user] = { "key_info": cross_signing_key, - "pubkey": verify_key.version + "pubkey": verify_key.version, } - cross_signing_key = yield self.get_e2e_cross_signing_key(user, "self_signing") + cross_signing_key = yield self.get_e2e_cross_signing_key( + user, "self_signing" + ) if cross_signing_key: key_id, verify_key = get_verify_key_from_cross_signing_key( cross_signing_key ) self_signing_key_by_user[user] = { "key_info": cross_signing_key, - "pubkey": verify_key.version + "pubkey": verify_key.version, } # if we have exceeded the limit, we need to exclude any results with the @@ -185,10 +188,13 @@ class DeviceWorkerStore(SQLBaseStore): break # skip over cross-signing keys - if (update[0] in master_key_by_user - and update[1] == master_key_by_user[update[0]]["pubkey"]) \ - or (update[0] in master_key_by_user - and update[1] == self_signing_key_by_user[update[0]]["pubkey"]): + if ( + update[0] in master_key_by_user + and update[1] == master_key_by_user[update[0]]["pubkey"] + ) or ( + update[0] in master_key_by_user + and update[1] == self_signing_key_by_user[update[0]]["pubkey"] + ): continue key = (update[0], update[1]) @@ -209,16 +215,16 @@ class DeviceWorkerStore(SQLBaseStore): # update list with the master/self-signing key by user maps cross_signing_keys_by_user = {} for user_id, device_id, stream in updates: - if device_id == master_key_by_user.get(user_id, {}) \ - .get("pubkey", None): + if device_id == master_key_by_user.get(user_id, {}).get("pubkey", None): result = cross_signing_keys_by_user.setdefault(user_id, {}) - result["master_key"] = \ - master_key_by_user[user_id]["key_info"] - elif device_id == self_signing_key_by_user.get(user_id, {}) \ - .get("pubkey", None): + result["master_key"] = master_key_by_user[user_id]["key_info"] + elif device_id == self_signing_key_by_user.get(user_id, {}).get( + "pubkey", None + ): result = cross_signing_keys_by_user.setdefault(user_id, {}) - result["self_signing_key"] = \ - self_signing_key_by_user[user_id]["key_info"] + result["self_signing_key"] = self_signing_key_by_user[user_id][ + "key_info" + ] cross_signing_results = [] @@ -282,13 +288,17 @@ class DeviceWorkerStore(SQLBaseStore): List[Dict]: List of objects representing an device update EDU """ - devices = yield self.runInteraction( - "_get_e2e_device_keys_txn", - self._get_e2e_device_keys_txn, - query_map.keys(), - include_all_devices=True, - include_deleted_devices=True, - ) if query_map else {} + devices = ( + yield self.runInteraction( + "_get_e2e_device_keys_txn", + self._get_e2e_device_keys_txn, + query_map.keys(), + include_all_devices=True, + include_deleted_devices=True, + ) + if query_map + else {} + ) results = [] for user_id, user_devices in iteritems(devices): -- cgit 1.4.1 From 41ad35b5235ad9ed8a1b8889287ae840ee3373bd Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Fri, 2 Aug 2019 18:03:23 -0400 Subject: add missing param --- synapse/handlers/e2e_keys.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse') diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index f3cfba0bd3..6b65f47d18 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -355,7 +355,7 @@ class E2eKeysHandler(object): ret = {"device_keys": res} # add in the cross-signing keys - cross_signing_keys = yield self.query_cross_signing_keys(device_keys_query) + cross_signing_keys = yield self.query_cross_signing_keys(device_keys_query, None) for key, value in iteritems(cross_signing_keys): ret[key + "_keys"] = value -- cgit 1.4.1 From 1fabf82d50f3db25ce0e4a93f349d90eb2d30a16 Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Tue, 22 Oct 2019 21:44:58 -0400 Subject: update to work with newer code, and fix formatting --- synapse/handlers/device.py | 2 +- synapse/handlers/e2e_keys.py | 9 +++++---- synapse/storage/devices.py | 2 +- 3 files changed, 7 insertions(+), 6 deletions(-) (limited to 'synapse') diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index cd6eb52316..fd8d14b680 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -468,7 +468,7 @@ class DeviceHandler(DeviceWorkerHandler): "stream_id": stream_id, "devices": devices, "master_key": master_key, - "self_signing_key": self_signing_key + "self_signing_key": self_signing_key, } @defer.inlineCallbacks diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 6b65f47d18..73572f4614 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -260,7 +260,7 @@ class E2eKeysHandler(object): Returns: defer.Deferred[dict[str, dict[str, dict]]]: map from - (master|self_signing|user_signing) -> user_id -> key + (master_keys|self_signing_keys|user_signing_keys) -> user_id -> key """ master_keys = {} self_signing_keys = {} @@ -355,10 +355,11 @@ class E2eKeysHandler(object): ret = {"device_keys": res} # add in the cross-signing keys - cross_signing_keys = yield self.query_cross_signing_keys(device_keys_query, None) + cross_signing_keys = yield self.get_cross_signing_keys_from_cache( + device_keys_query, None + ) - for key, value in iteritems(cross_signing_keys): - ret[key + "_keys"] = value + ret.update(cross_signing_keys) return ret diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index 60bf6d68ec..1aaef1fbb0 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -214,7 +214,7 @@ class DeviceWorkerStore(SQLBaseStore): # figure out which cross-signing keys were changed by intersecting the # update list with the master/self-signing key by user maps cross_signing_keys_by_user = {} - for user_id, device_id, stream in updates: + for user_id, device_id, stream, _opentracing_context in updates: if device_id == master_key_by_user.get(user_id, {}).get("pubkey", None): result = cross_signing_keys_by_user.setdefault(user_id, {}) result["master_key"] = master_key_by_user[user_id]["key_info"] -- cgit 1.4.1 From 404e8c85321b4fe9b9b74e07841367c4cf201551 Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Tue, 22 Oct 2019 22:33:23 -0400 Subject: vendor-prefix the EDU name until MSC1756 is merged into the spec --- synapse/handlers/e2e_keys.py | 3 ++- synapse/storage/devices.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) (limited to 'synapse') diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 73572f4614..d25af42f5f 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -55,8 +55,9 @@ class E2eKeysHandler(object): federation_registry = hs.get_federation_registry() + # FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec federation_registry.register_edu_handler( - "m.signing_key_update", self._edu_updater.incoming_signing_key_update + "org.matrix.signing_key_update", self._edu_updater.incoming_signing_key_update ) # doesn't really work as part of the generic query API, because the # query request requires an object POST, but we abuse the diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index 1aaef1fbb0..6ac165068e 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -231,7 +231,8 @@ class DeviceWorkerStore(SQLBaseStore): # add the updated cross-signing keys to the results list for user_id, result in iteritems(cross_signing_keys_by_user): result["user_id"] = user_id - cross_signing_results.append(("m.signing_key_update", result)) + # FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec + cross_signing_results.append(("org.matrix.signing_key_update", result)) # That should only happen if a client is spamming the server with new # devices, in which case E2E isn't going to work well anyway. We'll just -- cgit 1.4.1 From 480eac30eb543ce6947009fa90b8409f153eb3a4 Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Tue, 22 Oct 2019 22:37:16 -0400 Subject: black --- synapse/handlers/e2e_keys.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'synapse') diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index d25af42f5f..f6aa9a940b 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -57,7 +57,8 @@ class E2eKeysHandler(object): # FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec federation_registry.register_edu_handler( - "org.matrix.signing_key_update", self._edu_updater.incoming_signing_key_update + "org.matrix.signing_key_update", + self._edu_updater.incoming_signing_key_update, ) # doesn't really work as part of the generic query API, because the # query request requires an object POST, but we abuse the -- cgit 1.4.1 From 22a984767018eb511b4cc5c7875ff8bea46db19e Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 23 Oct 2019 12:00:21 +0100 Subject: Move persist_events out from main data store. This is in preparation for splitting out of state_groups_state from the main store into it own one, as persisting events depends on calculating state. --- synapse/storage/data_stores/main/events.py | 701 +++-------------------------- synapse/storage/persist_events.py | 644 ++++++++++++++++++++++++++ 2 files changed, 711 insertions(+), 634 deletions(-) create mode 100644 synapse/storage/persist_events.py (limited to 'synapse') diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index 03b5111c5d..6304531cd5 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -17,14 +17,14 @@ import itertools import logging -from collections import Counter as c_counter, OrderedDict, deque, namedtuple +from collections import Counter as c_counter, OrderedDict, namedtuple from functools import wraps from six import iteritems, text_type from six.moves import range from canonicaljson import json -from prometheus_client import Counter, Histogram +from prometheus_client import Counter from twisted.internet import defer @@ -34,11 +34,9 @@ from synapse.api.errors import SynapseError from synapse.events import EventBase # noqa: F401 from synapse.events.snapshot import EventContext # noqa: F401 from synapse.events.utils import prune_event_dict -from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable from synapse.logging.utils import log_function from synapse.metrics import BucketCollector from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.state import StateResolutionStore from synapse.storage._base import make_in_list_sql_clause from synapse.storage.background_updates import BackgroundUpdateStore from synapse.storage.data_stores.main.event_federation import EventFederationStore @@ -46,10 +44,8 @@ from synapse.storage.data_stores.main.events_worker import EventsWorkerStore from synapse.storage.data_stores.main.state import StateGroupWorkerStore from synapse.types import RoomStreamToken, get_domain_from_id from synapse.util import batch_iter -from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches.descriptors import cached, cachedInlineCallbacks from synapse.util.frozenutils import frozendict_json_encoder -from synapse.util.metrics import Measure logger = logging.getLogger(__name__) @@ -60,37 +56,6 @@ event_counter = Counter( ["type", "origin_type", "origin_entity"], ) -# The number of times we are recalculating the current state -state_delta_counter = Counter("synapse_storage_events_state_delta", "") - -# The number of times we are recalculating state when there is only a -# single forward extremity -state_delta_single_event_counter = Counter( - "synapse_storage_events_state_delta_single_event", "" -) - -# The number of times we are reculating state when we could have resonably -# calculated the delta when we calculated the state for an event we were -# persisting. -state_delta_reuse_delta_counter = Counter( - "synapse_storage_events_state_delta_reuse_delta", "" -) - -# The number of forward extremities for each new event. -forward_extremities_counter = Histogram( - "synapse_storage_events_forward_extremities_persisted", - "Number of forward extremities for each new event", - buckets=(1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"), -) - -# The number of stale forward extremities for each new event. Stale extremities -# are those that were in the previous set of extremities as well as the new. -stale_forward_extremities_counter = Histogram( - "synapse_storage_events_stale_forward_extremities_persisted", - "Number of unchanged forward extremities for each new event", - buckets=(0, 1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"), -) - def encode_json(json_object): """ @@ -102,110 +67,6 @@ def encode_json(json_object): return out -class _EventPeristenceQueue(object): - """Queues up events so that they can be persisted in bulk with only one - concurrent transaction per room. - """ - - _EventPersistQueueItem = namedtuple( - "_EventPersistQueueItem", ("events_and_contexts", "backfilled", "deferred") - ) - - def __init__(self): - self._event_persist_queues = {} - self._currently_persisting_rooms = set() - - def add_to_queue(self, room_id, events_and_contexts, backfilled): - """Add events to the queue, with the given persist_event options. - - NB: due to the normal usage pattern of this method, it does *not* - follow the synapse logcontext rules, and leaves the logcontext in - place whether or not the returned deferred is ready. - - Args: - room_id (str): - events_and_contexts (list[(EventBase, EventContext)]): - backfilled (bool): - - Returns: - defer.Deferred: a deferred which will resolve once the events are - persisted. Runs its callbacks *without* a logcontext. - """ - queue = self._event_persist_queues.setdefault(room_id, deque()) - if queue: - # if the last item in the queue has the same `backfilled` setting, - # we can just add these new events to that item. - end_item = queue[-1] - if end_item.backfilled == backfilled: - end_item.events_and_contexts.extend(events_and_contexts) - return end_item.deferred.observe() - - deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True) - - queue.append( - self._EventPersistQueueItem( - events_and_contexts=events_and_contexts, - backfilled=backfilled, - deferred=deferred, - ) - ) - - return deferred.observe() - - def handle_queue(self, room_id, per_item_callback): - """Attempts to handle the queue for a room if not already being handled. - - The given callback will be invoked with for each item in the queue, - of type _EventPersistQueueItem. The per_item_callback will continuously - be called with new items, unless the queue becomnes empty. The return - value of the function will be given to the deferreds waiting on the item, - exceptions will be passed to the deferreds as well. - - This function should therefore be called whenever anything is added - to the queue. - - If another callback is currently handling the queue then it will not be - invoked. - """ - - if room_id in self._currently_persisting_rooms: - return - - self._currently_persisting_rooms.add(room_id) - - @defer.inlineCallbacks - def handle_queue_loop(): - try: - queue = self._get_drainining_queue(room_id) - for item in queue: - try: - ret = yield per_item_callback(item) - except Exception: - with PreserveLoggingContext(): - item.deferred.errback() - else: - with PreserveLoggingContext(): - item.deferred.callback(ret) - finally: - queue = self._event_persist_queues.pop(room_id, None) - if queue: - self._event_persist_queues[room_id] = queue - self._currently_persisting_rooms.discard(room_id) - - # set handle_queue_loop off in the background - run_as_background_process("persist_events", handle_queue_loop) - - def _get_drainining_queue(self, room_id): - queue = self._event_persist_queues.setdefault(room_id, deque()) - - try: - while True: - yield queue.popleft() - except IndexError: - # Queue has been drained. - pass - - _EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event")) @@ -241,9 +102,6 @@ class EventsStore( def __init__(self, db_conn, hs): super(EventsStore, self).__init__(db_conn, hs) - self._event_persist_queue = _EventPeristenceQueue() - self._state_resolution_handler = hs.get_state_resolution_handler() - # Collect metrics on the number of forward extremities that exist. # Counter of number of extremities to count self._current_forward_extremities_amount = c_counter() @@ -286,80 +144,16 @@ class EventsStore( res = yield self.runInteraction("read_forward_extremities", fetch) self._current_forward_extremities_amount = c_counter(list(x[0] for x in res)) - @defer.inlineCallbacks - def persist_events(self, events_and_contexts, backfilled=False): - """ - Write events to the database - Args: - events_and_contexts: list of tuples of (event, context) - backfilled (bool): Whether the results are retrieved from federation - via backfill or not. Used to determine if they're "new" events - which might update the current state etc. - - Returns: - Deferred[int]: the stream ordering of the latest persisted event - """ - partitioned = {} - for event, ctx in events_and_contexts: - partitioned.setdefault(event.room_id, []).append((event, ctx)) - - deferreds = [] - for room_id, evs_ctxs in iteritems(partitioned): - d = self._event_persist_queue.add_to_queue( - room_id, evs_ctxs, backfilled=backfilled - ) - deferreds.append(d) - - for room_id in partitioned: - self._maybe_start_persisting(room_id) - - yield make_deferred_yieldable( - defer.gatherResults(deferreds, consumeErrors=True) - ) - - max_persisted_id = yield self._stream_id_gen.get_current_token() - - return max_persisted_id - - @defer.inlineCallbacks - @log_function - def persist_event(self, event, context, backfilled=False): - """ - - Args: - event (EventBase): - context (EventContext): - backfilled (bool): - - Returns: - Deferred: resolves to (int, int): the stream ordering of ``event``, - and the stream ordering of the latest persisted event - """ - deferred = self._event_persist_queue.add_to_queue( - event.room_id, [(event, context)], backfilled=backfilled - ) - - self._maybe_start_persisting(event.room_id) - - yield make_deferred_yieldable(deferred) - - max_persisted_id = yield self._stream_id_gen.get_current_token() - return (event.internal_metadata.stream_ordering, max_persisted_id) - - def _maybe_start_persisting(self, room_id): - @defer.inlineCallbacks - def persisting_queue(item): - with Measure(self._clock, "persist_events"): - yield self._persist_events( - item.events_and_contexts, backfilled=item.backfilled - ) - - self._event_persist_queue.handle_queue(room_id, persisting_queue) - @_retry_on_integrity_error @defer.inlineCallbacks def _persist_events( - self, events_and_contexts, backfilled=False, delete_existing=False + self, + events_and_contexts, + current_state_for_room, + state_delta_for_room, + new_forward_extremeties, + backfilled=False, + delete_existing=False, ): """Persist events to db @@ -374,252 +168,73 @@ class EventsStore( if not events_and_contexts: return - chunks = [ - events_and_contexts[x : x + 100] - for x in range(0, len(events_and_contexts), 100) - ] - - for chunk in chunks: - # We can't easily parallelize these since different chunks - # might contain the same event. :( - - # NB: Assumes that we are only persisting events for one room - # at a time. - - # map room_id->list[event_ids] giving the new forward - # extremities in each room - new_forward_extremeties = {} + # We want to calculate the stream orderings as late as possible, as + # we only notify after all events with a lesser stream ordering have + # been persisted. I.e. if we spend 10s inside the with block then + # that will delay all subsequent events from being notified about. + # Hence why we do it down here rather than wrapping the entire + # function. + # + # Its safe to do this after calculating the state deltas etc as we + # only need to protect the *persistence* of the events. This is to + # ensure that queries of the form "fetch events since X" don't + # return events and stream positions after events that are still in + # flight, as otherwise subsequent requests "fetch event since Y" + # will not return those events. + # + # Note: Multiple instances of this function cannot be in flight at + # the same time for the same room. + if backfilled: + stream_ordering_manager = self._backfill_id_gen.get_next_mult( + len(events_and_contexts) + ) + else: + stream_ordering_manager = self._stream_id_gen.get_next_mult( + len(events_and_contexts) + ) - # map room_id->(type,state_key)->event_id tracking the full - # state in each room after adding these events. - # This is simply used to prefill the get_current_state_ids - # cache - current_state_for_room = {} + with stream_ordering_manager as stream_orderings: + for (event, context), stream in zip(events_and_contexts, stream_orderings): + event.internal_metadata.stream_ordering = stream - # map room_id->(to_delete, to_insert) where to_delete is a list - # of type/state keys to remove from current state, and to_insert - # is a map (type,key)->event_id giving the state delta in each - # room - state_delta_for_room = {} + yield self.runInteraction( + "persist_events", + self._persist_events_txn, + events_and_contexts=events_and_contexts, + backfilled=backfilled, + delete_existing=delete_existing, + state_delta_for_room=state_delta_for_room, + new_forward_extremeties=new_forward_extremeties, + ) + persist_event_counter.inc(len(events_and_contexts)) if not backfilled: - with Measure(self._clock, "_calculate_state_and_extrem"): - # Work out the new "current state" for each room. - # We do this by working out what the new extremities are and then - # calculating the state from that. - events_by_room = {} - for event, context in chunk: - events_by_room.setdefault(event.room_id, []).append( - (event, context) - ) - - for room_id, ev_ctx_rm in iteritems(events_by_room): - latest_event_ids = yield self.get_latest_event_ids_in_room( - room_id - ) - new_latest_event_ids = yield self._calculate_new_extremities( - room_id, ev_ctx_rm, latest_event_ids - ) - - latest_event_ids = set(latest_event_ids) - if new_latest_event_ids == latest_event_ids: - # No change in extremities, so no change in state - continue - - # there should always be at least one forward extremity. - # (except during the initial persistence of the send_join - # results, in which case there will be no existing - # extremities, so we'll `continue` above and skip this bit.) - assert new_latest_event_ids, "No forward extremities left!" - - new_forward_extremeties[room_id] = new_latest_event_ids - - len_1 = ( - len(latest_event_ids) == 1 - and len(new_latest_event_ids) == 1 - ) - if len_1: - all_single_prev_not_state = all( - len(event.prev_event_ids()) == 1 - and not event.is_state() - for event, ctx in ev_ctx_rm - ) - # Don't bother calculating state if they're just - # a long chain of single ancestor non-state events. - if all_single_prev_not_state: - continue - - state_delta_counter.inc() - if len(new_latest_event_ids) == 1: - state_delta_single_event_counter.inc() - - # This is a fairly handwavey check to see if we could - # have guessed what the delta would have been when - # processing one of these events. - # What we're interested in is if the latest extremities - # were the same when we created the event as they are - # now. When this server creates a new event (as opposed - # to receiving it over federation) it will use the - # forward extremities as the prev_events, so we can - # guess this by looking at the prev_events and checking - # if they match the current forward extremities. - for ev, _ in ev_ctx_rm: - prev_event_ids = set(ev.prev_event_ids()) - if latest_event_ids == prev_event_ids: - state_delta_reuse_delta_counter.inc() - break - - logger.info("Calculating state delta for room %s", room_id) - with Measure( - self._clock, "persist_events.get_new_state_after_events" - ): - res = yield self._get_new_state_after_events( - room_id, - ev_ctx_rm, - latest_event_ids, - new_latest_event_ids, - ) - current_state, delta_ids = res - - # If either are not None then there has been a change, - # and we need to work out the delta (or use that - # given) - if delta_ids is not None: - # If there is a delta we know that we've - # only added or replaced state, never - # removed keys entirely. - state_delta_for_room[room_id] = ([], delta_ids) - elif current_state is not None: - with Measure( - self._clock, "persist_events.calculate_state_delta" - ): - delta = yield self._calculate_state_delta( - room_id, current_state - ) - state_delta_for_room[room_id] = delta - - # If we have the current_state then lets prefill - # the cache with it. - if current_state is not None: - current_state_for_room[room_id] = current_state - - # We want to calculate the stream orderings as late as possible, as - # we only notify after all events with a lesser stream ordering have - # been persisted. I.e. if we spend 10s inside the with block then - # that will delay all subsequent events from being notified about. - # Hence why we do it down here rather than wrapping the entire - # function. - # - # Its safe to do this after calculating the state deltas etc as we - # only need to protect the *persistence* of the events. This is to - # ensure that queries of the form "fetch events since X" don't - # return events and stream positions after events that are still in - # flight, as otherwise subsequent requests "fetch event since Y" - # will not return those events. - # - # Note: Multiple instances of this function cannot be in flight at - # the same time for the same room. - if backfilled: - stream_ordering_manager = self._backfill_id_gen.get_next_mult( - len(chunk) - ) - else: - stream_ordering_manager = self._stream_id_gen.get_next_mult(len(chunk)) - - with stream_ordering_manager as stream_orderings: - for (event, context), stream in zip(chunk, stream_orderings): - event.internal_metadata.stream_ordering = stream - - yield self.runInteraction( - "persist_events", - self._persist_events_txn, - events_and_contexts=chunk, - backfilled=backfilled, - delete_existing=delete_existing, - state_delta_for_room=state_delta_for_room, - new_forward_extremeties=new_forward_extremeties, + # backfilled events have negative stream orderings, so we don't + # want to set the event_persisted_position to that. + synapse.metrics.event_persisted_position.set( + events_and_contexts[-1][0].internal_metadata.stream_ordering ) - persist_event_counter.inc(len(chunk)) - - if not backfilled: - # backfilled events have negative stream orderings, so we don't - # want to set the event_persisted_position to that. - synapse.metrics.event_persisted_position.set( - chunk[-1][0].internal_metadata.stream_ordering - ) - - for event, context in chunk: - if context.app_service: - origin_type = "local" - origin_entity = context.app_service.id - elif self.hs.is_mine_id(event.sender): - origin_type = "local" - origin_entity = "*client*" - else: - origin_type = "remote" - origin_entity = get_domain_from_id(event.sender) - - event_counter.labels(event.type, origin_type, origin_entity).inc() - - for room_id, new_state in iteritems(current_state_for_room): - self.get_current_state_ids.prefill((room_id,), new_state) - - for room_id, latest_event_ids in iteritems(new_forward_extremeties): - self.get_latest_event_ids_in_room.prefill( - (room_id,), list(latest_event_ids) - ) - - @defer.inlineCallbacks - def _calculate_new_extremities(self, room_id, event_contexts, latest_event_ids): - """Calculates the new forward extremities for a room given events to - persist. - - Assumes that we are only persisting events for one room at a time. - """ - - # we're only interested in new events which aren't outliers and which aren't - # being rejected. - new_events = [ - event - for event, ctx in event_contexts - if not event.internal_metadata.is_outlier() - and not ctx.rejected - and not event.internal_metadata.is_soft_failed() - ] - latest_event_ids = set(latest_event_ids) - - # start with the existing forward extremities - result = set(latest_event_ids) - - # add all the new events to the list - result.update(event.event_id for event in new_events) - - # Now remove all events which are prev_events of any of the new events - result.difference_update( - e_id for event in new_events for e_id in event.prev_event_ids() - ) - - # Remove any events which are prev_events of any existing events. - existing_prevs = yield self._get_events_which_are_prevs(result) - result.difference_update(existing_prevs) + for event, context in events_and_contexts: + if context.app_service: + origin_type = "local" + origin_entity = context.app_service.id + elif self.hs.is_mine_id(event.sender): + origin_type = "local" + origin_entity = "*client*" + else: + origin_type = "remote" + origin_entity = get_domain_from_id(event.sender) - # Finally handle the case where the new events have soft-failed prev - # events. If they do we need to remove them and their prev events, - # otherwise we end up with dangling extremities. - existing_prevs = yield self._get_prevs_before_rejected( - e_id for event in new_events for e_id in event.prev_event_ids() - ) - result.difference_update(existing_prevs) + event_counter.labels(event.type, origin_type, origin_entity).inc() - # We only update metrics for events that change forward extremities - # (e.g. we ignore backfill/outliers/etc) - if result != latest_event_ids: - forward_extremities_counter.observe(len(result)) - stale = latest_event_ids & result - stale_forward_extremities_counter.observe(len(stale)) + for room_id, new_state in iteritems(current_state_for_room): + self.get_current_state_ids.prefill((room_id,), new_state) - return result + for room_id, latest_event_ids in iteritems(new_forward_extremeties): + self.get_latest_event_ids_in_room.prefill( + (room_id,), list(latest_event_ids) + ) @defer.inlineCallbacks def _get_events_which_are_prevs(self, event_ids): @@ -725,188 +340,6 @@ class EventsStore( return existing_prevs - @defer.inlineCallbacks - def _get_new_state_after_events( - self, room_id, events_context, old_latest_event_ids, new_latest_event_ids - ): - """Calculate the current state dict after adding some new events to - a room - - Args: - room_id (str): - room to which the events are being added. Used for logging etc - - events_context (list[(EventBase, EventContext)]): - events and contexts which are being added to the room - - old_latest_event_ids (iterable[str]): - the old forward extremities for the room. - - new_latest_event_ids (iterable[str]): - the new forward extremities for the room. - - Returns: - Deferred[tuple[dict[(str,str), str]|None, dict[(str,str), str]|None]]: - Returns a tuple of two state maps, the first being the full new current - state and the second being the delta to the existing current state. - If both are None then there has been no change. - - If there has been a change then we only return the delta if its - already been calculated. Conversely if we do know the delta then - the new current state is only returned if we've already calculated - it. - """ - # map from state_group to ((type, key) -> event_id) state map - state_groups_map = {} - - # Map from (prev state group, new state group) -> delta state dict - state_group_deltas = {} - - for ev, ctx in events_context: - if ctx.state_group is None: - # This should only happen for outlier events. - if not ev.internal_metadata.is_outlier(): - raise Exception( - "Context for new event %s has no state " - "group" % (ev.event_id,) - ) - continue - - if ctx.state_group in state_groups_map: - continue - - # We're only interested in pulling out state that has already - # been cached in the context. We'll pull stuff out of the DB later - # if necessary. - current_state_ids = ctx.get_cached_current_state_ids() - if current_state_ids is not None: - state_groups_map[ctx.state_group] = current_state_ids - - if ctx.prev_group: - state_group_deltas[(ctx.prev_group, ctx.state_group)] = ctx.delta_ids - - # We need to map the event_ids to their state groups. First, let's - # check if the event is one we're persisting, in which case we can - # pull the state group from its context. - # Otherwise we need to pull the state group from the database. - - # Set of events we need to fetch groups for. (We know none of the old - # extremities are going to be in events_context). - missing_event_ids = set(old_latest_event_ids) - - event_id_to_state_group = {} - for event_id in new_latest_event_ids: - # First search in the list of new events we're adding. - for ev, ctx in events_context: - if event_id == ev.event_id and ctx.state_group is not None: - event_id_to_state_group[event_id] = ctx.state_group - break - else: - # If we couldn't find it, then we'll need to pull - # the state from the database - missing_event_ids.add(event_id) - - if missing_event_ids: - # Now pull out the state groups for any missing events from DB - event_to_groups = yield self._get_state_group_for_events(missing_event_ids) - event_id_to_state_group.update(event_to_groups) - - # State groups of old_latest_event_ids - old_state_groups = set( - event_id_to_state_group[evid] for evid in old_latest_event_ids - ) - - # State groups of new_latest_event_ids - new_state_groups = set( - event_id_to_state_group[evid] for evid in new_latest_event_ids - ) - - # If they old and new groups are the same then we don't need to do - # anything. - if old_state_groups == new_state_groups: - return None, None - - if len(new_state_groups) == 1 and len(old_state_groups) == 1: - # If we're going from one state group to another, lets check if - # we have a delta for that transition. If we do then we can just - # return that. - - new_state_group = next(iter(new_state_groups)) - old_state_group = next(iter(old_state_groups)) - - delta_ids = state_group_deltas.get((old_state_group, new_state_group), None) - if delta_ids is not None: - # We have a delta from the existing to new current state, - # so lets just return that. If we happen to already have - # the current state in memory then lets also return that, - # but it doesn't matter if we don't. - new_state = state_groups_map.get(new_state_group) - return new_state, delta_ids - - # Now that we have calculated new_state_groups we need to get - # their state IDs so we can resolve to a single state set. - missing_state = new_state_groups - set(state_groups_map) - if missing_state: - group_to_state = yield self._get_state_for_groups(missing_state) - state_groups_map.update(group_to_state) - - if len(new_state_groups) == 1: - # If there is only one state group, then we know what the current - # state is. - return state_groups_map[new_state_groups.pop()], None - - # Ok, we need to defer to the state handler to resolve our state sets. - - state_groups = {sg: state_groups_map[sg] for sg in new_state_groups} - - events_map = {ev.event_id: ev for ev, _ in events_context} - - # We need to get the room version, which is in the create event. - # Normally that'd be in the database, but its also possible that we're - # currently trying to persist it. - room_version = None - for ev, _ in events_context: - if ev.type == EventTypes.Create and ev.state_key == "": - room_version = ev.content.get("room_version", "1") - break - - if not room_version: - room_version = yield self.get_room_version(room_id) - - logger.debug("calling resolve_state_groups from preserve_events") - res = yield self._state_resolution_handler.resolve_state_groups( - room_id, - room_version, - state_groups, - events_map, - state_res_store=StateResolutionStore(self), - ) - - return res.state, None - - @defer.inlineCallbacks - def _calculate_state_delta(self, room_id, current_state): - """Calculate the new state deltas for a room. - - Assumes that we are only persisting events for one room at a time. - - Returns: - tuple[list, dict] (to_delete, to_insert): where to_delete are the - type/state_keys to remove from current_state_events and `to_insert` - are the updates to current_state_events. - """ - existing_state = yield self.get_current_state_ids(room_id) - - to_delete = [key for key in existing_state if key not in current_state] - - to_insert = { - key: ev_id - for key, ev_id in iteritems(current_state) - if ev_id != existing_state.get(key) - } - - return to_delete, to_insert - @log_function def _persist_events_txn( self, diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py new file mode 100644 index 0000000000..cd445be670 --- /dev/null +++ b/synapse/storage/persist_events.py @@ -0,0 +1,644 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018-2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from collections import deque, namedtuple + +from six import iteritems +from six.moves import range + +from prometheus_client import Counter, Histogram + +from twisted.internet import defer + +from synapse.api.constants import EventTypes +from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.state import StateResolutionStore +from synapse.util.async_helpers import ObservableDeferred +from synapse.util.metrics import Measure + +logger = logging.getLogger(__name__) + +# The number of times we are recalculating the current state +state_delta_counter = Counter("synapse_storage_events_state_delta", "") + +# The number of times we are recalculating state when there is only a +# single forward extremity +state_delta_single_event_counter = Counter( + "synapse_storage_events_state_delta_single_event", "" +) + +# The number of times we are reculating state when we could have resonably +# calculated the delta when we calculated the state for an event we were +# persisting. +state_delta_reuse_delta_counter = Counter( + "synapse_storage_events_state_delta_reuse_delta", "" +) + +# The number of forward extremities for each new event. +forward_extremities_counter = Histogram( + "synapse_storage_events_forward_extremities_persisted", + "Number of forward extremities for each new event", + buckets=(1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"), +) + +# The number of stale forward extremities for each new event. Stale extremities +# are those that were in the previous set of extremities as well as the new. +stale_forward_extremities_counter = Histogram( + "synapse_storage_events_stale_forward_extremities_persisted", + "Number of unchanged forward extremities for each new event", + buckets=(0, 1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"), +) + + +class _EventPeristenceQueue(object): + """Queues up events so that they can be persisted in bulk with only one + concurrent transaction per room. + """ + + _EventPersistQueueItem = namedtuple( + "_EventPersistQueueItem", ("events_and_contexts", "backfilled", "deferred") + ) + + def __init__(self): + self._event_persist_queues = {} + self._currently_persisting_rooms = set() + + def add_to_queue(self, room_id, events_and_contexts, backfilled): + """Add events to the queue, with the given persist_event options. + + NB: due to the normal usage pattern of this method, it does *not* + follow the synapse logcontext rules, and leaves the logcontext in + place whether or not the returned deferred is ready. + + Args: + room_id (str): + events_and_contexts (list[(EventBase, EventContext)]): + backfilled (bool): + + Returns: + defer.Deferred: a deferred which will resolve once the events are + persisted. Runs its callbacks *without* a logcontext. + """ + queue = self._event_persist_queues.setdefault(room_id, deque()) + if queue: + # if the last item in the queue has the same `backfilled` setting, + # we can just add these new events to that item. + end_item = queue[-1] + if end_item.backfilled == backfilled: + end_item.events_and_contexts.extend(events_and_contexts) + return end_item.deferred.observe() + + deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True) + + queue.append( + self._EventPersistQueueItem( + events_and_contexts=events_and_contexts, + backfilled=backfilled, + deferred=deferred, + ) + ) + + return deferred.observe() + + def handle_queue(self, room_id, per_item_callback): + """Attempts to handle the queue for a room if not already being handled. + + The given callback will be invoked with for each item in the queue, + of type _EventPersistQueueItem. The per_item_callback will continuously + be called with new items, unless the queue becomnes empty. The return + value of the function will be given to the deferreds waiting on the item, + exceptions will be passed to the deferreds as well. + + This function should therefore be called whenever anything is added + to the queue. + + If another callback is currently handling the queue then it will not be + invoked. + """ + + if room_id in self._currently_persisting_rooms: + return + + self._currently_persisting_rooms.add(room_id) + + @defer.inlineCallbacks + def handle_queue_loop(): + try: + queue = self._get_drainining_queue(room_id) + for item in queue: + try: + ret = yield per_item_callback(item) + except Exception: + with PreserveLoggingContext(): + item.deferred.errback() + else: + with PreserveLoggingContext(): + item.deferred.callback(ret) + finally: + queue = self._event_persist_queues.pop(room_id, None) + if queue: + self._event_persist_queues[room_id] = queue + self._currently_persisting_rooms.discard(room_id) + + # set handle_queue_loop off in the background + run_as_background_process("persist_events", handle_queue_loop) + + def _get_drainining_queue(self, room_id): + queue = self._event_persist_queues.setdefault(room_id, deque()) + + try: + while True: + yield queue.popleft() + except IndexError: + # Queue has been drained. + pass + + +class EventsPersistenceStore(object): + def __init__(self, hs): + # We ultimately want to split out the state store from the main store, + # so we use separate variables here even though they point to the same + # store for now. + self.main_store = hs.get_datastore() + self.state_store = hs.get_datastore() + + self._clock = hs.get_clock() + self.is_mine_id = hs.is_mine_id + self._event_persist_queue = _EventPeristenceQueue() + self._state_resolution_handler = hs.get_state_resolution_handler() + + @defer.inlineCallbacks + def persist_events(self, events_and_contexts, backfilled=False): + """ + Write events to the database + Args: + events_and_contexts: list of tuples of (event, context) + backfilled (bool): Whether the results are retrieved from federation + via backfill or not. Used to determine if they're "new" events + which might update the current state etc. + + Returns: + Deferred[int]: the stream ordering of the latest persisted event + """ + partitioned = {} + for event, ctx in events_and_contexts: + partitioned.setdefault(event.room_id, []).append((event, ctx)) + + deferreds = [] + for room_id, evs_ctxs in iteritems(partitioned): + d = self._event_persist_queue.add_to_queue( + room_id, evs_ctxs, backfilled=backfilled + ) + deferreds.append(d) + + for room_id in partitioned: + self._maybe_start_persisting(room_id) + + yield make_deferred_yieldable( + defer.gatherResults(deferreds, consumeErrors=True) + ) + + max_persisted_id = yield self.main_store.get_current_events_token() + + return max_persisted_id + + @defer.inlineCallbacks + def persist_event(self, event, context, backfilled=False): + """ + + Args: + event (EventBase): + context (EventContext): + backfilled (bool): + + Returns: + Deferred: resolves to (int, int): the stream ordering of ``event``, + and the stream ordering of the latest persisted event + """ + deferred = self._event_persist_queue.add_to_queue( + event.room_id, [(event, context)], backfilled=backfilled + ) + + self._maybe_start_persisting(event.room_id) + + yield make_deferred_yieldable(deferred) + + max_persisted_id = yield self.main_store.get_current_events_token() + return (event.internal_metadata.stream_ordering, max_persisted_id) + + def _maybe_start_persisting(self, room_id): + @defer.inlineCallbacks + def persisting_queue(item): + with Measure(self._clock, "persist_events"): + yield self._persist_events( + item.events_and_contexts, backfilled=item.backfilled + ) + + self._event_persist_queue.handle_queue(room_id, persisting_queue) + + @defer.inlineCallbacks + def _persist_events( + self, events_and_contexts, backfilled=False, delete_existing=False + ): + """Persist events to db + + Args: + events_and_contexts (list[(EventBase, EventContext)]): + backfilled (bool): + delete_existing (bool): + + Returns: + Deferred: resolves when the events have been persisted + """ + if not events_and_contexts: + return + + chunks = [ + events_and_contexts[x : x + 100] + for x in range(0, len(events_and_contexts), 100) + ] + + for chunk in chunks: + # We can't easily parallelize these since different chunks + # might contain the same event. :( + + # NB: Assumes that we are only persisting events for one room + # at a time. + + # map room_id->list[event_ids] giving the new forward + # extremities in each room + new_forward_extremeties = {} + + # map room_id->(type,state_key)->event_id tracking the full + # state in each room after adding these events. + # This is simply used to prefill the get_current_state_ids + # cache + current_state_for_room = {} + + # map room_id->(to_delete, to_insert) where to_delete is a list + # of type/state keys to remove from current state, and to_insert + # is a map (type,key)->event_id giving the state delta in each + # room + state_delta_for_room = {} + + if not backfilled: + with Measure(self._clock, "_calculate_state_and_extrem"): + # Work out the new "current state" for each room. + # We do this by working out what the new extremities are and then + # calculating the state from that. + events_by_room = {} + for event, context in chunk: + events_by_room.setdefault(event.room_id, []).append( + (event, context) + ) + + for room_id, ev_ctx_rm in iteritems(events_by_room): + latest_event_ids = yield self.main_store.get_latest_event_ids_in_room( + room_id + ) + new_latest_event_ids = yield self._calculate_new_extremities( + room_id, ev_ctx_rm, latest_event_ids + ) + + latest_event_ids = set(latest_event_ids) + if new_latest_event_ids == latest_event_ids: + # No change in extremities, so no change in state + continue + + # there should always be at least one forward extremity. + # (except during the initial persistence of the send_join + # results, in which case there will be no existing + # extremities, so we'll `continue` above and skip this bit.) + assert new_latest_event_ids, "No forward extremities left!" + + new_forward_extremeties[room_id] = new_latest_event_ids + + len_1 = ( + len(latest_event_ids) == 1 + and len(new_latest_event_ids) == 1 + ) + if len_1: + all_single_prev_not_state = all( + len(event.prev_event_ids()) == 1 + and not event.is_state() + for event, ctx in ev_ctx_rm + ) + # Don't bother calculating state if they're just + # a long chain of single ancestor non-state events. + if all_single_prev_not_state: + continue + + state_delta_counter.inc() + if len(new_latest_event_ids) == 1: + state_delta_single_event_counter.inc() + + # This is a fairly handwavey check to see if we could + # have guessed what the delta would have been when + # processing one of these events. + # What we're interested in is if the latest extremities + # were the same when we created the event as they are + # now. When this server creates a new event (as opposed + # to receiving it over federation) it will use the + # forward extremities as the prev_events, so we can + # guess this by looking at the prev_events and checking + # if they match the current forward extremities. + for ev, _ in ev_ctx_rm: + prev_event_ids = set(ev.prev_event_ids()) + if latest_event_ids == prev_event_ids: + state_delta_reuse_delta_counter.inc() + break + + logger.info("Calculating state delta for room %s", room_id) + with Measure( + self._clock, "persist_events.get_new_state_after_events" + ): + res = yield self._get_new_state_after_events( + room_id, + ev_ctx_rm, + latest_event_ids, + new_latest_event_ids, + ) + current_state, delta_ids = res + + # If either are not None then there has been a change, + # and we need to work out the delta (or use that + # given) + if delta_ids is not None: + # If there is a delta we know that we've + # only added or replaced state, never + # removed keys entirely. + state_delta_for_room[room_id] = ([], delta_ids) + elif current_state is not None: + with Measure( + self._clock, "persist_events.calculate_state_delta" + ): + delta = yield self._calculate_state_delta( + room_id, current_state + ) + state_delta_for_room[room_id] = delta + + # If we have the current_state then lets prefill + # the cache with it. + if current_state is not None: + current_state_for_room[room_id] = current_state + + yield self.main_store._persist_events( + chunk, + current_state_for_room=current_state_for_room, + state_delta_for_room=state_delta_for_room, + new_forward_extremeties=new_forward_extremeties, + backfilled=backfilled, + delete_existing=delete_existing, + ) + + @defer.inlineCallbacks + def _calculate_new_extremities(self, room_id, event_contexts, latest_event_ids): + """Calculates the new forward extremities for a room given events to + persist. + + Assumes that we are only persisting events for one room at a time. + """ + + # we're only interested in new events which aren't outliers and which aren't + # being rejected. + new_events = [ + event + for event, ctx in event_contexts + if not event.internal_metadata.is_outlier() + and not ctx.rejected + and not event.internal_metadata.is_soft_failed() + ] + + latest_event_ids = set(latest_event_ids) + + # start with the existing forward extremities + result = set(latest_event_ids) + + # add all the new events to the list + result.update(event.event_id for event in new_events) + + # Now remove all events which are prev_events of any of the new events + result.difference_update( + e_id for event in new_events for e_id in event.prev_event_ids() + ) + + # Remove any events which are prev_events of any existing events. + existing_prevs = yield self.main_store._get_events_which_are_prevs(result) + result.difference_update(existing_prevs) + + # Finally handle the case where the new events have soft-failed prev + # events. If they do we need to remove them and their prev events, + # otherwise we end up with dangling extremities. + existing_prevs = yield self.main_store._get_prevs_before_rejected( + e_id for event in new_events for e_id in event.prev_event_ids() + ) + result.difference_update(existing_prevs) + + # We only update metrics for events that change forward extremities + # (e.g. we ignore backfill/outliers/etc) + if result != latest_event_ids: + forward_extremities_counter.observe(len(result)) + stale = latest_event_ids & result + stale_forward_extremities_counter.observe(len(stale)) + + return result + + @defer.inlineCallbacks + def _get_new_state_after_events( + self, room_id, events_context, old_latest_event_ids, new_latest_event_ids + ): + """Calculate the current state dict after adding some new events to + a room + + Args: + room_id (str): + room to which the events are being added. Used for logging etc + + events_context (list[(EventBase, EventContext)]): + events and contexts which are being added to the room + + old_latest_event_ids (iterable[str]): + the old forward extremities for the room. + + new_latest_event_ids (iterable[str]): + the new forward extremities for the room. + + Returns: + Deferred[tuple[dict[(str,str), str]|None, dict[(str,str), str]|None]]: + Returns a tuple of two state maps, the first being the full new current + state and the second being the delta to the existing current state. + If both are None then there has been no change. + + If there has been a change then we only return the delta if its + already been calculated. Conversely if we do know the delta then + the new current state is only returned if we've already calculated + it. + """ + # map from state_group to ((type, key) -> event_id) state map + state_groups_map = {} + + # Map from (prev state group, new state group) -> delta state dict + state_group_deltas = {} + + for ev, ctx in events_context: + if ctx.state_group is None: + # This should only happen for outlier events. + if not ev.internal_metadata.is_outlier(): + raise Exception( + "Context for new event %s has no state " + "group" % (ev.event_id,) + ) + continue + + if ctx.state_group in state_groups_map: + continue + + # We're only interested in pulling out state that has already + # been cached in the context. We'll pull stuff out of the DB later + # if necessary. + current_state_ids = ctx.get_cached_current_state_ids() + if current_state_ids is not None: + state_groups_map[ctx.state_group] = current_state_ids + + if ctx.prev_group: + state_group_deltas[(ctx.prev_group, ctx.state_group)] = ctx.delta_ids + + # We need to map the event_ids to their state groups. First, let's + # check if the event is one we're persisting, in which case we can + # pull the state group from its context. + # Otherwise we need to pull the state group from the database. + + # Set of events we need to fetch groups for. (We know none of the old + # extremities are going to be in events_context). + missing_event_ids = set(old_latest_event_ids) + + event_id_to_state_group = {} + for event_id in new_latest_event_ids: + # First search in the list of new events we're adding. + for ev, ctx in events_context: + if event_id == ev.event_id and ctx.state_group is not None: + event_id_to_state_group[event_id] = ctx.state_group + break + else: + # If we couldn't find it, then we'll need to pull + # the state from the database + missing_event_ids.add(event_id) + + if missing_event_ids: + # Now pull out the state groups for any missing events from DB + event_to_groups = yield self.state_store._get_state_group_for_events( + missing_event_ids + ) + event_id_to_state_group.update(event_to_groups) + + # State groups of old_latest_event_ids + old_state_groups = set( + event_id_to_state_group[evid] for evid in old_latest_event_ids + ) + + # State groups of new_latest_event_ids + new_state_groups = set( + event_id_to_state_group[evid] for evid in new_latest_event_ids + ) + + # If they old and new groups are the same then we don't need to do + # anything. + if old_state_groups == new_state_groups: + return None, None + + if len(new_state_groups) == 1 and len(old_state_groups) == 1: + # If we're going from one state group to another, lets check if + # we have a delta for that transition. If we do then we can just + # return that. + + new_state_group = next(iter(new_state_groups)) + old_state_group = next(iter(old_state_groups)) + + delta_ids = state_group_deltas.get((old_state_group, new_state_group), None) + if delta_ids is not None: + # We have a delta from the existing to new current state, + # so lets just return that. If we happen to already have + # the current state in memory then lets also return that, + # but it doesn't matter if we don't. + new_state = state_groups_map.get(new_state_group) + return new_state, delta_ids + + # Now that we have calculated new_state_groups we need to get + # their state IDs so we can resolve to a single state set. + missing_state = new_state_groups - set(state_groups_map) + if missing_state: + group_to_state = yield self.state_store._get_state_for_groups(missing_state) + state_groups_map.update(group_to_state) + + if len(new_state_groups) == 1: + # If there is only one state group, then we know what the current + # state is. + return state_groups_map[new_state_groups.pop()], None + + # Ok, we need to defer to the state handler to resolve our state sets. + + state_groups = {sg: state_groups_map[sg] for sg in new_state_groups} + + events_map = {ev.event_id: ev for ev, _ in events_context} + + # We need to get the room version, which is in the create event. + # Normally that'd be in the database, but its also possible that we're + # currently trying to persist it. + room_version = None + for ev, _ in events_context: + if ev.type == EventTypes.Create and ev.state_key == "": + room_version = ev.content.get("room_version", "1") + break + + if not room_version: + room_version = yield self.main_store.get_room_version(room_id) + + logger.debug("calling resolve_state_groups from preserve_events") + res = yield self._state_resolution_handler.resolve_state_groups( + room_id, + room_version, + state_groups, + events_map, + state_res_store=StateResolutionStore(self.main_store), + ) + + return res.state, None + + @defer.inlineCallbacks + def _calculate_state_delta(self, room_id, current_state): + """Calculate the new state deltas for a room. + + Assumes that we are only persisting events for one room at a time. + + Returns: + tuple[list, dict] (to_delete, to_insert): where to_delete are the + type/state_keys to remove from current_state_events and `to_insert` + are the updates to current_state_events. + """ + existing_state = yield self.main_store.get_current_state_ids(room_id) + + to_delete = [key for key in existing_state if key not in current_state] + + to_insert = { + key: ev_id + for key, ev_id in iteritems(current_state) + if ev_id != existing_state.get(key) + } + + return to_delete, to_insert -- cgit 1.4.1 From dc2cd6f79d413012688ef89e8999c0c2207e198e Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Wed, 23 Oct 2019 09:13:47 -0400 Subject: move get_e2e_cross_signing_key to EndToEndKeyWorkerStore so it works with workers --- synapse/storage/end_to_end_keys.py | 134 ++++++++++++++++++------------------- 1 file changed, 67 insertions(+), 67 deletions(-) (limited to 'synapse') diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py index 2effe08592..f9bef14992 100644 --- a/synapse/storage/end_to_end_keys.py +++ b/synapse/storage/end_to_end_keys.py @@ -249,6 +249,73 @@ class EndToEndKeyWorkerStore(SQLBaseStore): return self.runInteraction("count_e2e_one_time_keys", _count_e2e_one_time_keys) + def _get_e2e_cross_signing_key_txn(self, txn, user_id, key_type, from_user_id=None): + """Returns a user's cross-signing key. + + Args: + txn (twisted.enterprise.adbapi.Connection): db connection + user_id (str): the user whose key is being requested + key_type (str): the type of key that is being set: either 'master' + for a master key, 'self_signing' for a self-signing key, or + 'user_signing' for a user-signing key + from_user_id (str): if specified, signatures made by this user on + the key will be included in the result + + Returns: + dict of the key data or None if not found + """ + sql = ( + "SELECT keydata " + " FROM e2e_cross_signing_keys " + " WHERE user_id = ? AND keytype = ? ORDER BY stream_id DESC LIMIT 1" + ) + txn.execute(sql, (user_id, key_type)) + row = txn.fetchone() + if not row: + return None + key = json.loads(row[0]) + + device_id = None + for k in key["keys"].values(): + device_id = k + + if from_user_id is not None: + sql = ( + "SELECT key_id, signature " + " FROM e2e_cross_signing_signatures " + " WHERE user_id = ? " + " AND target_user_id = ? " + " AND target_device_id = ? " + ) + txn.execute(sql, (from_user_id, user_id, device_id)) + row = txn.fetchone() + if row: + key.setdefault("signatures", {}).setdefault(from_user_id, {})[ + row[0] + ] = row[1] + + return key + + def get_e2e_cross_signing_key(self, user_id, key_type, from_user_id=None): + """Returns a user's cross-signing key. + + Args: + user_id (str): the user whose self-signing key is being requested + key_type (str): the type of cross-signing key to get + from_user_id (str): if specified, signatures made by this user on + the self-signing key will be included in the result + + Returns: + dict of the key data or None if not found + """ + return self.runInteraction( + "get_e2e_cross_signing_key", + self._get_e2e_cross_signing_key_txn, + user_id, + key_type, + from_user_id, + ) + class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys): @@ -427,73 +494,6 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): key, ) - def _get_e2e_cross_signing_key_txn(self, txn, user_id, key_type, from_user_id=None): - """Returns a user's cross-signing key. - - Args: - txn (twisted.enterprise.adbapi.Connection): db connection - user_id (str): the user whose key is being requested - key_type (str): the type of key that is being set: either 'master' - for a master key, 'self_signing' for a self-signing key, or - 'user_signing' for a user-signing key - from_user_id (str): if specified, signatures made by this user on - the key will be included in the result - - Returns: - dict of the key data or None if not found - """ - sql = ( - "SELECT keydata " - " FROM e2e_cross_signing_keys " - " WHERE user_id = ? AND keytype = ? ORDER BY stream_id DESC LIMIT 1" - ) - txn.execute(sql, (user_id, key_type)) - row = txn.fetchone() - if not row: - return None - key = json.loads(row[0]) - - device_id = None - for k in key["keys"].values(): - device_id = k - - if from_user_id is not None: - sql = ( - "SELECT key_id, signature " - " FROM e2e_cross_signing_signatures " - " WHERE user_id = ? " - " AND target_user_id = ? " - " AND target_device_id = ? " - ) - txn.execute(sql, (from_user_id, user_id, device_id)) - row = txn.fetchone() - if row: - key.setdefault("signatures", {}).setdefault(from_user_id, {})[ - row[0] - ] = row[1] - - return key - - def get_e2e_cross_signing_key(self, user_id, key_type, from_user_id=None): - """Returns a user's cross-signing key. - - Args: - user_id (str): the user whose self-signing key is being requested - key_type (str): the type of cross-signing key to get - from_user_id (str): if specified, signatures made by this user on - the self-signing key will be included in the result - - Returns: - dict of the key data or None if not found - """ - return self.runInteraction( - "get_e2e_cross_signing_key", - self._get_e2e_cross_signing_key_txn, - user_id, - key_type, - from_user_id, - ) - def store_e2e_cross_signing_signatures(self, user_id, signatures): """Stores cross-signing signatures. -- cgit 1.4.1 From 73cf63784b90ea194eb867aafe3f39203b7ae029 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 23 Oct 2019 16:14:16 +0100 Subject: Add DataStores and Storage classes. --- synapse/storage/__init__.py | 19 ++++++++++++++++++- synapse/storage/data_stores/__init__.py | 12 ++++++++++++ synapse/storage/persist_events.py | 7 ++++--- 3 files changed, 34 insertions(+), 4 deletions(-) (limited to 'synapse') diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index a249ecd219..a58187a76f 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -27,7 +27,24 @@ data stores associated with them (e.g. the schema version tables), which are stored in `synapse.storage.schema`. """ -from synapse.storage.data_stores.main import DataStore # noqa: F401 +from synapse.storage.data_stores import DataStores +from synapse.storage.data_stores.main import DataStore +from synapse.storage.persist_events import EventsPersistenceStore + +__all__ = ["DataStores", "DataStore"] + + +class Storage(object): + """The high level interfaces for talking to various storage layers. + """ + + def __init__(self, hs, stores: DataStores): + # We include the main data store here mainly so that we don't have to + # rewrite all the existing code to split it into high vs low level + # interfaces. + self.main = stores.main + + self.persistence = EventsPersistenceStore(hs, stores) def are_all_users_on_domain(txn, database_engine, domain): diff --git a/synapse/storage/data_stores/__init__.py b/synapse/storage/data_stores/__init__.py index 56094078ed..cb184a98cc 100644 --- a/synapse/storage/data_stores/__init__.py +++ b/synapse/storage/data_stores/__init__.py @@ -12,3 +12,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + + +class DataStores(object): + """The various data stores. + + These are low level interfaces to physical databases. + """ + + def __init__(self, main_store, db_conn, hs): + # Note we pass in the main store here as workers use a different main + # store. + self.main = main_store diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py index cd445be670..9a63953d4d 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py @@ -29,6 +29,7 @@ from synapse.api.constants import EventTypes from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable from synapse.metrics.background_process_metrics import run_as_background_process from synapse.state import StateResolutionStore +from synapse.storage.data_stores import DataStores from synapse.util.async_helpers import ObservableDeferred from synapse.util.metrics import Measure @@ -171,12 +172,12 @@ class _EventPeristenceQueue(object): class EventsPersistenceStore(object): - def __init__(self, hs): + def __init__(self, hs, stores: DataStores): # We ultimately want to split out the state store from the main store, # so we use separate variables here even though they point to the same # store for now. - self.main_store = hs.get_datastore() - self.state_store = hs.get_datastore() + self.main_store = stores.main + self.state_store = stores.main self._clock = hs.get_clock() self.is_mine_id = hs.is_mine_id -- cgit 1.4.1 From 3ca4c7c516a349cbb9dace0ace33bb889955eda1 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 23 Oct 2019 12:02:36 +0100 Subject: Use new EventPersistenceStore --- synapse/handlers/federation.py | 3 ++- synapse/handlers/message.py | 3 ++- synapse/server.py | 9 ++++++++- tests/replication/slave/storage/_base.py | 1 + tests/replication/slave/storage/test_events.py | 10 +++++++--- tests/storage/test_redaction.py | 11 ++++++----- tests/storage/test_room.py | 3 ++- tests/storage/test_roommember.py | 3 ++- tests/storage/test_state.py | 3 ++- tests/test_visibility.py | 7 ++++--- tests/utils.py | 10 ++++++++-- 11 files changed, 44 insertions(+), 19 deletions(-) (limited to 'synapse') diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 4b4c6c15f9..7e9c9aee13 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -109,6 +109,7 @@ class FederationHandler(BaseHandler): self.hs = hs self.store = hs.get_datastore() + self.storage = hs.get_storage() self.federation_client = hs.get_federation_client() self.state_handler = hs.get_state_handler() self.server_name = hs.hostname @@ -2648,7 +2649,7 @@ class FederationHandler(BaseHandler): backfilled=backfilled, ) else: - max_stream_id = yield self.store.persist_events( + max_stream_id = yield self.storage.persistence.persist_events( event_and_contexts, backfilled=backfilled ) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 0f8cce8ffe..7908a2d52c 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -234,6 +234,7 @@ class EventCreationHandler(object): self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() + self.storage = hs.get_storage() self.state = hs.get_state_handler() self.clock = hs.get_clock() self.validator = EventValidator() @@ -868,7 +869,7 @@ class EventCreationHandler(object): if prev_state_ids: raise AuthError(403, "Changing the room create event is forbidden") - (event_stream_id, max_stream_id) = yield self.store.persist_event( + event_stream_id, max_stream_id = yield self.storage.persistence.persist_event( event, context=context ) diff --git a/synapse/server.py b/synapse/server.py index 1fcc7375d3..54a7f4aa5f 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -95,6 +95,7 @@ from synapse.server_notices.worker_server_notices_sender import ( WorkerServerNoticesSender, ) from synapse.state import StateHandler, StateResolutionHandler +from synapse.storage import DataStores, Storage from synapse.streams.events import EventSources from synapse.util import Clock from synapse.util.distributor import Distributor @@ -196,6 +197,7 @@ class HomeServer(object): "account_validity_handler", "saml_handler", "event_client_serializer", + "storage", ] REQUIRED_ON_MASTER_STARTUP = ["user_directory_handler", "stats_handler"] @@ -225,6 +227,7 @@ class HomeServer(object): self.registration_ratelimiter = Ratelimiter() self.datastore = None + self.datastores = None # Other kwargs are explicit dependencies for depname in kwargs: @@ -234,6 +237,7 @@ class HomeServer(object): logger.info("Setting up.") with self.get_db_conn() as conn: self.datastore = self.DATASTORE_CLASS(conn, self) + self.datastores = DataStores(self.datastore, conn, self) conn.commit() logger.info("Finished setting up.") @@ -266,7 +270,7 @@ class HomeServer(object): return self.clock def get_datastore(self): - return self.datastore + return self.datastores.main def get_config(self): return self.config @@ -537,6 +541,9 @@ class HomeServer(object): def build_event_client_serializer(self): return EventClientSerializer(self) + def build_storage(self) -> Storage: + return Storage(self, self.datastores) + def remove_pusher(self, app_id, push_key, user_id): return self.get_pusherpool().remove_pusher(app_id, push_key, user_id) diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index 104349cdbd..4f924ce451 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -41,6 +41,7 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.master_store = self.hs.get_datastore() + self.storage = hs.get_storage() self.slaved_store = self.STORE_TYPE(self.hs.get_db_conn(), self.hs) self.event_id = 0 diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index a368117b43..b68e9fe082 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -234,7 +234,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): type="m.room.member", sender=USER_ID_2, key=USER_ID_2, membership="join" ) msg, msgctx = self.build_event() - self.get_success(self.master_store.persist_events([(j2, j2ctx), (msg, msgctx)])) + self.get_success( + self.storage.persistence.persist_events([(j2, j2ctx), (msg, msgctx)]) + ) self.replicate() event_source = RoomEventSource(self.hs) @@ -290,10 +292,12 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): if backfill: self.get_success( - self.master_store.persist_events([(event, context)], backfilled=True) + self.storage.persistence.persist_events( + [(event, context)], backfilled=True + ) ) else: - self.get_success(self.master_store.persist_event(event, context)) + self.get_success(self.storage.persistence.persist_event(event, context)) return event diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index 427d3c49c5..4561c3e383 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -39,6 +39,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.store = hs.get_datastore() + self.storage = hs.get_storage() self.event_builder_factory = hs.get_event_builder_factory() self.event_creation_handler = hs.get_event_creation_handler() @@ -73,7 +74,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): self.event_creation_handler.create_new_client_event(builder) ) - self.get_success(self.store.persist_event(event, context)) + self.get_success(self.storage.persistence.persist_event(event, context)) return event @@ -95,7 +96,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): self.event_creation_handler.create_new_client_event(builder) ) - self.get_success(self.store.persist_event(event, context)) + self.get_success(self.storage.persistence.persist_event(event, context)) return event @@ -116,7 +117,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): self.event_creation_handler.create_new_client_event(builder) ) - self.get_success(self.store.persist_event(event, context)) + self.get_success(self.storage.persistence.persist_event(event, context)) return event @@ -263,7 +264,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): ) ) - self.get_success(self.store.persist_event(event_1, context_1)) + self.get_success(self.storage.persistence.persist_event(event_1, context_1)) event_2, context_2 = self.get_success( self.event_creation_handler.create_new_client_event( @@ -282,7 +283,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): ) ) ) - self.get_success(self.store.persist_event(event_2, context_2)) + self.get_success(self.storage.persistence.persist_event(event_2, context_2)) # fetch one of the redactions fetched = self.get_success(self.store.get_event(redaction_event_id1)) diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py index 1bee45706f..3ddaa151fe 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py @@ -62,6 +62,7 @@ class RoomEventsStoreTestCase(unittest.TestCase): # Room events need the full datastore, for persist_event() and # get_room_state() self.store = hs.get_datastore() + self.storage = hs.get_storage() self.event_factory = hs.get_event_factory() self.room = RoomID.from_string("!abcde:test") @@ -72,7 +73,7 @@ class RoomEventsStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def inject_room_event(self, **kwargs): - yield self.store.persist_event( + yield self.storage.persistence.persist_event( self.event_factory.create_event(room_id=self.room.to_string(), **kwargs) ) diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index 447a3c6ffb..9ddd17f73d 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -44,6 +44,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): # We can't test the RoomMemberStore on its own without the other event # storage logic self.store = hs.get_datastore() + self.storage = hs.get_storage() self.event_builder_factory = hs.get_event_builder_factory() self.event_creation_handler = hs.get_event_creation_handler() @@ -70,7 +71,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): self.event_creation_handler.create_new_client_event(builder) ) - self.get_success(self.store.persist_event(event, context)) + self.get_success(self.storage.persistence.persist_event(event, context)) return event diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index 5c2cf3c2db..d573a3e07b 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -34,6 +34,7 @@ class StateStoreTestCase(tests.unittest.TestCase): hs = yield tests.utils.setup_test_homeserver(self.addCleanup) self.store = hs.get_datastore() + self.storage = hs.get_storage() self.event_builder_factory = hs.get_event_builder_factory() self.event_creation_handler = hs.get_event_creation_handler() @@ -63,7 +64,7 @@ class StateStoreTestCase(tests.unittest.TestCase): builder ) - yield self.store.persist_event(event, context) + yield self.storage.persistence.persist_event(event, context) return event diff --git a/tests/test_visibility.py b/tests/test_visibility.py index 18f1a0035d..6ae1ea9b04 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -36,6 +36,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): self.event_creation_handler = self.hs.get_event_creation_handler() self.event_builder_factory = self.hs.get_event_builder_factory() self.store = self.hs.get_datastore() + self.storage = self.hs.get_storage() yield create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM") @@ -137,7 +138,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): event, context = yield self.event_creation_handler.create_new_client_event( builder ) - yield self.hs.get_datastore().persist_event(event, context) + yield self.storage.persistence.persist_event(event, context) return event @defer.inlineCallbacks @@ -159,7 +160,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): builder ) - yield self.hs.get_datastore().persist_event(event, context) + yield self.storage.persistence.persist_event(event, context) return event @defer.inlineCallbacks @@ -180,7 +181,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): builder ) - yield self.hs.get_datastore().persist_event(event, context) + yield self.storage.persistence.persist_event(event, context) return event @defer.inlineCallbacks diff --git a/tests/utils.py b/tests/utils.py index 0a64f75d04..2808eea933 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -326,10 +326,16 @@ def setup_test_homeserver( if homeserverToUse.__name__ == "TestHomeServer": hs.setup_master() else: + # If we have been given an explicit datastore we probably want to mock + # out the DataStores somehow too. This all feels a bit wrong, but then + # mocking the stores feels wrong too. + datastores = Mock(datastore=datastore) + hs = homeserverToUse( name, db_pool=None, datastore=datastore, + datastores=datastores, config=config, version_string="Synapse/tests", database_engine=db_engine, @@ -647,7 +653,7 @@ def create_room(hs, room_id, creator_id): creator_id (str) """ - store = hs.get_datastore() + persistence_store = hs.get_storage().persistence event_builder_factory = hs.get_event_builder_factory() event_creation_handler = hs.get_event_creation_handler() @@ -664,4 +670,4 @@ def create_room(hs, room_id, creator_id): event, context = yield event_creation_handler.create_new_client_event(builder) - yield store.persist_event(event, context) + yield persistence_store.persist_event(event, context) -- cgit 1.4.1 From 39266a9c9f01515d6db3dc5372bb9463f8e81e4a Mon Sep 17 00:00:00 2001 From: Michael Kaye <1917473+michaelkaye@users.noreply.github.com> Date: Thu, 24 Oct 2019 17:55:53 +0100 Subject: Make user/room stats log line less verbose. --- synapse/storage/data_stores/main/stats.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse') diff --git a/synapse/storage/data_stores/main/stats.py b/synapse/storage/data_stores/main/stats.py index 5ab639b2ad..4d59b7833f 100644 --- a/synapse/storage/data_stores/main/stats.py +++ b/synapse/storage/data_stores/main/stats.py @@ -332,7 +332,7 @@ class StatsStore(StateDeltasStore): def _bulk_update_stats_delta_txn(txn): for stats_type, stats_updates in updates.items(): for stats_id, fields in stats_updates.items(): - logger.info( + logger.debug( "Updating %s stats for %s: %s", stats_type, stats_id, fields ) self._update_stats_delta_txn( -- cgit 1.4.1 From f85b9842f07aad66daf09d5ac6d943b430513434 Mon Sep 17 00:00:00 2001 From: Michael Kaye <1917473+michaelkaye@users.noreply.github.com> Date: Thu, 24 Oct 2019 18:08:45 +0100 Subject: Don't encode object as UTF-8 string if not needed. I believe that string formatting ~10-15 sized events will take a proportion of CPU time. --- synapse/crypto/event_signing.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) (limited to 'synapse') diff --git a/synapse/crypto/event_signing.py b/synapse/crypto/event_signing.py index 694fb2c816..ccaa8a9920 100644 --- a/synapse/crypto/event_signing.py +++ b/synapse/crypto/event_signing.py @@ -125,9 +125,11 @@ def compute_event_signature(event_dict, signature_name, signing_key): redact_json = prune_event_dict(event_dict) redact_json.pop("age_ts", None) redact_json.pop("unsigned", None) - logger.debug("Signing event: %s", encode_canonical_json(redact_json)) + if logger.isEnabledFor(logging.DEBUG): + logger.debug("Signing event: %s", encode_canonical_json(redact_json)) redact_json = sign_json(redact_json, signature_name, signing_key) - logger.debug("Signed event: %s", encode_canonical_json(redact_json)) + if logger.isEnabledFor(logging.DEBUG): + logger.debug("Signed event: %s", encode_canonical_json(redact_json)) return redact_json["signatures"] -- cgit 1.4.1 From 9eebc1e73b014c00f6c2d6e6760dfda809324a08 Mon Sep 17 00:00:00 2001 From: Michael Kaye <1917473+michaelkaye@users.noreply.github.com> Date: Thu, 24 Oct 2019 18:17:33 +0100 Subject: use %r to __repr__ objects This avoids calculating __repr__ unless we are going to log. --- synapse/federation/federation_client.py | 2 +- synapse/federation/transport/client.py | 4 ++-- synapse/storage/data_stores/main/event_federation.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) (limited to 'synapse') diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 5b22a39b7f..f5c1632916 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -196,7 +196,7 @@ class FederationClient(FederationBase): dest, room_id, extremities, limit ) - logger.debug("backfill transaction_data=%s", repr(transaction_data)) + logger.debug("backfill transaction_data=%r", transaction_data) room_version = yield self.store.get_room_version(room_id) format_ver = room_version_to_event_format(room_version) diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 7b18408144..920fa86853 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -122,10 +122,10 @@ class TransportLayerClient(object): Deferred: Results in a dict received from the remote homeserver. """ logger.debug( - "backfill dest=%s, room_id=%s, event_tuples=%s, limit=%s", + "backfill dest=%s, room_id=%s, event_tuples=%r, limit=%s", destination, room_id, - repr(event_tuples), + event_tuples, str(limit), ) diff --git a/synapse/storage/data_stores/main/event_federation.py b/synapse/storage/data_stores/main/event_federation.py index a470a48e0f..60bcfb768a 100644 --- a/synapse/storage/data_stores/main/event_federation.py +++ b/synapse/storage/data_stores/main/event_federation.py @@ -365,7 +365,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas def _get_backfill_events(self, txn, room_id, event_list, limit): logger.debug( - "_get_backfill_events: %s, %s, %s", room_id, repr(event_list), limit + "_get_backfill_events: %s, %r, %s", room_id, event_list, limit ) event_results = set() -- cgit 1.4.1 From 8f4a808d9df79300057ba9f55c9157790a069be3 Mon Sep 17 00:00:00 2001 From: Michael Kaye <1917473+michaelkaye@users.noreply.github.com> Date: Thu, 24 Oct 2019 18:31:53 +0100 Subject: Delay printf until logging is required. Using % will cause the string to be generated even if debugging is off. --- synapse/rest/client/v2_alpha/sync.py | 6 +++--- synapse/rest/media/v1/preview_url_resource.py | 12 ++++++------ 2 files changed, 9 insertions(+), 9 deletions(-) (limited to 'synapse') diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index a883c8adda..eb7351621b 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -112,9 +112,9 @@ class SyncRestServlet(RestServlet): full_state = parse_boolean(request, "full_state", default=False) logger.debug( - "/sync: user=%r, timeout=%r, since=%r," - " set_presence=%r, filter_id=%r, device_id=%r" - % (user, timeout, since, set_presence, filter_id, device_id) + "/sync: user=%r, timeout=%r, since=%r, " + "set_presence=%r, filter_id=%r, device_id=%r", + user, timeout, since, set_presence, filter_id, device_id ) request_key = (user, timeout, since, filter_id, full_state, device_id) diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 0c68c3aad5..13782d3120 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -117,8 +117,8 @@ class PreviewUrlResource(DirectServeResource): pattern = entry[attrib] value = getattr(url_tuple, attrib) logger.debug( - ("Matching attrib '%s' with value '%s' against" " pattern '%s'") - % (attrib, value, pattern) + "Matching attrib '%s' with value '%s' against" " pattern '%s'", + attrib, value, pattern ) if value is None: @@ -186,7 +186,7 @@ class PreviewUrlResource(DirectServeResource): media_info = yield self._download_url(url, user) - logger.debug("got media_info of '%s'" % media_info) + logger.debug("got media_info of '%s'", media_info) if _is_media(media_info["media_type"]): file_id = media_info["filesystem_id"] @@ -254,7 +254,7 @@ class PreviewUrlResource(DirectServeResource): og["og:image:width"] = dims["width"] og["og:image:height"] = dims["height"] else: - logger.warn("Couldn't get dims for %s" % og["og:image"]) + logger.warn("Couldn't get dims for %s", og["og:image"]) og["og:image"] = "mxc://%s/%s" % ( self.server_name, @@ -268,7 +268,7 @@ class PreviewUrlResource(DirectServeResource): logger.warn("Failed to find any OG data in %s", url) og = {} - logger.debug("Calculated OG for %s as %s" % (url, og)) + logger.debug("Calculated OG for %s as %s", url, og) jsonog = json.dumps(og) @@ -297,7 +297,7 @@ class PreviewUrlResource(DirectServeResource): with self.media_storage.store_into_file(file_info) as (f, fname, finish): try: - logger.debug("Trying to get url '%s'" % url) + logger.debug("Trying to get url '%s'", url) length, headers, uri, code = yield self.client.get_file( url, output_stream=f, max_size=self.max_spider_size ) -- cgit 1.4.1 From e4d98188da2bbe9211a6ee1c9479f7f30138ab46 Mon Sep 17 00:00:00 2001 From: Michael Kaye <1917473+michaelkaye@users.noreply.github.com> Date: Thu, 24 Oct 2019 18:43:13 +0100 Subject: Address codestyle concerns --- synapse/rest/client/v2_alpha/sync.py | 7 ++++++- synapse/rest/media/v1/preview_url_resource.py | 4 +++- synapse/storage/data_stores/main/event_federation.py | 4 +--- 3 files changed, 10 insertions(+), 5 deletions(-) (limited to 'synapse') diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index eb7351621b..541a6b0e10 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -114,7 +114,12 @@ class SyncRestServlet(RestServlet): logger.debug( "/sync: user=%r, timeout=%r, since=%r, " "set_presence=%r, filter_id=%r, device_id=%r", - user, timeout, since, set_presence, filter_id, device_id + user, + timeout, + since, + set_presence, + filter_id, + device_id, ) request_key = (user, timeout, since, filter_id, full_state, device_id) diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 13782d3120..094ebad770 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -118,7 +118,9 @@ class PreviewUrlResource(DirectServeResource): value = getattr(url_tuple, attrib) logger.debug( "Matching attrib '%s' with value '%s' against" " pattern '%s'", - attrib, value, pattern + attrib, + value, + pattern, ) if value is None: diff --git a/synapse/storage/data_stores/main/event_federation.py b/synapse/storage/data_stores/main/event_federation.py index 60bcfb768a..90bef0cd2c 100644 --- a/synapse/storage/data_stores/main/event_federation.py +++ b/synapse/storage/data_stores/main/event_federation.py @@ -364,9 +364,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas ) def _get_backfill_events(self, txn, room_id, event_list, limit): - logger.debug( - "_get_backfill_events: %s, %r, %s", room_id, event_list, limit - ) + logger.debug("_get_backfill_events: %s, %r, %s", room_id, event_list, limit) event_results = set() -- cgit 1.4.1 From 848cd388d96ec95b2598f1eaaf8967b8f064c08c Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Thu, 24 Oct 2019 21:13:01 -0400 Subject: delete keys when deleting backups --- synapse/storage/data_stores/main/e2e_room_keys.py | 8 +++ .../delta/56/delete_keys_from_deleted_backups.sql | 25 +++++++ tests/storage/test_e2e_room_keys.py | 76 ++++++++++++++++++++++ 3 files changed, 109 insertions(+) create mode 100644 synapse/storage/data_stores/main/schema/delta/56/delete_keys_from_deleted_backups.sql create mode 100644 tests/storage/test_e2e_room_keys.py (limited to 'synapse') diff --git a/synapse/storage/data_stores/main/e2e_room_keys.py b/synapse/storage/data_stores/main/e2e_room_keys.py index ef88e79293..1cbbae5b63 100644 --- a/synapse/storage/data_stores/main/e2e_room_keys.py +++ b/synapse/storage/data_stores/main/e2e_room_keys.py @@ -321,9 +321,17 @@ class EndToEndRoomKeyStore(SQLBaseStore): def _delete_e2e_room_keys_version_txn(txn): if version is None: this_version = self._get_current_version(txn, user_id) + if this_version is None: + raise StoreError(404, "No current backup version") else: this_version = version + self._simple_delete_txn( + txn, + table="e2e_room_keys", + keyvalues={"user_id": user_id, "version": this_version}, + ) + return self._simple_update_one_txn( txn, table="e2e_room_keys_versions", diff --git a/synapse/storage/data_stores/main/schema/delta/56/delete_keys_from_deleted_backups.sql b/synapse/storage/data_stores/main/schema/delta/56/delete_keys_from_deleted_backups.sql new file mode 100644 index 0000000000..1d2ddb1b1a --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/56/delete_keys_from_deleted_backups.sql @@ -0,0 +1,25 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* delete room keys that belong to deleted room key version, or to room key + * versions that don't exist (anymore) + */ +DELETE FROM e2e_room_keys +WHERE version NOT IN ( + SELECT version + FROM e2e_room_keys_versions + WHERE e2e_room_keys.user_id = e2e_room_keys_versions.user_id + AND e2e_room_keys_versions.deleted = 0 +); diff --git a/tests/storage/test_e2e_room_keys.py b/tests/storage/test_e2e_room_keys.py new file mode 100644 index 0000000000..ef4e7ce9d6 --- /dev/null +++ b/tests/storage/test_e2e_room_keys.py @@ -0,0 +1,76 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from tests import unittest, utils + +# sample room_key data for use in the tests +room_key = { + "first_message_index": 1, + "forwarded_count": 1, + "is_verified": False, + "session_data": "SSBBTSBBIEZJU0gK", +} + + +class E2eRoomKeysHandlerTestCase(unittest.TestCase): + def __init__(self, *args, **kwargs): + super(E2eRoomKeysHandlerTestCase, self).__init__(*args, **kwargs) + self.hs = None # type: synapse.server.HomeServer + self.store = None # type: synapse.storage.DataStore + + @defer.inlineCallbacks + def setUp(self): + hs = yield utils.setup_test_homeserver(self.addCleanup) + + self.store = hs.get_datastore() + + @defer.inlineCallbacks + def test_room_keys_version_delete(self): + # test that deleting a room key backup deletes the keys + version1 = yield self.store.create_e2e_room_keys_version( + "user_id", {"algorithm": "rot13", "auth_data": {}} + ) + + yield self.store.set_e2e_room_key( + "user_id", version1, "room", "session", room_key + ) + + version2 = yield self.store.create_e2e_room_keys_version( + "user_id", {"algorithm": "rot13", "auth_data": {}} + ) + + yield self.store.set_e2e_room_key( + "user_id", version2, "room", "session", room_key + ) + + # make sure the keys were stored properly + keys = yield self.store.get_e2e_room_keys("user_id", version1) + self.assertEqual(len(keys["rooms"]), 1) + + keys = yield self.store.get_e2e_room_keys("user_id", version2) + self.assertEqual(len(keys["rooms"]), 1) + + # delete version1 + yield self.store.delete_e2e_room_keys_version("user_id", version1) + + # make sure the key from version1 is gone, and the key from version2 is + # still there + keys = yield self.store.get_e2e_room_keys("user_id", version1) + self.assertEqual(len(keys["rooms"]), 0) + + keys = yield self.store.get_e2e_room_keys("user_id", version2) + self.assertEqual(len(keys["rooms"]), 1) -- cgit 1.4.1 From ff05c9b760ba5736a189b320c2e0d4592d0072a4 Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Thu, 24 Oct 2019 21:46:11 -0400 Subject: don't error if federation query doesn't have cross-signing keys --- synapse/handlers/e2e_keys.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) (limited to 'synapse') diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index f6aa9a940b..4ab75a351e 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -218,13 +218,15 @@ class E2eKeysHandler(object): if user_id in destination_query: results[user_id] = keys - for user_id, key in remote_result["master_keys"].items(): - if user_id in destination_query: - cross_signing_keys["master_keys"][user_id] = key - - for user_id, key in remote_result["self_signing_keys"].items(): - if user_id in destination_query: - cross_signing_keys["self_signing_keys"][user_id] = key + if "master_keys" in remote_result: + for user_id, key in remote_result["master_keys"].items(): + if user_id in destination_query: + cross_signing_keys["master_keys"][user_id] = key + + if "self_signing_keys" in remote_result: + for user_id, key in remote_result["self_signing_keys"].items(): + if user_id in destination_query: + cross_signing_keys["self_signing_keys"][user_id] = key except Exception as e: failure = _exception_to_failure(e) -- cgit 1.4.1 From 8ac766c44a18418d566e99b2b20dcec9e1cc5924 Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Thu, 24 Oct 2019 22:14:58 -0400 Subject: make notification of signatures work with workers --- synapse/replication/slave/storage/devices.py | 1 + synapse/storage/data_stores/main/__init__.py | 5 ++++- synapse/storage/data_stores/main/devices.py | 13 ++++++++++++- 3 files changed, 17 insertions(+), 2 deletions(-) (limited to 'synapse') diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py index 61557665a7..c9f99a8405 100644 --- a/synapse/replication/slave/storage/devices.py +++ b/synapse/replication/slave/storage/devices.py @@ -56,6 +56,7 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto def _invalidate_caches_for_devices(self, token, user_id, destination): self._device_list_stream_cache.entity_has_changed(user_id, token) + self._user_signature_stream_cache.entity_has_changed(user_id, token) if destination: self._device_list_federation_stream_cache.entity_has_changed( diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/data_stores/main/__init__.py index b185ba0b3e..2d8814f4a7 100644 --- a/synapse/storage/data_stores/main/__init__.py +++ b/synapse/storage/data_stores/main/__init__.py @@ -139,7 +139,10 @@ class DataStore( db_conn, "public_room_list_stream", "stream_id" ) self._device_list_id_gen = StreamIdGenerator( - db_conn, "device_lists_stream", "stream_id" + db_conn, + "device_lists_stream", + "stream_id", + extra_tables=[("user_signature_stream", "stream_id")], ) self._cross_signing_id_gen = StreamIdGenerator( db_conn, "e2e_cross_signing_keys", "stream_id" diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py index f7a3542348..a96f09ea7b 100644 --- a/synapse/storage/data_stores/main/devices.py +++ b/synapse/storage/data_stores/main/devices.py @@ -543,9 +543,20 @@ class DeviceWorkerStore(SQLBaseStore): LEFT JOIN device_lists_outbound_pokes USING (stream_id, user_id, device_id) WHERE ? < stream_id AND stream_id <= ? GROUP BY user_id, destination + UNION + SELECT MAX(stream_id) AS stream_id, from_user_id AS user_id, NULL AS destination + FROM user_signature_stream + WHERE ? < stream_id AND stream_id <= ? + GROUP BY user_id """ return self._execute( - "get_all_device_list_changes_for_remotes", None, sql, from_key, to_key + "get_all_device_list_changes_for_remotes", + None, + sql, + from_key, + to_key, + from_key, + to_key, ) @cached(max_entries=10000) -- cgit 1.4.1 From 9aee28927b22a16ea0699c3f73fbc58121511630 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 28 Oct 2019 12:29:55 +0000 Subject: Convert EventContext to attrs (#6218) * make EventContext use an attr --- changelog.d/6218.misc | 1 + synapse/events/snapshot.py | 100 ++++++++++++------------------ synapse/storage/data_stores/main/state.py | 7 ++- 3 files changed, 46 insertions(+), 62 deletions(-) create mode 100644 changelog.d/6218.misc (limited to 'synapse') diff --git a/changelog.d/6218.misc b/changelog.d/6218.misc new file mode 100644 index 0000000000..49d10c36cf --- /dev/null +++ b/changelog.d/6218.misc @@ -0,0 +1 @@ +Convert EventContext to an attrs. diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index acbcbeeced..27cd8a63ff 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -12,9 +12,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from six import iteritems +import attr from frozendict import frozendict from twisted.internet import defer @@ -22,7 +22,8 @@ from twisted.internet import defer from synapse.logging.context import make_deferred_yieldable, run_in_background -class EventContext(object): +@attr.s(slots=True) +class EventContext: """ Attributes: state_group (int|None): state group id, if the state has been stored @@ -31,9 +32,6 @@ class EventContext(object): rejected (bool|str): A rejection reason if the event was rejected, else False - push_actions (list[(str, list[object])]): list of (user_id, actions) - tuples - prev_group (int): Previously persisted state group. ``None`` for an outlier. delta_ids (dict[(str, str), str]): Delta from ``prev_group``. @@ -42,6 +40,8 @@ class EventContext(object): prev_state_events (?): XXX: is this ever set to anything other than the empty list? + app_service: FIXME + _current_state_ids (dict[(str, str), str]|None): The current state map including the current event. None if outlier or we haven't fetched the state from DB yet. @@ -67,49 +67,33 @@ class EventContext(object): Only set when state has not been fetched yet. """ - __slots__ = [ - "state_group", - "rejected", - "prev_group", - "delta_ids", - "prev_state_events", - "app_service", - "_current_state_ids", - "_prev_state_ids", - "_prev_state_id", - "_event_type", - "_event_state_key", - "_fetching_state_deferred", - ] - - def __init__(self): - self.prev_state_events = [] - self.rejected = False - self.app_service = None + state_group = attr.ib(default=None) + rejected = attr.ib(default=False) + prev_group = attr.ib(default=None) + delta_ids = attr.ib(default=None) + prev_state_events = attr.ib(default=attr.Factory(list)) + app_service = attr.ib(default=None) + + _current_state_ids = attr.ib(default=None) + _prev_state_ids = attr.ib(default=None) + _prev_state_id = attr.ib(default=None) + + _event_type = attr.ib(default=None) + _event_state_key = attr.ib(default=None) + _fetching_state_deferred = attr.ib(default=None) @staticmethod def with_state( state_group, current_state_ids, prev_state_ids, prev_group=None, delta_ids=None ): - context = EventContext() - - # The current state including the current event - context._current_state_ids = current_state_ids - # The current state excluding the current event - context._prev_state_ids = prev_state_ids - context.state_group = state_group - - context._prev_state_id = None - context._event_type = None - context._event_state_key = None - context._fetching_state_deferred = defer.succeed(None) - - # A previously persisted state group and a delta between that - # and this state. - context.prev_group = prev_group - context.delta_ids = delta_ids - - return context + return EventContext( + current_state_ids=current_state_ids, + prev_state_ids=prev_state_ids, + state_group=state_group, + fetching_state_deferred=defer.succeed(None), + prev_group=prev_group, + delta_ids=delta_ids, + ) @defer.inlineCallbacks def serialize(self, event, store): @@ -157,24 +141,18 @@ class EventContext(object): Returns: EventContext """ - context = EventContext() - - # We use the state_group and prev_state_id stuff to pull the - # current_state_ids out of the DB and construct prev_state_ids. - context._prev_state_id = input["prev_state_id"] - context._event_type = input["event_type"] - context._event_state_key = input["event_state_key"] - - context._current_state_ids = None - context._prev_state_ids = None - context._fetching_state_deferred = None - - context.state_group = input["state_group"] - context.prev_group = input["prev_group"] - context.delta_ids = _decode_state_dict(input["delta_ids"]) - - context.rejected = input["rejected"] - context.prev_state_events = input["prev_state_events"] + context = EventContext( + # We use the state_group and prev_state_id stuff to pull the + # current_state_ids out of the DB and construct prev_state_ids. + prev_state_id=input["prev_state_id"], + event_type=input["event_type"], + event_state_key=input["event_state_key"], + state_group=input["state_group"], + prev_group=input["prev_group"], + delta_ids=_decode_state_dict(input["delta_ids"]), + rejected=input["rejected"], + prev_state_events=input["prev_state_events"], + ) app_service_id = input["app_service_id"] if app_service_id: diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py index d54442e5fa..9b2207075b 100644 --- a/synapse/storage/data_stores/main/state.py +++ b/synapse/storage/data_stores/main/state.py @@ -15,6 +15,7 @@ import logging from collections import namedtuple +from typing import Iterable, Tuple from six import iteritems, itervalues from six.moves import range @@ -23,6 +24,8 @@ from twisted.internet import defer from synapse.api.constants import EventTypes from synapse.api.errors import NotFoundError +from synapse.events import EventBase +from synapse.events.snapshot import EventContext from synapse.storage._base import SQLBaseStore from synapse.storage.background_updates import BackgroundUpdateStore from synapse.storage.data_stores.main.events_worker import EventsWorkerStore @@ -1215,7 +1218,9 @@ class StateStore(StateGroupWorkerStore, StateBackgroundUpdateStore): def __init__(self, db_conn, hs): super(StateStore, self).__init__(db_conn, hs) - def _store_event_state_mappings_txn(self, txn, events_and_contexts): + def _store_event_state_mappings_txn( + self, txn, events_and_contexts: Iterable[Tuple[EventBase, EventContext]] + ): state_groups = {} for event, context in events_and_contexts: if event.internal_metadata.is_outlier(): -- cgit 1.4.1 From d0d8a22c13427cce341dbb7ae1d92d2c0ae709c3 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 28 Oct 2019 13:33:04 +0000 Subject: Quick fix to ensure cache descriptors always return deferreds --- synapse/push/bulk_push_rule_evaluator.py | 2 +- synapse/storage/data_stores/main/roommember.py | 2 +- synapse/util/caches/descriptors.py | 4 ++-- tests/util/caches/test_descriptors.py | 4 ++-- 4 files changed, 6 insertions(+), 6 deletions(-) (limited to 'synapse') diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 22491f3700..2bbdd11941 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -79,7 +79,7 @@ class BulkPushRuleEvaluator(object): dict of user_id -> push_rules """ room_id = event.room_id - rules_for_room = self._get_rules_for_room(room_id) + rules_for_room = yield self._get_rules_for_room(room_id) rules_by_user = yield rules_for_room.get_rules(event, context) diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py index e47ab604dd..bc04bfd7d4 100644 --- a/synapse/storage/data_stores/main/roommember.py +++ b/synapse/storage/data_stores/main/roommember.py @@ -720,7 +720,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): # See bulk_get_push_rules_for_room for how we work around this. assert state_group is not None - cache = self._get_joined_hosts_cache(room_id) + cache = yield self._get_joined_hosts_cache(room_id) joined_hosts = yield cache.get_destinations(state_entry) return joined_hosts diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 5ac2530a6a..5a8da449b2 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -438,7 +438,7 @@ class CacheDescriptor(_CacheDescriptorBase): if isinstance(cached_result_d, ObservableDeferred): observer = cached_result_d.observe() else: - observer = cached_result_d + observer = defer.succeed(cached_result_d) except KeyError: ret = defer.maybeDeferred( @@ -618,7 +618,7 @@ class CacheListDescriptor(_CacheDescriptorBase): ) return make_deferred_yieldable(d) else: - return results + return defer.succeed(results) obj.__dict__[self.orig.__name__] = wrapped diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index 5713870f48..f907903511 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -325,9 +325,9 @@ class DescriptorTestCase(unittest.TestCase): self.assertEqual(len(obj.fn.cache.cache), 3) r = obj.fn(1, 2) - self.assertEqual(r, ["spam", "eggs"]) + self.assertEqual(r.result, ["spam", "eggs"]) r = obj.fn(1, 3) - self.assertEqual(r, ["chips"]) + self.assertEqual(r.result, ["chips"]) obj.mock.assert_not_called() def test_cache_iterable_with_sync_exception(self): -- cgit 1.4.1 From 14504ad5736ae230d759d8fadccd8babb42fa548 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Mon, 28 Oct 2019 17:45:32 +0000 Subject: Add CI for synapse_port_db (#6140) This adds: * a test sqlite database * a configuration file for the sqlite database * a configuration file for a postgresql database (using the credentials in `.buildkite/docker-compose.pyXX.pgXX.yaml`) as well as a new script named `.buildkite/scripts/test_synapse_port_db.sh` that: 1. installs Synapse 2. updates the test sqlite database to the latest schema and runs background updates on it 3. creates an empty postgresql database 4. run the `synapse_port_db` script to migrate the test sqlite database to the empty postgresql database (with coverage) Step `2` is done via a new script located at `scripts-dev/update_database`. The test sqlite database is extracted from a SyTest run, so that it can be considered as an actual homeserver's database with actual data in it. --- .buildkite/postgres-config.yaml | 19 +++++ .buildkite/scripts/test_synapse_port_db.sh | 29 +++++++ .buildkite/sqlite-config.yaml | 16 ++++ .buildkite/test_db.db | Bin 0 -> 18825216 bytes changelog.d/6140.misc | 1 + scripts-dev/update_database | 125 +++++++++++++++++++++++++++++ synapse/storage/background_updates.py | 9 ++- 7 files changed, 196 insertions(+), 3 deletions(-) create mode 100644 .buildkite/postgres-config.yaml create mode 100755 .buildkite/scripts/test_synapse_port_db.sh create mode 100644 .buildkite/sqlite-config.yaml create mode 100644 .buildkite/test_db.db create mode 100644 changelog.d/6140.misc create mode 100755 scripts-dev/update_database (limited to 'synapse') diff --git a/.buildkite/postgres-config.yaml b/.buildkite/postgres-config.yaml new file mode 100644 index 0000000000..23db43fac9 --- /dev/null +++ b/.buildkite/postgres-config.yaml @@ -0,0 +1,19 @@ +# Configuration file used for testing the 'synapse_port_db' script. +# Tells the script to connect to the postgresql database that will be available in the +# CI's Docker setup at the point where this file is considered. +server_name: "test" + +report_stats: false + +database: + name: "psycopg2" + args: + user: postgres + host: postgres + password: postgres + database: synapse + +# Suppress the key server warning. +trusted_key_servers: + - server_name: "matrix.org" +suppress_key_server_warning: true diff --git a/.buildkite/scripts/test_synapse_port_db.sh b/.buildkite/scripts/test_synapse_port_db.sh new file mode 100755 index 0000000000..7defd47bc6 --- /dev/null +++ b/.buildkite/scripts/test_synapse_port_db.sh @@ -0,0 +1,29 @@ +#!/bin/bash +# +# Test script for 'synapse_port_db', which creates a virtualenv, installs Synapse along +# with additional dependencies needed for the test (such as coverage or the PostgreSQL +# driver), update the schema of the test SQLite database and run background updates on it, +# create an empty test database in PostgreSQL, then run the 'synapse_port_db' script to +# test porting the SQLite database to the PostgreSQL database (with coverage). + +set -xe +cd `dirname $0`/../.. + +# Create a virtualenv and use it. +virtualenv env +source env/bin/activate + +# Install dependencies for this test. +pip install psycopg2 coverage coverage-enable-subprocess + +# Install Synapse itself. This won't update any libraries. +pip install -e . + +# Make sure the SQLite3 database is using the latest schema and has no pending background update. +scripts-dev/update_database --database-config .buildkite/sqlite-config.yaml + +# Create the PostgreSQL database. +PGPASSWORD=postgres createdb -h postgres -U postgres synapse + +# Run the script +coverage run scripts/synapse_port_db --sqlite-database .buildkite/test_db.db --postgres-config .buildkite/postgres-config.yaml diff --git a/.buildkite/sqlite-config.yaml b/.buildkite/sqlite-config.yaml new file mode 100644 index 0000000000..56503cc4ce --- /dev/null +++ b/.buildkite/sqlite-config.yaml @@ -0,0 +1,16 @@ +# Configuration file used for testing the 'synapse_port_db' script. +# Tells the 'update_database' script to connect to the test SQLite database to upgrade its +# schema and run background updates on it. +server_name: "test" + +report_stats: false + +database: + name: "sqlite3" + args: + database: ".buildkite/test_db.db" + +# Suppress the key server warning. +trusted_key_servers: + - server_name: "matrix.org" +suppress_key_server_warning: true diff --git a/.buildkite/test_db.db b/.buildkite/test_db.db new file mode 100644 index 0000000000..f20567ba73 Binary files /dev/null and b/.buildkite/test_db.db differ diff --git a/changelog.d/6140.misc b/changelog.d/6140.misc new file mode 100644 index 0000000000..0feb97ec61 --- /dev/null +++ b/changelog.d/6140.misc @@ -0,0 +1 @@ +Add a CI job to test the `synapse_port_db` script. \ No newline at end of file diff --git a/scripts-dev/update_database b/scripts-dev/update_database new file mode 100755 index 0000000000..10166583e1 --- /dev/null +++ b/scripts-dev/update_database @@ -0,0 +1,125 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +import sys + +import yaml + +from twisted.internet import defer, reactor + +from synapse.config.homeserver import HomeServerConfig +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.server import HomeServer +from synapse.storage.engines import create_engine +from synapse.storage import DataStore +from synapse.storage.prepare_database import prepare_database + +logger = logging.getLogger("update_database") + + +class MockHomeserver(HomeServer): + DATASTORE_CLASS = DataStore + + def __init__(self, config, database_engine, db_conn, **kwargs): + super(MockHomeserver, self).__init__( + config.server_name, + reactor=reactor, + config=config, + database_engine=database_engine, + **kwargs + ) + + self.database_engine = database_engine + self.db_conn = db_conn + + def get_db_conn(self): + return self.db_conn + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description=( + "Updates a synapse database to the latest schema and runs background updates" + " on it." + ) + ) + parser.add_argument("-v", action='store_true') + parser.add_argument( + "--database-config", + type=argparse.FileType('r'), + required=True, + help="A database config file for either a SQLite3 database or a PostgreSQL one.", + ) + + args = parser.parse_args() + + logging_config = { + "level": logging.DEBUG if args.v else logging.INFO, + "format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s", + } + + logging.basicConfig(**logging_config) + + # Load, process and sanity-check the config. + hs_config = yaml.safe_load(args.database_config) + + if "database" not in hs_config: + sys.stderr.write("The configuration file must have a 'database' section.\n") + sys.exit(4) + + config = HomeServerConfig() + config.parse_config_dict(hs_config, "", "") + + # Create the database engine and a connection to it. + database_engine = create_engine(config.database_config) + db_conn = database_engine.module.connect( + **{ + k: v + for k, v in config.database_config.get("args", {}).items() + if not k.startswith("cp_") + } + ) + + # Update the database to the latest schema. + prepare_database(db_conn, database_engine, config=config) + db_conn.commit() + + # Instantiate and initialise the homeserver object. + hs = MockHomeserver( + config, + database_engine, + db_conn, + db_config=config.database_config, + ) + # setup instantiates the store within the homeserver object. + hs.setup() + store = hs.get_datastore() + + @defer.inlineCallbacks + def run_background_updates(): + yield store.run_background_updates(sleep=False) + # Stop the reactor to exit the script once every background update is run. + reactor.stop() + + # Apply all background updates on the database. + reactor.callWhenRunning(lambda: run_as_background_process( + "background_updates", run_background_updates + )) + + reactor.run() + diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index 80b57a948c..37d469ffd7 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -94,13 +94,16 @@ class BackgroundUpdateStore(SQLBaseStore): self._all_done = False def start_doing_background_updates(self): - run_as_background_process("background_updates", self._run_background_updates) + run_as_background_process("background_updates", self.run_background_updates) @defer.inlineCallbacks - def _run_background_updates(self): + def run_background_updates(self, sleep=True): logger.info("Starting background schema updates") while True: - yield self.hs.get_clock().sleep(self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.0) + if sleep: + yield self.hs.get_clock().sleep( + self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.0 + ) try: result = yield self.do_next_background_update( -- cgit 1.4.1 From e6c7e239ef70c851f21d39652ddedddfddf5f251 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 29 Oct 2019 11:48:24 +0000 Subject: Update docstring --- synapse/util/caches/descriptors.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) (limited to 'synapse') diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 5a8da449b2..0e8da27f53 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -482,9 +482,8 @@ class CacheListDescriptor(_CacheDescriptorBase): Given a list of keys it looks in the cache to find any hits, then passes the list of missing keys to the wrapped function. - Once wrapped, the function returns either a Deferred which resolves to - the list of results, or (if all results were cached), just the list of - results. + Once wrapped, the function returns a Deferred which resolves to the list + of results. """ def __init__( -- cgit 1.4.1 From e577a4b2ad01233ed70b3ab0a9e52aab42e88231 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 29 Oct 2019 13:00:51 +0000 Subject: Port replication http server endpoints to async/await --- synapse/replication/http/_base.py | 6 +++--- synapse/replication/http/federation.py | 24 +++++++++--------------- synapse/replication/http/login.py | 7 ++----- synapse/replication/http/membership.py | 14 +++++--------- synapse/replication/http/register.py | 12 ++++-------- synapse/replication/http/send_event.py | 7 +++---- 6 files changed, 26 insertions(+), 44 deletions(-) (limited to 'synapse') diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index 03560c1f0e..9be37cd998 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -110,14 +110,14 @@ class ReplicationEndpoint(object): return {} @abc.abstractmethod - def _handle_request(self, request, **kwargs): + async def _handle_request(self, request, **kwargs): """Handle incoming request. This is called with the request object and PATH_ARGS. Returns: - Deferred[dict]: A JSON serialisable dict to be used as response - body of request. + tuple[int, dict]: HTTP status code and a JSON serialisable dict + to be used as response body of request. """ pass diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py index 2f16955954..9af4e7e173 100644 --- a/synapse/replication/http/federation.py +++ b/synapse/replication/http/federation.py @@ -82,8 +82,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): return payload - @defer.inlineCallbacks - def _handle_request(self, request): + async def _handle_request(self, request): with Measure(self.clock, "repl_fed_send_events_parse"): content = parse_json_object_from_request(request) @@ -101,15 +100,13 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): EventType = event_type_from_format_version(format_ver) event = EventType(event_dict, internal_metadata, rejected_reason) - context = yield EventContext.deserialize( - self.store, event_payload["context"] - ) + context = EventContext.deserialize(self.store, event_payload["context"]) event_and_contexts.append((event, context)) logger.info("Got %d events from federation", len(event_and_contexts)) - yield self.federation_handler.persist_events_and_notify( + await self.federation_handler.persist_events_and_notify( event_and_contexts, backfilled ) @@ -144,8 +141,7 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint): def _serialize_payload(edu_type, origin, content): return {"origin": origin, "content": content} - @defer.inlineCallbacks - def _handle_request(self, request, edu_type): + async def _handle_request(self, request, edu_type): with Measure(self.clock, "repl_fed_send_edu_parse"): content = parse_json_object_from_request(request) @@ -154,7 +150,7 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint): logger.info("Got %r edu from %s", edu_type, origin) - result = yield self.registry.on_edu(edu_type, origin, edu_content) + result = await self.registry.on_edu(edu_type, origin, edu_content) return 200, result @@ -193,8 +189,7 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint): """ return {"args": args} - @defer.inlineCallbacks - def _handle_request(self, request, query_type): + async def _handle_request(self, request, query_type): with Measure(self.clock, "repl_fed_query_parse"): content = parse_json_object_from_request(request) @@ -202,7 +197,7 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint): logger.info("Got %r query", query_type) - result = yield self.registry.on_query(query_type, args) + result = await self.registry.on_query(query_type, args) return 200, result @@ -234,9 +229,8 @@ class ReplicationCleanRoomRestServlet(ReplicationEndpoint): """ return {} - @defer.inlineCallbacks - def _handle_request(self, request, room_id): - yield self.store.clean_room_for_join(room_id) + async def _handle_request(self, request, room_id): + await self.store.clean_room_for_join(room_id) return 200, {} diff --git a/synapse/replication/http/login.py b/synapse/replication/http/login.py index 786f5232b2..798b9d3af5 100644 --- a/synapse/replication/http/login.py +++ b/synapse/replication/http/login.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.http.servlet import parse_json_object_from_request from synapse.replication.http._base import ReplicationEndpoint @@ -52,15 +50,14 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint): "is_guest": is_guest, } - @defer.inlineCallbacks - def _handle_request(self, request, user_id): + async def _handle_request(self, request, user_id): content = parse_json_object_from_request(request) device_id = content["device_id"] initial_display_name = content["initial_display_name"] is_guest = content["is_guest"] - device_id, access_token = yield self.registration_handler.register_device( + device_id, access_token = await self.registration_handler.register_device( user_id, device_id, initial_display_name, is_guest ) diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py index b9ce3477ad..b5f5f13a62 100644 --- a/synapse/replication/http/membership.py +++ b/synapse/replication/http/membership.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.http.servlet import parse_json_object_from_request from synapse.replication.http._base import ReplicationEndpoint from synapse.types import Requester, UserID @@ -65,8 +63,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint): "content": content, } - @defer.inlineCallbacks - def _handle_request(self, request, room_id, user_id): + async def _handle_request(self, request, room_id, user_id): content = parse_json_object_from_request(request) remote_room_hosts = content["remote_room_hosts"] @@ -79,7 +76,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint): logger.info("remote_join: %s into room: %s", user_id, room_id) - yield self.federation_handler.do_invite_join( + await self.federation_handler.do_invite_join( remote_room_hosts, room_id, user_id, event_content ) @@ -123,8 +120,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): "remote_room_hosts": remote_room_hosts, } - @defer.inlineCallbacks - def _handle_request(self, request, room_id, user_id): + async def _handle_request(self, request, room_id, user_id): content = parse_json_object_from_request(request) remote_room_hosts = content["remote_room_hosts"] @@ -137,7 +133,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): logger.info("remote_reject_invite: %s out of room: %s", user_id, room_id) try: - event = yield self.federation_handler.do_remotely_reject_invite( + event = await self.federation_handler.do_remotely_reject_invite( remote_room_hosts, room_id, user_id ) ret = event.get_pdu_json() @@ -150,7 +146,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): # logger.warn("Failed to reject invite: %s", e) - yield self.store.locally_reject_invite(user_id, room_id) + await self.store.locally_reject_invite(user_id, room_id) ret = {} return 200, ret diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py index 38260256cf..915cfb9430 100644 --- a/synapse/replication/http/register.py +++ b/synapse/replication/http/register.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.http.servlet import parse_json_object_from_request from synapse.replication.http._base import ReplicationEndpoint @@ -74,11 +72,10 @@ class ReplicationRegisterServlet(ReplicationEndpoint): "address": address, } - @defer.inlineCallbacks - def _handle_request(self, request, user_id): + async def _handle_request(self, request, user_id): content = parse_json_object_from_request(request) - yield self.registration_handler.register_with_store( + await self.registration_handler.register_with_store( user_id=user_id, password_hash=content["password_hash"], was_guest=content["was_guest"], @@ -117,14 +114,13 @@ class ReplicationPostRegisterActionsServlet(ReplicationEndpoint): """ return {"auth_result": auth_result, "access_token": access_token} - @defer.inlineCallbacks - def _handle_request(self, request, user_id): + async def _handle_request(self, request, user_id): content = parse_json_object_from_request(request) auth_result = content["auth_result"] access_token = content["access_token"] - yield self.registration_handler.post_registration_actions( + await self.registration_handler.post_registration_actions( user_id=user_id, auth_result=auth_result, access_token=access_token ) diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py index adb9b2f7f4..9bafd60b14 100644 --- a/synapse/replication/http/send_event.py +++ b/synapse/replication/http/send_event.py @@ -87,8 +87,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): return payload - @defer.inlineCallbacks - def _handle_request(self, request, event_id): + async def _handle_request(self, request, event_id): with Measure(self.clock, "repl_send_event_parse"): content = parse_json_object_from_request(request) @@ -101,7 +100,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): event = EventType(event_dict, internal_metadata, rejected_reason) requester = Requester.deserialize(self.store, content["requester"]) - context = yield EventContext.deserialize(self.store, content["context"]) + context = EventContext.deserialize(self.store, content["context"]) ratelimit = content["ratelimit"] extra_users = [UserID.from_string(u) for u in content["extra_users"]] @@ -113,7 +112,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): "Got event to send with ID: %s into room: %s", event.event_id, event.room_id ) - yield self.event_creation_handler.persist_and_notify_client_event( + await self.event_creation_handler.persist_and_notify_client_event( requester, event, context, ratelimit=ratelimit, extra_users=extra_users ) -- cgit 1.4.1 From 9be41bc12159c52e4a700950569d24b0d9a3491b Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 29 Oct 2019 13:09:24 +0000 Subject: Port room rest handlers to async/await --- synapse/rest/client/v1/room.py | 166 ++++++++++++++++++----------------------- 1 file changed, 72 insertions(+), 94 deletions(-) (limited to 'synapse') diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 9c1d41421c..86bbcc0eea 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -21,8 +21,6 @@ from six.moves.urllib import parse as urlparse from canonicaljson import json -from twisted.internet import defer - from synapse.api.constants import EventTypes, Membership from synapse.api.errors import ( AuthError, @@ -85,11 +83,10 @@ class RoomCreateRestServlet(TransactionRestServlet): set_tag("txn_id", txn_id) return self.txns.fetch_or_execute_request(request, self.on_POST, request) - @defer.inlineCallbacks - def on_POST(self, request): - requester = yield self.auth.get_user_by_req(request) + async def on_POST(self, request): + requester = await self.auth.get_user_by_req(request) - info = yield self._room_creation_handler.create_room( + info = await self._room_creation_handler.create_room( requester, self.get_room_config(request) ) @@ -154,15 +151,14 @@ class RoomStateEventRestServlet(TransactionRestServlet): def on_PUT_no_state_key(self, request, room_id, event_type): return self.on_PUT(request, room_id, event_type, "") - @defer.inlineCallbacks - def on_GET(self, request, room_id, event_type, state_key): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, room_id, event_type, state_key): + requester = await self.auth.get_user_by_req(request, allow_guest=True) format = parse_string( request, "format", default="content", allowed_values=["content", "event"] ) msg_handler = self.message_handler - data = yield msg_handler.get_room_data( + data = await msg_handler.get_room_data( user_id=requester.user.to_string(), room_id=room_id, event_type=event_type, @@ -179,9 +175,8 @@ class RoomStateEventRestServlet(TransactionRestServlet): elif format == "content": return 200, data.get_dict()["content"] - @defer.inlineCallbacks - def on_PUT(self, request, room_id, event_type, state_key, txn_id=None): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, room_id, event_type, state_key, txn_id=None): + requester = await self.auth.get_user_by_req(request) if txn_id: set_tag("txn_id", txn_id) @@ -200,7 +195,7 @@ class RoomStateEventRestServlet(TransactionRestServlet): if event_type == EventTypes.Member: membership = content.get("membership", None) - event = yield self.room_member_handler.update_membership( + event = await self.room_member_handler.update_membership( requester, target=UserID.from_string(state_key), room_id=room_id, @@ -208,7 +203,7 @@ class RoomStateEventRestServlet(TransactionRestServlet): content=content, ) else: - event = yield self.event_creation_handler.create_and_send_nonmember_event( + event = await self.event_creation_handler.create_and_send_nonmember_event( requester, event_dict, txn_id=txn_id ) @@ -231,9 +226,8 @@ class RoomSendEventRestServlet(TransactionRestServlet): PATTERNS = "/rooms/(?P[^/]*)/send/(?P[^/]*)" register_txn_path(self, PATTERNS, http_server, with_get=True) - @defer.inlineCallbacks - def on_POST(self, request, room_id, event_type, txn_id=None): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_POST(self, request, room_id, event_type, txn_id=None): + requester = await self.auth.get_user_by_req(request, allow_guest=True) content = parse_json_object_from_request(request) event_dict = { @@ -246,7 +240,7 @@ class RoomSendEventRestServlet(TransactionRestServlet): if b"ts" in request.args and requester.app_service: event_dict["origin_server_ts"] = parse_integer(request, "ts", 0) - event = yield self.event_creation_handler.create_and_send_nonmember_event( + event = await self.event_creation_handler.create_and_send_nonmember_event( requester, event_dict, txn_id=txn_id ) @@ -276,9 +270,8 @@ class JoinRoomAliasServlet(TransactionRestServlet): PATTERNS = "/join/(?P[^/]*)" register_txn_path(self, PATTERNS, http_server) - @defer.inlineCallbacks - def on_POST(self, request, room_identifier, txn_id=None): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_POST(self, request, room_identifier, txn_id=None): + requester = await self.auth.get_user_by_req(request, allow_guest=True) try: content = parse_json_object_from_request(request) @@ -298,14 +291,14 @@ class JoinRoomAliasServlet(TransactionRestServlet): elif RoomAlias.is_valid(room_identifier): handler = self.room_member_handler room_alias = RoomAlias.from_string(room_identifier) - room_id, remote_room_hosts = yield handler.lookup_room_alias(room_alias) + room_id, remote_room_hosts = await handler.lookup_room_alias(room_alias) room_id = room_id.to_string() else: raise SynapseError( 400, "%s was not legal room ID or room alias" % (room_identifier,) ) - yield self.room_member_handler.update_membership( + await self.room_member_handler.update_membership( requester=requester, target=requester.user, room_id=room_id, @@ -335,12 +328,11 @@ class PublicRoomListRestServlet(TransactionRestServlet): self.hs = hs self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request): + async def on_GET(self, request): server = parse_string(request, "server", default=None) try: - yield self.auth.get_user_by_req(request, allow_guest=True) + await self.auth.get_user_by_req(request, allow_guest=True) except InvalidClientCredentialsError as e: # Option to allow servers to require auth when accessing # /publicRooms via CS API. This is especially helpful in private @@ -367,19 +359,18 @@ class PublicRoomListRestServlet(TransactionRestServlet): handler = self.hs.get_room_list_handler() if server: - data = yield handler.get_remote_public_room_list( + data = await handler.get_remote_public_room_list( server, limit=limit, since_token=since_token ) else: - data = yield handler.get_local_public_room_list( + data = await handler.get_local_public_room_list( limit=limit, since_token=since_token ) return 200, data - @defer.inlineCallbacks - def on_POST(self, request): - yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_POST(self, request): + await self.auth.get_user_by_req(request, allow_guest=True) server = parse_string(request, "server", default=None) content = parse_json_object_from_request(request) @@ -408,7 +399,7 @@ class PublicRoomListRestServlet(TransactionRestServlet): handler = self.hs.get_room_list_handler() if server: - data = yield handler.get_remote_public_room_list( + data = await handler.get_remote_public_room_list( server, limit=limit, since_token=since_token, @@ -417,7 +408,7 @@ class PublicRoomListRestServlet(TransactionRestServlet): third_party_instance_id=third_party_instance_id, ) else: - data = yield handler.get_local_public_room_list( + data = await handler.get_local_public_room_list( limit=limit, since_token=since_token, search_filter=search_filter, @@ -436,10 +427,9 @@ class RoomMemberListRestServlet(RestServlet): self.message_handler = hs.get_message_handler() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request, room_id): + async def on_GET(self, request, room_id): # TODO support Pagination stream API (limit/tokens) - requester = yield self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request) handler = self.message_handler # request the state as of a given event, as identified by a stream token, @@ -459,7 +449,7 @@ class RoomMemberListRestServlet(RestServlet): membership = parse_string(request, "membership") not_membership = parse_string(request, "not_membership") - events = yield handler.get_state_events( + events = await handler.get_state_events( room_id=room_id, user_id=requester.user.to_string(), at_token=at_token, @@ -488,11 +478,10 @@ class JoinedRoomMemberListRestServlet(RestServlet): self.message_handler = hs.get_message_handler() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request, room_id): - requester = yield self.auth.get_user_by_req(request) + async def on_GET(self, request, room_id): + requester = await self.auth.get_user_by_req(request) - users_with_profile = yield self.message_handler.get_joined_members( + users_with_profile = await self.message_handler.get_joined_members( requester, room_id ) @@ -508,9 +497,8 @@ class RoomMessageListRestServlet(RestServlet): self.pagination_handler = hs.get_pagination_handler() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request, room_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, room_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) pagination_config = PaginationConfig.from_request(request, default_limit=10) as_client_event = b"raw" not in request.args filter_bytes = parse_string(request, b"filter", encoding=None) @@ -521,7 +509,7 @@ class RoomMessageListRestServlet(RestServlet): as_client_event = False else: event_filter = None - msgs = yield self.pagination_handler.get_messages( + msgs = await self.pagination_handler.get_messages( room_id=room_id, requester=requester, pagin_config=pagination_config, @@ -541,11 +529,10 @@ class RoomStateRestServlet(RestServlet): self.message_handler = hs.get_message_handler() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request, room_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, room_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) # Get all the current state for this room - events = yield self.message_handler.get_state_events( + events = await self.message_handler.get_state_events( room_id=room_id, user_id=requester.user.to_string(), is_guest=requester.is_guest, @@ -562,11 +549,10 @@ class RoomInitialSyncRestServlet(RestServlet): self.initial_sync_handler = hs.get_initial_sync_handler() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request, room_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, room_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) pagination_config = PaginationConfig.from_request(request) - content = yield self.initial_sync_handler.room_initial_sync( + content = await self.initial_sync_handler.room_initial_sync( room_id=room_id, requester=requester, pagin_config=pagination_config ) return 200, content @@ -584,11 +570,10 @@ class RoomEventServlet(RestServlet): self._event_serializer = hs.get_event_client_serializer() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request, room_id, event_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, room_id, event_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) try: - event = yield self.event_handler.get_event( + event = await self.event_handler.get_event( requester.user, room_id, event_id ) except AuthError: @@ -599,7 +584,7 @@ class RoomEventServlet(RestServlet): time_now = self.clock.time_msec() if event: - event = yield self._event_serializer.serialize_event(event, time_now) + event = await self._event_serializer.serialize_event(event, time_now) return 200, event return SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND) @@ -617,9 +602,8 @@ class RoomEventContextServlet(RestServlet): self._event_serializer = hs.get_event_client_serializer() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request, room_id, event_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, room_id, event_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) limit = parse_integer(request, "limit", default=10) @@ -631,7 +615,7 @@ class RoomEventContextServlet(RestServlet): else: event_filter = None - results = yield self.room_context_handler.get_event_context( + results = await self.room_context_handler.get_event_context( requester.user, room_id, event_id, limit, event_filter ) @@ -639,16 +623,16 @@ class RoomEventContextServlet(RestServlet): raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND) time_now = self.clock.time_msec() - results["events_before"] = yield self._event_serializer.serialize_events( + results["events_before"] = await self._event_serializer.serialize_events( results["events_before"], time_now ) - results["event"] = yield self._event_serializer.serialize_event( + results["event"] = await self._event_serializer.serialize_event( results["event"], time_now ) - results["events_after"] = yield self._event_serializer.serialize_events( + results["events_after"] = await self._event_serializer.serialize_events( results["events_after"], time_now ) - results["state"] = yield self._event_serializer.serialize_events( + results["state"] = await self._event_serializer.serialize_events( results["state"], time_now ) @@ -665,11 +649,10 @@ class RoomForgetRestServlet(TransactionRestServlet): PATTERNS = "/rooms/(?P[^/]*)/forget" register_txn_path(self, PATTERNS, http_server) - @defer.inlineCallbacks - def on_POST(self, request, room_id, txn_id=None): - requester = yield self.auth.get_user_by_req(request, allow_guest=False) + async def on_POST(self, request, room_id, txn_id=None): + requester = await self.auth.get_user_by_req(request, allow_guest=False) - yield self.room_member_handler.forget(user=requester.user, room_id=room_id) + await self.room_member_handler.forget(user=requester.user, room_id=room_id) return 200, {} @@ -696,9 +679,8 @@ class RoomMembershipRestServlet(TransactionRestServlet): ) register_txn_path(self, PATTERNS, http_server) - @defer.inlineCallbacks - def on_POST(self, request, room_id, membership_action, txn_id=None): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_POST(self, request, room_id, membership_action, txn_id=None): + requester = await self.auth.get_user_by_req(request, allow_guest=True) if requester.is_guest and membership_action not in { Membership.JOIN, @@ -714,7 +696,7 @@ class RoomMembershipRestServlet(TransactionRestServlet): content = {} if membership_action == "invite" and self._has_3pid_invite_keys(content): - yield self.room_member_handler.do_3pid_invite( + await self.room_member_handler.do_3pid_invite( room_id, requester.user, content["medium"], @@ -735,7 +717,7 @@ class RoomMembershipRestServlet(TransactionRestServlet): if "reason" in content and membership_action in ["kick", "ban"]: event_content = {"reason": content["reason"]} - yield self.room_member_handler.update_membership( + await self.room_member_handler.update_membership( requester=requester, target=target, room_id=room_id, @@ -777,12 +759,11 @@ class RoomRedactEventRestServlet(TransactionRestServlet): PATTERNS = "/rooms/(?P[^/]*)/redact/(?P[^/]*)" register_txn_path(self, PATTERNS, http_server) - @defer.inlineCallbacks - def on_POST(self, request, room_id, event_id, txn_id=None): - requester = yield self.auth.get_user_by_req(request) + async def on_POST(self, request, room_id, event_id, txn_id=None): + requester = await self.auth.get_user_by_req(request) content = parse_json_object_from_request(request) - event = yield self.event_creation_handler.create_and_send_nonmember_event( + event = await self.event_creation_handler.create_and_send_nonmember_event( requester, { "type": EventTypes.Redaction, @@ -816,29 +797,28 @@ class RoomTypingRestServlet(RestServlet): self.typing_handler = hs.get_typing_handler() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_PUT(self, request, room_id, user_id): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, room_id, user_id): + requester = await self.auth.get_user_by_req(request) room_id = urlparse.unquote(room_id) target_user = UserID.from_string(urlparse.unquote(user_id)) content = parse_json_object_from_request(request) - yield self.presence_handler.bump_presence_active_time(requester.user) + await self.presence_handler.bump_presence_active_time(requester.user) # Limit timeout to stop people from setting silly typing timeouts. timeout = min(content.get("timeout", 30000), 120000) if content["typing"]: - yield self.typing_handler.started_typing( + await self.typing_handler.started_typing( target_user=target_user, auth_user=requester.user, room_id=room_id, timeout=timeout, ) else: - yield self.typing_handler.stopped_typing( + await self.typing_handler.stopped_typing( target_user=target_user, auth_user=requester.user, room_id=room_id ) @@ -853,14 +833,13 @@ class SearchRestServlet(RestServlet): self.handlers = hs.get_handlers() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_POST(self, request): - requester = yield self.auth.get_user_by_req(request) + async def on_POST(self, request): + requester = await self.auth.get_user_by_req(request) content = parse_json_object_from_request(request) batch = parse_string(request, "next_batch") - results = yield self.handlers.search_handler.search( + results = await self.handlers.search_handler.search( requester.user, content, batch ) @@ -875,11 +854,10 @@ class JoinedRoomsRestServlet(RestServlet): self.store = hs.get_datastore() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request): + requester = await self.auth.get_user_by_req(request, allow_guest=True) - room_ids = yield self.store.get_rooms_for_user(requester.user.to_string()) + room_ids = await self.store.get_rooms_for_user(requester.user.to_string()) return 200, {"joined_rooms": list(room_ids)} -- cgit 1.4.1 From 3f33879be4697666a304a472d090d4def0c671d0 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 29 Oct 2019 14:05:32 +0000 Subject: Port federation_server to async/await --- synapse/federation/federation_server.py | 205 ++++++++++++++------------------ tests/handlers/test_typing.py | 3 + 2 files changed, 90 insertions(+), 118 deletions(-) (limited to 'synapse') diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 5fc7c1d67b..15c1fa0a51 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -21,7 +21,6 @@ from six import iteritems from canonicaljson import json from prometheus_client import Counter -from twisted.internet import defer from twisted.internet.abstract import isIPAddress from twisted.python import failure @@ -86,14 +85,12 @@ class FederationServer(FederationBase): # come in waves. self._state_resp_cache = ResponseCache(hs, "state_resp", timeout_ms=30000) - @defer.inlineCallbacks - @log_function - def on_backfill_request(self, origin, room_id, versions, limit): - with (yield self._server_linearizer.queue((origin, room_id))): + async def on_backfill_request(self, origin, room_id, versions, limit): + with (await self._server_linearizer.queue((origin, room_id))): origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, room_id) + await self.check_server_matches_acl(origin_host, room_id) - pdus = yield self.handler.on_backfill_request( + pdus = await self.handler.on_backfill_request( origin, room_id, versions, limit ) @@ -101,9 +98,7 @@ class FederationServer(FederationBase): return 200, res - @defer.inlineCallbacks - @log_function - def on_incoming_transaction(self, origin, transaction_data): + async def on_incoming_transaction(self, origin, transaction_data): # keep this as early as possible to make the calculated origin ts as # accurate as possible. request_time = self._clock.time_msec() @@ -118,18 +113,17 @@ class FederationServer(FederationBase): # use a linearizer to ensure that we don't process the same transaction # multiple times in parallel. with ( - yield self._transaction_linearizer.queue( + await self._transaction_linearizer.queue( (origin, transaction.transaction_id) ) ): - result = yield self._handle_incoming_transaction( + result = await self._handle_incoming_transaction( origin, transaction, request_time ) return result - @defer.inlineCallbacks - def _handle_incoming_transaction(self, origin, transaction, request_time): + async def _handle_incoming_transaction(self, origin, transaction, request_time): """ Process an incoming transaction and return the HTTP response Args: @@ -140,7 +134,7 @@ class FederationServer(FederationBase): Returns: Deferred[(int, object)]: http response code and body """ - response = yield self.transaction_actions.have_responded(origin, transaction) + response = await self.transaction_actions.have_responded(origin, transaction) if response: logger.debug( @@ -159,7 +153,7 @@ class FederationServer(FederationBase): logger.info("Transaction PDU or EDU count too large. Returning 400") response = {} - yield self.transaction_actions.set_response( + await self.transaction_actions.set_response( origin, transaction, 400, response ) return 400, response @@ -195,7 +189,7 @@ class FederationServer(FederationBase): continue try: - room_version = yield self.store.get_room_version(room_id) + room_version = await self.store.get_room_version(room_id) except NotFoundError: logger.info("Ignoring PDU for unknown room_id: %s", room_id) continue @@ -221,11 +215,10 @@ class FederationServer(FederationBase): # require callouts to other servers to fetch missing events), but # impose a limit to avoid going too crazy with ram/cpu. - @defer.inlineCallbacks - def process_pdus_for_room(room_id): + async def process_pdus_for_room(room_id): logger.debug("Processing PDUs for %s", room_id) try: - yield self.check_server_matches_acl(origin_host, room_id) + await self.check_server_matches_acl(origin_host, room_id) except AuthError as e: logger.warn("Ignoring PDUs for room %s from banned server", room_id) for pdu in pdus_by_room[room_id]: @@ -237,7 +230,7 @@ class FederationServer(FederationBase): event_id = pdu.event_id with nested_logging_context(event_id): try: - yield self._handle_received_pdu(origin, pdu) + await self._handle_received_pdu(origin, pdu) pdu_results[event_id] = {} except FederationError as e: logger.warn("Error handling PDU %s: %s", event_id, e) @@ -251,36 +244,33 @@ class FederationServer(FederationBase): exc_info=(f.type, f.value, f.getTracebackObject()), ) - yield concurrently_execute( + await concurrently_execute( process_pdus_for_room, pdus_by_room.keys(), TRANSACTION_CONCURRENCY_LIMIT ) if hasattr(transaction, "edus"): for edu in (Edu(**x) for x in transaction.edus): - yield self.received_edu(origin, edu.edu_type, edu.content) + await self.received_edu(origin, edu.edu_type, edu.content) response = {"pdus": pdu_results} logger.debug("Returning: %s", str(response)) - yield self.transaction_actions.set_response(origin, transaction, 200, response) + await self.transaction_actions.set_response(origin, transaction, 200, response) return 200, response - @defer.inlineCallbacks - def received_edu(self, origin, edu_type, content): + async def received_edu(self, origin, edu_type, content): received_edus_counter.inc() - yield self.registry.on_edu(edu_type, origin, content) + await self.registry.on_edu(edu_type, origin, content) - @defer.inlineCallbacks - @log_function - def on_context_state_request(self, origin, room_id, event_id): + async def on_context_state_request(self, origin, room_id, event_id): if not event_id: raise NotImplementedError("Specify an event") origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, room_id) + await self.check_server_matches_acl(origin_host, room_id) - in_room = yield self.auth.check_host_in_room(room_id, origin) + in_room = await self.auth.check_host_in_room(room_id, origin) if not in_room: raise AuthError(403, "Host not in room.") @@ -289,8 +279,8 @@ class FederationServer(FederationBase): # in the cache so we could return it without waiting for the linearizer # - but that's non-trivial to get right, and anyway somewhat defeats # the point of the linearizer. - with (yield self._server_linearizer.queue((origin, room_id))): - resp = yield self._state_resp_cache.wrap( + with (await self._server_linearizer.queue((origin, room_id))): + resp = await self._state_resp_cache.wrap( (room_id, event_id), self._on_context_state_request_compute, room_id, @@ -299,65 +289,58 @@ class FederationServer(FederationBase): return 200, resp - @defer.inlineCallbacks - def on_state_ids_request(self, origin, room_id, event_id): + async def on_state_ids_request(self, origin, room_id, event_id): if not event_id: raise NotImplementedError("Specify an event") origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, room_id) + await self.check_server_matches_acl(origin_host, room_id) - in_room = yield self.auth.check_host_in_room(room_id, origin) + in_room = await self.auth.check_host_in_room(room_id, origin) if not in_room: raise AuthError(403, "Host not in room.") - state_ids = yield self.handler.get_state_ids_for_pdu(room_id, event_id) - auth_chain_ids = yield self.store.get_auth_chain_ids(state_ids) + state_ids = await self.handler.get_state_ids_for_pdu(room_id, event_id) + auth_chain_ids = await self.store.get_auth_chain_ids(state_ids) return 200, {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids} - @defer.inlineCallbacks - def _on_context_state_request_compute(self, room_id, event_id): - pdus = yield self.handler.get_state_for_pdu(room_id, event_id) - auth_chain = yield self.store.get_auth_chain([pdu.event_id for pdu in pdus]) + async def _on_context_state_request_compute(self, room_id, event_id): + pdus = await self.handler.get_state_for_pdu(room_id, event_id) + auth_chain = await self.store.get_auth_chain([pdu.event_id for pdu in pdus]) return { "pdus": [pdu.get_pdu_json() for pdu in pdus], "auth_chain": [pdu.get_pdu_json() for pdu in auth_chain], } - @defer.inlineCallbacks - @log_function - def on_pdu_request(self, origin, event_id): - pdu = yield self.handler.get_persisted_pdu(origin, event_id) + async def on_pdu_request(self, origin, event_id): + pdu = await self.handler.get_persisted_pdu(origin, event_id) if pdu: return 200, self._transaction_from_pdus([pdu]).get_dict() else: return 404, "" - @defer.inlineCallbacks - def on_query_request(self, query_type, args): + async def on_query_request(self, query_type, args): received_queries_counter.labels(query_type).inc() - resp = yield self.registry.on_query(query_type, args) + resp = await self.registry.on_query(query_type, args) return 200, resp - @defer.inlineCallbacks - def on_make_join_request(self, origin, room_id, user_id, supported_versions): + async def on_make_join_request(self, origin, room_id, user_id, supported_versions): origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, room_id) + await self.check_server_matches_acl(origin_host, room_id) - room_version = yield self.store.get_room_version(room_id) + room_version = await self.store.get_room_version(room_id) if room_version not in supported_versions: logger.warn("Room version %s not in %s", room_version, supported_versions) raise IncompatibleRoomVersionError(room_version=room_version) - pdu = yield self.handler.on_make_join_request(origin, room_id, user_id) + pdu = await self.handler.on_make_join_request(origin, room_id, user_id) time_now = self._clock.time_msec() return {"event": pdu.get_pdu_json(time_now), "room_version": room_version} - @defer.inlineCallbacks - def on_invite_request(self, origin, content, room_version): + async def on_invite_request(self, origin, content, room_version): if room_version not in KNOWN_ROOM_VERSIONS: raise SynapseError( 400, @@ -369,28 +352,27 @@ class FederationServer(FederationBase): pdu = event_from_pdu_json(content, format_ver) origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, pdu.room_id) - pdu = yield self._check_sigs_and_hash(room_version, pdu) - ret_pdu = yield self.handler.on_invite_request(origin, pdu) + await self.check_server_matches_acl(origin_host, pdu.room_id) + pdu = await self._check_sigs_and_hash(room_version, pdu) + ret_pdu = await self.handler.on_invite_request(origin, pdu) time_now = self._clock.time_msec() return {"event": ret_pdu.get_pdu_json(time_now)} - @defer.inlineCallbacks - def on_send_join_request(self, origin, content, room_id): + async def on_send_join_request(self, origin, content, room_id): logger.debug("on_send_join_request: content: %s", content) - room_version = yield self.store.get_room_version(room_id) + room_version = await self.store.get_room_version(room_id) format_ver = room_version_to_event_format(room_version) pdu = event_from_pdu_json(content, format_ver) origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, pdu.room_id) + await self.check_server_matches_acl(origin_host, pdu.room_id) logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures) - pdu = yield self._check_sigs_and_hash(room_version, pdu) + pdu = await self._check_sigs_and_hash(room_version, pdu) - res_pdus = yield self.handler.on_send_join_request(origin, pdu) + res_pdus = await self.handler.on_send_join_request(origin, pdu) time_now = self._clock.time_msec() return ( 200, @@ -402,48 +384,44 @@ class FederationServer(FederationBase): }, ) - @defer.inlineCallbacks - def on_make_leave_request(self, origin, room_id, user_id): + async def on_make_leave_request(self, origin, room_id, user_id): origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, room_id) - pdu = yield self.handler.on_make_leave_request(origin, room_id, user_id) + await self.check_server_matches_acl(origin_host, room_id) + pdu = await self.handler.on_make_leave_request(origin, room_id, user_id) - room_version = yield self.store.get_room_version(room_id) + room_version = await self.store.get_room_version(room_id) time_now = self._clock.time_msec() return {"event": pdu.get_pdu_json(time_now), "room_version": room_version} - @defer.inlineCallbacks - def on_send_leave_request(self, origin, content, room_id): + async def on_send_leave_request(self, origin, content, room_id): logger.debug("on_send_leave_request: content: %s", content) - room_version = yield self.store.get_room_version(room_id) + room_version = await self.store.get_room_version(room_id) format_ver = room_version_to_event_format(room_version) pdu = event_from_pdu_json(content, format_ver) origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, pdu.room_id) + await self.check_server_matches_acl(origin_host, pdu.room_id) logger.debug("on_send_leave_request: pdu sigs: %s", pdu.signatures) - pdu = yield self._check_sigs_and_hash(room_version, pdu) + pdu = await self._check_sigs_and_hash(room_version, pdu) - yield self.handler.on_send_leave_request(origin, pdu) + await self.handler.on_send_leave_request(origin, pdu) return 200, {} - @defer.inlineCallbacks - def on_event_auth(self, origin, room_id, event_id): - with (yield self._server_linearizer.queue((origin, room_id))): + async def on_event_auth(self, origin, room_id, event_id): + with (await self._server_linearizer.queue((origin, room_id))): origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, room_id) + await self.check_server_matches_acl(origin_host, room_id) time_now = self._clock.time_msec() - auth_pdus = yield self.handler.on_event_auth(event_id) + auth_pdus = await self.handler.on_event_auth(event_id) res = {"auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus]} return 200, res - @defer.inlineCallbacks - def on_query_auth_request(self, origin, content, room_id, event_id): + async def on_query_auth_request(self, origin, content, room_id, event_id): """ Content is a dict with keys:: auth_chain (list): A list of events that give the auth chain. @@ -462,22 +440,22 @@ class FederationServer(FederationBase): Returns: Deferred: Results in `dict` with the same format as `content` """ - with (yield self._server_linearizer.queue((origin, room_id))): + with (await self._server_linearizer.queue((origin, room_id))): origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, room_id) + await self.check_server_matches_acl(origin_host, room_id) - room_version = yield self.store.get_room_version(room_id) + room_version = await self.store.get_room_version(room_id) format_ver = room_version_to_event_format(room_version) auth_chain = [ event_from_pdu_json(e, format_ver) for e in content["auth_chain"] ] - signed_auth = yield self._check_sigs_and_hash_and_fetch( + signed_auth = await self._check_sigs_and_hash_and_fetch( origin, auth_chain, outlier=True, room_version=room_version ) - ret = yield self.handler.on_query_auth( + ret = await self.handler.on_query_auth( origin, event_id, room_id, @@ -503,16 +481,14 @@ class FederationServer(FederationBase): return self.on_query_request("user_devices", user_id) @trace - @defer.inlineCallbacks - @log_function - def on_claim_client_keys(self, origin, content): + async def on_claim_client_keys(self, origin, content): query = [] for user_id, device_keys in content.get("one_time_keys", {}).items(): for device_id, algorithm in device_keys.items(): query.append((user_id, device_id, algorithm)) log_kv({"message": "Claiming one time keys.", "user, device pairs": query}) - results = yield self.store.claim_e2e_one_time_keys(query) + results = await self.store.claim_e2e_one_time_keys(query) json_result = {} for user_id, device_keys in results.items(): @@ -536,14 +512,12 @@ class FederationServer(FederationBase): return {"one_time_keys": json_result} - @defer.inlineCallbacks - @log_function - def on_get_missing_events( + async def on_get_missing_events( self, origin, room_id, earliest_events, latest_events, limit ): - with (yield self._server_linearizer.queue((origin, room_id))): + with (await self._server_linearizer.queue((origin, room_id))): origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, room_id) + await self.check_server_matches_acl(origin_host, room_id) logger.info( "on_get_missing_events: earliest_events: %r, latest_events: %r," @@ -553,7 +527,7 @@ class FederationServer(FederationBase): limit, ) - missing_events = yield self.handler.on_get_missing_events( + missing_events = await self.handler.on_get_missing_events( origin, room_id, earliest_events, latest_events, limit ) @@ -586,8 +560,7 @@ class FederationServer(FederationBase): destination=None, ) - @defer.inlineCallbacks - def _handle_received_pdu(self, origin, pdu): + async def _handle_received_pdu(self, origin, pdu): """ Process a PDU received in a federation /send/ transaction. If the event is invalid, then this method throws a FederationError. @@ -640,37 +613,34 @@ class FederationServer(FederationBase): logger.info("Accepting join PDU %s from %s", pdu.event_id, origin) # We've already checked that we know the room version by this point - room_version = yield self.store.get_room_version(pdu.room_id) + room_version = await self.store.get_room_version(pdu.room_id) # Check signature. try: - pdu = yield self._check_sigs_and_hash(room_version, pdu) + pdu = await self._check_sigs_and_hash(room_version, pdu) except SynapseError as e: raise FederationError("ERROR", e.code, e.msg, affected=pdu.event_id) - yield self.handler.on_receive_pdu(origin, pdu, sent_to_us_directly=True) + await self.handler.on_receive_pdu(origin, pdu, sent_to_us_directly=True) def __str__(self): return "" % self.server_name - @defer.inlineCallbacks - def exchange_third_party_invite( + async def exchange_third_party_invite( self, sender_user_id, target_user_id, room_id, signed ): - ret = yield self.handler.exchange_third_party_invite( + ret = await self.handler.exchange_third_party_invite( sender_user_id, target_user_id, room_id, signed ) return ret - @defer.inlineCallbacks - def on_exchange_third_party_invite_request(self, room_id, event_dict): - ret = yield self.handler.on_exchange_third_party_invite_request( + async def on_exchange_third_party_invite_request(self, room_id, event_dict): + ret = await self.handler.on_exchange_third_party_invite_request( room_id, event_dict ) return ret - @defer.inlineCallbacks - def check_server_matches_acl(self, server_name, room_id): + async def check_server_matches_acl(self, server_name, room_id): """Check if the given server is allowed by the server ACLs in the room Args: @@ -680,13 +650,13 @@ class FederationServer(FederationBase): Raises: AuthError if the server does not match the ACL """ - state_ids = yield self.store.get_current_state_ids(room_id) + state_ids = await self.store.get_current_state_ids(room_id) acl_event_id = state_ids.get((EventTypes.ServerACL, "")) if not acl_event_id: return - acl_event = yield self.store.get_event(acl_event_id) + acl_event = await self.store.get_event(acl_event_id) if server_matches_acl_event(server_name, acl_event): return @@ -799,15 +769,14 @@ class FederationHandlerRegistry(object): self.query_handlers[query_type] = handler - @defer.inlineCallbacks - def on_edu(self, edu_type, origin, content): + async def on_edu(self, edu_type, origin, content): handler = self.edu_handlers.get(edu_type) if not handler: logger.warn("No handler registered for EDU type %s", edu_type) with start_active_span_from_edu(content, "handle_edu"): try: - yield handler(origin, content) + await handler(origin, content) except SynapseError as e: logger.info("Failed to handle edu %r: %r", edu_type, e) except Exception: diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 67f1013051..f360c8e965 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -144,6 +144,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.datastore.get_to_device_stream_token = lambda: 0 self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: ([], 0) self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: None + self.datastore.set_received_txn_response = lambda *args, **kwargs: defer.succeed( + None + ) def test_started_typing_local(self): self.room_members = [U_APPLE, U_BANANA] -- cgit 1.4.1 From 09a135b0391a7a65c24c8c7b046cb6573dee3947 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 29 Oct 2019 15:02:23 +0000 Subject: Make concurrently_execute work with async/await --- synapse/util/async_helpers.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) (limited to 'synapse') diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 804dbca443..7659eaeb42 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -138,7 +138,7 @@ def concurrently_execute(func, args, limit): the number of concurrent executions. Args: - func (func): Function to execute, should return a deferred. + func (func): Function to execute, should return a deferred or coroutine. args (list): List of arguments to pass to func, each invocation of func gets a signle argument. limit (int): Maximum number of conccurent executions. @@ -148,11 +148,10 @@ def concurrently_execute(func, args, limit): """ it = iter(args) - @defer.inlineCallbacks - def _concurrently_execute_inner(): + async def _concurrently_execute_inner(): try: while True: - yield func(next(it)) + await maybe_awaitable(func(next(it))) except StopIteration: pass -- cgit 1.4.1 From 2c35ffead257171d195f228bafd0d65b917e2165 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 29 Oct 2019 15:08:22 +0000 Subject: Port receipt and read markers to async/wait --- synapse/federation/send_queue.py | 4 +++- synapse/handlers/read_marker.py | 13 ++++------ synapse/handlers/receipts.py | 37 ++++++++++------------------- synapse/rest/client/v2_alpha/read_marker.py | 13 ++++------ synapse/rest/client/v2_alpha/receipts.py | 11 ++++----- synapse/storage/data_stores/main/events.py | 7 +++--- 6 files changed, 32 insertions(+), 53 deletions(-) (limited to 'synapse') diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py index 454456a52d..ced4925a98 100644 --- a/synapse/federation/send_queue.py +++ b/synapse/federation/send_queue.py @@ -36,6 +36,8 @@ from six import iteritems from sortedcontainers import SortedDict +from twisted.internet import defer + from synapse.metrics import LaterGauge from synapse.storage.presence import UserPresenceState from synapse.util.metrics import Measure @@ -212,7 +214,7 @@ class FederationRemoteSendQueue(object): receipt (synapse.types.ReadReceipt): """ # nothing to do here: the replication listener will handle it. - pass + return defer.succeed(None) def send_presence(self, states): """As per FederationSender diff --git a/synapse/handlers/read_marker.py b/synapse/handlers/read_marker.py index 3e4d8c93a4..e3b528d271 100644 --- a/synapse/handlers/read_marker.py +++ b/synapse/handlers/read_marker.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.util.async_helpers import Linearizer from ._base import BaseHandler @@ -32,8 +30,7 @@ class ReadMarkerHandler(BaseHandler): self.read_marker_linearizer = Linearizer(name="read_marker") self.notifier = hs.get_notifier() - @defer.inlineCallbacks - def received_client_read_marker(self, room_id, user_id, event_id): + async def received_client_read_marker(self, room_id, user_id, event_id): """Updates the read marker for a given user in a given room if the event ID given is ahead in the stream relative to the current read marker. @@ -41,8 +38,8 @@ class ReadMarkerHandler(BaseHandler): the read marker has changed. """ - with (yield self.read_marker_linearizer.queue((room_id, user_id))): - existing_read_marker = yield self.store.get_account_data_for_room_and_type( + with await self.read_marker_linearizer.queue((room_id, user_id)): + existing_read_marker = await self.store.get_account_data_for_room_and_type( user_id, room_id, "m.fully_read" ) @@ -50,13 +47,13 @@ class ReadMarkerHandler(BaseHandler): if existing_read_marker: # Only update if the new marker is ahead in the stream - should_update = yield self.store.is_event_after( + should_update = await self.store.is_event_after( event_id, existing_read_marker["event_id"] ) if should_update: content = {"event_id": event_id} - max_id = yield self.store.add_account_data_to_room( + max_id = await self.store.add_account_data_to_room( user_id, room_id, "m.fully_read", content ) self.notifier.on_new_event("account_data_key", max_id, users=[user_id]) diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index 6854c751a6..9283c039e3 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -18,6 +18,7 @@ from twisted.internet import defer from synapse.handlers._base import BaseHandler from synapse.types import ReadReceipt, get_domain_from_id +from synapse.util.async_helpers import maybe_awaitable logger = logging.getLogger(__name__) @@ -36,8 +37,7 @@ class ReceiptsHandler(BaseHandler): self.clock = self.hs.get_clock() self.state = hs.get_state_handler() - @defer.inlineCallbacks - def _received_remote_receipt(self, origin, content): + async def _received_remote_receipt(self, origin, content): """Called when we receive an EDU of type m.receipt from a remote HS. """ receipts = [] @@ -62,17 +62,16 @@ class ReceiptsHandler(BaseHandler): ) ) - yield self._handle_new_receipts(receipts) + await self._handle_new_receipts(receipts) - @defer.inlineCallbacks - def _handle_new_receipts(self, receipts): + async def _handle_new_receipts(self, receipts): """Takes a list of receipts, stores them and informs the notifier. """ min_batch_id = None max_batch_id = None for receipt in receipts: - res = yield self.store.insert_receipt( + res = await self.store.insert_receipt( receipt.room_id, receipt.receipt_type, receipt.user_id, @@ -99,14 +98,15 @@ class ReceiptsHandler(BaseHandler): self.notifier.on_new_event("receipt_key", max_batch_id, rooms=affected_room_ids) # Note that the min here shouldn't be relied upon to be accurate. - yield self.hs.get_pusherpool().on_new_receipts( - min_batch_id, max_batch_id, affected_room_ids + await maybe_awaitable( + self.hs.get_pusherpool().on_new_receipts( + min_batch_id, max_batch_id, affected_room_ids + ) ) return True - @defer.inlineCallbacks - def received_client_receipt(self, room_id, receipt_type, user_id, event_id): + async def received_client_receipt(self, room_id, receipt_type, user_id, event_id): """Called when a client tells us a local user has read up to the given event_id in the room. """ @@ -118,24 +118,11 @@ class ReceiptsHandler(BaseHandler): data={"ts": int(self.clock.time_msec())}, ) - is_new = yield self._handle_new_receipts([receipt]) + is_new = await self._handle_new_receipts([receipt]) if not is_new: return - yield self.federation.send_read_receipt(receipt) - - @defer.inlineCallbacks - def get_receipts_for_room(self, room_id, to_key): - """Gets all receipts for a room, upto the given key. - """ - result = yield self.store.get_linearized_receipts_for_room( - room_id, to_key=to_key - ) - - if not result: - return [] - - return result + await self.federation.send_read_receipt(receipt) class ReceiptEventSource(object): diff --git a/synapse/rest/client/v2_alpha/read_marker.py b/synapse/rest/client/v2_alpha/read_marker.py index b3bf8567e1..67cbc37312 100644 --- a/synapse/rest/client/v2_alpha/read_marker.py +++ b/synapse/rest/client/v2_alpha/read_marker.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.http.servlet import RestServlet, parse_json_object_from_request from ._base import client_patterns @@ -34,17 +32,16 @@ class ReadMarkerRestServlet(RestServlet): self.read_marker_handler = hs.get_read_marker_handler() self.presence_handler = hs.get_presence_handler() - @defer.inlineCallbacks - def on_POST(self, request, room_id): - requester = yield self.auth.get_user_by_req(request) + async def on_POST(self, request, room_id): + requester = await self.auth.get_user_by_req(request) - yield self.presence_handler.bump_presence_active_time(requester.user) + await self.presence_handler.bump_presence_active_time(requester.user) body = parse_json_object_from_request(request) read_event_id = body.get("m.read", None) if read_event_id: - yield self.receipts_handler.received_client_receipt( + await self.receipts_handler.received_client_receipt( room_id, "m.read", user_id=requester.user.to_string(), @@ -53,7 +50,7 @@ class ReadMarkerRestServlet(RestServlet): read_marker_event_id = body.get("m.fully_read", None) if read_marker_event_id: - yield self.read_marker_handler.received_client_read_marker( + await self.read_marker_handler.received_client_read_marker( room_id, user_id=requester.user.to_string(), event_id=read_marker_event_id, diff --git a/synapse/rest/client/v2_alpha/receipts.py b/synapse/rest/client/v2_alpha/receipts.py index 0dab03d227..92555bd4a9 100644 --- a/synapse/rest/client/v2_alpha/receipts.py +++ b/synapse/rest/client/v2_alpha/receipts.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.api.errors import SynapseError from synapse.http.servlet import RestServlet @@ -39,16 +37,15 @@ class ReceiptRestServlet(RestServlet): self.receipts_handler = hs.get_receipts_handler() self.presence_handler = hs.get_presence_handler() - @defer.inlineCallbacks - def on_POST(self, request, room_id, receipt_type, event_id): - requester = yield self.auth.get_user_by_req(request) + async def on_POST(self, request, room_id, receipt_type, event_id): + requester = await self.auth.get_user_by_req(request) if receipt_type != "m.read": raise SynapseError(400, "Receipt type must be 'm.read'") - yield self.presence_handler.bump_presence_active_time(requester.user) + await self.presence_handler.bump_presence_active_time(requester.user) - yield self.receipts_handler.received_client_receipt( + await self.receipts_handler.received_client_receipt( room_id, receipt_type, user_id=requester.user.to_string(), event_id=event_id ) diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index 03b5111c5d..067e77ae00 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -2439,12 +2439,11 @@ class EventsStore( logger.info("[purge] done") - @defer.inlineCallbacks - def is_event_after(self, event_id1, event_id2): + async def is_event_after(self, event_id1, event_id2): """Returns True if event_id1 is after event_id2 in the stream """ - to_1, so_1 = yield self._get_event_ordering(event_id1) - to_2, so_2 = yield self._get_event_ordering(event_id2) + to_1, so_1 = await self._get_event_ordering(event_id1) + to_2, so_2 = await self._get_event_ordering(event_id2) return (to_1, so_1) > (to_2, so_2) @cachedInlineCallbacks(max_entries=5000) -- cgit 1.4.1 From a287f1e80475e0c33743ed2990fd9d72572bda68 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 29 Oct 2019 16:36:46 +0000 Subject: Don't return coroutines --- synapse/federation/federation_server.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) (limited to 'synapse') diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 15c1fa0a51..7c331753ad 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -809,7 +809,7 @@ class ReplicationFederationHandlerRegistry(FederationHandlerRegistry): super(ReplicationFederationHandlerRegistry, self).__init__() - def on_edu(self, edu_type, origin, content): + async def on_edu(self, edu_type, origin, content): """Overrides FederationHandlerRegistry """ if not self.config.use_presence and edu_type == "m.presence": @@ -817,17 +817,17 @@ class ReplicationFederationHandlerRegistry(FederationHandlerRegistry): handler = self.edu_handlers.get(edu_type) if handler: - return super(ReplicationFederationHandlerRegistry, self).on_edu( + return await super(ReplicationFederationHandlerRegistry, self).on_edu( edu_type, origin, content ) - return self._send_edu(edu_type=edu_type, origin=origin, content=content) + return await self._send_edu(edu_type=edu_type, origin=origin, content=content) - def on_query(self, query_type, args): + async def on_query(self, query_type, args): """Overrides FederationHandlerRegistry """ handler = self.query_handlers.get(query_type) if handler: - return handler(args) + return await handler(args) - return self._get_query_client(query_type=query_type, args=args) + return await self._get_query_client(query_type=query_type, args=args) -- cgit 1.4.1 From 47f767269c51773a1cb24e65f958a891e00c7c06 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Tue, 29 Oct 2019 16:56:22 +0000 Subject: Add database table for keeping track of labels on events --- .../main/schema/delta/56/event_labels.sql | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) create mode 100644 synapse/storage/data_stores/main/schema/delta/56/event_labels.sql (limited to 'synapse') diff --git a/synapse/storage/data_stores/main/schema/delta/56/event_labels.sql b/synapse/storage/data_stores/main/schema/delta/56/event_labels.sql new file mode 100644 index 0000000000..323d797419 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/56/event_labels.sql @@ -0,0 +1,20 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS event_labels ( + event_id TEXT, + label TEXT, + PRIMARY KEY(event_id, label) +); \ No newline at end of file -- cgit 1.4.1 From b39ca49db167e814d7848c6a4872c8b65ec03ff1 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Wed, 30 Oct 2019 11:00:15 +0000 Subject: Handle FileNotFound error in checking git repository version (#6284) --- changelog.d/6284.bugfix | 1 + synapse/util/versionstring.py | 10 ++++++---- 2 files changed, 7 insertions(+), 4 deletions(-) create mode 100644 changelog.d/6284.bugfix (limited to 'synapse') diff --git a/changelog.d/6284.bugfix b/changelog.d/6284.bugfix new file mode 100644 index 0000000000..cf15053d2d --- /dev/null +++ b/changelog.d/6284.bugfix @@ -0,0 +1 @@ +Prevent errors from appearing on Synapse startup if `git` is not installed. \ No newline at end of file diff --git a/synapse/util/versionstring.py b/synapse/util/versionstring.py index fa404b9d75..ab7d03af3a 100644 --- a/synapse/util/versionstring.py +++ b/synapse/util/versionstring.py @@ -42,6 +42,7 @@ def get_version_string(module): try: null = open(os.devnull, "w") cwd = os.path.dirname(os.path.abspath(module.__file__)) + try: git_branch = ( subprocess.check_output( @@ -51,7 +52,8 @@ def get_version_string(module): .decode("ascii") ) git_branch = "b=" + git_branch - except subprocess.CalledProcessError: + except (subprocess.CalledProcessError, FileNotFoundError): + # FileNotFoundError can arise when git is not installed git_branch = "" try: @@ -63,7 +65,7 @@ def get_version_string(module): .decode("ascii") ) git_tag = "t=" + git_tag - except subprocess.CalledProcessError: + except (subprocess.CalledProcessError, FileNotFoundError): git_tag = "" try: @@ -74,7 +76,7 @@ def get_version_string(module): .strip() .decode("ascii") ) - except subprocess.CalledProcessError: + except (subprocess.CalledProcessError, FileNotFoundError): git_commit = "" try: @@ -89,7 +91,7 @@ def get_version_string(module): ) git_dirty = "dirty" if is_dirty else "" - except subprocess.CalledProcessError: + except (subprocess.CalledProcessError, FileNotFoundError): git_dirty = "" if git_branch or git_tag or git_commit or git_dirty: -- cgit 1.4.1 From 46c12918add132d8d0cbb808b499c815e2745f72 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Wed, 30 Oct 2019 11:07:42 +0000 Subject: Fix typo in domain name in account_threepid_delegates config option (#6273) --- changelog.d/6273.doc | 1 + docs/sample_config.yaml | 2 +- synapse/config/registration.py | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) create mode 100644 changelog.d/6273.doc (limited to 'synapse') diff --git a/changelog.d/6273.doc b/changelog.d/6273.doc new file mode 100644 index 0000000000..21a41d987d --- /dev/null +++ b/changelog.d/6273.doc @@ -0,0 +1 @@ +Fix a small typo in `account_threepid_delegates` configuration option. \ No newline at end of file diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 6c81c0db75..d2f4aff826 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -955,7 +955,7 @@ uploads_path: "DATADIR/uploads" # If a delegate is specified, the config option public_baseurl must also be filled out. # account_threepid_delegates: - #email: https://example.com # Delegate email sending to example.org + #email: https://example.com # Delegate email sending to example.com #msisdn: http://localhost:8090 # Delegate SMS sending to this local process # Users who register on this homeserver will automatically be joined diff --git a/synapse/config/registration.py b/synapse/config/registration.py index ab41623b2b..1f6dac69da 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -300,7 +300,7 @@ class RegistrationConfig(Config): # If a delegate is specified, the config option public_baseurl must also be filled out. # account_threepid_delegates: - #email: https://example.com # Delegate email sending to example.org + #email: https://example.com # Delegate email sending to example.com #msisdn: http://localhost:8090 # Delegate SMS sending to this local process # Users who register on this homeserver will automatically be joined -- cgit 1.4.1 From 7955abeaac54e97332bd42186299e648bb3ace6c Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Wed, 30 Oct 2019 11:16:19 +0000 Subject: Fix small typo in comment (#6269) --- changelog.d/6269.misc | 1 + synapse/federation/federation_server.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 changelog.d/6269.misc (limited to 'synapse') diff --git a/changelog.d/6269.misc b/changelog.d/6269.misc new file mode 100644 index 0000000000..9fd333cc89 --- /dev/null +++ b/changelog.d/6269.misc @@ -0,0 +1 @@ +Fix incorrect comment regarding the functionality of an `if` statement. \ No newline at end of file diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 7c331753ad..d5a19764d2 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -145,7 +145,7 @@ class FederationServer(FederationBase): logger.debug("[%s] Transaction is new", transaction.transaction_id) - # Reject if PDU count > 50 and EDU count > 100 + # Reject if PDU count > 50 or EDU count > 100 if len(transaction.pdus) > 50 or ( hasattr(transaction, "edus") and len(transaction.edus) > 100 ): -- cgit 1.4.1 From a2276d4d3ca72896582ef24d01fdff6c01e38689 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Wed, 30 Oct 2019 11:28:48 +0000 Subject: Fix log line that was printing undefined value (#6278) --- changelog.d/6278.bugfix | 1 + synapse/handlers/federation.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 changelog.d/6278.bugfix (limited to 'synapse') diff --git a/changelog.d/6278.bugfix b/changelog.d/6278.bugfix new file mode 100644 index 0000000000..c107270461 --- /dev/null +++ b/changelog.d/6278.bugfix @@ -0,0 +1 @@ +Fix exception when remote servers attempt to join a room that they're not allowed to join. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 488058fe68..2da520e6e8 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1250,7 +1250,7 @@ class FederationHandler(BaseHandler): builder=builder ) except AuthError as e: - logger.warn("Failed to create join %r because %s", event, e) + logger.warn("Failed to create join to %s because %s", room_id, e) raise e event_allowed = yield self.third_party_event_rules.check_event_allowed( -- cgit 1.4.1 From 326b3dace77aeb36e516ea9b04ba1baa171bcb47 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 30 Oct 2019 11:35:46 +0000 Subject: Make ObservableDeferred.observe() always return deferred. This makes it easier to use in an async/await world. Also fixes a bug where cache descriptors would occaisonally return a raw value rather than a deferred. --- synapse/util/async_helpers.py | 7 ++----- tests/storage/test__base.py | 2 +- tests/util/caches/test_descriptors.py | 4 ++-- 3 files changed, 5 insertions(+), 8 deletions(-) (limited to 'synapse') diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 7659eaeb42..fd75ba27ad 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -86,11 +86,8 @@ class ObservableDeferred(object): deferred.addCallbacks(callback, errback) - def observe(self): + def observe(self) -> defer.Deferred: """Observe the underlying deferred. - - Can return either a deferred if the underlying deferred is still pending - (or has failed), or the actual value. Callers may need to use maybeDeferred. """ if not self._result: d = defer.Deferred() @@ -105,7 +102,7 @@ class ObservableDeferred(object): return d else: success, res = self._result - return res if success else defer.fail(res) + return defer.succeed(res) if success else defer.fail(res) def observers(self): return self._observers diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index dd49a14524..9b81b536f5 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -197,7 +197,7 @@ class CacheDecoratorTestCase(unittest.TestCase): a.func.prefill(("foo",), ObservableDeferred(d)) - self.assertEquals(a.func("foo"), d.result) + self.assertEquals(a.func("foo").result, d.result) self.assertEquals(callcount[0], 0) @defer.inlineCallbacks diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index f907903511..39e360fe24 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -310,14 +310,14 @@ class DescriptorTestCase(unittest.TestCase): obj.mock.return_value = ["spam", "eggs"] r = obj.fn(1, 2) - self.assertEqual(r, ["spam", "eggs"]) + self.assertEqual(r.result, ["spam", "eggs"]) obj.mock.assert_called_once_with(1, 2) obj.mock.reset_mock() # a call with different params should call the mock again obj.mock.return_value = ["chips"] r = obj.fn(1, 3) - self.assertEqual(r, ["chips"]) + self.assertEqual(r.result, ["chips"]) obj.mock.assert_called_once_with(1, 3) obj.mock.reset_mock() -- cgit 1.4.1 From 6e677403b7c71511374b1532b56e0c411840f87b Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 30 Oct 2019 11:52:04 +0000 Subject: Clarify docstring --- synapse/util/async_helpers.py | 4 ++++ 1 file changed, 4 insertions(+) (limited to 'synapse') diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index fd75ba27ad..b60a604474 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -88,6 +88,10 @@ class ObservableDeferred(object): def observe(self) -> defer.Deferred: """Observe the underlying deferred. + + This returns a brand new deferred that is resolved when the underlying + deferred is resolved. Interacting with the returned deferred does not + effect the underdlying deferred. """ if not self._result: d = defer.Deferred() -- cgit 1.4.1 From a8d16f6c00d5adb204af5fa73ffaea40eea4b632 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 30 Oct 2019 13:33:38 +0000 Subject: Review comments --- synapse/server.py | 5 ++--- synapse/storage/__init__.py | 4 ++-- synapse/storage/data_stores/main/events.py | 19 ++++++++++++++----- synapse/storage/persist_events.py | 13 ++++++++++--- tests/crypto/test_keyring.py | 4 ++-- tests/test_federation.py | 11 +++++++---- 6 files changed, 37 insertions(+), 19 deletions(-) (limited to 'synapse') diff --git a/synapse/server.py b/synapse/server.py index 54a7f4aa5f..0b81af646c 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -226,7 +226,6 @@ class HomeServer(object): self.admin_redaction_ratelimiter = Ratelimiter() self.registration_ratelimiter = Ratelimiter() - self.datastore = None self.datastores = None # Other kwargs are explicit dependencies @@ -236,8 +235,8 @@ class HomeServer(object): def setup(self): logger.info("Setting up.") with self.get_db_conn() as conn: - self.datastore = self.DATASTORE_CLASS(conn, self) - self.datastores = DataStores(self.datastore, conn, self) + datastore = self.DATASTORE_CLASS(conn, self) + self.datastores = DataStores(datastore, conn, self) conn.commit() logger.info("Finished setting up.") diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index a58187a76f..a6429d17ed 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -29,7 +29,7 @@ stored in `synapse.storage.schema`. from synapse.storage.data_stores import DataStores from synapse.storage.data_stores.main import DataStore -from synapse.storage.persist_events import EventsPersistenceStore +from synapse.storage.persist_events import EventsPersistenceStorage __all__ = ["DataStores", "DataStore"] @@ -44,7 +44,7 @@ class Storage(object): # interfaces. self.main = stores.main - self.persistence = EventsPersistenceStore(hs, stores) + self.persistence = EventsPersistenceStorage(hs, stores) def are_all_users_on_domain(txn, database_engine, domain): diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index 6304531cd5..813f34528c 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -146,7 +146,7 @@ class EventsStore( @_retry_on_integrity_error @defer.inlineCallbacks - def _persist_events( + def _persist_events_and_state_updates( self, events_and_contexts, current_state_for_room, @@ -155,18 +155,27 @@ class EventsStore( backfilled=False, delete_existing=False, ): - """Persist events to db + """Persist a set of events alongside updates to the current state and + forward extremities tables. Args: events_and_contexts (list[(EventBase, EventContext)]): - backfilled (bool): + current_state_for_room (dict[str, dict]): Map from room_id to the + current state of the room based on forward extremities + state_delta_for_room (dict[str, tuple]): Map from room_id to tuple + of `(to_delete, to_insert)` where to_delete is a list + of type/state keys to remove from current state, and to_insert + is a map (type,key)->event_id giving the state delta in each + room. + new_forward_extremities (dict[str, list[str]]): Map from room_id + to list of event IDs that are the new forward extremities of + the room. + backfilled (bool) delete_existing (bool): Returns: Deferred: resolves when the events have been persisted """ - if not events_and_contexts: - return # We want to calculate the stream orderings as late as possible, as # we only notify after all events with a lesser stream ordering have diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py index 9a63953d4d..cf66225574 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py @@ -171,7 +171,13 @@ class _EventPeristenceQueue(object): pass -class EventsPersistenceStore(object): +class EventsPersistenceStorage(object): + """High level interface for handling persisting newly received events. + + Takes care of batching up events by room, and calculating the necessary + current state and forward extremity changes. + """ + def __init__(self, hs, stores: DataStores): # We ultimately want to split out the state store from the main store, # so we use separate variables here even though they point to the same @@ -257,7 +263,8 @@ class EventsPersistenceStore(object): def _persist_events( self, events_and_contexts, backfilled=False, delete_existing=False ): - """Persist events to db + """Calculates the change to current state and forward extremities, and + persists the given events and with those updates. Args: events_and_contexts (list[(EventBase, EventContext)]): @@ -399,7 +406,7 @@ class EventsPersistenceStore(object): if current_state is not None: current_state_for_room[room_id] = current_state - yield self.main_store._persist_events( + yield self.main_store._persist_events_and_state_updates( chunk, current_state_for_room=current_state_for_room, state_delta_for_room=state_delta_for_room, diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index c4f0bbd3dd..8efd39c7f7 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -178,7 +178,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): kr = keyring.Keyring(self.hs) key1 = signedjson.key.generate_signing_key(1) - r = self.hs.datastore.store_server_verify_keys( + r = self.hs.get_datastore().store_server_verify_keys( "server9", time.time() * 1000, [("server9", get_key_id(key1), FetchKeyResult(get_verify_key(key1), 1000))], @@ -209,7 +209,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): ) key1 = signedjson.key.generate_signing_key(1) - r = self.hs.datastore.store_server_verify_keys( + r = self.hs.get_datastore().store_server_verify_keys( "server9", time.time() * 1000, [("server9", get_key_id(key1), FetchKeyResult(get_verify_key(key1), None))], diff --git a/tests/test_federation.py b/tests/test_federation.py index a73f18f88e..d1acb16f30 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -36,7 +36,8 @@ class MessageAcceptTests(unittest.TestCase): # Figure out what the most recent event is most_recent = self.successResultOf( maybeDeferred( - self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id + self.homeserver.get_datastore().get_latest_event_ids_in_room, + self.room_id, ) )[0] @@ -75,7 +76,8 @@ class MessageAcceptTests(unittest.TestCase): self.assertEqual( self.successResultOf( maybeDeferred( - self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id + self.homeserver.get_datastore().get_latest_event_ids_in_room, + self.room_id, ) )[0], "$join:test.serv", @@ -97,7 +99,8 @@ class MessageAcceptTests(unittest.TestCase): # Figure out what the most recent event is most_recent = self.successResultOf( maybeDeferred( - self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id + self.homeserver.get_datastore().get_latest_event_ids_in_room, + self.room_id, ) )[0] @@ -137,6 +140,6 @@ class MessageAcceptTests(unittest.TestCase): # Make sure the invalid event isn't there extrem = maybeDeferred( - self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id + self.homeserver.get_datastore().get_latest_event_ids_in_room, self.room_id ) self.assertEqual(self.successResultOf(extrem)[0], "$join:test.serv") -- cgit 1.4.1 From d78b1e339dd813214d8a8316c38a3be31ad8f132 Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Wed, 30 Oct 2019 10:01:53 -0400 Subject: apply changes as a result of PR review --- synapse/handlers/e2e_keys.py | 22 ++++---- synapse/storage/data_stores/main/devices.py | 79 +++++++++++++---------------- 2 files changed, 46 insertions(+), 55 deletions(-) (limited to 'synapse') diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 4ab75a351e..0f320b3764 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -1072,7 +1072,7 @@ class SignatureListItem: class SigningKeyEduUpdater(object): - "Handles incoming signing key updates from federation and updates the DB" + """Handles incoming signing key updates from federation and updates the DB""" def __init__(self, hs, e2e_keys_handler): self.store = hs.get_datastore() @@ -1111,7 +1111,6 @@ class SigningKeyEduUpdater(object): self_signing_key = edu_content.pop("self_signing_key", None) if get_domain_from_id(user_id) != origin: - # TODO: Raise? logger.warning("Got signing key update edu for %r from %r", user_id, origin) return @@ -1122,7 +1121,7 @@ class SigningKeyEduUpdater(object): return self._pending_updates.setdefault(user_id, []).append( - (master_key, self_signing_key, edu_content) + (master_key, self_signing_key) ) yield self._handle_signing_key_updates(user_id) @@ -1147,22 +1146,21 @@ class SigningKeyEduUpdater(object): logger.info("pending updates: %r", pending_updates) - for master_key, self_signing_key, edu_content in pending_updates: + for master_key, self_signing_key in pending_updates: if master_key: yield self.store.set_e2e_cross_signing_key( user_id, "master", master_key ) - device_id = get_verify_key_from_cross_signing_key(master_key)[ - 1 - ].version - device_ids.append(device_id) + _, verify_key = get_verify_key_from_cross_signing_key(master_key) + # verify_key is a VerifyKey from signedjson, which uses + # .version to denote the portion of the key ID after the + # algorithm and colon, which is the device ID + device_ids.append(verify_key.version) if self_signing_key: yield self.store.set_e2e_cross_signing_key( user_id, "self_signing", self_signing_key ) - device_id = get_verify_key_from_cross_signing_key(self_signing_key)[ - 1 - ].version - device_ids.append(device_id) + _, verify_key = get_verify_key_from_cross_signing_key(self_signing_key) + device_ids.append(verify_key.version) yield device_handler.notify_device_update(user_id, device_ids) diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py index 6ac165068e..0b12bc58c4 100644 --- a/synapse/storage/data_stores/main/devices.py +++ b/synapse/storage/data_stores/main/devices.py @@ -92,8 +92,12 @@ class DeviceWorkerStore(SQLBaseStore): @trace @defer.inlineCallbacks def get_devices_by_remote(self, destination, from_stream_id, limit): - """Get stream of updates to send to remote servers + """Get a stream of device updates to send to the given remote server. + Args: + destination (str): The host the device updates are intended for + from_stream_id (int): The minimum stream_id to filter updates by, exclusive + limit (int): Maximum number of device updates to return Returns: Deferred[tuple[int, list[tuple[string,dict]]]]: current stream id (ie, the stream id of the last update included in the @@ -131,7 +135,8 @@ class DeviceWorkerStore(SQLBaseStore): if not updates: return now_stream_id, [] - # get the cross-signing keys of the users the list + # get the cross-signing keys of the users in the list, so that we can + # determine which of the device changes were cross-signing keys users = set(r[0] for r in updates) master_key_by_user = {} self_signing_key_by_user = {} @@ -141,9 +146,12 @@ class DeviceWorkerStore(SQLBaseStore): key_id, verify_key = get_verify_key_from_cross_signing_key( cross_signing_key ) + # verify_key is a VerifyKey from signedjson, which uses + # .version to denote the portion of the key ID after the + # algorithm and colon, which is the device ID master_key_by_user[user] = { "key_info": cross_signing_key, - "pubkey": verify_key.version, + "device_id": verify_key.version, } cross_signing_key = yield self.get_e2e_cross_signing_key( @@ -155,7 +163,7 @@ class DeviceWorkerStore(SQLBaseStore): ) self_signing_key_by_user[user] = { "key_info": cross_signing_key, - "pubkey": verify_key.version, + "device_id": verify_key.version, } # if we have exceeded the limit, we need to exclude any results with the @@ -182,69 +190,54 @@ class DeviceWorkerStore(SQLBaseStore): # context which created the Edu. query_map = {} - for update in updates: - if stream_id_cutoff is not None and update[2] >= stream_id_cutoff: + cross_signing_keys_by_user = {} + for user_id, device_id, update_stream_id, update_context in updates: + if stream_id_cutoff is not None and update_stream_id >= stream_id_cutoff: # Stop processing updates break - # skip over cross-signing keys if ( - update[0] in master_key_by_user - and update[1] == master_key_by_user[update[0]]["pubkey"] - ) or ( - update[0] in master_key_by_user - and update[1] == self_signing_key_by_user[update[0]]["pubkey"] + user_id in master_key_by_user + and device_id == master_key_by_user[user_id]["device_id"] ): - continue - - key = (update[0], update[1]) - - update_context = update[3] - update_stream_id = update[2] - - previous_update_stream_id, _ = query_map.get(key, (0, None)) - - if update_stream_id > previous_update_stream_id: - query_map[key] = (update_stream_id, update_context) - - # If we didn't find any updates with a stream_id lower than the cutoff, it - # means that there are more than limit updates all of which have the same - # steam_id. - - # figure out which cross-signing keys were changed by intersecting the - # update list with the master/self-signing key by user maps - cross_signing_keys_by_user = {} - for user_id, device_id, stream, _opentracing_context in updates: - if device_id == master_key_by_user.get(user_id, {}).get("pubkey", None): result = cross_signing_keys_by_user.setdefault(user_id, {}) result["master_key"] = master_key_by_user[user_id]["key_info"] - elif device_id == self_signing_key_by_user.get(user_id, {}).get( - "pubkey", None + elif ( + user_id in master_key_by_user + and device_id == self_signing_key_by_user[user_id]["device_id"] ): result = cross_signing_keys_by_user.setdefault(user_id, {}) result["self_signing_key"] = self_signing_key_by_user[user_id][ "key_info" ] + else: + key = (user_id, device_id) - cross_signing_results = [] + previous_update_stream_id, _ = query_map.get(key, (0, None)) - # add the updated cross-signing keys to the results list - for user_id, result in iteritems(cross_signing_keys_by_user): - result["user_id"] = user_id - # FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec - cross_signing_results.append(("org.matrix.signing_key_update", result)) + if update_stream_id > previous_update_stream_id: + query_map[key] = (update_stream_id, update_context) + + # If we didn't find any updates with a stream_id lower than the cutoff, it + # means that there are more than limit updates all of which have the same + # steam_id. # That should only happen if a client is spamming the server with new # devices, in which case E2E isn't going to work well anyway. We'll just # skip that stream_id and return an empty list, and continue with the next # stream_id next time. - if not query_map and not cross_signing_results: + if not query_map and not cross_signing_keys_by_user: return stream_id_cutoff, [] results = yield self._get_device_update_edus_by_remote( destination, from_stream_id, query_map ) - results.extend(cross_signing_results) + + # add the updated cross-signing keys to the results list + for user_id, result in iteritems(cross_signing_keys_by_user): + result["user_id"] = user_id + # FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec + results.append(("org.matrix.signing_key_update", result)) return now_stream_id, results -- cgit 1.4.1 From bc32f102cd1144923581771b0cc84ead4d99cefb Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Wed, 30 Oct 2019 10:07:36 -0400 Subject: black --- synapse/handlers/e2e_keys.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'synapse') diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 0f320b3764..1ab471b3be 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -1160,7 +1160,9 @@ class SigningKeyEduUpdater(object): yield self.store.set_e2e_cross_signing_key( user_id, "self_signing", self_signing_key ) - _, verify_key = get_verify_key_from_cross_signing_key(self_signing_key) + _, verify_key = get_verify_key_from_cross_signing_key( + self_signing_key + ) device_ids.append(verify_key.version) yield device_handler.notify_device_update(user_id, device_ids) -- cgit 1.4.1 From fa0dcbc8fa396fa78aabc524728e08fd84b70a0c Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Tue, 29 Oct 2019 18:35:49 +0000 Subject: Store labels for new events --- synapse/api/constants.py | 3 +++ synapse/storage/data_stores/main/events.py | 20 +++++++++++++++++++- 2 files changed, 22 insertions(+), 1 deletion(-) (limited to 'synapse') diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 312196675e..999ec02fd9 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -138,3 +138,6 @@ class LimitBlockingTypes(object): MONTHLY_ACTIVE_USER = "monthly_active_user" HS_DISABLED = "hs_disabled" + + +LabelsField = "org.matrix.labels" diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index 03b5111c5d..f80b5f1a3f 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -29,7 +29,7 @@ from prometheus_client import Counter, Histogram from twisted.internet import defer import synapse.metrics -from synapse.api.constants import EventTypes +from synapse.api.constants import EventTypes, LabelsField from synapse.api.errors import SynapseError from synapse.events import EventBase # noqa: F401 from synapse.events.snapshot import EventContext # noqa: F401 @@ -1490,6 +1490,11 @@ class EventsStore( self._handle_event_relations(txn, event) + # Store the labels for this event. + labels = event.content.get(LabelsField) + if labels: + self.insert_labels_for_event_txn(txn, event.event_id, labels) + # Insert into the room_memberships table. self._store_room_members_txn( txn, @@ -2477,6 +2482,19 @@ class EventsStore( get_all_updated_current_state_deltas_txn, ) + def insert_labels_for_event_txn(self, txn, event_id, labels): + return self._simple_insert_many_txn( + txn=txn, + table="event_labels", + values=[ + { + "event_id": event_id, + "label": label, + } + for label in labels + ], + ) + AllNewEventsResult = namedtuple( "AllNewEventsResult", -- cgit 1.4.1 From 5db03535d587f9a3dc460ad9c015ab6a35ff5730 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 30 Oct 2019 14:07:48 +0000 Subject: Add StateGroupStorage interface --- synapse/storage/__init__.py | 2 + synapse/storage/persist_events.py | 2 +- synapse/storage/state.py | 232 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 235 insertions(+), 1 deletion(-) (limited to 'synapse') diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index a6429d17ed..0a1a8cc1e5 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -30,6 +30,7 @@ stored in `synapse.storage.schema`. from synapse.storage.data_stores import DataStores from synapse.storage.data_stores.main import DataStore from synapse.storage.persist_events import EventsPersistenceStorage +from synapse.storage.state import StateGroupStorage __all__ = ["DataStores", "DataStore"] @@ -45,6 +46,7 @@ class Storage(object): self.main = stores.main self.persistence = EventsPersistenceStorage(hs, stores) + self.state = StateGroupStorage(hs, stores) def are_all_users_on_domain(txn, database_engine, domain): diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py index cf66225574..61dfca8fdc 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py @@ -550,7 +550,7 @@ class EventsPersistenceStorage(object): if missing_event_ids: # Now pull out the state groups for any missing events from DB - event_to_groups = yield self.state_store._get_state_group_for_events( + event_to_groups = yield self.main_store._get_state_group_for_events( missing_event_ids ) event_id_to_state_group.update(event_to_groups) diff --git a/synapse/storage/state.py b/synapse/storage/state.py index a2df8fa827..b382a06dcc 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -19,6 +19,8 @@ from six import iteritems, itervalues import attr +from twisted.internet import defer + from synapse.api.constants import EventTypes logger = logging.getLogger(__name__) @@ -322,3 +324,233 @@ class StateFilter(object): ) return member_filter, non_member_filter + + +class StateGroupStorage(object): + """High level interface to fetching state for event. + """ + + def __init__(self, hs, stores): + self.stores = stores + + def get_state_group_delta(self, state_group): + """Given a state group try to return a previous group and a delta between + the old and the new. + + Returns: + (prev_group, delta_ids), where both may be None. + """ + + return self.stores.main.get_state_group_delta(state_group) + + @defer.inlineCallbacks + def get_state_groups_ids(self, _room_id, event_ids): + """Get the event IDs of all the state for the state groups for the given events + + Args: + _room_id (str): id of the room for these events + event_ids (iterable[str]): ids of the events + + Returns: + Deferred[dict[int, dict[tuple[str, str], str]]]: + dict of state_group_id -> (dict of (type, state_key) -> event id) + """ + if not event_ids: + return {} + + event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids) + + groups = set(itervalues(event_to_groups)) + group_to_state = yield self.stores.main._get_state_for_groups(groups) + + return group_to_state + + @defer.inlineCallbacks + def get_state_ids_for_group(self, state_group): + """Get the event IDs of all the state in the given state group + + Args: + state_group (int) + + Returns: + Deferred[dict]: Resolves to a map of (type, state_key) -> event_id + """ + group_to_state = yield self._get_state_for_groups((state_group,)) + + return group_to_state[state_group] + + @defer.inlineCallbacks + def get_state_groups(self, room_id, event_ids): + """ Get the state groups for the given list of event_ids + Returns: + Deferred[dict[int, list[EventBase]]]: + dict of state_group_id -> list of state events. + """ + if not event_ids: + return {} + + group_to_ids = yield self.get_state_groups_ids(room_id, event_ids) + + state_event_map = yield self.stores.main.get_events( + [ + ev_id + for group_ids in itervalues(group_to_ids) + for ev_id in itervalues(group_ids) + ], + get_prev_content=False, + ) + + return { + group: [ + state_event_map[v] + for v in itervalues(event_id_map) + if v in state_event_map + ] + for group, event_id_map in iteritems(group_to_ids) + } + + def _get_state_groups_from_groups(self, groups, state_filter): + """Returns the state groups for a given set of groups, filtering on + types of state events. + + Args: + groups(list[int]): list of state group IDs to query + state_filter (StateFilter): The state filter used to fetch state + from the database. + Returns: + Deferred[dict[int, dict[tuple[str, str], str]]]: + dict of state_group_id -> (dict of (type, state_key) -> event id) + """ + + return self.stores.main._get_state_groups_from_groups(groups, state_filter) + + @defer.inlineCallbacks + def get_state_for_events(self, event_ids, state_filter=StateFilter.all()): + """Given a list of event_ids and type tuples, return a list of state + dicts for each event. + Args: + event_ids (list[string]) + state_filter (StateFilter): The state filter used to fetch state + from the database. + Returns: + deferred: A dict of (event_id) -> (type, state_key) -> [state_events] + """ + event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids) + + groups = set(itervalues(event_to_groups)) + group_to_state = yield self.stores.main._get_state_for_groups( + groups, state_filter + ) + + state_event_map = yield self.stores.main.get_events( + [ev_id for sd in itervalues(group_to_state) for ev_id in itervalues(sd)], + get_prev_content=False, + ) + + event_to_state = { + event_id: { + k: state_event_map[v] + for k, v in iteritems(group_to_state[group]) + if v in state_event_map + } + for event_id, group in iteritems(event_to_groups) + } + + return {event: event_to_state[event] for event in event_ids} + + @defer.inlineCallbacks + def get_state_ids_for_events(self, event_ids, state_filter=StateFilter.all()): + """ + Get the state dicts corresponding to a list of events, containing the event_ids + of the state events (as opposed to the events themselves) + + Args: + event_ids(list(str)): events whose state should be returned + state_filter (StateFilter): The state filter used to fetch state + from the database. + + Returns: + A deferred dict from event_id -> (type, state_key) -> event_id + """ + event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids) + + groups = set(itervalues(event_to_groups)) + group_to_state = yield self.stores.main._get_state_for_groups( + groups, state_filter + ) + + event_to_state = { + event_id: group_to_state[group] + for event_id, group in iteritems(event_to_groups) + } + + return {event: event_to_state[event] for event in event_ids} + + @defer.inlineCallbacks + def get_state_for_event(self, event_id, state_filter=StateFilter.all()): + """ + Get the state dict corresponding to a particular event + + Args: + event_id(str): event whose state should be returned + state_filter (StateFilter): The state filter used to fetch state + from the database. + + Returns: + A deferred dict from (type, state_key) -> state_event + """ + state_map = yield self.get_state_for_events([event_id], state_filter) + return state_map[event_id] + + @defer.inlineCallbacks + def get_state_ids_for_event(self, event_id, state_filter=StateFilter.all()): + """ + Get the state dict corresponding to a particular event + + Args: + event_id(str): event whose state should be returned + state_filter (StateFilter): The state filter used to fetch state + from the database. + + Returns: + A deferred dict from (type, state_key) -> state_event + """ + state_map = yield self.get_state_ids_for_events([event_id], state_filter) + return state_map[event_id] + + def _get_state_for_groups(self, groups, state_filter=StateFilter.all()): + """Gets the state at each of a list of state groups, optionally + filtering by type/state_key + + Args: + groups (iterable[int]): list of state groups for which we want + to get the state. + state_filter (StateFilter): The state filter used to fetch state + from the database. + Returns: + Deferred[dict[int, dict[tuple[str, str], str]]]: + dict of state_group_id -> (dict of (type, state_key) -> event id) + """ + return self.stores.main._get_state_for_groups(groups, state_filter) + + def store_state_group( + self, event_id, room_id, prev_group, delta_ids, current_state_ids + ): + """Store a new set of state, returning a newly assigned state group. + + Args: + event_id (str): The event ID for which the state was calculated + room_id (str) + prev_group (int|None): A previous state group for the room, optional. + delta_ids (dict|None): The delta between state at `prev_group` and + `current_state_ids`, if `prev_group` was given. Same format as + `current_state_ids`. + current_state_ids (dict): The state to store. Map of (type, state_key) + to event_id. + + Returns: + Deferred[int]: The state group ID + """ + return self.stores.main.store_state_group( + event_id, room_id, prev_group, delta_ids, current_state_ids + ) -- cgit 1.4.1 From 69f0054ce675bd9d35104c39af9fae9a908b7f33 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 23 Oct 2019 17:25:54 +0100 Subject: Port to use state storage --- synapse/handlers/admin.py | 7 +- synapse/handlers/device.py | 3 +- synapse/handlers/events.py | 6 +- synapse/handlers/federation.py | 19 ++--- synapse/handlers/initial_sync.py | 14 ++-- synapse/handlers/message.py | 10 +-- synapse/handlers/pagination.py | 6 +- synapse/handlers/room.py | 6 +- synapse/handlers/search.py | 12 ++-- synapse/handlers/sync.py | 20 +++--- synapse/notifier.py | 6 +- synapse/push/httppusher.py | 3 +- synapse/push/mailer.py | 3 +- synapse/push/push_tools.py | 9 +-- synapse/state/__init__.py | 13 ++-- synapse/visibility.py | 30 ++++---- tests/storage/test_state.py | 150 ++++++++++++++++++++++++++------------- tests/test_state.py | 3 + tests/test_visibility.py | 11 ++- 19 files changed, 216 insertions(+), 115 deletions(-) (limited to 'synapse') diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index 1a87b58838..6407d56f8e 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -30,6 +30,9 @@ class AdminHandler(BaseHandler): def __init__(self, hs): super(AdminHandler, self).__init__(hs) + self.storage = hs.get_storage() + self.state_store = self.storage.state + @defer.inlineCallbacks def get_whois(self, user): connections = [] @@ -205,7 +208,7 @@ class AdminHandler(BaseHandler): from_key = events[-1].internal_metadata.after - events = yield filter_events_for_client(self.store, user_id, events) + events = yield filter_events_for_client(self.storage, user_id, events) writer.write_events(room_id, events) @@ -241,7 +244,7 @@ class AdminHandler(BaseHandler): for event_id in extremities: if not event_to_unseen_prevs[event_id]: continue - state = yield self.store.get_state_for_event(event_id) + state = yield self.state_store.get_state_for_event(event_id) writer.write_state(room_id, event_id, state) return writer.finished() diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 5f23ee4488..b3fd7e6249 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -46,6 +46,7 @@ class DeviceWorkerHandler(BaseHandler): self.hs = hs self.state = hs.get_state_handler() + self.state_store = hs.get_storage().state self._auth_handler = hs.get_auth_handler() @trace @@ -178,7 +179,7 @@ class DeviceWorkerHandler(BaseHandler): continue # mapping from event_id -> state_dict - prev_state_ids = yield self.store.get_state_ids_for_events(event_ids) + prev_state_ids = yield self.state_store.get_state_ids_for_events(event_ids) # Check if we've joined the room? If so we just blindly add all the users to # the "possibly changed" users. diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index 5e748687e3..45fe13c62f 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -147,6 +147,10 @@ class EventStreamHandler(BaseHandler): class EventHandler(BaseHandler): + def __init__(self, hs): + super(EventHandler, self).__init__(hs) + self.storage = hs.get_storage() + @defer.inlineCallbacks def get_event(self, user, room_id, event_id): """Retrieve a single specified event. @@ -172,7 +176,7 @@ class EventHandler(BaseHandler): is_peeking = user.to_string() not in users filtered = yield filter_events_for_client( - self.store, user.to_string(), [event], is_peeking=is_peeking + self.storage, user.to_string(), [event], is_peeking=is_peeking ) if not filtered: diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 08276fdebf..4d9e33346d 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -110,6 +110,7 @@ class FederationHandler(BaseHandler): self.store = hs.get_datastore() self.storage = hs.get_storage() + self.state_store = self.storage.state self.federation_client = hs.get_federation_client() self.state_handler = hs.get_state_handler() self.server_name = hs.hostname @@ -325,7 +326,7 @@ class FederationHandler(BaseHandler): event_map = {event_id: pdu} try: # Get the state of the events we know about - ours = yield self.store.get_state_groups_ids(room_id, seen) + ours = yield self.state_store.get_state_groups_ids(room_id, seen) # state_maps is a list of mappings from (type, state_key) to event_id state_maps = list( @@ -889,7 +890,7 @@ class FederationHandler(BaseHandler): # We set `check_history_visibility_only` as we might otherwise get false # positives from users having been erased. filtered_extremities = yield filter_events_for_server( - self.store, + self.storage, self.server_name, list(extremities_events.values()), redact=False, @@ -1550,7 +1551,7 @@ class FederationHandler(BaseHandler): event_id, allow_none=False, check_room_id=room_id ) - state_groups = yield self.store.get_state_groups(room_id, [event_id]) + state_groups = yield self.state_store.get_state_groups(room_id, [event_id]) if state_groups: _, state = list(iteritems(state_groups)).pop() @@ -1579,7 +1580,7 @@ class FederationHandler(BaseHandler): event_id, allow_none=False, check_room_id=room_id ) - state_groups = yield self.store.get_state_groups_ids(room_id, [event_id]) + state_groups = yield self.state_store.get_state_groups_ids(room_id, [event_id]) if state_groups: _, state = list(state_groups.items()).pop() @@ -1607,7 +1608,7 @@ class FederationHandler(BaseHandler): events = yield self.store.get_backfill_events(room_id, pdu_list, limit) - events = yield filter_events_for_server(self.store, origin, events) + events = yield filter_events_for_server(self.storage, origin, events) return events @@ -1637,7 +1638,7 @@ class FederationHandler(BaseHandler): if not in_room: raise AuthError(403, "Host not in room.") - events = yield filter_events_for_server(self.store, origin, [event]) + events = yield filter_events_for_server(self.storage, origin, [event]) event = events[0] return event else: @@ -1903,7 +1904,7 @@ class FederationHandler(BaseHandler): # given state at the event. This should correctly handle cases # like bans, especially with state res v2. - state_sets = yield self.store.get_state_groups( + state_sets = yield self.state_store.get_state_groups( event.room_id, extrem_ids ) state_sets = list(state_sets.values()) @@ -1994,7 +1995,7 @@ class FederationHandler(BaseHandler): ) missing_events = yield filter_events_for_server( - self.store, origin, missing_events + self.storage, origin, missing_events ) return missing_events @@ -2235,7 +2236,7 @@ class FederationHandler(BaseHandler): # create a new state group as a delta from the existing one. prev_group = context.state_group - state_group = yield self.store.store_state_group( + state_group = yield self.state_store.store_state_group( event.event_id, event.room_id, prev_group=prev_group, diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index f991efeee3..49c9e031f9 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -43,6 +43,8 @@ class InitialSyncHandler(BaseHandler): self.validator = EventValidator() self.snapshot_cache = SnapshotCache() self._event_serializer = hs.get_event_client_serializer() + self.storage = hs.get_storage() + self.state_store = self.storage.state def snapshot_all_rooms( self, @@ -169,7 +171,7 @@ class InitialSyncHandler(BaseHandler): elif event.membership == Membership.LEAVE: room_end_token = "s%d" % (event.stream_ordering,) deferred_room_state = run_in_background( - self.store.get_state_for_events, [event.event_id] + self.state_store.get_state_for_events, [event.event_id] ) deferred_room_state.addCallback( lambda states: states[event.event_id] @@ -189,7 +191,9 @@ class InitialSyncHandler(BaseHandler): ) ).addErrback(unwrapFirstError) - messages = yield filter_events_for_client(self.store, user_id, messages) + messages = yield filter_events_for_client( + self.storage, user_id, messages + ) start_token = now_token.copy_and_replace("room_key", token) end_token = now_token.copy_and_replace("room_key", room_end_token) @@ -307,7 +311,7 @@ class InitialSyncHandler(BaseHandler): def _room_initial_sync_parted( self, user_id, room_id, pagin_config, membership, member_event_id, is_peeking ): - room_state = yield self.store.get_state_for_events([member_event_id]) + room_state = yield self.state_store.get_state_for_events([member_event_id]) room_state = room_state[member_event_id] @@ -322,7 +326,7 @@ class InitialSyncHandler(BaseHandler): ) messages = yield filter_events_for_client( - self.store, user_id, messages, is_peeking=is_peeking + self.storage, user_id, messages, is_peeking=is_peeking ) start_token = StreamToken.START.copy_and_replace("room_key", token) @@ -414,7 +418,7 @@ class InitialSyncHandler(BaseHandler): ) messages = yield filter_events_for_client( - self.store, user_id, messages, is_peeking=is_peeking + self.storage, user_id, messages, is_peeking=is_peeking ) start_token = now_token.copy_and_replace("room_key", token) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 7908a2d52c..6e2a360262 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -59,6 +59,8 @@ class MessageHandler(object): self.clock = hs.get_clock() self.state = hs.get_state_handler() self.store = hs.get_datastore() + self.storage = hs.get_storage() + self.state_store = self.storage.state self._event_serializer = hs.get_event_client_serializer() @defer.inlineCallbacks @@ -82,7 +84,7 @@ class MessageHandler(object): data = yield self.state.get_current_state(room_id, event_type, state_key) elif membership == Membership.LEAVE: key = (event_type, state_key) - room_state = yield self.store.get_state_for_events( + room_state = yield self.state_store.get_state_for_events( [membership_event_id], StateFilter.from_types([key]) ) data = room_state[membership_event_id].get(key) @@ -135,12 +137,12 @@ class MessageHandler(object): raise NotFoundError("Can't find event for token %s" % (at_token,)) visible_events = yield filter_events_for_client( - self.store, user_id, last_events + self.storage, user_id, last_events ) event = last_events[0] if visible_events: - room_state = yield self.store.get_state_for_events( + room_state = yield self.state_store.get_state_for_events( [event.event_id], state_filter=state_filter ) room_state = room_state[event.event_id] @@ -161,7 +163,7 @@ class MessageHandler(object): ) room_state = yield self.store.get_events(state_ids.values()) elif membership == Membership.LEAVE: - room_state = yield self.store.get_state_for_events( + room_state = yield self.state_store.get_state_for_events( [membership_event_id], state_filter=state_filter ) room_state = room_state[membership_event_id] diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 5744f4579d..b7185fe7a0 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -69,6 +69,8 @@ class PaginationHandler(object): self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() + self.storage = hs.get_storage() + self.state_store = self.storage.state self.clock = hs.get_clock() self._server_name = hs.hostname @@ -255,7 +257,7 @@ class PaginationHandler(object): events = event_filter.filter(events) events = yield filter_events_for_client( - self.store, user_id, events, is_peeking=(member_event_id is None) + self.storage, user_id, events, is_peeking=(member_event_id is None) ) if not events: @@ -274,7 +276,7 @@ class PaginationHandler(object): (EventTypes.Member, event.sender) for event in events ) - state_ids = yield self.store.get_state_ids_for_event( + state_ids = yield self.state_store.get_state_ids_for_event( events[0].event_id, state_filter=state_filter ) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 2816bd8f87..84bad39815 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -822,6 +822,8 @@ class RoomContextHandler(object): def __init__(self, hs): self.hs = hs self.store = hs.get_datastore() + self.storage = hs.get_storage() + self.state_store = self.storage.state @defer.inlineCallbacks def get_event_context(self, user, room_id, event_id, limit, event_filter): @@ -848,7 +850,7 @@ class RoomContextHandler(object): def filter_evts(events): return filter_events_for_client( - self.store, user.to_string(), events, is_peeking=is_peeking + self.storage, user.to_string(), events, is_peeking=is_peeking ) event = yield self.store.get_event( @@ -890,7 +892,7 @@ class RoomContextHandler(object): # first? Shouldn't we be consistent with /sync? # https://github.com/matrix-org/matrix-doc/issues/687 - state = yield self.store.get_state_for_events( + state = yield self.state_store.get_state_for_events( [last_event_id], state_filter=state_filter ) results["state"] = list(state[last_event_id].values()) diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index cd5e90bacb..f4d8a60774 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -35,6 +35,8 @@ class SearchHandler(BaseHandler): def __init__(self, hs): super(SearchHandler, self).__init__(hs) self._event_serializer = hs.get_event_client_serializer() + self.storage = hs.get_storage() + self.state_store = self.storage.state @defer.inlineCallbacks def get_old_rooms_from_upgraded_room(self, room_id): @@ -221,7 +223,7 @@ class SearchHandler(BaseHandler): filtered_events = search_filter.filter([r["event"] for r in results]) events = yield filter_events_for_client( - self.store, user.to_string(), filtered_events + self.storage, user.to_string(), filtered_events ) events.sort(key=lambda e: -rank_map[e.event_id]) @@ -271,7 +273,7 @@ class SearchHandler(BaseHandler): filtered_events = search_filter.filter([r["event"] for r in results]) events = yield filter_events_for_client( - self.store, user.to_string(), filtered_events + self.storage, user.to_string(), filtered_events ) room_events.extend(events) @@ -340,11 +342,11 @@ class SearchHandler(BaseHandler): ) res["events_before"] = yield filter_events_for_client( - self.store, user.to_string(), res["events_before"] + self.storage, user.to_string(), res["events_before"] ) res["events_after"] = yield filter_events_for_client( - self.store, user.to_string(), res["events_after"] + self.storage, user.to_string(), res["events_after"] ) res["start"] = now_token.copy_and_replace( @@ -372,7 +374,7 @@ class SearchHandler(BaseHandler): [(EventTypes.Member, sender) for sender in senders] ) - state = yield self.store.get_state_for_event( + state = yield self.state_store.get_state_for_event( last_event_id, state_filter ) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index d99160e9d7..43a082dcda 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -230,6 +230,8 @@ class SyncHandler(object): self.response_cache = ResponseCache(hs, "sync") self.state = hs.get_state_handler() self.auth = hs.get_auth() + self.storage = hs.get_storage() + self.state_store = self.storage.state # ExpiringCache((User, Device)) -> LruCache(state_key => event_id) self.lazy_loaded_members_cache = ExpiringCache( @@ -417,7 +419,7 @@ class SyncHandler(object): current_state_ids = frozenset(itervalues(current_state_ids)) recents = yield filter_events_for_client( - self.store, + self.storage, sync_config.user.to_string(), recents, always_include_ids=current_state_ids, @@ -470,7 +472,7 @@ class SyncHandler(object): current_state_ids = frozenset(itervalues(current_state_ids)) loaded_recents = yield filter_events_for_client( - self.store, + self.storage, sync_config.user.to_string(), loaded_recents, always_include_ids=current_state_ids, @@ -509,7 +511,7 @@ class SyncHandler(object): Returns: A Deferred map from ((type, state_key)->Event) """ - state_ids = yield self.store.get_state_ids_for_event( + state_ids = yield self.state_store.get_state_ids_for_event( event.event_id, state_filter=state_filter ) if event.is_state(): @@ -580,7 +582,7 @@ class SyncHandler(object): return None last_event = last_events[-1] - state_ids = yield self.store.get_state_ids_for_event( + state_ids = yield self.state_store.get_state_ids_for_event( last_event.event_id, state_filter=StateFilter.from_types( [(EventTypes.Name, ""), (EventTypes.CanonicalAlias, "")] @@ -757,11 +759,11 @@ class SyncHandler(object): if full_state: if batch: - current_state_ids = yield self.store.get_state_ids_for_event( + current_state_ids = yield self.state_store.get_state_ids_for_event( batch.events[-1].event_id, state_filter=state_filter ) - state_ids = yield self.store.get_state_ids_for_event( + state_ids = yield self.state_store.get_state_ids_for_event( batch.events[0].event_id, state_filter=state_filter ) @@ -781,7 +783,7 @@ class SyncHandler(object): ) elif batch.limited: if batch: - state_at_timeline_start = yield self.store.get_state_ids_for_event( + state_at_timeline_start = yield self.state_store.get_state_ids_for_event( batch.events[0].event_id, state_filter=state_filter ) else: @@ -810,7 +812,7 @@ class SyncHandler(object): ) if batch: - current_state_ids = yield self.store.get_state_ids_for_event( + current_state_ids = yield self.state_store.get_state_ids_for_event( batch.events[-1].event_id, state_filter=state_filter ) else: @@ -841,7 +843,7 @@ class SyncHandler(object): # So we fish out all the member events corresponding to the # timeline here, and then dedupe any redundant ones below. - state_ids = yield self.store.get_state_ids_for_event( + state_ids = yield self.state_store.get_state_ids_for_event( batch.events[0].event_id, # we only want members! state_filter=StateFilter.from_types( diff --git a/synapse/notifier.py b/synapse/notifier.py index 4e091314e6..af161a81d7 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -159,6 +159,7 @@ class Notifier(object): self.room_to_user_streams = {} self.hs = hs + self.storage = hs.get_storage() self.event_sources = hs.get_event_sources() self.store = hs.get_datastore() self.pending_new_room_events = [] @@ -425,7 +426,10 @@ class Notifier(object): if name == "room": new_events = yield filter_events_for_client( - self.store, user.to_string(), new_events, is_peeking=is_peeking + self.storage, + user.to_string(), + new_events, + is_peeking=is_peeking, ) elif name == "presence": now = self.clock.time_msec() diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index 6299587808..36e26032a1 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -64,6 +64,7 @@ class HttpPusher(object): def __init__(self, hs, pusherdict): self.hs = hs self.store = self.hs.get_datastore() + self.storage = self.hs.get_storage() self.clock = self.hs.get_clock() self.state_handler = self.hs.get_state_handler() self.user_id = pusherdict["user_name"] @@ -329,7 +330,7 @@ class HttpPusher(object): return d ctx = yield push_tools.get_context_for_event( - self.store, self.state_handler, event, self.user_id + self.storage, self.state_handler, event, self.user_id ) d = { diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index 5b16ab4ae8..1d15a06a58 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -119,6 +119,7 @@ class Mailer(object): self.store = self.hs.get_datastore() self.macaroon_gen = self.hs.get_macaroon_generator() self.state_handler = self.hs.get_state_handler() + self.storage = hs.get_storage() self.app_name = app_name logger.info("Created Mailer for app_name %s" % app_name) @@ -389,7 +390,7 @@ class Mailer(object): } the_events = yield filter_events_for_client( - self.store, user_id, results["events_before"] + self.storage, user_id, results["events_before"] ) the_events.append(notif_event) diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py index a54051a726..de5c101a58 100644 --- a/synapse/push/push_tools.py +++ b/synapse/push/push_tools.py @@ -16,6 +16,7 @@ from twisted.internet import defer from synapse.push.presentable_names import calculate_room_name, name_from_member_event +from synapse.storage import Storage @defer.inlineCallbacks @@ -43,22 +44,22 @@ def get_badge_count(store, user_id): @defer.inlineCallbacks -def get_context_for_event(store, state_handler, ev, user_id): +def get_context_for_event(storage: Storage, state_handler, ev, user_id): ctx = {} - room_state_ids = yield store.get_state_ids_for_event(ev.event_id) + room_state_ids = yield storage.state.get_state_ids_for_event(ev.event_id) # we no longer bother setting room_alias, and make room_name the # human-readable name instead, be that m.room.name, an alias or # a list of people in the room name = yield calculate_room_name( - store, room_state_ids, user_id, fallback_to_single_member=False + storage.main, room_state_ids, user_id, fallback_to_single_member=False ) if name: ctx["name"] = name sender_state_event_id = room_state_ids[("m.room.member", ev.sender)] - sender_state_event = yield store.get_event(sender_state_event_id) + sender_state_event = yield storage.main.get_event(sender_state_event_id) ctx["sender_display_name"] = name_from_member_event(sender_state_event) return ctx diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index dc9f5a9008..4e91eb66fe 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -103,6 +103,7 @@ class StateHandler(object): def __init__(self, hs): self.clock = hs.get_clock() self.store = hs.get_datastore() + self.state_store = hs.get_storage().state self.hs = hs self._state_resolution_handler = hs.get_state_resolution_handler() @@ -271,7 +272,7 @@ class StateHandler(object): else: current_state_ids = prev_state_ids - state_group = yield self.store.store_state_group( + state_group = yield self.state_store.store_state_group( event.event_id, event.room_id, prev_group=None, @@ -321,7 +322,7 @@ class StateHandler(object): delta_ids = dict(entry.delta_ids) delta_ids[key] = event.event_id - state_group = yield self.store.store_state_group( + state_group = yield self.state_store.store_state_group( event.event_id, event.room_id, prev_group=prev_group, @@ -334,7 +335,7 @@ class StateHandler(object): delta_ids = entry.delta_ids if entry.state_group is None: - entry.state_group = yield self.store.store_state_group( + entry.state_group = yield self.state_store.store_state_group( event.event_id, event.room_id, prev_group=entry.prev_group, @@ -376,14 +377,16 @@ class StateHandler(object): # map from state group id to the state in that state group (where # 'state' is a map from state key to event id) # dict[int, dict[(str, str), str]] - state_groups_ids = yield self.store.get_state_groups_ids(room_id, event_ids) + state_groups_ids = yield self.state_store.get_state_groups_ids( + room_id, event_ids + ) if len(state_groups_ids) == 0: return _StateCacheEntry(state={}, state_group=None) elif len(state_groups_ids) == 1: name, state_list = list(state_groups_ids.items()).pop() - prev_group, delta_ids = yield self.store.get_state_group_delta(name) + prev_group, delta_ids = yield self.state_store.get_state_group_delta(name) return _StateCacheEntry( state=state_list, diff --git a/synapse/visibility.py b/synapse/visibility.py index bf0f1eebd8..8c843febd8 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -23,6 +23,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership from synapse.events.utils import prune_event +from synapse.storage import Storage from synapse.storage.state import StateFilter from synapse.types import get_domain_from_id @@ -43,14 +44,13 @@ MEMBERSHIP_PRIORITY = ( @defer.inlineCallbacks def filter_events_for_client( - store, user_id, events, is_peeking=False, always_include_ids=frozenset() + storage: Storage, user_id, events, is_peeking=False, always_include_ids=frozenset() ): """ Check which events a user is allowed to see Args: - store (synapse.storage.DataStore): our datastore (can also be a worker - store) + storage user_id(str): user id to be checked events(list[synapse.events.EventBase]): sequence of events to be checked is_peeking(bool): should be True if: @@ -68,12 +68,12 @@ def filter_events_for_client( events = list(e for e in events if not e.internal_metadata.is_soft_failed()) types = ((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, user_id)) - event_id_to_state = yield store.get_state_for_events( + event_id_to_state = yield storage.state.get_state_for_events( frozenset(e.event_id for e in events), state_filter=StateFilter.from_types(types), ) - ignore_dict_content = yield store.get_global_account_data_by_type_for_user( + ignore_dict_content = yield storage.main.get_global_account_data_by_type_for_user( "m.ignored_user_list", user_id ) @@ -84,7 +84,7 @@ def filter_events_for_client( else [] ) - erased_senders = yield store.are_users_erased((e.sender for e in events)) + erased_senders = yield storage.main.are_users_erased((e.sender for e in events)) def allowed(event): """ @@ -213,13 +213,17 @@ def filter_events_for_client( @defer.inlineCallbacks def filter_events_for_server( - store, server_name, events, redact=True, check_history_visibility_only=False + storage: Storage, + server_name, + events, + redact=True, + check_history_visibility_only=False, ): """Filter a list of events based on whether given server is allowed to see them. Args: - store (DataStore) + storage server_name (str) events (iterable[FrozenEvent]) redact (bool): Whether to return a redacted version of the event, or @@ -274,7 +278,7 @@ def filter_events_for_server( # Lets check to see if all the events have a history visibility # of "shared" or "world_readable". If thats the case then we don't # need to check membership (as we know the server is in the room). - event_to_state_ids = yield store.get_state_ids_for_events( + event_to_state_ids = yield storage.state.get_state_ids_for_events( frozenset(e.event_id for e in events), state_filter=StateFilter.from_types( types=((EventTypes.RoomHistoryVisibility, ""),) @@ -292,14 +296,14 @@ def filter_events_for_server( if not visibility_ids: all_open = True else: - event_map = yield store.get_events(visibility_ids) + event_map = yield storage.main.get_events(visibility_ids) all_open = all( e.content.get("history_visibility") in (None, "shared", "world_readable") for e in itervalues(event_map) ) if not check_history_visibility_only: - erased_senders = yield store.are_users_erased((e.sender for e in events)) + erased_senders = yield storage.main.are_users_erased((e.sender for e in events)) else: # We don't want to check whether users are erased, which is equivalent # to no users having been erased. @@ -328,7 +332,7 @@ def filter_events_for_server( # first, for each event we're wanting to return, get the event_ids # of the history vis and membership state at those events. - event_to_state_ids = yield store.get_state_ids_for_events( + event_to_state_ids = yield storage.state.get_state_ids_for_events( frozenset(e.event_id for e in events), state_filter=StateFilter.from_types( types=((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, None)) @@ -358,7 +362,7 @@ def filter_events_for_server( return False return state_key[idx + 1 :] == server_name - event_map = yield store.get_events( + event_map = yield storage.main.get_events( [ e_id for e_id, key in iteritems(event_id_to_state_key) diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index d573a3e07b..43200654f1 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -35,6 +35,7 @@ class StateStoreTestCase(tests.unittest.TestCase): self.store = hs.get_datastore() self.storage = hs.get_storage() + self.state_datastore = self.store self.event_builder_factory = hs.get_event_builder_factory() self.event_creation_handler = hs.get_event_creation_handler() @@ -83,7 +84,7 @@ class StateStoreTestCase(tests.unittest.TestCase): self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"} ) - state_group_map = yield self.store.get_state_groups_ids( + state_group_map = yield self.storage.state.get_state_groups_ids( self.room, [e2.event_id] ) self.assertEqual(len(state_group_map), 1) @@ -102,7 +103,9 @@ class StateStoreTestCase(tests.unittest.TestCase): self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"} ) - state_group_map = yield self.store.get_state_groups(self.room, [e2.event_id]) + state_group_map = yield self.storage.state.get_state_groups( + self.room, [e2.event_id] + ) self.assertEqual(len(state_group_map), 1) state_list = list(state_group_map.values())[0] @@ -142,7 +145,7 @@ class StateStoreTestCase(tests.unittest.TestCase): ) # check we get the full state as of the final event - state = yield self.store.get_state_for_event(e5.event_id) + state = yield self.storage.state.get_state_for_event(e5.event_id) self.assertIsNotNone(e4) @@ -158,21 +161,21 @@ class StateStoreTestCase(tests.unittest.TestCase): ) # check we can filter to the m.room.name event (with a '' state key) - state = yield self.store.get_state_for_event( + state = yield self.storage.state.get_state_for_event( e5.event_id, StateFilter.from_types([(EventTypes.Name, "")]) ) self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state) # check we can filter to the m.room.name event (with a wildcard None state key) - state = yield self.store.get_state_for_event( + state = yield self.storage.state.get_state_for_event( e5.event_id, StateFilter.from_types([(EventTypes.Name, None)]) ) self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state) # check we can grab the m.room.member events (with a wildcard None state key) - state = yield self.store.get_state_for_event( + state = yield self.storage.state.get_state_for_event( e5.event_id, StateFilter.from_types([(EventTypes.Member, None)]) ) @@ -182,7 +185,7 @@ class StateStoreTestCase(tests.unittest.TestCase): # check we can grab a specific room member without filtering out the # other event types - state = yield self.store.get_state_for_event( + state = yield self.storage.state.get_state_for_event( e5.event_id, state_filter=StateFilter( types={EventTypes.Member: {self.u_alice.to_string()}}, @@ -200,7 +203,7 @@ class StateStoreTestCase(tests.unittest.TestCase): ) # check that we can grab everything except members - state = yield self.store.get_state_for_event( + state = yield self.storage.state.get_state_for_event( e5.event_id, state_filter=StateFilter( types={EventTypes.Member: set()}, include_others=True @@ -216,13 +219,18 @@ class StateStoreTestCase(tests.unittest.TestCase): ####################################################### room_id = self.room.to_string() - group_ids = yield self.store.get_state_groups_ids(room_id, [e5.event_id]) + group_ids = yield self.storage.state.get_state_groups_ids( + room_id, [e5.event_id] + ) group = list(group_ids.keys())[0] # test _get_state_for_group_using_cache correctly filters out members # with types=[] - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_cache, group, state_filter=StateFilter( types={EventTypes.Member: set()}, include_others=True @@ -238,8 +246,11 @@ class StateStoreTestCase(tests.unittest.TestCase): state_dict, ) - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_members_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( types={EventTypes.Member: set()}, include_others=True @@ -251,8 +262,11 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_state_for_group_using_cache correctly filters in members # with wildcard types - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_cache, group, state_filter=StateFilter( types={EventTypes.Member: None}, include_others=True @@ -268,8 +282,11 @@ class StateStoreTestCase(tests.unittest.TestCase): state_dict, ) - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_members_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( types={EventTypes.Member: None}, include_others=True @@ -288,8 +305,11 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_state_for_group_using_cache correctly filters in members # with specific types - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_cache, group, state_filter=StateFilter( types={EventTypes.Member: {e5.state_key}}, include_others=True @@ -305,8 +325,11 @@ class StateStoreTestCase(tests.unittest.TestCase): state_dict, ) - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_members_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( types={EventTypes.Member: {e5.state_key}}, include_others=True @@ -318,8 +341,11 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_state_for_group_using_cache correctly filters in members # with specific types - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_members_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( types={EventTypes.Member: {e5.state_key}}, include_others=False @@ -332,9 +358,11 @@ class StateStoreTestCase(tests.unittest.TestCase): ####################################################### # deliberately remove e2 (room name) from the _state_group_cache - (is_all, known_absent, state_dict_ids) = self.store._state_group_cache.get( - group - ) + ( + is_all, + known_absent, + state_dict_ids, + ) = self.state_datastore._state_group_cache.get(group) self.assertEqual(is_all, True) self.assertEqual(known_absent, set()) @@ -347,18 +375,20 @@ class StateStoreTestCase(tests.unittest.TestCase): ) state_dict_ids.pop((e2.type, e2.state_key)) - self.store._state_group_cache.invalidate(group) - self.store._state_group_cache.update( - sequence=self.store._state_group_cache.sequence, + self.state_datastore._state_group_cache.invalidate(group) + self.state_datastore._state_group_cache.update( + sequence=self.state_datastore._state_group_cache.sequence, key=group, value=state_dict_ids, # list fetched keys so it knows it's partial fetched_keys=((e1.type, e1.state_key),), ) - (is_all, known_absent, state_dict_ids) = self.store._state_group_cache.get( - group - ) + ( + is_all, + known_absent, + state_dict_ids, + ) = self.state_datastore._state_group_cache.get(group) self.assertEqual(is_all, False) self.assertEqual(known_absent, set([(e1.type, e1.state_key)])) @@ -370,8 +400,11 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_state_for_group_using_cache correctly filters out members # with types=[] room_id = self.room.to_string() - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_cache, group, state_filter=StateFilter( types={EventTypes.Member: set()}, include_others=True @@ -382,8 +415,11 @@ class StateStoreTestCase(tests.unittest.TestCase): self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict) room_id = self.room.to_string() - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_members_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( types={EventTypes.Member: set()}, include_others=True @@ -395,8 +431,11 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_state_for_group_using_cache correctly filters in members # wildcard types - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_cache, group, state_filter=StateFilter( types={EventTypes.Member: None}, include_others=True @@ -406,8 +445,11 @@ class StateStoreTestCase(tests.unittest.TestCase): self.assertEqual(is_all, False) self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict) - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_members_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( types={EventTypes.Member: None}, include_others=True @@ -425,8 +467,11 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_state_for_group_using_cache correctly filters in members # with specific types - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_cache, group, state_filter=StateFilter( types={EventTypes.Member: {e5.state_key}}, include_others=True @@ -436,8 +481,11 @@ class StateStoreTestCase(tests.unittest.TestCase): self.assertEqual(is_all, False) self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict) - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_members_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( types={EventTypes.Member: {e5.state_key}}, include_others=True @@ -449,8 +497,11 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_state_for_group_using_cache correctly filters in members # with specific types - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_cache, group, state_filter=StateFilter( types={EventTypes.Member: {e5.state_key}}, include_others=False @@ -460,8 +511,11 @@ class StateStoreTestCase(tests.unittest.TestCase): self.assertEqual(is_all, False) self.assertDictEqual({}, state_dict) - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_members_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( types={EventTypes.Member: {e5.state_key}}, include_others=False diff --git a/tests/test_state.py b/tests/test_state.py index 610ec9fb46..38246555bd 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -158,10 +158,12 @@ class Graph(object): class StateTestCase(unittest.TestCase): def setUp(self): self.store = StateGroupStore() + storage = Mock(main=self.store, state=self.store) hs = Mock( spec_set=[ "config", "get_datastore", + "get_storage", "get_auth", "get_state_handler", "get_clock", @@ -174,6 +176,7 @@ class StateTestCase(unittest.TestCase): hs.get_clock.return_value = MockClock() hs.get_auth.return_value = Auth(hs) hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs) + hs.get_storage.return_value = storage self.state = StateHandler(hs) self.event_id = 0 diff --git a/tests/test_visibility.py b/tests/test_visibility.py index 6ae1ea9b04..f7381b2885 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -14,6 +14,8 @@ # limitations under the License. import logging +from mock import Mock + from twisted.internet import defer from twisted.internet.defer import succeed @@ -63,7 +65,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): events_to_filter.append(evt) filtered = yield filter_events_for_server( - self.store, "test_server", events_to_filter + self.storage, "test_server", events_to_filter ) # the result should be 5 redacted events, and 5 unredacted events. @@ -101,7 +103,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): # ... and the filtering happens. filtered = yield filter_events_for_server( - self.store, "test_server", events_to_filter + self.storage, "test_server", events_to_filter ) for i in range(0, len(events_to_filter)): @@ -258,6 +260,11 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): logger.info("Starting filtering") start = time.time() + + storage = Mock() + storage.main = test_store + storage.state = test_store + filtered = yield filter_events_for_server( test_store, "test_server", events_to_filter ) -- cgit 1.4.1 From acd16ad86a8f61ef261fa82960ee3634864db9ed Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Wed, 30 Oct 2019 15:56:33 +0000 Subject: Implement filtering --- synapse/api/filtering.py | 13 +++++++++++-- synapse/storage/data_stores/main/stream.py | 9 +++++++++ 2 files changed, 20 insertions(+), 2 deletions(-) (limited to 'synapse') diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index 9f06556bd2..a27029c678 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -20,6 +20,7 @@ from jsonschema import FormatChecker from twisted.internet import defer +from synapse.api.constants import LabelsField from synapse.api.errors import SynapseError from synapse.storage.presence import UserPresenceState from synapse.types import RoomID, UserID @@ -66,6 +67,8 @@ ROOM_EVENT_FILTER_SCHEMA = { "contains_url": {"type": "boolean"}, "lazy_load_members": {"type": "boolean"}, "include_redundant_members": {"type": "boolean"}, + "org.matrix.labels": {"type": "array", "items": {"type": "string"}}, + "org.matrix.not_labels": {"type": "array", "items": {"type": "string"}}, }, } @@ -259,6 +262,9 @@ class Filter(object): self.contains_url = self.filter_json.get("contains_url", None) + self.labels = self.filter_json.get("org.matrix.labels", None) + self.not_labels = self.filter_json.get("org.matrix.not_labels", []) + def filters_all_types(self): return "*" in self.not_types @@ -282,6 +288,7 @@ class Filter(object): room_id = None ev_type = "m.presence" contains_url = False + labels = [] else: sender = event.get("sender", None) if not sender: @@ -300,10 +307,11 @@ class Filter(object): content = event.get("content", {}) # check if there is a string url field in the content for filtering purposes contains_url = isinstance(content.get("url"), text_type) + labels = content.get(LabelsField) - return self.check_fields(room_id, sender, ev_type, contains_url) + return self.check_fields(room_id, sender, ev_type, labels, contains_url) - def check_fields(self, room_id, sender, event_type, contains_url): + def check_fields(self, room_id, sender, event_type, labels, contains_url): """Checks whether the filter matches the given event fields. Returns: @@ -313,6 +321,7 @@ class Filter(object): "rooms": lambda v: room_id == v, "senders": lambda v: sender == v, "types": lambda v: _matches_wildcard(event_type, v), + "labels": lambda v: v in labels, } for name, match_func in literal_keys.items(): diff --git a/synapse/storage/data_stores/main/stream.py b/synapse/storage/data_stores/main/stream.py index 263999dfca..907d7f20ba 100644 --- a/synapse/storage/data_stores/main/stream.py +++ b/synapse/storage/data_stores/main/stream.py @@ -229,6 +229,14 @@ def filter_to_clause(event_filter): clauses.append("contains_url = ?") args.append(event_filter.contains_url) + # We're only applying the "labels" filter on the database query, because applying the + # "not_labels" filter via a SQL query is non-trivial. Instead, we let + # event_filter.check_fields apply it, which is not as efficient but makes the + # implementation simpler. + if event_filter.labels: + clauses.append("(%s)" % " OR ".join("label = ?" for _ in event_filter.labels)) + args.extend(event_filter.labels) + return " AND ".join(clauses), args @@ -866,6 +874,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): sql = ( "SELECT event_id, topological_ordering, stream_ordering" " FROM events" + " LEFT JOIN event_labels USING (event_id)" " WHERE outlier = ? AND room_id = ? AND %(bounds)s" " ORDER BY topological_ordering %(order)s," " stream_ordering %(order)s LIMIT ?" -- cgit 1.4.1 From 233b14ebe1d96bd7cc3fcca09663794490186ad2 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Wed, 30 Oct 2019 15:58:05 +0000 Subject: Add index on label --- synapse/storage/data_stores/main/schema/delta/56/event_labels.sql | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'synapse') diff --git a/synapse/storage/data_stores/main/schema/delta/56/event_labels.sql b/synapse/storage/data_stores/main/schema/delta/56/event_labels.sql index 323d797419..9550b0adaa 100644 --- a/synapse/storage/data_stores/main/schema/delta/56/event_labels.sql +++ b/synapse/storage/data_stores/main/schema/delta/56/event_labels.sql @@ -17,4 +17,6 @@ CREATE TABLE IF NOT EXISTS event_labels ( event_id TEXT, label TEXT, PRIMARY KEY(event_id, label) -); \ No newline at end of file +); + +CREATE INDEX event_labels_label_idx ON event_labels(label); -- cgit 1.4.1 From e7943f660add8b602ea5225060bd0d74e6440017 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Wed, 30 Oct 2019 16:15:04 +0000 Subject: Add unit tests --- synapse/api/filtering.py | 2 +- tests/api/test_filtering.py | 51 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+), 1 deletion(-) (limited to 'synapse') diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index a27029c678..bd91b9f018 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -307,7 +307,7 @@ class Filter(object): content = event.get("content", {}) # check if there is a string url field in the content for filtering purposes contains_url = isinstance(content.get("url"), text_type) - labels = content.get(LabelsField) + labels = content.get(LabelsField, []) return self.check_fields(room_id, sender, ev_type, labels, contains_url) diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index 6ba623de13..66b3c828db 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -19,6 +19,7 @@ import jsonschema from twisted.internet import defer +from synapse.api.constants import LabelsField from synapse.api.errors import SynapseError from synapse.api.filtering import Filter from synapse.events import FrozenEvent @@ -95,6 +96,8 @@ class FilteringTestCase(unittest.TestCase): "types": ["m.room.message"], "not_rooms": ["!726s6s6q:example.com"], "not_senders": ["@spam:example.com"], + "org.matrix.labels": ["#fun"], + "org.matrix.not_labels": ["#work"], }, "ephemeral": { "types": ["m.receipt", "m.typing"], @@ -320,6 +323,54 @@ class FilteringTestCase(unittest.TestCase): ) self.assertFalse(Filter(definition).check(event)) + def test_filter_labels(self): + definition = {"org.matrix.labels": ["#fun"]} + event = MockEvent( + sender="@foo:bar", + type="m.room.message", + room_id="!secretbase:unknown", + content={ + LabelsField: ["#fun"] + }, + ) + + self.assertTrue(Filter(definition).check(event)) + + event = MockEvent( + sender="@foo:bar", + type="m.room.message", + room_id="!secretbase:unknown", + content={ + LabelsField: ["#notfun"] + }, + ) + + self.assertFalse(Filter(definition).check(event)) + + def test_filter_not_labels(self): + definition = {"org.matrix.not_labels": ["#fun"]} + event = MockEvent( + sender="@foo:bar", + type="m.room.message", + room_id="!secretbase:unknown", + content={ + LabelsField: ["#fun"] + }, + ) + + self.assertFalse(Filter(definition).check(event)) + + event = MockEvent( + sender="@foo:bar", + type="m.room.message", + room_id="!secretbase:unknown", + content={ + LabelsField: ["#notfun"] + }, + ) + + self.assertTrue(Filter(definition).check(event)) + @defer.inlineCallbacks def test_filter_presence_match(self): user_filter_json = {"presence": {"types": ["m.*"]}} -- cgit 1.4.1 From fe51d6cacf6e1a2da5fc3589d0bc4118342b33dd Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Wed, 30 Oct 2019 17:28:41 +0000 Subject: Add more integration testing --- synapse/storage/data_stores/main/stream.py | 2 +- tests/rest/client/v2_alpha/test_sync.py | 45 ++++++++++++++++++++++++++---- 2 files changed, 40 insertions(+), 7 deletions(-) (limited to 'synapse') diff --git a/synapse/storage/data_stores/main/stream.py b/synapse/storage/data_stores/main/stream.py index 907d7f20ba..cfa34ba1e7 100644 --- a/synapse/storage/data_stores/main/stream.py +++ b/synapse/storage/data_stores/main/stream.py @@ -872,7 +872,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): args.append(int(limit)) sql = ( - "SELECT event_id, topological_ordering, stream_ordering" + "SELECT DISTINCT event_id, topological_ordering, stream_ordering" " FROM events" " LEFT JOIN event_labels USING (event_id)" " WHERE outlier = ? AND room_id = ? AND %(bounds)s" diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py index 0263be010f..a1aa7d87bd 100644 --- a/tests/rest/client/v2_alpha/test_sync.py +++ b/tests/rest/client/v2_alpha/test_sync.py @@ -85,6 +85,7 @@ class SyncFilterTestCase(unittest.HomeserverTestCase): ] def test_sync_filter_labels(self): + """Test that we can filter by a label.""" sync_filter = json.dumps( { "room": { @@ -98,11 +99,12 @@ class SyncFilterTestCase(unittest.HomeserverTestCase): events = self._test_sync_filter_labels(sync_filter) - self.assertEqual(len(events), 2, events) - self.assertEqual(events[0]["content"]["body"], "with label", events[0]) - self.assertEqual(events[1]["content"]["body"], "with label", events[1]) + self.assertEqual(len(events), 2, [event["content"] for event in events]) + self.assertEqual(events[0]["content"]["body"], "with right label", events[0]) + self.assertEqual(events[1]["content"]["body"], "with right label", events[1]) def test_sync_filter_not_labels(self): + """Test that we can filter by the absence of a label.""" sync_filter = json.dumps( { "room": { @@ -116,9 +118,29 @@ class SyncFilterTestCase(unittest.HomeserverTestCase): events = self._test_sync_filter_labels(sync_filter) - self.assertEqual(len(events), 2, events) + self.assertEqual(len(events), 3, [event["content"] for event in events]) self.assertEqual(events[0]["content"]["body"], "without label", events[0]) self.assertEqual(events[1]["content"]["body"], "with wrong label", events[1]) + self.assertEqual(events[2]["content"]["body"], "with two wrong labels", events[2]) + + def test_sync_filter_labels_not_labels(self): + """Test that we can filter by both a label and the absence of another label.""" + sync_filter = json.dumps( + { + "room": { + "timeline": { + "types": [EventTypes.Message], + "org.matrix.labels": ["#work"], + "org.matrix.not_labels": ["#notfun"], + } + } + } + ) + + events = self._test_sync_filter_labels(sync_filter) + + self.assertEqual(len(events), 1, [event["content"] for event in events]) + self.assertEqual(events[0]["content"]["body"], "with wrong label", events[0]) def _test_sync_filter_labels(self, sync_filter): user_id = self.register_user("kermit", "test") @@ -131,7 +153,7 @@ class SyncFilterTestCase(unittest.HomeserverTestCase): type=EventTypes.Message, content={ "msgtype": "m.text", - "body": "with label", + "body": "with right label", LabelsField: ["#fun"], }, tok=tok, @@ -163,7 +185,18 @@ class SyncFilterTestCase(unittest.HomeserverTestCase): type=EventTypes.Message, content={ "msgtype": "m.text", - "body": "with label", + "body": "with two wrong labels", + LabelsField: ["#work", "#notfun"], + }, + tok=tok, + ) + + self.helper.send_event( + room_id=room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "with right label", LabelsField: ["#fun"], }, tok=tok, -- cgit 1.4.1 From dcc069a2e2540862c233a20037e3e59591a42431 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Wed, 30 Oct 2019 18:01:56 +0000 Subject: Lint --- synapse/storage/data_stores/main/events.py | 8 +---- tests/api/test_filtering.py | 16 +++------- tests/rest/client/v1/test_rooms.py | 49 +++++++++++++++--------------- tests/rest/client/v2_alpha/test_sync.py | 12 ++++---- 4 files changed, 35 insertions(+), 50 deletions(-) (limited to 'synapse') diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index f80b5f1a3f..2b900f1ce1 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -2486,13 +2486,7 @@ class EventsStore( return self._simple_insert_many_txn( txn=txn, table="event_labels", - values=[ - { - "event_id": event_id, - "label": label, - } - for label in labels - ], + values=[{"event_id": event_id, "label": label} for label in labels], ) diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index 66b3c828db..e004ab1ee5 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -329,9 +329,7 @@ class FilteringTestCase(unittest.TestCase): sender="@foo:bar", type="m.room.message", room_id="!secretbase:unknown", - content={ - LabelsField: ["#fun"] - }, + content={LabelsField: ["#fun"]}, ) self.assertTrue(Filter(definition).check(event)) @@ -340,9 +338,7 @@ class FilteringTestCase(unittest.TestCase): sender="@foo:bar", type="m.room.message", room_id="!secretbase:unknown", - content={ - LabelsField: ["#notfun"] - }, + content={LabelsField: ["#notfun"]}, ) self.assertFalse(Filter(definition).check(event)) @@ -353,9 +349,7 @@ class FilteringTestCase(unittest.TestCase): sender="@foo:bar", type="m.room.message", room_id="!secretbase:unknown", - content={ - LabelsField: ["#fun"] - }, + content={LabelsField: ["#fun"]}, ) self.assertFalse(Filter(definition).check(event)) @@ -364,9 +358,7 @@ class FilteringTestCase(unittest.TestCase): sender="@foo:bar", type="m.room.message", room_id="!secretbase:unknown", - content={ - LabelsField: ["#notfun"] - }, + content={LabelsField: ["#notfun"]}, ) self.assertTrue(Filter(definition).check(event)) diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index ba2008497e..188f47bd7d 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -813,10 +813,9 @@ class RoomMessageListTestCase(RoomBase): def test_filter_labels(self): """Test that we can filter by a label.""" - message_filter = json.dumps({ - "types": [EventTypes.Message], - "org.matrix.labels": ["#fun"], - }) + message_filter = json.dumps( + {"types": [EventTypes.Message], "org.matrix.labels": ["#fun"]} + ) events = self._test_filter_labels(message_filter) @@ -826,25 +825,28 @@ class RoomMessageListTestCase(RoomBase): def test_filter_not_labels(self): """Test that we can filter by the absence of a label.""" - message_filter = json.dumps({ - "types": [EventTypes.Message], - "org.matrix.not_labels": ["#fun"], - }) + message_filter = json.dumps( + {"types": [EventTypes.Message], "org.matrix.not_labels": ["#fun"]} + ) events = self._test_filter_labels(message_filter) self.assertEqual(len(events), 3, [event["content"] for event in events]) self.assertEqual(events[0]["content"]["body"], "without label", events[0]) self.assertEqual(events[1]["content"]["body"], "with wrong label", events[1]) - self.assertEqual(events[2]["content"]["body"], "with two wrong labels", events[2]) + self.assertEqual( + events[2]["content"]["body"], "with two wrong labels", events[2] + ) def test_filter_labels_not_labels(self): """Test that we can filter by both a label and the absence of another label.""" - sync_filter = json.dumps({ - "types": [EventTypes.Message], - "org.matrix.labels": ["#work"], - "org.matrix.not_labels": ["#notfun"], - }) + sync_filter = json.dumps( + { + "types": [EventTypes.Message], + "org.matrix.labels": ["#work"], + "org.matrix.not_labels": ["#notfun"], + } + ) events = self._test_filter_labels(sync_filter) @@ -859,16 +861,13 @@ class RoomMessageListTestCase(RoomBase): "msgtype": "m.text", "body": "with right label", LabelsField: ["#fun"], - } + }, ) self.helper.send_event( room_id=self.room_id, type=EventTypes.Message, - content={ - "msgtype": "m.text", - "body": "without label", - } + content={"msgtype": "m.text", "body": "without label"}, ) self.helper.send_event( @@ -878,7 +877,7 @@ class RoomMessageListTestCase(RoomBase): "msgtype": "m.text", "body": "with wrong label", LabelsField: ["#work"], - } + }, ) self.helper.send_event( @@ -888,7 +887,7 @@ class RoomMessageListTestCase(RoomBase): "msgtype": "m.text", "body": "with two wrong labels", LabelsField: ["#work", "#notfun"], - } + }, ) self.helper.send_event( @@ -898,14 +897,14 @@ class RoomMessageListTestCase(RoomBase): "msgtype": "m.text", "body": "with right label", LabelsField: ["#fun"], - } + }, ) token = "s0_0_0_0_0_0_0_0_0" request, channel = self.make_request( - "GET", "/rooms/%s/messages?access_token=x&from=%s&filter=%s" % ( - self.room_id, token, message_filter - ) + "GET", + "/rooms/%s/messages?access_token=x&from=%s&filter=%s" + % (self.room_id, token, message_filter), ) self.render(request) diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py index a1aa7d87bd..c5c199d412 100644 --- a/tests/rest/client/v2_alpha/test_sync.py +++ b/tests/rest/client/v2_alpha/test_sync.py @@ -13,10 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. import json + from mock import Mock -from synapse.api.constants import EventTypes, LabelsField import synapse.rest.admin +from synapse.api.constants import EventTypes, LabelsField from synapse.rest.client.v1 import login, room from synapse.rest.client.v2_alpha import sync @@ -121,7 +122,9 @@ class SyncFilterTestCase(unittest.HomeserverTestCase): self.assertEqual(len(events), 3, [event["content"] for event in events]) self.assertEqual(events[0]["content"]["body"], "without label", events[0]) self.assertEqual(events[1]["content"]["body"], "with wrong label", events[1]) - self.assertEqual(events[2]["content"]["body"], "with two wrong labels", events[2]) + self.assertEqual( + events[2]["content"]["body"], "with two wrong labels", events[2] + ) def test_sync_filter_labels_not_labels(self): """Test that we can filter by both a label and the absence of another label.""" @@ -162,10 +165,7 @@ class SyncFilterTestCase(unittest.HomeserverTestCase): self.helper.send_event( room_id=room_id, type=EventTypes.Message, - content={ - "msgtype": "m.text", - "body": "without label", - }, + content={"msgtype": "m.text", "body": "without label"}, tok=tok, ) -- cgit 1.4.1 From 0467f335847dd096913dcf404ca839f61c38758f Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 30 Oct 2019 18:05:00 +0000 Subject: fix delete_existing for _persist_events (#6300) this is part of _retry_on_integrity_error, so should only be on _persist_events_and_state_updates --- changelog.d/6300.misc | 1 + synapse/storage/data_stores/main/events.py | 2 +- synapse/storage/persist_events.py | 5 +---- 3 files changed, 3 insertions(+), 5 deletions(-) create mode 100644 changelog.d/6300.misc (limited to 'synapse') diff --git a/changelog.d/6300.misc b/changelog.d/6300.misc new file mode 100644 index 0000000000..0b3d7a14a1 --- /dev/null +++ b/changelog.d/6300.misc @@ -0,0 +1 @@ +Move `persist_events` out from main data store. diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index 7c3607f308..a4dab86a13 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -82,7 +82,7 @@ def _retry_on_integrity_error(func): @defer.inlineCallbacks def f(self, *args, **kwargs): try: - res = yield func(self, *args, **kwargs) + res = yield func(self, *args, delete_existing=False, **kwargs) except self.database_engine.module.IntegrityError: logger.exception("IntegrityError, retrying.") res = yield func(self, *args, delete_existing=True, **kwargs) diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py index cf66225574..931dcb6558 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py @@ -260,9 +260,7 @@ class EventsPersistenceStorage(object): self._event_persist_queue.handle_queue(room_id, persisting_queue) @defer.inlineCallbacks - def _persist_events( - self, events_and_contexts, backfilled=False, delete_existing=False - ): + def _persist_events(self, events_and_contexts, backfilled=False): """Calculates the change to current state and forward extremities, and persists the given events and with those updates. @@ -412,7 +410,6 @@ class EventsPersistenceStorage(object): state_delta_for_room=state_delta_for_room, new_forward_extremeties=new_forward_extremeties, backfilled=backfilled, - delete_existing=delete_existing, ) @defer.inlineCallbacks -- cgit 1.4.1 From bb6cec27a5ac6d5d6d5f67df21610a63745ac0a9 Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Wed, 30 Oct 2019 14:57:34 -0400 Subject: rename get_devices_by_remote to get_device_updates_by_remote --- synapse/federation/sender/per_destination_queue.py | 4 ++-- synapse/storage/data_stores/main/devices.py | 8 ++++---- tests/handlers/test_typing.py | 4 ++-- tests/storage/test_devices.py | 12 ++++++------ 4 files changed, 14 insertions(+), 14 deletions(-) (limited to 'synapse') diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index d5d4a60c88..6e3012cd41 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -359,7 +359,7 @@ class PerDestinationQueue(object): last_device_list = self._last_device_list_stream_id # Retrieve list of new device updates to send to the destination - now_stream_id, results = yield self._store.get_devices_by_remote( + now_stream_id, results = yield self._store.get_device_updates_by_remote( self._destination, last_device_list, limit=limit ) edus = [ @@ -372,7 +372,7 @@ class PerDestinationQueue(object): for (edu_type, content) in results ] - assert len(edus) <= limit, "get_devices_by_remote returned too many EDUs" + assert len(edus) <= limit, "get_device_updates_by_remote returned too many EDUs" return (edus, now_stream_id) diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py index 0b12bc58c4..717eab4159 100644 --- a/synapse/storage/data_stores/main/devices.py +++ b/synapse/storage/data_stores/main/devices.py @@ -91,7 +91,7 @@ class DeviceWorkerStore(SQLBaseStore): @trace @defer.inlineCallbacks - def get_devices_by_remote(self, destination, from_stream_id, limit): + def get_device_updates_by_remote(self, destination, from_stream_id, limit): """Get a stream of device updates to send to the given remote server. Args: @@ -123,8 +123,8 @@ class DeviceWorkerStore(SQLBaseStore): # stream_id; the rationale being that such a large device list update # is likely an error. updates = yield self.runInteraction( - "get_devices_by_remote", - self._get_devices_by_remote_txn, + "get_device_updates_by_remote", + self._get_device_updates_by_remote_txn, destination, from_stream_id, now_stream_id, @@ -241,7 +241,7 @@ class DeviceWorkerStore(SQLBaseStore): return now_stream_id, results - def _get_devices_by_remote_txn( + def _get_device_updates_by_remote_txn( self, txn, destination, from_stream_id, now_stream_id, limit ): """Return device update information for a given remote destination diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index f360c8e965..5ec568f4e6 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -73,7 +73,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): "get_received_txn_response", "set_received_txn_response", "get_destination_retry_timings", - "get_devices_by_remote", + "get_device_updates_by_remote", # Bits that user_directory needs "get_user_directory_stream_pos", "get_current_state_deltas", @@ -109,7 +109,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): retry_timings_res ) - self.datastore.get_devices_by_remote.return_value = (0, []) + self.datastore.get_device_updates_by_remote.return_value = (0, []) def get_received_txn_response(*args): return defer.succeed(None) diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py index 039cc79357..6f8d990959 100644 --- a/tests/storage/test_devices.py +++ b/tests/storage/test_devices.py @@ -72,7 +72,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase): ) @defer.inlineCallbacks - def test_get_devices_by_remote(self): + def test_get_device_updates_by_remote(self): device_ids = ["device_id1", "device_id2"] # Add two device updates with a single stream_id @@ -81,7 +81,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase): ) # Get all device updates ever meant for this remote - now_stream_id, device_updates = yield self.store.get_devices_by_remote( + now_stream_id, device_updates = yield self.store.get_device_updates_by_remote( "somehost", -1, limit=100 ) @@ -89,7 +89,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase): self._check_devices_in_updates(device_ids, device_updates) @defer.inlineCallbacks - def test_get_devices_by_remote_limited(self): + def test_get_device_updates_by_remote_limited(self): # Test breaking the update limit in 1, 101, and 1 device_id segments # first add one device @@ -115,20 +115,20 @@ class DeviceStoreTestCase(tests.unittest.TestCase): # # first we should get a single update - now_stream_id, device_updates = yield self.store.get_devices_by_remote( + now_stream_id, device_updates = yield self.store.get_device_updates_by_remote( "someotherhost", -1, limit=100 ) self._check_devices_in_updates(device_ids1, device_updates) # Then we should get an empty list back as the 101 devices broke the limit - now_stream_id, device_updates = yield self.store.get_devices_by_remote( + now_stream_id, device_updates = yield self.store.get_device_updates_by_remote( "someotherhost", now_stream_id, limit=100 ) self.assertEqual(len(device_updates), 0) # The 101 devices should've been cleared, so we should now just get one device # update - now_stream_id, device_updates = yield self.store.get_devices_by_remote( + now_stream_id, device_updates = yield self.store.get_device_updates_by_remote( "someotherhost", now_stream_id, limit=100 ) self._check_devices_in_updates(device_ids3, device_updates) -- cgit 1.4.1 From 998f7fe7d4ddb2dddf4d46a8a420a6fd7e37577c Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Wed, 30 Oct 2019 17:22:52 -0400 Subject: make user signatures a separate stream --- synapse/replication/slave/storage/devices.py | 8 ++++++-- synapse/replication/tcp/streams/__init__.py | 1 + synapse/replication/tcp/streams/_base.py | 18 ++++++++++++++++ synapse/storage/data_stores/main/devices.py | 13 +----------- .../storage/data_stores/main/end_to_end_keys.py | 24 ++++++++++++++++++++++ 5 files changed, 50 insertions(+), 14 deletions(-) (limited to 'synapse') diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py index c9f99a8405..f416d73b2e 100644 --- a/synapse/replication/slave/storage/devices.py +++ b/synapse/replication/slave/storage/devices.py @@ -42,7 +42,9 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto def stream_positions(self): result = super(SlavedDeviceStore, self).stream_positions() - result["device_lists"] = self._device_list_id_gen.get_current_token() + result["user_signature"] = result[ + "device_lists" + ] = self._device_list_id_gen.get_current_token() return result def process_replication_rows(self, stream_name, token, rows): @@ -50,13 +52,15 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto self._device_list_id_gen.advance(token) for row in rows: self._invalidate_caches_for_devices(token, row.user_id, row.destination) + elif stream_name == "user_signature": + for row in rows: + self._user_signature_stream_cache.entity_has_changed(row.user_id, token) return super(SlavedDeviceStore, self).process_replication_rows( stream_name, token, rows ) def _invalidate_caches_for_devices(self, token, user_id, destination): self._device_list_stream_cache.entity_has_changed(user_id, token) - self._user_signature_stream_cache.entity_has_changed(user_id, token) if destination: self._device_list_federation_stream_cache.entity_has_changed( diff --git a/synapse/replication/tcp/streams/__init__.py b/synapse/replication/tcp/streams/__init__.py index 634f636dc9..5f52264e84 100644 --- a/synapse/replication/tcp/streams/__init__.py +++ b/synapse/replication/tcp/streams/__init__.py @@ -45,5 +45,6 @@ STREAMS_MAP = { _base.TagAccountDataStream, _base.AccountDataStream, _base.GroupServerStream, + _base.UserSignatureStream, ) } diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index f03111c259..9e45429d49 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -95,6 +95,7 @@ GroupsStreamRow = namedtuple( "GroupsStreamRow", ("group_id", "user_id", "type", "content"), # str # str # str # dict ) +UserSignatureStreamRow = namedtuple("UserSignatureStreamRow", ("user_id")) # str class Stream(object): @@ -438,3 +439,20 @@ class GroupServerStream(Stream): self.update_function = store.get_all_groups_changes super(GroupServerStream, self).__init__(hs) + + +class UserSignatureStream(Stream): + """A user has signed their own device with their user-signing key + """ + + NAME = "user_signature" + _LIMITED = False + ROW_TYPE = UserSignatureStreamRow + + def __init__(self, hs): + store = hs.get_datastore() + + self.current_token = store.get_device_stream_token + self.update_function = store.get_all_user_signature_changes_for_remotes + + super(UserSignatureStream, self).__init__(hs) diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py index a96f09ea7b..f7a3542348 100644 --- a/synapse/storage/data_stores/main/devices.py +++ b/synapse/storage/data_stores/main/devices.py @@ -543,20 +543,9 @@ class DeviceWorkerStore(SQLBaseStore): LEFT JOIN device_lists_outbound_pokes USING (stream_id, user_id, device_id) WHERE ? < stream_id AND stream_id <= ? GROUP BY user_id, destination - UNION - SELECT MAX(stream_id) AS stream_id, from_user_id AS user_id, NULL AS destination - FROM user_signature_stream - WHERE ? < stream_id AND stream_id <= ? - GROUP BY user_id """ return self._execute( - "get_all_device_list_changes_for_remotes", - None, - sql, - from_key, - to_key, - from_key, - to_key, + "get_all_device_list_changes_for_remotes", None, sql, from_key, to_key ) @cached(max_entries=10000) diff --git a/synapse/storage/data_stores/main/end_to_end_keys.py b/synapse/storage/data_stores/main/end_to_end_keys.py index a0bc6f2d18..073412a78d 100644 --- a/synapse/storage/data_stores/main/end_to_end_keys.py +++ b/synapse/storage/data_stores/main/end_to_end_keys.py @@ -315,6 +315,30 @@ class EndToEndKeyWorkerStore(SQLBaseStore): from_user_id, ) + def get_all_user_signature_changes_for_remotes(self, from_key, to_key): + """Return a list of changes from the user signature stream to notify remotes. + Note that the user signature stream represents when a user signs their + device with their user-signing key, which is not published to other + users or servers, so no `destination` is needed in the returned + list. However, this is needed to poke workers. + + Args: + from_key (int): the stream ID to start at (exclusive) + to_key (int): the stream ID to end at (inclusive) + + Returns: + Deferred[list[(int,str)]] a list of `(stream_id, user_id)` + """ + sql = """ + SELECT MAX(stream_id) AS stream_id, from_user_id AS user_id + FROM user_signature_stream + WHERE ? < stream_id AND stream_id <= ? + GROUP BY user_id + """ + return self._execute( + "get_all_user_signature_changes_for_remotes", None, sql, from_key, to_key + ) + class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys): -- cgit 1.4.1 From 54fef094b31e0401d6d35efdf7d5d6b0b9e5d51f Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Thu, 31 Oct 2019 10:23:24 +0000 Subject: Remove usage of deprecated logger.warn method from codebase (#6271) Replace every instance of `logger.warn` with `logger.warning` as the former is deprecated. --- changelog.d/6271.misc | 1 + scripts/move_remote_media_to_new_store.py | 2 +- scripts/synapse_port_db | 6 ++-- synapse/api/auth.py | 2 +- synapse/app/__init__.py | 4 ++- synapse/app/appservice.py | 4 +-- synapse/app/client_reader.py | 4 +-- synapse/app/event_creator.py | 4 +-- synapse/app/federation_reader.py | 4 +-- synapse/app/federation_sender.py | 4 +-- synapse/app/frontend_proxy.py | 4 +-- synapse/app/homeserver.py | 6 ++-- synapse/app/media_repository.py | 4 +-- synapse/app/pusher.py | 4 +-- synapse/app/synchrotron.py | 4 +-- synapse/app/user_dir.py | 4 +-- synapse/config/key.py | 4 +-- synapse/config/logger.py | 2 +- synapse/event_auth.py | 2 +- synapse/federation/federation_base.py | 6 ++-- synapse/federation/federation_client.py | 8 +++-- synapse/federation/federation_server.py | 20 ++++++------ synapse/federation/sender/transaction_manager.py | 4 +-- synapse/federation/transport/server.py | 8 +++-- synapse/groups/attestations.py | 2 +- synapse/groups/groups_server.py | 2 +- synapse/handlers/auth.py | 6 ++-- synapse/handlers/device.py | 4 +-- synapse/handlers/devicemessage.py | 2 +- synapse/handlers/federation.py | 36 ++++++++++++---------- synapse/handlers/groups_local.py | 2 +- synapse/handlers/identity.py | 6 ++-- synapse/handlers/message.py | 2 +- synapse/handlers/profile.py | 2 +- synapse/handlers/room.py | 2 +- synapse/http/client.py | 4 +-- synapse/http/federation/srv_resolver.py | 2 +- synapse/http/matrixfederationclient.py | 10 +++--- synapse/http/request_metrics.py | 2 +- synapse/http/server.py | 2 +- synapse/http/servlet.py | 4 +-- synapse/http/site.py | 4 +-- synapse/logging/context.py | 2 +- synapse/push/httppusher.py | 4 +-- synapse/push/push_rule_evaluator.py | 4 +-- synapse/replication/http/_base.py | 2 +- synapse/replication/http/membership.py | 2 +- synapse/replication/tcp/client.py | 2 +- synapse/replication/tcp/protocol.py | 2 +- synapse/rest/admin/__init__.py | 2 +- synapse/rest/client/v1/login.py | 2 +- synapse/rest/client/v2_alpha/account.py | 14 ++++----- synapse/rest/client/v2_alpha/register.py | 10 +++--- synapse/rest/client/v2_alpha/sync.py | 2 +- synapse/rest/media/v1/media_repository.py | 12 +++++--- synapse/rest/media/v1/preview_url_resource.py | 16 +++++----- synapse/rest/media/v1/thumbnail_resource.py | 4 +-- .../resource_limits_server_notices.py | 2 +- synapse/storage/_base.py | 6 ++-- synapse/storage/data_stores/main/pusher.py | 2 +- synapse/storage/data_stores/main/search.py | 2 +- synapse/util/async_helpers.py | 2 +- synapse/util/caches/__init__.py | 2 +- synapse/util/metrics.py | 6 ++-- synapse/util/rlimit.py | 2 +- 65 files changed, 164 insertions(+), 149 deletions(-) create mode 100644 changelog.d/6271.misc (limited to 'synapse') diff --git a/changelog.d/6271.misc b/changelog.d/6271.misc new file mode 100644 index 0000000000..2369760272 --- /dev/null +++ b/changelog.d/6271.misc @@ -0,0 +1 @@ +Replace every instance of `logger.warn` method with `logger.warning` as the former is deprecated. \ No newline at end of file diff --git a/scripts/move_remote_media_to_new_store.py b/scripts/move_remote_media_to_new_store.py index 12747c6024..b5b63933ab 100755 --- a/scripts/move_remote_media_to_new_store.py +++ b/scripts/move_remote_media_to_new_store.py @@ -72,7 +72,7 @@ def move_media(origin_server, file_id, src_paths, dest_paths): # check that the original exists original_file = src_paths.remote_media_filepath(origin_server, file_id) if not os.path.exists(original_file): - logger.warn( + logger.warning( "Original for %s/%s (%s) does not exist", origin_server, file_id, diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db index 54faed1e83..0d3321682c 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db @@ -157,7 +157,7 @@ class Store( ) except self.database_engine.module.DatabaseError as e: if self.database_engine.is_deadlock(e): - logger.warn("[TXN DEADLOCK] {%s} %d/%d", desc, i, N) + logger.warning("[TXN DEADLOCK] {%s} %d/%d", desc, i, N) if i < N: i += 1 conn.rollback() @@ -432,7 +432,7 @@ class Porter(object): for row in rows: d = dict(zip(headers, row)) if "\0" in d['value']: - logger.warn('dropping search row %s', d) + logger.warning('dropping search row %s', d) else: rows_dict.append(d) @@ -647,7 +647,7 @@ class Porter(object): if isinstance(col, bytes): return bytearray(col) elif isinstance(col, string_types) and "\0" in col: - logger.warn( + logger.warning( "DROPPING ROW: NUL value in table %s col %s: %r", table, headers[j], diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 53f3bb0fa8..5d0b7d2801 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -497,7 +497,7 @@ class Auth(object): token = self.get_access_token_from_request(request) service = self.store.get_app_service_by_token(token) if not service: - logger.warn("Unrecognised appservice access token.") + logger.warning("Unrecognised appservice access token.") raise InvalidClientTokenError() request.authenticated_entity = service.sender return defer.succeed(service) diff --git a/synapse/app/__init__.py b/synapse/app/__init__.py index d877c77834..a01bac2997 100644 --- a/synapse/app/__init__.py +++ b/synapse/app/__init__.py @@ -44,6 +44,8 @@ def check_bind_error(e, address, bind_addresses): bind_addresses (list): Addresses on which the service listens. """ if address == "0.0.0.0" and "::" in bind_addresses: - logger.warn("Failed to listen on 0.0.0.0, continuing because listening on [::]") + logger.warning( + "Failed to listen on 0.0.0.0, continuing because listening on [::]" + ) else: raise e diff --git a/synapse/app/appservice.py b/synapse/app/appservice.py index 767b87d2db..02b900f382 100644 --- a/synapse/app/appservice.py +++ b/synapse/app/appservice.py @@ -94,7 +94,7 @@ class AppserviceServer(HomeServer): ) elif listener["type"] == "metrics": if not self.get_config().enable_metrics: - logger.warn( + logger.warning( ( "Metrics listener configured, but " "enable_metrics is not True!" @@ -103,7 +103,7 @@ class AppserviceServer(HomeServer): else: _base.listen_metrics(listener["bind_addresses"], listener["port"]) else: - logger.warn("Unrecognized listener type: %s", listener["type"]) + logger.warning("Unrecognized listener type: %s", listener["type"]) self.get_tcp_replication().start_replication(self) diff --git a/synapse/app/client_reader.py b/synapse/app/client_reader.py index dbcc414c42..dadb487d5f 100644 --- a/synapse/app/client_reader.py +++ b/synapse/app/client_reader.py @@ -153,7 +153,7 @@ class ClientReaderServer(HomeServer): ) elif listener["type"] == "metrics": if not self.get_config().enable_metrics: - logger.warn( + logger.warning( ( "Metrics listener configured, but " "enable_metrics is not True!" @@ -162,7 +162,7 @@ class ClientReaderServer(HomeServer): else: _base.listen_metrics(listener["bind_addresses"], listener["port"]) else: - logger.warn("Unrecognized listener type: %s", listener["type"]) + logger.warning("Unrecognized listener type: %s", listener["type"]) self.get_tcp_replication().start_replication(self) diff --git a/synapse/app/event_creator.py b/synapse/app/event_creator.py index f20d810ece..d110599a35 100644 --- a/synapse/app/event_creator.py +++ b/synapse/app/event_creator.py @@ -147,7 +147,7 @@ class EventCreatorServer(HomeServer): ) elif listener["type"] == "metrics": if not self.get_config().enable_metrics: - logger.warn( + logger.warning( ( "Metrics listener configured, but " "enable_metrics is not True!" @@ -156,7 +156,7 @@ class EventCreatorServer(HomeServer): else: _base.listen_metrics(listener["bind_addresses"], listener["port"]) else: - logger.warn("Unrecognized listener type: %s", listener["type"]) + logger.warning("Unrecognized listener type: %s", listener["type"]) self.get_tcp_replication().start_replication(self) diff --git a/synapse/app/federation_reader.py b/synapse/app/federation_reader.py index 1ef027a88c..418c086254 100644 --- a/synapse/app/federation_reader.py +++ b/synapse/app/federation_reader.py @@ -132,7 +132,7 @@ class FederationReaderServer(HomeServer): ) elif listener["type"] == "metrics": if not self.get_config().enable_metrics: - logger.warn( + logger.warning( ( "Metrics listener configured, but " "enable_metrics is not True!" @@ -141,7 +141,7 @@ class FederationReaderServer(HomeServer): else: _base.listen_metrics(listener["bind_addresses"], listener["port"]) else: - logger.warn("Unrecognized listener type: %s", listener["type"]) + logger.warning("Unrecognized listener type: %s", listener["type"]) self.get_tcp_replication().start_replication(self) diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py index 04fbb407af..139221ad34 100644 --- a/synapse/app/federation_sender.py +++ b/synapse/app/federation_sender.py @@ -123,7 +123,7 @@ class FederationSenderServer(HomeServer): ) elif listener["type"] == "metrics": if not self.get_config().enable_metrics: - logger.warn( + logger.warning( ( "Metrics listener configured, but " "enable_metrics is not True!" @@ -132,7 +132,7 @@ class FederationSenderServer(HomeServer): else: _base.listen_metrics(listener["bind_addresses"], listener["port"]) else: - logger.warn("Unrecognized listener type: %s", listener["type"]) + logger.warning("Unrecognized listener type: %s", listener["type"]) self.get_tcp_replication().start_replication(self) diff --git a/synapse/app/frontend_proxy.py b/synapse/app/frontend_proxy.py index 9504bfbc70..e647459d0e 100644 --- a/synapse/app/frontend_proxy.py +++ b/synapse/app/frontend_proxy.py @@ -204,7 +204,7 @@ class FrontendProxyServer(HomeServer): ) elif listener["type"] == "metrics": if not self.get_config().enable_metrics: - logger.warn( + logger.warning( ( "Metrics listener configured, but " "enable_metrics is not True!" @@ -213,7 +213,7 @@ class FrontendProxyServer(HomeServer): else: _base.listen_metrics(listener["bind_addresses"], listener["port"]) else: - logger.warn("Unrecognized listener type: %s", listener["type"]) + logger.warning("Unrecognized listener type: %s", listener["type"]) self.get_tcp_replication().start_replication(self) diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index eb54f56853..8997c1f9e7 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -282,7 +282,7 @@ class SynapseHomeServer(HomeServer): reactor.addSystemEventTrigger("before", "shutdown", s.stopListening) elif listener["type"] == "metrics": if not self.get_config().enable_metrics: - logger.warn( + logger.warning( ( "Metrics listener configured, but " "enable_metrics is not True!" @@ -291,7 +291,7 @@ class SynapseHomeServer(HomeServer): else: _base.listen_metrics(listener["bind_addresses"], listener["port"]) else: - logger.warn("Unrecognized listener type: %s", listener["type"]) + logger.warning("Unrecognized listener type: %s", listener["type"]) def run_startup_checks(self, db_conn, database_engine): all_users_native = are_all_users_on_domain( @@ -569,7 +569,7 @@ def run(hs): hs.config.report_stats_endpoint, stats ) except Exception as e: - logger.warn("Error reporting stats: %s", e) + logger.warning("Error reporting stats: %s", e) def performance_stats_init(): try: diff --git a/synapse/app/media_repository.py b/synapse/app/media_repository.py index 6bc7202f33..2c6dd3ef02 100644 --- a/synapse/app/media_repository.py +++ b/synapse/app/media_repository.py @@ -120,7 +120,7 @@ class MediaRepositoryServer(HomeServer): ) elif listener["type"] == "metrics": if not self.get_config().enable_metrics: - logger.warn( + logger.warning( ( "Metrics listener configured, but " "enable_metrics is not True!" @@ -129,7 +129,7 @@ class MediaRepositoryServer(HomeServer): else: _base.listen_metrics(listener["bind_addresses"], listener["port"]) else: - logger.warn("Unrecognized listener type: %s", listener["type"]) + logger.warning("Unrecognized listener type: %s", listener["type"]) self.get_tcp_replication().start_replication(self) diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py index d84732ee3c..01a5ffc363 100644 --- a/synapse/app/pusher.py +++ b/synapse/app/pusher.py @@ -114,7 +114,7 @@ class PusherServer(HomeServer): ) elif listener["type"] == "metrics": if not self.get_config().enable_metrics: - logger.warn( + logger.warning( ( "Metrics listener configured, but " "enable_metrics is not True!" @@ -123,7 +123,7 @@ class PusherServer(HomeServer): else: _base.listen_metrics(listener["bind_addresses"], listener["port"]) else: - logger.warn("Unrecognized listener type: %s", listener["type"]) + logger.warning("Unrecognized listener type: %s", listener["type"]) self.get_tcp_replication().start_replication(self) diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py index 6a7e2fa707..b14da09f47 100644 --- a/synapse/app/synchrotron.py +++ b/synapse/app/synchrotron.py @@ -326,7 +326,7 @@ class SynchrotronServer(HomeServer): ) elif listener["type"] == "metrics": if not self.get_config().enable_metrics: - logger.warn( + logger.warning( ( "Metrics listener configured, but " "enable_metrics is not True!" @@ -335,7 +335,7 @@ class SynchrotronServer(HomeServer): else: _base.listen_metrics(listener["bind_addresses"], listener["port"]) else: - logger.warn("Unrecognized listener type: %s", listener["type"]) + logger.warning("Unrecognized listener type: %s", listener["type"]) self.get_tcp_replication().start_replication(self) diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py index a5d6dc7915..6cb100319f 100644 --- a/synapse/app/user_dir.py +++ b/synapse/app/user_dir.py @@ -150,7 +150,7 @@ class UserDirectoryServer(HomeServer): ) elif listener["type"] == "metrics": if not self.get_config().enable_metrics: - logger.warn( + logger.warning( ( "Metrics listener configured, but " "enable_metrics is not True!" @@ -159,7 +159,7 @@ class UserDirectoryServer(HomeServer): else: _base.listen_metrics(listener["bind_addresses"], listener["port"]) else: - logger.warn("Unrecognized listener type: %s", listener["type"]) + logger.warning("Unrecognized listener type: %s", listener["type"]) self.get_tcp_replication().start_replication(self) diff --git a/synapse/config/key.py b/synapse/config/key.py index ec5d430afb..52ff1b2621 100644 --- a/synapse/config/key.py +++ b/synapse/config/key.py @@ -125,7 +125,7 @@ class KeyConfig(Config): # if neither trusted_key_servers nor perspectives are given, use the default. if "perspectives" not in config and "trusted_key_servers" not in config: - logger.warn(TRUSTED_KEY_SERVER_NOT_CONFIGURED_WARN) + logger.warning(TRUSTED_KEY_SERVER_NOT_CONFIGURED_WARN) key_servers = [{"server_name": "matrix.org"}] else: key_servers = config.get("trusted_key_servers", []) @@ -156,7 +156,7 @@ class KeyConfig(Config): if not self.macaroon_secret_key: # Unfortunately, there are people out there that don't have this # set. Lets just be "nice" and derive one from their secret key. - logger.warn("Config is missing macaroon_secret_key") + logger.warning("Config is missing macaroon_secret_key") seed = bytes(self.signing_key[0]) self.macaroon_secret_key = hashlib.sha256(seed).digest() diff --git a/synapse/config/logger.py b/synapse/config/logger.py index be92e33f93..2d2c1e54df 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -182,7 +182,7 @@ def _reload_stdlib_logging(*args, log_config=None): logger = logging.getLogger("") if not log_config: - logger.warn("Reloaded a blank config?") + logger.warning("Reloaded a blank config?") logging.config.dictConfig(log_config) diff --git a/synapse/event_auth.py b/synapse/event_auth.py index e7b722547b..ec3243b27b 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -77,7 +77,7 @@ def check(room_version, event, auth_events, do_sig_check=True, do_size_check=Tru if auth_events is None: # Oh, we don't know what the state of the room was, so we # are trusting that this is allowed (at least for now) - logger.warn("Trusting event: %s", event.event_id) + logger.warning("Trusting event: %s", event.event_id) return if event.type == EventTypes.Create: diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index 223aace0d9..0e22183280 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -102,7 +102,7 @@ class FederationBase(object): pass if not res: - logger.warn( + logger.warning( "Failed to find copy of %s with valid signature", pdu.event_id ) @@ -173,7 +173,7 @@ class FederationBase(object): return redacted_event if self.spam_checker.check_event_for_spam(pdu): - logger.warn( + logger.warning( "Event contains spam, redacting %s: %s", pdu.event_id, pdu.get_pdu_json(), @@ -185,7 +185,7 @@ class FederationBase(object): def errback(failure, pdu): failure.trap(SynapseError) with PreserveLoggingContext(ctx): - logger.warn( + logger.warning( "Signature check failed for %s: %s", pdu.event_id, failure.getErrorMessage(), diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index f5c1632916..595706d01a 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -522,12 +522,12 @@ class FederationClient(FederationBase): res = yield callback(destination) return res except InvalidResponseError as e: - logger.warn("Failed to %s via %s: %s", description, destination, e) + logger.warning("Failed to %s via %s: %s", description, destination, e) except HttpResponseException as e: if not 500 <= e.code < 600: raise e.to_synapse_error() else: - logger.warn( + logger.warning( "Failed to %s via %s: %i %s", description, destination, @@ -535,7 +535,9 @@ class FederationClient(FederationBase): e.args[0], ) except Exception: - logger.warn("Failed to %s via %s", description, destination, exc_info=1) + logger.warning( + "Failed to %s via %s", description, destination, exc_info=1 + ) raise SynapseError(502, "Failed to %s via any server" % (description,)) diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index d5a19764d2..d942d77a72 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -220,7 +220,7 @@ class FederationServer(FederationBase): try: await self.check_server_matches_acl(origin_host, room_id) except AuthError as e: - logger.warn("Ignoring PDUs for room %s from banned server", room_id) + logger.warning("Ignoring PDUs for room %s from banned server", room_id) for pdu in pdus_by_room[room_id]: event_id = pdu.event_id pdu_results[event_id] = e.error_dict() @@ -233,7 +233,7 @@ class FederationServer(FederationBase): await self._handle_received_pdu(origin, pdu) pdu_results[event_id] = {} except FederationError as e: - logger.warn("Error handling PDU %s: %s", event_id, e) + logger.warning("Error handling PDU %s: %s", event_id, e) pdu_results[event_id] = {"error": str(e)} except Exception as e: f = failure.Failure() @@ -333,7 +333,9 @@ class FederationServer(FederationBase): room_version = await self.store.get_room_version(room_id) if room_version not in supported_versions: - logger.warn("Room version %s not in %s", room_version, supported_versions) + logger.warning( + "Room version %s not in %s", room_version, supported_versions + ) raise IncompatibleRoomVersionError(room_version=room_version) pdu = await self.handler.on_make_join_request(origin, room_id, user_id) @@ -679,7 +681,7 @@ def server_matches_acl_event(server_name, acl_event): # server name is a literal IP allow_ip_literals = acl_event.content.get("allow_ip_literals", True) if not isinstance(allow_ip_literals, bool): - logger.warn("Ignorning non-bool allow_ip_literals flag") + logger.warning("Ignorning non-bool allow_ip_literals flag") allow_ip_literals = True if not allow_ip_literals: # check for ipv6 literals. These start with '['. @@ -693,7 +695,7 @@ def server_matches_acl_event(server_name, acl_event): # next, check the deny list deny = acl_event.content.get("deny", []) if not isinstance(deny, (list, tuple)): - logger.warn("Ignorning non-list deny ACL %s", deny) + logger.warning("Ignorning non-list deny ACL %s", deny) deny = [] for e in deny: if _acl_entry_matches(server_name, e): @@ -703,7 +705,7 @@ def server_matches_acl_event(server_name, acl_event): # then the allow list. allow = acl_event.content.get("allow", []) if not isinstance(allow, (list, tuple)): - logger.warn("Ignorning non-list allow ACL %s", allow) + logger.warning("Ignorning non-list allow ACL %s", allow) allow = [] for e in allow: if _acl_entry_matches(server_name, e): @@ -717,7 +719,7 @@ def server_matches_acl_event(server_name, acl_event): def _acl_entry_matches(server_name, acl_entry): if not isinstance(acl_entry, six.string_types): - logger.warn( + logger.warning( "Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry) ) return False @@ -772,7 +774,7 @@ class FederationHandlerRegistry(object): async def on_edu(self, edu_type, origin, content): handler = self.edu_handlers.get(edu_type) if not handler: - logger.warn("No handler registered for EDU type %s", edu_type) + logger.warning("No handler registered for EDU type %s", edu_type) with start_active_span_from_edu(content, "handle_edu"): try: @@ -785,7 +787,7 @@ class FederationHandlerRegistry(object): def on_query(self, query_type, args): handler = self.query_handlers.get(query_type) if not handler: - logger.warn("No handler registered for query type %s", query_type) + logger.warning("No handler registered for query type %s", query_type) raise NotFoundError("No handler for Query type '%s'" % (query_type,)) return handler(args) diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py index 5b6c79c51a..67b3e1ab6e 100644 --- a/synapse/federation/sender/transaction_manager.py +++ b/synapse/federation/sender/transaction_manager.py @@ -146,7 +146,7 @@ class TransactionManager(object): if code == 200: for e_id, r in response.get("pdus", {}).items(): if "error" in r: - logger.warn( + logger.warning( "TX [%s] {%s} Remote returned error for %s: %s", destination, txn_id, @@ -155,7 +155,7 @@ class TransactionManager(object): ) else: for p in pdus: - logger.warn( + logger.warning( "TX [%s] {%s} Failed to send event %s", destination, txn_id, diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 0f16f21c2d..d6c23f22bd 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -202,7 +202,7 @@ def _parse_auth_header(header_bytes): sig = strip_quotes(param_dict["sig"]) return origin, key, sig except Exception as e: - logger.warn( + logger.warning( "Error parsing auth header '%s': %s", header_bytes.decode("ascii", "replace"), e, @@ -287,10 +287,12 @@ class BaseFederationServlet(object): except NoAuthenticationError: origin = None if self.REQUIRE_AUTH: - logger.warn("authenticate_request failed: missing authentication") + logger.warning( + "authenticate_request failed: missing authentication" + ) raise except Exception as e: - logger.warn("authenticate_request failed: %s", e) + logger.warning("authenticate_request failed: %s", e) raise request_tags = { diff --git a/synapse/groups/attestations.py b/synapse/groups/attestations.py index dfd7ae041b..d950a8b246 100644 --- a/synapse/groups/attestations.py +++ b/synapse/groups/attestations.py @@ -181,7 +181,7 @@ class GroupAttestionRenewer(object): elif not self.is_mine_id(user_id): destination = get_domain_from_id(user_id) else: - logger.warn( + logger.warning( "Incorrectly trying to do attestations for user: %r in %r", user_id, group_id, diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py index 8f10b6adbb..29e8ffc295 100644 --- a/synapse/groups/groups_server.py +++ b/synapse/groups/groups_server.py @@ -488,7 +488,7 @@ class GroupsServerHandler(object): profile = yield self.profile_handler.get_profile_from_cache(user_id) user_profile.update(profile) except Exception as e: - logger.warn("Error getting profile for %s: %s", user_id, e) + logger.warning("Error getting profile for %s: %s", user_id, e) user_profiles.append(user_profile) return {"chunk": user_profiles, "total_user_count_estimate": len(invited_users)} diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 333eb30625..7a0f54ca24 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -525,7 +525,7 @@ class AuthHandler(BaseHandler): result = None if not user_infos: - logger.warn("Attempted to login as %s but they do not exist", user_id) + logger.warning("Attempted to login as %s but they do not exist", user_id) elif len(user_infos) == 1: # a single match (possibly not exact) result = user_infos.popitem() @@ -534,7 +534,7 @@ class AuthHandler(BaseHandler): result = (user_id, user_infos[user_id]) else: # multiple matches, none of them exact - logger.warn( + logger.warning( "Attempted to login as %s but it matches more than one user " "inexactly: %r", user_id, @@ -728,7 +728,7 @@ class AuthHandler(BaseHandler): result = yield self.validate_hash(password, password_hash) if not result: - logger.warn("Failed password login for user %s", user_id) + logger.warning("Failed password login for user %s", user_id) return None return user_id diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 5f23ee4488..befef2cf3d 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -656,7 +656,7 @@ class DeviceListUpdater(object): except (NotRetryingDestination, RequestSendFailed, HttpResponseException): # TODO: Remember that we are now out of sync and try again # later - logger.warn("Failed to handle device list update for %s", user_id) + logger.warning("Failed to handle device list update for %s", user_id) # We abort on exceptions rather than accepting the update # as otherwise synapse will 'forget' that its device list # is out of date. If we bail then we will retry the resync @@ -694,7 +694,7 @@ class DeviceListUpdater(object): # up on storing the total list of devices and only handle the # delta instead. if len(devices) > 1000: - logger.warn( + logger.warning( "Ignoring device list snapshot for %s as it has >1K devs (%d)", user_id, len(devices), diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py index 0043cbea17..73b9e120f5 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py @@ -52,7 +52,7 @@ class DeviceMessageHandler(object): local_messages = {} sender_user_id = content["sender"] if origin != get_domain_from_id(sender_user_id): - logger.warn( + logger.warning( "Dropping device message from %r with spoofed sender %r", origin, sender_user_id, diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 08276fdebf..f1547e3039 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -181,7 +181,7 @@ class FederationHandler(BaseHandler): try: self._sanity_check_event(pdu) except SynapseError as err: - logger.warn( + logger.warning( "[%s %s] Received event failed sanity checks", room_id, event_id ) raise FederationError("ERROR", err.code, err.msg, affected=pdu.event_id) @@ -302,7 +302,7 @@ class FederationHandler(BaseHandler): # following. if sent_to_us_directly: - logger.warn( + logger.warning( "[%s %s] Rejecting: failed to fetch %d prev events: %s", room_id, event_id, @@ -406,7 +406,7 @@ class FederationHandler(BaseHandler): state = [event_map[e] for e in six.itervalues(state_map)] auth_chain = list(auth_chains) except Exception: - logger.warn( + logger.warning( "[%s %s] Error attempting to resolve state at missing " "prev_events", room_id, @@ -519,7 +519,9 @@ class FederationHandler(BaseHandler): # We failed to get the missing events, but since we need to handle # the case of `get_missing_events` not returning the necessary # events anyway, it is safe to simply log the error and continue. - logger.warn("[%s %s]: Failed to get prev_events: %s", room_id, event_id, e) + logger.warning( + "[%s %s]: Failed to get prev_events: %s", room_id, event_id, e + ) return logger.info( @@ -546,7 +548,7 @@ class FederationHandler(BaseHandler): yield self.on_receive_pdu(origin, ev, sent_to_us_directly=False) except FederationError as e: if e.code == 403: - logger.warn( + logger.warning( "[%s %s] Received prev_event %s failed history check.", room_id, event_id, @@ -1060,7 +1062,7 @@ class FederationHandler(BaseHandler): SynapseError if the event does not pass muster """ if len(ev.prev_event_ids()) > 20: - logger.warn( + logger.warning( "Rejecting event %s which has %i prev_events", ev.event_id, len(ev.prev_event_ids()), @@ -1068,7 +1070,7 @@ class FederationHandler(BaseHandler): raise SynapseError(http_client.BAD_REQUEST, "Too many prev_events") if len(ev.auth_event_ids()) > 10: - logger.warn( + logger.warning( "Rejecting event %s which has %i auth_events", ev.event_id, len(ev.auth_event_ids()), @@ -1204,7 +1206,7 @@ class FederationHandler(BaseHandler): with nested_logging_context(p.event_id): yield self.on_receive_pdu(origin, p, sent_to_us_directly=True) except Exception as e: - logger.warn( + logger.warning( "Error handling queued PDU %s from %s: %s", p.event_id, origin, e ) @@ -1251,7 +1253,7 @@ class FederationHandler(BaseHandler): builder=builder ) except AuthError as e: - logger.warn("Failed to create join to %s because %s", room_id, e) + logger.warning("Failed to create join to %s because %s", room_id, e) raise e event_allowed = yield self.third_party_event_rules.check_event_allowed( @@ -1495,7 +1497,7 @@ class FederationHandler(BaseHandler): room_version, event, context, do_sig_check=False ) except AuthError as e: - logger.warn("Failed to create new leave %r because %s", event, e) + logger.warning("Failed to create new leave %r because %s", event, e) raise e return event @@ -1789,7 +1791,7 @@ class FederationHandler(BaseHandler): # cause SynapseErrors in auth.check. We don't want to give up # the attempt to federate altogether in such cases. - logger.warn("Rejecting %s because %s", e.event_id, err.msg) + logger.warning("Rejecting %s because %s", e.event_id, err.msg) if e == event: raise @@ -1845,7 +1847,9 @@ class FederationHandler(BaseHandler): try: yield self.do_auth(origin, event, context, auth_events=auth_events) except AuthError as e: - logger.warn("[%s %s] Rejecting: %s", event.room_id, event.event_id, e.msg) + logger.warning( + "[%s %s] Rejecting: %s", event.room_id, event.event_id, e.msg + ) context.rejected = RejectedReason.AUTH_ERROR @@ -1939,7 +1943,7 @@ class FederationHandler(BaseHandler): try: event_auth.check(room_version, event, auth_events=current_auth_events) except AuthError as e: - logger.warn("Soft-failing %r because %s", event, e) + logger.warning("Soft-failing %r because %s", event, e) event.internal_metadata.soft_failed = True @defer.inlineCallbacks @@ -2038,7 +2042,7 @@ class FederationHandler(BaseHandler): try: event_auth.check(room_version, event, auth_events=auth_events) except AuthError as e: - logger.warn("Failed auth resolution for %r because %s", event, e) + logger.warning("Failed auth resolution for %r because %s", event, e) raise e @defer.inlineCallbacks @@ -2432,7 +2436,7 @@ class FederationHandler(BaseHandler): try: yield self.auth.check_from_context(room_version, event, context) except AuthError as e: - logger.warn("Denying new third party invite %r because %s", event, e) + logger.warning("Denying new third party invite %r because %s", event, e) raise e yield self._check_signature(event, context) @@ -2488,7 +2492,7 @@ class FederationHandler(BaseHandler): try: yield self.auth.check_from_context(room_version, event, context) except AuthError as e: - logger.warn("Denying third party invite %r because %s", event, e) + logger.warning("Denying third party invite %r because %s", event, e) raise e yield self._check_signature(event, context) diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py index 46eb9ee88b..92fecbfc44 100644 --- a/synapse/handlers/groups_local.py +++ b/synapse/handlers/groups_local.py @@ -392,7 +392,7 @@ class GroupsLocalHandler(object): try: user_profile = yield self.profile_handler.get_profile(user_id) except Exception as e: - logger.warn("No profile for user %s: %s", user_id, e) + logger.warning("No profile for user %s: %s", user_id, e) user_profile = {} return {"state": "invite", "user_profile": user_profile} diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index ba99ddf76d..000fbf090f 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -272,7 +272,7 @@ class IdentityHandler(BaseHandler): changed = False if e.code in (400, 404, 501): # The remote server probably doesn't support unbinding (yet) - logger.warn("Received %d response while unbinding threepid", e.code) + logger.warning("Received %d response while unbinding threepid", e.code) else: logger.error("Failed to unbind threepid on identity server: %s", e) raise SynapseError(500, "Failed to contact identity server") @@ -403,7 +403,7 @@ class IdentityHandler(BaseHandler): if self.hs.config.using_identity_server_from_trusted_list: # Warn that a deprecated config option is in use - logger.warn( + logger.warning( 'The config option "trust_identity_server_for_password_resets" ' 'has been replaced by "account_threepid_delegate". ' "Please consult the sample config at docs/sample_config.yaml for " @@ -457,7 +457,7 @@ class IdentityHandler(BaseHandler): if self.hs.config.using_identity_server_from_trusted_list: # Warn that a deprecated config option is in use - logger.warn( + logger.warning( 'The config option "trust_identity_server_for_password_resets" ' 'has been replaced by "account_threepid_delegate". ' "Please consult the sample config at docs/sample_config.yaml for " diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 7908a2d52c..5698e5fee0 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -688,7 +688,7 @@ class EventCreationHandler(object): try: yield self.auth.check_from_context(room_version, event, context) except AuthError as err: - logger.warn("Denying new event %r because %s", event, err) + logger.warning("Denying new event %r because %s", event, err) raise err # Ensure that we can round trip before trying to persist in db diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 8690f69d45..22e0a04da4 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -275,7 +275,7 @@ class BaseProfileHandler(BaseHandler): ratelimit=False, # Try to hide that these events aren't atomic. ) except Exception as e: - logger.warn( + logger.warning( "Failed to update join event for room %s - %s", room_id, str(e) ) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 2816bd8f87..445a08f445 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -922,7 +922,7 @@ class RoomEventSource(object): from_token = RoomStreamToken.parse(from_key) if from_token.topological: - logger.warn("Stream has topological part!!!! %r", from_key) + logger.warning("Stream has topological part!!!! %r", from_key) from_key = "s%s" % (from_token.stream,) app_service = self.store.get_app_service_by_user_id(user.to_string()) diff --git a/synapse/http/client.py b/synapse/http/client.py index cdf828a4ff..2df5b383b5 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -535,7 +535,7 @@ class SimpleHttpClient(object): b"Content-Length" in resp_headers and int(resp_headers[b"Content-Length"][0]) > max_size ): - logger.warn("Requested URL is too large > %r bytes" % (self.max_size,)) + logger.warning("Requested URL is too large > %r bytes" % (self.max_size,)) raise SynapseError( 502, "Requested file is too large > %r bytes" % (self.max_size,), @@ -543,7 +543,7 @@ class SimpleHttpClient(object): ) if response.code > 299: - logger.warn("Got %d when downloading %s" % (response.code, url)) + logger.warning("Got %d when downloading %s" % (response.code, url)) raise SynapseError(502, "Got error %d" % (response.code,), Codes.UNKNOWN) # TODO: if our Content-Type is HTML or something, just read the first diff --git a/synapse/http/federation/srv_resolver.py b/synapse/http/federation/srv_resolver.py index 3fe4ffb9e5..021b233a7d 100644 --- a/synapse/http/federation/srv_resolver.py +++ b/synapse/http/federation/srv_resolver.py @@ -148,7 +148,7 @@ class SrvResolver(object): # Try something in the cache, else rereaise cache_entry = self._cache.get(service_name, None) if cache_entry: - logger.warn( + logger.warning( "Failed to resolve %r, falling back to cache. %r", service_name, e ) return list(cache_entry) diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 3f7c93ffcb..691380abda 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -149,7 +149,7 @@ def _handle_json_response(reactor, timeout_sec, request, response): body = yield make_deferred_yieldable(d) except Exception as e: - logger.warn( + logger.warning( "{%s} [%s] Error reading response: %s", request.txn_id, request.destination, @@ -457,7 +457,7 @@ class MatrixFederationHttpClient(object): except Exception as e: # Eh, we're already going to raise an exception so lets # ignore if this fails. - logger.warn( + logger.warning( "{%s} [%s] Failed to get error response: %s %s: %s", request.txn_id, request.destination, @@ -478,7 +478,7 @@ class MatrixFederationHttpClient(object): break except RequestSendFailed as e: - logger.warn( + logger.warning( "{%s} [%s] Request failed: %s %s: %s", request.txn_id, request.destination, @@ -513,7 +513,7 @@ class MatrixFederationHttpClient(object): raise except Exception as e: - logger.warn( + logger.warning( "{%s} [%s] Request failed: %s %s: %s", request.txn_id, request.destination, @@ -889,7 +889,7 @@ class MatrixFederationHttpClient(object): d.addTimeout(self.default_timeout, self.reactor) length = yield make_deferred_yieldable(d) except Exception as e: - logger.warn( + logger.warning( "{%s} [%s] Error reading response: %s", request.txn_id, request.destination, diff --git a/synapse/http/request_metrics.py b/synapse/http/request_metrics.py index 46af27c8f6..58f9cc61c8 100644 --- a/synapse/http/request_metrics.py +++ b/synapse/http/request_metrics.py @@ -170,7 +170,7 @@ class RequestMetrics(object): tag = context.tag if context != self.start_context: - logger.warn( + logger.warning( "Context have unexpectedly changed %r, %r", context, self.start_context, diff --git a/synapse/http/server.py b/synapse/http/server.py index 2ccb210fd6..943d12c907 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -454,7 +454,7 @@ def respond_with_json( # the Deferred fires, but since the flag is RIGHT THERE it seems like # a waste. if request._disconnected: - logger.warn( + logger.warning( "Not sending response to request %s, already disconnected.", request ) return diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index 274c1a6a87..e9a5e46ced 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -219,13 +219,13 @@ def parse_json_value_from_request(request, allow_empty_body=False): try: content_unicode = content_bytes.decode("utf8") except UnicodeDecodeError: - logger.warn("Unable to decode UTF-8") + logger.warning("Unable to decode UTF-8") raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON) try: content = json.loads(content_unicode) except Exception as e: - logger.warn("Unable to parse JSON: %s", e) + logger.warning("Unable to parse JSON: %s", e) raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON) return content diff --git a/synapse/http/site.py b/synapse/http/site.py index df5274c177..ff8184a3d0 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -199,7 +199,7 @@ class SynapseRequest(Request): # It's useful to log it here so that we can get an idea of when # the client disconnects. with PreserveLoggingContext(self.logcontext): - logger.warn( + logger.warning( "Error processing request %r: %s %s", self, reason.type, reason.value ) @@ -305,7 +305,7 @@ class SynapseRequest(Request): try: self.request_metrics.stop(self.finish_time, self.code, self.sentLength) except Exception as e: - logger.warn("Failed to stop metrics: %r", e) + logger.warning("Failed to stop metrics: %r", e) class XForwardedForRequest(SynapseRequest): diff --git a/synapse/logging/context.py b/synapse/logging/context.py index 370000e377..2c1fb9ddac 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py @@ -294,7 +294,7 @@ class LoggingContext(object): """Enters this logging context into thread local storage""" old_context = self.set_current_context(self) if self.previous_context != old_context: - logger.warn( + logger.warning( "Expected previous context %r, found %r", self.previous_context, old_context, diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index 6299587808..23d3678420 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -246,7 +246,7 @@ class HttpPusher(object): # we really only give up so that if the URL gets # fixed, we don't suddenly deliver a load # of old notifications. - logger.warn( + logger.warning( "Giving up on a notification to user %s, " "pushkey %s", self.user_id, self.pushkey, @@ -299,7 +299,7 @@ class HttpPusher(object): if pk != self.pushkey: # for sanity, we only remove the pushkey if it # was the one we actually sent... - logger.warn( + logger.warning( ("Ignoring rejected pushkey %s because we" " didn't send it"), pk, ) diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py index 5ed9147de4..b1587183a8 100644 --- a/synapse/push/push_rule_evaluator.py +++ b/synapse/push/push_rule_evaluator.py @@ -117,7 +117,7 @@ class PushRuleEvaluatorForEvent(object): pattern = UserID.from_string(user_id).localpart if not pattern: - logger.warn("event_match condition with no pattern") + logger.warning("event_match condition with no pattern") return False # XXX: optimisation: cache our pattern regexps @@ -173,7 +173,7 @@ def _glob_matches(glob, value, word_boundary=False): regex_cache[(glob, word_boundary)] = r return r.search(value) except re.error: - logger.warn("Failed to parse glob to regex: %r", glob) + logger.warning("Failed to parse glob to regex: %r", glob) return False diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index 9be37cd998..c8056b0c0c 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -180,7 +180,7 @@ class ReplicationEndpoint(object): if e.code != 504 or not cls.RETRY_ON_TIMEOUT: raise - logger.warn("%s request timed out", cls.NAME) + logger.warning("%s request timed out", cls.NAME) # If we timed out we probably don't need to worry about backing # off too much, but lets just wait a little anyway. diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py index b5f5f13a62..cc1f249740 100644 --- a/synapse/replication/http/membership.py +++ b/synapse/replication/http/membership.py @@ -144,7 +144,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): # The 'except' clause is very broad, but we need to # capture everything from DNS failures upwards # - logger.warn("Failed to reject invite: %s", e) + logger.warning("Failed to reject invite: %s", e) await self.store.locally_reject_invite(user_id, room_id) ret = {} diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index a44ceb00e7..563ce0fc53 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -168,7 +168,7 @@ class ReplicationClientHandler(object): if self.connection: self.connection.send_command(cmd) else: - logger.warn("Queuing command as not connected: %r", cmd.NAME) + logger.warning("Queuing command as not connected: %r", cmd.NAME) self.pending_commands.append(cmd) def send_federation_ack(self, token): diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index 5ffdf2675d..b64f3f44b5 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -249,7 +249,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): return handler(cmd) def close(self): - logger.warn("[%s] Closing connection", self.id()) + logger.warning("[%s] Closing connection", self.id()) self.time_we_closed = self.clock.time_msec() self.transport.loseConnection() self.on_connection_closed() diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index 939418ee2b..5c2a2eb593 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -286,7 +286,7 @@ class PurgeHistoryRestServlet(RestServlet): room_id, stream_ordering ) if not r: - logger.warn( + logger.warning( "[purge] purging events not possible: No event found " "(received_ts %i => stream_ordering %i)", ts, diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 8414af08cb..39a5c5e9de 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -221,7 +221,7 @@ class LoginRestServlet(RestServlet): medium, address ) if not user_id: - logger.warn( + logger.warning( "unknown 3pid identifier medium %s, address %r", medium, address ) raise LoginError(403, "", errcode=Codes.FORBIDDEN) diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index 80cf7126a0..332d7138b1 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -71,7 +71,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): def on_POST(self, request): if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: if self.config.local_threepid_handling_disabled_due_to_email_config: - logger.warn( + logger.warning( "User password resets have been disabled due to lack of email config" ) raise SynapseError( @@ -162,7 +162,7 @@ class PasswordResetSubmitTokenServlet(RestServlet): ) if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: if self.config.local_threepid_handling_disabled_due_to_email_config: - logger.warn( + logger.warning( "Password reset emails have been disabled due to lack of an email config" ) raise SynapseError( @@ -183,7 +183,7 @@ class PasswordResetSubmitTokenServlet(RestServlet): # Perform a 302 redirect if next_link is set if next_link: if next_link.startswith("file:///"): - logger.warn( + logger.warning( "Not redirecting to next_link as it is a local file: address" ) else: @@ -350,7 +350,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): def on_POST(self, request): if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: if self.config.local_threepid_handling_disabled_due_to_email_config: - logger.warn( + logger.warning( "Adding emails have been disabled due to lack of an email config" ) raise SynapseError( @@ -441,7 +441,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE) if not self.hs.config.account_threepid_delegate_msisdn: - logger.warn( + logger.warning( "No upstream msisdn account_threepid_delegate configured on the server to " "handle this request" ) @@ -488,7 +488,7 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet): def on_GET(self, request): if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: if self.config.local_threepid_handling_disabled_due_to_email_config: - logger.warn( + logger.warning( "Adding emails have been disabled due to lack of an email config" ) raise SynapseError( @@ -515,7 +515,7 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet): # Perform a 302 redirect if next_link is set if next_link: if next_link.startswith("file:///"): - logger.warn( + logger.warning( "Not redirecting to next_link as it is a local file: address" ) else: diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 4f24a124a6..6c7d25d411 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -106,7 +106,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): def on_POST(self, request): if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.OFF: if self.hs.config.local_threepid_handling_disabled_due_to_email_config: - logger.warn( + logger.warning( "Email registration has been disabled due to lack of email config" ) raise SynapseError( @@ -207,7 +207,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): ) if not self.hs.config.account_threepid_delegate_msisdn: - logger.warn( + logger.warning( "No upstream msisdn account_threepid_delegate configured on the server to " "handle this request" ) @@ -266,7 +266,7 @@ class RegistrationSubmitTokenServlet(RestServlet): ) if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: if self.config.local_threepid_handling_disabled_due_to_email_config: - logger.warn( + logger.warning( "User registration via email has been disabled due to lack of email config" ) raise SynapseError( @@ -287,7 +287,7 @@ class RegistrationSubmitTokenServlet(RestServlet): # Perform a 302 redirect if next_link is set if next_link: if next_link.startswith("file:///"): - logger.warn( + logger.warning( "Not redirecting to next_link as it is a local file: address" ) else: @@ -480,7 +480,7 @@ class RegisterRestServlet(RestServlet): # a password to work around a client bug where it sent # the 'initial_device_display_name' param alone, wiping out # the original registration params - logger.warn("Ignoring initial_device_display_name without password") + logger.warning("Ignoring initial_device_display_name without password") del body["initial_device_display_name"] session_id = self.auth_handler.get_session_id(body) diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index 541a6b0e10..ccd8b17b23 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -394,7 +394,7 @@ class SyncRestServlet(RestServlet): # We've had bug reports that events were coming down under the # wrong room. if event.room_id != room.room_id: - logger.warn( + logger.warning( "Event %r is under room %r instead of %r", event.event_id, room.room_id, diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index b972e152a9..bd9186fe50 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -363,7 +363,7 @@ class MediaRepository(object): }, ) except RequestSendFailed as e: - logger.warn( + logger.warning( "Request failed fetching remote media %s/%s: %r", server_name, media_id, @@ -372,7 +372,7 @@ class MediaRepository(object): raise SynapseError(502, "Failed to fetch remote media") except HttpResponseException as e: - logger.warn( + logger.warning( "HTTP error fetching remote media %s/%s: %s", server_name, media_id, @@ -383,10 +383,12 @@ class MediaRepository(object): raise SynapseError(502, "Failed to fetch remote media") except SynapseError: - logger.warn("Failed to fetch remote media %s/%s", server_name, media_id) + logger.warning( + "Failed to fetch remote media %s/%s", server_name, media_id + ) raise except NotRetryingDestination: - logger.warn("Not retrying destination %r", server_name) + logger.warning("Not retrying destination %r", server_name) raise SynapseError(502, "Failed to fetch remote media") except Exception: logger.exception( @@ -691,7 +693,7 @@ class MediaRepository(object): try: os.remove(full_path) except OSError as e: - logger.warn("Failed to remove file: %r", full_path) + logger.warning("Failed to remove file: %r", full_path) if e.errno == errno.ENOENT: pass else: diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 094ebad770..5a25b6b3fc 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -136,7 +136,7 @@ class PreviewUrlResource(DirectServeResource): match = False continue if match: - logger.warn("URL %s blocked by url_blacklist entry %s", url, entry) + logger.warning("URL %s blocked by url_blacklist entry %s", url, entry) raise SynapseError( 403, "URL blocked by url pattern blacklist entry", Codes.UNKNOWN ) @@ -208,7 +208,7 @@ class PreviewUrlResource(DirectServeResource): og["og:image:width"] = dims["width"] og["og:image:height"] = dims["height"] else: - logger.warn("Couldn't get dims for %s" % url) + logger.warning("Couldn't get dims for %s" % url) # define our OG response for this media elif _is_html(media_info["media_type"]): @@ -256,7 +256,7 @@ class PreviewUrlResource(DirectServeResource): og["og:image:width"] = dims["width"] og["og:image:height"] = dims["height"] else: - logger.warn("Couldn't get dims for %s", og["og:image"]) + logger.warning("Couldn't get dims for %s", og["og:image"]) og["og:image"] = "mxc://%s/%s" % ( self.server_name, @@ -267,7 +267,7 @@ class PreviewUrlResource(DirectServeResource): else: del og["og:image"] else: - logger.warn("Failed to find any OG data in %s", url) + logger.warning("Failed to find any OG data in %s", url) og = {} logger.debug("Calculated OG for %s as %s", url, og) @@ -319,7 +319,7 @@ class PreviewUrlResource(DirectServeResource): ) except Exception as e: # FIXME: pass through 404s and other error messages nicely - logger.warn("Error downloading %s: %r", url, e) + logger.warning("Error downloading %s: %r", url, e) raise SynapseError( 500, @@ -400,7 +400,7 @@ class PreviewUrlResource(DirectServeResource): except OSError as e: # If the path doesn't exist, meh if e.errno != errno.ENOENT: - logger.warn("Failed to remove media: %r: %s", media_id, e) + logger.warning("Failed to remove media: %r: %s", media_id, e) continue removed_media.append(media_id) @@ -432,7 +432,7 @@ class PreviewUrlResource(DirectServeResource): except OSError as e: # If the path doesn't exist, meh if e.errno != errno.ENOENT: - logger.warn("Failed to remove media: %r: %s", media_id, e) + logger.warning("Failed to remove media: %r: %s", media_id, e) continue try: @@ -448,7 +448,7 @@ class PreviewUrlResource(DirectServeResource): except OSError as e: # If the path doesn't exist, meh if e.errno != errno.ENOENT: - logger.warn("Failed to remove media: %r: %s", media_id, e) + logger.warning("Failed to remove media: %r: %s", media_id, e) continue removed_media.append(media_id) diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py index 08329884ac..931ce79be8 100644 --- a/synapse/rest/media/v1/thumbnail_resource.py +++ b/synapse/rest/media/v1/thumbnail_resource.py @@ -182,7 +182,7 @@ class ThumbnailResource(DirectServeResource): if file_path: yield respond_with_file(request, desired_type, file_path) else: - logger.warn("Failed to generate thumbnail") + logger.warning("Failed to generate thumbnail") respond_404(request) @defer.inlineCallbacks @@ -245,7 +245,7 @@ class ThumbnailResource(DirectServeResource): if file_path: yield respond_with_file(request, desired_type, file_path) else: - logger.warn("Failed to generate thumbnail") + logger.warning("Failed to generate thumbnail") respond_404(request) @defer.inlineCallbacks diff --git a/synapse/server_notices/resource_limits_server_notices.py b/synapse/server_notices/resource_limits_server_notices.py index c0e7f475c9..9fae2e0afe 100644 --- a/synapse/server_notices/resource_limits_server_notices.py +++ b/synapse/server_notices/resource_limits_server_notices.py @@ -83,7 +83,7 @@ class ResourceLimitsServerNotices(object): room_id = yield self._server_notices_manager.get_notice_room_for_user(user_id) if not room_id: - logger.warn("Failed to get server notices room") + logger.warning("Failed to get server notices room") return yield self._check_and_set_tags(user_id, room_id) diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index f5906fcd54..1a2b7ebe25 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -494,7 +494,7 @@ class SQLBaseStore(object): exception_callbacks = [] if LoggingContext.current_context() == LoggingContext.sentinel: - logger.warn("Starting db txn '%s' from sentinel context", desc) + logger.warning("Starting db txn '%s' from sentinel context", desc) try: result = yield self.runWithConnection( @@ -532,7 +532,7 @@ class SQLBaseStore(object): """ parent_context = LoggingContext.current_context() if parent_context == LoggingContext.sentinel: - logger.warn( + logger.warning( "Starting db connection from sentinel context: metrics will be lost" ) parent_context = None @@ -719,7 +719,7 @@ class SQLBaseStore(object): raise # presumably we raced with another transaction: let's retry. - logger.warn( + logger.warning( "IntegrityError when upserting into %s; retrying: %s", table, e ) diff --git a/synapse/storage/data_stores/main/pusher.py b/synapse/storage/data_stores/main/pusher.py index f005c1ae0a..d76861cdc0 100644 --- a/synapse/storage/data_stores/main/pusher.py +++ b/synapse/storage/data_stores/main/pusher.py @@ -44,7 +44,7 @@ class PusherWorkerStore(SQLBaseStore): r["data"] = json.loads(dataJson) except Exception as e: - logger.warn( + logger.warning( "Invalid JSON in data for pusher %d: %s, %s", r["id"], dataJson, diff --git a/synapse/storage/data_stores/main/search.py b/synapse/storage/data_stores/main/search.py index 0e08497452..a59b8331e1 100644 --- a/synapse/storage/data_stores/main/search.py +++ b/synapse/storage/data_stores/main/search.py @@ -196,7 +196,7 @@ class SearchBackgroundUpdateStore(BackgroundUpdateStore): " ON event_search USING GIN (vector)" ) except psycopg2.ProgrammingError as e: - logger.warn( + logger.warning( "Ignoring error %r when trying to switch from GIST to GIN", e ) diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index b60a604474..5c4de2e69f 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -309,7 +309,7 @@ class Linearizer(object): ) else: - logger.warn( + logger.warning( "Unexpected exception waiting for linearizer lock %r for key %r", self.name, key, diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py index 43fd65d693..da5077b471 100644 --- a/synapse/util/caches/__init__.py +++ b/synapse/util/caches/__init__.py @@ -107,7 +107,7 @@ def register_cache(cache_type, cache_name, cache, collect_callback=None): if collect_callback: collect_callback() except Exception as e: - logger.warn("Error calculating metrics for %s: %s", cache_name, e) + logger.warning("Error calculating metrics for %s: %s", cache_name, e) raise yield GaugeMetricFamily("__unused", "") diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py index 4b1bcdf23c..3286804322 100644 --- a/synapse/util/metrics.py +++ b/synapse/util/metrics.py @@ -119,7 +119,7 @@ class Measure(object): context = LoggingContext.current_context() if context != self.start_context: - logger.warn( + logger.warning( "Context has unexpectedly changed from '%s' to '%s'. (%r)", self.start_context, context, @@ -128,7 +128,7 @@ class Measure(object): return if not context: - logger.warn("Expected context. (%r)", self.name) + logger.warning("Expected context. (%r)", self.name) return current = context.get_resource_usage() @@ -140,7 +140,7 @@ class Measure(object): block_db_txn_duration.labels(self.name).inc(usage.db_txn_duration_sec) block_db_sched_duration.labels(self.name).inc(usage.db_sched_duration_sec) except ValueError: - logger.warn( + logger.warning( "Failed to save metrics! OLD: %r, NEW: %r", self.start_usage, current ) diff --git a/synapse/util/rlimit.py b/synapse/util/rlimit.py index 6c0f2bb0cf..207cd17c2a 100644 --- a/synapse/util/rlimit.py +++ b/synapse/util/rlimit.py @@ -33,4 +33,4 @@ def change_resource_limit(soft_file_no): resource.RLIMIT_CORE, (resource.RLIM_INFINITY, resource.RLIM_INFINITY) ) except (ValueError, resource.error) as e: - logger.warn("Failed to set file or core limit: %s", e) + logger.warning("Failed to set file or core limit: %s", e) -- cgit 1.4.1 From c6bcd388414c88fff1418de072af60906c001a10 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 31 Oct 2019 11:17:23 +0000 Subject: Fix /purge_room API. It fails trying to clean the `topic` table which was recently removed. --- synapse/storage/data_stores/main/events.py | 1 - 1 file changed, 1 deletion(-) (limited to 'synapse') diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index a4dab86a13..64a8a05279 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -1838,7 +1838,6 @@ class EventsStore( "room_stats_earliest_token", "rooms", "stream_ordering_to_exterm", - "topics", "users_in_public_rooms", "users_who_share_private_rooms", # no useful index, but let's clear them anyway -- cgit 1.4.1 From 64f2b8c3d8672273bf173cc125e339a297e5a29a Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 31 Oct 2019 15:44:31 +0100 Subject: Apply suggestions from code review Fix docstring Co-Authored-By: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> --- synapse/storage/state.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'synapse') diff --git a/synapse/storage/state.py b/synapse/storage/state.py index b382a06dcc..3735846899 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -338,7 +338,8 @@ class StateGroupStorage(object): the old and the new. Returns: - (prev_group, delta_ids), where both may be None. + Deferred[Tuple[Optional[int], Optional[list[dict[tuple[str, str], str]]]]]): + (prev_group, delta_ids) """ return self.stores.main.get_state_group_delta(state_group) -- cgit 1.4.1 From 3a74c03ffb5532831c8412b52a2d682bdeb9f322 Mon Sep 17 00:00:00 2001 From: Travis Ralston Date: Thu, 31 Oct 2019 09:16:14 -0600 Subject: Expose some homeserver functionality to spam checkers (#6259) * Offer the homeserver instance to the spam checker * Newsfile * Linting * Expose a Spam Checker API instead of passing the homeserver object * Alter changelog * s/hs/api --- changelog.d/6259.misc | 1 + synapse/events/spamcheck.py | 14 +++++++++- synapse/spam_checker_api/__init__.py | 51 ++++++++++++++++++++++++++++++++++++ 3 files changed, 65 insertions(+), 1 deletion(-) create mode 100644 changelog.d/6259.misc create mode 100644 synapse/spam_checker_api/__init__.py (limited to 'synapse') diff --git a/changelog.d/6259.misc b/changelog.d/6259.misc new file mode 100644 index 0000000000..3ff81b1ac7 --- /dev/null +++ b/changelog.d/6259.misc @@ -0,0 +1 @@ +Expose some homeserver functionality to spam checkers. diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py index 129771f183..5a907718d6 100644 --- a/synapse/events/spamcheck.py +++ b/synapse/events/spamcheck.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2017 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,6 +14,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect + +from synapse.spam_checker_api import SpamCheckerApi + class SpamChecker(object): def __init__(self, hs): @@ -26,7 +31,14 @@ class SpamChecker(object): pass if module is not None: - self.spam_checker = module(config=config) + # Older spam checkers don't accept the `api` argument, so we + # try and detect support. + spam_args = inspect.getfullargspec(module) + if "api" in spam_args.args: + api = SpamCheckerApi(hs) + self.spam_checker = module(config=config, api=api) + else: + self.spam_checker = module(config=config) def check_event_for_spam(self, event): """Checks if a given event is considered "spammy" by this server. diff --git a/synapse/spam_checker_api/__init__.py b/synapse/spam_checker_api/__init__.py new file mode 100644 index 0000000000..efcc10f808 --- /dev/null +++ b/synapse/spam_checker_api/__init__.py @@ -0,0 +1,51 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from twisted.internet import defer + +from synapse.storage.state import StateFilter + +logger = logging.getLogger(__name__) + + +class SpamCheckerApi(object): + """A proxy object that gets passed to spam checkers so they can get + access to rooms and other relevant information. + """ + + def __init__(self, hs): + self.hs = hs + + self._store = hs.get_datastore() + + @defer.inlineCallbacks + def get_state_events_in_room(self, room_id, types): + """Gets state events for the given room. + + Args: + room_id (string): The room ID to get state events in. + types (tuple): The event type and state key (using None + to represent 'any') of the room state to acquire. + + Returns: + twisted.internet.defer.Deferred[list(synapse.events.FrozenEvent)]: + The filtered state events in the room. + """ + state_ids = yield self._store.get_filtered_current_state_ids( + room_id=room_id, state_filter=StateFilter.from_types(types) + ) + state = yield self._store.get_events(state_ids.values()) + return state.values() -- cgit 1.4.1 From 020add50997f697c7847ac84b86b457ba2f3e32d Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Fri, 1 Nov 2019 02:43:24 +1100 Subject: Update black to 19.10b0 (#6304) * update version of black and also fix the mypy config being overridden --- changelog.d/6304.misc | 1 + contrib/experiments/test_messaging.py | 4 +-- mypy.ini | 11 ++++--- synapse/federation/sender/per_destination_queue.py | 11 ++++--- synapse/handlers/account_data.py | 7 ++-- synapse/handlers/appservice.py | 5 ++- synapse/handlers/e2e_keys.py | 37 ++++++++++++++-------- synapse/handlers/federation.py | 9 +++--- synapse/handlers/initial_sync.py | 4 +-- synapse/handlers/message.py | 14 ++++---- synapse/handlers/pagination.py | 13 ++++---- synapse/handlers/register.py | 4 +-- synapse/handlers/room.py | 29 +++++++++-------- synapse/handlers/room_member.py | 35 ++++++++++---------- synapse/handlers/search.py | 12 +++---- synapse/handlers/stats.py | 5 ++- synapse/handlers/sync.py | 16 ++++++---- synapse/logging/_structured.py | 2 +- synapse/push/bulk_push_rule_evaluator.py | 7 ++-- synapse/push/emailpusher.py | 14 ++++---- synapse/push/httppusher.py | 14 ++++---- synapse/push/pusherpool.py | 4 +-- synapse/rest/client/v1/login.py | 13 ++++---- synapse/rest/client/v2_alpha/account.py | 4 +-- synapse/rest/client/v2_alpha/register.py | 4 +-- synapse/rest/key/v2/remote_key_resource.py | 2 +- synapse/server.pyi | 16 +++++----- synapse/storage/data_stores/main/__init__.py | 4 +-- .../storage/data_stores/main/event_push_actions.py | 2 +- synapse/storage/data_stores/main/events.py | 8 ++--- .../storage/data_stores/main/events_bg_updates.py | 2 +- synapse/storage/data_stores/main/group_server.py | 4 +-- .../data_stores/main/monthly_active_users.py | 2 +- synapse/storage/data_stores/main/push_rule.py | 2 +- synapse/storage/data_stores/main/registration.py | 2 +- synapse/storage/data_stores/main/roommember.py | 2 +- synapse/storage/data_stores/main/search.py | 2 +- synapse/storage/data_stores/main/state.py | 20 ++++++------ synapse/storage/data_stores/main/stats.py | 4 +-- synapse/storage/util/id_generators.py | 2 +- tox.ini | 4 +-- 41 files changed, 191 insertions(+), 166 deletions(-) create mode 100644 changelog.d/6304.misc (limited to 'synapse') diff --git a/changelog.d/6304.misc b/changelog.d/6304.misc new file mode 100644 index 0000000000..20372b4f7c --- /dev/null +++ b/changelog.d/6304.misc @@ -0,0 +1 @@ +Update the version of black used to 19.10b0. diff --git a/contrib/experiments/test_messaging.py b/contrib/experiments/test_messaging.py index 6b22400a60..3bbbcfa1b4 100644 --- a/contrib/experiments/test_messaging.py +++ b/contrib/experiments/test_messaging.py @@ -78,7 +78,7 @@ class InputOutput(object): m = re.match("^join (\S+)$", line) if m: # The `sender` wants to join a room. - room_name, = m.groups() + (room_name,) = m.groups() self.print_line("%s joining %s" % (self.user, room_name)) self.server.join_room(room_name, self.user, self.user) # self.print_line("OK.") @@ -105,7 +105,7 @@ class InputOutput(object): m = re.match("^backfill (\S+)$", line) if m: # we want to backfill a room - room_name, = m.groups() + (room_name,) = m.groups() self.print_line("backfill %s" % room_name) self.server.backfill(room_name) return diff --git a/mypy.ini b/mypy.ini index ffadaddc0b..1d77c0ecc8 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,8 +1,11 @@ [mypy] -namespace_packages=True -plugins=mypy_zope:plugin -follow_imports=skip -mypy_path=stubs +namespace_packages = True +plugins = mypy_zope:plugin +follow_imports = normal +check_untyped_defs = True +show_error_codes = True +show_traceback = True +mypy_path = stubs [mypy-zope] ignore_missing_imports = True diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index cc75c39476..b754a09d7a 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -192,15 +192,16 @@ class PerDestinationQueue(object): # We have to keep 2 free slots for presence and rr_edus limit = MAX_EDUS_PER_TRANSACTION - 2 - device_update_edus, dev_list_id = ( - yield self._get_device_update_edus(limit) + device_update_edus, dev_list_id = yield self._get_device_update_edus( + limit ) limit -= len(device_update_edus) - to_device_edus, device_stream_id = ( - yield self._get_to_device_message_edus(limit) - ) + ( + to_device_edus, + device_stream_id, + ) = yield self._get_to_device_message_edus(limit) pending_edus = device_update_edus + to_device_edus diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py index 38bc67191c..2d7e6df6e4 100644 --- a/synapse/handlers/account_data.py +++ b/synapse/handlers/account_data.py @@ -38,9 +38,10 @@ class AccountDataEventSource(object): {"type": "m.tag", "content": {"tags": room_tags}, "room_id": room_id} ) - account_data, room_account_data = ( - yield self.store.get_updated_account_data_for_user(user_id, last_stream_id) - ) + ( + account_data, + room_account_data, + ) = yield self.store.get_updated_account_data_for_user(user_id, last_stream_id) for account_data_type, content in account_data.items(): results.append({"type": account_data_type, "content": content}) diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 3e9b298154..fe62f78e67 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -73,7 +73,10 @@ class ApplicationServicesHandler(object): try: limit = 100 while True: - upper_bound, events = yield self.store.get_new_events_for_appservice( + ( + upper_bound, + events, + ) = yield self.store.get_new_events_for_appservice( self.current_max, limit ) diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 5ea54f60be..0449034a4e 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -119,9 +119,10 @@ class E2eKeysHandler(object): else: query_list.append((user_id, None)) - user_ids_not_in_cache, remote_results = ( - yield self.store.get_user_devices_from_cache(query_list) - ) + ( + user_ids_not_in_cache, + remote_results, + ) = yield self.store.get_user_devices_from_cache(query_list) for user_id, devices in iteritems(remote_results): user_devices = results.setdefault(user_id, {}) for device_id, device in iteritems(devices): @@ -688,17 +689,21 @@ class E2eKeysHandler(object): try: # get our self-signing key to verify the signatures - _, self_signing_key_id, self_signing_verify_key = yield self._get_e2e_cross_signing_verify_key( - user_id, "self_signing" - ) + ( + _, + self_signing_key_id, + self_signing_verify_key, + ) = yield self._get_e2e_cross_signing_verify_key(user_id, "self_signing") # get our master key, since we may have received a signature of it. # We need to fetch it here so that we know what its key ID is, so # that we can check if a signature that was sent is a signature of # the master key or of a device - master_key, _, master_verify_key = yield self._get_e2e_cross_signing_verify_key( - user_id, "master" - ) + ( + master_key, + _, + master_verify_key, + ) = yield self._get_e2e_cross_signing_verify_key(user_id, "master") # fetch our stored devices. This is used to 1. verify # signatures on the master key, and 2. to compare with what @@ -838,9 +843,11 @@ class E2eKeysHandler(object): try: # get our user-signing key to verify the signatures - user_signing_key, user_signing_key_id, user_signing_verify_key = yield self._get_e2e_cross_signing_verify_key( - user_id, "user_signing" - ) + ( + user_signing_key, + user_signing_key_id, + user_signing_verify_key, + ) = yield self._get_e2e_cross_signing_verify_key(user_id, "user_signing") except SynapseError as e: failure = _exception_to_failure(e) for user, devicemap in signatures.items(): @@ -859,7 +866,11 @@ class E2eKeysHandler(object): try: # get the target user's master key, to make sure it matches # what was sent - master_key, master_key_id, _ = yield self._get_e2e_cross_signing_verify_key( + ( + master_key, + master_key_id, + _, + ) = yield self._get_e2e_cross_signing_verify_key( target_user, "master", user_id ) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index d2d9f8c26a..a932d3085f 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -352,10 +352,11 @@ class FederationHandler(BaseHandler): # note that if any of the missing prevs share missing state or # auth events, the requests to fetch those events are deduped # by the get_pdu_cache in federation_client. - remote_state, got_auth_chain = ( - yield self.federation_client.get_state_for_room( - origin, room_id, p - ) + ( + remote_state, + got_auth_chain, + ) = yield self.federation_client.get_state_for_room( + origin, room_id, p ) # we want the state *after* p; get_state_for_room returns the diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 49c9e031f9..81dce96f4b 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -128,8 +128,8 @@ class InitialSyncHandler(BaseHandler): tags_by_room = yield self.store.get_tags_for_user(user_id) - account_data, account_data_by_room = ( - yield self.store.get_account_data_for_user(user_id) + account_data, account_data_by_room = yield self.store.get_account_data_for_user( + user_id ) public_room_ids = yield self.store.get_public_room_ids() diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 0d546d2487..d682dc2b7a 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -76,9 +76,10 @@ class MessageHandler(object): Raises: SynapseError if something went wrong. """ - membership, membership_event_id = yield self.auth.check_in_room_or_world_readable( - room_id, user_id - ) + ( + membership, + membership_event_id, + ) = yield self.auth.check_in_room_or_world_readable(room_id, user_id) if membership == Membership.JOIN: data = yield self.state.get_current_state(room_id, event_type, state_key) @@ -153,9 +154,10 @@ class MessageHandler(object): % (user_id, room_id, at_token), ) else: - membership, membership_event_id = ( - yield self.auth.check_in_room_or_world_readable(room_id, user_id) - ) + ( + membership, + membership_event_id, + ) = yield self.auth.check_in_room_or_world_readable(room_id, user_id) if membership == Membership.JOIN: state_ids = yield self.store.get_filtered_current_state_ids( diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index b7185fe7a0..97f15a1c32 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -212,9 +212,10 @@ class PaginationHandler(object): source_config = pagin_config.get_source_config("room") with (yield self.pagination_lock.read(room_id)): - membership, member_event_id = yield self.auth.check_in_room_or_world_readable( - room_id, user_id - ) + ( + membership, + member_event_id, + ) = yield self.auth.check_in_room_or_world_readable(room_id, user_id) if source_config.direction == "b": # if we're going backwards, we might need to backfill. This @@ -297,10 +298,8 @@ class PaginationHandler(object): } if state: - chunk["state"] = ( - yield self._event_serializer.serialize_events( - state, time_now, as_client_event=as_client_event - ) + chunk["state"] = yield self._event_serializer.serialize_events( + state, time_now, as_client_event=as_client_event ) return chunk diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 53410f120b..cff6b0d375 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -396,8 +396,8 @@ class RegistrationHandler(BaseHandler): room_id = room_identifier elif RoomAlias.is_valid(room_identifier): room_alias = RoomAlias.from_string(room_identifier) - room_id, remote_room_hosts = ( - yield room_member_handler.lookup_room_alias(room_alias) + room_id, remote_room_hosts = yield room_member_handler.lookup_room_alias( + room_alias ) room_id = room_id.to_string() else: diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 650bd28abb..0182e5b432 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -147,21 +147,22 @@ class RoomCreationHandler(BaseHandler): # we create and auth the tombstone event before properly creating the new # room, to check our user has perms in the old room. - tombstone_event, tombstone_context = ( - yield self.event_creation_handler.create_event( - requester, - { - "type": EventTypes.Tombstone, - "state_key": "", - "room_id": old_room_id, - "sender": user_id, - "content": { - "body": "This room has been replaced", - "replacement_room": new_room_id, - }, + ( + tombstone_event, + tombstone_context, + ) = yield self.event_creation_handler.create_event( + requester, + { + "type": EventTypes.Tombstone, + "state_key": "", + "room_id": old_room_id, + "sender": user_id, + "content": { + "body": "This room has been replaced", + "replacement_room": new_room_id, }, - token_id=requester.access_token_id, - ) + }, + token_id=requester.access_token_id, ) old_room_version = yield self.store.get_room_version(old_room_id) yield self.auth.check_from_context( diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 380e2fad5e..9a940d2c05 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -759,22 +759,25 @@ class RoomMemberHandler(object): if room_avatar_event: room_avatar_url = room_avatar_event.content.get("url", "") - token, public_keys, fallback_public_key, display_name = ( - yield self.identity_handler.ask_id_server_for_third_party_invite( - requester=requester, - id_server=id_server, - medium=medium, - address=address, - room_id=room_id, - inviter_user_id=user.to_string(), - room_alias=canonical_room_alias, - room_avatar_url=room_avatar_url, - room_join_rules=room_join_rules, - room_name=room_name, - inviter_display_name=inviter_display_name, - inviter_avatar_url=inviter_avatar_url, - id_access_token=id_access_token, - ) + ( + token, + public_keys, + fallback_public_key, + display_name, + ) = yield self.identity_handler.ask_id_server_for_third_party_invite( + requester=requester, + id_server=id_server, + medium=medium, + address=address, + room_id=room_id, + inviter_user_id=user.to_string(), + room_alias=canonical_room_alias, + room_avatar_url=room_avatar_url, + room_join_rules=room_join_rules, + room_name=room_name, + inviter_display_name=inviter_display_name, + inviter_avatar_url=inviter_avatar_url, + id_access_token=id_access_token, ) yield self.event_creation_handler.create_and_send_nonmember_event( diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index f4d8a60774..56ed262a1f 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -396,15 +396,11 @@ class SearchHandler(BaseHandler): time_now = self.clock.time_msec() for context in contexts.values(): - context["events_before"] = ( - yield self._event_serializer.serialize_events( - context["events_before"], time_now - ) + context["events_before"] = yield self._event_serializer.serialize_events( + context["events_before"], time_now ) - context["events_after"] = ( - yield self._event_serializer.serialize_events( - context["events_after"], time_now - ) + context["events_after"] = yield self._event_serializer.serialize_events( + context["events_after"], time_now ) state_results = {} diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py index 26bc276692..7f7d56390e 100644 --- a/synapse/handlers/stats.py +++ b/synapse/handlers/stats.py @@ -108,7 +108,10 @@ class StatsHandler(StateDeltasHandler): user_deltas = {} # Then count deltas for total_events and total_event_bytes. - room_count, user_count = yield self.store.get_changes_room_total_events_and_bytes( + ( + room_count, + user_count, + ) = yield self.store.get_changes_room_total_events_and_bytes( self.pos, max_pos ) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 43a082dcda..b536d410e5 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -1206,10 +1206,11 @@ class SyncHandler(object): since_token = sync_result_builder.since_token if since_token and not sync_result_builder.full_state: - account_data, account_data_by_room = ( - yield self.store.get_updated_account_data_for_user( - user_id, since_token.account_data_key - ) + ( + account_data, + account_data_by_room, + ) = yield self.store.get_updated_account_data_for_user( + user_id, since_token.account_data_key ) push_rules_changed = yield self.store.have_push_rules_changed_for_user( @@ -1221,9 +1222,10 @@ class SyncHandler(object): sync_config.user ) else: - account_data, account_data_by_room = ( - yield self.store.get_account_data_for_user(sync_config.user.to_string()) - ) + ( + account_data, + account_data_by_room, + ) = yield self.store.get_account_data_for_user(sync_config.user.to_string()) account_data["m.push_rules"] = yield self.push_rules_for_user( sync_config.user diff --git a/synapse/logging/_structured.py b/synapse/logging/_structured.py index 3220e985a9..334ddaf39a 100644 --- a/synapse/logging/_structured.py +++ b/synapse/logging/_structured.py @@ -185,7 +185,7 @@ DEFAULT_LOGGERS = {"synapse": {"level": "INFO"}} def parse_drain_configs( - drains: dict + drains: dict, ) -> typing.Generator[DrainConfiguration, None, None]: """ Parse the drain configurations. diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 2bbdd11941..1ba7bcd4d8 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -149,9 +149,10 @@ class BulkPushRuleEvaluator(object): room_members = yield self.store.get_joined_users_from_context(event, context) - (power_levels, sender_power_level) = ( - yield self._get_power_levels_and_sender_level(event, context) - ) + ( + power_levels, + sender_power_level, + ) = yield self._get_power_levels_and_sender_level(event, context) evaluator = PushRuleEvaluatorForEvent( event, len(room_members), sender_power_level, power_levels diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py index 42e5b0c0a5..8c818a86bf 100644 --- a/synapse/push/emailpusher.py +++ b/synapse/push/emailpusher.py @@ -234,14 +234,12 @@ class EmailPusher(object): return self.last_stream_ordering = last_stream_ordering - pusher_still_exists = ( - yield self.store.update_pusher_last_stream_ordering_and_success( - self.app_id, - self.email, - self.user_id, - last_stream_ordering, - self.clock.time_msec(), - ) + pusher_still_exists = yield self.store.update_pusher_last_stream_ordering_and_success( + self.app_id, + self.email, + self.user_id, + last_stream_ordering, + self.clock.time_msec(), ) if not pusher_still_exists: # The pusher has been deleted while we were processing, so diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index 9a1bb64887..7dde2ad055 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -211,14 +211,12 @@ class HttpPusher(object): http_push_processed_counter.inc() self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC self.last_stream_ordering = push_action["stream_ordering"] - pusher_still_exists = ( - yield self.store.update_pusher_last_stream_ordering_and_success( - self.app_id, - self.pushkey, - self.user_id, - self.last_stream_ordering, - self.clock.time_msec(), - ) + pusher_still_exists = yield self.store.update_pusher_last_stream_ordering_and_success( + self.app_id, + self.pushkey, + self.user_id, + self.last_stream_ordering, + self.clock.time_msec(), ) if not pusher_still_exists: # The pusher has been deleted while we were processing, so diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index 08e840fdc2..0f6992202d 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -103,9 +103,7 @@ class PusherPool: # create the pusher setting last_stream_ordering to the current maximum # stream ordering in event_push_actions, so it will process # pushes from this point onwards. - last_stream_ordering = ( - yield self.store.get_latest_push_action_stream_ordering() - ) + last_stream_ordering = yield self.store.get_latest_push_action_stream_ordering() yield self.store.add_pusher( user_id=user_id, diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 39a5c5e9de..00a7dd6d09 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -203,10 +203,11 @@ class LoginRestServlet(RestServlet): address = address.lower() # Check for login providers that support 3pid login types - canonical_user_id, callback_3pid = ( - yield self.auth_handler.check_password_provider_3pid( - medium, address, login_submission["password"] - ) + ( + canonical_user_id, + callback_3pid, + ) = yield self.auth_handler.check_password_provider_3pid( + medium, address, login_submission["password"] ) if canonical_user_id: # Authentication through password provider and 3pid succeeded @@ -280,8 +281,8 @@ class LoginRestServlet(RestServlet): def do_token_login(self, login_submission): token = login_submission["token"] auth_handler = self.auth_handler - user_id = ( - yield auth_handler.validate_short_term_login_token_and_get_user_id(token) + user_id = yield auth_handler.validate_short_term_login_token_and_get_user_id( + token ) result = yield self._register_device_with_callback(user_id, login_submission) diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index 332d7138b1..f26eae794c 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -148,7 +148,7 @@ class PasswordResetSubmitTokenServlet(RestServlet): self.clock = hs.get_clock() self.store = hs.get_datastore() if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: - self.failure_email_template, = load_jinja2_templates( + (self.failure_email_template,) = load_jinja2_templates( self.config.email_template_dir, [self.config.email_password_reset_template_failure_html], ) @@ -479,7 +479,7 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet): self.clock = hs.get_clock() self.store = hs.get_datastore() if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: - self.failure_email_template, = load_jinja2_templates( + (self.failure_email_template,) = load_jinja2_templates( self.config.email_template_dir, [self.config.email_add_threepid_template_failure_html], ) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 6c7d25d411..91db923814 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -247,13 +247,13 @@ class RegistrationSubmitTokenServlet(RestServlet): self.store = hs.get_datastore() if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: - self.failure_email_template, = load_jinja2_templates( + (self.failure_email_template,) = load_jinja2_templates( self.config.email_template_dir, [self.config.email_registration_template_failure_html], ) if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: - self.failure_email_template, = load_jinja2_templates( + (self.failure_email_template,) = load_jinja2_templates( self.config.email_template_dir, [self.config.email_registration_template_failure_html], ) diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index 55580bc59e..e7fc3f0431 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -102,7 +102,7 @@ class RemoteKey(DirectServeResource): @wrap_json_request_handler async def _async_render_GET(self, request): if len(request.postpath) == 1: - server, = request.postpath + (server,) = request.postpath query = {server.decode("ascii"): {}} elif len(request.postpath) == 2: server, key_id = request.postpath diff --git a/synapse/server.pyi b/synapse/server.pyi index 16f8f6b573..83d1f11283 100644 --- a/synapse/server.pyi +++ b/synapse/server.pyi @@ -39,7 +39,7 @@ class HomeServer(object): def get_state_resolution_handler(self) -> synapse.state.StateResolutionHandler: pass def get_deactivate_account_handler( - self + self, ) -> synapse.handlers.deactivate_account.DeactivateAccountHandler: pass def get_room_creation_handler(self) -> synapse.handlers.room.RoomCreationHandler: @@ -47,32 +47,32 @@ class HomeServer(object): def get_room_member_handler(self) -> synapse.handlers.room_member.RoomMemberHandler: pass def get_event_creation_handler( - self + self, ) -> synapse.handlers.message.EventCreationHandler: pass def get_set_password_handler( - self + self, ) -> synapse.handlers.set_password.SetPasswordHandler: pass def get_federation_sender(self) -> synapse.federation.sender.FederationSender: pass def get_federation_transport_client( - self + self, ) -> synapse.federation.transport.client.TransportLayerClient: pass def get_media_repository_resource( - self + self, ) -> synapse.rest.media.v1.media_repository.MediaRepositoryResource: pass def get_media_repository( - self + self, ) -> synapse.rest.media.v1.media_repository.MediaRepository: pass def get_server_notices_manager( - self + self, ) -> synapse.server_notices.server_notices_manager.ServerNoticesManager: pass def get_server_notices_sender( - self + self, ) -> synapse.server_notices.server_notices_sender.ServerNoticesSender: pass diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/data_stores/main/__init__.py index b185ba0b3e..60ae01d972 100644 --- a/synapse/storage/data_stores/main/__init__.py +++ b/synapse/storage/data_stores/main/__init__.py @@ -317,7 +317,7 @@ class DataStore( ) u """ txn.execute(sql, (time_from,)) - count, = txn.fetchone() + (count,) = txn.fetchone() return count def count_r30_users(self): @@ -396,7 +396,7 @@ class DataStore( txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs)) - count, = txn.fetchone() + (count,) = txn.fetchone() results["all"] = count return results diff --git a/synapse/storage/data_stores/main/event_push_actions.py b/synapse/storage/data_stores/main/event_push_actions.py index 22025effbc..04ce21ac66 100644 --- a/synapse/storage/data_stores/main/event_push_actions.py +++ b/synapse/storage/data_stores/main/event_push_actions.py @@ -863,7 +863,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore): ) stream_row = txn.fetchone() if stream_row: - offset_stream_ordering, = stream_row + (offset_stream_ordering,) = stream_row rotate_to_stream_ordering = min( self.stream_ordering_day_ago, offset_stream_ordering ) diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index 64a8a05279..aafc2007d3 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -1125,7 +1125,7 @@ class EventsStore( AND stream_ordering > ? """ txn.execute(sql, (self.stream_ordering_day_ago,)) - count, = txn.fetchone() + (count,) = txn.fetchone() return count ret = yield self.runInteraction("count_messages", _count_messages) @@ -1146,7 +1146,7 @@ class EventsStore( """ txn.execute(sql, (like_clause, self.stream_ordering_day_ago)) - count, = txn.fetchone() + (count,) = txn.fetchone() return count ret = yield self.runInteraction("count_daily_sent_messages", _count_messages) @@ -1161,7 +1161,7 @@ class EventsStore( AND stream_ordering > ? """ txn.execute(sql, (self.stream_ordering_day_ago,)) - count, = txn.fetchone() + (count,) = txn.fetchone() return count ret = yield self.runInteraction("count_daily_active_rooms", _count) @@ -1646,7 +1646,7 @@ class EventsStore( """, (room_id,), ) - min_depth, = txn.fetchone() + (min_depth,) = txn.fetchone() logger.info("[purge] updating room_depth to %d", min_depth) diff --git a/synapse/storage/data_stores/main/events_bg_updates.py b/synapse/storage/data_stores/main/events_bg_updates.py index 31ea6f917f..51352b9966 100644 --- a/synapse/storage/data_stores/main/events_bg_updates.py +++ b/synapse/storage/data_stores/main/events_bg_updates.py @@ -438,7 +438,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): if not rows: return 0 - upper_event_id, = rows[-1] + (upper_event_id,) = rows[-1] # Update the redactions with the received_ts. # diff --git a/synapse/storage/data_stores/main/group_server.py b/synapse/storage/data_stores/main/group_server.py index aeae5a2b28..b3a2771f1b 100644 --- a/synapse/storage/data_stores/main/group_server.py +++ b/synapse/storage/data_stores/main/group_server.py @@ -249,7 +249,7 @@ class GroupServerStore(SQLBaseStore): WHERE group_id = ? AND category_id = ? """ txn.execute(sql, (group_id, category_id)) - order, = txn.fetchone() + (order,) = txn.fetchone() if existing: to_update = {} @@ -509,7 +509,7 @@ class GroupServerStore(SQLBaseStore): WHERE group_id = ? AND role_id = ? """ txn.execute(sql, (group_id, role_id)) - order, = txn.fetchone() + (order,) = txn.fetchone() if existing: to_update = {} diff --git a/synapse/storage/data_stores/main/monthly_active_users.py b/synapse/storage/data_stores/main/monthly_active_users.py index e6ee1e4aaa..b41c3d317a 100644 --- a/synapse/storage/data_stores/main/monthly_active_users.py +++ b/synapse/storage/data_stores/main/monthly_active_users.py @@ -171,7 +171,7 @@ class MonthlyActiveUsersStore(SQLBaseStore): sql = "SELECT COALESCE(count(*), 0) FROM monthly_active_users" txn.execute(sql) - count, = txn.fetchone() + (count,) = txn.fetchone() return count return self.runInteraction("count_users", _count_users) diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py index cd95f1ce60..b520062d84 100644 --- a/synapse/storage/data_stores/main/push_rule.py +++ b/synapse/storage/data_stores/main/push_rule.py @@ -143,7 +143,7 @@ class PushRulesWorkerStore( " WHERE user_id = ? AND ? < stream_id" ) txn.execute(sql, (user_id, last_id)) - count, = txn.fetchone() + (count,) = txn.fetchone() return bool(count) return self.runInteraction( diff --git a/synapse/storage/data_stores/main/registration.py b/synapse/storage/data_stores/main/registration.py index 6c5b29288a..f70d41ecab 100644 --- a/synapse/storage/data_stores/main/registration.py +++ b/synapse/storage/data_stores/main/registration.py @@ -459,7 +459,7 @@ class RegistrationWorkerStore(SQLBaseStore): WHERE appservice_id IS NULL """ ) - count, = txn.fetchone() + (count,) = txn.fetchone() return count ret = yield self.runInteraction("count_users", _count_users) diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py index bc04bfd7d4..2af24a20b7 100644 --- a/synapse/storage/data_stores/main/roommember.py +++ b/synapse/storage/data_stores/main/roommember.py @@ -927,7 +927,7 @@ class RoomMemberBackgroundUpdateStore(BackgroundUpdateStore): if not row or not row[0]: return processed, True - next_room, = row + (next_room,) = row sql = """ UPDATE current_state_events diff --git a/synapse/storage/data_stores/main/search.py b/synapse/storage/data_stores/main/search.py index a59b8331e1..d1d7c6863d 100644 --- a/synapse/storage/data_stores/main/search.py +++ b/synapse/storage/data_stores/main/search.py @@ -672,7 +672,7 @@ class SearchStore(SearchBackgroundUpdateStore): ) ) txn.execute(query, (value, search_query)) - headline, = txn.fetchall()[0] + (headline,) = txn.fetchall()[0] # Now we need to pick the possible highlights out of the haedline # result. diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py index 9b2207075b..3132848034 100644 --- a/synapse/storage/data_stores/main/state.py +++ b/synapse/storage/data_stores/main/state.py @@ -725,16 +725,18 @@ class StateGroupWorkerStore( member_filter, non_member_filter = state_filter.get_member_split() # Now we look them up in the member and non-member caches - non_member_state, incomplete_groups_nm, = ( - yield self._get_state_for_groups_using_cache( - groups, self._state_group_cache, state_filter=non_member_filter - ) + ( + non_member_state, + incomplete_groups_nm, + ) = yield self._get_state_for_groups_using_cache( + groups, self._state_group_cache, state_filter=non_member_filter ) - member_state, incomplete_groups_m, = ( - yield self._get_state_for_groups_using_cache( - groups, self._state_group_members_cache, state_filter=member_filter - ) + ( + member_state, + incomplete_groups_m, + ) = yield self._get_state_for_groups_using_cache( + groups, self._state_group_members_cache, state_filter=member_filter ) state = dict(non_member_state) @@ -1076,7 +1078,7 @@ class StateBackgroundUpdateStore( " WHERE id < ? AND room_id = ?", (state_group, room_id), ) - prev_group, = txn.fetchone() + (prev_group,) = txn.fetchone() new_last_state_group = state_group if prev_group: diff --git a/synapse/storage/data_stores/main/stats.py b/synapse/storage/data_stores/main/stats.py index 4d59b7833f..45b3de7d56 100644 --- a/synapse/storage/data_stores/main/stats.py +++ b/synapse/storage/data_stores/main/stats.py @@ -773,7 +773,7 @@ class StatsStore(StateDeltasStore): (room_id,), ) - current_state_events_count, = txn.fetchone() + (current_state_events_count,) = txn.fetchone() users_in_room = self.get_users_in_room_txn(txn, room_id) @@ -863,7 +863,7 @@ class StatsStore(StateDeltasStore): """, (user_id,), ) - count, = txn.fetchone() + (count,) = txn.fetchone() return count, pos joined_rooms, pos = yield self.runInteraction( diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index cbb0a4810a..9d851beaa5 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -46,7 +46,7 @@ def _load_current_id(db_conn, table, column, step=1): cur.execute("SELECT MAX(%s) FROM %s" % (column, table)) else: cur.execute("SELECT MIN(%s) FROM %s" % (column, table)) - val, = cur.fetchone() + (val,) = cur.fetchone() cur.close() current_id = int(val) if val else step return (max if step > 0 else min)(current_id, step) diff --git a/tox.ini b/tox.ini index 50b6afe611..afe9bc909b 100644 --- a/tox.ini +++ b/tox.ini @@ -114,7 +114,7 @@ skip_install = True basepython = python3.6 deps = flake8 - black==19.3b0 # We pin so that our tests don't start failing on new releases of black. + black==19.10b0 # We pin so that our tests don't start failing on new releases of black. commands = python -m black --check --diff . /bin/sh -c "flake8 synapse tests scripts scripts-dev synctl {env:PEP8SUFFIX:}" @@ -167,6 +167,6 @@ deps = env = MYPYPATH = stubs/ extras = all -commands = mypy --show-traceback --check-untyped-defs --show-error-codes --follow-imports=normal \ +commands = mypy \ synapse/logging/ \ synapse/config/ -- cgit 1.4.1 From f7e4a582ef1f1fc14edc86fa677a7d880a5ef01b Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Thu, 31 Oct 2019 12:01:00 -0400 Subject: clean up code a bit --- synapse/replication/slave/storage/devices.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) (limited to 'synapse') diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py index f416d73b2e..de50748c30 100644 --- a/synapse/replication/slave/storage/devices.py +++ b/synapse/replication/slave/storage/devices.py @@ -15,6 +15,7 @@ from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker +from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream from synapse.storage.data_stores.main.devices import DeviceWorkerStore from synapse.storage.data_stores.main.end_to_end_keys import EndToEndKeyWorkerStore from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -42,17 +43,20 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto def stream_positions(self): result = super(SlavedDeviceStore, self).stream_positions() - result["user_signature"] = result[ - "device_lists" - ] = self._device_list_id_gen.get_current_token() + # The user signature stream uses the same stream ID generator as the + # device list stream, so set them both to the device list ID + # generator's current token. + current_token = self._device_list_id_gen.get_current_token() + result[DeviceListsStream.NAME] = current_token + result[UserSignatureStream.NAME] = current_token return result def process_replication_rows(self, stream_name, token, rows): - if stream_name == "device_lists": + if stream_name == DeviceListsStream.NAME: self._device_list_id_gen.advance(token) for row in rows: self._invalidate_caches_for_devices(token, row.user_id, row.destination) - elif stream_name == "user_signature": + elif stream_name == UserSignatureStream.NAME: for row in rows: self._user_signature_stream_cache.entity_has_changed(row.user_id, token) return super(SlavedDeviceStore, self).process_replication_rows( -- cgit 1.4.1 From 42e707c663505cecb63da579d5fa2c4d30811db7 Mon Sep 17 00:00:00 2001 From: Will Hunt Date: Thu, 31 Oct 2019 17:32:25 +0000 Subject: rstrip slashes from url on appservice (#6306) --- changelog.d/6306.bugfix | 1 + synapse/appservice/__init__.py | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) create mode 100644 changelog.d/6306.bugfix (limited to 'synapse') diff --git a/changelog.d/6306.bugfix b/changelog.d/6306.bugfix new file mode 100644 index 0000000000..c7dcbcdce8 --- /dev/null +++ b/changelog.d/6306.bugfix @@ -0,0 +1 @@ +Appservice requests will no longer contain a double slash prefix when the appservice url provided ends in a slash. diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index 33b3579425..aea3985a5f 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -94,7 +94,9 @@ class ApplicationService(object): ip_range_whitelist=None, ): self.token = token - self.url = url + self.url = ( + url.rstrip("/") if isinstance(url, str) else None + ) # url must not end with a slash self.hs_token = hs_token self.sender = sender self.server_name = hostname -- cgit 1.4.1 From c3fc176c6047c8194262a64599d000e9cb43f7f8 Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Thu, 31 Oct 2019 22:49:48 -0400 Subject: Update synapse/storage/data_stores/main/devices.py Co-Authored-By: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> --- synapse/storage/data_stores/main/devices.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse') diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py index 717eab4159..71f62036c0 100644 --- a/synapse/storage/data_stores/main/devices.py +++ b/synapse/storage/data_stores/main/devices.py @@ -203,7 +203,7 @@ class DeviceWorkerStore(SQLBaseStore): result = cross_signing_keys_by_user.setdefault(user_id, {}) result["master_key"] = master_key_by_user[user_id]["key_info"] elif ( - user_id in master_key_by_user + user_id in self_signing_key_by_user and device_id == self_signing_key_by_user[user_id]["device_id"] ): result = cross_signing_keys_by_user.setdefault(user_id, {}) -- cgit 1.4.1 From c61db131837eb5bb50355b6c34542cfd73654b41 Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Thu, 31 Oct 2019 11:28:18 -0400 Subject: fix hidden field in devices table for older sqlite --- .../schema/delta/56/hidden_devices_fix.sql.sqlite | 42 ++++++++++++++++++++++ 1 file changed, 42 insertions(+) create mode 100644 synapse/storage/data_stores/main/schema/delta/56/hidden_devices_fix.sql.sqlite (limited to 'synapse') diff --git a/synapse/storage/data_stores/main/schema/delta/56/hidden_devices_fix.sql.sqlite b/synapse/storage/data_stores/main/schema/delta/56/hidden_devices_fix.sql.sqlite new file mode 100644 index 0000000000..e8b1fd35d8 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/56/hidden_devices_fix.sql.sqlite @@ -0,0 +1,42 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* Change the hidden column from a default value of FALSE to a default value of + * 0, because sqlite3 prior to 3.23.0 caused the hidden column to contain the + * string 'FALSE', which is truthy. + * + * Since sqlite doesn't allow us to just change the default value, we have to + * recreate the table, copy the data, fix the rows that have incorrect data, and + * replace the old table with the new table. + */ + +CREATE TABLE IF NOT EXISTS devices2 ( + user_id TEXT NOT NULL, + device_id TEXT NOT NULL, + display_name TEXT, + last_seen BIGINT, + ip TEXT, + user_agent TEXT, + hidden BOOLEAN DEFAULT 0, + CONSTRAINT device_uniqueness UNIQUE (user_id, device_id) +); + +INSERT INTO devices2 SELECT * FROM devices; + +UPDATE devices2 SET hidden = 0 WHERE hidden = 'FALSE'; + +DROP TABLE devices; + +ALTER TABLE devices2 RENAME TO devices; -- cgit 1.4.1 From ace947e8da30c37ead3357abe34adee8a1528296 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Fri, 1 Nov 2019 10:28:09 +0000 Subject: Depublish a room from the public rooms list when it is upgraded (#6232) --- changelog.d/6232.bugfix | 1 + synapse/federation/federation_client.py | 2 +- synapse/handlers/federation.py | 30 +++++++++++- synapse/handlers/room.py | 8 +++- synapse/handlers/room_member.py | 81 ++++++++++++++++++++++----------- 5 files changed, 92 insertions(+), 30 deletions(-) create mode 100644 changelog.d/6232.bugfix (limited to 'synapse') diff --git a/changelog.d/6232.bugfix b/changelog.d/6232.bugfix new file mode 100644 index 0000000000..12718ba934 --- /dev/null +++ b/changelog.d/6232.bugfix @@ -0,0 +1 @@ +Remove a room from a server's public rooms list on room upgrade. \ No newline at end of file diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 595706d01a..545d719652 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -555,7 +555,7 @@ class FederationClient(FederationBase): Note that this does not append any events to any graphs. Args: - destinations (str): Candidate homeservers which are probably + destinations (Iterable[str]): Candidate homeservers which are probably participating in the room. room_id (str): The room in which the event will happen. user_id (str): The user whose membership is being evented. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index a932d3085f..dab6be9573 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1106,7 +1106,7 @@ class FederationHandler(BaseHandler): @defer.inlineCallbacks def do_invite_join(self, target_hosts, room_id, joinee, content): """ Attempts to join the `joinee` to the room `room_id` via the - server `target_host`. + servers contained in `target_hosts`. This first triggers a /make_join/ request that returns a partial event that we can fill out and sign. This is then sent to the @@ -1115,6 +1115,15 @@ class FederationHandler(BaseHandler): We suspend processing of any received events from this room until we have finished processing the join. + + Args: + target_hosts (Iterable[str]): List of servers to attempt to join the room with. + + room_id (str): The ID of the room to join. + + joinee (str): The User ID of the joining user. + + content (dict): The event content to use for the join event. """ logger.debug("Joining %s to %s", joinee, room_id) @@ -1174,6 +1183,22 @@ class FederationHandler(BaseHandler): yield self._persist_auth_tree(origin, auth_chain, state, event) + # Check whether this room is the result of an upgrade of a room we already know + # about. If so, migrate over user information + predecessor = yield self.store.get_room_predecessor(room_id) + if not predecessor: + return + old_room_id = predecessor["room_id"] + logger.debug( + "Found predecessor for %s during remote join: %s", room_id, old_room_id + ) + + # We retrieve the room member handler here as to not cause a cyclic dependency + member_handler = self.hs.get_room_member_handler() + yield member_handler.transfer_room_state_on_room_upgrade( + old_room_id, room_id + ) + logger.debug("Finished joining %s to %s", joinee, room_id) finally: room_queue = self.room_queues[room_id] @@ -2442,6 +2467,8 @@ class FederationHandler(BaseHandler): raise e yield self._check_signature(event, context) + + # We retrieve the room member handler here as to not cause a cyclic dependency member_handler = self.hs.get_room_member_handler() yield member_handler.send_membership_event(None, event, context) else: @@ -2502,6 +2529,7 @@ class FederationHandler(BaseHandler): # though the sender isn't a local user. event.internal_metadata.send_on_behalf_of = get_domain_from_id(event.sender) + # We retrieve the room member handler here as to not cause a cyclic dependency member_handler = self.hs.get_room_member_handler() yield member_handler.send_membership_event(None, event, context) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 0182e5b432..e92b2eafd5 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -129,6 +129,7 @@ class RoomCreationHandler(BaseHandler): old_room_id, new_version, # args for _upgrade_room ) + return ret @defer.inlineCallbacks @@ -189,7 +190,12 @@ class RoomCreationHandler(BaseHandler): requester, old_room_id, new_room_id, old_room_state ) - # and finally, shut down the PLs in the old room, and update them in the new + # Copy over user push rules, tags and migrate room directory state + yield self.room_member_handler.transfer_room_state_on_room_upgrade( + old_room_id, new_room_id + ) + + # finally, shut down the PLs in the old room, and update them in the new # room. yield self._update_upgraded_room_pls( requester, old_room_id, new_room_id, old_room_state diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 9a940d2c05..06d09c2947 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -203,10 +203,6 @@ class RoomMemberHandler(object): prev_member_event = yield self.store.get_event(prev_member_event_id) newly_joined = prev_member_event.membership != Membership.JOIN if newly_joined: - # Copy over user state if we're joining an upgraded room - yield self.copy_user_state_if_room_upgrade( - room_id, requester.user.to_string() - ) yield self._user_joined_room(target, room_id) elif event.membership == Membership.LEAVE: if prev_member_event_id: @@ -455,11 +451,6 @@ class RoomMemberHandler(object): requester, remote_room_hosts, room_id, target, content ) - # Copy over user state if this is a join on an remote upgraded room - yield self.copy_user_state_if_room_upgrade( - room_id, requester.user.to_string() - ) - return remote_join_response elif effective_membership_state == Membership.LEAVE: @@ -498,36 +489,72 @@ class RoomMemberHandler(object): return res @defer.inlineCallbacks - def copy_user_state_if_room_upgrade(self, new_room_id, user_id): - """Copy user-specific information when they join a new room if that new room is the + def transfer_room_state_on_room_upgrade(self, old_room_id, room_id): + """Upon our server becoming aware of an upgraded room, either by upgrading a room + ourselves or joining one, we can transfer over information from the previous room. + + Copies user state (tags/push rules) for every local user that was in the old room, as + well as migrating the room directory state. + + Args: + old_room_id (str): The ID of the old room + + room_id (str): The ID of the new room + + Returns: + Deferred + """ + # Find all local users that were in the old room and copy over each user's state + users = yield self.store.get_users_in_room(old_room_id) + yield self.copy_user_state_on_room_upgrade(old_room_id, room_id, users) + + # Add new room to the room directory if the old room was there + # Remove old room from the room directory + old_room = yield self.store.get_room(old_room_id) + if old_room and old_room["is_public"]: + yield self.store.set_room_is_public(old_room_id, False) + yield self.store.set_room_is_public(room_id, True) + + @defer.inlineCallbacks + def copy_user_state_on_room_upgrade(self, old_room_id, new_room_id, user_ids): + """Copy user-specific information when they join a new room when that new room is the result of a room upgrade Args: - new_room_id (str): The ID of the room the user is joining - user_id (str): The ID of the user + old_room_id (str): The ID of upgraded room + new_room_id (str): The ID of the new room + user_ids (Iterable[str]): User IDs to copy state for Returns: Deferred """ - # Check if the new room is an upgraded room - predecessor = yield self.store.get_room_predecessor(new_room_id) - if not predecessor: - return logger.debug( - "Found predecessor for %s: %s. Copying over room tags and push " "rules", + "Copying over room tags and push rules from %s to %s for users %s", + old_room_id, new_room_id, - predecessor, + user_ids, ) - # It is an upgraded room. Copy over old tags - yield self.copy_room_tags_and_direct_to_room( - predecessor["room_id"], new_room_id, user_id - ) - # Copy over push rules - yield self.store.copy_push_rules_from_room_to_room_for_user( - predecessor["room_id"], new_room_id, user_id - ) + for user_id in user_ids: + try: + # It is an upgraded room. Copy over old tags + yield self.copy_room_tags_and_direct_to_room( + old_room_id, new_room_id, user_id + ) + # Copy over push rules + yield self.store.copy_push_rules_from_room_to_room_for_user( + old_room_id, new_room_id, user_id + ) + except Exception: + logger.exception( + "Error copying tags and/or push rules from rooms %s to %s for user %s. " + "Skipping...", + old_room_id, + new_room_id, + user_id, + ) + continue @defer.inlineCallbacks def send_membership_event(self, requester, event, context, ratelimit=True): -- cgit 1.4.1 From c6dbca2422bf77ccbf0b52d9245d28c258dac4f3 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Fri, 1 Nov 2019 10:30:51 +0000 Subject: Incorporate review --- changelog.d/6301.feature | 2 +- synapse/api/constants.py | 5 ++++- synapse/api/filtering.py | 6 ++++-- synapse/storage/data_stores/main/events.py | 12 ++++++++++-- tests/api/test_filtering.py | 10 +++++----- tests/rest/client/v1/test_rooms.py | 10 +++++----- tests/rest/client/v2_alpha/test_sync.py | 10 +++++----- 7 files changed, 34 insertions(+), 21 deletions(-) (limited to 'synapse') diff --git a/changelog.d/6301.feature b/changelog.d/6301.feature index b7ff3fad3b..78a187a1dc 100644 --- a/changelog.d/6301.feature +++ b/changelog.d/6301.feature @@ -1 +1 @@ -Implement label-based filtering. +Implement label-based filtering on `/sync` and `/messages` ([MSC2326](https://github.com/matrix-org/matrix-doc/pull/2326)). diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 999ec02fd9..cf4ce5f5a2 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -140,4 +140,7 @@ class LimitBlockingTypes(object): HS_DISABLED = "hs_disabled" -LabelsField = "org.matrix.labels" +class EventContentFields(object): + """Fields found in events' content, regardless of type.""" + # Labels for the event, cf https://github.com/matrix-org/matrix-doc/pull/2326 + Labels = "org.matrix.labels" diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index bd91b9f018..30a7ee0a7a 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -20,7 +20,7 @@ from jsonschema import FormatChecker from twisted.internet import defer -from synapse.api.constants import LabelsField +from synapse.api.constants import EventContentFields from synapse.api.errors import SynapseError from synapse.storage.presence import UserPresenceState from synapse.types import RoomID, UserID @@ -67,6 +67,8 @@ ROOM_EVENT_FILTER_SCHEMA = { "contains_url": {"type": "boolean"}, "lazy_load_members": {"type": "boolean"}, "include_redundant_members": {"type": "boolean"}, + # Include or exclude events with the provided labels. + # cf https://github.com/matrix-org/matrix-doc/pull/2326 "org.matrix.labels": {"type": "array", "items": {"type": "string"}}, "org.matrix.not_labels": {"type": "array", "items": {"type": "string"}}, }, @@ -307,7 +309,7 @@ class Filter(object): content = event.get("content", {}) # check if there is a string url field in the content for filtering purposes contains_url = isinstance(content.get("url"), text_type) - labels = content.get(LabelsField, []) + labels = content.get(EventContentFields.Labels, []) return self.check_fields(room_id, sender, ev_type, labels, contains_url) diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index 2b900f1ce1..42ffa9066a 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -29,7 +29,7 @@ from prometheus_client import Counter, Histogram from twisted.internet import defer import synapse.metrics -from synapse.api.constants import EventTypes, LabelsField +from synapse.api.constants import EventTypes, EventContentFields from synapse.api.errors import SynapseError from synapse.events import EventBase # noqa: F401 from synapse.events.snapshot import EventContext # noqa: F401 @@ -1491,7 +1491,7 @@ class EventsStore( self._handle_event_relations(txn, event) # Store the labels for this event. - labels = event.content.get(LabelsField) + labels = event.content.get(EventContentFields.Labels) if labels: self.insert_labels_for_event_txn(txn, event.event_id, labels) @@ -2483,6 +2483,14 @@ class EventsStore( ) def insert_labels_for_event_txn(self, txn, event_id, labels): + """Store the mapping between an event's ID and its labels, with one row per + (event_id, label) tuple. + + Args: + txn (LoggingTransaction): The transaction to execute. + event_id (str): The event's ID. + labels (list[str]): A list of text labels. + """ return self._simple_insert_many_txn( txn=txn, table="event_labels", diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index e004ab1ee5..8ec48c4154 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -19,7 +19,7 @@ import jsonschema from twisted.internet import defer -from synapse.api.constants import LabelsField +from synapse.api.constants import EventContentFields from synapse.api.errors import SynapseError from synapse.api.filtering import Filter from synapse.events import FrozenEvent @@ -329,7 +329,7 @@ class FilteringTestCase(unittest.TestCase): sender="@foo:bar", type="m.room.message", room_id="!secretbase:unknown", - content={LabelsField: ["#fun"]}, + content={EventContentFields.Labels: ["#fun"]}, ) self.assertTrue(Filter(definition).check(event)) @@ -338,7 +338,7 @@ class FilteringTestCase(unittest.TestCase): sender="@foo:bar", type="m.room.message", room_id="!secretbase:unknown", - content={LabelsField: ["#notfun"]}, + content={EventContentFields.Labels: ["#notfun"]}, ) self.assertFalse(Filter(definition).check(event)) @@ -349,7 +349,7 @@ class FilteringTestCase(unittest.TestCase): sender="@foo:bar", type="m.room.message", room_id="!secretbase:unknown", - content={LabelsField: ["#fun"]}, + content={EventContentFields.Labels: ["#fun"]}, ) self.assertFalse(Filter(definition).check(event)) @@ -358,7 +358,7 @@ class FilteringTestCase(unittest.TestCase): sender="@foo:bar", type="m.room.message", room_id="!secretbase:unknown", - content={LabelsField: ["#notfun"]}, + content={EventContentFields.Labels: ["#notfun"]}, ) self.assertTrue(Filter(definition).check(event)) diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 188f47bd7d..0dc0faa0e5 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -24,7 +24,7 @@ from six.moves.urllib import parse as urlparse from twisted.internet import defer import synapse.rest.admin -from synapse.api.constants import EventTypes, LabelsField, Membership +from synapse.api.constants import EventContentFields, EventTypes, Membership from synapse.rest.client.v1 import login, profile, room from tests import unittest @@ -860,7 +860,7 @@ class RoomMessageListTestCase(RoomBase): content={ "msgtype": "m.text", "body": "with right label", - LabelsField: ["#fun"], + EventContentFields.Labels: ["#fun"], }, ) @@ -876,7 +876,7 @@ class RoomMessageListTestCase(RoomBase): content={ "msgtype": "m.text", "body": "with wrong label", - LabelsField: ["#work"], + EventContentFields.Labels: ["#work"], }, ) @@ -886,7 +886,7 @@ class RoomMessageListTestCase(RoomBase): content={ "msgtype": "m.text", "body": "with two wrong labels", - LabelsField: ["#work", "#notfun"], + EventContentFields.Labels: ["#work", "#notfun"], }, ) @@ -896,7 +896,7 @@ class RoomMessageListTestCase(RoomBase): content={ "msgtype": "m.text", "body": "with right label", - LabelsField: ["#fun"], + EventContentFields.Labels: ["#fun"], }, ) diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py index c5c199d412..c3c6f75ced 100644 --- a/tests/rest/client/v2_alpha/test_sync.py +++ b/tests/rest/client/v2_alpha/test_sync.py @@ -17,7 +17,7 @@ import json from mock import Mock import synapse.rest.admin -from synapse.api.constants import EventTypes, LabelsField +from synapse.api.constants import EventContentFields, EventTypes from synapse.rest.client.v1 import login, room from synapse.rest.client.v2_alpha import sync @@ -157,7 +157,7 @@ class SyncFilterTestCase(unittest.HomeserverTestCase): content={ "msgtype": "m.text", "body": "with right label", - LabelsField: ["#fun"], + EventContentFields.Labels: ["#fun"], }, tok=tok, ) @@ -175,7 +175,7 @@ class SyncFilterTestCase(unittest.HomeserverTestCase): content={ "msgtype": "m.text", "body": "with wrong label", - LabelsField: ["#work"], + EventContentFields.Labels: ["#work"], }, tok=tok, ) @@ -186,7 +186,7 @@ class SyncFilterTestCase(unittest.HomeserverTestCase): content={ "msgtype": "m.text", "body": "with two wrong labels", - LabelsField: ["#work", "#notfun"], + EventContentFields.Labels: ["#work", "#notfun"], }, tok=tok, ) @@ -197,7 +197,7 @@ class SyncFilterTestCase(unittest.HomeserverTestCase): content={ "msgtype": "m.text", "body": "with right label", - LabelsField: ["#fun"], + EventContentFields.Labels: ["#fun"], }, tok=tok, ) -- cgit 1.4.1 From 57cdb046e48c6837fc9b41ade9b06d3ff68ec91b Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Fri, 1 Nov 2019 10:39:14 +0000 Subject: Lint --- synapse/api/constants.py | 1 + synapse/storage/data_stores/main/events.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) (limited to 'synapse') diff --git a/synapse/api/constants.py b/synapse/api/constants.py index cf4ce5f5a2..066ce18704 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -142,5 +142,6 @@ class LimitBlockingTypes(object): class EventContentFields(object): """Fields found in events' content, regardless of type.""" + # Labels for the event, cf https://github.com/matrix-org/matrix-doc/pull/2326 Labels = "org.matrix.labels" diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index 42ffa9066a..0480161056 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -29,7 +29,7 @@ from prometheus_client import Counter, Histogram from twisted.internet import defer import synapse.metrics -from synapse.api.constants import EventTypes, EventContentFields +from synapse.api.constants import EventContentFields, EventTypes from synapse.api.errors import SynapseError from synapse.events import EventBase # noqa: F401 from synapse.events.snapshot import EventContext # noqa: F401 -- cgit 1.4.1 From e3689ac6f75681c75f9931f79f248269811704c5 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Fri, 1 Nov 2019 10:41:23 +0000 Subject: Add unstable feature flag --- synapse/rest/client/versions.py | 3 +++ 1 file changed, 3 insertions(+) (limited to 'synapse') diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index 1044ae7b4e..bb30ce3f34 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -65,6 +65,9 @@ class VersionsRestServlet(RestServlet): "m.require_identity_server": False, # as per MSC2290 "m.separate_add_and_bind": True, + # Implements support for label-based filtering as described in + # MSC2326. + "org.matrix.label_based_filtering": True, }, }, ) -- cgit 1.4.1 From a2c63c619ad1116cad07564c055aa6af8943d899 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Fri, 1 Nov 2019 11:47:28 +0000 Subject: Add more data to the event_labels table and fix the indexes --- synapse/storage/data_stores/main/events.py | 20 +++++++++++++++++--- .../main/schema/delta/56/event_labels.sql | 4 +++- synapse/storage/data_stores/main/stream.py | 2 +- 3 files changed, 21 insertions(+), 5 deletions(-) (limited to 'synapse') diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index 0480161056..577e79bcf9 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -1493,7 +1493,9 @@ class EventsStore( # Store the labels for this event. labels = event.content.get(EventContentFields.Labels) if labels: - self.insert_labels_for_event_txn(txn, event.event_id, labels) + self.insert_labels_for_event_txn( + txn, event.event_id, labels, event.room_id, event.depth + ) # Insert into the room_memberships table. self._store_room_members_txn( @@ -2482,7 +2484,9 @@ class EventsStore( get_all_updated_current_state_deltas_txn, ) - def insert_labels_for_event_txn(self, txn, event_id, labels): + def insert_labels_for_event_txn( + self, txn, event_id, labels, room_id, topological_ordering + ): """Store the mapping between an event's ID and its labels, with one row per (event_id, label) tuple. @@ -2490,11 +2494,21 @@ class EventsStore( txn (LoggingTransaction): The transaction to execute. event_id (str): The event's ID. labels (list[str]): A list of text labels. + room_id (str): The ID of the room the event was sent to. + topological_ordering (int): The position of the event in the room's topology. """ return self._simple_insert_many_txn( txn=txn, table="event_labels", - values=[{"event_id": event_id, "label": label} for label in labels], + values=[ + { + "event_id": event_id, + "label": label, + "room_id": room_id, + "topological_ordering": topological_ordering, + } + for label in labels + ], ) diff --git a/synapse/storage/data_stores/main/schema/delta/56/event_labels.sql b/synapse/storage/data_stores/main/schema/delta/56/event_labels.sql index 9550b0adaa..765124d131 100644 --- a/synapse/storage/data_stores/main/schema/delta/56/event_labels.sql +++ b/synapse/storage/data_stores/main/schema/delta/56/event_labels.sql @@ -16,7 +16,9 @@ CREATE TABLE IF NOT EXISTS event_labels ( event_id TEXT, label TEXT, + room_id TEXT NOT NULL, + topological_ordering bigint NOT NULL, PRIMARY KEY(event_id, label) ); -CREATE INDEX event_labels_label_idx ON event_labels(label); +CREATE INDEX event_labels_room_id_label_idx ON event_labels(room_id, label, topological_ordering); diff --git a/synapse/storage/data_stores/main/stream.py b/synapse/storage/data_stores/main/stream.py index cfa34ba1e7..616ef91d4e 100644 --- a/synapse/storage/data_stores/main/stream.py +++ b/synapse/storage/data_stores/main/stream.py @@ -874,7 +874,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): sql = ( "SELECT DISTINCT event_id, topological_ordering, stream_ordering" " FROM events" - " LEFT JOIN event_labels USING (event_id)" + " LEFT JOIN event_labels USING (event_id, room_id, topological_ordering)" " WHERE outlier = ? AND room_id = ? AND %(bounds)s" " ORDER BY topological_ordering %(order)s," " stream_ordering %(order)s LIMIT ?" -- cgit 1.4.1 From fe1f2b452073e5939cddd23acc6f2d226673a03f Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Fri, 1 Nov 2019 12:03:44 +0000 Subject: Remove last usages of deprecated logging.warn method (#6314) --- changelog.d/6314.misc | 1 + synapse/config/logger.py | 4 ++-- synapse/handlers/directory.py | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) create mode 100644 changelog.d/6314.misc (limited to 'synapse') diff --git a/changelog.d/6314.misc b/changelog.d/6314.misc new file mode 100644 index 0000000000..2369760272 --- /dev/null +++ b/changelog.d/6314.misc @@ -0,0 +1 @@ +Replace every instance of `logger.warn` method with `logger.warning` as the former is deprecated. \ No newline at end of file diff --git a/synapse/config/logger.py b/synapse/config/logger.py index 2d2c1e54df..75bb904718 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -234,8 +234,8 @@ def setup_logging( # make sure that the first thing we log is a thing we can grep backwards # for - logging.warn("***** STARTING SERVER *****") - logging.warn("Server %s version %s", sys.argv[0], get_version_string(synapse)) + logging.warning("***** STARTING SERVER *****") + logging.warning("Server %s version %s", sys.argv[0], get_version_string(synapse)) logging.info("Server hostname: %s", config.server_name) return logger diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 526379c6f7..c4632f8984 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -250,7 +250,7 @@ class DirectoryHandler(BaseHandler): ignore_backoff=True, ) except CodeMessageException as e: - logging.warn("Error retrieving alias") + logging.warning("Error retrieving alias") if e.code == 404: result = None else: -- cgit 1.4.1 From 1cb84c6486a5131dd284f341bb657434becda255 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 1 Nov 2019 14:07:44 +0000 Subject: Support for routing outbound HTTP requests via a proxy (#6239) The `http_proxy` and `HTTPS_PROXY` env vars can be set to a `host[:port]` value which should point to a proxy. The address of the proxy should be excluded from IP blacklists such as the `url_preview_ip_range_blacklist`. The proxy will then be used for * push * url previews * phone-home stats * recaptcha validation * CAS auth validation It will *not* be used for: * Application Services * Identity servers * Outbound federation * In worker configurations, connections from workers to masters Fixes #4198. --- changelog.d/6238.feature | 1 + synapse/app/homeserver.py | 2 +- synapse/handlers/ui_auth/checkers.py | 2 +- synapse/http/client.py | 17 +- synapse/http/connectproxyclient.py | 195 ++++++++++++ synapse/http/proxyagent.py | 195 ++++++++++++ synapse/push/httppusher.py | 2 +- synapse/rest/client/v1/login.py | 2 +- synapse/rest/media/v1/preview_url_resource.py | 2 + synapse/server.py | 9 + synapse/server.pyi | 9 + tests/http/__init__.py | 17 ++ .../federation/test_matrix_federation_agent.py | 11 +- tests/http/test_proxyagent.py | 334 +++++++++++++++++++++ tests/push/test_http.py | 2 +- tests/server.py | 24 +- 16 files changed, 812 insertions(+), 12 deletions(-) create mode 100644 changelog.d/6238.feature create mode 100644 synapse/http/connectproxyclient.py create mode 100644 synapse/http/proxyagent.py create mode 100644 tests/http/test_proxyagent.py (limited to 'synapse') diff --git a/changelog.d/6238.feature b/changelog.d/6238.feature new file mode 100644 index 0000000000..d225ac33b6 --- /dev/null +++ b/changelog.d/6238.feature @@ -0,0 +1 @@ +Add support for outbound http proxying via http_proxy/HTTPS_PROXY env vars. diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 8997c1f9e7..8d28076d92 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -565,7 +565,7 @@ def run(hs): "Reporting stats to %s: %s" % (hs.config.report_stats_endpoint, stats) ) try: - yield hs.get_simple_http_client().put_json( + yield hs.get_proxied_http_client().put_json( hs.config.report_stats_endpoint, stats ) except Exception as e: diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py index 29aa1e5aaf..8363d887a9 100644 --- a/synapse/handlers/ui_auth/checkers.py +++ b/synapse/handlers/ui_auth/checkers.py @@ -81,7 +81,7 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker): def __init__(self, hs): super().__init__(hs) self._enabled = bool(hs.config.recaptcha_private_key) - self._http_client = hs.get_simple_http_client() + self._http_client = hs.get_proxied_http_client() self._url = hs.config.recaptcha_siteverify_api self._secret = hs.config.recaptcha_private_key diff --git a/synapse/http/client.py b/synapse/http/client.py index 2df5b383b5..d4c285445e 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -45,6 +45,7 @@ from synapse.http import ( cancelled_to_request_timed_out_error, redact_uri, ) +from synapse.http.proxyagent import ProxyAgent from synapse.logging.context import make_deferred_yieldable from synapse.logging.opentracing import set_tag, start_active_span, tags from synapse.util.async_helpers import timeout_deferred @@ -183,7 +184,15 @@ class SimpleHttpClient(object): using HTTP in Matrix """ - def __init__(self, hs, treq_args={}, ip_whitelist=None, ip_blacklist=None): + def __init__( + self, + hs, + treq_args={}, + ip_whitelist=None, + ip_blacklist=None, + http_proxy=None, + https_proxy=None, + ): """ Args: hs (synapse.server.HomeServer) @@ -192,6 +201,8 @@ class SimpleHttpClient(object): we may not request. ip_whitelist (netaddr.IPSet): The whitelisted IP addresses, that we can request if it were otherwise caught in a blacklist. + http_proxy (bytes): proxy server to use for http connections. host[:port] + https_proxy (bytes): proxy server to use for https connections. host[:port] """ self.hs = hs @@ -236,11 +247,13 @@ class SimpleHttpClient(object): # The default context factory in Twisted 14.0.0 (which we require) is # BrowserLikePolicyForHTTPS which will do regular cert validation # 'like a browser' - self.agent = Agent( + self.agent = ProxyAgent( self.reactor, connectTimeout=15, contextFactory=self.hs.get_http_client_context_factory(), pool=pool, + http_proxy=http_proxy, + https_proxy=https_proxy, ) if self._ip_blacklist: diff --git a/synapse/http/connectproxyclient.py b/synapse/http/connectproxyclient.py new file mode 100644 index 0000000000..be7b2ceb8e --- /dev/null +++ b/synapse/http/connectproxyclient.py @@ -0,0 +1,195 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from zope.interface import implementer + +from twisted.internet import defer, protocol +from twisted.internet.error import ConnectError +from twisted.internet.interfaces import IStreamClientEndpoint +from twisted.internet.protocol import connectionDone +from twisted.web import http + +logger = logging.getLogger(__name__) + + +class ProxyConnectError(ConnectError): + pass + + +@implementer(IStreamClientEndpoint) +class HTTPConnectProxyEndpoint(object): + """An Endpoint implementation which will send a CONNECT request to an http proxy + + Wraps an existing HostnameEndpoint for the proxy. + + When we get the connect() request from the connection pool (via the TLS wrapper), + we'll first connect to the proxy endpoint with a ProtocolFactory which will make the + CONNECT request. Once that completes, we invoke the protocolFactory which was passed + in. + + Args: + reactor: the Twisted reactor to use for the connection + proxy_endpoint (IStreamClientEndpoint): the endpoint to use to connect to the + proxy + host (bytes): hostname that we want to CONNECT to + port (int): port that we want to connect to + """ + + def __init__(self, reactor, proxy_endpoint, host, port): + self._reactor = reactor + self._proxy_endpoint = proxy_endpoint + self._host = host + self._port = port + + def __repr__(self): + return "" % (self._proxy_endpoint,) + + def connect(self, protocolFactory): + f = HTTPProxiedClientFactory(self._host, self._port, protocolFactory) + d = self._proxy_endpoint.connect(f) + # once the tcp socket connects successfully, we need to wait for the + # CONNECT to complete. + d.addCallback(lambda conn: f.on_connection) + return d + + +class HTTPProxiedClientFactory(protocol.ClientFactory): + """ClientFactory wrapper that triggers an HTTP proxy CONNECT on connect. + + Once the CONNECT completes, invokes the original ClientFactory to build the + HTTP Protocol object and run the rest of the connection. + + Args: + dst_host (bytes): hostname that we want to CONNECT to + dst_port (int): port that we want to connect to + wrapped_factory (protocol.ClientFactory): The original Factory + """ + + def __init__(self, dst_host, dst_port, wrapped_factory): + self.dst_host = dst_host + self.dst_port = dst_port + self.wrapped_factory = wrapped_factory + self.on_connection = defer.Deferred() + + def startedConnecting(self, connector): + return self.wrapped_factory.startedConnecting(connector) + + def buildProtocol(self, addr): + wrapped_protocol = self.wrapped_factory.buildProtocol(addr) + + return HTTPConnectProtocol( + self.dst_host, self.dst_port, wrapped_protocol, self.on_connection + ) + + def clientConnectionFailed(self, connector, reason): + logger.debug("Connection to proxy failed: %s", reason) + if not self.on_connection.called: + self.on_connection.errback(reason) + return self.wrapped_factory.clientConnectionFailed(connector, reason) + + def clientConnectionLost(self, connector, reason): + logger.debug("Connection to proxy lost: %s", reason) + if not self.on_connection.called: + self.on_connection.errback(reason) + return self.wrapped_factory.clientConnectionLost(connector, reason) + + +class HTTPConnectProtocol(protocol.Protocol): + """Protocol that wraps an existing Protocol to do a CONNECT handshake at connect + + Args: + host (bytes): The original HTTP(s) hostname or IPv4 or IPv6 address literal + to put in the CONNECT request + + port (int): The original HTTP(s) port to put in the CONNECT request + + wrapped_protocol (interfaces.IProtocol): the original protocol (probably + HTTPChannel or TLSMemoryBIOProtocol, but could be anything really) + + connected_deferred (Deferred): a Deferred which will be callbacked with + wrapped_protocol when the CONNECT completes + """ + + def __init__(self, host, port, wrapped_protocol, connected_deferred): + self.host = host + self.port = port + self.wrapped_protocol = wrapped_protocol + self.connected_deferred = connected_deferred + self.http_setup_client = HTTPConnectSetupClient(self.host, self.port) + self.http_setup_client.on_connected.addCallback(self.proxyConnected) + + def connectionMade(self): + self.http_setup_client.makeConnection(self.transport) + + def connectionLost(self, reason=connectionDone): + if self.wrapped_protocol.connected: + self.wrapped_protocol.connectionLost(reason) + + self.http_setup_client.connectionLost(reason) + + if not self.connected_deferred.called: + self.connected_deferred.errback(reason) + + def proxyConnected(self, _): + self.wrapped_protocol.makeConnection(self.transport) + + self.connected_deferred.callback(self.wrapped_protocol) + + # Get any pending data from the http buf and forward it to the original protocol + buf = self.http_setup_client.clearLineBuffer() + if buf: + self.wrapped_protocol.dataReceived(buf) + + def dataReceived(self, data): + # if we've set up the HTTP protocol, we can send the data there + if self.wrapped_protocol.connected: + return self.wrapped_protocol.dataReceived(data) + + # otherwise, we must still be setting up the connection: send the data to the + # setup client + return self.http_setup_client.dataReceived(data) + + +class HTTPConnectSetupClient(http.HTTPClient): + """HTTPClient protocol to send a CONNECT message for proxies and read the response. + + Args: + host (bytes): The hostname to send in the CONNECT message + port (int): The port to send in the CONNECT message + """ + + def __init__(self, host, port): + self.host = host + self.port = port + self.on_connected = defer.Deferred() + + def connectionMade(self): + logger.debug("Connected to proxy, sending CONNECT") + self.sendCommand(b"CONNECT", b"%s:%d" % (self.host, self.port)) + self.endHeaders() + + def handleStatus(self, version, status, message): + logger.debug("Got Status: %s %s %s", status, message, version) + if status != b"200": + raise ProxyConnectError("Unexpected status on CONNECT: %s" % status) + + def handleEndHeaders(self): + logger.debug("End Headers") + self.on_connected.callback(None) + + def handleResponse(self, body): + pass diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py new file mode 100644 index 0000000000..332da02a8d --- /dev/null +++ b/synapse/http/proxyagent.py @@ -0,0 +1,195 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import re + +from zope.interface import implementer + +from twisted.internet import defer +from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS +from twisted.python.failure import Failure +from twisted.web.client import URI, BrowserLikePolicyForHTTPS, _AgentBase +from twisted.web.error import SchemeNotSupported +from twisted.web.iweb import IAgent + +from synapse.http.connectproxyclient import HTTPConnectProxyEndpoint + +logger = logging.getLogger(__name__) + +_VALID_URI = re.compile(br"\A[\x21-\x7e]+\Z") + + +@implementer(IAgent) +class ProxyAgent(_AgentBase): + """An Agent implementation which will use an HTTP proxy if one was requested + + Args: + reactor: twisted reactor to place outgoing + connections. + + contextFactory (IPolicyForHTTPS): A factory for TLS contexts, to control the + verification parameters of OpenSSL. The default is to use a + `BrowserLikePolicyForHTTPS`, so unless you have special + requirements you can leave this as-is. + + connectTimeout (float): The amount of time that this Agent will wait + for the peer to accept a connection. + + bindAddress (bytes): The local address for client sockets to bind to. + + pool (HTTPConnectionPool|None): connection pool to be used. If None, a + non-persistent pool instance will be created. + """ + + def __init__( + self, + reactor, + contextFactory=BrowserLikePolicyForHTTPS(), + connectTimeout=None, + bindAddress=None, + pool=None, + http_proxy=None, + https_proxy=None, + ): + _AgentBase.__init__(self, reactor, pool) + + self._endpoint_kwargs = {} + if connectTimeout is not None: + self._endpoint_kwargs["timeout"] = connectTimeout + if bindAddress is not None: + self._endpoint_kwargs["bindAddress"] = bindAddress + + self.http_proxy_endpoint = _http_proxy_endpoint( + http_proxy, reactor, **self._endpoint_kwargs + ) + + self.https_proxy_endpoint = _http_proxy_endpoint( + https_proxy, reactor, **self._endpoint_kwargs + ) + + self._policy_for_https = contextFactory + self._reactor = reactor + + def request(self, method, uri, headers=None, bodyProducer=None): + """ + Issue a request to the server indicated by the given uri. + + Supports `http` and `https` schemes. + + An existing connection from the connection pool may be used or a new one may be + created. + + See also: twisted.web.iweb.IAgent.request + + Args: + method (bytes): The request method to use, such as `GET`, `POST`, etc + + uri (bytes): The location of the resource to request. + + headers (Headers|None): Extra headers to send with the request + + bodyProducer (IBodyProducer|None): An object which can generate bytes to + make up the body of this request (for example, the properly encoded + contents of a file for a file upload). Or, None if the request is to + have no body. + + Returns: + Deferred[IResponse]: completes when the header of the response has + been received (regardless of the response status code). + """ + uri = uri.strip() + if not _VALID_URI.match(uri): + raise ValueError("Invalid URI {!r}".format(uri)) + + parsed_uri = URI.fromBytes(uri) + pool_key = (parsed_uri.scheme, parsed_uri.host, parsed_uri.port) + request_path = parsed_uri.originForm + + if parsed_uri.scheme == b"http" and self.http_proxy_endpoint: + # Cache *all* connections under the same key, since we are only + # connecting to a single destination, the proxy: + pool_key = ("http-proxy", self.http_proxy_endpoint) + endpoint = self.http_proxy_endpoint + request_path = uri + elif parsed_uri.scheme == b"https" and self.https_proxy_endpoint: + endpoint = HTTPConnectProxyEndpoint( + self._reactor, + self.https_proxy_endpoint, + parsed_uri.host, + parsed_uri.port, + ) + else: + # not using a proxy + endpoint = HostnameEndpoint( + self._reactor, parsed_uri.host, parsed_uri.port, **self._endpoint_kwargs + ) + + logger.debug("Requesting %s via %s", uri, endpoint) + + if parsed_uri.scheme == b"https": + tls_connection_creator = self._policy_for_https.creatorForNetloc( + parsed_uri.host, parsed_uri.port + ) + endpoint = wrapClientTLS(tls_connection_creator, endpoint) + elif parsed_uri.scheme == b"http": + pass + else: + return defer.fail( + Failure( + SchemeNotSupported("Unsupported scheme: %r" % (parsed_uri.scheme,)) + ) + ) + + return self._requestWithEndpoint( + pool_key, endpoint, method, parsed_uri, headers, bodyProducer, request_path + ) + + +def _http_proxy_endpoint(proxy, reactor, **kwargs): + """Parses an http proxy setting and returns an endpoint for the proxy + + Args: + proxy (bytes|None): the proxy setting + reactor: reactor to be used to connect to the proxy + kwargs: other args to be passed to HostnameEndpoint + + Returns: + interfaces.IStreamClientEndpoint|None: endpoint to use to connect to the proxy, + or None + """ + if proxy is None: + return None + + # currently we only support hostname:port. Some apps also support + # protocol://[:port], which allows a way of requiring a TLS connection to the + # proxy. + + host, port = parse_host_port(proxy, default_port=1080) + return HostnameEndpoint(reactor, host, port, **kwargs) + + +def parse_host_port(hostport, default_port=None): + # could have sworn we had one of these somewhere else... + if b":" in hostport: + host, port = hostport.rsplit(b":", 1) + try: + port = int(port) + return host, port + except ValueError: + # the thing after the : wasn't a valid port; presumably this is an + # IPv6 address. + pass + + return hostport, default_port diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index 7dde2ad055..e994037be6 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -103,7 +103,7 @@ class HttpPusher(object): if "url" not in self.data: raise PusherConfigException("'url' required in data for HTTP pusher") self.url = self.data["url"] - self.http_client = hs.get_simple_http_client() + self.http_client = hs.get_proxied_http_client() self.data_minus_url = {} self.data_minus_url.update(self.data) del self.data_minus_url["url"] diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 00a7dd6d09..24a0ce74f2 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -381,7 +381,7 @@ class CasTicketServlet(RestServlet): self.cas_displayname_attribute = hs.config.cas_displayname_attribute self.cas_required_attributes = hs.config.cas_required_attributes self._sso_auth_handler = SSOAuthHandler(hs) - self._http_client = hs.get_simple_http_client() + self._http_client = hs.get_proxied_http_client() @defer.inlineCallbacks def on_GET(self, request): diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 5a25b6b3fc..531d923f76 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -74,6 +74,8 @@ class PreviewUrlResource(DirectServeResource): treq_args={"browser_like_redirects": True}, ip_whitelist=hs.config.url_preview_ip_range_whitelist, ip_blacklist=hs.config.url_preview_ip_range_blacklist, + http_proxy=os.getenv("http_proxy"), + https_proxy=os.getenv("HTTPS_PROXY"), ) self.media_repo = media_repo self.primary_base_path = media_repo.primary_base_path diff --git a/synapse/server.py b/synapse/server.py index 0b81af646c..f8aeebcff8 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -23,6 +23,7 @@ # Imports required for the default HomeServer() implementation import abc import logging +import os from twisted.enterprise import adbapi from twisted.mail.smtp import sendmail @@ -168,6 +169,7 @@ class HomeServer(object): "filtering", "http_client_context_factory", "simple_http_client", + "proxied_http_client", "media_repository", "media_repository_resource", "federation_transport_client", @@ -311,6 +313,13 @@ class HomeServer(object): def build_simple_http_client(self): return SimpleHttpClient(self) + def build_proxied_http_client(self): + return SimpleHttpClient( + self, + http_proxy=os.getenv("http_proxy"), + https_proxy=os.getenv("HTTPS_PROXY"), + ) + def build_room_creation_handler(self): return RoomCreationHandler(self) diff --git a/synapse/server.pyi b/synapse/server.pyi index 83d1f11283..b5e0b57095 100644 --- a/synapse/server.pyi +++ b/synapse/server.pyi @@ -12,6 +12,7 @@ import synapse.handlers.message import synapse.handlers.room import synapse.handlers.room_member import synapse.handlers.set_password +import synapse.http.client import synapse.rest.media.v1.media_repository import synapse.server_notices.server_notices_manager import synapse.server_notices.server_notices_sender @@ -38,6 +39,14 @@ class HomeServer(object): pass def get_state_resolution_handler(self) -> synapse.state.StateResolutionHandler: pass + def get_simple_http_client(self) -> synapse.http.client.SimpleHttpClient: + """Fetch an HTTP client implementation which doesn't do any blacklisting + or support any HTTP_PROXY settings""" + pass + def get_proxied_http_client(self) -> synapse.http.client.SimpleHttpClient: + """Fetch an HTTP client implementation which doesn't do any blacklisting + but does support HTTP_PROXY settings""" + pass def get_deactivate_account_handler( self, ) -> synapse.handlers.deactivate_account.DeactivateAccountHandler: diff --git a/tests/http/__init__.py b/tests/http/__init__.py index 2d5dba6464..2096ba3c91 100644 --- a/tests/http/__init__.py +++ b/tests/http/__init__.py @@ -20,6 +20,23 @@ from zope.interface import implementer from OpenSSL import SSL from OpenSSL.SSL import Connection from twisted.internet.interfaces import IOpenSSLServerConnectionCreator +from twisted.internet.ssl import Certificate, trustRootFromCertificates +from twisted.web.client import BrowserLikePolicyForHTTPS # noqa: F401 +from twisted.web.iweb import IPolicyForHTTPS # noqa: F401 + + +def get_test_https_policy(): + """Get a test IPolicyForHTTPS which trusts the test CA cert + + Returns: + IPolicyForHTTPS + """ + ca_file = get_test_ca_cert_file() + with open(ca_file) as stream: + content = stream.read() + cert = Certificate.loadPEM(content) + trust_root = trustRootFromCertificates([cert]) + return BrowserLikePolicyForHTTPS(trustRoot=trust_root) def get_test_ca_cert_file(): diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index 71d7025264..cfcd98ff7d 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -124,19 +124,24 @@ class MatrixFederationAgentTests(unittest.TestCase): FakeTransport(client_protocol, self.reactor, server_tls_protocol) ) + # grab a hold of the TLS connection, in case it gets torn down + server_tls_connection = server_tls_protocol._tlsConnection + + # fish the test server back out of the server-side TLS protocol. + http_protocol = server_tls_protocol.wrappedProtocol + # give the reactor a pump to get the TLS juices flowing. self.reactor.pump((0.1,)) # check the SNI - server_name = server_tls_protocol._tlsConnection.get_servername() + server_name = server_tls_connection.get_servername() self.assertEqual( server_name, expected_sni, "Expected SNI %s but got %s" % (expected_sni, server_name), ) - # fish the test server back out of the server-side TLS protocol. - return server_tls_protocol.wrappedProtocol + return http_protocol @defer.inlineCallbacks def _make_get_request(self, uri): diff --git a/tests/http/test_proxyagent.py b/tests/http/test_proxyagent.py new file mode 100644 index 0000000000..22abf76515 --- /dev/null +++ b/tests/http/test_proxyagent.py @@ -0,0 +1,334 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +import treq + +from twisted.internet import interfaces # noqa: F401 +from twisted.internet.protocol import Factory +from twisted.protocols.tls import TLSMemoryBIOFactory +from twisted.web.http import HTTPChannel + +from synapse.http.proxyagent import ProxyAgent + +from tests.http import TestServerTLSConnectionFactory, get_test_https_policy +from tests.server import FakeTransport, ThreadedMemoryReactorClock +from tests.unittest import TestCase + +logger = logging.getLogger(__name__) + +HTTPFactory = Factory.forProtocol(HTTPChannel) + + +class MatrixFederationAgentTests(TestCase): + def setUp(self): + self.reactor = ThreadedMemoryReactorClock() + + def _make_connection( + self, client_factory, server_factory, ssl=False, expected_sni=None + ): + """Builds a test server, and completes the outgoing client connection + + Args: + client_factory (interfaces.IProtocolFactory): the the factory that the + application is trying to use to make the outbound connection. We will + invoke it to build the client Protocol + + server_factory (interfaces.IProtocolFactory): a factory to build the + server-side protocol + + ssl (bool): If true, we will expect an ssl connection and wrap + server_factory with a TLSMemoryBIOFactory + + expected_sni (bytes|None): the expected SNI value + + Returns: + IProtocol: the server Protocol returned by server_factory + """ + if ssl: + server_factory = _wrap_server_factory_for_tls(server_factory) + + server_protocol = server_factory.buildProtocol(None) + + # now, tell the client protocol factory to build the client protocol, + # and wire the output of said protocol up to the server via + # a FakeTransport. + # + # Normally this would be done by the TCP socket code in Twisted, but we are + # stubbing that out here. + client_protocol = client_factory.buildProtocol(None) + client_protocol.makeConnection( + FakeTransport(server_protocol, self.reactor, client_protocol) + ) + + # tell the server protocol to send its stuff back to the client, too + server_protocol.makeConnection( + FakeTransport(client_protocol, self.reactor, server_protocol) + ) + + if ssl: + http_protocol = server_protocol.wrappedProtocol + tls_connection = server_protocol._tlsConnection + else: + http_protocol = server_protocol + tls_connection = None + + # give the reactor a pump to get the TLS juices flowing (if needed) + self.reactor.advance(0) + + if expected_sni is not None: + server_name = tls_connection.get_servername() + self.assertEqual( + server_name, + expected_sni, + "Expected SNI %s but got %s" % (expected_sni, server_name), + ) + + return http_protocol + + def test_http_request(self): + agent = ProxyAgent(self.reactor) + + self.reactor.lookups["test.com"] = "1.2.3.4" + d = agent.request(b"GET", b"http://test.com") + + # there should be a pending TCP connection + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (host, port, client_factory, _timeout, _bindAddress) = clients[0] + self.assertEqual(host, "1.2.3.4") + self.assertEqual(port, 80) + + # make a test server, and wire up the client + http_server = self._make_connection( + client_factory, _get_test_protocol_factory() + ) + + # the FakeTransport is async, so we need to pump the reactor + self.reactor.advance(0) + + # now there should be a pending request + self.assertEqual(len(http_server.requests), 1) + + request = http_server.requests[0] + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/") + self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"]) + request.write(b"result") + request.finish() + + self.reactor.advance(0) + + resp = self.successResultOf(d) + body = self.successResultOf(treq.content(resp)) + self.assertEqual(body, b"result") + + def test_https_request(self): + agent = ProxyAgent(self.reactor, contextFactory=get_test_https_policy()) + + self.reactor.lookups["test.com"] = "1.2.3.4" + d = agent.request(b"GET", b"https://test.com/abc") + + # there should be a pending TCP connection + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (host, port, client_factory, _timeout, _bindAddress) = clients[0] + self.assertEqual(host, "1.2.3.4") + self.assertEqual(port, 443) + + # make a test server, and wire up the client + http_server = self._make_connection( + client_factory, + _get_test_protocol_factory(), + ssl=True, + expected_sni=b"test.com", + ) + + # the FakeTransport is async, so we need to pump the reactor + self.reactor.advance(0) + + # now there should be a pending request + self.assertEqual(len(http_server.requests), 1) + + request = http_server.requests[0] + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/abc") + self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"]) + request.write(b"result") + request.finish() + + self.reactor.advance(0) + + resp = self.successResultOf(d) + body = self.successResultOf(treq.content(resp)) + self.assertEqual(body, b"result") + + def test_http_request_via_proxy(self): + agent = ProxyAgent(self.reactor, http_proxy=b"proxy.com:8888") + + self.reactor.lookups["proxy.com"] = "1.2.3.5" + d = agent.request(b"GET", b"http://test.com") + + # there should be a pending TCP connection + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (host, port, client_factory, _timeout, _bindAddress) = clients[0] + self.assertEqual(host, "1.2.3.5") + self.assertEqual(port, 8888) + + # make a test server, and wire up the client + http_server = self._make_connection( + client_factory, _get_test_protocol_factory() + ) + + # the FakeTransport is async, so we need to pump the reactor + self.reactor.advance(0) + + # now there should be a pending request + self.assertEqual(len(http_server.requests), 1) + + request = http_server.requests[0] + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"http://test.com") + self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"]) + request.write(b"result") + request.finish() + + self.reactor.advance(0) + + resp = self.successResultOf(d) + body = self.successResultOf(treq.content(resp)) + self.assertEqual(body, b"result") + + def test_https_request_via_proxy(self): + agent = ProxyAgent( + self.reactor, + contextFactory=get_test_https_policy(), + https_proxy=b"proxy.com", + ) + + self.reactor.lookups["proxy.com"] = "1.2.3.5" + d = agent.request(b"GET", b"https://test.com/abc") + + # there should be a pending TCP connection + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (host, port, client_factory, _timeout, _bindAddress) = clients[0] + self.assertEqual(host, "1.2.3.5") + self.assertEqual(port, 1080) + + # make a test HTTP server, and wire up the client + proxy_server = self._make_connection( + client_factory, _get_test_protocol_factory() + ) + + # fish the transports back out so that we can do the old switcheroo + s2c_transport = proxy_server.transport + client_protocol = s2c_transport.other + c2s_transport = client_protocol.transport + + # the FakeTransport is async, so we need to pump the reactor + self.reactor.advance(0) + + # now there should be a pending CONNECT request + self.assertEqual(len(proxy_server.requests), 1) + + request = proxy_server.requests[0] + self.assertEqual(request.method, b"CONNECT") + self.assertEqual(request.path, b"test.com:443") + + # tell the proxy server not to close the connection + proxy_server.persistent = True + + # this just stops the http Request trying to do a chunked response + # request.setHeader(b"Content-Length", b"0") + request.finish() + + # now we can replace the proxy channel with a new, SSL-wrapped HTTP channel + ssl_factory = _wrap_server_factory_for_tls(_get_test_protocol_factory()) + ssl_protocol = ssl_factory.buildProtocol(None) + http_server = ssl_protocol.wrappedProtocol + + ssl_protocol.makeConnection( + FakeTransport(client_protocol, self.reactor, ssl_protocol) + ) + c2s_transport.other = ssl_protocol + + self.reactor.advance(0) + + server_name = ssl_protocol._tlsConnection.get_servername() + expected_sni = b"test.com" + self.assertEqual( + server_name, + expected_sni, + "Expected SNI %s but got %s" % (expected_sni, server_name), + ) + + # now there should be a pending request + self.assertEqual(len(http_server.requests), 1) + + request = http_server.requests[0] + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/abc") + self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"]) + request.write(b"result") + request.finish() + + self.reactor.advance(0) + + resp = self.successResultOf(d) + body = self.successResultOf(treq.content(resp)) + self.assertEqual(body, b"result") + + +def _wrap_server_factory_for_tls(factory, sanlist=None): + """Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory + + The resultant factory will create a TLS server which presents a certificate + signed by our test CA, valid for the domains in `sanlist` + + Args: + factory (interfaces.IProtocolFactory): protocol factory to wrap + sanlist (iterable[bytes]): list of domains the cert should be valid for + + Returns: + interfaces.IProtocolFactory + """ + if sanlist is None: + sanlist = [b"DNS:test.com"] + + connection_creator = TestServerTLSConnectionFactory(sanlist=sanlist) + return TLSMemoryBIOFactory( + connection_creator, isClient=False, wrappedFactory=factory + ) + + +def _get_test_protocol_factory(): + """Get a protocol Factory which will build an HTTPChannel + + Returns: + interfaces.IProtocolFactory + """ + server_factory = Factory.forProtocol(HTTPChannel) + + # Request.finish expects the factory to have a 'log' method. + server_factory.log = _log_request + + return server_factory + + +def _log_request(request): + """Implements Factory.log, which is expected by Request.finish""" + logger.info("Completed request %s", request) diff --git a/tests/push/test_http.py b/tests/push/test_http.py index 8ce6bb62da..af2327fb66 100644 --- a/tests/push/test_http.py +++ b/tests/push/test_http.py @@ -50,7 +50,7 @@ class HTTPPusherTests(HomeserverTestCase): config = self.default_config() config["start_pushers"] = True - hs = self.setup_test_homeserver(config=config, simple_http_client=m) + hs = self.setup_test_homeserver(config=config, proxied_http_client=m) return hs diff --git a/tests/server.py b/tests/server.py index 469efb4edb..f878aeaada 100644 --- a/tests/server.py +++ b/tests/server.py @@ -395,11 +395,24 @@ class FakeTransport(object): self.disconnecting = True if self._protocol: self._protocol.connectionLost(reason) - self.disconnected = True + + # if we still have data to write, delay until that is done + if self.buffer: + logger.info( + "FakeTransport: Delaying disconnect until buffer is flushed" + ) + else: + self.disconnected = True def abortConnection(self): logger.info("FakeTransport: abortConnection()") - self.loseConnection() + + if not self.disconnecting: + self.disconnecting = True + if self._protocol: + self._protocol.connectionLost(None) + + self.disconnected = True def pauseProducing(self): if not self.producer: @@ -430,6 +443,9 @@ class FakeTransport(object): self._reactor.callLater(0.0, _produce) def write(self, byt): + if self.disconnecting: + raise Exception("Writing to disconnecting FakeTransport") + self.buffer = self.buffer + byt # always actually do the write asynchronously. Some protocols (notably the @@ -474,6 +490,10 @@ class FakeTransport(object): if self.buffer and self.autoflush: self._reactor.callLater(0.0, self.flush) + if not self.buffer and self.disconnecting: + logger.info("FakeTransport: Buffer now empty, completing disconnect") + self.disconnected = True + def connect_client(reactor: IReactorTCP, client_id: int) -> AccumulatingProtocol: """ -- cgit 1.4.1 From 559844565515334d1168746f96fcc0db4fce3f6e Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Fri, 1 Nov 2019 16:18:34 +0000 Subject: Update synapse/storage/data_stores/main/schema/delta/56/event_labels.sql Co-Authored-By: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> --- synapse/storage/data_stores/main/schema/delta/56/event_labels.sql | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse') diff --git a/synapse/storage/data_stores/main/schema/delta/56/event_labels.sql b/synapse/storage/data_stores/main/schema/delta/56/event_labels.sql index 765124d131..2acd8e1be5 100644 --- a/synapse/storage/data_stores/main/schema/delta/56/event_labels.sql +++ b/synapse/storage/data_stores/main/schema/delta/56/event_labels.sql @@ -17,7 +17,7 @@ CREATE TABLE IF NOT EXISTS event_labels ( event_id TEXT, label TEXT, room_id TEXT NOT NULL, - topological_ordering bigint NOT NULL, + topological_ordering BIGINT NOT NULL, PRIMARY KEY(event_id, label) ); -- cgit 1.4.1 From c6516adbe03a0acdd614ba6eb9d6f447dd4259e9 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 1 Nov 2019 16:19:09 +0000 Subject: Factor out an _AsyncEventContextImpl (#6298) The intention here is to make it clearer which fields we can expect to be populated when: notably, that the _event_type etc aren't used for the synchronous impl of EventContext. --- changelog.d/6298.misc | 1 + synapse/events/snapshot.py | 107 ++++++++++++++++------------------------- synapse/handlers/federation.py | 38 +++++++-------- tests/test_federation.py | 4 +- 4 files changed, 65 insertions(+), 85 deletions(-) create mode 100644 changelog.d/6298.misc (limited to 'synapse') diff --git a/changelog.d/6298.misc b/changelog.d/6298.misc new file mode 100644 index 0000000000..d4190730b2 --- /dev/null +++ b/changelog.d/6298.misc @@ -0,0 +1 @@ +Refactor EventContext for clarity. \ No newline at end of file diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index 27cd8a63ff..a269de5482 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -37,9 +37,6 @@ class EventContext: delta_ids (dict[(str, str), str]): Delta from ``prev_group``. (type, state_key) -> event_id. ``None`` for an outlier. - prev_state_events (?): XXX: is this ever set to anything other than - the empty list? - app_service: FIXME _current_state_ids (dict[(str, str), str]|None): @@ -51,36 +48,16 @@ class EventContext: The current state map excluding the current event. None if outlier or we haven't fetched the state from DB yet. (type, state_key) -> event_id - - _fetching_state_deferred (Deferred|None): Resolves when *_state_ids have - been calculated. None if we haven't started calculating yet - - _event_type (str): The type of the event the context is associated with. - Only set when state has not been fetched yet. - - _event_state_key (str|None): The state_key of the event the context is - associated with. Only set when state has not been fetched yet. - - _prev_state_id (str|None): If the event associated with the context is - a state event, then `_prev_state_id` is the event_id of the state - that was replaced. - Only set when state has not been fetched yet. """ state_group = attr.ib(default=None) rejected = attr.ib(default=False) prev_group = attr.ib(default=None) delta_ids = attr.ib(default=None) - prev_state_events = attr.ib(default=attr.Factory(list)) app_service = attr.ib(default=None) - _current_state_ids = attr.ib(default=None) _prev_state_ids = attr.ib(default=None) - _prev_state_id = attr.ib(default=None) - - _event_type = attr.ib(default=None) - _event_state_key = attr.ib(default=None) - _fetching_state_deferred = attr.ib(default=None) + _current_state_ids = attr.ib(default=None) @staticmethod def with_state( @@ -90,7 +67,6 @@ class EventContext: current_state_ids=current_state_ids, prev_state_ids=prev_state_ids, state_group=state_group, - fetching_state_deferred=defer.succeed(None), prev_group=prev_group, delta_ids=delta_ids, ) @@ -125,7 +101,6 @@ class EventContext: "rejected": self.rejected, "prev_group": self.prev_group, "delta_ids": _encode_state_dict(self.delta_ids), - "prev_state_events": self.prev_state_events, "app_service_id": self.app_service.id if self.app_service else None, } @@ -141,7 +116,7 @@ class EventContext: Returns: EventContext """ - context = EventContext( + context = _AsyncEventContextImpl( # We use the state_group and prev_state_id stuff to pull the # current_state_ids out of the DB and construct prev_state_ids. prev_state_id=input["prev_state_id"], @@ -151,7 +126,6 @@ class EventContext: prev_group=input["prev_group"], delta_ids=_decode_state_dict(input["delta_ids"]), rejected=input["rejected"], - prev_state_events=input["prev_state_events"], ) app_service_id = input["app_service_id"] @@ -170,14 +144,7 @@ class EventContext: Maps a (type, state_key) to the event ID of the state event matching this tuple. """ - - if not self._fetching_state_deferred: - self._fetching_state_deferred = run_in_background( - self._fill_out_state, store - ) - - yield make_deferred_yieldable(self._fetching_state_deferred) - + yield self._ensure_fetched(store) return self._current_state_ids @defer.inlineCallbacks @@ -190,14 +157,7 @@ class EventContext: Maps a (type, state_key) to the event ID of the state event matching this tuple. """ - - if not self._fetching_state_deferred: - self._fetching_state_deferred = run_in_background( - self._fill_out_state, store - ) - - yield make_deferred_yieldable(self._fetching_state_deferred) - + yield self._ensure_fetched(store) return self._prev_state_ids def get_cached_current_state_ids(self): @@ -211,6 +171,44 @@ class EventContext: return self._current_state_ids + def _ensure_fetched(self, store): + return defer.succeed(None) + + +@attr.s(slots=True) +class _AsyncEventContextImpl(EventContext): + """ + An implementation of EventContext which fetches _current_state_ids and + _prev_state_ids from the database on demand. + + Attributes: + + _fetching_state_deferred (Deferred|None): Resolves when *_state_ids have + been calculated. None if we haven't started calculating yet + + _event_type (str): The type of the event the context is associated with. + + _event_state_key (str): The state_key of the event the context is + associated with. + + _prev_state_id (str|None): If the event associated with the context is + a state event, then `_prev_state_id` is the event_id of the state + that was replaced. + """ + + _prev_state_id = attr.ib(default=None) + _event_type = attr.ib(default=None) + _event_state_key = attr.ib(default=None) + _fetching_state_deferred = attr.ib(default=None) + + def _ensure_fetched(self, store): + if not self._fetching_state_deferred: + self._fetching_state_deferred = run_in_background( + self._fill_out_state, store + ) + + return make_deferred_yieldable(self._fetching_state_deferred) + @defer.inlineCallbacks def _fill_out_state(self, store): """Called to populate the _current_state_ids and _prev_state_ids @@ -228,27 +226,6 @@ class EventContext: else: self._prev_state_ids = self._current_state_ids - @defer.inlineCallbacks - def update_state( - self, state_group, prev_state_ids, current_state_ids, prev_group, delta_ids - ): - """Replace the state in the context - """ - - # We need to make sure we wait for any ongoing fetching of state - # to complete so that the updated state doesn't get clobbered - if self._fetching_state_deferred: - yield make_deferred_yieldable(self._fetching_state_deferred) - - self.state_group = state_group - self._prev_state_ids = prev_state_ids - self.prev_group = prev_group - self._current_state_ids = current_state_ids - self.delta_ids = delta_ids - - # We need to ensure that that we've marked as having fetched the state - self._fetching_state_deferred = defer.succeed(None) - def _encode_state_dict(state_dict): """Since dicts of (type, state_key) -> event_id cannot be serialized in diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index dab6be9573..8cafcfdab0 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -45,6 +45,7 @@ from synapse.api.errors import ( from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions from synapse.crypto.event_signing import compute_event_signature from synapse.event_auth import auth_types_for_event +from synapse.events.snapshot import EventContext from synapse.events.validator import EventValidator from synapse.logging.context import ( make_deferred_yieldable, @@ -1871,14 +1872,7 @@ class FederationHandler(BaseHandler): if c and c.type == EventTypes.Create: auth_events[(c.type, c.state_key)] = c - try: - yield self.do_auth(origin, event, context, auth_events=auth_events) - except AuthError as e: - logger.warning( - "[%s %s] Rejecting: %s", event.room_id, event.event_id, e.msg - ) - - context.rejected = RejectedReason.AUTH_ERROR + context = yield self.do_auth(origin, event, context, auth_events=auth_events) if not context.rejected: yield self._check_for_soft_fail(event, state, backfilled) @@ -2047,12 +2041,12 @@ class FederationHandler(BaseHandler): Also NB that this function adds entries to it. Returns: - defer.Deferred[None] + defer.Deferred[EventContext]: updated context object """ room_version = yield self.store.get_room_version(event.room_id) try: - yield self._update_auth_events_and_context_for_auth( + context = yield self._update_auth_events_and_context_for_auth( origin, event, context, auth_events ) except Exception: @@ -2070,7 +2064,9 @@ class FederationHandler(BaseHandler): event_auth.check(room_version, event, auth_events=auth_events) except AuthError as e: logger.warning("Failed auth resolution for %r because %s", event, e) - raise e + context.rejected = RejectedReason.AUTH_ERROR + + return context @defer.inlineCallbacks def _update_auth_events_and_context_for_auth( @@ -2094,7 +2090,7 @@ class FederationHandler(BaseHandler): auth_events (dict[(str, str)->synapse.events.EventBase]): Returns: - defer.Deferred[None] + defer.Deferred[EventContext]: updated context """ event_auth_events = set(event.auth_event_ids()) @@ -2133,7 +2129,7 @@ class FederationHandler(BaseHandler): # The other side isn't around or doesn't implement the # endpoint, so lets just bail out. logger.info("Failed to get event auth from remote: %s", e) - return + return context seen_remotes = yield self.store.have_seen_events( [e.event_id for e in remote_auth_chain] @@ -2174,7 +2170,7 @@ class FederationHandler(BaseHandler): if event.internal_metadata.is_outlier(): logger.info("Skipping auth_event fetch for outlier") - return + return context # FIXME: Assumes we have and stored all the state for all the # prev_events @@ -2183,7 +2179,7 @@ class FederationHandler(BaseHandler): ) if not different_auth: - return + return context logger.info( "auth_events refers to events which are not in our calculated auth " @@ -2230,10 +2226,12 @@ class FederationHandler(BaseHandler): auth_events.update(new_state) - yield self._update_context_for_auth_events( + context = yield self._update_context_for_auth_events( event, context, auth_events, event_key ) + return context + @defer.inlineCallbacks def _update_context_for_auth_events(self, event, context, auth_events, event_key): """Update the state_ids in an event context after auth event resolution, @@ -2242,14 +2240,16 @@ class FederationHandler(BaseHandler): Args: event (Event): The event we're handling the context for - context (synapse.events.snapshot.EventContext): event context - to be updated + context (synapse.events.snapshot.EventContext): initial event context auth_events (dict[(str, str)->str]): Events to update in the event context. event_key ((str, str)): (type, state_key) for the current event. this will not be included in the current_state in the context. + + Returns: + Deferred[EventContext]: new event context """ state_updates = { k: a.event_id for k, a in iteritems(auth_events) if k != event_key @@ -2274,7 +2274,7 @@ class FederationHandler(BaseHandler): current_state_ids=current_state_ids, ) - yield context.update_state( + return EventContext.with_state( state_group=state_group, current_state_ids=current_state_ids, prev_state_ids=prev_state_ids, diff --git a/tests/test_federation.py b/tests/test_federation.py index d1acb16f30..7d82b58466 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -59,7 +59,9 @@ class MessageAcceptTests(unittest.TestCase): ) self.handler = self.homeserver.get_handlers().federation_handler - self.handler.do_auth = lambda *a, **b: succeed(True) + self.handler.do_auth = lambda origin, event, context, auth_events: succeed( + context + ) self.client = self.homeserver.get_federation_client() self.client._check_sigs_and_hash_and_fetch = lambda dest, pdus, **k: succeed( pdus -- cgit 1.4.1 From 988d8d6507a0e8b34f2c352c77b5742197762190 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Fri, 1 Nov 2019 16:22:44 +0000 Subject: Incorporate review --- synapse/api/constants.py | 2 +- synapse/api/filtering.py | 2 +- synapse/storage/data_stores/main/events.py | 2 +- synapse/storage/data_stores/main/schema/delta/56/event_labels.sql | 6 ++++++ tests/api/test_filtering.py | 8 ++++---- tests/rest/client/v1/test_rooms.py | 8 ++++---- tests/rest/client/v2_alpha/test_sync.py | 8 ++++---- 7 files changed, 21 insertions(+), 15 deletions(-) (limited to 'synapse') diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 066ce18704..49c4b85054 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -144,4 +144,4 @@ class EventContentFields(object): """Fields found in events' content, regardless of type.""" # Labels for the event, cf https://github.com/matrix-org/matrix-doc/pull/2326 - Labels = "org.matrix.labels" + LABELS = "org.matrix.labels" diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index 30a7ee0a7a..bec13f08d8 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -309,7 +309,7 @@ class Filter(object): content = event.get("content", {}) # check if there is a string url field in the content for filtering purposes contains_url = isinstance(content.get("url"), text_type) - labels = content.get(EventContentFields.Labels, []) + labels = content.get(EventContentFields.LABELS, []) return self.check_fields(room_id, sender, ev_type, labels, contains_url) diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index 577e79bcf9..1045c7fa2e 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -1491,7 +1491,7 @@ class EventsStore( self._handle_event_relations(txn, event) # Store the labels for this event. - labels = event.content.get(EventContentFields.Labels) + labels = event.content.get(EventContentFields.LABELS) if labels: self.insert_labels_for_event_txn( txn, event.event_id, labels, event.room_id, event.depth diff --git a/synapse/storage/data_stores/main/schema/delta/56/event_labels.sql b/synapse/storage/data_stores/main/schema/delta/56/event_labels.sql index 2acd8e1be5..5e29c1da19 100644 --- a/synapse/storage/data_stores/main/schema/delta/56/event_labels.sql +++ b/synapse/storage/data_stores/main/schema/delta/56/event_labels.sql @@ -13,6 +13,8 @@ * limitations under the License. */ +-- room_id and topoligical_ordering are denormalised from the events table in order to +-- make the index work. CREATE TABLE IF NOT EXISTS event_labels ( event_id TEXT, label TEXT, @@ -21,4 +23,8 @@ CREATE TABLE IF NOT EXISTS event_labels ( PRIMARY KEY(event_id, label) ); + +-- This index enables an event pagination looking for a particular label to index the +-- event_labels table first, which is much quicker than scanning the events table and then +-- filtering by label, if the label is rarely used relative to the size of the room. CREATE INDEX event_labels_room_id_label_idx ON event_labels(room_id, label, topological_ordering); diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index 8ec48c4154..2dc5052249 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -329,7 +329,7 @@ class FilteringTestCase(unittest.TestCase): sender="@foo:bar", type="m.room.message", room_id="!secretbase:unknown", - content={EventContentFields.Labels: ["#fun"]}, + content={EventContentFields.LABELS: ["#fun"]}, ) self.assertTrue(Filter(definition).check(event)) @@ -338,7 +338,7 @@ class FilteringTestCase(unittest.TestCase): sender="@foo:bar", type="m.room.message", room_id="!secretbase:unknown", - content={EventContentFields.Labels: ["#notfun"]}, + content={EventContentFields.LABELS: ["#notfun"]}, ) self.assertFalse(Filter(definition).check(event)) @@ -349,7 +349,7 @@ class FilteringTestCase(unittest.TestCase): sender="@foo:bar", type="m.room.message", room_id="!secretbase:unknown", - content={EventContentFields.Labels: ["#fun"]}, + content={EventContentFields.LABELS: ["#fun"]}, ) self.assertFalse(Filter(definition).check(event)) @@ -358,7 +358,7 @@ class FilteringTestCase(unittest.TestCase): sender="@foo:bar", type="m.room.message", room_id="!secretbase:unknown", - content={EventContentFields.Labels: ["#notfun"]}, + content={EventContentFields.LABELS: ["#notfun"]}, ) self.assertTrue(Filter(definition).check(event)) diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 0dc0faa0e5..5e38fd6ced 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -860,7 +860,7 @@ class RoomMessageListTestCase(RoomBase): content={ "msgtype": "m.text", "body": "with right label", - EventContentFields.Labels: ["#fun"], + EventContentFields.LABELS: ["#fun"], }, ) @@ -876,7 +876,7 @@ class RoomMessageListTestCase(RoomBase): content={ "msgtype": "m.text", "body": "with wrong label", - EventContentFields.Labels: ["#work"], + EventContentFields.LABELS: ["#work"], }, ) @@ -886,7 +886,7 @@ class RoomMessageListTestCase(RoomBase): content={ "msgtype": "m.text", "body": "with two wrong labels", - EventContentFields.Labels: ["#work", "#notfun"], + EventContentFields.LABELS: ["#work", "#notfun"], }, ) @@ -896,7 +896,7 @@ class RoomMessageListTestCase(RoomBase): content={ "msgtype": "m.text", "body": "with right label", - EventContentFields.Labels: ["#fun"], + EventContentFields.LABELS: ["#fun"], }, ) diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py index c3c6f75ced..3283c0e47b 100644 --- a/tests/rest/client/v2_alpha/test_sync.py +++ b/tests/rest/client/v2_alpha/test_sync.py @@ -157,7 +157,7 @@ class SyncFilterTestCase(unittest.HomeserverTestCase): content={ "msgtype": "m.text", "body": "with right label", - EventContentFields.Labels: ["#fun"], + EventContentFields.LABELS: ["#fun"], }, tok=tok, ) @@ -175,7 +175,7 @@ class SyncFilterTestCase(unittest.HomeserverTestCase): content={ "msgtype": "m.text", "body": "with wrong label", - EventContentFields.Labels: ["#work"], + EventContentFields.LABELS: ["#work"], }, tok=tok, ) @@ -186,7 +186,7 @@ class SyncFilterTestCase(unittest.HomeserverTestCase): content={ "msgtype": "m.text", "body": "with two wrong labels", - EventContentFields.Labels: ["#work", "#notfun"], + EventContentFields.LABELS: ["#work", "#notfun"], }, tok=tok, ) @@ -197,7 +197,7 @@ class SyncFilterTestCase(unittest.HomeserverTestCase): content={ "msgtype": "m.text", "body": "with right label", - EventContentFields.Labels: ["#fun"], + EventContentFields.LABELS: ["#fun"], }, tok=tok, ) -- cgit 1.4.1 From cc6243b4c08bfae77c9ff29d23c40568ab284924 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 4 Nov 2019 12:40:18 +0000 Subject: document the REPLICATE command a bit better (#6305) since I found myself wonder how it works --- changelog.d/6305.misc | 1 + docs/tcp_replication.md | 15 +++++- synapse/replication/slave/storage/_base.py | 10 +++- synapse/replication/tcp/client.py | 20 +++++--- synapse/replication/tcp/protocol.py | 74 +++++++++++++++++++++++++++++- 5 files changed, 110 insertions(+), 10 deletions(-) create mode 100644 changelog.d/6305.misc (limited to 'synapse') diff --git a/changelog.d/6305.misc b/changelog.d/6305.misc new file mode 100644 index 0000000000..f047fc3062 --- /dev/null +++ b/changelog.d/6305.misc @@ -0,0 +1 @@ +Add some documentation about worker replication. diff --git a/docs/tcp_replication.md b/docs/tcp_replication.md index e099d8a87b..ba9e874d07 100644 --- a/docs/tcp_replication.md +++ b/docs/tcp_replication.md @@ -199,7 +199,20 @@ client (C): #### REPLICATE (C) - Asks the server to replicate a given stream +Asks the server to replicate a given stream. The syntax is: + +``` + REPLICATE +``` + +Where `` may be either: + * a numeric stream_id to stream updates since (exclusive) + * `NOW` to stream all subsequent updates. + +The `` is the name of a replication stream to subscribe +to (see [here](../synapse/replication/tcp/streams/_base.py) for a list +of streams). It can also be `ALL` to subscribe to all known streams, +in which case the `` must be set to `NOW`. #### USER_SYNC (C) diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py index 182cb2a1d8..456bc005a0 100644 --- a/synapse/replication/slave/storage/_base.py +++ b/synapse/replication/slave/storage/_base.py @@ -14,6 +14,7 @@ # limitations under the License. import logging +from typing import Dict import six @@ -44,7 +45,14 @@ class BaseSlavedStore(SQLBaseStore): self.hs = hs - def stream_positions(self): + def stream_positions(self) -> Dict[str, int]: + """ + Get the current positions of all the streams this store wants to subscribe to + + Returns: + map from stream name to the most recent update we have for + that stream (ie, the point we want to start replicating from) + """ pos = {} if self._cache_id_gen: pos["caches"] = self._cache_id_gen.get_current_token() diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 563ce0fc53..fead78388c 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -16,10 +16,17 @@ """ import logging +from typing import Dict from twisted.internet import defer from twisted.internet.protocol import ReconnectingClientFactory +from synapse.replication.slave.storage._base import BaseSlavedStore +from synapse.replication.tcp.protocol import ( + AbstractReplicationClientHandler, + ClientReplicationStreamProtocol, +) + from .commands import ( FederationAckCommand, InvalidateCacheCommand, @@ -27,7 +34,6 @@ from .commands import ( UserIpCommand, UserSyncCommand, ) -from .protocol import ClientReplicationStreamProtocol logger = logging.getLogger(__name__) @@ -42,7 +48,7 @@ class ReplicationClientFactory(ReconnectingClientFactory): maxDelay = 30 # Try at least once every N seconds - def __init__(self, hs, client_name, handler): + def __init__(self, hs, client_name, handler: AbstractReplicationClientHandler): self.client_name = client_name self.handler = handler self.server_name = hs.config.server_name @@ -68,13 +74,13 @@ class ReplicationClientFactory(ReconnectingClientFactory): ReconnectingClientFactory.clientConnectionFailed(self, connector, reason) -class ReplicationClientHandler(object): +class ReplicationClientHandler(AbstractReplicationClientHandler): """A base handler that can be passed to the ReplicationClientFactory. By default proxies incoming replication data to the SlaveStore. """ - def __init__(self, store): + def __init__(self, store: BaseSlavedStore): self.store = store # The current connection. None if we are currently (re)connecting @@ -138,11 +144,13 @@ class ReplicationClientHandler(object): if d: d.callback(data) - def get_streams_to_replicate(self): + def get_streams_to_replicate(self) -> Dict[str, int]: """Called when a new connection has been established and we need to subscribe to streams. - Returns a dictionary of stream name to token. + Returns: + map from stream name to the most recent update we have for + that stream (ie, the point we want to start replicating from) """ args = self.store.stream_positions() user_account_data = args.pop("user_account_data", None) diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index b64f3f44b5..afaf002fe6 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -48,7 +48,7 @@ indicate which side is sending, these are *not* included on the wire:: > ERROR server stopping * connection closed by server * """ - +import abc import fcntl import logging import struct @@ -65,6 +65,7 @@ from twisted.python.failure import Failure from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.metrics import LaterGauge from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.util import Clock from synapse.util.stringutils import random_string from .commands import ( @@ -558,11 +559,80 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): self.streamer.lost_connection(self) +class AbstractReplicationClientHandler(metaclass=abc.ABCMeta): + """ + The interface for the handler that should be passed to + ClientReplicationStreamProtocol + """ + + @abc.abstractmethod + def on_rdata(self, stream_name, token, rows): + """Called to handle a batch of replication data with a given stream token. + + Args: + stream_name (str): name of the replication stream for this batch of rows + token (int): stream token for this batch of rows + rows (list): a list of Stream.ROW_TYPE objects as returned by + Stream.parse_row. + + Returns: + Deferred|None + """ + raise NotImplementedError() + + @abc.abstractmethod + def on_position(self, stream_name, token): + """Called when we get new position data.""" + raise NotImplementedError() + + @abc.abstractmethod + def on_sync(self, data): + """Called when get a new SYNC command.""" + raise NotImplementedError() + + @abc.abstractmethod + def get_streams_to_replicate(self): + """Called when a new connection has been established and we need to + subscribe to streams. + + Returns: + map from stream name to the most recent update we have for + that stream (ie, the point we want to start replicating from) + """ + raise NotImplementedError() + + @abc.abstractmethod + def get_currently_syncing_users(self): + """Get the list of currently syncing users (if any). This is called + when a connection has been established and we need to send the + currently syncing users.""" + raise NotImplementedError() + + @abc.abstractmethod + def update_connection(self, connection): + """Called when a connection has been established (or lost with None). + """ + raise NotImplementedError() + + @abc.abstractmethod + def finished_connecting(self): + """Called when we have successfully subscribed and caught up to all + streams we're interested in. + """ + raise NotImplementedError() + + class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): VALID_INBOUND_COMMANDS = VALID_SERVER_COMMANDS VALID_OUTBOUND_COMMANDS = VALID_CLIENT_COMMANDS - def __init__(self, client_name, server_name, clock, handler): + def __init__( + self, + client_name: str, + server_name: str, + clock: Clock, + handler: AbstractReplicationClientHandler, + ): BaseReplicationStreamProtocol.__init__(self, clock) self.client_name = client_name -- cgit 1.4.1 From 4e1c7b79fa3498c48106c17d0edbab2f7bcc0c38 Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Tue, 5 Nov 2019 05:05:48 +1100 Subject: Remove the psutil dependency (#6318) * remove psutil and replace with resource --- changelog.d/6318.misc | 1 + synapse/app/homeserver.py | 174 ++++++++++++++++++++++------------------- synapse/python_dependencies.py | 1 - synapse/server.py | 2 + tests/test_phone_home.py | 51 ++++++++++++ 5 files changed, 146 insertions(+), 83 deletions(-) create mode 100644 changelog.d/6318.misc create mode 100644 tests/test_phone_home.py (limited to 'synapse') diff --git a/changelog.d/6318.misc b/changelog.d/6318.misc new file mode 100644 index 0000000000..63527ccef4 --- /dev/null +++ b/changelog.d/6318.misc @@ -0,0 +1 @@ +Remove the dependency on psutil and replace functionality with the stdlib `resource` module. diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 8d28076d92..00a7f8330e 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -19,12 +19,13 @@ from __future__ import print_function import gc import logging +import math import os +import resource import sys from six import iteritems -import psutil from prometheus_client import Gauge from twisted.application import service @@ -471,6 +472,87 @@ class SynapseService(service.Service): return self._port.stopListening() +# Contains the list of processes we will be monitoring +# currently either 0 or 1 +_stats_process = [] + + +@defer.inlineCallbacks +def phone_stats_home(hs, stats, stats_process=_stats_process): + logger.info("Gathering stats for reporting") + now = int(hs.get_clock().time()) + uptime = int(now - hs.start_time) + if uptime < 0: + uptime = 0 + + stats["homeserver"] = hs.config.server_name + stats["server_context"] = hs.config.server_context + stats["timestamp"] = now + stats["uptime_seconds"] = uptime + version = sys.version_info + stats["python_version"] = "{}.{}.{}".format( + version.major, version.minor, version.micro + ) + stats["total_users"] = yield hs.get_datastore().count_all_users() + + total_nonbridged_users = yield hs.get_datastore().count_nonbridged_users() + stats["total_nonbridged_users"] = total_nonbridged_users + + daily_user_type_results = yield hs.get_datastore().count_daily_user_type() + for name, count in iteritems(daily_user_type_results): + stats["daily_user_type_" + name] = count + + room_count = yield hs.get_datastore().get_room_count() + stats["total_room_count"] = room_count + + stats["daily_active_users"] = yield hs.get_datastore().count_daily_users() + stats["monthly_active_users"] = yield hs.get_datastore().count_monthly_users() + stats["daily_active_rooms"] = yield hs.get_datastore().count_daily_active_rooms() + stats["daily_messages"] = yield hs.get_datastore().count_daily_messages() + + r30_results = yield hs.get_datastore().count_r30_users() + for name, count in iteritems(r30_results): + stats["r30_users_" + name] = count + + daily_sent_messages = yield hs.get_datastore().count_daily_sent_messages() + stats["daily_sent_messages"] = daily_sent_messages + stats["cache_factor"] = CACHE_SIZE_FACTOR + stats["event_cache_size"] = hs.config.event_cache_size + + # + # Performance statistics + # + old = stats_process[0] + new = (now, resource.getrusage(resource.RUSAGE_SELF)) + stats_process[0] = new + + # Get RSS in bytes + stats["memory_rss"] = new[1].ru_maxrss + + # Get CPU time in % of a single core, not % of all cores + used_cpu_time = (new[1].ru_utime + new[1].ru_stime) - ( + old[1].ru_utime + old[1].ru_stime + ) + if used_cpu_time == 0 or new[0] == old[0]: + stats["cpu_average"] = 0 + else: + stats["cpu_average"] = math.floor(used_cpu_time / (new[0] - old[0]) * 100) + + # + # Database version + # + + stats["database_engine"] = hs.get_datastore().database_engine_name + stats["database_server_version"] = hs.get_datastore().get_server_version() + logger.info("Reporting stats to %s: %s" % (hs.config.report_stats_endpoint, stats)) + try: + yield hs.get_proxied_http_client().put_json( + hs.config.report_stats_endpoint, stats + ) + except Exception as e: + logger.warning("Error reporting stats: %s", e) + + def run(hs): PROFILE_SYNAPSE = False if PROFILE_SYNAPSE: @@ -497,91 +579,19 @@ def run(hs): reactor.run = profile(reactor.run) clock = hs.get_clock() - start_time = clock.time() stats = {} - # Contains the list of processes we will be monitoring - # currently either 0 or 1 - stats_process = [] + def performance_stats_init(): + _stats_process.clear() + _stats_process.append( + (int(hs.get_clock().time(), resource.getrusage(resource.RUSAGE_SELF))) + ) def start_phone_stats_home(): - return run_as_background_process("phone_stats_home", phone_stats_home) - - @defer.inlineCallbacks - def phone_stats_home(): - logger.info("Gathering stats for reporting") - now = int(hs.get_clock().time()) - uptime = int(now - start_time) - if uptime < 0: - uptime = 0 - - stats["homeserver"] = hs.config.server_name - stats["server_context"] = hs.config.server_context - stats["timestamp"] = now - stats["uptime_seconds"] = uptime - version = sys.version_info - stats["python_version"] = "{}.{}.{}".format( - version.major, version.minor, version.micro - ) - stats["total_users"] = yield hs.get_datastore().count_all_users() - - total_nonbridged_users = yield hs.get_datastore().count_nonbridged_users() - stats["total_nonbridged_users"] = total_nonbridged_users - - daily_user_type_results = yield hs.get_datastore().count_daily_user_type() - for name, count in iteritems(daily_user_type_results): - stats["daily_user_type_" + name] = count - - room_count = yield hs.get_datastore().get_room_count() - stats["total_room_count"] = room_count - - stats["daily_active_users"] = yield hs.get_datastore().count_daily_users() - stats["monthly_active_users"] = yield hs.get_datastore().count_monthly_users() - stats[ - "daily_active_rooms" - ] = yield hs.get_datastore().count_daily_active_rooms() - stats["daily_messages"] = yield hs.get_datastore().count_daily_messages() - - r30_results = yield hs.get_datastore().count_r30_users() - for name, count in iteritems(r30_results): - stats["r30_users_" + name] = count - - daily_sent_messages = yield hs.get_datastore().count_daily_sent_messages() - stats["daily_sent_messages"] = daily_sent_messages - stats["cache_factor"] = CACHE_SIZE_FACTOR - stats["event_cache_size"] = hs.config.event_cache_size - - if len(stats_process) > 0: - stats["memory_rss"] = 0 - stats["cpu_average"] = 0 - for process in stats_process: - stats["memory_rss"] += process.memory_info().rss - stats["cpu_average"] += int(process.cpu_percent(interval=None)) - - stats["database_engine"] = hs.get_datastore().database_engine_name - stats["database_server_version"] = hs.get_datastore().get_server_version() - logger.info( - "Reporting stats to %s: %s" % (hs.config.report_stats_endpoint, stats) + return run_as_background_process( + "phone_stats_home", phone_stats_home, hs, stats ) - try: - yield hs.get_proxied_http_client().put_json( - hs.config.report_stats_endpoint, stats - ) - except Exception as e: - logger.warning("Error reporting stats: %s", e) - - def performance_stats_init(): - try: - process = psutil.Process() - # Ensure we can fetch both, and make the initial request for cpu_percent - # so the next request will use this as the initial point. - process.memory_info().rss - process.cpu_percent(interval=None) - logger.info("report_stats can use psutil") - stats_process.append(process) - except (AttributeError): - logger.warning("Unable to read memory/cpu stats. Disabling reporting.") def generate_user_daily_visit_stats(): return run_as_background_process( @@ -626,7 +636,7 @@ def run(hs): if hs.config.report_stats: logger.info("Scheduling stats reporting for 3 hour intervals") - clock.looping_call(start_phone_stats_home, 3 * 60 * 60 * 1000) + clock.looping_call(start_phone_stats_home, 3 * 60 * 60 * 1000, hs, stats) # We need to defer this init for the cases that we daemonize # otherwise the process ID we get is that of the non-daemon process @@ -634,7 +644,7 @@ def run(hs): # We wait 5 minutes to send the first set of stats as the server can # be quite busy the first few minutes - clock.call_later(5 * 60, start_phone_stats_home) + clock.call_later(5 * 60, start_phone_stats_home, hs, stats) _base.start_reactor( "synapse-homeserver", diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index aa7da1c543..5871feaafd 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -61,7 +61,6 @@ REQUIREMENTS = [ "bcrypt>=3.1.0", "pillow>=4.3.0", "sortedcontainers>=1.4.4", - "psutil>=2.0.0", "pymacaroons>=0.13.0", "msgpack>=0.5.2", "phonenumbers>=8.2.0", diff --git a/synapse/server.py b/synapse/server.py index f8aeebcff8..90c3b072e8 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -221,6 +221,7 @@ class HomeServer(object): self.hostname = hostname self._building = {} self._listening_services = [] + self.start_time = None self.clock = Clock(reactor) self.distributor = Distributor() @@ -240,6 +241,7 @@ class HomeServer(object): datastore = self.DATASTORE_CLASS(conn, self) self.datastores = DataStores(datastore, conn, self) conn.commit() + self.start_time = int(self.get_clock().time()) logger.info("Finished setting up.") def setup_master(self): diff --git a/tests/test_phone_home.py b/tests/test_phone_home.py new file mode 100644 index 0000000000..7657bddea5 --- /dev/null +++ b/tests/test_phone_home.py @@ -0,0 +1,51 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import resource + +import mock + +from synapse.app.homeserver import phone_stats_home + +from tests.unittest import HomeserverTestCase + + +class PhoneHomeStatsTestCase(HomeserverTestCase): + def test_performance_frozen_clock(self): + """ + If time doesn't move, don't error out. + """ + past_stats = [ + (self.hs.get_clock().time(), resource.getrusage(resource.RUSAGE_SELF)) + ] + stats = {} + self.get_success(phone_stats_home(self.hs, stats, past_stats)) + self.assertEqual(stats["cpu_average"], 0) + + def test_performance_100(self): + """ + 1 second of usage over 1 second is 100% CPU usage. + """ + real_res = resource.getrusage(resource.RUSAGE_SELF) + old_resource = mock.Mock(spec=real_res) + old_resource.ru_utime = real_res.ru_utime - 1 + old_resource.ru_stime = real_res.ru_stime + old_resource.ru_maxrss = real_res.ru_maxrss + + past_stats = [(self.hs.get_clock().time(), old_resource)] + stats = {} + self.reactor.advance(1) + self.get_success(phone_stats_home(self.hs, stats, past_stats)) + self.assertApproximates(stats["cpu_average"], 100, tolerance=2.5) -- cgit 1.4.1 From 408600282774391fade9e4a6606f4967184865c0 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 5 Nov 2019 13:23:25 +0000 Subject: Improve documentation for EventContext fields (#6319) --- changelog.d/6319.misc | 1 + synapse/events/snapshot.py | 91 +++++++++++++++++++++++++++++++++------------- synapse/state/__init__.py | 3 ++ 3 files changed, 69 insertions(+), 26 deletions(-) create mode 100644 changelog.d/6319.misc (limited to 'synapse') diff --git a/changelog.d/6319.misc b/changelog.d/6319.misc new file mode 100644 index 0000000000..9711ef21ed --- /dev/null +++ b/changelog.d/6319.misc @@ -0,0 +1 @@ +Improve documentation for EventContext fields. diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index a269de5482..5f07f6fe4b 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -12,6 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Dict, Optional, Tuple, Union + from six import iteritems import attr @@ -19,45 +21,82 @@ from frozendict import frozendict from twisted.internet import defer +from synapse.appservice import ApplicationService from synapse.logging.context import make_deferred_yieldable, run_in_background @attr.s(slots=True) class EventContext: """ + Holds information relevant to persisting an event + Attributes: - state_group (int|None): state group id, if the state has been stored - as a state group. This is usually only None if e.g. the event is - an outlier. - rejected (bool|str): A rejection reason if the event was rejected, else - False - - prev_group (int): Previously persisted state group. ``None`` for an - outlier. - delta_ids (dict[(str, str), str]): Delta from ``prev_group``. - (type, state_key) -> event_id. ``None`` for an outlier. - - app_service: FIXME - - _current_state_ids (dict[(str, str), str]|None): - The current state map including the current event. None if outlier - or we haven't fetched the state from DB yet. + rejected: A rejection reason if the event was rejected, else False + + state_group: The ID of the state group for this event. Note that state events + are persisted with a state group which includes the new event, so this is + effectively the state *after* the event in question. + + For a *rejected* state event, where the state of the rejected event is + ignored, this state_group should never make it into the + event_to_state_groups table. Indeed, inspecting this value for a rejected + state event is almost certainly incorrect. + + For an outlier, where we don't have the state at the event, this will be + None. + + prev_group: If it is known, ``state_group``'s prev_group. Note that this being + None does not necessarily mean that ``state_group`` does not have + a prev_group! + + If ``state_group`` is None (ie, the event is an outlier), ``prev_group`` + will always also be ``None``. + + Note that this *not* (necessarily) the state group associated with + ``_prev_state_ids``. + + delta_ids: If ``prev_group`` is not None, the state delta between ``prev_group`` + and ``state_group``. + + app_service: If this event is being sent by a (local) application service, that + app service. + + _current_state_ids: The room state map, including this event - ie, the state + in ``state_group``. + (type, state_key) -> event_id - _prev_state_ids (dict[(str, str), str]|None): - The current state map excluding the current event. None if outlier - or we haven't fetched the state from DB yet. + FIXME: what is this for an outlier? it seems ill-defined. It seems like + it could be either {}, or the state we were given by the remote + server, depending on $THINGS + + Note that this is a private attribute: it should be accessed via + ``get_current_state_ids``. _AsyncEventContext impl calculates this + on-demand: it will be None until that happens. + + _prev_state_ids: The room state map, excluding this event. For a non-state + event, this will be the same as _current_state_events. + + Note that it is a completely different thing to prev_group! + (type, state_key) -> event_id + + FIXME: again, what is this for an outlier? + + As with _current_state_ids, this is a private attribute. It should be + accessed via get_prev_state_ids. """ - state_group = attr.ib(default=None) - rejected = attr.ib(default=False) - prev_group = attr.ib(default=None) - delta_ids = attr.ib(default=None) - app_service = attr.ib(default=None) + rejected = attr.ib(default=False, type=Union[bool, str]) + state_group = attr.ib(default=None, type=Optional[int]) + prev_group = attr.ib(default=None, type=Optional[int]) + delta_ids = attr.ib(default=None, type=Optional[Dict[Tuple[str, str], str]]) + app_service = attr.ib(default=None, type=Optional[ApplicationService]) - _prev_state_ids = attr.ib(default=None) - _current_state_ids = attr.ib(default=None) + _current_state_ids = attr.ib( + default=None, type=Optional[Dict[Tuple[str, str], str]] + ) + _prev_state_ids = attr.ib(default=None, type=Optional[Dict[Tuple[str, str], str]]) @staticmethod def with_state( diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 4e91eb66fe..2c04ab1854 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -232,6 +232,9 @@ class StateHandler(object): # If this is an outlier, then we know it shouldn't have any current # state. Certainly store.get_current_state won't return any, and # persisting the event won't store the state group. + + # FIXME: why do we populate current_state_ids? I thought the point was + # that we weren't supposed to have any state for outliers? if old_state: prev_state_ids = {(s.type, s.state_key): s.event_id for s in old_state} if event.is_state(): -- cgit 1.4.1