From e908b8683248328dd92479fc81be350336b9c8f4 Mon Sep 17 00:00:00 2001 From: Travis Ralston Date: Mon, 30 Jul 2018 16:24:02 -0600 Subject: Remove pdu_failures from transactions The field is never read from, and all the opportunities given to populate it are not utilized. It should be very safe to remove this. --- synapse/federation/federation_server.py | 4 --- synapse/federation/send_queue.py | 63 +-------------------------------- synapse/federation/transaction_queue.py | 32 +++-------------- synapse/federation/transport/server.py | 3 +- synapse/federation/units.py | 1 - 5 files changed, 7 insertions(+), 96 deletions(-) (limited to 'synapse/federation') diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index e501251b6e..657935d1ac 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -207,10 +207,6 @@ class FederationServer(FederationBase): edu.content ) - pdu_failures = getattr(transaction, "pdu_failures", []) - for fail in pdu_failures: - logger.info("Got failure %r", fail) - response = { "pdus": pdu_results, } diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py index 5157c3860d..0bb468385d 100644 --- a/synapse/federation/send_queue.py +++ b/synapse/federation/send_queue.py @@ -62,8 +62,6 @@ class FederationRemoteSendQueue(object): self.edus = SortedDict() # stream position -> Edu - self.failures = SortedDict() # stream position -> (destination, Failure) - self.device_messages = SortedDict() # stream position -> destination self.pos = 1 @@ -79,7 +77,7 @@ class FederationRemoteSendQueue(object): for queue_name in [ "presence_map", "presence_changed", "keyed_edu", "keyed_edu_changed", - "edus", "failures", "device_messages", "pos_time", + "edus", "device_messages", "pos_time", ]: register(queue_name, getattr(self, queue_name)) @@ -149,12 +147,6 @@ class FederationRemoteSendQueue(object): for key in keys[:i]: del self.edus[key] - # Delete things out of failure map - keys = self.failures.keys() - i = self.failures.bisect_left(position_to_delete) - for key in keys[:i]: - del self.failures[key] - # Delete things out of device map keys = self.device_messages.keys() i = self.device_messages.bisect_left(position_to_delete) @@ -204,13 +196,6 @@ class FederationRemoteSendQueue(object): self.notifier.on_new_replication_data() - def send_failure(self, failure, destination): - """As per TransactionQueue""" - pos = self._next_pos() - - self.failures[pos] = (destination, str(failure)) - self.notifier.on_new_replication_data() - def send_device_messages(self, destination): """As per TransactionQueue""" pos = self._next_pos() @@ -285,17 +270,6 @@ class FederationRemoteSendQueue(object): for (pos, edu) in edus: rows.append((pos, EduRow(edu))) - # Fetch changed failures - i = self.failures.bisect_right(from_token) - j = self.failures.bisect_right(to_token) + 1 - failures = self.failures.items()[i:j] - - for (pos, (destination, failure)) in failures: - rows.append((pos, FailureRow( - destination=destination, - failure=failure, - ))) - # Fetch changed device messages i = self.device_messages.bisect_right(from_token) j = self.device_messages.bisect_right(to_token) + 1 @@ -417,34 +391,6 @@ class EduRow(BaseFederationRow, namedtuple("EduRow", ( buff.edus.setdefault(self.edu.destination, []).append(self.edu) -class FailureRow(BaseFederationRow, namedtuple("FailureRow", ( - "destination", # str - "failure", -))): - """Streams failures to a remote server. Failures are issued when there was - something wrong with a transaction the remote sent us, e.g. it included - an event that was invalid. - """ - - TypeId = "f" - - @staticmethod - def from_data(data): - return FailureRow( - destination=data["destination"], - failure=data["failure"], - ) - - def to_data(self): - return { - "destination": self.destination, - "failure": self.failure, - } - - def add_to_buffer(self, buff): - buff.failures.setdefault(self.destination, []).append(self.failure) - - class DeviceRow(BaseFederationRow, namedtuple("DeviceRow", ( "destination", # str ))): @@ -471,7 +417,6 @@ TypeToRow = { PresenceRow, KeyedEduRow, EduRow, - FailureRow, DeviceRow, ) } @@ -481,7 +426,6 @@ ParsedFederationStreamData = namedtuple("ParsedFederationStreamData", ( "presence", # list(UserPresenceState) "keyed_edus", # dict of destination -> { key -> Edu } "edus", # dict of destination -> [Edu] - "failures", # dict of destination -> [failures] "device_destinations", # set of destinations )) @@ -503,7 +447,6 @@ def process_rows_for_federation(transaction_queue, rows): presence=[], keyed_edus={}, edus={}, - failures={}, device_destinations=set(), ) @@ -532,9 +475,5 @@ def process_rows_for_federation(transaction_queue, rows): edu.destination, edu.edu_type, edu.content, key=None, ) - for destination, failure_list in iteritems(buff.failures): - for failure in failure_list: - transaction_queue.send_failure(destination, failure) - for destination in buff.device_destinations: transaction_queue.send_device_messages(destination) diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py index 6996d6b695..78f9d40a3a 100644 --- a/synapse/federation/transaction_queue.py +++ b/synapse/federation/transaction_queue.py @@ -116,9 +116,6 @@ class TransactionQueue(object): ), ) - # destination -> list of tuple(failure, deferred) - self.pending_failures_by_dest = {} - # destination -> stream_id of last successfully sent to-device message. # NB: may be a long or an int. self.last_device_stream_id_by_dest = {} @@ -382,19 +379,6 @@ class TransactionQueue(object): self._attempt_new_transaction(destination) - def send_failure(self, failure, destination): - if destination == self.server_name or destination == "localhost": - return - - if not self.can_send_to(destination): - return - - self.pending_failures_by_dest.setdefault( - destination, [] - ).append(failure) - - self._attempt_new_transaction(destination) - def send_device_messages(self, destination): if destination == self.server_name or destination == "localhost": return @@ -469,7 +453,6 @@ class TransactionQueue(object): pending_pdus = self.pending_pdus_by_dest.pop(destination, []) pending_edus = self.pending_edus_by_dest.pop(destination, []) pending_presence = self.pending_presence_by_dest.pop(destination, {}) - pending_failures = self.pending_failures_by_dest.pop(destination, []) pending_edus.extend( self.pending_edus_keyed_by_dest.pop(destination, {}).values() @@ -497,7 +480,7 @@ class TransactionQueue(object): logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d", destination, len(pending_pdus)) - if not pending_pdus and not pending_edus and not pending_failures: + if not pending_pdus and not pending_edus: logger.debug("TX [%s] Nothing to send", destination) self.last_device_stream_id_by_dest[destination] = ( device_stream_id @@ -507,7 +490,7 @@ class TransactionQueue(object): # END CRITICAL SECTION success = yield self._send_new_transaction( - destination, pending_pdus, pending_edus, pending_failures, + destination, pending_pdus, pending_edus, ) if success: sent_transactions_counter.inc() @@ -584,14 +567,12 @@ class TransactionQueue(object): @measure_func("_send_new_transaction") @defer.inlineCallbacks - def _send_new_transaction(self, destination, pending_pdus, pending_edus, - pending_failures): + def _send_new_transaction(self, destination, pending_pdus, pending_edus): # Sort based on the order field pending_pdus.sort(key=lambda t: t[1]) pdus = [x[0] for x in pending_pdus] edus = pending_edus - failures = [x.get_dict() for x in pending_failures] success = True @@ -601,11 +582,10 @@ class TransactionQueue(object): logger.debug( "TX [%s] {%s} Attempting new transaction" - " (pdus: %d, edus: %d, failures: %d)", + " (pdus: %d, edus: %d)", destination, txn_id, len(pdus), len(edus), - len(failures) ) logger.debug("TX [%s] Persisting transaction...", destination) @@ -617,7 +597,6 @@ class TransactionQueue(object): destination=destination, pdus=pdus, edus=edus, - pdu_failures=failures, ) self._next_txn_id += 1 @@ -627,12 +606,11 @@ class TransactionQueue(object): logger.debug("TX [%s] Persisted transaction", destination) logger.info( "TX [%s] {%s} Sending transaction [%s]," - " (PDUs: %d, EDUs: %d, failures: %d)", + " (PDUs: %d, EDUs: %d)", destination, txn_id, transaction.transaction_id, len(pdus), len(edus), - len(failures), ) # Actually send the transaction diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 8574898f0c..3b5ea9515a 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -283,11 +283,10 @@ class FederationSendServlet(BaseFederationServlet): ) logger.info( - "Received txn %s from %s. (PDUs: %d, EDUs: %d, failures: %d)", + "Received txn %s from %s. (PDUs: %d, EDUs: %d)", transaction_id, origin, len(transaction_data.get("pdus", [])), len(transaction_data.get("edus", [])), - len(transaction_data.get("failures", [])), ) # We should ideally be getting this from the security layer. diff --git a/synapse/federation/units.py b/synapse/federation/units.py index bb1b3b13f7..c5ab14314e 100644 --- a/synapse/federation/units.py +++ b/synapse/federation/units.py @@ -73,7 +73,6 @@ class Transaction(JsonEncodedObject): "previous_ids", "pdus", "edus", - "pdu_failures", ] internal_keys = [ -- cgit 1.5.1 From da7785147df442eb9cdc1031fa5fea12b7b25334 Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Thu, 2 Aug 2018 00:54:06 +1000 Subject: Python 3: Convert some unicode/bytes uses (#3569) --- changelog.d/3569.bugfix | 1 + synapse/api/auth.py | 4 ++-- synapse/federation/transport/server.py | 2 +- synapse/handlers/auth.py | 29 ++++++++++++++++++-------- synapse/handlers/register.py | 2 +- synapse/http/server.py | 35 +++++++++++++++++++++++--------- synapse/http/servlet.py | 10 ++++++++- synapse/rest/client/v1/admin.py | 22 +++++++++++++------- synapse/rest/client/v2_alpha/register.py | 12 +++++------ synapse/rest/media/v1/media_storage.py | 2 +- synapse/state.py | 2 +- synapse/storage/events.py | 14 +++++++++---- synapse/storage/signatures.py | 2 +- synapse/types.py | 2 +- synapse/util/frozenutils.py | 6 +++--- tests/api/test_auth.py | 35 +++++++++++++++++--------------- tests/utils.py | 9 +++++--- 17 files changed, 122 insertions(+), 67 deletions(-) create mode 100644 changelog.d/3569.bugfix (limited to 'synapse/federation') diff --git a/changelog.d/3569.bugfix b/changelog.d/3569.bugfix new file mode 100644 index 0000000000..d77f035ee0 --- /dev/null +++ b/changelog.d/3569.bugfix @@ -0,0 +1 @@ +Unicode passwords are now normalised before hashing, preventing the instance where two different devices or browsers might send a different UTF-8 sequence for the password. diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 073229b4c4..5bbbe8e2e7 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -252,10 +252,10 @@ class Auth(object): if ip_address not in app_service.ip_range_whitelist: defer.returnValue((None, None)) - if "user_id" not in request.args: + if b"user_id" not in request.args: defer.returnValue((app_service.sender, app_service)) - user_id = request.args["user_id"][0] + user_id = request.args[b"user_id"][0].decode('utf8') if app_service.sender == user_id: defer.returnValue((app_service.sender, app_service)) diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 3b5ea9515a..eae5f2b427 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -165,7 +165,7 @@ def _parse_auth_header(header_bytes): param_dict = dict(kv.split("=") for kv in params) def strip_quotes(value): - if value.startswith(b"\""): + if value.startswith("\""): return value[1:-1] else: return value diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 402e44cdef..5d03bfa5f7 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -15,6 +15,7 @@ # limitations under the License. import logging +import unicodedata import attr import bcrypt @@ -626,6 +627,7 @@ class AuthHandler(BaseHandler): # special case to check for "password" for the check_password interface # for the auth providers password = login_submission.get("password") + if login_type == LoginType.PASSWORD: if not self._password_enabled: raise SynapseError(400, "Password login has been disabled.") @@ -707,9 +709,10 @@ class AuthHandler(BaseHandler): multiple inexact matches. Args: - user_id (str): complete @user:id + user_id (unicode): complete @user:id + password (unicode): the provided password Returns: - (str) the canonical_user_id, or None if unknown user / bad password + (unicode) the canonical_user_id, or None if unknown user / bad password """ lookupres = yield self._find_user_id_and_pwd_hash(user_id) if not lookupres: @@ -849,14 +852,19 @@ class AuthHandler(BaseHandler): """Computes a secure hash of password. Args: - password (str): Password to hash. + password (unicode): Password to hash. Returns: - Deferred(str): Hashed password. + Deferred(unicode): Hashed password. """ def _do_hash(): - return bcrypt.hashpw(password.encode('utf8') + self.hs.config.password_pepper, - bcrypt.gensalt(self.bcrypt_rounds)) + # Normalise the Unicode in the password + pw = unicodedata.normalize("NFKC", password) + + return bcrypt.hashpw( + pw.encode('utf8') + self.hs.config.password_pepper.encode("utf8"), + bcrypt.gensalt(self.bcrypt_rounds), + ).decode('ascii') return make_deferred_yieldable( threads.deferToThreadPool( @@ -868,16 +876,19 @@ class AuthHandler(BaseHandler): """Validates that self.hash(password) == stored_hash. Args: - password (str): Password to hash. - stored_hash (str): Expected hash value. + password (unicode): Password to hash. + stored_hash (unicode): Expected hash value. Returns: Deferred(bool): Whether self.hash(password) == stored_hash. """ def _do_validate_hash(): + # Normalise the Unicode in the password + pw = unicodedata.normalize("NFKC", password) + return bcrypt.checkpw( - password.encode('utf8') + self.hs.config.password_pepper, + pw.encode('utf8') + self.hs.config.password_pepper.encode("utf8"), stored_hash.encode('utf8') ) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 7caff0cbc8..234f8e8019 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -131,7 +131,7 @@ class RegistrationHandler(BaseHandler): Args: localpart : The local part of the user ID to register. If None, one will be generated. - password (str) : The password to assign to this user so they can + password (unicode) : The password to assign to this user so they can login again. This can be None which means they cannot login again via a password (e.g. the user is an application service user). generate_token (bool): Whether a new access token should be diff --git a/synapse/http/server.py b/synapse/http/server.py index c70fdbdfd2..1940c1c4f4 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -13,12 +13,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import cgi import collections import logging -import urllib -from six.moves import http_client +from six import PY3 +from six.moves import http_client, urllib from canonicaljson import encode_canonical_json, encode_pretty_printed_json, json @@ -264,6 +265,7 @@ class JsonResource(HttpServer, resource.Resource): self.hs = hs def register_paths(self, method, path_patterns, callback): + method = method.encode("utf-8") # method is bytes on py3 for path_pattern in path_patterns: logger.debug("Registering for %s %s", method, path_pattern.pattern) self.path_regexs.setdefault(method, []).append( @@ -296,8 +298,19 @@ class JsonResource(HttpServer, resource.Resource): # here. If it throws an exception, that is handled by the wrapper # installed by @request_handler. + def _unquote(s): + if PY3: + # On Python 3, unquote is unicode -> unicode + return urllib.parse.unquote(s) + else: + # On Python 2, unquote is bytes -> bytes We need to encode the + # URL again (as it was decoded by _get_handler_for request), as + # ASCII because it's a URL, and then decode it to get the UTF-8 + # characters that were quoted. + return urllib.parse.unquote(s.encode('ascii')).decode('utf8') + kwargs = intern_dict({ - name: urllib.unquote(value).decode("UTF-8") if value else value + name: _unquote(value) if value else value for name, value in group_dict.items() }) @@ -313,9 +326,9 @@ class JsonResource(HttpServer, resource.Resource): request (twisted.web.http.Request): Returns: - Tuple[Callable, dict[str, str]]: callback method, and the dict - mapping keys to path components as specified in the handler's - path match regexp. + Tuple[Callable, dict[unicode, unicode]]: callback method, and the + dict mapping keys to path components as specified in the + handler's path match regexp. The callback will normally be a method registered via register_paths, so will return (possibly via Deferred) either @@ -327,7 +340,7 @@ class JsonResource(HttpServer, resource.Resource): # Loop through all the registered callbacks to check if the method # and path regex match for path_entry in self.path_regexs.get(request.method, []): - m = path_entry.pattern.match(request.path) + m = path_entry.pattern.match(request.path.decode('ascii')) if m: # We found a match! return path_entry.callback, m.groupdict() @@ -383,7 +396,7 @@ class RootRedirect(resource.Resource): self.url = path def render_GET(self, request): - return redirectTo(self.url, request) + return redirectTo(self.url.encode('ascii'), request) def getChild(self, name, request): if len(name) == 0: @@ -404,12 +417,14 @@ def respond_with_json(request, code, json_object, send_cors=False, return if pretty_print: - json_bytes = encode_pretty_printed_json(json_object) + "\n" + json_bytes = (encode_pretty_printed_json(json_object) + "\n" + ).encode("utf-8") else: if canonical_json or synapse.events.USE_FROZEN_DICTS: + # canonicaljson already encodes to bytes json_bytes = encode_canonical_json(json_object) else: - json_bytes = json.dumps(json_object) + json_bytes = json.dumps(json_object).encode("utf-8") return respond_with_json_bytes( request, code, json_bytes, diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index 882816dc8f..69f7085291 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -171,8 +171,16 @@ def parse_json_value_from_request(request, allow_empty_body=False): if not content_bytes and allow_empty_body: return None + # Decode to Unicode so that simplejson will return Unicode strings on + # Python 2 try: - content = json.loads(content_bytes) + content_unicode = content_bytes.decode('utf8') + except UnicodeDecodeError: + logger.warn("Unable to decode UTF-8") + raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON) + + try: + content = json.loads(content_unicode) except Exception as e: logger.warn("Unable to parse JSON: %s", e) raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON) diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py index 99f6c6e3c3..80d625eecc 100644 --- a/synapse/rest/client/v1/admin.py +++ b/synapse/rest/client/v1/admin.py @@ -18,6 +18,7 @@ import hashlib import hmac import logging +from six import text_type from six.moves import http_client from twisted.internet import defer @@ -131,7 +132,10 @@ class UserRegisterServlet(ClientV1RestServlet): 400, "username must be specified", errcode=Codes.BAD_JSON, ) else: - if (not isinstance(body['username'], str) or len(body['username']) > 512): + if ( + not isinstance(body['username'], text_type) + or len(body['username']) > 512 + ): raise SynapseError(400, "Invalid username") username = body["username"].encode("utf-8") @@ -143,7 +147,10 @@ class UserRegisterServlet(ClientV1RestServlet): 400, "password must be specified", errcode=Codes.BAD_JSON, ) else: - if (not isinstance(body['password'], str) or len(body['password']) > 512): + if ( + not isinstance(body['password'], text_type) + or len(body['password']) > 512 + ): raise SynapseError(400, "Invalid password") password = body["password"].encode("utf-8") @@ -166,17 +173,18 @@ class UserRegisterServlet(ClientV1RestServlet): want_mac.update(b"admin" if admin else b"notadmin") want_mac = want_mac.hexdigest() - if not hmac.compare_digest(want_mac, got_mac): - raise SynapseError( - 403, "HMAC incorrect", - ) + if not hmac.compare_digest(want_mac, got_mac.encode('ascii')): + raise SynapseError(403, "HMAC incorrect") # Reuse the parts of RegisterRestServlet to reduce code duplication from synapse.rest.client.v2_alpha.register import RegisterRestServlet + register = RegisterRestServlet(self.hs) (user_id, _) = yield register.registration_handler.register( - localpart=username.lower(), password=password, admin=bool(admin), + localpart=body['username'].lower(), + password=body["password"], + admin=bool(admin), generate_token=False, ) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index d6cf915d86..2f64155d13 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -193,15 +193,15 @@ class RegisterRestServlet(RestServlet): def on_POST(self, request): body = parse_json_object_from_request(request) - kind = "user" - if "kind" in request.args: - kind = request.args["kind"][0] + kind = b"user" + if b"kind" in request.args: + kind = request.args[b"kind"][0] - if kind == "guest": + if kind == b"guest": ret = yield self._do_guest_registration(body) defer.returnValue(ret) return - elif kind != "user": + elif kind != b"user": raise UnrecognizedRequestError( "Do not understand membership kind: %s" % (kind,) ) @@ -389,8 +389,8 @@ class RegisterRestServlet(RestServlet): assert_params_in_dict(params, ["password"]) desired_username = params.get("username", None) - new_password = params.get("password", None) guest_access_token = params.get("guest_access_token", None) + new_password = params.get("password", None) if desired_username is not None: desired_username = desired_username.lower() diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py index b25993fcb5..a6189224ee 100644 --- a/synapse/rest/media/v1/media_storage.py +++ b/synapse/rest/media/v1/media_storage.py @@ -177,7 +177,7 @@ class MediaStorage(object): if res: with res: consumer = BackgroundFileConsumer( - open(local_path, "w"), self.hs.get_reactor()) + open(local_path, "wb"), self.hs.get_reactor()) yield res.write_to_consumer(consumer) yield consumer.wait() defer.returnValue(local_path) diff --git a/synapse/state.py b/synapse/state.py index 033f55d967..e1092b97a9 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -577,7 +577,7 @@ def _make_state_cache_entry( def _ordered_events(events): def key_func(e): - return -int(e.depth), hashlib.sha1(e.event_id.encode()).hexdigest() + return -int(e.depth), hashlib.sha1(e.event_id.encode('ascii')).hexdigest() return sorted(events, key=key_func) diff --git a/synapse/storage/events.py b/synapse/storage/events.py index c98e524ba1..61223da1a5 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -67,7 +67,13 @@ state_delta_reuse_delta_counter = Counter( def encode_json(json_object): - return frozendict_json_encoder.encode(json_object) + """ + Encode a Python object as JSON and return it in a Unicode string. + """ + out = frozendict_json_encoder.encode(json_object) + if isinstance(out, bytes): + out = out.decode('utf8') + return out class _EventPeristenceQueue(object): @@ -1058,7 +1064,7 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore metadata_json = encode_json( event.internal_metadata.get_dict() - ).decode("UTF-8") + ) sql = ( "UPDATE event_json SET internal_metadata = ?" @@ -1172,8 +1178,8 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore "room_id": event.room_id, "internal_metadata": encode_json( event.internal_metadata.get_dict() - ).decode("UTF-8"), - "json": encode_json(event_dict(event)).decode("UTF-8"), + ), + "json": encode_json(event_dict(event)), } for event, _ in events_and_contexts ], diff --git a/synapse/storage/signatures.py b/synapse/storage/signatures.py index 470212aa2a..5623391f6e 100644 --- a/synapse/storage/signatures.py +++ b/synapse/storage/signatures.py @@ -74,7 +74,7 @@ class SignatureWorkerStore(SQLBaseStore): txn (cursor): event_id (str): Id for the Event. Returns: - A dict of algorithm -> hash. + A dict[unicode, bytes] of algorithm -> hash. """ query = ( "SELECT algorithm, hash" diff --git a/synapse/types.py b/synapse/types.py index 08f058f714..41afb27a74 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -137,7 +137,7 @@ class DomainSpecificString( @classmethod def from_string(cls, s): """Parse the string given by 's' into a structure object.""" - if len(s) < 1 or s[0] != cls.SIGIL: + if len(s) < 1 or s[0:1] != cls.SIGIL: raise SynapseError(400, "Expected %s string to start with '%s'" % ( cls.__name__, cls.SIGIL, )) diff --git a/synapse/util/frozenutils.py b/synapse/util/frozenutils.py index 581c6052ac..014edea971 100644 --- a/synapse/util/frozenutils.py +++ b/synapse/util/frozenutils.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from six import string_types +from six import binary_type, text_type from canonicaljson import json from frozendict import frozendict @@ -26,7 +26,7 @@ def freeze(o): if isinstance(o, frozendict): return o - if isinstance(o, string_types): + if isinstance(o, (binary_type, text_type)): return o try: @@ -41,7 +41,7 @@ def unfreeze(o): if isinstance(o, (dict, frozendict)): return dict({k: unfreeze(v) for k, v in o.items()}) - if isinstance(o, string_types): + if isinstance(o, (binary_type, text_type)): return o try: diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index 5f158ec4b9..a82d737e71 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -46,7 +46,7 @@ class AuthTestCase(unittest.TestCase): self.auth = Auth(self.hs) self.test_user = "@foo:bar" - self.test_token = "_test_token_" + self.test_token = b"_test_token_" # this is overridden for the appservice tests self.store.get_app_service_by_token = Mock(return_value=None) @@ -61,7 +61,7 @@ class AuthTestCase(unittest.TestCase): self.store.get_user_by_access_token = Mock(return_value=user_info) request = Mock(args={}) - request.args["access_token"] = [self.test_token] + request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() requester = yield self.auth.get_user_by_req(request) self.assertEquals(requester.user.to_string(), self.test_user) @@ -70,7 +70,7 @@ class AuthTestCase(unittest.TestCase): self.store.get_user_by_access_token = Mock(return_value=None) request = Mock(args={}) - request.args["access_token"] = [self.test_token] + request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() d = self.auth.get_user_by_req(request) self.failureResultOf(d, AuthError) @@ -98,7 +98,7 @@ class AuthTestCase(unittest.TestCase): request = Mock(args={}) request.getClientIP.return_value = "127.0.0.1" - request.args["access_token"] = [self.test_token] + request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() requester = yield self.auth.get_user_by_req(request) self.assertEquals(requester.user.to_string(), self.test_user) @@ -115,7 +115,7 @@ class AuthTestCase(unittest.TestCase): request = Mock(args={}) request.getClientIP.return_value = "192.168.10.10" - request.args["access_token"] = [self.test_token] + request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() requester = yield self.auth.get_user_by_req(request) self.assertEquals(requester.user.to_string(), self.test_user) @@ -131,7 +131,7 @@ class AuthTestCase(unittest.TestCase): request = Mock(args={}) request.getClientIP.return_value = "131.111.8.42" - request.args["access_token"] = [self.test_token] + request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() d = self.auth.get_user_by_req(request) self.failureResultOf(d, AuthError) @@ -141,7 +141,7 @@ class AuthTestCase(unittest.TestCase): self.store.get_user_by_access_token = Mock(return_value=None) request = Mock(args={}) - request.args["access_token"] = [self.test_token] + request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() d = self.auth.get_user_by_req(request) self.failureResultOf(d, AuthError) @@ -158,7 +158,7 @@ class AuthTestCase(unittest.TestCase): @defer.inlineCallbacks def test_get_user_by_req_appservice_valid_token_valid_user_id(self): - masquerading_user_id = "@doppelganger:matrix.org" + masquerading_user_id = b"@doppelganger:matrix.org" app_service = Mock( token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None, @@ -169,14 +169,17 @@ class AuthTestCase(unittest.TestCase): request = Mock(args={}) request.getClientIP.return_value = "127.0.0.1" - request.args["access_token"] = [self.test_token] - request.args["user_id"] = [masquerading_user_id] + request.args[b"access_token"] = [self.test_token] + request.args[b"user_id"] = [masquerading_user_id] request.requestHeaders.getRawHeaders = mock_getRawHeaders() requester = yield self.auth.get_user_by_req(request) - self.assertEquals(requester.user.to_string(), masquerading_user_id) + self.assertEquals( + requester.user.to_string(), + masquerading_user_id.decode('utf8') + ) def test_get_user_by_req_appservice_valid_token_bad_user_id(self): - masquerading_user_id = "@doppelganger:matrix.org" + masquerading_user_id = b"@doppelganger:matrix.org" app_service = Mock( token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None, @@ -187,8 +190,8 @@ class AuthTestCase(unittest.TestCase): request = Mock(args={}) request.getClientIP.return_value = "127.0.0.1" - request.args["access_token"] = [self.test_token] - request.args["user_id"] = [masquerading_user_id] + request.args[b"access_token"] = [self.test_token] + request.args[b"user_id"] = [masquerading_user_id] request.requestHeaders.getRawHeaders = mock_getRawHeaders() d = self.auth.get_user_by_req(request) self.failureResultOf(d, AuthError) @@ -418,7 +421,7 @@ class AuthTestCase(unittest.TestCase): # check the token works request = Mock(args={}) - request.args["access_token"] = [token] + request.args[b"access_token"] = [token.encode('ascii')] request.requestHeaders.getRawHeaders = mock_getRawHeaders() requester = yield self.auth.get_user_by_req(request, allow_guest=True) self.assertEqual(UserID.from_string(USER_ID), requester.user) @@ -431,7 +434,7 @@ class AuthTestCase(unittest.TestCase): # the token should *not* work now request = Mock(args={}) - request.args["access_token"] = [guest_tok] + request.args[b"access_token"] = [guest_tok.encode('ascii')] request.requestHeaders.getRawHeaders = mock_getRawHeaders() with self.assertRaises(AuthError) as cm: diff --git a/tests/utils.py b/tests/utils.py index c3dbff8507..9bff3ff3b9 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -193,7 +193,7 @@ class MockHttpResource(HttpServer): self.prefix = prefix def trigger_get(self, path): - return self.trigger("GET", path, None) + return self.trigger(b"GET", path, None) @patch('twisted.web.http.Request') @defer.inlineCallbacks @@ -227,7 +227,7 @@ class MockHttpResource(HttpServer): headers = {} if federation_auth: - headers[b"Authorization"] = ["X-Matrix origin=test,key=,sig="] + headers[b"Authorization"] = [b"X-Matrix origin=test,key=,sig="] mock_request.requestHeaders.getRawHeaders = mock_getRawHeaders(headers) # return the right path if the event requires it @@ -241,6 +241,9 @@ class MockHttpResource(HttpServer): except Exception: pass + if isinstance(path, bytes): + path = path.decode('utf8') + for (method, pattern, func) in self.callbacks: if http_method != method: continue @@ -249,7 +252,7 @@ class MockHttpResource(HttpServer): if matcher: try: args = [ - urlparse.unquote(u).decode("UTF-8") + urlparse.unquote(u) for u in matcher.groups() ] -- cgit 1.5.1 From c82ccd3027dd9599ac0214adfd351996cf658681 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 1 Aug 2018 11:24:19 +0100 Subject: Factor out exception handling in federation_client Factor out the error handling from make_membership_event, send_join, and send_leave, so that it can be shared. --- synapse/federation/federation_client.py | 277 +++++++++++++++++--------------- 1 file changed, 148 insertions(+), 129 deletions(-) (limited to 'synapse/federation') diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 62d7ed13cf..baa9c3586f 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -48,6 +48,13 @@ sent_queries_counter = Counter("synapse_federation_client_sent_queries", "", ["t PDU_RETRY_TIME_MS = 1 * 60 * 1000 +class InvalidResponseError(RuntimeError): + """Helper for _try_destination_list: indicates that the server returned a response + we couldn't parse + """ + pass + + class FederationClient(FederationBase): def __init__(self, hs): super(FederationClient, self).__init__(hs) @@ -458,6 +465,61 @@ class FederationClient(FederationBase): defer.returnValue(signed_auth) @defer.inlineCallbacks + def _try_destination_list(self, description, destinations, callback): + """Try an operation on a series of servers, until it succeeds + + Args: + description (unicode): description of the operation we're doing, for logging + + destinations (Iterable[unicode]): list of server_names to try + + callback (callable): Function to run for each server. Passed a single + argument: the server_name to try. May return a deferred. + + If the callback raises a CodeMessageException with a 300/400 code, + attempts to perform the operation stop immediately and the exception is + reraised. + + Otherwise, if the callback raises an Exception the error is logged and the + next server tried. Normally the stacktrace is logged but this is + suppressed if the exception is an InvalidResponseError. + + Returns: + The [Deferred] result of callback, if it succeeds + + Raises: + CodeMessageException if the chosen remote server returns a 300/400 code. + + RuntimeError if no servers were reachable. + """ + for destination in destinations: + if destination == self.server_name: + continue + + try: + res = yield callback(destination) + defer.returnValue(res) + except InvalidResponseError as e: + logger.warn( + "Failed to %s via %s: %s", + description, destination, e, + ) + except CodeMessageException as e: + if not 500 <= e.code < 600: + raise + else: + logger.warn( + "Failed to %s via %s: %i %s", + description, destination, e.code, e.message, + ) + except Exception: + logger.warn( + "Failed to %s via %s", + description, destination, exc_info=1, + ) + + raise RuntimeError("Failed to %s via any server", description) + def make_membership_event(self, destinations, room_id, user_id, membership, content={},): """ @@ -492,50 +554,35 @@ class FederationClient(FederationBase): "make_membership_event called with membership='%s', must be one of %s" % (membership, ",".join(valid_memberships)) ) - for destination in destinations: - if destination == self.server_name: - continue - try: - ret = yield self.transport_layer.make_membership_event( - destination, room_id, user_id, membership - ) + @defer.inlineCallbacks + def send_request(destination): + ret = yield self.transport_layer.make_membership_event( + destination, room_id, user_id, membership + ) - pdu_dict = ret["event"] + pdu_dict = ret["event"] - logger.debug("Got response to make_%s: %s", membership, pdu_dict) + logger.debug("Got response to make_%s: %s", membership, pdu_dict) - pdu_dict["content"].update(content) + pdu_dict["content"].update(content) - # The protoevent received over the JSON wire may not have all - # the required fields. Lets just gloss over that because - # there's some we never care about - if "prev_state" not in pdu_dict: - pdu_dict["prev_state"] = [] + # The protoevent received over the JSON wire may not have all + # the required fields. Lets just gloss over that because + # there's some we never care about + if "prev_state" not in pdu_dict: + pdu_dict["prev_state"] = [] - ev = builder.EventBuilder(pdu_dict) + ev = builder.EventBuilder(pdu_dict) - defer.returnValue( - (destination, ev) - ) - break - except CodeMessageException as e: - if not 500 <= e.code < 600: - raise - else: - logger.warn( - "Failed to make_%s via %s: %s", - membership, destination, e.message - ) - except Exception as e: - logger.warn( - "Failed to make_%s via %s: %s", - membership, destination, e.message - ) + defer.returnValue( + (destination, ev) + ) - raise RuntimeError("Failed to send to any server.") + return self._try_destination_list( + "make_" + membership, destinations, send_request, + ) - @defer.inlineCallbacks def send_join(self, destinations, pdu): """Sends a join event to one of a list of homeservers. @@ -558,87 +605,70 @@ class FederationClient(FederationBase): Fails with a ``RuntimeError`` if no servers were reachable. """ - for destination in destinations: - if destination == self.server_name: - continue - - try: - time_now = self._clock.time_msec() - _, content = yield self.transport_layer.send_join( - destination=destination, - room_id=pdu.room_id, - event_id=pdu.event_id, - content=pdu.get_pdu_json(time_now), - ) + @defer.inlineCallbacks + def send_request(destination): + time_now = self._clock.time_msec() + _, content = yield self.transport_layer.send_join( + destination=destination, + room_id=pdu.room_id, + event_id=pdu.event_id, + content=pdu.get_pdu_json(time_now), + ) - logger.debug("Got content: %s", content) + logger.debug("Got content: %s", content) - state = [ - event_from_pdu_json(p, outlier=True) - for p in content.get("state", []) - ] + state = [ + event_from_pdu_json(p, outlier=True) + for p in content.get("state", []) + ] - auth_chain = [ - event_from_pdu_json(p, outlier=True) - for p in content.get("auth_chain", []) - ] + auth_chain = [ + event_from_pdu_json(p, outlier=True) + for p in content.get("auth_chain", []) + ] - pdus = { - p.event_id: p - for p in itertools.chain(state, auth_chain) - } + pdus = { + p.event_id: p + for p in itertools.chain(state, auth_chain) + } - valid_pdus = yield self._check_sigs_and_hash_and_fetch( - destination, list(pdus.values()), - outlier=True, - ) + valid_pdus = yield self._check_sigs_and_hash_and_fetch( + destination, list(pdus.values()), + outlier=True, + ) - valid_pdus_map = { - p.event_id: p - for p in valid_pdus - } - - # NB: We *need* to copy to ensure that we don't have multiple - # references being passed on, as that causes... issues. - signed_state = [ - copy.copy(valid_pdus_map[p.event_id]) - for p in state - if p.event_id in valid_pdus_map - ] + valid_pdus_map = { + p.event_id: p + for p in valid_pdus + } - signed_auth = [ - valid_pdus_map[p.event_id] - for p in auth_chain - if p.event_id in valid_pdus_map - ] + # NB: We *need* to copy to ensure that we don't have multiple + # references being passed on, as that causes... issues. + signed_state = [ + copy.copy(valid_pdus_map[p.event_id]) + for p in state + if p.event_id in valid_pdus_map + ] - # NB: We *need* to copy to ensure that we don't have multiple - # references being passed on, as that causes... issues. - for s in signed_state: - s.internal_metadata = copy.deepcopy(s.internal_metadata) + signed_auth = [ + valid_pdus_map[p.event_id] + for p in auth_chain + if p.event_id in valid_pdus_map + ] - auth_chain.sort(key=lambda e: e.depth) + # NB: We *need* to copy to ensure that we don't have multiple + # references being passed on, as that causes... issues. + for s in signed_state: + s.internal_metadata = copy.deepcopy(s.internal_metadata) - defer.returnValue({ - "state": signed_state, - "auth_chain": signed_auth, - "origin": destination, - }) - except CodeMessageException as e: - if not 500 <= e.code < 600: - raise - else: - logger.exception( - "Failed to send_join via %s: %s", - destination, e.message - ) - except Exception as e: - logger.exception( - "Failed to send_join via %s: %s", - destination, e.message - ) + auth_chain.sort(key=lambda e: e.depth) - raise RuntimeError("Failed to send to any server.") + defer.returnValue({ + "state": signed_state, + "auth_chain": signed_auth, + "origin": destination, + }) + return self._try_destination_list("send_join", destinations, send_request) @defer.inlineCallbacks def send_invite(self, destination, room_id, event_id, pdu): @@ -663,7 +693,6 @@ class FederationClient(FederationBase): defer.returnValue(pdu) - @defer.inlineCallbacks def send_leave(self, destinations, pdu): """Sends a leave event to one of a list of homeservers. @@ -681,34 +710,24 @@ class FederationClient(FederationBase): Deferred: resolves to None. Fails with a ``CodeMessageException`` if the chosen remote server - returns a non-200 code. + returns a 300/400 code. Fails with a ``RuntimeError`` if no servers were reachable. """ - for destination in destinations: - if destination == self.server_name: - continue - - try: - time_now = self._clock.time_msec() - _, content = yield self.transport_layer.send_leave( - destination=destination, - room_id=pdu.room_id, - event_id=pdu.event_id, - content=pdu.get_pdu_json(time_now), - ) + @defer.inlineCallbacks + def send_request(destination): + time_now = self._clock.time_msec() + _, content = yield self.transport_layer.send_leave( + destination=destination, + room_id=pdu.room_id, + event_id=pdu.event_id, + content=pdu.get_pdu_json(time_now), + ) - logger.debug("Got content: %s", content) - defer.returnValue(None) - except CodeMessageException: - raise - except Exception as e: - logger.exception( - "Failed to send_leave via %s: %s", - destination, e.message - ) + logger.debug("Got content: %s", content) + defer.returnValue(None) - raise RuntimeError("Failed to send to any server.") + return self._try_destination_list("send_leave", destinations, send_request) def get_public_rooms(self, destination, limit=None, since_token=None, search_filter=None, include_all_networks=False, -- cgit 1.5.1 From fa7dc889f19f581e81624245ce7820525066eff3 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 1 Aug 2018 13:47:07 +0100 Subject: Be more careful which errors we send back over the C-S API We really shouldn't be sending all CodeMessageExceptions back over the C-S API; it will include things like 401s which we shouldn't proxy. That means that we need to explicitly turn a few HttpResponseExceptions into SynapseErrors in the federation layer. The effect of the latter is that the matrix errcode will get passed through correctly to calling clients, which might help with some of the random M_UNKNOWN errors when trying to join rooms. --- synapse/api/errors.py | 11 ----------- synapse/federation/federation_client.py | 29 +++++++++++++++++------------ synapse/http/server.py | 14 +++++--------- 3 files changed, 22 insertions(+), 32 deletions(-) (limited to 'synapse/federation') diff --git a/synapse/api/errors.py b/synapse/api/errors.py index 6074df292f..cf48829c8b 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -69,9 +69,6 @@ class CodeMessageException(RuntimeError): self.code = code self.msg = msg - def error_dict(self): - return cs_error(self.msg) - class MatrixCodeMessageException(CodeMessageException): """An error from a general matrix endpoint, eg. from a proxied Matrix API call. @@ -308,14 +305,6 @@ class LimitExceededError(SynapseError): ) -def cs_exception(exception): - if isinstance(exception, CodeMessageException): - return exception.error_dict() - else: - logger.error("Unknown exception type: %s", type(exception)) - return {} - - def cs_error(msg, code=Codes.UNKNOWN, **kwargs): """ Utility method for constructing an error response for client-server interactions. diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index baa9c3586f..0b09f93ca9 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -488,7 +488,7 @@ class FederationClient(FederationBase): The [Deferred] result of callback, if it succeeds Raises: - CodeMessageException if the chosen remote server returns a 300/400 code. + SynapseError if the chosen remote server returns a 300/400 code. RuntimeError if no servers were reachable. """ @@ -504,9 +504,9 @@ class FederationClient(FederationBase): "Failed to %s via %s: %s", description, destination, e, ) - except CodeMessageException as e: + except HttpResponseException as e: if not 500 <= e.code < 600: - raise + raise SynapseError.from_http_response_exception(e) else: logger.warn( "Failed to %s via %s: %i %s", @@ -543,7 +543,7 @@ class FederationClient(FederationBase): Deferred: resolves to a tuple of (origin (str), event (object)) where origin is the remote homeserver which generated the event. - Fails with a ``CodeMessageException`` if the chosen remote server + Fails with a ``SynapseError`` if the chosen remote server returns a 300/400 code. Fails with a ``RuntimeError`` if no servers were reachable. @@ -599,7 +599,7 @@ class FederationClient(FederationBase): giving the serer the event was sent to, ``state`` (?) and ``auth_chain``. - Fails with a ``CodeMessageException`` if the chosen remote server + Fails with a ``SynapseError`` if the chosen remote server returns a 300/400 code. Fails with a ``RuntimeError`` if no servers were reachable. @@ -673,12 +673,17 @@ class FederationClient(FederationBase): @defer.inlineCallbacks def send_invite(self, destination, room_id, event_id, pdu): time_now = self._clock.time_msec() - code, content = yield self.transport_layer.send_invite( - destination=destination, - room_id=room_id, - event_id=event_id, - content=pdu.get_pdu_json(time_now), - ) + try: + code, content = yield self.transport_layer.send_invite( + destination=destination, + room_id=room_id, + event_id=event_id, + content=pdu.get_pdu_json(time_now), + ) + except HttpResponseException as e: + if e.code == 403: + raise SynapseError.from_http_response_exception(e) + raise pdu_dict = content["event"] @@ -709,7 +714,7 @@ class FederationClient(FederationBase): Return: Deferred: resolves to None. - Fails with a ``CodeMessageException`` if the chosen remote server + Fails with a ``SynapseError`` if the chosen remote server returns a 300/400 code. Fails with a ``RuntimeError`` if no servers were reachable. diff --git a/synapse/http/server.py b/synapse/http/server.py index 1940c1c4f4..6dacb31037 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -36,7 +36,6 @@ from synapse.api.errors import ( Codes, SynapseError, UnrecognizedRequestError, - cs_exception, ) from synapse.http.request_metrics import requests_counter from synapse.util.caches import intern_dict @@ -77,16 +76,13 @@ def wrap_json_request_handler(h): def wrapped_request_handler(self, request): try: yield h(self, request) - except CodeMessageException as e: + except SynapseError as e: code = e.code - if isinstance(e, SynapseError): - logger.info( - "%s SynapseError: %s - %s", request, code, e.msg - ) - else: - logger.exception(e) + logger.info( + "%s SynapseError: %s - %s", request, code, e.msg + ) respond_with_json( - request, code, cs_exception(e), send_cors=True, + request, code, e.error_dict(), send_cors=True, pretty_print=_request_user_agent_is_curl(request), ) -- cgit 1.5.1 From 018d75a148ced6945aca7b095a272e0edba5aae1 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 1 Aug 2018 14:58:16 +0100 Subject: Refactor code for turning HttpResponseException into SynapseError This commit replaces SynapseError.from_http_response_exception with HttpResponseException.to_synapse_error. The new method actually returns a ProxiedRequestError, which allows us to pass through additional metadata from the API call. --- synapse/api/errors.py | 84 +++++++++++++++++++------------ synapse/federation/federation_client.py | 4 +- synapse/rest/media/v1/media_repository.py | 2 +- 3 files changed, 56 insertions(+), 34 deletions(-) (limited to 'synapse/federation') diff --git a/synapse/api/errors.py b/synapse/api/errors.py index cf48829c8b..7476c90ed3 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -105,38 +105,28 @@ class SynapseError(CodeMessageException): self.errcode, ) - @classmethod - def from_http_response_exception(cls, err): - """Make a SynapseError based on an HTTPResponseException - - This is useful when a proxied request has failed, and we need to - decide how to map the failure onto a matrix error to send back to the - client. - An attempt is made to parse the body of the http response as a matrix - error. If that succeeds, the errcode and error message from the body - are used as the errcode and error message in the new synapse error. - - Otherwise, the errcode is set to M_UNKNOWN, and the error message is - set to the reason code from the HTTP response. - - Args: - err (HttpResponseException): +class ProxiedRequestError(SynapseError): + """An error from a general matrix endpoint, eg. from a proxied Matrix API call. - Returns: - SynapseError: - """ - # try to parse the body as json, to get better errcode/msg, but - # default to M_UNKNOWN with the HTTP status as the error text - try: - j = json.loads(err.response) - except ValueError: - j = {} - errcode = j.get('errcode', Codes.UNKNOWN) - errmsg = j.get('error', err.msg) + Attributes: + errcode (str): Matrix error code e.g 'M_FORBIDDEN' + """ + def __init__(self, code, msg, errcode=Codes.UNKNOWN, additional_fields=None): + super(ProxiedRequestError, self).__init__( + code, msg, errcode + ) + if additional_fields is None: + self._additional_fields = {} + else: + self._additional_fields = dict(additional_fields) - res = SynapseError(err.code, errmsg, errcode) - return res + def error_dict(self): + return cs_error( + self.msg, + self.errcode, + **self._additional_fields + ) class ConsentNotGivenError(SynapseError): @@ -361,7 +351,7 @@ class HttpResponseException(CodeMessageException): Represents an HTTP-level failure of an outbound request Attributes: - response (str): body of response + response (bytes): body of response """ def __init__(self, code, msg, response): """ @@ -369,7 +359,39 @@ class HttpResponseException(CodeMessageException): Args: code (int): HTTP status code msg (str): reason phrase from HTTP response status line - response (str): body of response + response (bytes): body of response """ super(HttpResponseException, self).__init__(code, msg) self.response = response + + def to_synapse_error(self): + """Make a SynapseError based on an HTTPResponseException + + This is useful when a proxied request has failed, and we need to + decide how to map the failure onto a matrix error to send back to the + client. + + An attempt is made to parse the body of the http response as a matrix + error. If that succeeds, the errcode and error message from the body + are used as the errcode and error message in the new synapse error. + + Otherwise, the errcode is set to M_UNKNOWN, and the error message is + set to the reason code from the HTTP response. + + Returns: + SynapseError: + """ + # try to parse the body as json, to get better errcode/msg, but + # default to M_UNKNOWN with the HTTP status as the error text + try: + j = json.loads(self.response) + except ValueError: + j = {} + + if not isinstance(j, dict): + j = {} + + errcode = j.pop('errcode', Codes.UNKNOWN) + errmsg = j.pop('error', self.msg) + + return ProxiedRequestError(self.code, errmsg, errcode, j) diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 0b09f93ca9..7550e11b6e 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -506,7 +506,7 @@ class FederationClient(FederationBase): ) except HttpResponseException as e: if not 500 <= e.code < 600: - raise SynapseError.from_http_response_exception(e) + raise e.to_synapse_error() else: logger.warn( "Failed to %s via %s: %i %s", @@ -682,7 +682,7 @@ class FederationClient(FederationBase): ) except HttpResponseException as e: if e.code == 403: - raise SynapseError.from_http_response_exception(e) + raise e.to_synapse_error() raise pdu_dict = content["event"] diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 174ad20123..8fb413d825 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -379,7 +379,7 @@ class MediaRepository(object): logger.warn("HTTP error fetching remote media %s/%s: %s", server_name, media_id, e.response) if e.code == twisted.web.http.NOT_FOUND: - raise SynapseError.from_http_response_exception(e) + raise e.to_synapse_error() raise SynapseError(502, "Failed to fetch remote media") except SynapseError: -- cgit 1.5.1 From 0a65450d044fb580d789013dcdac48b10c930761 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Thu, 2 Aug 2018 11:53:52 +0100 Subject: Validation for events/rooms in fed requests When we get a federation request which refers to an event id, make sure that said event is in the room the caller claims it is in. (patch supplied by @turt2live) --- synapse/federation/federation_server.py | 1 + synapse/handlers/federation.py | 35 ++++++++++++++++++++++++++++++++- synapse/storage/event_federation.py | 29 +++++++++++++++++++++++++++ 3 files changed, 64 insertions(+), 1 deletion(-) (limited to 'synapse/federation') diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 48f26db67c..10e71c78ce 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -425,6 +425,7 @@ class FederationServer(FederationBase): ret = yield self.handler.on_query_auth( origin, event_id, + room_id, signed_auth, content.get("rejects", []), content.get("missing", []), diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 20fb46fc89..12eeb7c4cd 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1349,6 +1349,9 @@ class FederationHandler(BaseHandler): def get_state_for_pdu(self, room_id, event_id): """Returns the state at the event. i.e. not including said event. """ + + yield self._verify_events_in_room([event_id], room_id) + state_groups = yield self.store.get_state_groups( room_id, [event_id] ) @@ -1391,6 +1394,9 @@ class FederationHandler(BaseHandler): def get_state_ids_for_pdu(self, room_id, event_id): """Returns the state at the event. i.e. not including said event. """ + + yield self._verify_events_in_room([event_id], room_id) + state_groups = yield self.store.get_state_groups_ids( room_id, [event_id] ) @@ -1420,6 +1426,8 @@ class FederationHandler(BaseHandler): if not in_room: raise AuthError(403, "Host not in room.") + yield self._verify_events_in_room(pdu_list, room_id) + events = yield self.store.get_backfill_events( room_id, pdu_list, @@ -1706,8 +1714,17 @@ class FederationHandler(BaseHandler): defer.returnValue(context) @defer.inlineCallbacks - def on_query_auth(self, origin, event_id, remote_auth_chain, rejects, + def on_query_auth(self, origin, event_id, room_id, remote_auth_chain, rejects, missing): + in_room = yield self.auth.check_host_in_room( + room_id, + origin + ) + if not in_room: + raise AuthError(403, "Host not in room.") + + yield self._verify_events_in_room([event_id], room_id) + # Just go through and process each event in `remote_auth_chain`. We # don't want to fall into the trap of `missing` being wrong. for e in remote_auth_chain: @@ -2368,3 +2385,19 @@ class FederationHandler(BaseHandler): ) if "valid" not in response or not response["valid"]: raise AuthError(403, "Third party certificate was invalid") + + @defer.inlineCallbacks + def _verify_events_in_room(self, pdu_ids, room_id): + """Checks whether the given PDU IDs are in the given room or not. + + Args: + pdu_ids (list): list of PDU IDs + room_id (str): the room ID that the PDUs should be in + + Raises: + AuthError: if one or more of the PDUs does not belong to the + given room. + """ + room_ids = yield self.store.get_room_ids_for_events(pdu_ids) + if len(room_ids) != 1 or room_ids[0] != room_id: + raise AuthError(403, "Events must belong to the given room") diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py index 8d366d1b91..e860fe1a1e 100644 --- a/synapse/storage/event_federation.py +++ b/synapse/storage/event_federation.py @@ -295,6 +295,35 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, get_forward_extremeties_for_room_txn ) + def get_room_ids_for_events(self, event_ids): + """Get a list of room IDs for which the given events belong. + + Args: + event_ids (list): the events to look up the room of + + Returns: + list, the room IDs for the events + """ + return self.runInteraction( + "get_room_ids_for_events", + self._get_room_ids_for_events, event_ids + ) + + def _get_room_ids_for_events(self, txn, event_ids): + logger.debug("_get_room_ids_for_events: %s", repr(event_ids)) + + base_sql = ( + "SELECT DISTINCT room_id FROM events" + " WHERE event_id IN (%s)" + ) + + txn.execute( + base_sql % (",".join(["?"] * len(event_ids)),), + event_ids + ) + + return [r[0] for r in txn] + def get_backfill_events(self, room_id, event_list, limit): """Get a list of Events for a given topic that occurred before (and including) the events in event_list. Return a list of max size `limit` -- cgit 1.5.1