diff options
author | Erik Johnston <erik@matrix.org> | 2018-08-03 09:25:15 +0100 |
---|---|---|
committer | Erik Johnston <erik@matrix.org> | 2018-08-03 09:25:15 +0100 |
commit | cb298ff623c1ad375084f7687b7f6e546c4c1c1f (patch) | |
tree | 81854fc4ae0a45bed487a58f2c40974b683adb92 /synapse | |
parent | Newsfile (diff) | |
parent | Merge pull request #3645 from matrix-org/michaelkaye/mention_newsfragment (diff) | |
download | synapse-cb298ff623c1ad375084f7687b7f6e546c4c1c1f.tar.xz |
Merge branch 'develop' of github.com:matrix-org/synapse into erikj/refactor_repl_servlet
Diffstat (limited to 'synapse')
40 files changed, 707 insertions, 601 deletions
diff --git a/synapse/__init__.py b/synapse/__init__.py index 5c0f2f83aa..1810cb6fcd 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -17,4 +17,4 @@ """ This is a reference implementation of a Matrix home server. """ -__version__ = "0.33.0" +__version__ = "0.33.1" diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 073229b4c4..5bbbe8e2e7 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -252,10 +252,10 @@ class Auth(object): if ip_address not in app_service.ip_range_whitelist: defer.returnValue((None, None)) - if "user_id" not in request.args: + if b"user_id" not in request.args: defer.returnValue((app_service.sender, app_service)) - user_id = request.args["user_id"][0] + user_id = request.args[b"user_id"][0].decode('utf8') if app_service.sender == user_id: defer.returnValue((app_service.sender, app_service)) diff --git a/synapse/api/errors.py b/synapse/api/errors.py index 6074df292f..b41d595059 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -55,6 +55,7 @@ class Codes(object): SERVER_NOT_TRUSTED = "M_SERVER_NOT_TRUSTED" CONSENT_NOT_GIVEN = "M_CONSENT_NOT_GIVEN" CANNOT_LEAVE_SERVER_NOTICE_ROOM = "M_CANNOT_LEAVE_SERVER_NOTICE_ROOM" + MAU_LIMIT_EXCEEDED = "M_MAU_LIMIT_EXCEEDED" class CodeMessageException(RuntimeError): @@ -69,20 +70,6 @@ class CodeMessageException(RuntimeError): self.code = code self.msg = msg - def error_dict(self): - return cs_error(self.msg) - - -class MatrixCodeMessageException(CodeMessageException): - """An error from a general matrix endpoint, eg. from a proxied Matrix API call. - - Attributes: - errcode (str): Matrix error code e.g 'M_FORBIDDEN' - """ - def __init__(self, code, msg, errcode=Codes.UNKNOWN): - super(MatrixCodeMessageException, self).__init__(code, msg) - self.errcode = errcode - class SynapseError(CodeMessageException): """A base exception type for matrix errors which have an errcode and error @@ -108,38 +95,28 @@ class SynapseError(CodeMessageException): self.errcode, ) - @classmethod - def from_http_response_exception(cls, err): - """Make a SynapseError based on an HTTPResponseException - - This is useful when a proxied request has failed, and we need to - decide how to map the failure onto a matrix error to send back to the - client. - - An attempt is made to parse the body of the http response as a matrix - error. If that succeeds, the errcode and error message from the body - are used as the errcode and error message in the new synapse error. - - Otherwise, the errcode is set to M_UNKNOWN, and the error message is - set to the reason code from the HTTP response. - Args: - err (HttpResponseException): +class ProxiedRequestError(SynapseError): + """An error from a general matrix endpoint, eg. from a proxied Matrix API call. - Returns: - SynapseError: - """ - # try to parse the body as json, to get better errcode/msg, but - # default to M_UNKNOWN with the HTTP status as the error text - try: - j = json.loads(err.response) - except ValueError: - j = {} - errcode = j.get('errcode', Codes.UNKNOWN) - errmsg = j.get('error', err.msg) + Attributes: + errcode (str): Matrix error code e.g 'M_FORBIDDEN' + """ + def __init__(self, code, msg, errcode=Codes.UNKNOWN, additional_fields=None): + super(ProxiedRequestError, self).__init__( + code, msg, errcode + ) + if additional_fields is None: + self._additional_fields = {} + else: + self._additional_fields = dict(additional_fields) - res = SynapseError(err.code, errmsg, errcode) - return res + def error_dict(self): + return cs_error( + self.msg, + self.errcode, + **self._additional_fields + ) class ConsentNotGivenError(SynapseError): @@ -308,14 +285,6 @@ class LimitExceededError(SynapseError): ) -def cs_exception(exception): - if isinstance(exception, CodeMessageException): - return exception.error_dict() - else: - logger.error("Unknown exception type: %s", type(exception)) - return {} - - def cs_error(msg, code=Codes.UNKNOWN, **kwargs): """ Utility method for constructing an error response for client-server interactions. @@ -372,7 +341,7 @@ class HttpResponseException(CodeMessageException): Represents an HTTP-level failure of an outbound request Attributes: - response (str): body of response + response (bytes): body of response """ def __init__(self, code, msg, response): """ @@ -380,7 +349,39 @@ class HttpResponseException(CodeMessageException): Args: code (int): HTTP status code msg (str): reason phrase from HTTP response status line - response (str): body of response + response (bytes): body of response """ super(HttpResponseException, self).__init__(code, msg) self.response = response + + def to_synapse_error(self): + """Make a SynapseError based on an HTTPResponseException + + This is useful when a proxied request has failed, and we need to + decide how to map the failure onto a matrix error to send back to the + client. + + An attempt is made to parse the body of the http response as a matrix + error. If that succeeds, the errcode and error message from the body + are used as the errcode and error message in the new synapse error. + + Otherwise, the errcode is set to M_UNKNOWN, and the error message is + set to the reason code from the HTTP response. + + Returns: + SynapseError: + """ + # try to parse the body as json, to get better errcode/msg, but + # default to M_UNKNOWN with the HTTP status as the error text + try: + j = json.loads(self.response) + except ValueError: + j = {} + + if not isinstance(j, dict): + j = {} + + errcode = j.pop('errcode', Codes.UNKNOWN) + errmsg = j.pop('error', self.msg) + + return ProxiedRequestError(self.code, errmsg, errcode, j) diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 57b815d777..fba51c26e8 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -20,6 +20,8 @@ import sys from six import iteritems +from prometheus_client import Gauge + from twisted.application import service from twisted.internet import defer, reactor from twisted.web.resource import EncodingResourceWrapper, NoResource @@ -300,6 +302,11 @@ class SynapseHomeServer(HomeServer): quit_with_error(e.message) +# Gauges to expose monthly active user control metrics +current_mau_gauge = Gauge("synapse_admin_current_mau", "Current MAU") +max_mau_value_gauge = Gauge("synapse_admin_max_mau_value", "MAU Limit") + + def setup(config_options): """ Args: @@ -512,6 +519,18 @@ def run(hs): # table will decrease clock.looping_call(generate_user_daily_visit_stats, 5 * 60 * 1000) + @defer.inlineCallbacks + def generate_monthly_active_users(): + count = 0 + if hs.config.limit_usage_by_mau: + count = yield hs.get_datastore().count_monthly_users() + current_mau_gauge.set(float(count)) + max_mau_value_gauge.set(float(hs.config.max_mau_value)) + + generate_monthly_active_users() + if hs.config.limit_usage_by_mau: + clock.looping_call(generate_monthly_active_users, 5 * 60 * 1000) + if hs.config.report_stats: logger.info("Scheduling stats reporting for 3 hour intervals") clock.looping_call(start_phone_stats_home, 3 * 60 * 60 * 1000) diff --git a/synapse/config/server.py b/synapse/config/server.py index 18102656b0..6a471a0a5e 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -67,6 +67,14 @@ class ServerConfig(Config): "block_non_admin_invites", False, ) + # Options to control access by tracking MAU + self.limit_usage_by_mau = config.get("limit_usage_by_mau", False) + if self.limit_usage_by_mau: + self.max_mau_value = config.get( + "max_mau_value", 0, + ) + else: + self.max_mau_value = 0 # FIXME: federation_domain_whitelist needs sytests self.federation_domain_whitelist = None federation_domain_whitelist = config.get( @@ -209,6 +217,8 @@ class ServerConfig(Config): # different cores. See # https://www.mirantis.com/blog/improve-performance-python-programs-restricting-single-cpu/. # + # This setting requires the affinity package to be installed! + # # cpu_affinity: 0xFFFFFFFF # Whether to serve a web client from the HTTP/HTTPS root resource. diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 62d7ed13cf..7550e11b6e 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -48,6 +48,13 @@ sent_queries_counter = Counter("synapse_federation_client_sent_queries", "", ["t PDU_RETRY_TIME_MS = 1 * 60 * 1000 +class InvalidResponseError(RuntimeError): + """Helper for _try_destination_list: indicates that the server returned a response + we couldn't parse + """ + pass + + class FederationClient(FederationBase): def __init__(self, hs): super(FederationClient, self).__init__(hs) @@ -458,6 +465,61 @@ class FederationClient(FederationBase): defer.returnValue(signed_auth) @defer.inlineCallbacks + def _try_destination_list(self, description, destinations, callback): + """Try an operation on a series of servers, until it succeeds + + Args: + description (unicode): description of the operation we're doing, for logging + + destinations (Iterable[unicode]): list of server_names to try + + callback (callable): Function to run for each server. Passed a single + argument: the server_name to try. May return a deferred. + + If the callback raises a CodeMessageException with a 300/400 code, + attempts to perform the operation stop immediately and the exception is + reraised. + + Otherwise, if the callback raises an Exception the error is logged and the + next server tried. Normally the stacktrace is logged but this is + suppressed if the exception is an InvalidResponseError. + + Returns: + The [Deferred] result of callback, if it succeeds + + Raises: + SynapseError if the chosen remote server returns a 300/400 code. + + RuntimeError if no servers were reachable. + """ + for destination in destinations: + if destination == self.server_name: + continue + + try: + res = yield callback(destination) + defer.returnValue(res) + except InvalidResponseError as e: + logger.warn( + "Failed to %s via %s: %s", + description, destination, e, + ) + except HttpResponseException as e: + if not 500 <= e.code < 600: + raise e.to_synapse_error() + else: + logger.warn( + "Failed to %s via %s: %i %s", + description, destination, e.code, e.message, + ) + except Exception: + logger.warn( + "Failed to %s via %s", + description, destination, exc_info=1, + ) + + raise RuntimeError("Failed to %s via any server", description) + def make_membership_event(self, destinations, room_id, user_id, membership, content={},): """ @@ -481,7 +543,7 @@ class FederationClient(FederationBase): Deferred: resolves to a tuple of (origin (str), event (object)) where origin is the remote homeserver which generated the event. - Fails with a ``CodeMessageException`` if the chosen remote server + Fails with a ``SynapseError`` if the chosen remote server returns a 300/400 code. Fails with a ``RuntimeError`` if no servers were reachable. @@ -492,50 +554,35 @@ class FederationClient(FederationBase): "make_membership_event called with membership='%s', must be one of %s" % (membership, ",".join(valid_memberships)) ) - for destination in destinations: - if destination == self.server_name: - continue - try: - ret = yield self.transport_layer.make_membership_event( - destination, room_id, user_id, membership - ) + @defer.inlineCallbacks + def send_request(destination): + ret = yield self.transport_layer.make_membership_event( + destination, room_id, user_id, membership + ) - pdu_dict = ret["event"] + pdu_dict = ret["event"] - logger.debug("Got response to make_%s: %s", membership, pdu_dict) + logger.debug("Got response to make_%s: %s", membership, pdu_dict) - pdu_dict["content"].update(content) + pdu_dict["content"].update(content) - # The protoevent received over the JSON wire may not have all - # the required fields. Lets just gloss over that because - # there's some we never care about - if "prev_state" not in pdu_dict: - pdu_dict["prev_state"] = [] + # The protoevent received over the JSON wire may not have all + # the required fields. Lets just gloss over that because + # there's some we never care about + if "prev_state" not in pdu_dict: + pdu_dict["prev_state"] = [] - ev = builder.EventBuilder(pdu_dict) + ev = builder.EventBuilder(pdu_dict) - defer.returnValue( - (destination, ev) - ) - break - except CodeMessageException as e: - if not 500 <= e.code < 600: - raise - else: - logger.warn( - "Failed to make_%s via %s: %s", - membership, destination, e.message - ) - except Exception as e: - logger.warn( - "Failed to make_%s via %s: %s", - membership, destination, e.message - ) + defer.returnValue( + (destination, ev) + ) - raise RuntimeError("Failed to send to any server.") + return self._try_destination_list( + "make_" + membership, destinations, send_request, + ) - @defer.inlineCallbacks def send_join(self, destinations, pdu): """Sends a join event to one of a list of homeservers. @@ -552,103 +599,91 @@ class FederationClient(FederationBase): giving the serer the event was sent to, ``state`` (?) and ``auth_chain``. - Fails with a ``CodeMessageException`` if the chosen remote server + Fails with a ``SynapseError`` if the chosen remote server returns a 300/400 code. Fails with a ``RuntimeError`` if no servers were reachable. """ - for destination in destinations: - if destination == self.server_name: - continue - - try: - time_now = self._clock.time_msec() - _, content = yield self.transport_layer.send_join( - destination=destination, - room_id=pdu.room_id, - event_id=pdu.event_id, - content=pdu.get_pdu_json(time_now), - ) + @defer.inlineCallbacks + def send_request(destination): + time_now = self._clock.time_msec() + _, content = yield self.transport_layer.send_join( + destination=destination, + room_id=pdu.room_id, + event_id=pdu.event_id, + content=pdu.get_pdu_json(time_now), + ) - logger.debug("Got content: %s", content) + logger.debug("Got content: %s", content) - state = [ - event_from_pdu_json(p, outlier=True) - for p in content.get("state", []) - ] + state = [ + event_from_pdu_json(p, outlier=True) + for p in content.get("state", []) + ] - auth_chain = [ - event_from_pdu_json(p, outlier=True) - for p in content.get("auth_chain", []) - ] + auth_chain = [ + event_from_pdu_json(p, outlier=True) + for p in content.get("auth_chain", []) + ] - pdus = { - p.event_id: p - for p in itertools.chain(state, auth_chain) - } + pdus = { + p.event_id: p + for p in itertools.chain(state, auth_chain) + } - valid_pdus = yield self._check_sigs_and_hash_and_fetch( - destination, list(pdus.values()), - outlier=True, - ) + valid_pdus = yield self._check_sigs_and_hash_and_fetch( + destination, list(pdus.values()), + outlier=True, + ) - valid_pdus_map = { - p.event_id: p - for p in valid_pdus - } - - # NB: We *need* to copy to ensure that we don't have multiple - # references being passed on, as that causes... issues. - signed_state = [ - copy.copy(valid_pdus_map[p.event_id]) - for p in state - if p.event_id in valid_pdus_map - ] + valid_pdus_map = { + p.event_id: p + for p in valid_pdus + } - signed_auth = [ - valid_pdus_map[p.event_id] - for p in auth_chain - if p.event_id in valid_pdus_map - ] + # NB: We *need* to copy to ensure that we don't have multiple + # references being passed on, as that causes... issues. + signed_state = [ + copy.copy(valid_pdus_map[p.event_id]) + for p in state + if p.event_id in valid_pdus_map + ] - # NB: We *need* to copy to ensure that we don't have multiple - # references being passed on, as that causes... issues. - for s in signed_state: - s.internal_metadata = copy.deepcopy(s.internal_metadata) + signed_auth = [ + valid_pdus_map[p.event_id] + for p in auth_chain + if p.event_id in valid_pdus_map + ] - auth_chain.sort(key=lambda e: e.depth) + # NB: We *need* to copy to ensure that we don't have multiple + # references being passed on, as that causes... issues. + for s in signed_state: + s.internal_metadata = copy.deepcopy(s.internal_metadata) - defer.returnValue({ - "state": signed_state, - "auth_chain": signed_auth, - "origin": destination, - }) - except CodeMessageException as e: - if not 500 <= e.code < 600: - raise - else: - logger.exception( - "Failed to send_join via %s: %s", - destination, e.message - ) - except Exception as e: - logger.exception( - "Failed to send_join via %s: %s", - destination, e.message - ) + auth_chain.sort(key=lambda e: e.depth) - raise RuntimeError("Failed to send to any server.") + defer.returnValue({ + "state": signed_state, + "auth_chain": signed_auth, + "origin": destination, + }) + return self._try_destination_list("send_join", destinations, send_request) @defer.inlineCallbacks def send_invite(self, destination, room_id, event_id, pdu): time_now = self._clock.time_msec() - code, content = yield self.transport_layer.send_invite( - destination=destination, - room_id=room_id, - event_id=event_id, - content=pdu.get_pdu_json(time_now), - ) + try: + code, content = yield self.transport_layer.send_invite( + destination=destination, + room_id=room_id, + event_id=event_id, + content=pdu.get_pdu_json(time_now), + ) + except HttpResponseException as e: + if e.code == 403: + raise e.to_synapse_error() + raise pdu_dict = content["event"] @@ -663,7 +698,6 @@ class FederationClient(FederationBase): defer.returnValue(pdu) - @defer.inlineCallbacks def send_leave(self, destinations, pdu): """Sends a leave event to one of a list of homeservers. @@ -680,35 +714,25 @@ class FederationClient(FederationBase): Return: Deferred: resolves to None. - Fails with a ``CodeMessageException`` if the chosen remote server - returns a non-200 code. + Fails with a ``SynapseError`` if the chosen remote server + returns a 300/400 code. Fails with a ``RuntimeError`` if no servers were reachable. """ - for destination in destinations: - if destination == self.server_name: - continue - - try: - time_now = self._clock.time_msec() - _, content = yield self.transport_layer.send_leave( - destination=destination, - room_id=pdu.room_id, - event_id=pdu.event_id, - content=pdu.get_pdu_json(time_now), - ) + @defer.inlineCallbacks + def send_request(destination): + time_now = self._clock.time_msec() + _, content = yield self.transport_layer.send_leave( + destination=destination, + room_id=pdu.room_id, + event_id=pdu.event_id, + content=pdu.get_pdu_json(time_now), + ) - logger.debug("Got content: %s", content) - defer.returnValue(None) - except CodeMessageException: - raise - except Exception as e: - logger.exception( - "Failed to send_leave via %s: %s", - destination, e.message - ) + logger.debug("Got content: %s", content) + defer.returnValue(None) - raise RuntimeError("Failed to send to any server.") + return self._try_destination_list("send_leave", destinations, send_request) def get_public_rooms(self, destination, limit=None, since_token=None, search_filter=None, include_all_networks=False, diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index e501251b6e..bf89d568af 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -207,10 +207,6 @@ class FederationServer(FederationBase): edu.content ) - pdu_failures = getattr(transaction, "pdu_failures", []) - for fail in pdu_failures: - logger.info("Got failure %r", fail) - response = { "pdus": pdu_results, } @@ -430,6 +426,7 @@ class FederationServer(FederationBase): ret = yield self.handler.on_query_auth( origin, event_id, + room_id, signed_auth, content.get("rejects", []), content.get("missing", []), diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py index 5157c3860d..0bb468385d 100644 --- a/synapse/federation/send_queue.py +++ b/synapse/federation/send_queue.py @@ -62,8 +62,6 @@ class FederationRemoteSendQueue(object): self.edus = SortedDict() # stream position -> Edu - self.failures = SortedDict() # stream position -> (destination, Failure) - self.device_messages = SortedDict() # stream position -> destination self.pos = 1 @@ -79,7 +77,7 @@ class FederationRemoteSendQueue(object): for queue_name in [ "presence_map", "presence_changed", "keyed_edu", "keyed_edu_changed", - "edus", "failures", "device_messages", "pos_time", + "edus", "device_messages", "pos_time", ]: register(queue_name, getattr(self, queue_name)) @@ -149,12 +147,6 @@ class FederationRemoteSendQueue(object): for key in keys[:i]: del self.edus[key] - # Delete things out of failure map - keys = self.failures.keys() - i = self.failures.bisect_left(position_to_delete) - for key in keys[:i]: - del self.failures[key] - # Delete things out of device map keys = self.device_messages.keys() i = self.device_messages.bisect_left(position_to_delete) @@ -204,13 +196,6 @@ class FederationRemoteSendQueue(object): self.notifier.on_new_replication_data() - def send_failure(self, failure, destination): - """As per TransactionQueue""" - pos = self._next_pos() - - self.failures[pos] = (destination, str(failure)) - self.notifier.on_new_replication_data() - def send_device_messages(self, destination): """As per TransactionQueue""" pos = self._next_pos() @@ -285,17 +270,6 @@ class FederationRemoteSendQueue(object): for (pos, edu) in edus: rows.append((pos, EduRow(edu))) - # Fetch changed failures - i = self.failures.bisect_right(from_token) - j = self.failures.bisect_right(to_token) + 1 - failures = self.failures.items()[i:j] - - for (pos, (destination, failure)) in failures: - rows.append((pos, FailureRow( - destination=destination, - failure=failure, - ))) - # Fetch changed device messages i = self.device_messages.bisect_right(from_token) j = self.device_messages.bisect_right(to_token) + 1 @@ -417,34 +391,6 @@ class EduRow(BaseFederationRow, namedtuple("EduRow", ( buff.edus.setdefault(self.edu.destination, []).append(self.edu) -class FailureRow(BaseFederationRow, namedtuple("FailureRow", ( - "destination", # str - "failure", -))): - """Streams failures to a remote server. Failures are issued when there was - something wrong with a transaction the remote sent us, e.g. it included - an event that was invalid. - """ - - TypeId = "f" - - @staticmethod - def from_data(data): - return FailureRow( - destination=data["destination"], - failure=data["failure"], - ) - - def to_data(self): - return { - "destination": self.destination, - "failure": self.failure, - } - - def add_to_buffer(self, buff): - buff.failures.setdefault(self.destination, []).append(self.failure) - - class DeviceRow(BaseFederationRow, namedtuple("DeviceRow", ( "destination", # str ))): @@ -471,7 +417,6 @@ TypeToRow = { PresenceRow, KeyedEduRow, EduRow, - FailureRow, DeviceRow, ) } @@ -481,7 +426,6 @@ ParsedFederationStreamData = namedtuple("ParsedFederationStreamData", ( "presence", # list(UserPresenceState) "keyed_edus", # dict of destination -> { key -> Edu } "edus", # dict of destination -> [Edu] - "failures", # dict of destination -> [failures] "device_destinations", # set of destinations )) @@ -503,7 +447,6 @@ def process_rows_for_federation(transaction_queue, rows): presence=[], keyed_edus={}, edus={}, - failures={}, device_destinations=set(), ) @@ -532,9 +475,5 @@ def process_rows_for_federation(transaction_queue, rows): edu.destination, edu.edu_type, edu.content, key=None, ) - for destination, failure_list in iteritems(buff.failures): - for failure in failure_list: - transaction_queue.send_failure(destination, failure) - for destination in buff.device_destinations: transaction_queue.send_device_messages(destination) diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py index 6996d6b695..78f9d40a3a 100644 --- a/synapse/federation/transaction_queue.py +++ b/synapse/federation/transaction_queue.py @@ -116,9 +116,6 @@ class TransactionQueue(object): ), ) - # destination -> list of tuple(failure, deferred) - self.pending_failures_by_dest = {} - # destination -> stream_id of last successfully sent to-device message. # NB: may be a long or an int. self.last_device_stream_id_by_dest = {} @@ -382,19 +379,6 @@ class TransactionQueue(object): self._attempt_new_transaction(destination) - def send_failure(self, failure, destination): - if destination == self.server_name or destination == "localhost": - return - - if not self.can_send_to(destination): - return - - self.pending_failures_by_dest.setdefault( - destination, [] - ).append(failure) - - self._attempt_new_transaction(destination) - def send_device_messages(self, destination): if destination == self.server_name or destination == "localhost": return @@ -469,7 +453,6 @@ class TransactionQueue(object): pending_pdus = self.pending_pdus_by_dest.pop(destination, []) pending_edus = self.pending_edus_by_dest.pop(destination, []) pending_presence = self.pending_presence_by_dest.pop(destination, {}) - pending_failures = self.pending_failures_by_dest.pop(destination, []) pending_edus.extend( self.pending_edus_keyed_by_dest.pop(destination, {}).values() @@ -497,7 +480,7 @@ class TransactionQueue(object): logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d", destination, len(pending_pdus)) - if not pending_pdus and not pending_edus and not pending_failures: + if not pending_pdus and not pending_edus: logger.debug("TX [%s] Nothing to send", destination) self.last_device_stream_id_by_dest[destination] = ( device_stream_id @@ -507,7 +490,7 @@ class TransactionQueue(object): # END CRITICAL SECTION success = yield self._send_new_transaction( - destination, pending_pdus, pending_edus, pending_failures, + destination, pending_pdus, pending_edus, ) if success: sent_transactions_counter.inc() @@ -584,14 +567,12 @@ class TransactionQueue(object): @measure_func("_send_new_transaction") @defer.inlineCallbacks - def _send_new_transaction(self, destination, pending_pdus, pending_edus, - pending_failures): + def _send_new_transaction(self, destination, pending_pdus, pending_edus): # Sort based on the order field pending_pdus.sort(key=lambda t: t[1]) pdus = [x[0] for x in pending_pdus] edus = pending_edus - failures = [x.get_dict() for x in pending_failures] success = True @@ -601,11 +582,10 @@ class TransactionQueue(object): logger.debug( "TX [%s] {%s} Attempting new transaction" - " (pdus: %d, edus: %d, failures: %d)", + " (pdus: %d, edus: %d)", destination, txn_id, len(pdus), len(edus), - len(failures) ) logger.debug("TX [%s] Persisting transaction...", destination) @@ -617,7 +597,6 @@ class TransactionQueue(object): destination=destination, pdus=pdus, edus=edus, - pdu_failures=failures, ) self._next_txn_id += 1 @@ -627,12 +606,11 @@ class TransactionQueue(object): logger.debug("TX [%s] Persisted transaction", destination) logger.info( "TX [%s] {%s} Sending transaction [%s]," - " (PDUs: %d, EDUs: %d, failures: %d)", + " (PDUs: %d, EDUs: %d)", destination, txn_id, transaction.transaction_id, len(pdus), len(edus), - len(failures), ) # Actually send the transaction diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 8574898f0c..eae5f2b427 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -165,7 +165,7 @@ def _parse_auth_header(header_bytes): param_dict = dict(kv.split("=") for kv in params) def strip_quotes(value): - if value.startswith(b"\""): + if value.startswith("\""): return value[1:-1] else: return value @@ -283,11 +283,10 @@ class FederationSendServlet(BaseFederationServlet): ) logger.info( - "Received txn %s from %s. (PDUs: %d, EDUs: %d, failures: %d)", + "Received txn %s from %s. (PDUs: %d, EDUs: %d)", transaction_id, origin, len(transaction_data.get("pdus", [])), len(transaction_data.get("edus", [])), - len(transaction_data.get("failures", [])), ) # We should ideally be getting this from the security layer. diff --git a/synapse/federation/units.py b/synapse/federation/units.py index bb1b3b13f7..c5ab14314e 100644 --- a/synapse/federation/units.py +++ b/synapse/federation/units.py @@ -73,7 +73,6 @@ class Transaction(JsonEncodedObject): "previous_ids", "pdus", "edus", - "pdu_failures", ] internal_keys = [ diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 402e44cdef..184eef09d0 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -15,6 +15,7 @@ # limitations under the License. import logging +import unicodedata import attr import bcrypt @@ -519,6 +520,7 @@ class AuthHandler(BaseHandler): """ logger.info("Logging in user %s on device %s", user_id, device_id) access_token = yield self.issue_access_token(user_id, device_id) + yield self._check_mau_limits() # the device *should* have been registered before we got here; however, # it's possible we raced against a DELETE operation. The thing we @@ -626,6 +628,7 @@ class AuthHandler(BaseHandler): # special case to check for "password" for the check_password interface # for the auth providers password = login_submission.get("password") + if login_type == LoginType.PASSWORD: if not self._password_enabled: raise SynapseError(400, "Password login has been disabled.") @@ -707,9 +710,10 @@ class AuthHandler(BaseHandler): multiple inexact matches. Args: - user_id (str): complete @user:id + user_id (unicode): complete @user:id + password (unicode): the provided password Returns: - (str) the canonical_user_id, or None if unknown user / bad password + (unicode) the canonical_user_id, or None if unknown user / bad password """ lookupres = yield self._find_user_id_and_pwd_hash(user_id) if not lookupres: @@ -728,15 +732,18 @@ class AuthHandler(BaseHandler): device_id) defer.returnValue(access_token) + @defer.inlineCallbacks def validate_short_term_login_token_and_get_user_id(self, login_token): + yield self._check_mau_limits() auth_api = self.hs.get_auth() + user_id = None try: macaroon = pymacaroons.Macaroon.deserialize(login_token) user_id = auth_api.get_user_id_from_macaroon(macaroon) auth_api.validate_macaroon(macaroon, "login", True, user_id) - return user_id except Exception: raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN) + defer.returnValue(user_id) @defer.inlineCallbacks def delete_access_token(self, access_token): @@ -849,14 +856,19 @@ class AuthHandler(BaseHandler): """Computes a secure hash of password. Args: - password (str): Password to hash. + password (unicode): Password to hash. Returns: - Deferred(str): Hashed password. + Deferred(unicode): Hashed password. """ def _do_hash(): - return bcrypt.hashpw(password.encode('utf8') + self.hs.config.password_pepper, - bcrypt.gensalt(self.bcrypt_rounds)) + # Normalise the Unicode in the password + pw = unicodedata.normalize("NFKC", password) + + return bcrypt.hashpw( + pw.encode('utf8') + self.hs.config.password_pepper.encode("utf8"), + bcrypt.gensalt(self.bcrypt_rounds), + ).decode('ascii') return make_deferred_yieldable( threads.deferToThreadPool( @@ -868,16 +880,19 @@ class AuthHandler(BaseHandler): """Validates that self.hash(password) == stored_hash. Args: - password (str): Password to hash. - stored_hash (str): Expected hash value. + password (unicode): Password to hash. + stored_hash (unicode): Expected hash value. Returns: Deferred(bool): Whether self.hash(password) == stored_hash. """ def _do_validate_hash(): + # Normalise the Unicode in the password + pw = unicodedata.normalize("NFKC", password) + return bcrypt.checkpw( - password.encode('utf8') + self.hs.config.password_pepper, + pw.encode('utf8') + self.hs.config.password_pepper.encode("utf8"), stored_hash.encode('utf8') ) @@ -892,6 +907,19 @@ class AuthHandler(BaseHandler): else: return defer.succeed(False) + @defer.inlineCallbacks + def _check_mau_limits(self): + """ + Ensure that if mau blocking is enabled that invalid users cannot + log in. + """ + if self.hs.config.limit_usage_by_mau is True: + current_mau = yield self.store.count_monthly_users() + if current_mau >= self.hs.config.max_mau_value: + raise AuthError( + 403, "MAU Limit Exceeded", errcode=Codes.MAU_LIMIT_EXCEEDED + ) + @attr.s class MacaroonGenerator(object): diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index c3f2d7feff..f772e62c28 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -19,10 +19,12 @@ import random from twisted.internet import defer from synapse.api.constants import EventTypes, Membership +from synapse.api.errors import AuthError from synapse.events import EventBase from synapse.events.utils import serialize_event from synapse.types import UserID from synapse.util.logutils import log_function +from synapse.visibility import filter_events_for_client from ._base import BaseHandler @@ -129,11 +131,13 @@ class EventStreamHandler(BaseHandler): class EventHandler(BaseHandler): @defer.inlineCallbacks - def get_event(self, user, event_id): + def get_event(self, user, room_id, event_id): """Retrieve a single specified event. Args: user (synapse.types.UserID): The user requesting the event + room_id (str|None): The expected room id. We'll return None if the + event's room does not match. event_id (str): The event ID to obtain. Returns: dict: An event, or None if there is no event matching this ID. @@ -142,13 +146,26 @@ class EventHandler(BaseHandler): AuthError if the user does not have the rights to inspect this event. """ - event = yield self.store.get_event(event_id) + event = yield self.store.get_event(event_id, check_room_id=room_id) if not event: defer.returnValue(None) return - if hasattr(event, "room_id"): - yield self.auth.check_joined_room(event.room_id, user.to_string()) + users = yield self.store.get_users_in_room(event.room_id) + is_peeking = user.to_string() not in users + + filtered = yield filter_events_for_client( + self.store, + user.to_string(), + [event], + is_peeking=is_peeking + ) + + if not filtered: + raise AuthError( + 403, + "You don't have permission to access that event." + ) defer.returnValue(event) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 49068c06d9..533b82c783 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -76,7 +76,7 @@ class FederationHandler(BaseHandler): self.hs = hs self.store = hs.get_datastore() - self.replication_layer = hs.get_federation_client() + self.federation_client = hs.get_federation_client() self.state_handler = hs.get_state_handler() self.server_name = hs.hostname self.keyring = hs.get_keyring() @@ -255,7 +255,7 @@ class FederationHandler(BaseHandler): # know about for p in prevs - seen: state, got_auth_chain = ( - yield self.replication_layer.get_state_for_room( + yield self.federation_client.get_state_for_room( origin, pdu.room_id, p ) ) @@ -338,7 +338,7 @@ class FederationHandler(BaseHandler): # # see https://github.com/matrix-org/synapse/pull/1744 - missing_events = yield self.replication_layer.get_missing_events( + missing_events = yield self.federation_client.get_missing_events( origin, pdu.room_id, earliest_events_ids=list(latest), @@ -400,7 +400,7 @@ class FederationHandler(BaseHandler): ) try: - event_stream_id, max_stream_id = yield self._persist_auth_tree( + yield self._persist_auth_tree( origin, auth_chain, state, event ) except AuthError as e: @@ -444,7 +444,7 @@ class FederationHandler(BaseHandler): yield self._handle_new_events(origin, event_infos) try: - context, event_stream_id, max_stream_id = yield self._handle_new_event( + context = yield self._handle_new_event( origin, event, state=state, @@ -469,17 +469,6 @@ class FederationHandler(BaseHandler): except StoreError: logger.exception("Failed to store room.") - extra_users = [] - if event.type == EventTypes.Member: - target_user_id = event.state_key - target_user = UserID.from_string(target_user_id) - extra_users.append(target_user) - - self.notifier.on_new_room_event( - event, event_stream_id, max_stream_id, - extra_users=extra_users - ) - if event.type == EventTypes.Member: if event.membership == Membership.JOIN: # Only fire user_joined_room if the user has acutally @@ -501,7 +490,7 @@ class FederationHandler(BaseHandler): if newly_joined: user = UserID.from_string(event.state_key) - yield user_joined_room(self.distributor, user, event.room_id) + yield self.user_joined_room(user, event.room_id) @log_function @defer.inlineCallbacks @@ -522,7 +511,7 @@ class FederationHandler(BaseHandler): if dest == self.server_name: raise SynapseError(400, "Can't backfill from self.") - events = yield self.replication_layer.backfill( + events = yield self.federation_client.backfill( dest, room_id, limit=limit, @@ -570,7 +559,7 @@ class FederationHandler(BaseHandler): state_events = {} events_to_state = {} for e_id in edges: - state, auth = yield self.replication_layer.get_state_for_room( + state, auth = yield self.federation_client.get_state_for_room( destination=dest, room_id=room_id, event_id=e_id @@ -612,7 +601,7 @@ class FederationHandler(BaseHandler): results = yield logcontext.make_deferred_yieldable(defer.gatherResults( [ logcontext.run_in_background( - self.replication_layer.get_pdu, + self.federation_client.get_pdu, [dest], event_id, outlier=True, @@ -893,7 +882,7 @@ class FederationHandler(BaseHandler): Invites must be signed by the invitee's server before distribution. """ - pdu = yield self.replication_layer.send_invite( + pdu = yield self.federation_client.send_invite( destination=target_host, room_id=event.room_id, event_id=event.event_id, @@ -942,7 +931,7 @@ class FederationHandler(BaseHandler): self.room_queues[room_id] = [] - yield self.store.clean_room_for_join(room_id) + yield self._clean_room_for_join(room_id) handled_events = set() @@ -955,7 +944,7 @@ class FederationHandler(BaseHandler): target_hosts.insert(0, origin) except ValueError: pass - ret = yield self.replication_layer.send_join(target_hosts, event) + ret = yield self.federation_client.send_join(target_hosts, event) origin = ret["origin"] state = ret["state"] @@ -981,15 +970,10 @@ class FederationHandler(BaseHandler): # FIXME pass - event_stream_id, max_stream_id = yield self._persist_auth_tree( + yield self._persist_auth_tree( origin, auth_chain, state, event ) - self.notifier.on_new_room_event( - event, event_stream_id, max_stream_id, - extra_users=[joinee] - ) - logger.debug("Finished joining %s to %s", joinee, room_id) finally: room_queue = self.room_queues[room_id] @@ -1084,7 +1068,7 @@ class FederationHandler(BaseHandler): # would introduce the danger of backwards-compatibility problems. event.internal_metadata.send_on_behalf_of = origin - context, event_stream_id, max_stream_id = yield self._handle_new_event( + context = yield self._handle_new_event( origin, event ) @@ -1094,20 +1078,10 @@ class FederationHandler(BaseHandler): event.signatures, ) - extra_users = [] - if event.type == EventTypes.Member: - target_user_id = event.state_key - target_user = UserID.from_string(target_user_id) - extra_users.append(target_user) - - self.notifier.on_new_room_event( - event, event_stream_id, max_stream_id, extra_users=extra_users - ) - if event.type == EventTypes.Member: if event.content["membership"] == Membership.JOIN: user = UserID.from_string(event.state_key) - yield user_joined_room(self.distributor, user, event.room_id) + yield self.user_joined_room(user, event.room_id) prev_state_ids = yield context.get_prev_state_ids(self.store) @@ -1176,17 +1150,7 @@ class FederationHandler(BaseHandler): ) context = yield self.state_handler.compute_event_context(event) - - event_stream_id, max_stream_id = yield self.store.persist_event( - event, - context=context, - ) - - target_user = UserID.from_string(event.state_key) - self.notifier.on_new_room_event( - event, event_stream_id, max_stream_id, - extra_users=[target_user], - ) + yield self._persist_events([(event, context)]) defer.returnValue(event) @@ -1211,30 +1175,20 @@ class FederationHandler(BaseHandler): except ValueError: pass - yield self.replication_layer.send_leave( + yield self.federation_client.send_leave( target_hosts, event ) context = yield self.state_handler.compute_event_context(event) - - event_stream_id, max_stream_id = yield self.store.persist_event( - event, - context=context, - ) - - target_user = UserID.from_string(event.state_key) - self.notifier.on_new_room_event( - event, event_stream_id, max_stream_id, - extra_users=[target_user], - ) + yield self._persist_events([(event, context)]) defer.returnValue(event) @defer.inlineCallbacks def _make_and_verify_event(self, target_hosts, room_id, user_id, membership, content={},): - origin, pdu = yield self.replication_layer.make_membership_event( + origin, pdu = yield self.federation_client.make_membership_event( target_hosts, room_id, user_id, @@ -1318,7 +1272,7 @@ class FederationHandler(BaseHandler): event.internal_metadata.outlier = False - context, event_stream_id, max_stream_id = yield self._handle_new_event( + yield self._handle_new_event( origin, event ) @@ -1328,22 +1282,17 @@ class FederationHandler(BaseHandler): event.signatures, ) - extra_users = [] - if event.type == EventTypes.Member: - target_user_id = event.state_key - target_user = UserID.from_string(target_user_id) - extra_users.append(target_user) - - self.notifier.on_new_room_event( - event, event_stream_id, max_stream_id, extra_users=extra_users - ) - defer.returnValue(None) @defer.inlineCallbacks def get_state_for_pdu(self, room_id, event_id): """Returns the state at the event. i.e. not including said event. """ + + event = yield self.store.get_event( + event_id, allow_none=False, check_room_id=room_id, + ) + state_groups = yield self.store.get_state_groups( room_id, [event_id] ) @@ -1354,8 +1303,7 @@ class FederationHandler(BaseHandler): (e.type, e.state_key): e for e in state } - event = yield self.store.get_event(event_id) - if event and event.is_state(): + if event.is_state(): # Get previous state if "replaces_state" in event.unsigned: prev_id = event.unsigned["replaces_state"] @@ -1374,6 +1322,10 @@ class FederationHandler(BaseHandler): def get_state_ids_for_pdu(self, room_id, event_id): """Returns the state at the event. i.e. not including said event. """ + event = yield self.store.get_event( + event_id, allow_none=False, check_room_id=room_id, + ) + state_groups = yield self.store.get_state_groups_ids( room_id, [event_id] ) @@ -1382,8 +1334,7 @@ class FederationHandler(BaseHandler): _, state = state_groups.items().pop() results = state - event = yield self.store.get_event(event_id) - if event and event.is_state(): + if event.is_state(): # Get previous state if "replaces_state" in event.unsigned: prev_id = event.unsigned["replaces_state"] @@ -1472,9 +1423,8 @@ class FederationHandler(BaseHandler): event, context ) - event_stream_id, max_stream_id = yield self.store.persist_event( - event, - context=context, + yield self._persist_events( + [(event, context)], backfilled=backfilled, ) except: # noqa: E722, as we reraise the exception this is fine. @@ -1487,15 +1437,7 @@ class FederationHandler(BaseHandler): six.reraise(tp, value, tb) - if not backfilled: - # this intentionally does not yield: we don't care about the result - # and don't need to wait for it. - logcontext.run_in_background( - self.pusher_pool.on_new_notifications, - event_stream_id, max_stream_id, - ) - - defer.returnValue((context, event_stream_id, max_stream_id)) + defer.returnValue(context) @defer.inlineCallbacks def _handle_new_events(self, origin, event_infos, backfilled=False): @@ -1503,6 +1445,8 @@ class FederationHandler(BaseHandler): should not depend on one another, e.g. this should be used to persist a bunch of outliers, but not a chunk of individual events that depend on each other for state calculations. + + Notifies about the events where appropriate. """ contexts = yield logcontext.make_deferred_yieldable(defer.gatherResults( [ @@ -1517,7 +1461,7 @@ class FederationHandler(BaseHandler): ], consumeErrors=True, )) - yield self.store.persist_events( + yield self._persist_events( [ (ev_info["event"], context) for ev_info, context in zip(event_infos, contexts) @@ -1529,7 +1473,8 @@ class FederationHandler(BaseHandler): def _persist_auth_tree(self, origin, auth_events, state, event): """Checks the auth chain is valid (and passes auth checks) for the state and event. Then persists the auth chain and state atomically. - Persists the event seperately. + Persists the event separately. Notifies about the persisted events + where appropriate. Will attempt to fetch missing auth events. @@ -1540,8 +1485,7 @@ class FederationHandler(BaseHandler): event (Event) Returns: - 2-tuple of (event_stream_id, max_stream_id) from the persist_event - call for `event` + Deferred """ events_to_context = {} for e in itertools.chain(auth_events, state): @@ -1567,7 +1511,7 @@ class FederationHandler(BaseHandler): missing_auth_events.add(e_id) for e_id in missing_auth_events: - m_ev = yield self.replication_layer.get_pdu( + m_ev = yield self.federation_client.get_pdu( [origin], e_id, outlier=True, @@ -1605,7 +1549,7 @@ class FederationHandler(BaseHandler): raise events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR - yield self.store.persist_events( + yield self._persist_events( [ (e, events_to_context[e.event_id]) for e in itertools.chain(auth_events, state) @@ -1616,12 +1560,10 @@ class FederationHandler(BaseHandler): event, old_state=state ) - event_stream_id, max_stream_id = yield self.store.persist_event( - event, new_event_context, + yield self._persist_events( + [(event, new_event_context)], ) - defer.returnValue((event_stream_id, max_stream_id)) - @defer.inlineCallbacks def _prep_event(self, origin, event, state=None, auth_events=None): """ @@ -1678,8 +1620,19 @@ class FederationHandler(BaseHandler): defer.returnValue(context) @defer.inlineCallbacks - def on_query_auth(self, origin, event_id, remote_auth_chain, rejects, + def on_query_auth(self, origin, event_id, room_id, remote_auth_chain, rejects, missing): + in_room = yield self.auth.check_host_in_room( + room_id, + origin + ) + if not in_room: + raise AuthError(403, "Host not in room.") + + event = yield self.store.get_event( + event_id, allow_none=False, check_room_id=room_id + ) + # Just go through and process each event in `remote_auth_chain`. We # don't want to fall into the trap of `missing` being wrong. for e in remote_auth_chain: @@ -1689,7 +1642,6 @@ class FederationHandler(BaseHandler): pass # Now get the current auth_chain for the event. - event = yield self.store.get_event(event_id) local_auth_chain = yield self.store.get_auth_chain( [auth_id for auth_id, _ in event.auth_events], include_given=True @@ -1777,7 +1729,7 @@ class FederationHandler(BaseHandler): logger.info("Missing auth: %s", missing_auth) # If we don't have all the auth events, we need to get them. try: - remote_auth_chain = yield self.replication_layer.get_event_auth( + remote_auth_chain = yield self.federation_client.get_event_auth( origin, event.room_id, event.event_id ) @@ -1893,7 +1845,7 @@ class FederationHandler(BaseHandler): try: # 2. Get remote difference. - result = yield self.replication_layer.query_auth( + result = yield self.federation_client.query_auth( origin, event.room_id, event.event_id, @@ -2192,7 +2144,7 @@ class FederationHandler(BaseHandler): yield member_handler.send_membership_event(None, event, context) else: destinations = set(x.split(":", 1)[-1] for x in (sender_user_id, room_id)) - yield self.replication_layer.forward_third_party_invite( + yield self.federation_client.forward_third_party_invite( destinations, room_id, event_dict, @@ -2347,3 +2299,69 @@ class FederationHandler(BaseHandler): ) if "valid" not in response or not response["valid"]: raise AuthError(403, "Third party certificate was invalid") + + @defer.inlineCallbacks + def _persist_events(self, event_and_contexts, backfilled=False): + """Persists events and tells the notifier/pushers about them, if + necessary. + + Args: + event_and_contexts(list[tuple[FrozenEvent, EventContext]]) + backfilled (bool): Whether these events are a result of + backfilling or not + + Returns: + Deferred + """ + max_stream_id = yield self.store.persist_events( + event_and_contexts, + backfilled=backfilled, + ) + + if not backfilled: # Never notify for backfilled events + for event, _ in event_and_contexts: + self._notify_persisted_event(event, max_stream_id) + + def _notify_persisted_event(self, event, max_stream_id): + """Checks to see if notifier/pushers should be notified about the + event or not. + + Args: + event (FrozenEvent) + max_stream_id (int): The max_stream_id returned by persist_events + """ + + extra_users = [] + if event.type == EventTypes.Member: + target_user_id = event.state_key + + # We notify for memberships if its an invite for one of our + # users + if event.internal_metadata.is_outlier(): + if event.membership != Membership.INVITE: + if not self.is_mine_id(target_user_id): + return + + target_user = UserID.from_string(target_user_id) + extra_users.append(target_user) + elif event.internal_metadata.is_outlier(): + return + + event_stream_id = event.internal_metadata.stream_ordering + self.notifier.on_new_room_event( + event, event_stream_id, max_stream_id, + extra_users=extra_users + ) + + logcontext.run_in_background( + self.pusher_pool.on_new_notifications, + event_stream_id, max_stream_id, + ) + + def _clean_room_for_join(self, room_id): + return self.store.clean_room_for_join(room_id) + + def user_joined_room(self, user, room_id): + """Called when a new user has joined the room + """ + return user_joined_room(self.distributor, user, room_id) diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index 8c8aedb2b8..1d36d967c3 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -26,7 +26,7 @@ from twisted.internet import defer from synapse.api.errors import ( CodeMessageException, Codes, - MatrixCodeMessageException, + HttpResponseException, SynapseError, ) @@ -85,7 +85,6 @@ class IdentityHandler(BaseHandler): ) defer.returnValue(None) - data = {} try: data = yield self.http_client.get_json( "https://%s%s" % ( @@ -94,11 +93,9 @@ class IdentityHandler(BaseHandler): ), {'sid': creds['sid'], 'client_secret': client_secret} ) - except MatrixCodeMessageException as e: + except HttpResponseException as e: logger.info("getValidated3pid failed with Matrix error: %r", e) - raise SynapseError(e.code, e.msg, e.errcode) - except CodeMessageException as e: - data = json.loads(e.msg) + raise e.to_synapse_error() if 'medium' in data: defer.returnValue(data) @@ -136,7 +133,7 @@ class IdentityHandler(BaseHandler): ) logger.debug("bound threepid %r to %s", creds, mxid) except CodeMessageException as e: - data = json.loads(e.msg) + data = json.loads(e.msg) # XXX WAT? defer.returnValue(data) @defer.inlineCallbacks @@ -209,12 +206,9 @@ class IdentityHandler(BaseHandler): params ) defer.returnValue(data) - except MatrixCodeMessageException as e: - logger.info("Proxied requestToken failed with Matrix error: %r", e) - raise SynapseError(e.code, e.msg, e.errcode) - except CodeMessageException as e: + except HttpResponseException as e: logger.info("Proxied requestToken failed: %r", e) - raise e + raise e.to_synapse_error() @defer.inlineCallbacks def requestMsisdnToken( @@ -244,9 +238,6 @@ class IdentityHandler(BaseHandler): params ) defer.returnValue(data) - except MatrixCodeMessageException as e: - logger.info("Proxied requestToken failed with Matrix error: %r", e) - raise SynapseError(e.code, e.msg, e.errcode) - except CodeMessageException as e: + except HttpResponseException as e: logger.info("Proxied requestToken failed: %r", e) - raise e + raise e.to_synapse_error() diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 7caff0cbc8..289704b241 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -45,7 +45,7 @@ class RegistrationHandler(BaseHandler): hs (synapse.server.HomeServer): """ super(RegistrationHandler, self).__init__(hs) - + self.hs = hs self.auth = hs.get_auth() self._auth_handler = hs.get_auth_handler() self.profile_handler = hs.get_profile_handler() @@ -131,7 +131,7 @@ class RegistrationHandler(BaseHandler): Args: localpart : The local part of the user ID to register. If None, one will be generated. - password (str) : The password to assign to this user so they can + password (unicode) : The password to assign to this user so they can login again. This can be None which means they cannot login again via a password (e.g. the user is an application service user). generate_token (bool): Whether a new access token should be @@ -144,6 +144,7 @@ class RegistrationHandler(BaseHandler): Raises: RegistrationError if there was a problem registering. """ + yield self._check_mau_limits() password_hash = None if password: password_hash = yield self.auth_handler().hash(password) @@ -288,6 +289,7 @@ class RegistrationHandler(BaseHandler): 400, "User ID can only contain characters a-z, 0-9, or '=_-./'", ) + yield self._check_mau_limits() user = UserID(localpart, self.hs.hostname) user_id = user.to_string() @@ -437,7 +439,7 @@ class RegistrationHandler(BaseHandler): """ if localpart is None: raise SynapseError(400, "Request must include user id") - + yield self._check_mau_limits() need_register = True try: @@ -531,3 +533,16 @@ class RegistrationHandler(BaseHandler): remote_room_hosts=remote_room_hosts, action="join", ) + + @defer.inlineCallbacks + def _check_mau_limits(self): + """ + Do not accept registrations if monthly active user limits exceeded + and limiting is enabled + """ + if self.hs.config.limit_usage_by_mau is True: + current_mau = yield self.store.count_monthly_users() + if current_mau >= self.hs.config.max_mau_value: + raise RegistrationError( + 403, "MAU Limit Exceeded", Codes.MAU_LIMIT_EXCEEDED + ) diff --git a/synapse/http/client.py b/synapse/http/client.py index 25b6307884..3771e0b3f6 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -39,12 +39,7 @@ from twisted.web.client import ( from twisted.web.http import PotentialDataLoss from twisted.web.http_headers import Headers -from synapse.api.errors import ( - CodeMessageException, - Codes, - MatrixCodeMessageException, - SynapseError, -) +from synapse.api.errors import Codes, HttpResponseException, SynapseError from synapse.http import cancelled_to_request_timed_out_error, redact_uri from synapse.http.endpoint import SpiderEndpoint from synapse.util.async import add_timeout_to_deferred @@ -132,6 +127,11 @@ class SimpleHttpClient(object): Returns: Deferred[object]: parsed json + + Raises: + HttpResponseException: On a non-2xx HTTP response. + + ValueError: if the response was not JSON """ # TODO: Do we ever want to log message contents? @@ -155,7 +155,10 @@ class SimpleHttpClient(object): body = yield make_deferred_yieldable(readBody(response)) - defer.returnValue(json.loads(body)) + if 200 <= response.code < 300: + defer.returnValue(json.loads(body)) + else: + raise HttpResponseException(response.code, response.phrase, body) @defer.inlineCallbacks def post_json_get_json(self, uri, post_json, headers=None): @@ -169,6 +172,11 @@ class SimpleHttpClient(object): Returns: Deferred[object]: parsed json + + Raises: + HttpResponseException: On a non-2xx HTTP response. + + ValueError: if the response was not JSON """ json_str = encode_canonical_json(post_json) @@ -193,9 +201,7 @@ class SimpleHttpClient(object): if 200 <= response.code < 300: defer.returnValue(json.loads(body)) else: - raise self._exceptionFromFailedRequest(response, body) - - defer.returnValue(json.loads(body)) + raise HttpResponseException(response.code, response.phrase, body) @defer.inlineCallbacks def get_json(self, uri, args={}, headers=None): @@ -213,14 +219,12 @@ class SimpleHttpClient(object): Deferred: Succeeds when we get *any* 2xx HTTP response, with the HTTP body as JSON. Raises: - On a non-2xx HTTP response. The response body will be used as the - error message. + HttpResponseException On a non-2xx HTTP response. + + ValueError: if the response was not JSON """ - try: - body = yield self.get_raw(uri, args, headers=headers) - defer.returnValue(json.loads(body)) - except CodeMessageException as e: - raise self._exceptionFromFailedRequest(e.code, e.msg) + body = yield self.get_raw(uri, args, headers=headers) + defer.returnValue(json.loads(body)) @defer.inlineCallbacks def put_json(self, uri, json_body, args={}, headers=None): @@ -239,7 +243,9 @@ class SimpleHttpClient(object): Deferred: Succeeds when we get *any* 2xx HTTP response, with the HTTP body as JSON. Raises: - On a non-2xx HTTP response. + HttpResponseException On a non-2xx HTTP response. + + ValueError: if the response was not JSON """ if len(args): query_bytes = urllib.urlencode(args, True) @@ -266,10 +272,7 @@ class SimpleHttpClient(object): if 200 <= response.code < 300: defer.returnValue(json.loads(body)) else: - # NB: This is explicitly not json.loads(body)'d because the contract - # of CodeMessageException is a *string* message. Callers can always - # load it into JSON if they want. - raise CodeMessageException(response.code, body) + raise HttpResponseException(response.code, response.phrase, body) @defer.inlineCallbacks def get_raw(self, uri, args={}, headers=None): @@ -287,8 +290,7 @@ class SimpleHttpClient(object): Deferred: Succeeds when we get *any* 2xx HTTP response, with the HTTP body at text. Raises: - On a non-2xx HTTP response. The response body will be used as the - error message. + HttpResponseException on a non-2xx HTTP response. """ if len(args): query_bytes = urllib.urlencode(args, True) @@ -311,16 +313,7 @@ class SimpleHttpClient(object): if 200 <= response.code < 300: defer.returnValue(body) else: - raise CodeMessageException(response.code, body) - - def _exceptionFromFailedRequest(self, response, body): - try: - jsonBody = json.loads(body) - errcode = jsonBody['errcode'] - error = jsonBody['error'] - return MatrixCodeMessageException(response.code, error, errcode) - except (ValueError, KeyError): - return CodeMessageException(response.code, body) + raise HttpResponseException(response.code, response.phrase, body) # XXX: FIXME: This is horribly copy-pasted from matrixfederationclient. # The two should be factored out. diff --git a/synapse/http/server.py b/synapse/http/server.py index c70fdbdfd2..6dacb31037 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -13,12 +13,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import cgi import collections import logging -import urllib -from six.moves import http_client +from six import PY3 +from six.moves import http_client, urllib from canonicaljson import encode_canonical_json, encode_pretty_printed_json, json @@ -35,7 +36,6 @@ from synapse.api.errors import ( Codes, SynapseError, UnrecognizedRequestError, - cs_exception, ) from synapse.http.request_metrics import requests_counter from synapse.util.caches import intern_dict @@ -76,16 +76,13 @@ def wrap_json_request_handler(h): def wrapped_request_handler(self, request): try: yield h(self, request) - except CodeMessageException as e: + except SynapseError as e: code = e.code - if isinstance(e, SynapseError): - logger.info( - "%s SynapseError: %s - %s", request, code, e.msg - ) - else: - logger.exception(e) + logger.info( + "%s SynapseError: %s - %s", request, code, e.msg + ) respond_with_json( - request, code, cs_exception(e), send_cors=True, + request, code, e.error_dict(), send_cors=True, pretty_print=_request_user_agent_is_curl(request), ) @@ -264,6 +261,7 @@ class JsonResource(HttpServer, resource.Resource): self.hs = hs def register_paths(self, method, path_patterns, callback): + method = method.encode("utf-8") # method is bytes on py3 for path_pattern in path_patterns: logger.debug("Registering for %s %s", method, path_pattern.pattern) self.path_regexs.setdefault(method, []).append( @@ -296,8 +294,19 @@ class JsonResource(HttpServer, resource.Resource): # here. If it throws an exception, that is handled by the wrapper # installed by @request_handler. + def _unquote(s): + if PY3: + # On Python 3, unquote is unicode -> unicode + return urllib.parse.unquote(s) + else: + # On Python 2, unquote is bytes -> bytes We need to encode the + # URL again (as it was decoded by _get_handler_for request), as + # ASCII because it's a URL, and then decode it to get the UTF-8 + # characters that were quoted. + return urllib.parse.unquote(s.encode('ascii')).decode('utf8') + kwargs = intern_dict({ - name: urllib.unquote(value).decode("UTF-8") if value else value + name: _unquote(value) if value else value for name, value in group_dict.items() }) @@ -313,9 +322,9 @@ class JsonResource(HttpServer, resource.Resource): request (twisted.web.http.Request): Returns: - Tuple[Callable, dict[str, str]]: callback method, and the dict - mapping keys to path components as specified in the handler's - path match regexp. + Tuple[Callable, dict[unicode, unicode]]: callback method, and the + dict mapping keys to path components as specified in the + handler's path match regexp. The callback will normally be a method registered via register_paths, so will return (possibly via Deferred) either @@ -327,7 +336,7 @@ class JsonResource(HttpServer, resource.Resource): # Loop through all the registered callbacks to check if the method # and path regex match for path_entry in self.path_regexs.get(request.method, []): - m = path_entry.pattern.match(request.path) + m = path_entry.pattern.match(request.path.decode('ascii')) if m: # We found a match! return path_entry.callback, m.groupdict() @@ -383,7 +392,7 @@ class RootRedirect(resource.Resource): self.url = path def render_GET(self, request): - return redirectTo(self.url, request) + return redirectTo(self.url.encode('ascii'), request) def getChild(self, name, request): if len(name) == 0: @@ -404,12 +413,14 @@ def respond_with_json(request, code, json_object, send_cors=False, return if pretty_print: - json_bytes = encode_pretty_printed_json(json_object) + "\n" + json_bytes = (encode_pretty_printed_json(json_object) + "\n" + ).encode("utf-8") else: if canonical_json or synapse.events.USE_FROZEN_DICTS: + # canonicaljson already encodes to bytes json_bytes = encode_canonical_json(json_object) else: - json_bytes = json.dumps(json_object) + json_bytes = json.dumps(json_object).encode("utf-8") return respond_with_json_bytes( request, code, json_bytes, diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index 882816dc8f..69f7085291 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -171,8 +171,16 @@ def parse_json_value_from_request(request, allow_empty_body=False): if not content_bytes and allow_empty_body: return None + # Decode to Unicode so that simplejson will return Unicode strings on + # Python 2 try: - content = json.loads(content_bytes) + content_unicode = content_bytes.decode('utf8') + except UnicodeDecodeError: + logger.warn("Unable to decode UTF-8") + raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON) + + try: + content = json.loads(content_unicode) except Exception as e: logger.warn("Unable to parse JSON: %s", e) raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON) diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index 24f00d95fc..a7d1b2dabe 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -23,8 +23,7 @@ from twisted.internet import defer from synapse.api.errors import ( CodeMessageException, - MatrixCodeMessageException, - SynapseError, + HttpResponseException, ) from synapse.util.caches.response_cache import ResponseCache from synapse.util.stringutils import random_string @@ -160,11 +159,11 @@ class ReplicationEndpoint(object): # If we timed out we probably don't need to worry about backing # off too much, but lets just wait a little anyway. yield clock.sleep(1) - except MatrixCodeMessageException as e: + except HttpResponseException as e: # We convert to SynapseError as we know that it was a SynapseError # on the master process that we should send to the client. (And # importantly, not stack traces everywhere) - raise SynapseError(e.code, e.msg, e.errcode) + raise e.to_synapse_error() defer.returnValue(result) diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py index 99f6c6e3c3..80d625eecc 100644 --- a/synapse/rest/client/v1/admin.py +++ b/synapse/rest/client/v1/admin.py @@ -18,6 +18,7 @@ import hashlib import hmac import logging +from six import text_type from six.moves import http_client from twisted.internet import defer @@ -131,7 +132,10 @@ class UserRegisterServlet(ClientV1RestServlet): 400, "username must be specified", errcode=Codes.BAD_JSON, ) else: - if (not isinstance(body['username'], str) or len(body['username']) > 512): + if ( + not isinstance(body['username'], text_type) + or len(body['username']) > 512 + ): raise SynapseError(400, "Invalid username") username = body["username"].encode("utf-8") @@ -143,7 +147,10 @@ class UserRegisterServlet(ClientV1RestServlet): 400, "password must be specified", errcode=Codes.BAD_JSON, ) else: - if (not isinstance(body['password'], str) or len(body['password']) > 512): + if ( + not isinstance(body['password'], text_type) + or len(body['password']) > 512 + ): raise SynapseError(400, "Invalid password") password = body["password"].encode("utf-8") @@ -166,17 +173,18 @@ class UserRegisterServlet(ClientV1RestServlet): want_mac.update(b"admin" if admin else b"notadmin") want_mac = want_mac.hexdigest() - if not hmac.compare_digest(want_mac, got_mac): - raise SynapseError( - 403, "HMAC incorrect", - ) + if not hmac.compare_digest(want_mac, got_mac.encode('ascii')): + raise SynapseError(403, "HMAC incorrect") # Reuse the parts of RegisterRestServlet to reduce code duplication from synapse.rest.client.v2_alpha.register import RegisterRestServlet + register = RegisterRestServlet(self.hs) (user_id, _) = yield register.registration_handler.register( - localpart=username.lower(), password=password, admin=bool(admin), + localpart=body['username'].lower(), + password=body["password"], + admin=bool(admin), generate_token=False, ) diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py index 69dcd618cb..97733f3026 100644 --- a/synapse/rest/client/v1/directory.py +++ b/synapse/rest/client/v1/directory.py @@ -18,7 +18,7 @@ import logging from twisted.internet import defer -from synapse.api.errors import AuthError, Codes, SynapseError +from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError from synapse.http.servlet import parse_json_object_from_request from synapse.types import RoomAlias @@ -159,7 +159,7 @@ class ClientDirectoryListServer(ClientV1RestServlet): def on_GET(self, request, room_id): room = yield self.store.get_room(room_id) if room is None: - raise SynapseError(400, "Unknown room") + raise NotFoundError("Unknown room") defer.returnValue((200, { "visibility": "public" if room["is_public"] else "private" diff --git a/synapse/rest/client/v1/events.py b/synapse/rest/client/v1/events.py index b70c9c2806..0f3a2e8b51 100644 --- a/synapse/rest/client/v1/events.py +++ b/synapse/rest/client/v1/events.py @@ -88,7 +88,7 @@ class EventRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, event_id): requester = yield self.auth.get_user_by_req(request) - event = yield self.event_handler.get_event(requester.user, event_id) + event = yield self.event_handler.get_event(requester.user, None, event_id) time_now = self.clock.time_msec() if event: diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 13c331550b..fa5989e74e 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -506,7 +506,7 @@ class RoomEventServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, room_id, event_id): requester = yield self.auth.get_user_by_req(request) - event = yield self.event_handler.get_event(requester.user, event_id) + event = yield self.event_handler.get_event(requester.user, room_id, event_id) time_now = self.clock.time_msec() if event: diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index d6cf915d86..2f64155d13 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -193,15 +193,15 @@ class RegisterRestServlet(RestServlet): def on_POST(self, request): body = parse_json_object_from_request(request) - kind = "user" - if "kind" in request.args: - kind = request.args["kind"][0] + kind = b"user" + if b"kind" in request.args: + kind = request.args[b"kind"][0] - if kind == "guest": + if kind == b"guest": ret = yield self._do_guest_registration(body) defer.returnValue(ret) return - elif kind != "user": + elif kind != b"user": raise UnrecognizedRequestError( "Do not understand membership kind: %s" % (kind,) ) @@ -389,8 +389,8 @@ class RegisterRestServlet(RestServlet): assert_params_in_dict(params, ["password"]) desired_username = params.get("username", None) - new_password = params.get("password", None) guest_access_token = params.get("guest_access_token", None) + new_password = params.get("password", None) if desired_username is not None: desired_username = desired_username.lower() diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 174ad20123..8fb413d825 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -379,7 +379,7 @@ class MediaRepository(object): logger.warn("HTTP error fetching remote media %s/%s: %s", server_name, media_id, e.response) if e.code == twisted.web.http.NOT_FOUND: - raise SynapseError.from_http_response_exception(e) + raise e.to_synapse_error() raise SynapseError(502, "Failed to fetch remote media") except SynapseError: diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py index b25993fcb5..a6189224ee 100644 --- a/synapse/rest/media/v1/media_storage.py +++ b/synapse/rest/media/v1/media_storage.py @@ -177,7 +177,7 @@ class MediaStorage(object): if res: with res: consumer = BackgroundFileConsumer( - open(local_path, "w"), self.hs.get_reactor()) + open(local_path, "wb"), self.hs.get_reactor()) yield res.write_to_consumer(consumer) yield consumer.wait() defer.returnValue(local_path) diff --git a/synapse/state.py b/synapse/state.py index 033f55d967..e1092b97a9 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -577,7 +577,7 @@ def _make_state_cache_entry( def _ordered_events(events): def key_func(e): - return -int(e.depth), hashlib.sha1(e.event_id.encode()).hexdigest() + return -int(e.depth), hashlib.sha1(e.event_id.encode('ascii')).hexdigest() return sorted(events, key=key_func) diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index ba88a54979..134e4a80f1 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -66,6 +66,7 @@ class DataStore(RoomMemberStore, RoomStore, PresenceStore, TransactionStore, DirectoryStore, KeyStore, StateStore, SignatureStore, ApplicationServiceStore, + EventsStore, EventFederationStore, MediaRepositoryStore, RejectionsStore, @@ -73,7 +74,6 @@ class DataStore(RoomMemberStore, RoomStore, PusherStore, PushRuleStore, ApplicationServiceTransactionStore, - EventsStore, ReceiptsStore, EndToEndKeyStore, SearchStore, @@ -94,6 +94,7 @@ class DataStore(RoomMemberStore, RoomStore, self._clock = hs.get_clock() self.database_engine = hs.database_engine + self.db_conn = db_conn self._stream_id_gen = StreamIdGenerator( db_conn, "events", "stream_ordering", extra_tables=[("local_invites", "stream_id")] @@ -266,6 +267,31 @@ class DataStore(RoomMemberStore, RoomStore, return self.runInteraction("count_users", _count_users) + def count_monthly_users(self): + """Counts the number of users who used this homeserver in the last 30 days + + This method should be refactored with count_daily_users - the only + reason not to is waiting on definition of mau + + Returns: + Defered[int] + """ + def _count_monthly_users(txn): + thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30) + sql = """ + SELECT COALESCE(count(*), 0) FROM ( + SELECT user_id FROM user_ips + WHERE last_seen > ? + GROUP BY user_id + ) u + """ + + txn.execute(sql, (thirty_days_ago,)) + count, = txn.fetchone() + return count + + return self.runInteraction("count_monthly_users", _count_monthly_users) + def count_r30_users(self): """ Counts the number of 30 day retained users, defined as:- diff --git a/synapse/storage/appservice.py b/synapse/storage/appservice.py index 9f12b360bc..31248d5e06 100644 --- a/synapse/storage/appservice.py +++ b/synapse/storage/appservice.py @@ -22,7 +22,7 @@ from twisted.internet import defer from synapse.appservice import AppServiceTransaction from synapse.config.appservice import load_appservices -from synapse.storage.events import EventsWorkerStore +from synapse.storage.events_worker import EventsWorkerStore from ._base import SQLBaseStore diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py index 5d3ee90017..24345b20a6 100644 --- a/synapse/storage/event_federation.py +++ b/synapse/storage/event_federation.py @@ -25,7 +25,7 @@ from twisted.internet import defer from synapse.api.errors import StoreError from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore -from synapse.storage.events import EventsWorkerStore +from synapse.storage.events_worker import EventsWorkerStore from synapse.storage.signatures import SignatureWorkerStore from synapse.util.caches.descriptors import cached @@ -343,6 +343,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, table="events", keyvalues={ "event_id": event_id, + "room_id": room_id, }, retcol="depth", allow_none=True, diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 2f482af3a1..e8e5a0fe44 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -34,6 +34,8 @@ from synapse.api.errors import SynapseError from synapse.events import EventBase # noqa: F401 from synapse.events.snapshot import EventContext # noqa: F401 from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.storage.background_updates import BackgroundUpdateStore +from synapse.storage.event_federation import EventFederationStore from synapse.storage.events_worker import EventsWorkerStore from synapse.types import RoomStreamToken, get_domain_from_id from synapse.util.async import ObservableDeferred @@ -65,7 +67,13 @@ state_delta_reuse_delta_counter = Counter( def encode_json(json_object): - return frozendict_json_encoder.encode(json_object) + """ + Encode a Python object as JSON and return it in a Unicode string. + """ + out = frozendict_json_encoder.encode(json_object) + if isinstance(out, bytes): + out = out.decode('utf8') + return out class _EventPeristenceQueue(object): @@ -193,7 +201,9 @@ def _retry_on_integrity_error(func): return f -class EventsStore(EventsWorkerStore): +# inherits from EventFederationStore so that we can call _update_backward_extremities +# and _handle_mult_prev_events (though arguably those could both be moved in here) +class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore): EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts" EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url" @@ -231,12 +241,18 @@ class EventsStore(EventsWorkerStore): self._state_resolution_handler = hs.get_state_resolution_handler() + @defer.inlineCallbacks def persist_events(self, events_and_contexts, backfilled=False): """ Write events to the database Args: events_and_contexts: list of tuples of (event, context) - backfilled: ? + backfilled (bool): Whether the results are retrieved from federation + via backfill or not. Used to determine if they're "new" events + which might update the current state etc. + + Returns: + Deferred[int]: the stream ordering of the latest persisted event """ partitioned = {} for event, ctx in events_and_contexts: @@ -253,10 +269,14 @@ class EventsStore(EventsWorkerStore): for room_id in partitioned: self._maybe_start_persisting(room_id) - return make_deferred_yieldable( + yield make_deferred_yieldable( defer.gatherResults(deferreds, consumeErrors=True) ) + max_persisted_id = yield self._stream_id_gen.get_current_token() + + defer.returnValue(max_persisted_id) + @defer.inlineCallbacks @log_function def persist_event(self, event, context, backfilled=False): @@ -1054,7 +1074,7 @@ class EventsStore(EventsWorkerStore): metadata_json = encode_json( event.internal_metadata.get_dict() - ).decode("UTF-8") + ) sql = ( "UPDATE event_json SET internal_metadata = ?" @@ -1168,8 +1188,8 @@ class EventsStore(EventsWorkerStore): "room_id": event.room_id, "internal_metadata": encode_json( event.internal_metadata.get_dict() - ).decode("UTF-8"), - "json": encode_json(event_dict(event)).decode("UTF-8"), + ), + "json": encode_json(event_dict(event)), } for event, _ in events_and_contexts ], diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py index f28239a808..9b4cfeb899 100644 --- a/synapse/storage/events_worker.py +++ b/synapse/storage/events_worker.py @@ -19,7 +19,7 @@ from canonicaljson import json from twisted.internet import defer -from synapse.api.errors import SynapseError +from synapse.api.errors import NotFoundError # these are only included to make the type annotations work from synapse.events import EventBase # noqa: F401 from synapse.events import FrozenEvent @@ -77,7 +77,7 @@ class EventsWorkerStore(SQLBaseStore): @defer.inlineCallbacks def get_event(self, event_id, check_redacted=True, get_prev_content=False, allow_rejected=False, - allow_none=False): + allow_none=False, check_room_id=None): """Get an event from the database by event_id. Args: @@ -88,7 +88,9 @@ class EventsWorkerStore(SQLBaseStore): include the previous states content in the unsigned field. allow_rejected (bool): If True return rejected events. allow_none (bool): If True, return None if no event found, if - False throw an exception. + False throw a NotFoundError + check_room_id (str|None): if not None, check the room of the found event. + If there is a mismatch, behave as per allow_none. Returns: Deferred : A FrozenEvent. @@ -100,10 +102,16 @@ class EventsWorkerStore(SQLBaseStore): allow_rejected=allow_rejected, ) - if not events and not allow_none: - raise SynapseError(404, "Could not find event %s" % (event_id,)) + event = events[0] if events else None - defer.returnValue(events[0] if events else None) + if event is not None and check_room_id is not None: + if event.room_id != check_room_id: + event = None + + if event is None and not allow_none: + raise NotFoundError("Could not find event %s" % (event_id,)) + + defer.returnValue(event) @defer.inlineCallbacks def get_events(self, event_ids, check_redacted=True, diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 027bf8c85e..10dce21cea 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -24,7 +24,7 @@ from canonicaljson import json from twisted.internet import defer from synapse.api.constants import EventTypes, Membership -from synapse.storage.events import EventsWorkerStore +from synapse.storage.events_worker import EventsWorkerStore from synapse.types import get_domain_from_id from synapse.util.async import Linearizer from synapse.util.caches import intern_string diff --git a/synapse/storage/schema/delta/50/make_event_content_nullable.py b/synapse/storage/schema/delta/50/make_event_content_nullable.py index 7d27342e39..6dd467b6c5 100644 --- a/synapse/storage/schema/delta/50/make_event_content_nullable.py +++ b/synapse/storage/schema/delta/50/make_event_content_nullable.py @@ -88,5 +88,5 @@ def run_upgrade(cur, database_engine, *args, **kwargs): "UPDATE sqlite_master SET sql=? WHERE tbl_name='events' AND type='table'", (sql, ), ) - cur.execute("PRAGMA schema_version=%i" % (oldver+1,)) + cur.execute("PRAGMA schema_version=%i" % (oldver + 1,)) cur.execute("PRAGMA writable_schema=OFF") diff --git a/synapse/storage/signatures.py b/synapse/storage/signatures.py index 470212aa2a..5623391f6e 100644 --- a/synapse/storage/signatures.py +++ b/synapse/storage/signatures.py @@ -74,7 +74,7 @@ class SignatureWorkerStore(SQLBaseStore): txn (cursor): event_id (str): Id for the Event. Returns: - A dict of algorithm -> hash. + A dict[unicode, bytes] of algorithm -> hash. """ query = ( "SELECT algorithm, hash" diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index 25d0097b58..b9f2b74ac6 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -43,7 +43,7 @@ from twisted.internet import defer from synapse.storage._base import SQLBaseStore from synapse.storage.engines import PostgresEngine -from synapse.storage.events import EventsWorkerStore +from synapse.storage.events_worker import EventsWorkerStore from synapse.types import RoomStreamToken from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.logcontext import make_deferred_yieldable, run_in_background diff --git a/synapse/types.py b/synapse/types.py index 08f058f714..41afb27a74 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -137,7 +137,7 @@ class DomainSpecificString( @classmethod def from_string(cls, s): """Parse the string given by 's' into a structure object.""" - if len(s) < 1 or s[0] != cls.SIGIL: + if len(s) < 1 or s[0:1] != cls.SIGIL: raise SynapseError(400, "Expected %s string to start with '%s'" % ( cls.__name__, cls.SIGIL, )) diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index f8a07df6b8..861c24809c 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -473,105 +473,101 @@ class CacheListDescriptor(_CacheDescriptorBase): @functools.wraps(self.orig) def wrapped(*args, **kwargs): - # If we're passed a cache_context then we'll want to call its invalidate() - # whenever we are invalidated + # If we're passed a cache_context then we'll want to call its + # invalidate() whenever we are invalidated invalidate_callback = kwargs.pop("on_invalidate", None) arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names] list_args = arg_dict[self.list_name] - # cached is a dict arg -> deferred, where deferred results in a - # 2-tuple (`arg`, `result`) results = {} - cached_defers = {} - missing = [] + + def update_results_dict(res, arg): + results[arg] = res + + # list of deferreds to wait for + cached_defers = [] + + missing = set() # If the cache takes a single arg then that is used as the key, # otherwise a tuple is used. if num_args == 1: - def cache_get(arg): - return cache.get(arg, callback=invalidate_callback) + def arg_to_cache_key(arg): + return arg else: - key = list(keyargs) + keylist = list(keyargs) - def cache_get(arg): - key[self.list_pos] = arg - return cache.get(tuple(key), callback=invalidate_callback) + def arg_to_cache_key(arg): + keylist[self.list_pos] = arg + return tuple(keylist) for arg in list_args: try: - res = cache_get(arg) - + res = cache.get(arg_to_cache_key(arg), + callback=invalidate_callback) if not isinstance(res, ObservableDeferred): results[arg] = res elif not res.has_succeeded(): res = res.observe() - res.addCallback(lambda r, arg: (arg, r), arg) - cached_defers[arg] = res + res.addCallback(update_results_dict, arg) + cached_defers.append(res) else: results[arg] = res.get_result() except KeyError: - missing.append(arg) + missing.add(arg) if missing: + # we need an observable deferred for each entry in the list, + # which we put in the cache. Each deferred resolves with the + # relevant result for that key. + deferreds_map = {} + for arg in missing: + deferred = defer.Deferred() + deferreds_map[arg] = deferred + key = arg_to_cache_key(arg) + observable = ObservableDeferred(deferred) + cache.set(key, observable, callback=invalidate_callback) + + def complete_all(res): + # the wrapped function has completed. It returns a + # a dict. We can now resolve the observable deferreds in + # the cache and update our own result map. + for e in missing: + val = res.get(e, None) + deferreds_map[e].callback(val) + results[e] = val + + def errback(f): + # the wrapped function has failed. Invalidate any cache + # entries we're supposed to be populating, and fail + # their deferreds. + for e in missing: + key = arg_to_cache_key(e) + cache.invalidate(key) + deferreds_map[e].errback(f) + + # return the failure, to propagate to our caller. + return f + args_to_call = dict(arg_dict) - args_to_call[self.list_name] = missing + args_to_call[self.list_name] = list(missing) - ret_d = defer.maybeDeferred( + cached_defers.append(defer.maybeDeferred( logcontext.preserve_fn(self.function_to_call), **args_to_call - ) - - ret_d = ObservableDeferred(ret_d) - - # We need to create deferreds for each arg in the list so that - # we can insert the new deferred into the cache. - for arg in missing: - observer = ret_d.observe() - observer.addCallback(lambda r, arg: r.get(arg, None), arg) - - observer = ObservableDeferred(observer) - - if num_args == 1: - cache.set( - arg, observer, - callback=invalidate_callback - ) - - def invalidate(f, key): - cache.invalidate(key) - return f - observer.addErrback(invalidate, arg) - else: - key = list(keyargs) - key[self.list_pos] = arg - cache.set( - tuple(key), observer, - callback=invalidate_callback - ) - - def invalidate(f, key): - cache.invalidate(key) - return f - observer.addErrback(invalidate, tuple(key)) - - res = observer.observe() - res.addCallback(lambda r, arg: (arg, r), arg) - - cached_defers[arg] = res + ).addCallbacks(complete_all, errback)) if cached_defers: - def update_results_dict(res): - results.update(res) - return results - - return logcontext.make_deferred_yieldable(defer.gatherResults( - list(cached_defers.values()), + d = defer.gatherResults( + cached_defers, consumeErrors=True, - ).addCallback(update_results_dict).addErrback( + ).addCallbacks( + lambda _: results, unwrapFirstError - )) + ) + return logcontext.make_deferred_yieldable(d) else: return results @@ -625,7 +621,8 @@ def cachedList(cached_method_name, list_name, num_args=None, inlineCallbacks=Fal cache. Args: - cache (Cache): The underlying cache to use. + cached_method_name (str): The name of the single-item lookup method. + This is only used to find the cache to use. list_name (str): The name of the argument that is the list to use to do batch lookups in the cache. num_args (int): Number of arguments to use as the key in the cache diff --git a/synapse/util/frozenutils.py b/synapse/util/frozenutils.py index 581c6052ac..014edea971 100644 --- a/synapse/util/frozenutils.py +++ b/synapse/util/frozenutils.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from six import string_types +from six import binary_type, text_type from canonicaljson import json from frozendict import frozendict @@ -26,7 +26,7 @@ def freeze(o): if isinstance(o, frozendict): return o - if isinstance(o, string_types): + if isinstance(o, (binary_type, text_type)): return o try: @@ -41,7 +41,7 @@ def unfreeze(o): if isinstance(o, (dict, frozendict)): return dict({k: unfreeze(v) for k, v in o.items()}) - if isinstance(o, string_types): + if isinstance(o, (binary_type, text_type)): return o try: |