diff options
author | Erik Johnston <erik@matrix.org> | 2019-10-30 13:37:04 +0000 |
---|---|---|
committer | Erik Johnston <erik@matrix.org> | 2019-10-30 13:37:04 +0000 |
commit | ec6de1cc7d915abf6907b1d6a93336f8cd435cdd (patch) | |
tree | 5e0beaa14d2e9abc82116a7a07740abf2508177b /synapse | |
parent | Review comments (diff) | |
parent | Merge pull request #6291 from matrix-org/erikj/fix_cache_descriptor (diff) | |
download | synapse-ec6de1cc7d915abf6907b1d6a93336f8cd435cdd.tar.xz |
Merge branch 'develop' of github.com:matrix-org/synapse into erikj/split_out_persistence_store
Diffstat (limited to 'synapse')
45 files changed, 555 insertions, 542 deletions
diff --git a/synapse/__init__.py b/synapse/__init__.py index ee3313a41c..8587ffa76f 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -36,7 +36,7 @@ try: except ImportError: pass -__version__ = "1.4.1" +__version__ = "1.5.0" if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): # We import here so that we don't have to install a bunch of deps when diff --git a/synapse/api/auth.py b/synapse/api/auth.py index cd347fbe1b..53f3bb0fa8 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -25,7 +25,13 @@ from twisted.internet import defer import synapse.logging.opentracing as opentracing import synapse.types from synapse import event_auth -from synapse.api.constants import EventTypes, JoinRules, Membership, UserTypes +from synapse.api.constants import ( + EventTypes, + JoinRules, + LimitBlockingTypes, + Membership, + UserTypes, +) from synapse.api.errors import ( AuthError, Codes, @@ -726,7 +732,7 @@ class Auth(object): self.hs.config.hs_disabled_message, errcode=Codes.RESOURCE_LIMIT_EXCEEDED, admin_contact=self.hs.config.admin_contact, - limit_type=self.hs.config.hs_disabled_limit_type, + limit_type=LimitBlockingTypes.HS_DISABLED, ) if self.hs.config.limit_usage_by_mau is True: assert not (user_id and threepid) @@ -759,5 +765,5 @@ class Auth(object): "Monthly Active User Limit Exceeded", admin_contact=self.hs.config.admin_contact, errcode=Codes.RESOURCE_LIMIT_EXCEEDED, - limit_type="monthly_active_user", + limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER, ) diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 60e99e4663..312196675e 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -131,3 +131,10 @@ class RelationTypes(object): ANNOTATION = "m.annotation" REPLACE = "m.replace" REFERENCE = "m.reference" + + +class LimitBlockingTypes(object): + """Reasons that a server may be blocked""" + + MONTHLY_ACTIVE_USER = "monthly_active_user" + HS_DISABLED = "hs_disabled" diff --git a/synapse/config/registration.py b/synapse/config/registration.py index ab41623b2b..1f6dac69da 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -300,7 +300,7 @@ class RegistrationConfig(Config): # If a delegate is specified, the config option public_baseurl must also be filled out. # account_threepid_delegates: - #email: https://example.com # Delegate email sending to example.org + #email: https://example.com # Delegate email sending to example.com #msisdn: http://localhost:8090 # Delegate SMS sending to this local process # Users who register on this homeserver will automatically be joined diff --git a/synapse/config/server.py b/synapse/config/server.py index c942841578..d556df308d 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -171,6 +171,7 @@ class ServerConfig(Config): ) self.mau_trial_days = config.get("mau_trial_days", 0) + self.mau_limit_alerting = config.get("mau_limit_alerting", True) # How long to keep redacted events in the database in unredacted form # before redacting them. @@ -192,7 +193,6 @@ class ServerConfig(Config): # Options to disable HS self.hs_disabled = config.get("hs_disabled", False) self.hs_disabled_message = config.get("hs_disabled_message", "") - self.hs_disabled_limit_type = config.get("hs_disabled_limit_type", "") # Admin uri to direct users at should their instance become blocked # due to resource constraints @@ -675,7 +675,6 @@ class ServerConfig(Config): # #hs_disabled: false #hs_disabled_message: 'Human readable reason for why the HS is blocked' - #hs_disabled_limit_type: 'error code(str), to help clients decode reason' # Monthly Active User Blocking # @@ -695,9 +694,16 @@ class ServerConfig(Config): # sign up in a short space of time never to return after their initial # session. # + # 'mau_limit_alerting' is a means of limiting client side alerting + # should the mau limit be reached. This is useful for small instances + # where the admin has 5 mau seats (say) for 5 specific people and no + # interest increasing the mau limit further. Defaults to True, which + # means that alerting is enabled + # #limit_usage_by_mau: false #max_mau_value: 50 #mau_trial_days: 2 + #mau_limit_alerting: false # If enabled, the metrics for the number of monthly active users will # be populated, however no one will be limited. If limit_usage_by_mau diff --git a/synapse/crypto/event_signing.py b/synapse/crypto/event_signing.py index 694fb2c816..ccaa8a9920 100644 --- a/synapse/crypto/event_signing.py +++ b/synapse/crypto/event_signing.py @@ -125,9 +125,11 @@ def compute_event_signature(event_dict, signature_name, signing_key): redact_json = prune_event_dict(event_dict) redact_json.pop("age_ts", None) redact_json.pop("unsigned", None) - logger.debug("Signing event: %s", encode_canonical_json(redact_json)) + if logger.isEnabledFor(logging.DEBUG): + logger.debug("Signing event: %s", encode_canonical_json(redact_json)) redact_json = sign_json(redact_json, signature_name, signing_key) - logger.debug("Signed event: %s", encode_canonical_json(redact_json)) + if logger.isEnabledFor(logging.DEBUG): + logger.debug("Signed event: %s", encode_canonical_json(redact_json)) return redact_json["signatures"] diff --git a/synapse/event_auth.py b/synapse/event_auth.py index 4e91df60e6..e7b722547b 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -493,8 +493,7 @@ def _check_power_levels(event, auth_events): new_level_too_big = new_level is not None and new_level > user_level if old_level_too_big or new_level_too_big: raise AuthError( - 403, - "You don't have permission to add ops level greater " "than your own", + 403, "You don't have permission to add ops level greater than your own" ) diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index acbcbeeced..27cd8a63ff 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -12,9 +12,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from six import iteritems +import attr from frozendict import frozendict from twisted.internet import defer @@ -22,7 +22,8 @@ from twisted.internet import defer from synapse.logging.context import make_deferred_yieldable, run_in_background -class EventContext(object): +@attr.s(slots=True) +class EventContext: """ Attributes: state_group (int|None): state group id, if the state has been stored @@ -31,9 +32,6 @@ class EventContext(object): rejected (bool|str): A rejection reason if the event was rejected, else False - push_actions (list[(str, list[object])]): list of (user_id, actions) - tuples - prev_group (int): Previously persisted state group. ``None`` for an outlier. delta_ids (dict[(str, str), str]): Delta from ``prev_group``. @@ -42,6 +40,8 @@ class EventContext(object): prev_state_events (?): XXX: is this ever set to anything other than the empty list? + app_service: FIXME + _current_state_ids (dict[(str, str), str]|None): The current state map including the current event. None if outlier or we haven't fetched the state from DB yet. @@ -67,49 +67,33 @@ class EventContext(object): Only set when state has not been fetched yet. """ - __slots__ = [ - "state_group", - "rejected", - "prev_group", - "delta_ids", - "prev_state_events", - "app_service", - "_current_state_ids", - "_prev_state_ids", - "_prev_state_id", - "_event_type", - "_event_state_key", - "_fetching_state_deferred", - ] - - def __init__(self): - self.prev_state_events = [] - self.rejected = False - self.app_service = None + state_group = attr.ib(default=None) + rejected = attr.ib(default=False) + prev_group = attr.ib(default=None) + delta_ids = attr.ib(default=None) + prev_state_events = attr.ib(default=attr.Factory(list)) + app_service = attr.ib(default=None) + + _current_state_ids = attr.ib(default=None) + _prev_state_ids = attr.ib(default=None) + _prev_state_id = attr.ib(default=None) + + _event_type = attr.ib(default=None) + _event_state_key = attr.ib(default=None) + _fetching_state_deferred = attr.ib(default=None) @staticmethod def with_state( state_group, current_state_ids, prev_state_ids, prev_group=None, delta_ids=None ): - context = EventContext() - - # The current state including the current event - context._current_state_ids = current_state_ids - # The current state excluding the current event - context._prev_state_ids = prev_state_ids - context.state_group = state_group - - context._prev_state_id = None - context._event_type = None - context._event_state_key = None - context._fetching_state_deferred = defer.succeed(None) - - # A previously persisted state group and a delta between that - # and this state. - context.prev_group = prev_group - context.delta_ids = delta_ids - - return context + return EventContext( + current_state_ids=current_state_ids, + prev_state_ids=prev_state_ids, + state_group=state_group, + fetching_state_deferred=defer.succeed(None), + prev_group=prev_group, + delta_ids=delta_ids, + ) @defer.inlineCallbacks def serialize(self, event, store): @@ -157,24 +141,18 @@ class EventContext(object): Returns: EventContext """ - context = EventContext() - - # We use the state_group and prev_state_id stuff to pull the - # current_state_ids out of the DB and construct prev_state_ids. - context._prev_state_id = input["prev_state_id"] - context._event_type = input["event_type"] - context._event_state_key = input["event_state_key"] - - context._current_state_ids = None - context._prev_state_ids = None - context._fetching_state_deferred = None - - context.state_group = input["state_group"] - context.prev_group = input["prev_group"] - context.delta_ids = _decode_state_dict(input["delta_ids"]) - - context.rejected = input["rejected"] - context.prev_state_events = input["prev_state_events"] + context = EventContext( + # We use the state_group and prev_state_id stuff to pull the + # current_state_ids out of the DB and construct prev_state_ids. + prev_state_id=input["prev_state_id"], + event_type=input["event_type"], + event_state_key=input["event_state_key"], + state_group=input["state_group"], + prev_group=input["prev_group"], + delta_ids=_decode_state_dict(input["delta_ids"]), + rejected=input["rejected"], + prev_state_events=input["prev_state_events"], + ) app_service_id = input["app_service_id"] if app_service_id: diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index 5a1e23a145..223aace0d9 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -278,9 +278,7 @@ def _check_sigs_on_pdus(keyring, room_version, pdus): pdu_to_check.sender_domain, e.getErrorMessage(), ) - # XX not really sure if these are the right codes, but they are what - # we've done for ages - raise SynapseError(400, errmsg, Codes.UNAUTHORIZED) + raise SynapseError(403, errmsg, Codes.FORBIDDEN) for p, d in zip(pdus_to_check_sender, more_deferreds): d.addErrback(sender_err, p) @@ -314,8 +312,7 @@ def _check_sigs_on_pdus(keyring, room_version, pdus): "event id %s: unable to verify signature for event id domain: %s" % (pdu_to_check.pdu.event_id, e.getErrorMessage()) ) - # XX as above: not really sure if these are the right codes - raise SynapseError(400, errmsg, Codes.UNAUTHORIZED) + raise SynapseError(403, errmsg, Codes.FORBIDDEN) for p, d in zip(pdus_to_check_event_id, more_deferreds): d.addErrback(event_err, p) diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 5b22a39b7f..f5c1632916 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -196,7 +196,7 @@ class FederationClient(FederationBase): dest, room_id, extremities, limit ) - logger.debug("backfill transaction_data=%s", repr(transaction_data)) + logger.debug("backfill transaction_data=%r", transaction_data) room_version = yield self.store.get_room_version(room_id) format_ver = room_version_to_event_format(room_version) diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 21e52c9695..d5a19764d2 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -21,7 +21,6 @@ from six import iteritems from canonicaljson import json from prometheus_client import Counter -from twisted.internet import defer from twisted.internet.abstract import isIPAddress from twisted.python import failure @@ -86,14 +85,12 @@ class FederationServer(FederationBase): # come in waves. self._state_resp_cache = ResponseCache(hs, "state_resp", timeout_ms=30000) - @defer.inlineCallbacks - @log_function - def on_backfill_request(self, origin, room_id, versions, limit): - with (yield self._server_linearizer.queue((origin, room_id))): + async def on_backfill_request(self, origin, room_id, versions, limit): + with (await self._server_linearizer.queue((origin, room_id))): origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, room_id) + await self.check_server_matches_acl(origin_host, room_id) - pdus = yield self.handler.on_backfill_request( + pdus = await self.handler.on_backfill_request( origin, room_id, versions, limit ) @@ -101,9 +98,7 @@ class FederationServer(FederationBase): return 200, res - @defer.inlineCallbacks - @log_function - def on_incoming_transaction(self, origin, transaction_data): + async def on_incoming_transaction(self, origin, transaction_data): # keep this as early as possible to make the calculated origin ts as # accurate as possible. request_time = self._clock.time_msec() @@ -118,18 +113,17 @@ class FederationServer(FederationBase): # use a linearizer to ensure that we don't process the same transaction # multiple times in parallel. with ( - yield self._transaction_linearizer.queue( + await self._transaction_linearizer.queue( (origin, transaction.transaction_id) ) ): - result = yield self._handle_incoming_transaction( + result = await self._handle_incoming_transaction( origin, transaction, request_time ) return result - @defer.inlineCallbacks - def _handle_incoming_transaction(self, origin, transaction, request_time): + async def _handle_incoming_transaction(self, origin, transaction, request_time): """ Process an incoming transaction and return the HTTP response Args: @@ -140,7 +134,7 @@ class FederationServer(FederationBase): Returns: Deferred[(int, object)]: http response code and body """ - response = yield self.transaction_actions.have_responded(origin, transaction) + response = await self.transaction_actions.have_responded(origin, transaction) if response: logger.debug( @@ -151,7 +145,7 @@ class FederationServer(FederationBase): logger.debug("[%s] Transaction is new", transaction.transaction_id) - # Reject if PDU count > 50 and EDU count > 100 + # Reject if PDU count > 50 or EDU count > 100 if len(transaction.pdus) > 50 or ( hasattr(transaction, "edus") and len(transaction.edus) > 100 ): @@ -159,7 +153,7 @@ class FederationServer(FederationBase): logger.info("Transaction PDU or EDU count too large. Returning 400") response = {} - yield self.transaction_actions.set_response( + await self.transaction_actions.set_response( origin, transaction, 400, response ) return 400, response @@ -195,7 +189,7 @@ class FederationServer(FederationBase): continue try: - room_version = yield self.store.get_room_version(room_id) + room_version = await self.store.get_room_version(room_id) except NotFoundError: logger.info("Ignoring PDU for unknown room_id: %s", room_id) continue @@ -221,11 +215,10 @@ class FederationServer(FederationBase): # require callouts to other servers to fetch missing events), but # impose a limit to avoid going too crazy with ram/cpu. - @defer.inlineCallbacks - def process_pdus_for_room(room_id): + async def process_pdus_for_room(room_id): logger.debug("Processing PDUs for %s", room_id) try: - yield self.check_server_matches_acl(origin_host, room_id) + await self.check_server_matches_acl(origin_host, room_id) except AuthError as e: logger.warn("Ignoring PDUs for room %s from banned server", room_id) for pdu in pdus_by_room[room_id]: @@ -237,7 +230,7 @@ class FederationServer(FederationBase): event_id = pdu.event_id with nested_logging_context(event_id): try: - yield self._handle_received_pdu(origin, pdu) + await self._handle_received_pdu(origin, pdu) pdu_results[event_id] = {} except FederationError as e: logger.warn("Error handling PDU %s: %s", event_id, e) @@ -251,36 +244,33 @@ class FederationServer(FederationBase): exc_info=(f.type, f.value, f.getTracebackObject()), ) - yield concurrently_execute( + await concurrently_execute( process_pdus_for_room, pdus_by_room.keys(), TRANSACTION_CONCURRENCY_LIMIT ) if hasattr(transaction, "edus"): for edu in (Edu(**x) for x in transaction.edus): - yield self.received_edu(origin, edu.edu_type, edu.content) + await self.received_edu(origin, edu.edu_type, edu.content) response = {"pdus": pdu_results} logger.debug("Returning: %s", str(response)) - yield self.transaction_actions.set_response(origin, transaction, 200, response) + await self.transaction_actions.set_response(origin, transaction, 200, response) return 200, response - @defer.inlineCallbacks - def received_edu(self, origin, edu_type, content): + async def received_edu(self, origin, edu_type, content): received_edus_counter.inc() - yield self.registry.on_edu(edu_type, origin, content) + await self.registry.on_edu(edu_type, origin, content) - @defer.inlineCallbacks - @log_function - def on_context_state_request(self, origin, room_id, event_id): + async def on_context_state_request(self, origin, room_id, event_id): if not event_id: raise NotImplementedError("Specify an event") origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, room_id) + await self.check_server_matches_acl(origin_host, room_id) - in_room = yield self.auth.check_host_in_room(room_id, origin) + in_room = await self.auth.check_host_in_room(room_id, origin) if not in_room: raise AuthError(403, "Host not in room.") @@ -289,8 +279,8 @@ class FederationServer(FederationBase): # in the cache so we could return it without waiting for the linearizer # - but that's non-trivial to get right, and anyway somewhat defeats # the point of the linearizer. - with (yield self._server_linearizer.queue((origin, room_id))): - resp = yield self._state_resp_cache.wrap( + with (await self._server_linearizer.queue((origin, room_id))): + resp = await self._state_resp_cache.wrap( (room_id, event_id), self._on_context_state_request_compute, room_id, @@ -299,65 +289,58 @@ class FederationServer(FederationBase): return 200, resp - @defer.inlineCallbacks - def on_state_ids_request(self, origin, room_id, event_id): + async def on_state_ids_request(self, origin, room_id, event_id): if not event_id: raise NotImplementedError("Specify an event") origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, room_id) + await self.check_server_matches_acl(origin_host, room_id) - in_room = yield self.auth.check_host_in_room(room_id, origin) + in_room = await self.auth.check_host_in_room(room_id, origin) if not in_room: raise AuthError(403, "Host not in room.") - state_ids = yield self.handler.get_state_ids_for_pdu(room_id, event_id) - auth_chain_ids = yield self.store.get_auth_chain_ids(state_ids) + state_ids = await self.handler.get_state_ids_for_pdu(room_id, event_id) + auth_chain_ids = await self.store.get_auth_chain_ids(state_ids) return 200, {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids} - @defer.inlineCallbacks - def _on_context_state_request_compute(self, room_id, event_id): - pdus = yield self.handler.get_state_for_pdu(room_id, event_id) - auth_chain = yield self.store.get_auth_chain([pdu.event_id for pdu in pdus]) + async def _on_context_state_request_compute(self, room_id, event_id): + pdus = await self.handler.get_state_for_pdu(room_id, event_id) + auth_chain = await self.store.get_auth_chain([pdu.event_id for pdu in pdus]) return { "pdus": [pdu.get_pdu_json() for pdu in pdus], "auth_chain": [pdu.get_pdu_json() for pdu in auth_chain], } - @defer.inlineCallbacks - @log_function - def on_pdu_request(self, origin, event_id): - pdu = yield self.handler.get_persisted_pdu(origin, event_id) + async def on_pdu_request(self, origin, event_id): + pdu = await self.handler.get_persisted_pdu(origin, event_id) if pdu: return 200, self._transaction_from_pdus([pdu]).get_dict() else: return 404, "" - @defer.inlineCallbacks - def on_query_request(self, query_type, args): + async def on_query_request(self, query_type, args): received_queries_counter.labels(query_type).inc() - resp = yield self.registry.on_query(query_type, args) + resp = await self.registry.on_query(query_type, args) return 200, resp - @defer.inlineCallbacks - def on_make_join_request(self, origin, room_id, user_id, supported_versions): + async def on_make_join_request(self, origin, room_id, user_id, supported_versions): origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, room_id) + await self.check_server_matches_acl(origin_host, room_id) - room_version = yield self.store.get_room_version(room_id) + room_version = await self.store.get_room_version(room_id) if room_version not in supported_versions: logger.warn("Room version %s not in %s", room_version, supported_versions) raise IncompatibleRoomVersionError(room_version=room_version) - pdu = yield self.handler.on_make_join_request(origin, room_id, user_id) + pdu = await self.handler.on_make_join_request(origin, room_id, user_id) time_now = self._clock.time_msec() return {"event": pdu.get_pdu_json(time_now), "room_version": room_version} - @defer.inlineCallbacks - def on_invite_request(self, origin, content, room_version): + async def on_invite_request(self, origin, content, room_version): if room_version not in KNOWN_ROOM_VERSIONS: raise SynapseError( 400, @@ -369,24 +352,27 @@ class FederationServer(FederationBase): pdu = event_from_pdu_json(content, format_ver) origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, pdu.room_id) - ret_pdu = yield self.handler.on_invite_request(origin, pdu) + await self.check_server_matches_acl(origin_host, pdu.room_id) + pdu = await self._check_sigs_and_hash(room_version, pdu) + ret_pdu = await self.handler.on_invite_request(origin, pdu) time_now = self._clock.time_msec() return {"event": ret_pdu.get_pdu_json(time_now)} - @defer.inlineCallbacks - def on_send_join_request(self, origin, content, room_id): + async def on_send_join_request(self, origin, content, room_id): logger.debug("on_send_join_request: content: %s", content) - room_version = yield self.store.get_room_version(room_id) + room_version = await self.store.get_room_version(room_id) format_ver = room_version_to_event_format(room_version) pdu = event_from_pdu_json(content, format_ver) origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, pdu.room_id) + await self.check_server_matches_acl(origin_host, pdu.room_id) logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures) - res_pdus = yield self.handler.on_send_join_request(origin, pdu) + + pdu = await self._check_sigs_and_hash(room_version, pdu) + + res_pdus = await self.handler.on_send_join_request(origin, pdu) time_now = self._clock.time_msec() return ( 200, @@ -398,45 +384,44 @@ class FederationServer(FederationBase): }, ) - @defer.inlineCallbacks - def on_make_leave_request(self, origin, room_id, user_id): + async def on_make_leave_request(self, origin, room_id, user_id): origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, room_id) - pdu = yield self.handler.on_make_leave_request(origin, room_id, user_id) + await self.check_server_matches_acl(origin_host, room_id) + pdu = await self.handler.on_make_leave_request(origin, room_id, user_id) - room_version = yield self.store.get_room_version(room_id) + room_version = await self.store.get_room_version(room_id) time_now = self._clock.time_msec() return {"event": pdu.get_pdu_json(time_now), "room_version": room_version} - @defer.inlineCallbacks - def on_send_leave_request(self, origin, content, room_id): + async def on_send_leave_request(self, origin, content, room_id): logger.debug("on_send_leave_request: content: %s", content) - room_version = yield self.store.get_room_version(room_id) + room_version = await self.store.get_room_version(room_id) format_ver = room_version_to_event_format(room_version) pdu = event_from_pdu_json(content, format_ver) origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, pdu.room_id) + await self.check_server_matches_acl(origin_host, pdu.room_id) logger.debug("on_send_leave_request: pdu sigs: %s", pdu.signatures) - yield self.handler.on_send_leave_request(origin, pdu) + + pdu = await self._check_sigs_and_hash(room_version, pdu) + + await self.handler.on_send_leave_request(origin, pdu) return 200, {} - @defer.inlineCallbacks - def on_event_auth(self, origin, room_id, event_id): - with (yield self._server_linearizer.queue((origin, room_id))): + async def on_event_auth(self, origin, room_id, event_id): + with (await self._server_linearizer.queue((origin, room_id))): origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, room_id) + await self.check_server_matches_acl(origin_host, room_id) time_now = self._clock.time_msec() - auth_pdus = yield self.handler.on_event_auth(event_id) + auth_pdus = await self.handler.on_event_auth(event_id) res = {"auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus]} return 200, res - @defer.inlineCallbacks - def on_query_auth_request(self, origin, content, room_id, event_id): + async def on_query_auth_request(self, origin, content, room_id, event_id): """ Content is a dict with keys:: auth_chain (list): A list of events that give the auth chain. @@ -455,22 +440,22 @@ class FederationServer(FederationBase): Returns: Deferred: Results in `dict` with the same format as `content` """ - with (yield self._server_linearizer.queue((origin, room_id))): + with (await self._server_linearizer.queue((origin, room_id))): origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, room_id) + await self.check_server_matches_acl(origin_host, room_id) - room_version = yield self.store.get_room_version(room_id) + room_version = await self.store.get_room_version(room_id) format_ver = room_version_to_event_format(room_version) auth_chain = [ event_from_pdu_json(e, format_ver) for e in content["auth_chain"] ] - signed_auth = yield self._check_sigs_and_hash_and_fetch( + signed_auth = await self._check_sigs_and_hash_and_fetch( origin, auth_chain, outlier=True, room_version=room_version ) - ret = yield self.handler.on_query_auth( + ret = await self.handler.on_query_auth( origin, event_id, room_id, @@ -496,16 +481,14 @@ class FederationServer(FederationBase): return self.on_query_request("user_devices", user_id) @trace - @defer.inlineCallbacks - @log_function - def on_claim_client_keys(self, origin, content): + async def on_claim_client_keys(self, origin, content): query = [] for user_id, device_keys in content.get("one_time_keys", {}).items(): for device_id, algorithm in device_keys.items(): query.append((user_id, device_id, algorithm)) log_kv({"message": "Claiming one time keys.", "user, device pairs": query}) - results = yield self.store.claim_e2e_one_time_keys(query) + results = await self.store.claim_e2e_one_time_keys(query) json_result = {} for user_id, device_keys in results.items(): @@ -529,14 +512,12 @@ class FederationServer(FederationBase): return {"one_time_keys": json_result} - @defer.inlineCallbacks - @log_function - def on_get_missing_events( + async def on_get_missing_events( self, origin, room_id, earliest_events, latest_events, limit ): - with (yield self._server_linearizer.queue((origin, room_id))): + with (await self._server_linearizer.queue((origin, room_id))): origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, room_id) + await self.check_server_matches_acl(origin_host, room_id) logger.info( "on_get_missing_events: earliest_events: %r, latest_events: %r," @@ -546,7 +527,7 @@ class FederationServer(FederationBase): limit, ) - missing_events = yield self.handler.on_get_missing_events( + missing_events = await self.handler.on_get_missing_events( origin, room_id, earliest_events, latest_events, limit ) @@ -579,8 +560,7 @@ class FederationServer(FederationBase): destination=None, ) - @defer.inlineCallbacks - def _handle_received_pdu(self, origin, pdu): + async def _handle_received_pdu(self, origin, pdu): """ Process a PDU received in a federation /send/ transaction. If the event is invalid, then this method throws a FederationError. @@ -633,37 +613,34 @@ class FederationServer(FederationBase): logger.info("Accepting join PDU %s from %s", pdu.event_id, origin) # We've already checked that we know the room version by this point - room_version = yield self.store.get_room_version(pdu.room_id) + room_version = await self.store.get_room_version(pdu.room_id) # Check signature. try: - pdu = yield self._check_sigs_and_hash(room_version, pdu) + pdu = await self._check_sigs_and_hash(room_version, pdu) except SynapseError as e: raise FederationError("ERROR", e.code, e.msg, affected=pdu.event_id) - yield self.handler.on_receive_pdu(origin, pdu, sent_to_us_directly=True) + await self.handler.on_receive_pdu(origin, pdu, sent_to_us_directly=True) def __str__(self): return "<ReplicationLayer(%s)>" % self.server_name - @defer.inlineCallbacks - def exchange_third_party_invite( + async def exchange_third_party_invite( self, sender_user_id, target_user_id, room_id, signed ): - ret = yield self.handler.exchange_third_party_invite( + ret = await self.handler.exchange_third_party_invite( sender_user_id, target_user_id, room_id, signed ) return ret - @defer.inlineCallbacks - def on_exchange_third_party_invite_request(self, room_id, event_dict): - ret = yield self.handler.on_exchange_third_party_invite_request( + async def on_exchange_third_party_invite_request(self, room_id, event_dict): + ret = await self.handler.on_exchange_third_party_invite_request( room_id, event_dict ) return ret - @defer.inlineCallbacks - def check_server_matches_acl(self, server_name, room_id): + async def check_server_matches_acl(self, server_name, room_id): """Check if the given server is allowed by the server ACLs in the room Args: @@ -673,13 +650,13 @@ class FederationServer(FederationBase): Raises: AuthError if the server does not match the ACL """ - state_ids = yield self.store.get_current_state_ids(room_id) + state_ids = await self.store.get_current_state_ids(room_id) acl_event_id = state_ids.get((EventTypes.ServerACL, "")) if not acl_event_id: return - acl_event = yield self.store.get_event(acl_event_id) + acl_event = await self.store.get_event(acl_event_id) if server_matches_acl_event(server_name, acl_event): return @@ -792,15 +769,14 @@ class FederationHandlerRegistry(object): self.query_handlers[query_type] = handler - @defer.inlineCallbacks - def on_edu(self, edu_type, origin, content): + async def on_edu(self, edu_type, origin, content): handler = self.edu_handlers.get(edu_type) if not handler: logger.warn("No handler registered for EDU type %s", edu_type) with start_active_span_from_edu(content, "handle_edu"): try: - yield handler(origin, content) + await handler(origin, content) except SynapseError as e: logger.info("Failed to handle edu %r: %r", edu_type, e) except Exception: @@ -833,7 +809,7 @@ class ReplicationFederationHandlerRegistry(FederationHandlerRegistry): super(ReplicationFederationHandlerRegistry, self).__init__() - def on_edu(self, edu_type, origin, content): + async def on_edu(self, edu_type, origin, content): """Overrides FederationHandlerRegistry """ if not self.config.use_presence and edu_type == "m.presence": @@ -841,17 +817,17 @@ class ReplicationFederationHandlerRegistry(FederationHandlerRegistry): handler = self.edu_handlers.get(edu_type) if handler: - return super(ReplicationFederationHandlerRegistry, self).on_edu( + return await super(ReplicationFederationHandlerRegistry, self).on_edu( edu_type, origin, content ) - return self._send_edu(edu_type=edu_type, origin=origin, content=content) + return await self._send_edu(edu_type=edu_type, origin=origin, content=content) - def on_query(self, query_type, args): + async def on_query(self, query_type, args): """Overrides FederationHandlerRegistry """ handler = self.query_handlers.get(query_type) if handler: - return handler(args) + return await handler(args) - return self._get_query_client(query_type=query_type, args=args) + return await self._get_query_client(query_type=query_type, args=args) diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py index 454456a52d..ced4925a98 100644 --- a/synapse/federation/send_queue.py +++ b/synapse/federation/send_queue.py @@ -36,6 +36,8 @@ from six import iteritems from sortedcontainers import SortedDict +from twisted.internet import defer + from synapse.metrics import LaterGauge from synapse.storage.presence import UserPresenceState from synapse.util.metrics import Measure @@ -212,7 +214,7 @@ class FederationRemoteSendQueue(object): receipt (synapse.types.ReadReceipt): """ # nothing to do here: the replication listener will handle it. - pass + return defer.succeed(None) def send_presence(self, states): """As per FederationSender diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 7b18408144..920fa86853 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -122,10 +122,10 @@ class TransportLayerClient(object): Deferred: Results in a dict received from the remote homeserver. """ logger.debug( - "backfill dest=%s, room_id=%s, event_tuples=%s, limit=%s", + "backfill dest=%s, room_id=%s, event_tuples=%r, limit=%s", destination, room_id, - repr(event_tuples), + event_tuples, str(limit), ) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 7e9c9aee13..08276fdebf 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1223,7 +1223,6 @@ class FederationHandler(BaseHandler): Returns: Deferred[FrozenEvent] """ - if get_domain_from_id(user_id) != origin: logger.info( "Got /make_join request for user %r from different origin %s, ignoring", @@ -1252,7 +1251,7 @@ class FederationHandler(BaseHandler): builder=builder ) except AuthError as e: - logger.warn("Failed to create join %r because %s", event, e) + logger.warn("Failed to create join to %s because %s", room_id, e) raise e event_allowed = yield self.third_party_event_rules.check_event_allowed( @@ -1281,11 +1280,20 @@ class FederationHandler(BaseHandler): event = pdu logger.debug( - "on_send_join_request: Got event: %s, signatures: %s", + "on_send_join_request from %s: Got event: %s, signatures: %s", + origin, event.event_id, event.signatures, ) + if get_domain_from_id(event.sender) != origin: + logger.info( + "Got /send_join request for user %r from different origin %s", + event.sender, + origin, + ) + raise SynapseError(403, "User not from origin", Codes.FORBIDDEN) + event.internal_metadata.outlier = False # Send this event on behalf of the origin server. # @@ -1504,6 +1512,14 @@ class FederationHandler(BaseHandler): event.signatures, ) + if get_domain_from_id(event.sender) != origin: + logger.info( + "Got /send_leave request for user %r from different origin %s", + event.sender, + origin, + ) + raise SynapseError(403, "User not from origin", Codes.FORBIDDEN) + event.internal_metadata.outlier = False context = yield self._handle_new_event(origin, event) diff --git a/synapse/handlers/read_marker.py b/synapse/handlers/read_marker.py index 3e4d8c93a4..e3b528d271 100644 --- a/synapse/handlers/read_marker.py +++ b/synapse/handlers/read_marker.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.util.async_helpers import Linearizer from ._base import BaseHandler @@ -32,8 +30,7 @@ class ReadMarkerHandler(BaseHandler): self.read_marker_linearizer = Linearizer(name="read_marker") self.notifier = hs.get_notifier() - @defer.inlineCallbacks - def received_client_read_marker(self, room_id, user_id, event_id): + async def received_client_read_marker(self, room_id, user_id, event_id): """Updates the read marker for a given user in a given room if the event ID given is ahead in the stream relative to the current read marker. @@ -41,8 +38,8 @@ class ReadMarkerHandler(BaseHandler): the read marker has changed. """ - with (yield self.read_marker_linearizer.queue((room_id, user_id))): - existing_read_marker = yield self.store.get_account_data_for_room_and_type( + with await self.read_marker_linearizer.queue((room_id, user_id)): + existing_read_marker = await self.store.get_account_data_for_room_and_type( user_id, room_id, "m.fully_read" ) @@ -50,13 +47,13 @@ class ReadMarkerHandler(BaseHandler): if existing_read_marker: # Only update if the new marker is ahead in the stream - should_update = yield self.store.is_event_after( + should_update = await self.store.is_event_after( event_id, existing_read_marker["event_id"] ) if should_update: content = {"event_id": event_id} - max_id = yield self.store.add_account_data_to_room( + max_id = await self.store.add_account_data_to_room( user_id, room_id, "m.fully_read", content ) self.notifier.on_new_event("account_data_key", max_id, users=[user_id]) diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index 6854c751a6..9283c039e3 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -18,6 +18,7 @@ from twisted.internet import defer from synapse.handlers._base import BaseHandler from synapse.types import ReadReceipt, get_domain_from_id +from synapse.util.async_helpers import maybe_awaitable logger = logging.getLogger(__name__) @@ -36,8 +37,7 @@ class ReceiptsHandler(BaseHandler): self.clock = self.hs.get_clock() self.state = hs.get_state_handler() - @defer.inlineCallbacks - def _received_remote_receipt(self, origin, content): + async def _received_remote_receipt(self, origin, content): """Called when we receive an EDU of type m.receipt from a remote HS. """ receipts = [] @@ -62,17 +62,16 @@ class ReceiptsHandler(BaseHandler): ) ) - yield self._handle_new_receipts(receipts) + await self._handle_new_receipts(receipts) - @defer.inlineCallbacks - def _handle_new_receipts(self, receipts): + async def _handle_new_receipts(self, receipts): """Takes a list of receipts, stores them and informs the notifier. """ min_batch_id = None max_batch_id = None for receipt in receipts: - res = yield self.store.insert_receipt( + res = await self.store.insert_receipt( receipt.room_id, receipt.receipt_type, receipt.user_id, @@ -99,14 +98,15 @@ class ReceiptsHandler(BaseHandler): self.notifier.on_new_event("receipt_key", max_batch_id, rooms=affected_room_ids) # Note that the min here shouldn't be relied upon to be accurate. - yield self.hs.get_pusherpool().on_new_receipts( - min_batch_id, max_batch_id, affected_room_ids + await maybe_awaitable( + self.hs.get_pusherpool().on_new_receipts( + min_batch_id, max_batch_id, affected_room_ids + ) ) return True - @defer.inlineCallbacks - def received_client_receipt(self, room_id, receipt_type, user_id, event_id): + async def received_client_receipt(self, room_id, receipt_type, user_id, event_id): """Called when a client tells us a local user has read up to the given event_id in the room. """ @@ -118,24 +118,11 @@ class ReceiptsHandler(BaseHandler): data={"ts": int(self.clock.time_msec())}, ) - is_new = yield self._handle_new_receipts([receipt]) + is_new = await self._handle_new_receipts([receipt]) if not is_new: return - yield self.federation.send_read_receipt(receipt) - - @defer.inlineCallbacks - def get_receipts_for_room(self, room_id, to_key): - """Gets all receipts for a room, upto the given key. - """ - result = yield self.store.get_linearized_receipts_for_room( - room_id, to_key=to_key - ) - - if not result: - return [] - - return result + await self.federation.send_read_receipt(receipt) class ReceiptEventSource(object): diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py index 466daf9202..26bc276692 100644 --- a/synapse/handlers/stats.py +++ b/synapse/handlers/stats.py @@ -45,6 +45,8 @@ class StatsHandler(StateDeltasHandler): self.is_mine_id = hs.is_mine_id self.stats_bucket_size = hs.config.stats_bucket_size + self.stats_enabled = hs.config.stats_enabled + # The current position in the current_state_delta stream self.pos = None @@ -61,7 +63,7 @@ class StatsHandler(StateDeltasHandler): def notify_new_event(self): """Called when there may be more deltas to process """ - if not self.hs.config.stats_enabled or self._is_processing: + if not self.stats_enabled or self._is_processing: return self._is_processing = True diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 22491f3700..2bbdd11941 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -79,7 +79,7 @@ class BulkPushRuleEvaluator(object): dict of user_id -> push_rules """ room_id = event.room_id - rules_for_room = self._get_rules_for_room(room_id) + rules_for_room = yield self._get_rules_for_room(room_id) rules_by_user = yield rules_for_room.get_rules(event, context) diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index 03560c1f0e..9be37cd998 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -110,14 +110,14 @@ class ReplicationEndpoint(object): return {} @abc.abstractmethod - def _handle_request(self, request, **kwargs): + async def _handle_request(self, request, **kwargs): """Handle incoming request. This is called with the request object and PATH_ARGS. Returns: - Deferred[dict]: A JSON serialisable dict to be used as response - body of request. + tuple[int, dict]: HTTP status code and a JSON serialisable dict + to be used as response body of request. """ pass diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py index 2f16955954..9af4e7e173 100644 --- a/synapse/replication/http/federation.py +++ b/synapse/replication/http/federation.py @@ -82,8 +82,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): return payload - @defer.inlineCallbacks - def _handle_request(self, request): + async def _handle_request(self, request): with Measure(self.clock, "repl_fed_send_events_parse"): content = parse_json_object_from_request(request) @@ -101,15 +100,13 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): EventType = event_type_from_format_version(format_ver) event = EventType(event_dict, internal_metadata, rejected_reason) - context = yield EventContext.deserialize( - self.store, event_payload["context"] - ) + context = EventContext.deserialize(self.store, event_payload["context"]) event_and_contexts.append((event, context)) logger.info("Got %d events from federation", len(event_and_contexts)) - yield self.federation_handler.persist_events_and_notify( + await self.federation_handler.persist_events_and_notify( event_and_contexts, backfilled ) @@ -144,8 +141,7 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint): def _serialize_payload(edu_type, origin, content): return {"origin": origin, "content": content} - @defer.inlineCallbacks - def _handle_request(self, request, edu_type): + async def _handle_request(self, request, edu_type): with Measure(self.clock, "repl_fed_send_edu_parse"): content = parse_json_object_from_request(request) @@ -154,7 +150,7 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint): logger.info("Got %r edu from %s", edu_type, origin) - result = yield self.registry.on_edu(edu_type, origin, edu_content) + result = await self.registry.on_edu(edu_type, origin, edu_content) return 200, result @@ -193,8 +189,7 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint): """ return {"args": args} - @defer.inlineCallbacks - def _handle_request(self, request, query_type): + async def _handle_request(self, request, query_type): with Measure(self.clock, "repl_fed_query_parse"): content = parse_json_object_from_request(request) @@ -202,7 +197,7 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint): logger.info("Got %r query", query_type) - result = yield self.registry.on_query(query_type, args) + result = await self.registry.on_query(query_type, args) return 200, result @@ -234,9 +229,8 @@ class ReplicationCleanRoomRestServlet(ReplicationEndpoint): """ return {} - @defer.inlineCallbacks - def _handle_request(self, request, room_id): - yield self.store.clean_room_for_join(room_id) + async def _handle_request(self, request, room_id): + await self.store.clean_room_for_join(room_id) return 200, {} diff --git a/synapse/replication/http/login.py b/synapse/replication/http/login.py index 786f5232b2..798b9d3af5 100644 --- a/synapse/replication/http/login.py +++ b/synapse/replication/http/login.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.http.servlet import parse_json_object_from_request from synapse.replication.http._base import ReplicationEndpoint @@ -52,15 +50,14 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint): "is_guest": is_guest, } - @defer.inlineCallbacks - def _handle_request(self, request, user_id): + async def _handle_request(self, request, user_id): content = parse_json_object_from_request(request) device_id = content["device_id"] initial_display_name = content["initial_display_name"] is_guest = content["is_guest"] - device_id, access_token = yield self.registration_handler.register_device( + device_id, access_token = await self.registration_handler.register_device( user_id, device_id, initial_display_name, is_guest ) diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py index b9ce3477ad..b5f5f13a62 100644 --- a/synapse/replication/http/membership.py +++ b/synapse/replication/http/membership.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.http.servlet import parse_json_object_from_request from synapse.replication.http._base import ReplicationEndpoint from synapse.types import Requester, UserID @@ -65,8 +63,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint): "content": content, } - @defer.inlineCallbacks - def _handle_request(self, request, room_id, user_id): + async def _handle_request(self, request, room_id, user_id): content = parse_json_object_from_request(request) remote_room_hosts = content["remote_room_hosts"] @@ -79,7 +76,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint): logger.info("remote_join: %s into room: %s", user_id, room_id) - yield self.federation_handler.do_invite_join( + await self.federation_handler.do_invite_join( remote_room_hosts, room_id, user_id, event_content ) @@ -123,8 +120,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): "remote_room_hosts": remote_room_hosts, } - @defer.inlineCallbacks - def _handle_request(self, request, room_id, user_id): + async def _handle_request(self, request, room_id, user_id): content = parse_json_object_from_request(request) remote_room_hosts = content["remote_room_hosts"] @@ -137,7 +133,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): logger.info("remote_reject_invite: %s out of room: %s", user_id, room_id) try: - event = yield self.federation_handler.do_remotely_reject_invite( + event = await self.federation_handler.do_remotely_reject_invite( remote_room_hosts, room_id, user_id ) ret = event.get_pdu_json() @@ -150,7 +146,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): # logger.warn("Failed to reject invite: %s", e) - yield self.store.locally_reject_invite(user_id, room_id) + await self.store.locally_reject_invite(user_id, room_id) ret = {} return 200, ret diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py index 38260256cf..915cfb9430 100644 --- a/synapse/replication/http/register.py +++ b/synapse/replication/http/register.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.http.servlet import parse_json_object_from_request from synapse.replication.http._base import ReplicationEndpoint @@ -74,11 +72,10 @@ class ReplicationRegisterServlet(ReplicationEndpoint): "address": address, } - @defer.inlineCallbacks - def _handle_request(self, request, user_id): + async def _handle_request(self, request, user_id): content = parse_json_object_from_request(request) - yield self.registration_handler.register_with_store( + await self.registration_handler.register_with_store( user_id=user_id, password_hash=content["password_hash"], was_guest=content["was_guest"], @@ -117,14 +114,13 @@ class ReplicationPostRegisterActionsServlet(ReplicationEndpoint): """ return {"auth_result": auth_result, "access_token": access_token} - @defer.inlineCallbacks - def _handle_request(self, request, user_id): + async def _handle_request(self, request, user_id): content = parse_json_object_from_request(request) auth_result = content["auth_result"] access_token = content["access_token"] - yield self.registration_handler.post_registration_actions( + await self.registration_handler.post_registration_actions( user_id=user_id, auth_result=auth_result, access_token=access_token ) diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py index adb9b2f7f4..9bafd60b14 100644 --- a/synapse/replication/http/send_event.py +++ b/synapse/replication/http/send_event.py @@ -87,8 +87,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): return payload - @defer.inlineCallbacks - def _handle_request(self, request, event_id): + async def _handle_request(self, request, event_id): with Measure(self.clock, "repl_send_event_parse"): content = parse_json_object_from_request(request) @@ -101,7 +100,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): event = EventType(event_dict, internal_metadata, rejected_reason) requester = Requester.deserialize(self.store, content["requester"]) - context = yield EventContext.deserialize(self.store, content["context"]) + context = EventContext.deserialize(self.store, content["context"]) ratelimit = content["ratelimit"] extra_users = [UserID.from_string(u) for u in content["extra_users"]] @@ -113,7 +112,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): "Got event to send with ID: %s into room: %s", event.event_id, event.room_id ) - yield self.event_creation_handler.persist_and_notify_client_event( + await self.event_creation_handler.persist_and_notify_client_event( requester, event, context, ratelimit=ratelimit, extra_users=extra_users ) diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 9c1d41421c..86bbcc0eea 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -21,8 +21,6 @@ from six.moves.urllib import parse as urlparse from canonicaljson import json -from twisted.internet import defer - from synapse.api.constants import EventTypes, Membership from synapse.api.errors import ( AuthError, @@ -85,11 +83,10 @@ class RoomCreateRestServlet(TransactionRestServlet): set_tag("txn_id", txn_id) return self.txns.fetch_or_execute_request(request, self.on_POST, request) - @defer.inlineCallbacks - def on_POST(self, request): - requester = yield self.auth.get_user_by_req(request) + async def on_POST(self, request): + requester = await self.auth.get_user_by_req(request) - info = yield self._room_creation_handler.create_room( + info = await self._room_creation_handler.create_room( requester, self.get_room_config(request) ) @@ -154,15 +151,14 @@ class RoomStateEventRestServlet(TransactionRestServlet): def on_PUT_no_state_key(self, request, room_id, event_type): return self.on_PUT(request, room_id, event_type, "") - @defer.inlineCallbacks - def on_GET(self, request, room_id, event_type, state_key): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, room_id, event_type, state_key): + requester = await self.auth.get_user_by_req(request, allow_guest=True) format = parse_string( request, "format", default="content", allowed_values=["content", "event"] ) msg_handler = self.message_handler - data = yield msg_handler.get_room_data( + data = await msg_handler.get_room_data( user_id=requester.user.to_string(), room_id=room_id, event_type=event_type, @@ -179,9 +175,8 @@ class RoomStateEventRestServlet(TransactionRestServlet): elif format == "content": return 200, data.get_dict()["content"] - @defer.inlineCallbacks - def on_PUT(self, request, room_id, event_type, state_key, txn_id=None): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, room_id, event_type, state_key, txn_id=None): + requester = await self.auth.get_user_by_req(request) if txn_id: set_tag("txn_id", txn_id) @@ -200,7 +195,7 @@ class RoomStateEventRestServlet(TransactionRestServlet): if event_type == EventTypes.Member: membership = content.get("membership", None) - event = yield self.room_member_handler.update_membership( + event = await self.room_member_handler.update_membership( requester, target=UserID.from_string(state_key), room_id=room_id, @@ -208,7 +203,7 @@ class RoomStateEventRestServlet(TransactionRestServlet): content=content, ) else: - event = yield self.event_creation_handler.create_and_send_nonmember_event( + event = await self.event_creation_handler.create_and_send_nonmember_event( requester, event_dict, txn_id=txn_id ) @@ -231,9 +226,8 @@ class RoomSendEventRestServlet(TransactionRestServlet): PATTERNS = "/rooms/(?P<room_id>[^/]*)/send/(?P<event_type>[^/]*)" register_txn_path(self, PATTERNS, http_server, with_get=True) - @defer.inlineCallbacks - def on_POST(self, request, room_id, event_type, txn_id=None): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_POST(self, request, room_id, event_type, txn_id=None): + requester = await self.auth.get_user_by_req(request, allow_guest=True) content = parse_json_object_from_request(request) event_dict = { @@ -246,7 +240,7 @@ class RoomSendEventRestServlet(TransactionRestServlet): if b"ts" in request.args and requester.app_service: event_dict["origin_server_ts"] = parse_integer(request, "ts", 0) - event = yield self.event_creation_handler.create_and_send_nonmember_event( + event = await self.event_creation_handler.create_and_send_nonmember_event( requester, event_dict, txn_id=txn_id ) @@ -276,9 +270,8 @@ class JoinRoomAliasServlet(TransactionRestServlet): PATTERNS = "/join/(?P<room_identifier>[^/]*)" register_txn_path(self, PATTERNS, http_server) - @defer.inlineCallbacks - def on_POST(self, request, room_identifier, txn_id=None): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_POST(self, request, room_identifier, txn_id=None): + requester = await self.auth.get_user_by_req(request, allow_guest=True) try: content = parse_json_object_from_request(request) @@ -298,14 +291,14 @@ class JoinRoomAliasServlet(TransactionRestServlet): elif RoomAlias.is_valid(room_identifier): handler = self.room_member_handler room_alias = RoomAlias.from_string(room_identifier) - room_id, remote_room_hosts = yield handler.lookup_room_alias(room_alias) + room_id, remote_room_hosts = await handler.lookup_room_alias(room_alias) room_id = room_id.to_string() else: raise SynapseError( 400, "%s was not legal room ID or room alias" % (room_identifier,) ) - yield self.room_member_handler.update_membership( + await self.room_member_handler.update_membership( requester=requester, target=requester.user, room_id=room_id, @@ -335,12 +328,11 @@ class PublicRoomListRestServlet(TransactionRestServlet): self.hs = hs self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request): + async def on_GET(self, request): server = parse_string(request, "server", default=None) try: - yield self.auth.get_user_by_req(request, allow_guest=True) + await self.auth.get_user_by_req(request, allow_guest=True) except InvalidClientCredentialsError as e: # Option to allow servers to require auth when accessing # /publicRooms via CS API. This is especially helpful in private @@ -367,19 +359,18 @@ class PublicRoomListRestServlet(TransactionRestServlet): handler = self.hs.get_room_list_handler() if server: - data = yield handler.get_remote_public_room_list( + data = await handler.get_remote_public_room_list( server, limit=limit, since_token=since_token ) else: - data = yield handler.get_local_public_room_list( + data = await handler.get_local_public_room_list( limit=limit, since_token=since_token ) return 200, data - @defer.inlineCallbacks - def on_POST(self, request): - yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_POST(self, request): + await self.auth.get_user_by_req(request, allow_guest=True) server = parse_string(request, "server", default=None) content = parse_json_object_from_request(request) @@ -408,7 +399,7 @@ class PublicRoomListRestServlet(TransactionRestServlet): handler = self.hs.get_room_list_handler() if server: - data = yield handler.get_remote_public_room_list( + data = await handler.get_remote_public_room_list( server, limit=limit, since_token=since_token, @@ -417,7 +408,7 @@ class PublicRoomListRestServlet(TransactionRestServlet): third_party_instance_id=third_party_instance_id, ) else: - data = yield handler.get_local_public_room_list( + data = await handler.get_local_public_room_list( limit=limit, since_token=since_token, search_filter=search_filter, @@ -436,10 +427,9 @@ class RoomMemberListRestServlet(RestServlet): self.message_handler = hs.get_message_handler() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request, room_id): + async def on_GET(self, request, room_id): # TODO support Pagination stream API (limit/tokens) - requester = yield self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request) handler = self.message_handler # request the state as of a given event, as identified by a stream token, @@ -459,7 +449,7 @@ class RoomMemberListRestServlet(RestServlet): membership = parse_string(request, "membership") not_membership = parse_string(request, "not_membership") - events = yield handler.get_state_events( + events = await handler.get_state_events( room_id=room_id, user_id=requester.user.to_string(), at_token=at_token, @@ -488,11 +478,10 @@ class JoinedRoomMemberListRestServlet(RestServlet): self.message_handler = hs.get_message_handler() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request, room_id): - requester = yield self.auth.get_user_by_req(request) + async def on_GET(self, request, room_id): + requester = await self.auth.get_user_by_req(request) - users_with_profile = yield self.message_handler.get_joined_members( + users_with_profile = await self.message_handler.get_joined_members( requester, room_id ) @@ -508,9 +497,8 @@ class RoomMessageListRestServlet(RestServlet): self.pagination_handler = hs.get_pagination_handler() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request, room_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, room_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) pagination_config = PaginationConfig.from_request(request, default_limit=10) as_client_event = b"raw" not in request.args filter_bytes = parse_string(request, b"filter", encoding=None) @@ -521,7 +509,7 @@ class RoomMessageListRestServlet(RestServlet): as_client_event = False else: event_filter = None - msgs = yield self.pagination_handler.get_messages( + msgs = await self.pagination_handler.get_messages( room_id=room_id, requester=requester, pagin_config=pagination_config, @@ -541,11 +529,10 @@ class RoomStateRestServlet(RestServlet): self.message_handler = hs.get_message_handler() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request, room_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, room_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) # Get all the current state for this room - events = yield self.message_handler.get_state_events( + events = await self.message_handler.get_state_events( room_id=room_id, user_id=requester.user.to_string(), is_guest=requester.is_guest, @@ -562,11 +549,10 @@ class RoomInitialSyncRestServlet(RestServlet): self.initial_sync_handler = hs.get_initial_sync_handler() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request, room_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, room_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) pagination_config = PaginationConfig.from_request(request) - content = yield self.initial_sync_handler.room_initial_sync( + content = await self.initial_sync_handler.room_initial_sync( room_id=room_id, requester=requester, pagin_config=pagination_config ) return 200, content @@ -584,11 +570,10 @@ class RoomEventServlet(RestServlet): self._event_serializer = hs.get_event_client_serializer() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request, room_id, event_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, room_id, event_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) try: - event = yield self.event_handler.get_event( + event = await self.event_handler.get_event( requester.user, room_id, event_id ) except AuthError: @@ -599,7 +584,7 @@ class RoomEventServlet(RestServlet): time_now = self.clock.time_msec() if event: - event = yield self._event_serializer.serialize_event(event, time_now) + event = await self._event_serializer.serialize_event(event, time_now) return 200, event return SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND) @@ -617,9 +602,8 @@ class RoomEventContextServlet(RestServlet): self._event_serializer = hs.get_event_client_serializer() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request, room_id, event_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, room_id, event_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) limit = parse_integer(request, "limit", default=10) @@ -631,7 +615,7 @@ class RoomEventContextServlet(RestServlet): else: event_filter = None - results = yield self.room_context_handler.get_event_context( + results = await self.room_context_handler.get_event_context( requester.user, room_id, event_id, limit, event_filter ) @@ -639,16 +623,16 @@ class RoomEventContextServlet(RestServlet): raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND) time_now = self.clock.time_msec() - results["events_before"] = yield self._event_serializer.serialize_events( + results["events_before"] = await self._event_serializer.serialize_events( results["events_before"], time_now ) - results["event"] = yield self._event_serializer.serialize_event( + results["event"] = await self._event_serializer.serialize_event( results["event"], time_now ) - results["events_after"] = yield self._event_serializer.serialize_events( + results["events_after"] = await self._event_serializer.serialize_events( results["events_after"], time_now ) - results["state"] = yield self._event_serializer.serialize_events( + results["state"] = await self._event_serializer.serialize_events( results["state"], time_now ) @@ -665,11 +649,10 @@ class RoomForgetRestServlet(TransactionRestServlet): PATTERNS = "/rooms/(?P<room_id>[^/]*)/forget" register_txn_path(self, PATTERNS, http_server) - @defer.inlineCallbacks - def on_POST(self, request, room_id, txn_id=None): - requester = yield self.auth.get_user_by_req(request, allow_guest=False) + async def on_POST(self, request, room_id, txn_id=None): + requester = await self.auth.get_user_by_req(request, allow_guest=False) - yield self.room_member_handler.forget(user=requester.user, room_id=room_id) + await self.room_member_handler.forget(user=requester.user, room_id=room_id) return 200, {} @@ -696,9 +679,8 @@ class RoomMembershipRestServlet(TransactionRestServlet): ) register_txn_path(self, PATTERNS, http_server) - @defer.inlineCallbacks - def on_POST(self, request, room_id, membership_action, txn_id=None): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_POST(self, request, room_id, membership_action, txn_id=None): + requester = await self.auth.get_user_by_req(request, allow_guest=True) if requester.is_guest and membership_action not in { Membership.JOIN, @@ -714,7 +696,7 @@ class RoomMembershipRestServlet(TransactionRestServlet): content = {} if membership_action == "invite" and self._has_3pid_invite_keys(content): - yield self.room_member_handler.do_3pid_invite( + await self.room_member_handler.do_3pid_invite( room_id, requester.user, content["medium"], @@ -735,7 +717,7 @@ class RoomMembershipRestServlet(TransactionRestServlet): if "reason" in content and membership_action in ["kick", "ban"]: event_content = {"reason": content["reason"]} - yield self.room_member_handler.update_membership( + await self.room_member_handler.update_membership( requester=requester, target=target, room_id=room_id, @@ -777,12 +759,11 @@ class RoomRedactEventRestServlet(TransactionRestServlet): PATTERNS = "/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)" register_txn_path(self, PATTERNS, http_server) - @defer.inlineCallbacks - def on_POST(self, request, room_id, event_id, txn_id=None): - requester = yield self.auth.get_user_by_req(request) + async def on_POST(self, request, room_id, event_id, txn_id=None): + requester = await self.auth.get_user_by_req(request) content = parse_json_object_from_request(request) - event = yield self.event_creation_handler.create_and_send_nonmember_event( + event = await self.event_creation_handler.create_and_send_nonmember_event( requester, { "type": EventTypes.Redaction, @@ -816,29 +797,28 @@ class RoomTypingRestServlet(RestServlet): self.typing_handler = hs.get_typing_handler() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_PUT(self, request, room_id, user_id): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, room_id, user_id): + requester = await self.auth.get_user_by_req(request) room_id = urlparse.unquote(room_id) target_user = UserID.from_string(urlparse.unquote(user_id)) content = parse_json_object_from_request(request) - yield self.presence_handler.bump_presence_active_time(requester.user) + await self.presence_handler.bump_presence_active_time(requester.user) # Limit timeout to stop people from setting silly typing timeouts. timeout = min(content.get("timeout", 30000), 120000) if content["typing"]: - yield self.typing_handler.started_typing( + await self.typing_handler.started_typing( target_user=target_user, auth_user=requester.user, room_id=room_id, timeout=timeout, ) else: - yield self.typing_handler.stopped_typing( + await self.typing_handler.stopped_typing( target_user=target_user, auth_user=requester.user, room_id=room_id ) @@ -853,14 +833,13 @@ class SearchRestServlet(RestServlet): self.handlers = hs.get_handlers() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_POST(self, request): - requester = yield self.auth.get_user_by_req(request) + async def on_POST(self, request): + requester = await self.auth.get_user_by_req(request) content = parse_json_object_from_request(request) batch = parse_string(request, "next_batch") - results = yield self.handlers.search_handler.search( + results = await self.handlers.search_handler.search( requester.user, content, batch ) @@ -875,11 +854,10 @@ class JoinedRoomsRestServlet(RestServlet): self.store = hs.get_datastore() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request): + requester = await self.auth.get_user_by_req(request, allow_guest=True) - room_ids = yield self.store.get_rooms_for_user(requester.user.to_string()) + room_ids = await self.store.get_rooms_for_user(requester.user.to_string()) return 200, {"joined_rooms": list(room_ids)} diff --git a/synapse/rest/client/v2_alpha/read_marker.py b/synapse/rest/client/v2_alpha/read_marker.py index b3bf8567e1..67cbc37312 100644 --- a/synapse/rest/client/v2_alpha/read_marker.py +++ b/synapse/rest/client/v2_alpha/read_marker.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.http.servlet import RestServlet, parse_json_object_from_request from ._base import client_patterns @@ -34,17 +32,16 @@ class ReadMarkerRestServlet(RestServlet): self.read_marker_handler = hs.get_read_marker_handler() self.presence_handler = hs.get_presence_handler() - @defer.inlineCallbacks - def on_POST(self, request, room_id): - requester = yield self.auth.get_user_by_req(request) + async def on_POST(self, request, room_id): + requester = await self.auth.get_user_by_req(request) - yield self.presence_handler.bump_presence_active_time(requester.user) + await self.presence_handler.bump_presence_active_time(requester.user) body = parse_json_object_from_request(request) read_event_id = body.get("m.read", None) if read_event_id: - yield self.receipts_handler.received_client_receipt( + await self.receipts_handler.received_client_receipt( room_id, "m.read", user_id=requester.user.to_string(), @@ -53,7 +50,7 @@ class ReadMarkerRestServlet(RestServlet): read_marker_event_id = body.get("m.fully_read", None) if read_marker_event_id: - yield self.read_marker_handler.received_client_read_marker( + await self.read_marker_handler.received_client_read_marker( room_id, user_id=requester.user.to_string(), event_id=read_marker_event_id, diff --git a/synapse/rest/client/v2_alpha/receipts.py b/synapse/rest/client/v2_alpha/receipts.py index 0dab03d227..92555bd4a9 100644 --- a/synapse/rest/client/v2_alpha/receipts.py +++ b/synapse/rest/client/v2_alpha/receipts.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.api.errors import SynapseError from synapse.http.servlet import RestServlet @@ -39,16 +37,15 @@ class ReceiptRestServlet(RestServlet): self.receipts_handler = hs.get_receipts_handler() self.presence_handler = hs.get_presence_handler() - @defer.inlineCallbacks - def on_POST(self, request, room_id, receipt_type, event_id): - requester = yield self.auth.get_user_by_req(request) + async def on_POST(self, request, room_id, receipt_type, event_id): + requester = await self.auth.get_user_by_req(request) if receipt_type != "m.read": raise SynapseError(400, "Receipt type must be 'm.read'") - yield self.presence_handler.bump_presence_active_time(requester.user) + await self.presence_handler.bump_presence_active_time(requester.user) - yield self.receipts_handler.received_client_receipt( + await self.receipts_handler.received_client_receipt( room_id, receipt_type, user_id=requester.user.to_string(), event_id=event_id ) diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index a883c8adda..541a6b0e10 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -112,9 +112,14 @@ class SyncRestServlet(RestServlet): full_state = parse_boolean(request, "full_state", default=False) logger.debug( - "/sync: user=%r, timeout=%r, since=%r," - " set_presence=%r, filter_id=%r, device_id=%r" - % (user, timeout, since, set_presence, filter_id, device_id) + "/sync: user=%r, timeout=%r, since=%r, " + "set_presence=%r, filter_id=%r, device_id=%r", + user, + timeout, + since, + set_presence, + filter_id, + device_id, ) request_key = (user, timeout, since, filter_id, full_state, device_id) diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 0c68c3aad5..094ebad770 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -117,8 +117,10 @@ class PreviewUrlResource(DirectServeResource): pattern = entry[attrib] value = getattr(url_tuple, attrib) logger.debug( - ("Matching attrib '%s' with value '%s' against" " pattern '%s'") - % (attrib, value, pattern) + "Matching attrib '%s' with value '%s' against" " pattern '%s'", + attrib, + value, + pattern, ) if value is None: @@ -186,7 +188,7 @@ class PreviewUrlResource(DirectServeResource): media_info = yield self._download_url(url, user) - logger.debug("got media_info of '%s'" % media_info) + logger.debug("got media_info of '%s'", media_info) if _is_media(media_info["media_type"]): file_id = media_info["filesystem_id"] @@ -254,7 +256,7 @@ class PreviewUrlResource(DirectServeResource): og["og:image:width"] = dims["width"] og["og:image:height"] = dims["height"] else: - logger.warn("Couldn't get dims for %s" % og["og:image"]) + logger.warn("Couldn't get dims for %s", og["og:image"]) og["og:image"] = "mxc://%s/%s" % ( self.server_name, @@ -268,7 +270,7 @@ class PreviewUrlResource(DirectServeResource): logger.warn("Failed to find any OG data in %s", url) og = {} - logger.debug("Calculated OG for %s as %s" % (url, og)) + logger.debug("Calculated OG for %s as %s", url, og) jsonog = json.dumps(og) @@ -297,7 +299,7 @@ class PreviewUrlResource(DirectServeResource): with self.media_storage.store_into_file(file_info) as (f, fname, finish): try: - logger.debug("Trying to get url '%s'" % url) + logger.debug("Trying to get url '%s'", url) length, headers, uri, code = yield self.client.get_file( url, output_stream=f, max_size=self.max_spider_size ) diff --git a/synapse/server_notices/resource_limits_server_notices.py b/synapse/server_notices/resource_limits_server_notices.py index 81c4aff496..c0e7f475c9 100644 --- a/synapse/server_notices/resource_limits_server_notices.py +++ b/synapse/server_notices/resource_limits_server_notices.py @@ -20,6 +20,7 @@ from twisted.internet import defer from synapse.api.constants import ( EventTypes, + LimitBlockingTypes, ServerNoticeLimitReached, ServerNoticeMsgType, ) @@ -70,7 +71,7 @@ class ResourceLimitsServerNotices(object): return if not self._server_notices_manager.is_enabled(): - # Don't try and send server notices unles they've been enabled + # Don't try and send server notices unless they've been enabled return timestamp = yield self._store.user_last_seen_monthly_active(user_id) @@ -79,8 +80,6 @@ class ResourceLimitsServerNotices(object): # In practice, not sure we can ever get here return - # Determine current state of room - room_id = yield self._server_notices_manager.get_notice_room_for_user(user_id) if not room_id: @@ -88,51 +87,86 @@ class ResourceLimitsServerNotices(object): return yield self._check_and_set_tags(user_id, room_id) + + # Determine current state of room currently_blocked, ref_events = yield self._is_room_currently_blocked(room_id) + limit_msg = None + limit_type = None try: - # Normally should always pass in user_id if you have it, but in - # this case are checking what would happen to other users if they - # were to arrive. - try: - yield self._auth.check_auth_blocking() - is_auth_blocking = False - except ResourceLimitError as e: - is_auth_blocking = True - event_content = e.msg - event_limit_type = e.limit_type - - if currently_blocked and not is_auth_blocking: - # Room is notifying of a block, when it ought not to be. - # Remove block notification - content = {"pinned": ref_events} - yield self._server_notices_manager.send_notice( - user_id, content, EventTypes.Pinned, "" - ) + # Normally should always pass in user_id to check_auth_blocking + # if you have it, but in this case are checking what would happen + # to other users if they were to arrive. + yield self._auth.check_auth_blocking() + except ResourceLimitError as e: + limit_msg = e.msg + limit_type = e.limit_type - elif not currently_blocked and is_auth_blocking: + try: + if ( + limit_type == LimitBlockingTypes.MONTHLY_ACTIVE_USER + and not self._config.mau_limit_alerting + ): + # We have hit the MAU limit, but MAU alerting is disabled: + # reset room if necessary and return + if currently_blocked: + self._remove_limit_block_notification(user_id, ref_events) + return + + if currently_blocked and not limit_msg: + # Room is notifying of a block, when it ought not to be. + yield self._remove_limit_block_notification(user_id, ref_events) + elif not currently_blocked and limit_msg: # Room is not notifying of a block, when it ought to be. - # Add block notification - content = { - "body": event_content, - "msgtype": ServerNoticeMsgType, - "server_notice_type": ServerNoticeLimitReached, - "admin_contact": self._config.admin_contact, - "limit_type": event_limit_type, - } - event = yield self._server_notices_manager.send_notice( - user_id, content, EventTypes.Message + yield self._apply_limit_block_notification( + user_id, limit_msg, limit_type ) - - content = {"pinned": [event.event_id]} - yield self._server_notices_manager.send_notice( - user_id, content, EventTypes.Pinned, "" - ) - except SynapseError as e: logger.error("Error sending resource limits server notice: %s", e) @defer.inlineCallbacks + def _remove_limit_block_notification(self, user_id, ref_events): + """Utility method to remove limit block notifications from the server + notices room. + + Args: + user_id (str): user to notify + ref_events (list[str]): The event_ids of pinned events that are unrelated to + limit blocking and need to be preserved. + """ + content = {"pinned": ref_events} + yield self._server_notices_manager.send_notice( + user_id, content, EventTypes.Pinned, "" + ) + + @defer.inlineCallbacks + def _apply_limit_block_notification(self, user_id, event_body, event_limit_type): + """Utility method to apply limit block notifications in the server + notices room. + + Args: + user_id (str): user to notify + event_body(str): The human readable text that describes the block. + event_limit_type(str): Specifies the type of block e.g. monthly active user + limit has been exceeded. + """ + content = { + "body": event_body, + "msgtype": ServerNoticeMsgType, + "server_notice_type": ServerNoticeLimitReached, + "admin_contact": self._config.admin_contact, + "limit_type": event_limit_type, + } + event = yield self._server_notices_manager.send_notice( + user_id, content, EventTypes.Message + ) + + content = {"pinned": [event.event_id]} + yield self._server_notices_manager.send_notice( + user_id, content, EventTypes.Pinned, "" + ) + + @defer.inlineCallbacks def _check_and_set_tags(self, user_id, room_id): """ Since server notices rooms were originally not with tags, diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index 80b57a948c..37d469ffd7 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -94,13 +94,16 @@ class BackgroundUpdateStore(SQLBaseStore): self._all_done = False def start_doing_background_updates(self): - run_as_background_process("background_updates", self._run_background_updates) + run_as_background_process("background_updates", self.run_background_updates) @defer.inlineCallbacks - def _run_background_updates(self): + def run_background_updates(self, sleep=True): logger.info("Starting background schema updates") while True: - yield self.hs.get_clock().sleep(self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.0) + if sleep: + yield self.hs.get_clock().sleep( + self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.0 + ) try: result = yield self.do_next_background_update( diff --git a/synapse/storage/data_stores/main/e2e_room_keys.py b/synapse/storage/data_stores/main/e2e_room_keys.py index ef88e79293..1cbbae5b63 100644 --- a/synapse/storage/data_stores/main/e2e_room_keys.py +++ b/synapse/storage/data_stores/main/e2e_room_keys.py @@ -321,9 +321,17 @@ class EndToEndRoomKeyStore(SQLBaseStore): def _delete_e2e_room_keys_version_txn(txn): if version is None: this_version = self._get_current_version(txn, user_id) + if this_version is None: + raise StoreError(404, "No current backup version") else: this_version = version + self._simple_delete_txn( + txn, + table="e2e_room_keys", + keyvalues={"user_id": user_id, "version": this_version}, + ) + return self._simple_update_one_txn( txn, table="e2e_room_keys_versions", diff --git a/synapse/storage/data_stores/main/end_to_end_keys.py b/synapse/storage/data_stores/main/end_to_end_keys.py index f5c3ed9dc2..a0bc6f2d18 100644 --- a/synapse/storage/data_stores/main/end_to_end_keys.py +++ b/synapse/storage/data_stores/main/end_to_end_keys.py @@ -248,6 +248,73 @@ class EndToEndKeyWorkerStore(SQLBaseStore): return self.runInteraction("count_e2e_one_time_keys", _count_e2e_one_time_keys) + def _get_e2e_cross_signing_key_txn(self, txn, user_id, key_type, from_user_id=None): + """Returns a user's cross-signing key. + + Args: + txn (twisted.enterprise.adbapi.Connection): db connection + user_id (str): the user whose key is being requested + key_type (str): the type of key that is being set: either 'master' + for a master key, 'self_signing' for a self-signing key, or + 'user_signing' for a user-signing key + from_user_id (str): if specified, signatures made by this user on + the key will be included in the result + + Returns: + dict of the key data or None if not found + """ + sql = ( + "SELECT keydata " + " FROM e2e_cross_signing_keys " + " WHERE user_id = ? AND keytype = ? ORDER BY stream_id DESC LIMIT 1" + ) + txn.execute(sql, (user_id, key_type)) + row = txn.fetchone() + if not row: + return None + key = json.loads(row[0]) + + device_id = None + for k in key["keys"].values(): + device_id = k + + if from_user_id is not None: + sql = ( + "SELECT key_id, signature " + " FROM e2e_cross_signing_signatures " + " WHERE user_id = ? " + " AND target_user_id = ? " + " AND target_device_id = ? " + ) + txn.execute(sql, (from_user_id, user_id, device_id)) + row = txn.fetchone() + if row: + key.setdefault("signatures", {}).setdefault(from_user_id, {})[ + row[0] + ] = row[1] + + return key + + def get_e2e_cross_signing_key(self, user_id, key_type, from_user_id=None): + """Returns a user's cross-signing key. + + Args: + user_id (str): the user whose self-signing key is being requested + key_type (str): the type of cross-signing key to get + from_user_id (str): if specified, signatures made by this user on + the self-signing key will be included in the result + + Returns: + dict of the key data or None if not found + """ + return self.runInteraction( + "get_e2e_cross_signing_key", + self._get_e2e_cross_signing_key_txn, + user_id, + key_type, + from_user_id, + ) + class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys): @@ -426,73 +493,6 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): key, ) - def _get_e2e_cross_signing_key_txn(self, txn, user_id, key_type, from_user_id=None): - """Returns a user's cross-signing key. - - Args: - txn (twisted.enterprise.adbapi.Connection): db connection - user_id (str): the user whose key is being requested - key_type (str): the type of key that is being set: either 'master' - for a master key, 'self_signing' for a self-signing key, or - 'user_signing' for a user-signing key - from_user_id (str): if specified, signatures made by this user on - the key will be included in the result - - Returns: - dict of the key data or None if not found - """ - sql = ( - "SELECT keydata " - " FROM e2e_cross_signing_keys " - " WHERE user_id = ? AND keytype = ? ORDER BY stream_id DESC LIMIT 1" - ) - txn.execute(sql, (user_id, key_type)) - row = txn.fetchone() - if not row: - return None - key = json.loads(row[0]) - - device_id = None - for k in key["keys"].values(): - device_id = k - - if from_user_id is not None: - sql = ( - "SELECT key_id, signature " - " FROM e2e_cross_signing_signatures " - " WHERE user_id = ? " - " AND target_user_id = ? " - " AND target_device_id = ? " - ) - txn.execute(sql, (from_user_id, user_id, device_id)) - row = txn.fetchone() - if row: - key.setdefault("signatures", {}).setdefault(from_user_id, {})[ - row[0] - ] = row[1] - - return key - - def get_e2e_cross_signing_key(self, user_id, key_type, from_user_id=None): - """Returns a user's cross-signing key. - - Args: - user_id (str): the user whose self-signing key is being requested - key_type (str): the type of cross-signing key to get - from_user_id (str): if specified, signatures made by this user on - the self-signing key will be included in the result - - Returns: - dict of the key data or None if not found - """ - return self.runInteraction( - "get_e2e_cross_signing_key", - self._get_e2e_cross_signing_key_txn, - user_id, - key_type, - from_user_id, - ) - def store_e2e_cross_signing_signatures(self, user_id, signatures): """Stores cross-signing signatures. diff --git a/synapse/storage/data_stores/main/event_federation.py b/synapse/storage/data_stores/main/event_federation.py index a470a48e0f..90bef0cd2c 100644 --- a/synapse/storage/data_stores/main/event_federation.py +++ b/synapse/storage/data_stores/main/event_federation.py @@ -364,9 +364,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas ) def _get_backfill_events(self, txn, room_id, event_list, limit): - logger.debug( - "_get_backfill_events: %s, %s, %s", room_id, repr(event_list), limit - ) + logger.debug("_get_backfill_events: %s, %r, %s", room_id, event_list, limit) event_results = set() diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index 813f34528c..7c3607f308 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -1881,12 +1881,11 @@ class EventsStore( logger.info("[purge] done") - @defer.inlineCallbacks - def is_event_after(self, event_id1, event_id2): + async def is_event_after(self, event_id1, event_id2): """Returns True if event_id1 is after event_id2 in the stream """ - to_1, so_1 = yield self._get_event_ordering(event_id1) - to_2, so_2 = yield self._get_event_ordering(event_id2) + to_1, so_1 = await self._get_event_ordering(event_id1) + to_2, so_2 = await self._get_event_ordering(event_id2) return (to_1, so_1) > (to_2, so_2) @cachedInlineCallbacks(max_entries=5000) diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py index 4428e5c55d..67bb1b6f60 100644 --- a/synapse/storage/data_stores/main/room.py +++ b/synapse/storage/data_stores/main/room.py @@ -201,13 +201,17 @@ class RoomWorkerStore(SQLBaseStore): where_clauses.append( """ ( - name LIKE ? - OR topic LIKE ? - OR canonical_alias LIKE ? + LOWER(name) LIKE ? + OR LOWER(topic) LIKE ? + OR LOWER(canonical_alias) LIKE ? ) """ ) - query_args += [search_term, search_term, search_term] + query_args += [ + search_term.lower(), + search_term.lower(), + search_term.lower(), + ] where_clause = "" if where_clauses: diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py index e47ab604dd..bc04bfd7d4 100644 --- a/synapse/storage/data_stores/main/roommember.py +++ b/synapse/storage/data_stores/main/roommember.py @@ -720,7 +720,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): # See bulk_get_push_rules_for_room for how we work around this. assert state_group is not None - cache = self._get_joined_hosts_cache(room_id) + cache = yield self._get_joined_hosts_cache(room_id) joined_hosts = yield cache.get_destinations(state_entry) return joined_hosts diff --git a/synapse/storage/data_stores/main/schema/delta/56/delete_keys_from_deleted_backups.sql b/synapse/storage/data_stores/main/schema/delta/56/delete_keys_from_deleted_backups.sql new file mode 100644 index 0000000000..1d2ddb1b1a --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/56/delete_keys_from_deleted_backups.sql @@ -0,0 +1,25 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* delete room keys that belong to deleted room key version, or to room key + * versions that don't exist (anymore) + */ +DELETE FROM e2e_room_keys +WHERE version NOT IN ( + SELECT version + FROM e2e_room_keys_versions + WHERE e2e_room_keys.user_id = e2e_room_keys_versions.user_id + AND e2e_room_keys_versions.deleted = 0 +); diff --git a/synapse/storage/schema/delta/56/hidden_devices.sql b/synapse/storage/data_stores/main/schema/delta/56/hidden_devices.sql index 67f8b20297..67f8b20297 100644 --- a/synapse/storage/schema/delta/56/hidden_devices.sql +++ b/synapse/storage/data_stores/main/schema/delta/56/hidden_devices.sql diff --git a/synapse/storage/schema/delta/56/signing_keys.sql b/synapse/storage/data_stores/main/schema/delta/56/signing_keys.sql index 27a96123e3..27a96123e3 100644 --- a/synapse/storage/schema/delta/56/signing_keys.sql +++ b/synapse/storage/data_stores/main/schema/delta/56/signing_keys.sql diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py index d54442e5fa..9b2207075b 100644 --- a/synapse/storage/data_stores/main/state.py +++ b/synapse/storage/data_stores/main/state.py @@ -15,6 +15,7 @@ import logging from collections import namedtuple +from typing import Iterable, Tuple from six import iteritems, itervalues from six.moves import range @@ -23,6 +24,8 @@ from twisted.internet import defer from synapse.api.constants import EventTypes from synapse.api.errors import NotFoundError +from synapse.events import EventBase +from synapse.events.snapshot import EventContext from synapse.storage._base import SQLBaseStore from synapse.storage.background_updates import BackgroundUpdateStore from synapse.storage.data_stores.main.events_worker import EventsWorkerStore @@ -1215,7 +1218,9 @@ class StateStore(StateGroupWorkerStore, StateBackgroundUpdateStore): def __init__(self, db_conn, hs): super(StateStore, self).__init__(db_conn, hs) - def _store_event_state_mappings_txn(self, txn, events_and_contexts): + def _store_event_state_mappings_txn( + self, txn, events_and_contexts: Iterable[Tuple[EventBase, EventContext]] + ): state_groups = {} for event, context in events_and_contexts: if event.internal_metadata.is_outlier(): diff --git a/synapse/storage/data_stores/main/stats.py b/synapse/storage/data_stores/main/stats.py index 5ab639b2ad..4d59b7833f 100644 --- a/synapse/storage/data_stores/main/stats.py +++ b/synapse/storage/data_stores/main/stats.py @@ -332,7 +332,7 @@ class StatsStore(StateDeltasStore): def _bulk_update_stats_delta_txn(txn): for stats_type, stats_updates in updates.items(): for stats_id, fields in stats_updates.items(): - logger.info( + logger.debug( "Updating %s stats for %s: %s", stats_type, stats_id, fields ) self._update_stats_delta_txn( diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 804dbca443..b60a604474 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -86,11 +86,12 @@ class ObservableDeferred(object): deferred.addCallbacks(callback, errback) - def observe(self): + def observe(self) -> defer.Deferred: """Observe the underlying deferred. - Can return either a deferred if the underlying deferred is still pending - (or has failed), or the actual value. Callers may need to use maybeDeferred. + This returns a brand new deferred that is resolved when the underlying + deferred is resolved. Interacting with the returned deferred does not + effect the underdlying deferred. """ if not self._result: d = defer.Deferred() @@ -105,7 +106,7 @@ class ObservableDeferred(object): return d else: success, res = self._result - return res if success else defer.fail(res) + return defer.succeed(res) if success else defer.fail(res) def observers(self): return self._observers @@ -138,7 +139,7 @@ def concurrently_execute(func, args, limit): the number of concurrent executions. Args: - func (func): Function to execute, should return a deferred. + func (func): Function to execute, should return a deferred or coroutine. args (list): List of arguments to pass to func, each invocation of func gets a signle argument. limit (int): Maximum number of conccurent executions. @@ -148,11 +149,10 @@ def concurrently_execute(func, args, limit): """ it = iter(args) - @defer.inlineCallbacks - def _concurrently_execute_inner(): + async def _concurrently_execute_inner(): try: while True: - yield func(next(it)) + await maybe_awaitable(func(next(it))) except StopIteration: pass diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 5ac2530a6a..0e8da27f53 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -438,7 +438,7 @@ class CacheDescriptor(_CacheDescriptorBase): if isinstance(cached_result_d, ObservableDeferred): observer = cached_result_d.observe() else: - observer = cached_result_d + observer = defer.succeed(cached_result_d) except KeyError: ret = defer.maybeDeferred( @@ -482,9 +482,8 @@ class CacheListDescriptor(_CacheDescriptorBase): Given a list of keys it looks in the cache to find any hits, then passes the list of missing keys to the wrapped function. - Once wrapped, the function returns either a Deferred which resolves to - the list of results, or (if all results were cached), just the list of - results. + Once wrapped, the function returns a Deferred which resolves to the list + of results. """ def __init__( @@ -618,7 +617,7 @@ class CacheListDescriptor(_CacheDescriptorBase): ) return make_deferred_yieldable(d) else: - return results + return defer.succeed(results) obj.__dict__[self.orig.__name__] = wrapped diff --git a/synapse/util/versionstring.py b/synapse/util/versionstring.py index fa404b9d75..ab7d03af3a 100644 --- a/synapse/util/versionstring.py +++ b/synapse/util/versionstring.py @@ -42,6 +42,7 @@ def get_version_string(module): try: null = open(os.devnull, "w") cwd = os.path.dirname(os.path.abspath(module.__file__)) + try: git_branch = ( subprocess.check_output( @@ -51,7 +52,8 @@ def get_version_string(module): .decode("ascii") ) git_branch = "b=" + git_branch - except subprocess.CalledProcessError: + except (subprocess.CalledProcessError, FileNotFoundError): + # FileNotFoundError can arise when git is not installed git_branch = "" try: @@ -63,7 +65,7 @@ def get_version_string(module): .decode("ascii") ) git_tag = "t=" + git_tag - except subprocess.CalledProcessError: + except (subprocess.CalledProcessError, FileNotFoundError): git_tag = "" try: @@ -74,7 +76,7 @@ def get_version_string(module): .strip() .decode("ascii") ) - except subprocess.CalledProcessError: + except (subprocess.CalledProcessError, FileNotFoundError): git_commit = "" try: @@ -89,7 +91,7 @@ def get_version_string(module): ) git_dirty = "dirty" if is_dirty else "" - except subprocess.CalledProcessError: + except (subprocess.CalledProcessError, FileNotFoundError): git_dirty = "" if git_branch or git_tag or git_commit or git_dirty: |