summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2018-08-06 13:33:54 +0100
committerErik Johnston <erik@matrix.org>2018-08-06 13:33:54 +0100
commit49a316395831a0676160833c39daa0bb63ceea09 (patch)
treef1c4db4121d7890a7ba390c01b0e81e329f09b22 /synapse
parentMerge branch 'release-v0.33.1' of github.com:matrix-org/synapse into matrix-o... (diff)
parentReturn M_NOT_FOUND when a profile could not be found. (#3596) (diff)
downloadsynapse-49a316395831a0676160833c39daa0bb63ceea09.tar.xz
Merge branch 'develop' of github.com:matrix-org/synapse into matrix-org-hotfixes
Diffstat (limited to 'synapse')
-rw-r--r--synapse/__init__.py2
-rw-r--r--synapse/api/auth.py4
-rw-r--r--synapse/api/errors.py107
-rw-r--r--synapse/api/filtering.py20
-rwxr-xr-xsynapse/app/homeserver.py32
-rw-r--r--synapse/config/server.py10
-rw-r--r--synapse/federation/federation_client.py300
-rw-r--r--synapse/federation/federation_server.py4
-rw-r--r--synapse/federation/send_queue.py63
-rw-r--r--synapse/federation/transaction_queue.py32
-rw-r--r--synapse/federation/transport/server.py9
-rw-r--r--synapse/federation/units.py1
-rw-r--r--synapse/groups/attestations.py6
-rw-r--r--synapse/handlers/auth.py48
-rw-r--r--synapse/handlers/federation.py205
-rw-r--r--synapse/handlers/identity.py25
-rw-r--r--synapse/handlers/profile.py98
-rw-r--r--synapse/handlers/register.py21
-rw-r--r--synapse/handlers/room.py24
-rw-r--r--synapse/handlers/room_member.py4
-rw-r--r--synapse/handlers/search.py2
-rw-r--r--synapse/handlers/sync.py174
-rw-r--r--synapse/http/client.py61
-rw-r--r--synapse/http/server.py49
-rw-r--r--synapse/http/servlet.py10
-rw-r--r--synapse/metrics/background_process_metrics.py10
-rw-r--r--synapse/replication/http/membership.py18
-rw-r--r--synapse/replication/http/send_event.py10
-rw-r--r--synapse/replication/tcp/client.py2
-rw-r--r--synapse/replication/tcp/resource.py14
-rw-r--r--synapse/rest/client/v1/admin.py22
-rw-r--r--synapse/rest/client/v1/directory.py4
-rw-r--r--synapse/rest/client/v1/room.py9
-rw-r--r--synapse/rest/client/v2_alpha/register.py12
-rw-r--r--synapse/rest/media/v1/media_repository.py10
-rw-r--r--synapse/rest/media/v1/media_storage.py2
-rw-r--r--synapse/rest/media/v1/preview_url_resource.py8
-rw-r--r--synapse/secrets.py7
-rw-r--r--synapse/state.py2
-rw-r--r--synapse/storage/__init__.py28
-rw-r--r--synapse/storage/_base.py6
-rw-r--r--synapse/storage/appservice.py2
-rw-r--r--synapse/storage/client_ips.py2
-rw-r--r--synapse/storage/devices.py8
-rw-r--r--synapse/storage/event_federation.py24
-rw-r--r--synapse/storage/event_push_actions.py13
-rw-r--r--synapse/storage/events.py61
-rw-r--r--synapse/storage/roommember.py2
-rw-r--r--synapse/storage/schema/delta/50/make_event_content_nullable.py92
-rw-r--r--synapse/storage/schema/full_schemas/16/event_edges.sql3
-rw-r--r--synapse/storage/schema/full_schemas/16/im.sql7
-rw-r--r--synapse/storage/signatures.py2
-rw-r--r--synapse/storage/state.py147
-rw-r--r--synapse/storage/stream.py16
-rw-r--r--synapse/storage/transactions.py8
-rw-r--r--synapse/types.py2
-rw-r--r--synapse/util/caches/descriptors.py131
-rw-r--r--synapse/util/caches/expiringcache.py2
-rw-r--r--synapse/util/frozenutils.py6
59 files changed, 1241 insertions, 762 deletions
diff --git a/synapse/__init__.py b/synapse/__init__.py

index 5c0f2f83aa..1810cb6fcd 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py
@@ -17,4 +17,4 @@ """ This is a reference implementation of a Matrix home server. """ -__version__ = "0.33.0" +__version__ = "0.33.1" diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 073229b4c4..5bbbe8e2e7 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py
@@ -252,10 +252,10 @@ class Auth(object): if ip_address not in app_service.ip_range_whitelist: defer.returnValue((None, None)) - if "user_id" not in request.args: + if b"user_id" not in request.args: defer.returnValue((app_service.sender, app_service)) - user_id = request.args["user_id"][0] + user_id = request.args[b"user_id"][0].decode('utf8') if app_service.sender == user_id: defer.returnValue((app_service.sender, app_service)) diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index 6074df292f..b41d595059 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py
@@ -55,6 +55,7 @@ class Codes(object): SERVER_NOT_TRUSTED = "M_SERVER_NOT_TRUSTED" CONSENT_NOT_GIVEN = "M_CONSENT_NOT_GIVEN" CANNOT_LEAVE_SERVER_NOTICE_ROOM = "M_CANNOT_LEAVE_SERVER_NOTICE_ROOM" + MAU_LIMIT_EXCEEDED = "M_MAU_LIMIT_EXCEEDED" class CodeMessageException(RuntimeError): @@ -69,20 +70,6 @@ class CodeMessageException(RuntimeError): self.code = code self.msg = msg - def error_dict(self): - return cs_error(self.msg) - - -class MatrixCodeMessageException(CodeMessageException): - """An error from a general matrix endpoint, eg. from a proxied Matrix API call. - - Attributes: - errcode (str): Matrix error code e.g 'M_FORBIDDEN' - """ - def __init__(self, code, msg, errcode=Codes.UNKNOWN): - super(MatrixCodeMessageException, self).__init__(code, msg) - self.errcode = errcode - class SynapseError(CodeMessageException): """A base exception type for matrix errors which have an errcode and error @@ -108,38 +95,28 @@ class SynapseError(CodeMessageException): self.errcode, ) - @classmethod - def from_http_response_exception(cls, err): - """Make a SynapseError based on an HTTPResponseException - - This is useful when a proxied request has failed, and we need to - decide how to map the failure onto a matrix error to send back to the - client. - - An attempt is made to parse the body of the http response as a matrix - error. If that succeeds, the errcode and error message from the body - are used as the errcode and error message in the new synapse error. - - Otherwise, the errcode is set to M_UNKNOWN, and the error message is - set to the reason code from the HTTP response. - Args: - err (HttpResponseException): +class ProxiedRequestError(SynapseError): + """An error from a general matrix endpoint, eg. from a proxied Matrix API call. - Returns: - SynapseError: - """ - # try to parse the body as json, to get better errcode/msg, but - # default to M_UNKNOWN with the HTTP status as the error text - try: - j = json.loads(err.response) - except ValueError: - j = {} - errcode = j.get('errcode', Codes.UNKNOWN) - errmsg = j.get('error', err.msg) + Attributes: + errcode (str): Matrix error code e.g 'M_FORBIDDEN' + """ + def __init__(self, code, msg, errcode=Codes.UNKNOWN, additional_fields=None): + super(ProxiedRequestError, self).__init__( + code, msg, errcode + ) + if additional_fields is None: + self._additional_fields = {} + else: + self._additional_fields = dict(additional_fields) - res = SynapseError(err.code, errmsg, errcode) - return res + def error_dict(self): + return cs_error( + self.msg, + self.errcode, + **self._additional_fields + ) class ConsentNotGivenError(SynapseError): @@ -308,14 +285,6 @@ class LimitExceededError(SynapseError): ) -def cs_exception(exception): - if isinstance(exception, CodeMessageException): - return exception.error_dict() - else: - logger.error("Unknown exception type: %s", type(exception)) - return {} - - def cs_error(msg, code=Codes.UNKNOWN, **kwargs): """ Utility method for constructing an error response for client-server interactions. @@ -372,7 +341,7 @@ class HttpResponseException(CodeMessageException): Represents an HTTP-level failure of an outbound request Attributes: - response (str): body of response + response (bytes): body of response """ def __init__(self, code, msg, response): """ @@ -380,7 +349,39 @@ class HttpResponseException(CodeMessageException): Args: code (int): HTTP status code msg (str): reason phrase from HTTP response status line - response (str): body of response + response (bytes): body of response """ super(HttpResponseException, self).__init__(code, msg) self.response = response + + def to_synapse_error(self): + """Make a SynapseError based on an HTTPResponseException + + This is useful when a proxied request has failed, and we need to + decide how to map the failure onto a matrix error to send back to the + client. + + An attempt is made to parse the body of the http response as a matrix + error. If that succeeds, the errcode and error message from the body + are used as the errcode and error message in the new synapse error. + + Otherwise, the errcode is set to M_UNKNOWN, and the error message is + set to the reason code from the HTTP response. + + Returns: + SynapseError: + """ + # try to parse the body as json, to get better errcode/msg, but + # default to M_UNKNOWN with the HTTP status as the error text + try: + j = json.loads(self.response) + except ValueError: + j = {} + + if not isinstance(j, dict): + j = {} + + errcode = j.pop('errcode', Codes.UNKNOWN) + errmsg = j.pop('error', self.msg) + + return ProxiedRequestError(self.code, errmsg, errcode, j) diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index 25346baa87..186831e118 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py
@@ -113,7 +113,13 @@ ROOM_EVENT_FILTER_SCHEMA = { }, "contains_url": { "type": "boolean" - } + }, + "lazy_load_members": { + "type": "boolean" + }, + "include_redundant_members": { + "type": "boolean" + }, } } @@ -261,6 +267,12 @@ class FilterCollection(object): def ephemeral_limit(self): return self._room_ephemeral_filter.limit() + def lazy_load_members(self): + return self._room_state_filter.lazy_load_members() + + def include_redundant_members(self): + return self._room_state_filter.include_redundant_members() + def filter_presence(self, events): return self._presence_filter.filter(events) @@ -417,6 +429,12 @@ class Filter(object): def limit(self): return self.filter_json.get("limit", 10) + def lazy_load_members(self): + return self.filter_json.get("lazy_load_members", False) + + def include_redundant_members(self): + return self.filter_json.get("include_redundant_members", False) + def _matches_wildcard(actual_value, filter_value): if filter_value.endswith("*"): diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 2ad1beb8d8..fba51c26e8 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py
@@ -20,6 +20,8 @@ import sys from six import iteritems +from prometheus_client import Gauge + from twisted.application import service from twisted.internet import defer, reactor from twisted.web.resource import EncodingResourceWrapper, NoResource @@ -49,6 +51,7 @@ from synapse.http.additional_resource import AdditionalResource from synapse.http.server import RootRedirect from synapse.http.site import SynapseSite from synapse.metrics import RegistryProxy +from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.resource import METRICS_PREFIX, MetricsResource from synapse.module_api import ModuleApi from synapse.python_dependencies import CONDITIONAL_REQUIREMENTS, check_requirements @@ -299,6 +302,11 @@ class SynapseHomeServer(HomeServer): quit_with_error(e.message) +# Gauges to expose monthly active user control metrics +current_mau_gauge = Gauge("synapse_admin_current_mau", "Current MAU") +max_mau_value_gauge = Gauge("synapse_admin_max_mau_value", "MAU Limit") + + def setup(config_options): """ Args: @@ -427,6 +435,9 @@ def run(hs): # currently either 0 or 1 stats_process = [] + def start_phone_stats_home(): + return run_as_background_process("phone_stats_home", phone_stats_home) + @defer.inlineCallbacks def phone_stats_home(): logger.info("Gathering stats for reporting") @@ -498,16 +509,31 @@ def run(hs): ) def generate_user_daily_visit_stats(): - hs.get_datastore().generate_user_daily_visits() + return run_as_background_process( + "generate_user_daily_visits", + hs.get_datastore().generate_user_daily_visits, + ) # Rather than update on per session basis, batch up the requests. # If you increase the loop period, the accuracy of user_daily_visits # table will decrease clock.looping_call(generate_user_daily_visit_stats, 5 * 60 * 1000) + @defer.inlineCallbacks + def generate_monthly_active_users(): + count = 0 + if hs.config.limit_usage_by_mau: + count = yield hs.get_datastore().count_monthly_users() + current_mau_gauge.set(float(count)) + max_mau_value_gauge.set(float(hs.config.max_mau_value)) + + generate_monthly_active_users() + if hs.config.limit_usage_by_mau: + clock.looping_call(generate_monthly_active_users, 5 * 60 * 1000) + if hs.config.report_stats: logger.info("Scheduling stats reporting for 3 hour intervals") - clock.looping_call(phone_stats_home, 3 * 60 * 60 * 1000) + clock.looping_call(start_phone_stats_home, 3 * 60 * 60 * 1000) # We need to defer this init for the cases that we daemonize # otherwise the process ID we get is that of the non-daemon process @@ -515,7 +541,7 @@ def run(hs): # We wait 5 minutes to send the first set of stats as the server can # be quite busy the first few minutes - clock.call_later(5 * 60, phone_stats_home) + clock.call_later(5 * 60, start_phone_stats_home) if hs.config.daemonize and hs.config.print_pidfile: print (hs.config.pid_file) diff --git a/synapse/config/server.py b/synapse/config/server.py
index 18102656b0..6a471a0a5e 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py
@@ -67,6 +67,14 @@ class ServerConfig(Config): "block_non_admin_invites", False, ) + # Options to control access by tracking MAU + self.limit_usage_by_mau = config.get("limit_usage_by_mau", False) + if self.limit_usage_by_mau: + self.max_mau_value = config.get( + "max_mau_value", 0, + ) + else: + self.max_mau_value = 0 # FIXME: federation_domain_whitelist needs sytests self.federation_domain_whitelist = None federation_domain_whitelist = config.get( @@ -209,6 +217,8 @@ class ServerConfig(Config): # different cores. See # https://www.mirantis.com/blog/improve-performance-python-programs-restricting-single-cpu/. # + # This setting requires the affinity package to be installed! + # # cpu_affinity: 0xFFFFFFFF # Whether to serve a web client from the HTTP/HTTPS root resource. diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 62d7ed13cf..7550e11b6e 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py
@@ -48,6 +48,13 @@ sent_queries_counter = Counter("synapse_federation_client_sent_queries", "", ["t PDU_RETRY_TIME_MS = 1 * 60 * 1000 +class InvalidResponseError(RuntimeError): + """Helper for _try_destination_list: indicates that the server returned a response + we couldn't parse + """ + pass + + class FederationClient(FederationBase): def __init__(self, hs): super(FederationClient, self).__init__(hs) @@ -458,6 +465,61 @@ class FederationClient(FederationBase): defer.returnValue(signed_auth) @defer.inlineCallbacks + def _try_destination_list(self, description, destinations, callback): + """Try an operation on a series of servers, until it succeeds + + Args: + description (unicode): description of the operation we're doing, for logging + + destinations (Iterable[unicode]): list of server_names to try + + callback (callable): Function to run for each server. Passed a single + argument: the server_name to try. May return a deferred. + + If the callback raises a CodeMessageException with a 300/400 code, + attempts to perform the operation stop immediately and the exception is + reraised. + + Otherwise, if the callback raises an Exception the error is logged and the + next server tried. Normally the stacktrace is logged but this is + suppressed if the exception is an InvalidResponseError. + + Returns: + The [Deferred] result of callback, if it succeeds + + Raises: + SynapseError if the chosen remote server returns a 300/400 code. + + RuntimeError if no servers were reachable. + """ + for destination in destinations: + if destination == self.server_name: + continue + + try: + res = yield callback(destination) + defer.returnValue(res) + except InvalidResponseError as e: + logger.warn( + "Failed to %s via %s: %s", + description, destination, e, + ) + except HttpResponseException as e: + if not 500 <= e.code < 600: + raise e.to_synapse_error() + else: + logger.warn( + "Failed to %s via %s: %i %s", + description, destination, e.code, e.message, + ) + except Exception: + logger.warn( + "Failed to %s via %s", + description, destination, exc_info=1, + ) + + raise RuntimeError("Failed to %s via any server", description) + def make_membership_event(self, destinations, room_id, user_id, membership, content={},): """ @@ -481,7 +543,7 @@ class FederationClient(FederationBase): Deferred: resolves to a tuple of (origin (str), event (object)) where origin is the remote homeserver which generated the event. - Fails with a ``CodeMessageException`` if the chosen remote server + Fails with a ``SynapseError`` if the chosen remote server returns a 300/400 code. Fails with a ``RuntimeError`` if no servers were reachable. @@ -492,50 +554,35 @@ class FederationClient(FederationBase): "make_membership_event called with membership='%s', must be one of %s" % (membership, ",".join(valid_memberships)) ) - for destination in destinations: - if destination == self.server_name: - continue - try: - ret = yield self.transport_layer.make_membership_event( - destination, room_id, user_id, membership - ) + @defer.inlineCallbacks + def send_request(destination): + ret = yield self.transport_layer.make_membership_event( + destination, room_id, user_id, membership + ) - pdu_dict = ret["event"] + pdu_dict = ret["event"] - logger.debug("Got response to make_%s: %s", membership, pdu_dict) + logger.debug("Got response to make_%s: %s", membership, pdu_dict) - pdu_dict["content"].update(content) + pdu_dict["content"].update(content) - # The protoevent received over the JSON wire may not have all - # the required fields. Lets just gloss over that because - # there's some we never care about - if "prev_state" not in pdu_dict: - pdu_dict["prev_state"] = [] + # The protoevent received over the JSON wire may not have all + # the required fields. Lets just gloss over that because + # there's some we never care about + if "prev_state" not in pdu_dict: + pdu_dict["prev_state"] = [] - ev = builder.EventBuilder(pdu_dict) + ev = builder.EventBuilder(pdu_dict) - defer.returnValue( - (destination, ev) - ) - break - except CodeMessageException as e: - if not 500 <= e.code < 600: - raise - else: - logger.warn( - "Failed to make_%s via %s: %s", - membership, destination, e.message - ) - except Exception as e: - logger.warn( - "Failed to make_%s via %s: %s", - membership, destination, e.message - ) + defer.returnValue( + (destination, ev) + ) - raise RuntimeError("Failed to send to any server.") + return self._try_destination_list( + "make_" + membership, destinations, send_request, + ) - @defer.inlineCallbacks def send_join(self, destinations, pdu): """Sends a join event to one of a list of homeservers. @@ -552,103 +599,91 @@ class FederationClient(FederationBase): giving the serer the event was sent to, ``state`` (?) and ``auth_chain``. - Fails with a ``CodeMessageException`` if the chosen remote server + Fails with a ``SynapseError`` if the chosen remote server returns a 300/400 code. Fails with a ``RuntimeError`` if no servers were reachable. """ - for destination in destinations: - if destination == self.server_name: - continue - - try: - time_now = self._clock.time_msec() - _, content = yield self.transport_layer.send_join( - destination=destination, - room_id=pdu.room_id, - event_id=pdu.event_id, - content=pdu.get_pdu_json(time_now), - ) + @defer.inlineCallbacks + def send_request(destination): + time_now = self._clock.time_msec() + _, content = yield self.transport_layer.send_join( + destination=destination, + room_id=pdu.room_id, + event_id=pdu.event_id, + content=pdu.get_pdu_json(time_now), + ) - logger.debug("Got content: %s", content) + logger.debug("Got content: %s", content) - state = [ - event_from_pdu_json(p, outlier=True) - for p in content.get("state", []) - ] + state = [ + event_from_pdu_json(p, outlier=True) + for p in content.get("state", []) + ] - auth_chain = [ - event_from_pdu_json(p, outlier=True) - for p in content.get("auth_chain", []) - ] + auth_chain = [ + event_from_pdu_json(p, outlier=True) + for p in content.get("auth_chain", []) + ] - pdus = { - p.event_id: p - for p in itertools.chain(state, auth_chain) - } + pdus = { + p.event_id: p + for p in itertools.chain(state, auth_chain) + } - valid_pdus = yield self._check_sigs_and_hash_and_fetch( - destination, list(pdus.values()), - outlier=True, - ) + valid_pdus = yield self._check_sigs_and_hash_and_fetch( + destination, list(pdus.values()), + outlier=True, + ) - valid_pdus_map = { - p.event_id: p - for p in valid_pdus - } - - # NB: We *need* to copy to ensure that we don't have multiple - # references being passed on, as that causes... issues. - signed_state = [ - copy.copy(valid_pdus_map[p.event_id]) - for p in state - if p.event_id in valid_pdus_map - ] + valid_pdus_map = { + p.event_id: p + for p in valid_pdus + } - signed_auth = [ - valid_pdus_map[p.event_id] - for p in auth_chain - if p.event_id in valid_pdus_map - ] + # NB: We *need* to copy to ensure that we don't have multiple + # references being passed on, as that causes... issues. + signed_state = [ + copy.copy(valid_pdus_map[p.event_id]) + for p in state + if p.event_id in valid_pdus_map + ] - # NB: We *need* to copy to ensure that we don't have multiple - # references being passed on, as that causes... issues. - for s in signed_state: - s.internal_metadata = copy.deepcopy(s.internal_metadata) + signed_auth = [ + valid_pdus_map[p.event_id] + for p in auth_chain + if p.event_id in valid_pdus_map + ] - auth_chain.sort(key=lambda e: e.depth) + # NB: We *need* to copy to ensure that we don't have multiple + # references being passed on, as that causes... issues. + for s in signed_state: + s.internal_metadata = copy.deepcopy(s.internal_metadata) - defer.returnValue({ - "state": signed_state, - "auth_chain": signed_auth, - "origin": destination, - }) - except CodeMessageException as e: - if not 500 <= e.code < 600: - raise - else: - logger.exception( - "Failed to send_join via %s: %s", - destination, e.message - ) - except Exception as e: - logger.exception( - "Failed to send_join via %s: %s", - destination, e.message - ) + auth_chain.sort(key=lambda e: e.depth) - raise RuntimeError("Failed to send to any server.") + defer.returnValue({ + "state": signed_state, + "auth_chain": signed_auth, + "origin": destination, + }) + return self._try_destination_list("send_join", destinations, send_request) @defer.inlineCallbacks def send_invite(self, destination, room_id, event_id, pdu): time_now = self._clock.time_msec() - code, content = yield self.transport_layer.send_invite( - destination=destination, - room_id=room_id, - event_id=event_id, - content=pdu.get_pdu_json(time_now), - ) + try: + code, content = yield self.transport_layer.send_invite( + destination=destination, + room_id=room_id, + event_id=event_id, + content=pdu.get_pdu_json(time_now), + ) + except HttpResponseException as e: + if e.code == 403: + raise e.to_synapse_error() + raise pdu_dict = content["event"] @@ -663,7 +698,6 @@ class FederationClient(FederationBase): defer.returnValue(pdu) - @defer.inlineCallbacks def send_leave(self, destinations, pdu): """Sends a leave event to one of a list of homeservers. @@ -680,35 +714,25 @@ class FederationClient(FederationBase): Return: Deferred: resolves to None. - Fails with a ``CodeMessageException`` if the chosen remote server - returns a non-200 code. + Fails with a ``SynapseError`` if the chosen remote server + returns a 300/400 code. Fails with a ``RuntimeError`` if no servers were reachable. """ - for destination in destinations: - if destination == self.server_name: - continue - - try: - time_now = self._clock.time_msec() - _, content = yield self.transport_layer.send_leave( - destination=destination, - room_id=pdu.room_id, - event_id=pdu.event_id, - content=pdu.get_pdu_json(time_now), - ) + @defer.inlineCallbacks + def send_request(destination): + time_now = self._clock.time_msec() + _, content = yield self.transport_layer.send_leave( + destination=destination, + room_id=pdu.room_id, + event_id=pdu.event_id, + content=pdu.get_pdu_json(time_now), + ) - logger.debug("Got content: %s", content) - defer.returnValue(None) - except CodeMessageException: - raise - except Exception as e: - logger.exception( - "Failed to send_leave via %s: %s", - destination, e.message - ) + logger.debug("Got content: %s", content) + defer.returnValue(None) - raise RuntimeError("Failed to send to any server.") + return self._try_destination_list("send_leave", destinations, send_request) def get_public_rooms(self, destination, limit=None, since_token=None, search_filter=None, include_all_networks=False, diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 015ef6334a..bf89d568af 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py
@@ -207,10 +207,6 @@ class FederationServer(FederationBase): edu.content ) - pdu_failures = getattr(transaction, "pdu_failures", []) - for fail in pdu_failures: - logger.info("Got failure %r", fail) - response = { "pdus": pdu_results, } diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py
index 5157c3860d..0bb468385d 100644 --- a/synapse/federation/send_queue.py +++ b/synapse/federation/send_queue.py
@@ -62,8 +62,6 @@ class FederationRemoteSendQueue(object): self.edus = SortedDict() # stream position -> Edu - self.failures = SortedDict() # stream position -> (destination, Failure) - self.device_messages = SortedDict() # stream position -> destination self.pos = 1 @@ -79,7 +77,7 @@ class FederationRemoteSendQueue(object): for queue_name in [ "presence_map", "presence_changed", "keyed_edu", "keyed_edu_changed", - "edus", "failures", "device_messages", "pos_time", + "edus", "device_messages", "pos_time", ]: register(queue_name, getattr(self, queue_name)) @@ -149,12 +147,6 @@ class FederationRemoteSendQueue(object): for key in keys[:i]: del self.edus[key] - # Delete things out of failure map - keys = self.failures.keys() - i = self.failures.bisect_left(position_to_delete) - for key in keys[:i]: - del self.failures[key] - # Delete things out of device map keys = self.device_messages.keys() i = self.device_messages.bisect_left(position_to_delete) @@ -204,13 +196,6 @@ class FederationRemoteSendQueue(object): self.notifier.on_new_replication_data() - def send_failure(self, failure, destination): - """As per TransactionQueue""" - pos = self._next_pos() - - self.failures[pos] = (destination, str(failure)) - self.notifier.on_new_replication_data() - def send_device_messages(self, destination): """As per TransactionQueue""" pos = self._next_pos() @@ -285,17 +270,6 @@ class FederationRemoteSendQueue(object): for (pos, edu) in edus: rows.append((pos, EduRow(edu))) - # Fetch changed failures - i = self.failures.bisect_right(from_token) - j = self.failures.bisect_right(to_token) + 1 - failures = self.failures.items()[i:j] - - for (pos, (destination, failure)) in failures: - rows.append((pos, FailureRow( - destination=destination, - failure=failure, - ))) - # Fetch changed device messages i = self.device_messages.bisect_right(from_token) j = self.device_messages.bisect_right(to_token) + 1 @@ -417,34 +391,6 @@ class EduRow(BaseFederationRow, namedtuple("EduRow", ( buff.edus.setdefault(self.edu.destination, []).append(self.edu) -class FailureRow(BaseFederationRow, namedtuple("FailureRow", ( - "destination", # str - "failure", -))): - """Streams failures to a remote server. Failures are issued when there was - something wrong with a transaction the remote sent us, e.g. it included - an event that was invalid. - """ - - TypeId = "f" - - @staticmethod - def from_data(data): - return FailureRow( - destination=data["destination"], - failure=data["failure"], - ) - - def to_data(self): - return { - "destination": self.destination, - "failure": self.failure, - } - - def add_to_buffer(self, buff): - buff.failures.setdefault(self.destination, []).append(self.failure) - - class DeviceRow(BaseFederationRow, namedtuple("DeviceRow", ( "destination", # str ))): @@ -471,7 +417,6 @@ TypeToRow = { PresenceRow, KeyedEduRow, EduRow, - FailureRow, DeviceRow, ) } @@ -481,7 +426,6 @@ ParsedFederationStreamData = namedtuple("ParsedFederationStreamData", ( "presence", # list(UserPresenceState) "keyed_edus", # dict of destination -> { key -> Edu } "edus", # dict of destination -> [Edu] - "failures", # dict of destination -> [failures] "device_destinations", # set of destinations )) @@ -503,7 +447,6 @@ def process_rows_for_federation(transaction_queue, rows): presence=[], keyed_edus={}, edus={}, - failures={}, device_destinations=set(), ) @@ -532,9 +475,5 @@ def process_rows_for_federation(transaction_queue, rows): edu.destination, edu.edu_type, edu.content, key=None, ) - for destination, failure_list in iteritems(buff.failures): - for failure in failure_list: - transaction_queue.send_failure(destination, failure) - for destination in buff.device_destinations: transaction_queue.send_device_messages(destination) diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py
index 2e5faeb96a..3ffcd76c70 100644 --- a/synapse/federation/transaction_queue.py +++ b/synapse/federation/transaction_queue.py
@@ -116,9 +116,6 @@ class TransactionQueue(object): ), ) - # destination -> list of tuple(failure, deferred) - self.pending_failures_by_dest = {} - # destination -> stream_id of last successfully sent to-device message. # NB: may be a long or an int. self.last_device_stream_id_by_dest = {} @@ -383,19 +380,6 @@ class TransactionQueue(object): self._attempt_new_transaction(destination) - def send_failure(self, failure, destination): - if destination == self.server_name or destination == "localhost": - return - - if not self.can_send_to(destination): - return - - self.pending_failures_by_dest.setdefault( - destination, [] - ).append(failure) - - self._attempt_new_transaction(destination) - def send_device_messages(self, destination): if destination == self.server_name or destination == "localhost": return @@ -470,7 +454,6 @@ class TransactionQueue(object): pending_pdus = self.pending_pdus_by_dest.pop(destination, []) pending_edus = self.pending_edus_by_dest.pop(destination, []) pending_presence = self.pending_presence_by_dest.pop(destination, {}) - pending_failures = self.pending_failures_by_dest.pop(destination, []) pending_edus.extend( self.pending_edus_keyed_by_dest.pop(destination, {}).values() @@ -498,7 +481,7 @@ class TransactionQueue(object): logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d", destination, len(pending_pdus)) - if not pending_pdus and not pending_edus and not pending_failures: + if not pending_pdus and not pending_edus: logger.debug("TX [%s] Nothing to send", destination) self.last_device_stream_id_by_dest[destination] = ( device_stream_id @@ -508,7 +491,7 @@ class TransactionQueue(object): # END CRITICAL SECTION success = yield self._send_new_transaction( - destination, pending_pdus, pending_edus, pending_failures, + destination, pending_pdus, pending_edus, ) if success: sent_transactions_counter.inc() @@ -585,14 +568,12 @@ class TransactionQueue(object): @measure_func("_send_new_transaction") @defer.inlineCallbacks - def _send_new_transaction(self, destination, pending_pdus, pending_edus, - pending_failures): + def _send_new_transaction(self, destination, pending_pdus, pending_edus): # Sort based on the order field pending_pdus.sort(key=lambda t: t[1]) pdus = [x[0] for x in pending_pdus] edus = pending_edus - failures = [x.get_dict() for x in pending_failures] success = True @@ -602,11 +583,10 @@ class TransactionQueue(object): logger.debug( "TX [%s] {%s} Attempting new transaction" - " (pdus: %d, edus: %d, failures: %d)", + " (pdus: %d, edus: %d)", destination, txn_id, len(pdus), len(edus), - len(failures) ) logger.debug("TX [%s] Persisting transaction...", destination) @@ -618,7 +598,6 @@ class TransactionQueue(object): destination=destination, pdus=pdus, edus=edus, - pdu_failures=failures, ) self._next_txn_id += 1 @@ -628,12 +607,11 @@ class TransactionQueue(object): logger.debug("TX [%s] Persisted transaction", destination) logger.info( "TX [%s] {%s} Sending transaction [%s]," - " (PDUs: %d, EDUs: %d, failures: %d)", + " (PDUs: %d, EDUs: %d)", destination, txn_id, transaction.transaction_id, len(pdus), len(edus), - len(failures), ) # Actually send the transaction diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index c9beca27c2..eae5f2b427 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py
@@ -165,7 +165,7 @@ def _parse_auth_header(header_bytes): param_dict = dict(kv.split("=") for kv in params) def strip_quotes(value): - if value.startswith(b"\""): + if value.startswith("\""): return value[1:-1] else: return value @@ -283,11 +283,10 @@ class FederationSendServlet(BaseFederationServlet): ) logger.info( - "Received txn %s from %s. (PDUs: %d, EDUs: %d, failures: %d)", + "Received txn %s from %s. (PDUs: %d, EDUs: %d)", transaction_id, origin, len(transaction_data.get("pdus", [])), len(transaction_data.get("edus", [])), - len(transaction_data.get("failures", [])), ) # We should ideally be getting this from the security layer. @@ -404,10 +403,10 @@ class FederationMakeLeaveServlet(BaseFederationServlet): class FederationSendLeaveServlet(BaseFederationServlet): - PATH = "/send_leave/(?P<room_id>[^/]*)/(?P<txid>[^/]*)" + PATH = "/send_leave/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)" @defer.inlineCallbacks - def on_PUT(self, origin, content, query, room_id, txid): + def on_PUT(self, origin, content, query, room_id, event_id): content = yield self.handler.on_send_leave_request(origin, content) defer.returnValue((200, content)) diff --git a/synapse/federation/units.py b/synapse/federation/units.py
index bb1b3b13f7..c5ab14314e 100644 --- a/synapse/federation/units.py +++ b/synapse/federation/units.py
@@ -73,7 +73,6 @@ class Transaction(JsonEncodedObject): "previous_ids", "pdus", "edus", - "pdu_failures", ] internal_keys = [ diff --git a/synapse/groups/attestations.py b/synapse/groups/attestations.py
index 47452700a8..b04f4234ca 100644 --- a/synapse/groups/attestations.py +++ b/synapse/groups/attestations.py
@@ -43,6 +43,7 @@ from signedjson.sign import sign_json from twisted.internet import defer from synapse.api.errors import SynapseError +from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import get_domain_from_id from synapse.util.logcontext import run_in_background @@ -129,7 +130,7 @@ class GroupAttestionRenewer(object): self.attestations = hs.get_groups_attestation_signing() self._renew_attestations_loop = self.clock.looping_call( - self._renew_attestations, 30 * 60 * 1000, + self._start_renew_attestations, 30 * 60 * 1000, ) @defer.inlineCallbacks @@ -151,6 +152,9 @@ class GroupAttestionRenewer(object): defer.returnValue({}) + def _start_renew_attestations(self): + return run_as_background_process("renew_attestations", self._renew_attestations) + @defer.inlineCallbacks def _renew_attestations(self): """Called periodically to check if we need to update any of our attestations diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 402e44cdef..184eef09d0 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py
@@ -15,6 +15,7 @@ # limitations under the License. import logging +import unicodedata import attr import bcrypt @@ -519,6 +520,7 @@ class AuthHandler(BaseHandler): """ logger.info("Logging in user %s on device %s", user_id, device_id) access_token = yield self.issue_access_token(user_id, device_id) + yield self._check_mau_limits() # the device *should* have been registered before we got here; however, # it's possible we raced against a DELETE operation. The thing we @@ -626,6 +628,7 @@ class AuthHandler(BaseHandler): # special case to check for "password" for the check_password interface # for the auth providers password = login_submission.get("password") + if login_type == LoginType.PASSWORD: if not self._password_enabled: raise SynapseError(400, "Password login has been disabled.") @@ -707,9 +710,10 @@ class AuthHandler(BaseHandler): multiple inexact matches. Args: - user_id (str): complete @user:id + user_id (unicode): complete @user:id + password (unicode): the provided password Returns: - (str) the canonical_user_id, or None if unknown user / bad password + (unicode) the canonical_user_id, or None if unknown user / bad password """ lookupres = yield self._find_user_id_and_pwd_hash(user_id) if not lookupres: @@ -728,15 +732,18 @@ class AuthHandler(BaseHandler): device_id) defer.returnValue(access_token) + @defer.inlineCallbacks def validate_short_term_login_token_and_get_user_id(self, login_token): + yield self._check_mau_limits() auth_api = self.hs.get_auth() + user_id = None try: macaroon = pymacaroons.Macaroon.deserialize(login_token) user_id = auth_api.get_user_id_from_macaroon(macaroon) auth_api.validate_macaroon(macaroon, "login", True, user_id) - return user_id except Exception: raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN) + defer.returnValue(user_id) @defer.inlineCallbacks def delete_access_token(self, access_token): @@ -849,14 +856,19 @@ class AuthHandler(BaseHandler): """Computes a secure hash of password. Args: - password (str): Password to hash. + password (unicode): Password to hash. Returns: - Deferred(str): Hashed password. + Deferred(unicode): Hashed password. """ def _do_hash(): - return bcrypt.hashpw(password.encode('utf8') + self.hs.config.password_pepper, - bcrypt.gensalt(self.bcrypt_rounds)) + # Normalise the Unicode in the password + pw = unicodedata.normalize("NFKC", password) + + return bcrypt.hashpw( + pw.encode('utf8') + self.hs.config.password_pepper.encode("utf8"), + bcrypt.gensalt(self.bcrypt_rounds), + ).decode('ascii') return make_deferred_yieldable( threads.deferToThreadPool( @@ -868,16 +880,19 @@ class AuthHandler(BaseHandler): """Validates that self.hash(password) == stored_hash. Args: - password (str): Password to hash. - stored_hash (str): Expected hash value. + password (unicode): Password to hash. + stored_hash (unicode): Expected hash value. Returns: Deferred(bool): Whether self.hash(password) == stored_hash. """ def _do_validate_hash(): + # Normalise the Unicode in the password + pw = unicodedata.normalize("NFKC", password) + return bcrypt.checkpw( - password.encode('utf8') + self.hs.config.password_pepper, + pw.encode('utf8') + self.hs.config.password_pepper.encode("utf8"), stored_hash.encode('utf8') ) @@ -892,6 +907,19 @@ class AuthHandler(BaseHandler): else: return defer.succeed(False) + @defer.inlineCallbacks + def _check_mau_limits(self): + """ + Ensure that if mau blocking is enabled that invalid users cannot + log in. + """ + if self.hs.config.limit_usage_by_mau is True: + current_mau = yield self.store.count_monthly_users() + if current_mau >= self.hs.config.max_mau_value: + raise AuthError( + 403, "MAU Limit Exceeded", errcode=Codes.MAU_LIMIT_EXCEEDED + ) + @attr.s class MacaroonGenerator(object): diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index c0dd468940..533b82c783 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py
@@ -76,7 +76,7 @@ class FederationHandler(BaseHandler): self.hs = hs self.store = hs.get_datastore() - self.replication_layer = hs.get_federation_client() + self.federation_client = hs.get_federation_client() self.state_handler = hs.get_state_handler() self.server_name = hs.hostname self.keyring = hs.get_keyring() @@ -255,7 +255,7 @@ class FederationHandler(BaseHandler): # know about for p in prevs - seen: state, got_auth_chain = ( - yield self.replication_layer.get_state_for_room( + yield self.federation_client.get_state_for_room( origin, pdu.room_id, p ) ) @@ -338,7 +338,7 @@ class FederationHandler(BaseHandler): # # see https://github.com/matrix-org/synapse/pull/1744 - missing_events = yield self.replication_layer.get_missing_events( + missing_events = yield self.federation_client.get_missing_events( origin, pdu.room_id, earliest_events_ids=list(latest), @@ -400,7 +400,7 @@ class FederationHandler(BaseHandler): ) try: - event_stream_id, max_stream_id = yield self._persist_auth_tree( + yield self._persist_auth_tree( origin, auth_chain, state, event ) except AuthError as e: @@ -444,7 +444,7 @@ class FederationHandler(BaseHandler): yield self._handle_new_events(origin, event_infos) try: - context, event_stream_id, max_stream_id = yield self._handle_new_event( + context = yield self._handle_new_event( origin, event, state=state, @@ -469,17 +469,6 @@ class FederationHandler(BaseHandler): except StoreError: logger.exception("Failed to store room.") - extra_users = [] - if event.type == EventTypes.Member: - target_user_id = event.state_key - target_user = UserID.from_string(target_user_id) - extra_users.append(target_user) - - self.notifier.on_new_room_event( - event, event_stream_id, max_stream_id, - extra_users=extra_users - ) - if event.type == EventTypes.Member: if event.membership == Membership.JOIN: # Only fire user_joined_room if the user has acutally @@ -501,7 +490,7 @@ class FederationHandler(BaseHandler): if newly_joined: user = UserID.from_string(event.state_key) - yield user_joined_room(self.distributor, user, event.room_id) + yield self.user_joined_room(user, event.room_id) @log_function @defer.inlineCallbacks @@ -522,7 +511,7 @@ class FederationHandler(BaseHandler): if dest == self.server_name: raise SynapseError(400, "Can't backfill from self.") - events = yield self.replication_layer.backfill( + events = yield self.federation_client.backfill( dest, room_id, limit=limit, @@ -570,7 +559,7 @@ class FederationHandler(BaseHandler): state_events = {} events_to_state = {} for e_id in edges: - state, auth = yield self.replication_layer.get_state_for_room( + state, auth = yield self.federation_client.get_state_for_room( destination=dest, room_id=room_id, event_id=e_id @@ -612,7 +601,7 @@ class FederationHandler(BaseHandler): results = yield logcontext.make_deferred_yieldable(defer.gatherResults( [ logcontext.run_in_background( - self.replication_layer.get_pdu, + self.federation_client.get_pdu, [dest], event_id, outlier=True, @@ -893,7 +882,7 @@ class FederationHandler(BaseHandler): Invites must be signed by the invitee's server before distribution. """ - pdu = yield self.replication_layer.send_invite( + pdu = yield self.federation_client.send_invite( destination=target_host, room_id=event.room_id, event_id=event.event_id, @@ -942,7 +931,7 @@ class FederationHandler(BaseHandler): self.room_queues[room_id] = [] - yield self.store.clean_room_for_join(room_id) + yield self._clean_room_for_join(room_id) handled_events = set() @@ -955,7 +944,7 @@ class FederationHandler(BaseHandler): target_hosts.insert(0, origin) except ValueError: pass - ret = yield self.replication_layer.send_join(target_hosts, event) + ret = yield self.federation_client.send_join(target_hosts, event) origin = ret["origin"] state = ret["state"] @@ -981,15 +970,10 @@ class FederationHandler(BaseHandler): # FIXME pass - event_stream_id, max_stream_id = yield self._persist_auth_tree( + yield self._persist_auth_tree( origin, auth_chain, state, event ) - self.notifier.on_new_room_event( - event, event_stream_id, max_stream_id, - extra_users=[joinee] - ) - logger.debug("Finished joining %s to %s", joinee, room_id) finally: room_queue = self.room_queues[room_id] @@ -1084,7 +1068,7 @@ class FederationHandler(BaseHandler): # would introduce the danger of backwards-compatibility problems. event.internal_metadata.send_on_behalf_of = origin - context, event_stream_id, max_stream_id = yield self._handle_new_event( + context = yield self._handle_new_event( origin, event ) @@ -1094,20 +1078,10 @@ class FederationHandler(BaseHandler): event.signatures, ) - extra_users = [] - if event.type == EventTypes.Member: - target_user_id = event.state_key - target_user = UserID.from_string(target_user_id) - extra_users.append(target_user) - - self.notifier.on_new_room_event( - event, event_stream_id, max_stream_id, extra_users=extra_users - ) - if event.type == EventTypes.Member: if event.content["membership"] == Membership.JOIN: user = UserID.from_string(event.state_key) - yield user_joined_room(self.distributor, user, event.room_id) + yield self.user_joined_room(user, event.room_id) prev_state_ids = yield context.get_prev_state_ids(self.store) @@ -1176,17 +1150,7 @@ class FederationHandler(BaseHandler): ) context = yield self.state_handler.compute_event_context(event) - - event_stream_id, max_stream_id = yield self.store.persist_event( - event, - context=context, - ) - - target_user = UserID.from_string(event.state_key) - self.notifier.on_new_room_event( - event, event_stream_id, max_stream_id, - extra_users=[target_user], - ) + yield self._persist_events([(event, context)]) defer.returnValue(event) @@ -1211,30 +1175,20 @@ class FederationHandler(BaseHandler): except ValueError: pass - yield self.replication_layer.send_leave( + yield self.federation_client.send_leave( target_hosts, event ) context = yield self.state_handler.compute_event_context(event) - - event_stream_id, max_stream_id = yield self.store.persist_event( - event, - context=context, - ) - - target_user = UserID.from_string(event.state_key) - self.notifier.on_new_room_event( - event, event_stream_id, max_stream_id, - extra_users=[target_user], - ) + yield self._persist_events([(event, context)]) defer.returnValue(event) @defer.inlineCallbacks def _make_and_verify_event(self, target_hosts, room_id, user_id, membership, content={},): - origin, pdu = yield self.replication_layer.make_membership_event( + origin, pdu = yield self.federation_client.make_membership_event( target_hosts, room_id, user_id, @@ -1279,7 +1233,7 @@ class FederationHandler(BaseHandler): @log_function def on_make_leave_request(self, room_id, user_id): """ We've received a /make_leave/ request, so we create a partial - join event for the room and return that. We do *not* persist or + leave event for the room and return that. We do *not* persist or process it until the other server has signed it and sent it back. """ builder = self.event_builder_factory.new({ @@ -1318,7 +1272,7 @@ class FederationHandler(BaseHandler): event.internal_metadata.outlier = False - context, event_stream_id, max_stream_id = yield self._handle_new_event( + yield self._handle_new_event( origin, event ) @@ -1328,16 +1282,6 @@ class FederationHandler(BaseHandler): event.signatures, ) - extra_users = [] - if event.type == EventTypes.Member: - target_user_id = event.state_key - target_user = UserID.from_string(target_user_id) - extra_users.append(target_user) - - self.notifier.on_new_room_event( - event, event_stream_id, max_stream_id, extra_users=extra_users - ) - defer.returnValue(None) @defer.inlineCallbacks @@ -1479,9 +1423,8 @@ class FederationHandler(BaseHandler): event, context ) - event_stream_id, max_stream_id = yield self.store.persist_event( - event, - context=context, + yield self._persist_events( + [(event, context)], backfilled=backfilled, ) except: # noqa: E722, as we reraise the exception this is fine. @@ -1494,15 +1437,7 @@ class FederationHandler(BaseHandler): six.reraise(tp, value, tb) - if not backfilled: - # this intentionally does not yield: we don't care about the result - # and don't need to wait for it. - logcontext.run_in_background( - self.pusher_pool.on_new_notifications, - event_stream_id, max_stream_id, - ) - - defer.returnValue((context, event_stream_id, max_stream_id)) + defer.returnValue(context) @defer.inlineCallbacks def _handle_new_events(self, origin, event_infos, backfilled=False): @@ -1510,6 +1445,8 @@ class FederationHandler(BaseHandler): should not depend on one another, e.g. this should be used to persist a bunch of outliers, but not a chunk of individual events that depend on each other for state calculations. + + Notifies about the events where appropriate. """ contexts = yield logcontext.make_deferred_yieldable(defer.gatherResults( [ @@ -1524,7 +1461,7 @@ class FederationHandler(BaseHandler): ], consumeErrors=True, )) - yield self.store.persist_events( + yield self._persist_events( [ (ev_info["event"], context) for ev_info, context in zip(event_infos, contexts) @@ -1536,7 +1473,8 @@ class FederationHandler(BaseHandler): def _persist_auth_tree(self, origin, auth_events, state, event): """Checks the auth chain is valid (and passes auth checks) for the state and event. Then persists the auth chain and state atomically. - Persists the event seperately. + Persists the event separately. Notifies about the persisted events + where appropriate. Will attempt to fetch missing auth events. @@ -1547,8 +1485,7 @@ class FederationHandler(BaseHandler): event (Event) Returns: - 2-tuple of (event_stream_id, max_stream_id) from the persist_event - call for `event` + Deferred """ events_to_context = {} for e in itertools.chain(auth_events, state): @@ -1574,7 +1511,7 @@ class FederationHandler(BaseHandler): missing_auth_events.add(e_id) for e_id in missing_auth_events: - m_ev = yield self.replication_layer.get_pdu( + m_ev = yield self.federation_client.get_pdu( [origin], e_id, outlier=True, @@ -1612,7 +1549,7 @@ class FederationHandler(BaseHandler): raise events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR - yield self.store.persist_events( + yield self._persist_events( [ (e, events_to_context[e.event_id]) for e in itertools.chain(auth_events, state) @@ -1623,12 +1560,10 @@ class FederationHandler(BaseHandler): event, old_state=state ) - event_stream_id, max_stream_id = yield self.store.persist_event( - event, new_event_context, + yield self._persist_events( + [(event, new_event_context)], ) - defer.returnValue((event_stream_id, max_stream_id)) - @defer.inlineCallbacks def _prep_event(self, origin, event, state=None, auth_events=None): """ @@ -1794,7 +1729,7 @@ class FederationHandler(BaseHandler): logger.info("Missing auth: %s", missing_auth) # If we don't have all the auth events, we need to get them. try: - remote_auth_chain = yield self.replication_layer.get_event_auth( + remote_auth_chain = yield self.federation_client.get_event_auth( origin, event.room_id, event.event_id ) @@ -1910,7 +1845,7 @@ class FederationHandler(BaseHandler): try: # 2. Get remote difference. - result = yield self.replication_layer.query_auth( + result = yield self.federation_client.query_auth( origin, event.room_id, event.event_id, @@ -2209,7 +2144,7 @@ class FederationHandler(BaseHandler): yield member_handler.send_membership_event(None, event, context) else: destinations = set(x.split(":", 1)[-1] for x in (sender_user_id, room_id)) - yield self.replication_layer.forward_third_party_invite( + yield self.federation_client.forward_third_party_invite( destinations, room_id, event_dict, @@ -2364,3 +2299,69 @@ class FederationHandler(BaseHandler): ) if "valid" not in response or not response["valid"]: raise AuthError(403, "Third party certificate was invalid") + + @defer.inlineCallbacks + def _persist_events(self, event_and_contexts, backfilled=False): + """Persists events and tells the notifier/pushers about them, if + necessary. + + Args: + event_and_contexts(list[tuple[FrozenEvent, EventContext]]) + backfilled (bool): Whether these events are a result of + backfilling or not + + Returns: + Deferred + """ + max_stream_id = yield self.store.persist_events( + event_and_contexts, + backfilled=backfilled, + ) + + if not backfilled: # Never notify for backfilled events + for event, _ in event_and_contexts: + self._notify_persisted_event(event, max_stream_id) + + def _notify_persisted_event(self, event, max_stream_id): + """Checks to see if notifier/pushers should be notified about the + event or not. + + Args: + event (FrozenEvent) + max_stream_id (int): The max_stream_id returned by persist_events + """ + + extra_users = [] + if event.type == EventTypes.Member: + target_user_id = event.state_key + + # We notify for memberships if its an invite for one of our + # users + if event.internal_metadata.is_outlier(): + if event.membership != Membership.INVITE: + if not self.is_mine_id(target_user_id): + return + + target_user = UserID.from_string(target_user_id) + extra_users.append(target_user) + elif event.internal_metadata.is_outlier(): + return + + event_stream_id = event.internal_metadata.stream_ordering + self.notifier.on_new_room_event( + event, event_stream_id, max_stream_id, + extra_users=extra_users + ) + + logcontext.run_in_background( + self.pusher_pool.on_new_notifications, + event_stream_id, max_stream_id, + ) + + def _clean_room_for_join(self, room_id): + return self.store.clean_room_for_join(room_id) + + def user_joined_room(self, user, room_id): + """Called when a new user has joined the room + """ + return user_joined_room(self.distributor, user, room_id) diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 8c8aedb2b8..1d36d967c3 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py
@@ -26,7 +26,7 @@ from twisted.internet import defer from synapse.api.errors import ( CodeMessageException, Codes, - MatrixCodeMessageException, + HttpResponseException, SynapseError, ) @@ -85,7 +85,6 @@ class IdentityHandler(BaseHandler): ) defer.returnValue(None) - data = {} try: data = yield self.http_client.get_json( "https://%s%s" % ( @@ -94,11 +93,9 @@ class IdentityHandler(BaseHandler): ), {'sid': creds['sid'], 'client_secret': client_secret} ) - except MatrixCodeMessageException as e: + except HttpResponseException as e: logger.info("getValidated3pid failed with Matrix error: %r", e) - raise SynapseError(e.code, e.msg, e.errcode) - except CodeMessageException as e: - data = json.loads(e.msg) + raise e.to_synapse_error() if 'medium' in data: defer.returnValue(data) @@ -136,7 +133,7 @@ class IdentityHandler(BaseHandler): ) logger.debug("bound threepid %r to %s", creds, mxid) except CodeMessageException as e: - data = json.loads(e.msg) + data = json.loads(e.msg) # XXX WAT? defer.returnValue(data) @defer.inlineCallbacks @@ -209,12 +206,9 @@ class IdentityHandler(BaseHandler): params ) defer.returnValue(data) - except MatrixCodeMessageException as e: - logger.info("Proxied requestToken failed with Matrix error: %r", e) - raise SynapseError(e.code, e.msg, e.errcode) - except CodeMessageException as e: + except HttpResponseException as e: logger.info("Proxied requestToken failed: %r", e) - raise e + raise e.to_synapse_error() @defer.inlineCallbacks def requestMsisdnToken( @@ -244,9 +238,6 @@ class IdentityHandler(BaseHandler): params ) defer.returnValue(data) - except MatrixCodeMessageException as e: - logger.info("Proxied requestToken failed with Matrix error: %r", e) - raise SynapseError(e.code, e.msg, e.errcode) - except CodeMessageException as e: + except HttpResponseException as e: logger.info("Proxied requestToken failed: %r", e) - raise e + raise e.to_synapse_error() diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 859f6d2b2e..9af2e8f869 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py
@@ -17,7 +17,14 @@ import logging from twisted.internet import defer -from synapse.api.errors import AuthError, CodeMessageException, SynapseError +from synapse.api.errors import ( + AuthError, + CodeMessageException, + Codes, + StoreError, + SynapseError, +) +from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import UserID, get_domain_from_id from ._base import BaseHandler @@ -41,19 +48,24 @@ class ProfileHandler(BaseHandler): if hs.config.worker_app is None: self.clock.looping_call( - self._update_remote_profile_cache, self.PROFILE_UPDATE_MS, + self._start_update_remote_profile_cache, self.PROFILE_UPDATE_MS, ) @defer.inlineCallbacks def get_profile(self, user_id): target_user = UserID.from_string(user_id) if self.hs.is_mine(target_user): - displayname = yield self.store.get_profile_displayname( - target_user.localpart - ) - avatar_url = yield self.store.get_profile_avatar_url( - target_user.localpart - ) + try: + displayname = yield self.store.get_profile_displayname( + target_user.localpart + ) + avatar_url = yield self.store.get_profile_avatar_url( + target_user.localpart + ) + except StoreError as e: + if e.code == 404: + raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND) + raise defer.returnValue({ "displayname": displayname, @@ -73,7 +85,6 @@ class ProfileHandler(BaseHandler): except CodeMessageException as e: if e.code != 404: logger.exception("Failed to get displayname") - raise @defer.inlineCallbacks @@ -84,12 +95,17 @@ class ProfileHandler(BaseHandler): """ target_user = UserID.from_string(user_id) if self.hs.is_mine(target_user): - displayname = yield self.store.get_profile_displayname( - target_user.localpart - ) - avatar_url = yield self.store.get_profile_avatar_url( - target_user.localpart - ) + try: + displayname = yield self.store.get_profile_displayname( + target_user.localpart + ) + avatar_url = yield self.store.get_profile_avatar_url( + target_user.localpart + ) + except StoreError as e: + if e.code == 404: + raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND) + raise defer.returnValue({ "displayname": displayname, @@ -102,9 +118,14 @@ class ProfileHandler(BaseHandler): @defer.inlineCallbacks def get_displayname(self, target_user): if self.hs.is_mine(target_user): - displayname = yield self.store.get_profile_displayname( - target_user.localpart - ) + try: + displayname = yield self.store.get_profile_displayname( + target_user.localpart + ) + except StoreError as e: + if e.code == 404: + raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND) + raise defer.returnValue(displayname) else: @@ -121,7 +142,6 @@ class ProfileHandler(BaseHandler): except CodeMessageException as e: if e.code != 404: logger.exception("Failed to get displayname") - raise except Exception: logger.exception("Failed to get displayname") @@ -156,10 +176,14 @@ class ProfileHandler(BaseHandler): @defer.inlineCallbacks def get_avatar_url(self, target_user): if self.hs.is_mine(target_user): - avatar_url = yield self.store.get_profile_avatar_url( - target_user.localpart - ) - + try: + avatar_url = yield self.store.get_profile_avatar_url( + target_user.localpart + ) + except StoreError as e: + if e.code == 404: + raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND) + raise defer.returnValue(avatar_url) else: try: @@ -212,16 +236,20 @@ class ProfileHandler(BaseHandler): just_field = args.get("field", None) response = {} + try: + if just_field is None or just_field == "displayname": + response["displayname"] = yield self.store.get_profile_displayname( + user.localpart + ) - if just_field is None or just_field == "displayname": - response["displayname"] = yield self.store.get_profile_displayname( - user.localpart - ) - - if just_field is None or just_field == "avatar_url": - response["avatar_url"] = yield self.store.get_profile_avatar_url( - user.localpart - ) + if just_field is None or just_field == "avatar_url": + response["avatar_url"] = yield self.store.get_profile_avatar_url( + user.localpart + ) + except StoreError as e: + if e.code == 404: + raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND) + raise defer.returnValue(response) @@ -254,6 +282,12 @@ class ProfileHandler(BaseHandler): room_id, str(e.message) ) + def _start_update_remote_profile_cache(self): + return run_as_background_process( + "Update remote profile", self._update_remote_profile_cache, + ) + + @defer.inlineCallbacks def _update_remote_profile_cache(self): """Called periodically to check profiles of remote users we haven't checked in a while. diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 7caff0cbc8..289704b241 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py
@@ -45,7 +45,7 @@ class RegistrationHandler(BaseHandler): hs (synapse.server.HomeServer): """ super(RegistrationHandler, self).__init__(hs) - + self.hs = hs self.auth = hs.get_auth() self._auth_handler = hs.get_auth_handler() self.profile_handler = hs.get_profile_handler() @@ -131,7 +131,7 @@ class RegistrationHandler(BaseHandler): Args: localpart : The local part of the user ID to register. If None, one will be generated. - password (str) : The password to assign to this user so they can + password (unicode) : The password to assign to this user so they can login again. This can be None which means they cannot login again via a password (e.g. the user is an application service user). generate_token (bool): Whether a new access token should be @@ -144,6 +144,7 @@ class RegistrationHandler(BaseHandler): Raises: RegistrationError if there was a problem registering. """ + yield self._check_mau_limits() password_hash = None if password: password_hash = yield self.auth_handler().hash(password) @@ -288,6 +289,7 @@ class RegistrationHandler(BaseHandler): 400, "User ID can only contain characters a-z, 0-9, or '=_-./'", ) + yield self._check_mau_limits() user = UserID(localpart, self.hs.hostname) user_id = user.to_string() @@ -437,7 +439,7 @@ class RegistrationHandler(BaseHandler): """ if localpart is None: raise SynapseError(400, "Request must include user id") - + yield self._check_mau_limits() need_register = True try: @@ -531,3 +533,16 @@ class RegistrationHandler(BaseHandler): remote_room_hosts=remote_room_hosts, action="join", ) + + @defer.inlineCallbacks + def _check_mau_limits(self): + """ + Do not accept registrations if monthly active user limits exceeded + and limiting is enabled + """ + if self.hs.config.limit_usage_by_mau is True: + current_mau = yield self.store.count_monthly_users() + if current_mau >= self.hs.config.max_mau_value: + raise RegistrationError( + 403, "MAU Limit Exceeded", Codes.MAU_LIMIT_EXCEEDED + ) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 003b848c00..7b7804d9b2 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py
@@ -15,6 +15,7 @@ # limitations under the License. """Contains functions for performing events on rooms.""" +import itertools import logging import math import string @@ -401,7 +402,7 @@ class RoomContextHandler(object): self.store = hs.get_datastore() @defer.inlineCallbacks - def get_event_context(self, user, room_id, event_id, limit): + def get_event_context(self, user, room_id, event_id, limit, event_filter): """Retrieves events, pagination tokens and state around a given event in a room. @@ -411,6 +412,8 @@ class RoomContextHandler(object): event_id (str) limit (int): The maximum number of events to return in total (excluding state). + event_filter (Filter|None): the filter to apply to the events returned + (excluding the target event_id) Returns: dict, or None if the event isn't found @@ -443,7 +446,7 @@ class RoomContextHandler(object): ) results = yield self.store.get_events_around( - room_id, event_id, before_limit, after_limit + room_id, event_id, before_limit, after_limit, event_filter ) results["events_before"] = yield filter_evts(results["events_before"]) @@ -455,8 +458,23 @@ class RoomContextHandler(object): else: last_event_id = event_id + types = None + filtered_types = None + if event_filter and event_filter.lazy_load_members(): + members = set(ev.sender for ev in itertools.chain( + results["events_before"], + (results["event"],), + results["events_after"], + )) + filtered_types = [EventTypes.Member] + types = [(EventTypes.Member, member) for member in members] + + # XXX: why do we return the state as of the last event rather than the + # first? Shouldn't we be consistent with /sync? + # https://github.com/matrix-org/matrix-doc/issues/687 + state = yield self.store.get_state_for_events( - [last_event_id], None + [last_event_id], types, filtered_types=filtered_types, ) results["state"] = list(state[last_event_id].values()) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 6184737cd4..e13551d569 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py
@@ -728,6 +728,10 @@ class RoomMemberHandler(object): inviter_display_name = member_event.content.get("displayname", "") inviter_avatar_url = member_event.content.get("avatar_url", "") + # if user has no display name, default to their MXID + if not inviter_display_name: + inviter_display_name = user.to_string() + canonical_room_alias = "" canonical_alias_event = room_state.get((EventTypes.CanonicalAlias, "")) if canonical_alias_event: diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index 69ae9731d5..c464adbd0b 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py
@@ -287,7 +287,7 @@ class SearchHandler(BaseHandler): contexts = {} for event in allowed_events: res = yield self.store.get_events_around( - event.room_id, event.event_id, before_limit, after_limit + event.room_id, event.event_id, before_limit, after_limit, ) logger.info( diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index a006c952f2..ea148866e7 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py
@@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- -# Copyright 2015 - 2016 OpenMarket Ltd +# Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -25,6 +26,8 @@ from synapse.api.constants import EventTypes, Membership from synapse.push.clientformat import format_push_rules_for_user from synapse.types import RoomStreamToken from synapse.util.async import concurrently_execute +from synapse.util.caches.expiringcache import ExpiringCache +from synapse.util.caches.lrucache import LruCache from synapse.util.caches.response_cache import ResponseCache from synapse.util.logcontext import LoggingContext from synapse.util.metrics import Measure, measure_func @@ -34,6 +37,15 @@ logger = logging.getLogger(__name__) SYNC_RESPONSE_CACHE_MS = 2 * 60 * 1000 +# Store the cache that tracks which lazy-loaded members have been sent to a given +# client for no more than 30 minutes. +LAZY_LOADED_MEMBERS_CACHE_MAX_AGE = 30 * 60 * 1000 + +# Remember the last 100 members we sent to a client for the purposes of +# avoiding redundantly sending the same lazy-loaded members to the client +LAZY_LOADED_MEMBERS_CACHE_MAX_SIZE = 100 + + SyncConfig = collections.namedtuple("SyncConfig", [ "user", "filter_collection", @@ -184,6 +196,12 @@ class SyncHandler(object): ) self.state = hs.get_state_handler() + # ExpiringCache((User, Device)) -> LruCache(state_key => event_id) + self.lazy_loaded_members_cache = ExpiringCache( + "lazy_loaded_members_cache", self.clock, + max_len=0, expiry_ms=LAZY_LOADED_MEMBERS_CACHE_MAX_AGE, + ) + def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0, full_state=False): """Get the sync for a client if we have new data for it now. Otherwise @@ -419,29 +437,44 @@ class SyncHandler(object): )) @defer.inlineCallbacks - def get_state_after_event(self, event): + def get_state_after_event(self, event, types=None, filtered_types=None): """ Get the room state after the given event Args: event(synapse.events.EventBase): event of interest + types(list[(str, str|None)]|None): List of (type, state_key) tuples + which are used to filter the state fetched. If `state_key` is None, + all events are returned of the given type. + May be None, which matches any key. + filtered_types(list[str]|None): Only apply filtering via `types` to this + list of event types. Other types of events are returned unfiltered. + If None, `types` filtering is applied to all events. Returns: A Deferred map from ((type, state_key)->Event) """ - state_ids = yield self.store.get_state_ids_for_event(event.event_id) + state_ids = yield self.store.get_state_ids_for_event( + event.event_id, types, filtered_types=filtered_types, + ) if event.is_state(): state_ids = state_ids.copy() state_ids[(event.type, event.state_key)] = event.event_id defer.returnValue(state_ids) @defer.inlineCallbacks - def get_state_at(self, room_id, stream_position): + def get_state_at(self, room_id, stream_position, types=None, filtered_types=None): """ Get the room state at a particular stream position Args: room_id(str): room for which to get state stream_position(StreamToken): point at which to get state + types(list[(str, str|None)]|None): List of (type, state_key) tuples + which are used to filter the state fetched. If `state_key` is None, + all events are returned of the given type. + filtered_types(list[str]|None): Only apply filtering via `types` to this + list of event types. Other types of events are returned unfiltered. + If None, `types` filtering is applied to all events. Returns: A Deferred map from ((type, state_key)->Event) @@ -456,7 +489,9 @@ class SyncHandler(object): if last_events: last_event = last_events[-1] - state = yield self.get_state_after_event(last_event) + state = yield self.get_state_after_event( + last_event, types, filtered_types=filtered_types, + ) else: # no events in this room - so presumably no state @@ -488,59 +523,129 @@ class SyncHandler(object): # TODO(mjark) Check for new redactions in the state events. with Measure(self.clock, "compute_state_delta"): + + types = None + filtered_types = None + + lazy_load_members = sync_config.filter_collection.lazy_load_members() + include_redundant_members = ( + sync_config.filter_collection.include_redundant_members() + ) + + if lazy_load_members: + # We only request state for the members needed to display the + # timeline: + + types = [ + (EventTypes.Member, state_key) + for state_key in set( + event.sender # FIXME: we also care about invite targets etc. + for event in batch.events + ) + ] + + # only apply the filtering to room members + filtered_types = [EventTypes.Member] + + timeline_state = { + (event.type, event.state_key): event.event_id + for event in batch.events if event.is_state() + } + if full_state: if batch: current_state_ids = yield self.store.get_state_ids_for_event( - batch.events[-1].event_id + batch.events[-1].event_id, types=types, + filtered_types=filtered_types, ) state_ids = yield self.store.get_state_ids_for_event( - batch.events[0].event_id + batch.events[0].event_id, types=types, + filtered_types=filtered_types, ) + else: current_state_ids = yield self.get_state_at( - room_id, stream_position=now_token + room_id, stream_position=now_token, types=types, + filtered_types=filtered_types, ) state_ids = current_state_ids - timeline_state = { - (event.type, event.state_key): event.event_id - for event in batch.events if event.is_state() - } - state_ids = _calculate_state( timeline_contains=timeline_state, timeline_start=state_ids, previous={}, current=current_state_ids, + lazy_load_members=lazy_load_members, ) elif batch.limited: state_at_previous_sync = yield self.get_state_at( - room_id, stream_position=since_token + room_id, stream_position=since_token, types=types, + filtered_types=filtered_types, ) current_state_ids = yield self.store.get_state_ids_for_event( - batch.events[-1].event_id + batch.events[-1].event_id, types=types, + filtered_types=filtered_types, ) state_at_timeline_start = yield self.store.get_state_ids_for_event( - batch.events[0].event_id + batch.events[0].event_id, types=types, + filtered_types=filtered_types, ) - timeline_state = { - (event.type, event.state_key): event.event_id - for event in batch.events if event.is_state() - } - state_ids = _calculate_state( timeline_contains=timeline_state, timeline_start=state_at_timeline_start, previous=state_at_previous_sync, current=current_state_ids, + lazy_load_members=lazy_load_members, ) else: state_ids = {} + if lazy_load_members: + if types: + state_ids = yield self.store.get_state_ids_for_event( + batch.events[0].event_id, types=types, + filtered_types=filtered_types, + ) + + if lazy_load_members and not include_redundant_members: + cache_key = (sync_config.user.to_string(), sync_config.device_id) + cache = self.lazy_loaded_members_cache.get(cache_key) + if cache is None: + logger.debug("creating LruCache for %r", cache_key) + cache = LruCache(LAZY_LOADED_MEMBERS_CACHE_MAX_SIZE) + self.lazy_loaded_members_cache[cache_key] = cache + else: + logger.debug("found LruCache for %r", cache_key) + + # if it's a new sync sequence, then assume the client has had + # amnesia and doesn't want any recent lazy-loaded members + # de-duplicated. + if since_token is None: + logger.debug("clearing LruCache for %r", cache_key) + cache.clear() + else: + # only send members which aren't in our LruCache (either + # because they're new to this client or have been pushed out + # of the cache) + logger.debug("filtering state from %r...", state_ids) + state_ids = { + t: event_id + for t, event_id in state_ids.iteritems() + if cache.get(t[1]) != event_id + } + logger.debug("...to %r", state_ids) + + # add any member IDs we are about to send into our LruCache + for t, event_id in itertools.chain( + state_ids.items(), + timeline_state.items(), + ): + if t[0] == EventTypes.Member: + cache.set(t[1], event_id) state = {} if state_ids: @@ -1451,7 +1556,9 @@ def _action_has_highlight(actions): return False -def _calculate_state(timeline_contains, timeline_start, previous, current): +def _calculate_state( + timeline_contains, timeline_start, previous, current, lazy_load_members, +): """Works out what state to include in a sync response. Args: @@ -1460,6 +1567,9 @@ def _calculate_state(timeline_contains, timeline_start, previous, current): previous (dict): state at the end of the previous sync (or empty dict if this is an initial sync) current (dict): state at the end of the timeline + lazy_load_members (bool): whether to return members from timeline_start + or not. assumes that timeline_start has already been filtered to + include only the members the client needs to know about. Returns: dict @@ -1475,9 +1585,25 @@ def _calculate_state(timeline_contains, timeline_start, previous, current): } c_ids = set(e for e in current.values()) - tc_ids = set(e for e in timeline_contains.values()) - p_ids = set(e for e in previous.values()) ts_ids = set(e for e in timeline_start.values()) + p_ids = set(e for e in previous.values()) + tc_ids = set(e for e in timeline_contains.values()) + + # If we are lazyloading room members, we explicitly add the membership events + # for the senders in the timeline into the state block returned by /sync, + # as we may not have sent them to the client before. We find these membership + # events by filtering them out of timeline_start, which has already been filtered + # to only include membership events for the senders in the timeline. + # In practice, we can do this by removing them from the p_ids list, + # which is the list of relevant state we know we have already sent to the client. + # see https://github.com/matrix-org/synapse/pull/2970 + # /files/efcdacad7d1b7f52f879179701c7e0d9b763511f#r204732809 + + if lazy_load_members: + p_ids.difference_update( + e for t, e in timeline_start.iteritems() + if t[0] == EventTypes.Member + ) state_ids = ((c_ids | ts_ids) - p_ids) - tc_ids diff --git a/synapse/http/client.py b/synapse/http/client.py
index 25b6307884..3771e0b3f6 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py
@@ -39,12 +39,7 @@ from twisted.web.client import ( from twisted.web.http import PotentialDataLoss from twisted.web.http_headers import Headers -from synapse.api.errors import ( - CodeMessageException, - Codes, - MatrixCodeMessageException, - SynapseError, -) +from synapse.api.errors import Codes, HttpResponseException, SynapseError from synapse.http import cancelled_to_request_timed_out_error, redact_uri from synapse.http.endpoint import SpiderEndpoint from synapse.util.async import add_timeout_to_deferred @@ -132,6 +127,11 @@ class SimpleHttpClient(object): Returns: Deferred[object]: parsed json + + Raises: + HttpResponseException: On a non-2xx HTTP response. + + ValueError: if the response was not JSON """ # TODO: Do we ever want to log message contents? @@ -155,7 +155,10 @@ class SimpleHttpClient(object): body = yield make_deferred_yieldable(readBody(response)) - defer.returnValue(json.loads(body)) + if 200 <= response.code < 300: + defer.returnValue(json.loads(body)) + else: + raise HttpResponseException(response.code, response.phrase, body) @defer.inlineCallbacks def post_json_get_json(self, uri, post_json, headers=None): @@ -169,6 +172,11 @@ class SimpleHttpClient(object): Returns: Deferred[object]: parsed json + + Raises: + HttpResponseException: On a non-2xx HTTP response. + + ValueError: if the response was not JSON """ json_str = encode_canonical_json(post_json) @@ -193,9 +201,7 @@ class SimpleHttpClient(object): if 200 <= response.code < 300: defer.returnValue(json.loads(body)) else: - raise self._exceptionFromFailedRequest(response, body) - - defer.returnValue(json.loads(body)) + raise HttpResponseException(response.code, response.phrase, body) @defer.inlineCallbacks def get_json(self, uri, args={}, headers=None): @@ -213,14 +219,12 @@ class SimpleHttpClient(object): Deferred: Succeeds when we get *any* 2xx HTTP response, with the HTTP body as JSON. Raises: - On a non-2xx HTTP response. The response body will be used as the - error message. + HttpResponseException On a non-2xx HTTP response. + + ValueError: if the response was not JSON """ - try: - body = yield self.get_raw(uri, args, headers=headers) - defer.returnValue(json.loads(body)) - except CodeMessageException as e: - raise self._exceptionFromFailedRequest(e.code, e.msg) + body = yield self.get_raw(uri, args, headers=headers) + defer.returnValue(json.loads(body)) @defer.inlineCallbacks def put_json(self, uri, json_body, args={}, headers=None): @@ -239,7 +243,9 @@ class SimpleHttpClient(object): Deferred: Succeeds when we get *any* 2xx HTTP response, with the HTTP body as JSON. Raises: - On a non-2xx HTTP response. + HttpResponseException On a non-2xx HTTP response. + + ValueError: if the response was not JSON """ if len(args): query_bytes = urllib.urlencode(args, True) @@ -266,10 +272,7 @@ class SimpleHttpClient(object): if 200 <= response.code < 300: defer.returnValue(json.loads(body)) else: - # NB: This is explicitly not json.loads(body)'d because the contract - # of CodeMessageException is a *string* message. Callers can always - # load it into JSON if they want. - raise CodeMessageException(response.code, body) + raise HttpResponseException(response.code, response.phrase, body) @defer.inlineCallbacks def get_raw(self, uri, args={}, headers=None): @@ -287,8 +290,7 @@ class SimpleHttpClient(object): Deferred: Succeeds when we get *any* 2xx HTTP response, with the HTTP body at text. Raises: - On a non-2xx HTTP response. The response body will be used as the - error message. + HttpResponseException on a non-2xx HTTP response. """ if len(args): query_bytes = urllib.urlencode(args, True) @@ -311,16 +313,7 @@ class SimpleHttpClient(object): if 200 <= response.code < 300: defer.returnValue(body) else: - raise CodeMessageException(response.code, body) - - def _exceptionFromFailedRequest(self, response, body): - try: - jsonBody = json.loads(body) - errcode = jsonBody['errcode'] - error = jsonBody['error'] - return MatrixCodeMessageException(response.code, error, errcode) - except (ValueError, KeyError): - return CodeMessageException(response.code, body) + raise HttpResponseException(response.code, response.phrase, body) # XXX: FIXME: This is horribly copy-pasted from matrixfederationclient. # The two should be factored out. diff --git a/synapse/http/server.py b/synapse/http/server.py
index c70fdbdfd2..6dacb31037 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py
@@ -13,12 +13,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import cgi import collections import logging -import urllib -from six.moves import http_client +from six import PY3 +from six.moves import http_client, urllib from canonicaljson import encode_canonical_json, encode_pretty_printed_json, json @@ -35,7 +36,6 @@ from synapse.api.errors import ( Codes, SynapseError, UnrecognizedRequestError, - cs_exception, ) from synapse.http.request_metrics import requests_counter from synapse.util.caches import intern_dict @@ -76,16 +76,13 @@ def wrap_json_request_handler(h): def wrapped_request_handler(self, request): try: yield h(self, request) - except CodeMessageException as e: + except SynapseError as e: code = e.code - if isinstance(e, SynapseError): - logger.info( - "%s SynapseError: %s - %s", request, code, e.msg - ) - else: - logger.exception(e) + logger.info( + "%s SynapseError: %s - %s", request, code, e.msg + ) respond_with_json( - request, code, cs_exception(e), send_cors=True, + request, code, e.error_dict(), send_cors=True, pretty_print=_request_user_agent_is_curl(request), ) @@ -264,6 +261,7 @@ class JsonResource(HttpServer, resource.Resource): self.hs = hs def register_paths(self, method, path_patterns, callback): + method = method.encode("utf-8") # method is bytes on py3 for path_pattern in path_patterns: logger.debug("Registering for %s %s", method, path_pattern.pattern) self.path_regexs.setdefault(method, []).append( @@ -296,8 +294,19 @@ class JsonResource(HttpServer, resource.Resource): # here. If it throws an exception, that is handled by the wrapper # installed by @request_handler. + def _unquote(s): + if PY3: + # On Python 3, unquote is unicode -> unicode + return urllib.parse.unquote(s) + else: + # On Python 2, unquote is bytes -> bytes We need to encode the + # URL again (as it was decoded by _get_handler_for request), as + # ASCII because it's a URL, and then decode it to get the UTF-8 + # characters that were quoted. + return urllib.parse.unquote(s.encode('ascii')).decode('utf8') + kwargs = intern_dict({ - name: urllib.unquote(value).decode("UTF-8") if value else value + name: _unquote(value) if value else value for name, value in group_dict.items() }) @@ -313,9 +322,9 @@ class JsonResource(HttpServer, resource.Resource): request (twisted.web.http.Request): Returns: - Tuple[Callable, dict[str, str]]: callback method, and the dict - mapping keys to path components as specified in the handler's - path match regexp. + Tuple[Callable, dict[unicode, unicode]]: callback method, and the + dict mapping keys to path components as specified in the + handler's path match regexp. The callback will normally be a method registered via register_paths, so will return (possibly via Deferred) either @@ -327,7 +336,7 @@ class JsonResource(HttpServer, resource.Resource): # Loop through all the registered callbacks to check if the method # and path regex match for path_entry in self.path_regexs.get(request.method, []): - m = path_entry.pattern.match(request.path) + m = path_entry.pattern.match(request.path.decode('ascii')) if m: # We found a match! return path_entry.callback, m.groupdict() @@ -383,7 +392,7 @@ class RootRedirect(resource.Resource): self.url = path def render_GET(self, request): - return redirectTo(self.url, request) + return redirectTo(self.url.encode('ascii'), request) def getChild(self, name, request): if len(name) == 0: @@ -404,12 +413,14 @@ def respond_with_json(request, code, json_object, send_cors=False, return if pretty_print: - json_bytes = encode_pretty_printed_json(json_object) + "\n" + json_bytes = (encode_pretty_printed_json(json_object) + "\n" + ).encode("utf-8") else: if canonical_json or synapse.events.USE_FROZEN_DICTS: + # canonicaljson already encodes to bytes json_bytes = encode_canonical_json(json_object) else: - json_bytes = json.dumps(json_object) + json_bytes = json.dumps(json_object).encode("utf-8") return respond_with_json_bytes( request, code, json_bytes, diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index 882816dc8f..69f7085291 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py
@@ -171,8 +171,16 @@ def parse_json_value_from_request(request, allow_empty_body=False): if not content_bytes and allow_empty_body: return None + # Decode to Unicode so that simplejson will return Unicode strings on + # Python 2 try: - content = json.loads(content_bytes) + content_unicode = content_bytes.decode('utf8') + except UnicodeDecodeError: + logger.warn("Unable to decode UTF-8") + raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON) + + try: + content = json.loads(content_unicode) except Exception as e: logger.warn("Unable to parse JSON: %s", e) raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON) diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py
index 9d820e44a6..ce678d5f75 100644 --- a/synapse/metrics/background_process_metrics.py +++ b/synapse/metrics/background_process_metrics.py
@@ -151,13 +151,19 @@ def run_as_background_process(desc, func, *args, **kwargs): This should be used to wrap processes which are fired off to run in the background, instead of being associated with a particular request. + It returns a Deferred which completes when the function completes, but it doesn't + follow the synapse logcontext rules, which makes it appropriate for passing to + clock.looping_call and friends (or for firing-and-forgetting in the middle of a + normal synapse inlineCallbacks function). + Args: desc (str): a description for this background process type func: a function, which may return a Deferred args: positional args for func kwargs: keyword args for func - Returns: None + Returns: Deferred which returns the result of func, but note that it does not + follow the synapse logcontext rules. """ @defer.inlineCallbacks def run(): @@ -176,4 +182,4 @@ def run_as_background_process(desc, func, *args, **kwargs): _background_processes[desc].remove(proc) with PreserveLoggingContext(): - run() + return run() diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py
index 6bfc8a5b89..7a3cfb159c 100644 --- a/synapse/replication/http/membership.py +++ b/synapse/replication/http/membership.py
@@ -18,7 +18,7 @@ import re from twisted.internet import defer -from synapse.api.errors import MatrixCodeMessageException, SynapseError +from synapse.api.errors import HttpResponseException from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.types import Requester, UserID from synapse.util.distributor import user_joined_room, user_left_room @@ -56,11 +56,11 @@ def remote_join(client, host, port, requester, remote_room_hosts, try: result = yield client.post_json_get_json(uri, payload) - except MatrixCodeMessageException as e: + except HttpResponseException as e: # We convert to SynapseError as we know that it was a SynapseError # on the master process that we should send to the client. (And # importantly, not stack traces everywhere) - raise SynapseError(e.code, e.msg, e.errcode) + raise e.to_synapse_error() defer.returnValue(result) @@ -92,11 +92,11 @@ def remote_reject_invite(client, host, port, requester, remote_room_hosts, try: result = yield client.post_json_get_json(uri, payload) - except MatrixCodeMessageException as e: + except HttpResponseException as e: # We convert to SynapseError as we know that it was a SynapseError # on the master process that we should send to the client. (And # importantly, not stack traces everywhere) - raise SynapseError(e.code, e.msg, e.errcode) + raise e.to_synapse_error() defer.returnValue(result) @@ -131,11 +131,11 @@ def get_or_register_3pid_guest(client, host, port, requester, try: result = yield client.post_json_get_json(uri, payload) - except MatrixCodeMessageException as e: + except HttpResponseException as e: # We convert to SynapseError as we know that it was a SynapseError # on the master process that we should send to the client. (And # importantly, not stack traces everywhere) - raise SynapseError(e.code, e.msg, e.errcode) + raise e.to_synapse_error() defer.returnValue(result) @@ -165,11 +165,11 @@ def notify_user_membership_change(client, host, port, user_id, room_id, change): try: result = yield client.post_json_get_json(uri, payload) - except MatrixCodeMessageException as e: + except HttpResponseException as e: # We convert to SynapseError as we know that it was a SynapseError # on the master process that we should send to the client. (And # importantly, not stack traces everywhere) - raise SynapseError(e.code, e.msg, e.errcode) + raise e.to_synapse_error() defer.returnValue(result) diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py
index 5227bc333d..d3509dc288 100644 --- a/synapse/replication/http/send_event.py +++ b/synapse/replication/http/send_event.py
@@ -18,11 +18,7 @@ import re from twisted.internet import defer -from synapse.api.errors import ( - CodeMessageException, - MatrixCodeMessageException, - SynapseError, -) +from synapse.api.errors import CodeMessageException, HttpResponseException from synapse.events import FrozenEvent from synapse.events.snapshot import EventContext from synapse.http.servlet import RestServlet, parse_json_object_from_request @@ -83,11 +79,11 @@ def send_event_to_master(clock, store, client, host, port, requester, event, con # If we timed out we probably don't need to worry about backing # off too much, but lets just wait a little anyway. yield clock.sleep(1) - except MatrixCodeMessageException as e: + except HttpResponseException as e: # We convert to SynapseError as we know that it was a SynapseError # on the master process that we should send to the client. (And # importantly, not stack traces everywhere) - raise SynapseError(e.code, e.msg, e.errcode) + raise e.to_synapse_error() defer.returnValue(result) diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index e592ab57bf..970e94313e 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py
@@ -192,7 +192,7 @@ class ReplicationClientHandler(object): """Returns a deferred that is resolved when we receive a SYNC command with given data. - Used by tests. + [Not currently] used by tests. """ return self.awaiting_syncs.setdefault(data, defer.Deferred()) diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index 611fb66e1d..fd59f1595f 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py
@@ -25,6 +25,7 @@ from twisted.internet import defer from twisted.internet.protocol import Factory from synapse.metrics import LaterGauge +from synapse.metrics.background_process_metrics import run_as_background_process from synapse.util.metrics import Measure, measure_func from .protocol import ServerReplicationStreamProtocol @@ -117,7 +118,6 @@ class ReplicationStreamer(object): for conn in self.connections: conn.send_error("server shutting down") - @defer.inlineCallbacks def on_notifier_poke(self): """Checks if there is actually any new data and sends it to the connections if there are. @@ -132,14 +132,16 @@ class ReplicationStreamer(object): stream.discard_updates_and_advance() return - # If we're in the process of checking for new updates, mark that fact - # and return + self.pending_updates = True + if self.is_looping: - logger.debug("Noitifier poke loop already running") - self.pending_updates = True + logger.debug("Notifier poke loop already running") return - self.pending_updates = True + run_as_background_process("replication_notifier", self._run_notifier_loop) + + @defer.inlineCallbacks + def _run_notifier_loop(self): self.is_looping = True try: diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py
index 6d276a109a..698e34dda3 100644 --- a/synapse/rest/client/v1/admin.py +++ b/synapse/rest/client/v1/admin.py
@@ -18,6 +18,7 @@ import hashlib import hmac import logging +from six import text_type from six.moves import http_client from twisted.internet import defer @@ -131,7 +132,10 @@ class UserRegisterServlet(ClientV1RestServlet): 400, "username must be specified", errcode=Codes.BAD_JSON, ) else: - if (not isinstance(body['username'], str) or len(body['username']) > 512): + if ( + not isinstance(body['username'], text_type) + or len(body['username']) > 512 + ): raise SynapseError(400, "Invalid username") username = body["username"].encode("utf-8") @@ -143,7 +147,10 @@ class UserRegisterServlet(ClientV1RestServlet): 400, "password must be specified", errcode=Codes.BAD_JSON, ) else: - if (not isinstance(body['password'], str) or len(body['password']) > 512): + if ( + not isinstance(body['password'], text_type) + or len(body['password']) > 512 + ): raise SynapseError(400, "Invalid password") password = body["password"].encode("utf-8") @@ -166,17 +173,18 @@ class UserRegisterServlet(ClientV1RestServlet): want_mac.update(b"admin" if admin else b"notadmin") want_mac = want_mac.hexdigest() - if not hmac.compare_digest(want_mac, got_mac): - raise SynapseError( - 403, "HMAC incorrect", - ) + if not hmac.compare_digest(want_mac, got_mac.encode('ascii')): + raise SynapseError(403, "HMAC incorrect") # Reuse the parts of RegisterRestServlet to reduce code duplication from synapse.rest.client.v2_alpha.register import RegisterRestServlet + register = RegisterRestServlet(self.hs) (user_id, _) = yield register.registration_handler.register( - localpart=username.lower(), password=password, admin=bool(admin), + localpart=body['username'].lower(), + password=body["password"], + admin=bool(admin), generate_token=False, ) diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py
index 69dcd618cb..97733f3026 100644 --- a/synapse/rest/client/v1/directory.py +++ b/synapse/rest/client/v1/directory.py
@@ -18,7 +18,7 @@ import logging from twisted.internet import defer -from synapse.api.errors import AuthError, Codes, SynapseError +from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError from synapse.http.servlet import parse_json_object_from_request from synapse.types import RoomAlias @@ -159,7 +159,7 @@ class ClientDirectoryListServer(ClientV1RestServlet): def on_GET(self, request, room_id): room = yield self.store.get_room(room_id) if room is None: - raise SynapseError(400, "Unknown room") + raise NotFoundError("Unknown room") defer.returnValue((200, { "visibility": "public" if room["is_public"] else "private" diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index ee1feeec69..fa5989e74e 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py
@@ -531,11 +531,20 @@ class RoomEventContextServlet(ClientV1RestServlet): limit = parse_integer(request, "limit", default=10) + # picking the API shape for symmetry with /messages + filter_bytes = parse_string(request, "filter") + if filter_bytes: + filter_json = urlparse.unquote(filter_bytes).decode("UTF-8") + event_filter = Filter(json.loads(filter_json)) + else: + event_filter = None + results = yield self.room_context_handler.get_event_context( requester.user, room_id, event_id, limit, + event_filter, ) if not results: diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index d6cf915d86..2f64155d13 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py
@@ -193,15 +193,15 @@ class RegisterRestServlet(RestServlet): def on_POST(self, request): body = parse_json_object_from_request(request) - kind = "user" - if "kind" in request.args: - kind = request.args["kind"][0] + kind = b"user" + if b"kind" in request.args: + kind = request.args[b"kind"][0] - if kind == "guest": + if kind == b"guest": ret = yield self._do_guest_registration(body) defer.returnValue(ret) return - elif kind != "user": + elif kind != b"user": raise UnrecognizedRequestError( "Do not understand membership kind: %s" % (kind,) ) @@ -389,8 +389,8 @@ class RegisterRestServlet(RestServlet): assert_params_in_dict(params, ["password"]) desired_username = params.get("username", None) - new_password = params.get("password", None) guest_access_token = params.get("guest_access_token", None) + new_password = params.get("password", None) if desired_username is not None: desired_username = desired_username.lower() diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 30242c525a..8fb413d825 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py
@@ -35,6 +35,7 @@ from synapse.api.errors import ( SynapseError, ) from synapse.http.matrixfederationclient import MatrixFederationHttpClient +from synapse.metrics.background_process_metrics import run_as_background_process from synapse.util.async import Linearizer from synapse.util.logcontext import make_deferred_yieldable from synapse.util.retryutils import NotRetryingDestination @@ -100,10 +101,15 @@ class MediaRepository(object): ) self.clock.looping_call( - self._update_recently_accessed, + self._start_update_recently_accessed, UPDATE_RECENTLY_ACCESSED_TS, ) + def _start_update_recently_accessed(self): + return run_as_background_process( + "update_recently_accessed_media", self._update_recently_accessed, + ) + @defer.inlineCallbacks def _update_recently_accessed(self): remote_media = self.recently_accessed_remotes @@ -373,7 +379,7 @@ class MediaRepository(object): logger.warn("HTTP error fetching remote media %s/%s: %s", server_name, media_id, e.response) if e.code == twisted.web.http.NOT_FOUND: - raise SynapseError.from_http_response_exception(e) + raise e.to_synapse_error() raise SynapseError(502, "Failed to fetch remote media") except SynapseError: diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
index b25993fcb5..a6189224ee 100644 --- a/synapse/rest/media/v1/media_storage.py +++ b/synapse/rest/media/v1/media_storage.py
@@ -177,7 +177,7 @@ class MediaStorage(object): if res: with res: consumer = BackgroundFileConsumer( - open(local_path, "w"), self.hs.get_reactor()) + open(local_path, "wb"), self.hs.get_reactor()) yield res.write_to_consumer(consumer) yield consumer.wait() defer.returnValue(local_path) diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index b70b15c4c2..27aa0def2f 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -41,6 +41,7 @@ from synapse.http.server import ( wrap_json_request_handler, ) from synapse.http.servlet import parse_integer, parse_string +from synapse.metrics.background_process_metrics import run_as_background_process from synapse.util.async import ObservableDeferred from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.logcontext import make_deferred_yieldable, run_in_background @@ -81,7 +82,7 @@ class PreviewUrlResource(Resource): self._cache.start() self._cleaner_loop = self.clock.looping_call( - self._expire_url_cache_data, 10 * 1000 + self._start_expire_url_cache_data, 10 * 1000, ) def render_OPTIONS(self, request): @@ -371,6 +372,11 @@ class PreviewUrlResource(Resource): "etag": headers["ETag"][0] if "ETag" in headers else None, }) + def _start_expire_url_cache_data(self): + return run_as_background_process( + "expire_url_cache_data", self._expire_url_cache_data, + ) + @defer.inlineCallbacks def _expire_url_cache_data(self): """Clean up expired url cache content, media and thumbnails. diff --git a/synapse/secrets.py b/synapse/secrets.py
index f397daaa5e..f05e9ea535 100644 --- a/synapse/secrets.py +++ b/synapse/secrets.py
@@ -20,17 +20,16 @@ See https://docs.python.org/3/library/secrets.html#module-secrets for the API used in Python 3.6, and the API emulated in Python 2.7. """ -import six +import sys -if six.PY3: +# secrets is available since python 3.6 +if sys.version_info[0:2] >= (3, 6): import secrets def Secrets(): return secrets - else: - import os import binascii diff --git a/synapse/state.py b/synapse/state.py
index 033f55d967..e1092b97a9 100644 --- a/synapse/state.py +++ b/synapse/state.py
@@ -577,7 +577,7 @@ def _make_state_cache_entry( def _ordered_events(events): def key_func(e): - return -int(e.depth), hashlib.sha1(e.event_id.encode()).hexdigest() + return -int(e.depth), hashlib.sha1(e.event_id.encode('ascii')).hexdigest() return sorted(events, key=key_func) diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index ba88a54979..134e4a80f1 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py
@@ -66,6 +66,7 @@ class DataStore(RoomMemberStore, RoomStore, PresenceStore, TransactionStore, DirectoryStore, KeyStore, StateStore, SignatureStore, ApplicationServiceStore, + EventsStore, EventFederationStore, MediaRepositoryStore, RejectionsStore, @@ -73,7 +74,6 @@ class DataStore(RoomMemberStore, RoomStore, PusherStore, PushRuleStore, ApplicationServiceTransactionStore, - EventsStore, ReceiptsStore, EndToEndKeyStore, SearchStore, @@ -94,6 +94,7 @@ class DataStore(RoomMemberStore, RoomStore, self._clock = hs.get_clock() self.database_engine = hs.database_engine + self.db_conn = db_conn self._stream_id_gen = StreamIdGenerator( db_conn, "events", "stream_ordering", extra_tables=[("local_invites", "stream_id")] @@ -266,6 +267,31 @@ class DataStore(RoomMemberStore, RoomStore, return self.runInteraction("count_users", _count_users) + def count_monthly_users(self): + """Counts the number of users who used this homeserver in the last 30 days + + This method should be refactored with count_daily_users - the only + reason not to is waiting on definition of mau + + Returns: + Defered[int] + """ + def _count_monthly_users(txn): + thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30) + sql = """ + SELECT COALESCE(count(*), 0) FROM ( + SELECT user_id FROM user_ips + WHERE last_seen > ? + GROUP BY user_id + ) u + """ + + txn.execute(sql, (thirty_days_ago,)) + count, = txn.fetchone() + return count + + return self.runInteraction("count_monthly_users", _count_monthly_users) + def count_r30_users(self): """ Counts the number of 30 day retained users, defined as:- diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 1d41d8d445..44f37b4c1e 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py
@@ -311,6 +311,12 @@ class SQLBaseStore(object): after_callbacks = [] exception_callbacks = [] + if LoggingContext.current_context() == LoggingContext.sentinel: + logger.warn( + "Starting db txn '%s' from sentinel context", + desc, + ) + try: result = yield self.runWithConnection( self._new_transaction, diff --git a/synapse/storage/appservice.py b/synapse/storage/appservice.py
index 9f12b360bc..31248d5e06 100644 --- a/synapse/storage/appservice.py +++ b/synapse/storage/appservice.py
@@ -22,7 +22,7 @@ from twisted.internet import defer from synapse.appservice import AppServiceTransaction from synapse.config.appservice import load_appservices -from synapse.storage.events import EventsWorkerStore +from synapse.storage.events_worker import EventsWorkerStore from ._base import SQLBaseStore diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py
index cf796242b8..d146ae79aa 100644 --- a/synapse/storage/client_ips.py +++ b/synapse/storage/client_ips.py
@@ -102,7 +102,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): to_update, ) - run_as_background_process( + return run_as_background_process( "update_client_ips", update, ) diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py
index cc3cdf2ebc..c0943ecf91 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py
@@ -21,6 +21,7 @@ from canonicaljson import json from twisted.internet import defer from synapse.api.errors import StoreError +from synapse.metrics.background_process_metrics import run_as_background_process from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList from ._base import Cache, SQLBaseStore @@ -711,6 +712,9 @@ class DeviceStore(SQLBaseStore): logger.info("Pruned %d device list outbound pokes", txn.rowcount) - return self.runInteraction( - "_prune_old_outbound_device_pokes", _prune_txn + return run_as_background_process( + "prune_old_outbound_device_pokes", + self.runInteraction, + "_prune_old_outbound_device_pokes", + _prune_txn, ) diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
index 7cd77c1c29..24345b20a6 100644 --- a/synapse/storage/event_federation.py +++ b/synapse/storage/event_federation.py
@@ -23,8 +23,9 @@ from unpaddedbase64 import encode_base64 from twisted.internet import defer from synapse.api.errors import StoreError +from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore -from synapse.storage.events import EventsWorkerStore +from synapse.storage.events_worker import EventsWorkerStore from synapse.storage.signatures import SignatureWorkerStore from synapse.util.caches.descriptors import cached @@ -113,9 +114,9 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, sql = ( "SELECT b.event_id, MAX(e.depth) FROM events as e" " INNER JOIN event_edges as g" - " ON g.event_id = e.event_id AND g.room_id = e.room_id" + " ON g.event_id = e.event_id" " INNER JOIN event_backward_extremities as b" - " ON g.prev_event_id = b.event_id AND g.room_id = b.room_id" + " ON g.prev_event_id = b.event_id" " WHERE b.room_id = ? AND g.is_state is ?" " GROUP BY b.event_id" ) @@ -329,8 +330,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, "SELECT depth, prev_event_id FROM event_edges" " INNER JOIN events" " ON prev_event_id = events.event_id" - " AND event_edges.room_id = events.room_id" - " WHERE event_edges.room_id = ? AND event_edges.event_id = ?" + " WHERE event_edges.event_id = ?" " AND event_edges.is_state = ?" " LIMIT ?" ) @@ -365,7 +365,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, txn.execute( query, - (room_id, event_id, False, limit - len(event_results)) + (event_id, False, limit - len(event_results)) ) for row in txn: @@ -402,7 +402,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, query = ( "SELECT prev_event_id FROM event_edges " - "WHERE room_id = ? AND event_id = ? AND is_state = ? " + "WHERE event_id = ? AND is_state = ? " "LIMIT ?" ) @@ -411,7 +411,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, for event_id in front: txn.execute( query, - (room_id, event_id, False, limit - len(event_results)) + (event_id, False, limit - len(event_results)) ) for e_id, in txn: @@ -447,7 +447,7 @@ class EventFederationStore(EventFederationWorkerStore): ) hs.get_clock().looping_call( - self._delete_old_forward_extrem_cache, 60 * 60 * 1000 + self._delete_old_forward_extrem_cache, 60 * 60 * 1000, ) def _update_min_depth_for_room_txn(self, txn, room_id, depth): @@ -549,9 +549,11 @@ class EventFederationStore(EventFederationWorkerStore): sql, (self.stream_ordering_month_ago, self.stream_ordering_month_ago,) ) - return self.runInteraction( + return run_as_background_process( + "delete_old_forward_extrem_cache", + self.runInteraction, "_delete_old_forward_extrem_cache", - _delete_old_forward_extrem_cache_txn + _delete_old_forward_extrem_cache_txn, ) def clean_room_for_join(self, room_id): diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py
index 29b511ae5e..6840320641 100644 --- a/synapse/storage/event_push_actions.py +++ b/synapse/storage/event_push_actions.py
@@ -22,6 +22,7 @@ from canonicaljson import json from twisted.internet import defer +from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import LoggingTransaction, SQLBaseStore from synapse.util.caches.descriptors import cachedInlineCallbacks @@ -458,11 +459,12 @@ class EventPushActionsWorkerStore(SQLBaseStore): "Error removing push actions after event persistence failure", ) - @defer.inlineCallbacks def _find_stream_orderings_for_times(self): - yield self.runInteraction( + return run_as_background_process( + "event_push_action_stream_orderings", + self.runInteraction, "_find_stream_orderings_for_times", - self._find_stream_orderings_for_times_txn + self._find_stream_orderings_for_times_txn, ) def _find_stream_orderings_for_times_txn(self, txn): @@ -604,7 +606,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore): self._doing_notif_rotation = False self._rotate_notif_loop = self._clock.looping_call( - self._rotate_notifs, 30 * 60 * 1000 + self._start_rotate_notifs, 30 * 60 * 1000, ) def _set_push_actions_for_event_and_users_txn(self, txn, events_and_contexts, @@ -787,6 +789,9 @@ class EventPushActionsStore(EventPushActionsWorkerStore): WHERE room_id = ? AND user_id = ? AND stream_ordering <= ? """, (room_id, user_id, stream_ordering)) + def _start_rotate_notifs(self): + return run_as_background_process("rotate_notifs", self._rotate_notifs) + @defer.inlineCallbacks def _rotate_notifs(self): if self._doing_notif_rotation or self.stream_ordering_day_ago is None: diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 906a405031..e8e5a0fe44 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py
@@ -34,6 +34,8 @@ from synapse.api.errors import SynapseError from synapse.events import EventBase # noqa: F401 from synapse.events.snapshot import EventContext # noqa: F401 from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.storage.background_updates import BackgroundUpdateStore +from synapse.storage.event_federation import EventFederationStore from synapse.storage.events_worker import EventsWorkerStore from synapse.types import RoomStreamToken, get_domain_from_id from synapse.util.async import ObservableDeferred @@ -65,7 +67,13 @@ state_delta_reuse_delta_counter = Counter( def encode_json(json_object): - return frozendict_json_encoder.encode(json_object) + """ + Encode a Python object as JSON and return it in a Unicode string. + """ + out = frozendict_json_encoder.encode(json_object) + if isinstance(out, bytes): + out = out.decode('utf8') + return out class _EventPeristenceQueue(object): @@ -142,15 +150,14 @@ class _EventPeristenceQueue(object): try: queue = self._get_drainining_queue(room_id) for item in queue: - # handle_queue_loop runs in the sentinel logcontext, so - # there is no need to preserve_fn when running the - # callbacks on the deferred. try: ret = yield per_item_callback(item) + except Exception: + with PreserveLoggingContext(): + item.deferred.errback() + else: with PreserveLoggingContext(): item.deferred.callback(ret) - except Exception: - item.deferred.errback() finally: queue = self._event_persist_queues.pop(room_id, None) if queue: @@ -194,7 +201,9 @@ def _retry_on_integrity_error(func): return f -class EventsStore(EventsWorkerStore): +# inherits from EventFederationStore so that we can call _update_backward_extremities +# and _handle_mult_prev_events (though arguably those could both be moved in here) +class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore): EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts" EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url" @@ -232,12 +241,18 @@ class EventsStore(EventsWorkerStore): self._state_resolution_handler = hs.get_state_resolution_handler() + @defer.inlineCallbacks def persist_events(self, events_and_contexts, backfilled=False): """ Write events to the database Args: events_and_contexts: list of tuples of (event, context) - backfilled: ? + backfilled (bool): Whether the results are retrieved from federation + via backfill or not. Used to determine if they're "new" events + which might update the current state etc. + + Returns: + Deferred[int]: the stream ordering of the latest persisted event """ partitioned = {} for event, ctx in events_and_contexts: @@ -254,10 +269,14 @@ class EventsStore(EventsWorkerStore): for room_id in partitioned: self._maybe_start_persisting(room_id) - return make_deferred_yieldable( + yield make_deferred_yieldable( defer.gatherResults(deferreds, consumeErrors=True) ) + max_persisted_id = yield self._stream_id_gen.get_current_token() + + defer.returnValue(max_persisted_id) + @defer.inlineCallbacks @log_function def persist_event(self, event, context, backfilled=False): @@ -521,7 +540,6 @@ class EventsStore(EventsWorkerStore): iterable=list(new_latest_event_ids), retcols=["prev_event_id"], keyvalues={ - "room_id": room_id, "is_state": False, }, desc="_calculate_new_extremeties", @@ -575,11 +593,13 @@ class EventsStore(EventsWorkerStore): for ev, ctx in events_context: if ctx.state_group is None: - # I don't think this can happen, but let's double-check - raise Exception( - "Context for new extremity event %s has no state " - "group" % (ev.event_id, ), - ) + # This should only happen for outlier events. + if not ev.internal_metadata.is_outlier(): + raise Exception( + "Context for new event %s has no state " + "group" % (ev.event_id, ), + ) + continue if ctx.state_group in state_groups_map: continue @@ -607,7 +627,7 @@ class EventsStore(EventsWorkerStore): for event_id in new_latest_event_ids: # First search in the list of new events we're adding. for ev, ctx in events_context: - if event_id == ev.event_id: + if event_id == ev.event_id and ctx.state_group is not None: event_id_to_state_group[event_id] = ctx.state_group break else: @@ -1054,7 +1074,7 @@ class EventsStore(EventsWorkerStore): metadata_json = encode_json( event.internal_metadata.get_dict() - ).decode("UTF-8") + ) sql = ( "UPDATE event_json SET internal_metadata = ?" @@ -1137,7 +1157,7 @@ class EventsStore(EventsWorkerStore): ): txn.executemany( "DELETE FROM %s WHERE room_id = ? AND event_id = ?" % (table,), - [(ev.event_id,) for ev, _ in events_and_contexts] + [(ev.room_id, ev.event_id) for ev, _ in events_and_contexts] ) def _store_event_txn(self, txn, events_and_contexts): @@ -1168,8 +1188,8 @@ class EventsStore(EventsWorkerStore): "room_id": event.room_id, "internal_metadata": encode_json( event.internal_metadata.get_dict() - ).decode("UTF-8"), - "json": encode_json(event_dict(event)).decode("UTF-8"), + ), + "json": encode_json(event_dict(event)), } for event, _ in events_and_contexts ], @@ -1188,7 +1208,6 @@ class EventsStore(EventsWorkerStore): "type": event.type, "processed": True, "outlier": event.internal_metadata.is_outlier(), - "content": encode_json(event.content).decode("UTF-8"), "origin_server_ts": int(event.origin_server_ts), "received_ts": self._clock.time_msec(), "sender": event.sender, diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 009b91dd52..beb99f65fc 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py
@@ -24,7 +24,7 @@ from canonicaljson import json from twisted.internet import defer from synapse.api.constants import EventTypes, Membership -from synapse.storage.events import EventsWorkerStore +from synapse.storage.events_worker import EventsWorkerStore from synapse.types import get_domain_from_id from synapse.util.async import Linearizer from synapse.util.caches import intern_string diff --git a/synapse/storage/schema/delta/50/make_event_content_nullable.py b/synapse/storage/schema/delta/50/make_event_content_nullable.py new file mode 100644
index 0000000000..6dd467b6c5 --- /dev/null +++ b/synapse/storage/schema/delta/50/make_event_content_nullable.py
@@ -0,0 +1,92 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +We want to stop populating 'event.content', so we need to make it nullable. + +If this has to be rolled back, then the following should populate the missing data: + +Postgres: + + UPDATE events SET content=(ej.json::json)->'content' FROM event_json ej + WHERE ej.event_id = events.event_id AND + stream_ordering < ( + SELECT stream_ordering FROM events WHERE content IS NOT NULL + ORDER BY stream_ordering LIMIT 1 + ); + + UPDATE events SET content=(ej.json::json)->'content' FROM event_json ej + WHERE ej.event_id = events.event_id AND + stream_ordering > ( + SELECT stream_ordering FROM events WHERE content IS NOT NULL + ORDER BY stream_ordering DESC LIMIT 1 + ); + +SQLite: + + UPDATE events SET content=( + SELECT json_extract(json,'$.content') FROM event_json ej + WHERE ej.event_id = events.event_id + ) + WHERE + stream_ordering < ( + SELECT stream_ordering FROM events WHERE content IS NOT NULL + ORDER BY stream_ordering LIMIT 1 + ) + OR stream_ordering > ( + SELECT stream_ordering FROM events WHERE content IS NOT NULL + ORDER BY stream_ordering DESC LIMIT 1 + ); + +""" + +import logging + +from synapse.storage.engines import PostgresEngine + +logger = logging.getLogger(__name__) + + +def run_create(cur, database_engine, *args, **kwargs): + pass + + +def run_upgrade(cur, database_engine, *args, **kwargs): + if isinstance(database_engine, PostgresEngine): + cur.execute(""" + ALTER TABLE events ALTER COLUMN content DROP NOT NULL; + """) + return + + # sqlite is an arse about this. ref: https://www.sqlite.org/lang_altertable.html + + cur.execute("SELECT sql FROM sqlite_master WHERE tbl_name='events' AND type='table'") + (oldsql,) = cur.fetchone() + + sql = oldsql.replace("content TEXT NOT NULL", "content TEXT") + if sql == oldsql: + raise Exception("Couldn't find null constraint to drop in %s" % oldsql) + + logger.info("Replacing definition of 'events' with: %s", sql) + + cur.execute("PRAGMA schema_version") + (oldver,) = cur.fetchone() + cur.execute("PRAGMA writable_schema=ON") + cur.execute( + "UPDATE sqlite_master SET sql=? WHERE tbl_name='events' AND type='table'", + (sql, ), + ) + cur.execute("PRAGMA schema_version=%i" % (oldver + 1,)) + cur.execute("PRAGMA writable_schema=OFF") diff --git a/synapse/storage/schema/full_schemas/16/event_edges.sql b/synapse/storage/schema/full_schemas/16/event_edges.sql
index 52eec88357..6b5a5a88fa 100644 --- a/synapse/storage/schema/full_schemas/16/event_edges.sql +++ b/synapse/storage/schema/full_schemas/16/event_edges.sql
@@ -37,7 +37,8 @@ CREATE TABLE IF NOT EXISTS event_edges( event_id TEXT NOT NULL, prev_event_id TEXT NOT NULL, room_id TEXT NOT NULL, - is_state BOOL NOT NULL, + is_state BOOL NOT NULL, -- true if this is a prev_state edge rather than a regular + -- event dag edge. UNIQUE (event_id, prev_event_id, room_id, is_state) ); diff --git a/synapse/storage/schema/full_schemas/16/im.sql b/synapse/storage/schema/full_schemas/16/im.sql
index ba5346806e..5f5cb8d01d 100644 --- a/synapse/storage/schema/full_schemas/16/im.sql +++ b/synapse/storage/schema/full_schemas/16/im.sql
@@ -19,7 +19,12 @@ CREATE TABLE IF NOT EXISTS events( event_id TEXT NOT NULL, type TEXT NOT NULL, room_id TEXT NOT NULL, - content TEXT NOT NULL, + + -- 'content' used to be created NULLable, but as of delta 50 we drop that constraint. + -- the hack we use to drop the constraint doesn't work for an in-memory sqlite + -- database, which breaks the sytests. Hence, we no longer make it nullable. + content TEXT, + unrecognized_keys TEXT, processed BOOL NOT NULL, outlier BOOL NOT NULL, diff --git a/synapse/storage/signatures.py b/synapse/storage/signatures.py
index 470212aa2a..5623391f6e 100644 --- a/synapse/storage/signatures.py +++ b/synapse/storage/signatures.py
@@ -74,7 +74,7 @@ class SignatureWorkerStore(SQLBaseStore): txn (cursor): event_id (str): Id for the Event. Returns: - A dict of algorithm -> hash. + A dict[unicode, bytes] of algorithm -> hash. """ query = ( "SELECT algorithm, hash" diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 89a05c4618..b27b3ae144 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py
@@ -186,7 +186,17 @@ class StateGroupWorkerStore(SQLBaseStore): @defer.inlineCallbacks def _get_state_groups_from_groups(self, groups, types): - """Returns dictionary state_group -> (dict of (type, state_key) -> event id) + """Returns the state groups for a given set of groups, filtering on + types of state events. + + Args: + groups(list[int]): list of state group IDs to query + types (Iterable[str, str|None]|None): list of 2-tuples of the form + (`type`, `state_key`), where a `state_key` of `None` matches all + state_keys for the `type`. If None, all types are returned. + + Returns: + dictionary state_group -> (dict of (type, state_key) -> event id) """ results = {} @@ -200,8 +210,11 @@ class StateGroupWorkerStore(SQLBaseStore): defer.returnValue(results) - def _get_state_groups_from_groups_txn(self, txn, groups, types=None): + def _get_state_groups_from_groups_txn( + self, txn, groups, types=None, + ): results = {group: {} for group in groups} + if types is not None: types = list(set(types)) # deduplicate types list @@ -239,7 +252,7 @@ class StateGroupWorkerStore(SQLBaseStore): # Turns out that postgres doesn't like doing a list of OR's and # is about 1000x slower, so we just issue a query for each specific # type seperately. - if types: + if types is not None: clause_to_args = [ ( "AND type = ? AND state_key = ?", @@ -278,6 +291,7 @@ class StateGroupWorkerStore(SQLBaseStore): else: where_clauses.append("(type = ? AND state_key = ?)") where_args.extend([typ[0], typ[1]]) + where_clause = "AND (%s)" % (" OR ".join(where_clauses)) else: where_clause = "" @@ -332,16 +346,20 @@ class StateGroupWorkerStore(SQLBaseStore): return results @defer.inlineCallbacks - def get_state_for_events(self, event_ids, types): + def get_state_for_events(self, event_ids, types, filtered_types=None): """Given a list of event_ids and type tuples, return a list of state dicts for each event. The state dicts will only have the type/state_keys that are in the `types` list. Args: - event_ids (list) - types (list): List of (type, state_key) tuples which are used to - filter the state fetched. `state_key` may be None, which matches - any `state_key` + event_ids (list[string]) + types (list[(str, str|None)]|None): List of (type, state_key) tuples + which are used to filter the state fetched. If `state_key` is None, + all events are returned of the given type. + May be None, which matches any key. + filtered_types(list[str]|None): Only apply filtering via `types` to this + list of event types. Other types of events are returned unfiltered. + If None, `types` filtering is applied to all events. Returns: deferred: A list of dicts corresponding to the event_ids given. @@ -352,7 +370,7 @@ class StateGroupWorkerStore(SQLBaseStore): ) groups = set(itervalues(event_to_groups)) - group_to_state = yield self._get_state_for_groups(groups, types) + group_to_state = yield self._get_state_for_groups(groups, types, filtered_types) state_event_map = yield self.get_events( [ev_id for sd in itervalues(group_to_state) for ev_id in itervalues(sd)], @@ -371,15 +389,19 @@ class StateGroupWorkerStore(SQLBaseStore): defer.returnValue({event: event_to_state[event] for event in event_ids}) @defer.inlineCallbacks - def get_state_ids_for_events(self, event_ids, types=None): + def get_state_ids_for_events(self, event_ids, types=None, filtered_types=None): """ Get the state dicts corresponding to a list of events Args: event_ids(list(str)): events whose state should be returned - types(list[(str, str)]|None): List of (type, state_key) tuples - which are used to filter the state fetched. May be None, which - matches any key + types(list[(str, str|None)]|None): List of (type, state_key) tuples + which are used to filter the state fetched. If `state_key` is None, + all events are returned of the given type. + May be None, which matches any key. + filtered_types(list[str]|None): Only apply filtering via `types` to this + list of event types. Other types of events are returned unfiltered. + If None, `types` filtering is applied to all events. Returns: A deferred dict from event_id -> (type, state_key) -> state_event @@ -389,7 +411,7 @@ class StateGroupWorkerStore(SQLBaseStore): ) groups = set(itervalues(event_to_groups)) - group_to_state = yield self._get_state_for_groups(groups, types) + group_to_state = yield self._get_state_for_groups(groups, types, filtered_types) event_to_state = { event_id: group_to_state[group] @@ -399,37 +421,45 @@ class StateGroupWorkerStore(SQLBaseStore): defer.returnValue({event: event_to_state[event] for event in event_ids}) @defer.inlineCallbacks - def get_state_for_event(self, event_id, types=None): + def get_state_for_event(self, event_id, types=None, filtered_types=None): """ Get the state dict corresponding to a particular event Args: event_id(str): event whose state should be returned - types(list[(str, str)]|None): List of (type, state_key) tuples - which are used to filter the state fetched. May be None, which - matches any key + types(list[(str, str|None)]|None): List of (type, state_key) tuples + which are used to filter the state fetched. If `state_key` is None, + all events are returned of the given type. + May be None, which matches any key. + filtered_types(list[str]|None): Only apply filtering via `types` to this + list of event types. Other types of events are returned unfiltered. + If None, `types` filtering is applied to all events. Returns: A deferred dict from (type, state_key) -> state_event """ - state_map = yield self.get_state_for_events([event_id], types) + state_map = yield self.get_state_for_events([event_id], types, filtered_types) defer.returnValue(state_map[event_id]) @defer.inlineCallbacks - def get_state_ids_for_event(self, event_id, types=None): + def get_state_ids_for_event(self, event_id, types=None, filtered_types=None): """ Get the state dict corresponding to a particular event Args: event_id(str): event whose state should be returned - types(list[(str, str)]|None): List of (type, state_key) tuples - which are used to filter the state fetched. May be None, which - matches any key + types(list[(str, str|None)]|None): List of (type, state_key) tuples + which are used to filter the state fetched. If `state_key` is None, + all events are returned of the given type. + May be None, which matches any key. + filtered_types(list[str]|None): Only apply filtering via `types` to this + list of event types. Other types of events are returned unfiltered. + If None, `types` filtering is applied to all events. Returns: A deferred dict from (type, state_key) -> state_event """ - state_map = yield self.get_state_ids_for_events([event_id], types) + state_map = yield self.get_state_ids_for_events([event_id], types, filtered_types) defer.returnValue(state_map[event_id]) @cached(max_entries=50000) @@ -460,56 +490,73 @@ class StateGroupWorkerStore(SQLBaseStore): defer.returnValue({row["event_id"]: row["state_group"] for row in rows}) - def _get_some_state_from_cache(self, group, types): + def _get_some_state_from_cache(self, group, types, filtered_types=None): """Checks if group is in cache. See `_get_state_for_groups` - Returns 3-tuple (`state_dict`, `missing_types`, `got_all`). - `missing_types` is the list of types that aren't in the cache for that - group. `got_all` is a bool indicating if we successfully retrieved all + Args: + group(int): The state group to lookup + types(list[str, str|None]): List of 2-tuples of the form + (`type`, `state_key`), where a `state_key` of `None` matches all + state_keys for the `type`. + filtered_types(list[str]|None): Only apply filtering via `types` to this + list of event types. Other types of events are returned unfiltered. + If None, `types` filtering is applied to all events. + + Returns 2-tuple (`state_dict`, `got_all`). + `got_all` is a bool indicating if we successfully retrieved all requests state from the cache, if False we need to query the DB for the missing state. - - Args: - group: The state group to lookup - types (list): List of 2-tuples of the form (`type`, `state_key`), - where a `state_key` of `None` matches all state_keys for the - `type`. """ is_all, known_absent, state_dict_ids = self._state_group_cache.get(group) type_to_key = {} - missing_types = set() + + # tracks whether any of ourrequested types are missing from the cache + missing_types = False for typ, state_key in types: key = (typ, state_key) - if state_key is None: + + if ( + state_key is None or + (filtered_types is not None and typ not in filtered_types) + ): type_to_key[typ] = None - missing_types.add(key) + # we mark the type as missing from the cache because + # when the cache was populated it might have been done with a + # restricted set of state_keys, so the wildcard will not work + # and the cache may be incomplete. + missing_types = True else: if type_to_key.get(typ, object()) is not None: type_to_key.setdefault(typ, set()).add(state_key) if key not in state_dict_ids and key not in known_absent: - missing_types.add(key) + missing_types = True sentinel = object() def include(typ, state_key): valid_state_keys = type_to_key.get(typ, sentinel) if valid_state_keys is sentinel: - return False + return filtered_types is not None and typ not in filtered_types if valid_state_keys is None: return True if state_key in valid_state_keys: return True return False - got_all = is_all or not missing_types + got_all = is_all + if not got_all: + # the cache is incomplete. We may still have got all the results we need, if + # we don't have any wildcards in the match list. + if not missing_types and filtered_types is None: + got_all = True return { k: v for k, v in iteritems(state_dict_ids) if include(k[0], k[1]) - }, missing_types, got_all + }, got_all def _get_all_state_from_cache(self, group): """Checks if group is in cache. See `_get_state_for_groups` @@ -526,7 +573,7 @@ class StateGroupWorkerStore(SQLBaseStore): return state_dict_ids, is_all @defer.inlineCallbacks - def _get_state_for_groups(self, groups, types=None): + def _get_state_for_groups(self, groups, types=None, filtered_types=None): """Gets the state at each of a list of state groups, optionally filtering by type/state_key @@ -540,6 +587,9 @@ class StateGroupWorkerStore(SQLBaseStore): Otherwise, each entry should be a `(type, state_key)` tuple to include in the response. A `state_key` of None is a wildcard meaning that we require all state with that type. + filtered_types(list[str]|None): Only apply filtering via `types` to this + list of event types. Other types of events are returned unfiltered. + If None, `types` filtering is applied to all events. Returns: Deferred[dict[int, dict[(type, state_key), EventBase]]] @@ -551,8 +601,8 @@ class StateGroupWorkerStore(SQLBaseStore): missing_groups = [] if types is not None: for group in set(groups): - state_dict_ids, _, got_all = self._get_some_state_from_cache( - group, types, + state_dict_ids, got_all = self._get_some_state_from_cache( + group, types, filtered_types ) results[group] = state_dict_ids @@ -579,13 +629,13 @@ class StateGroupWorkerStore(SQLBaseStore): # cache. Hence, if we are doing a wildcard lookup, populate the # cache fully so that we can do an efficient lookup next time. - if types and any(k is None for (t, k) in types): + if filtered_types or (types and any(k is None for (t, k) in types)): types_to_fetch = None else: types_to_fetch = types group_to_state_dict = yield self._get_state_groups_from_groups( - missing_groups, types_to_fetch, + missing_groups, types_to_fetch ) for group, group_state_dict in iteritems(group_to_state_dict): @@ -595,7 +645,10 @@ class StateGroupWorkerStore(SQLBaseStore): if types: for k, v in iteritems(group_state_dict): (typ, _) = k - if k in types or (typ, None) in types: + if ( + (k in types or (typ, None) in types) or + (filtered_types and typ not in filtered_types) + ): state_dict[k] = v else: state_dict.update(group_state_dict) diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index 66856342f0..b9f2b74ac6 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py
@@ -43,7 +43,7 @@ from twisted.internet import defer from synapse.storage._base import SQLBaseStore from synapse.storage.engines import PostgresEngine -from synapse.storage.events import EventsWorkerStore +from synapse.storage.events_worker import EventsWorkerStore from synapse.types import RoomStreamToken from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.logcontext import make_deferred_yieldable, run_in_background @@ -527,7 +527,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): ) @defer.inlineCallbacks - def get_events_around(self, room_id, event_id, before_limit, after_limit): + def get_events_around( + self, room_id, event_id, before_limit, after_limit, event_filter=None, + ): """Retrieve events and pagination tokens around a given event in a room. @@ -536,6 +538,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): event_id (str) before_limit (int) after_limit (int) + event_filter (Filter|None) Returns: dict @@ -543,7 +546,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): results = yield self.runInteraction( "get_events_around", self._get_events_around_txn, - room_id, event_id, before_limit, after_limit + room_id, event_id, before_limit, after_limit, event_filter, ) events_before = yield self._get_events( @@ -563,7 +566,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): "end": results["after"]["token"], }) - def _get_events_around_txn(self, txn, room_id, event_id, before_limit, after_limit): + def _get_events_around_txn( + self, txn, room_id, event_id, before_limit, after_limit, event_filter, + ): """Retrieves event_ids and pagination tokens around a given event in a room. @@ -572,6 +577,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): event_id (str) before_limit (int) after_limit (int) + event_filter (Filter|None) Returns: dict @@ -601,11 +607,13 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): rows, start_token = self._paginate_room_events_txn( txn, room_id, before_token, direction='b', limit=before_limit, + event_filter=event_filter, ) events_before = [r.event_id for r in rows] rows, end_token = self._paginate_room_events_txn( txn, room_id, after_token, direction='f', limit=after_limit, + event_filter=event_filter, ) events_after = [r.event_id for r in rows] diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py
index c3bc94f56d..428e7fa36e 100644 --- a/synapse/storage/transactions.py +++ b/synapse/storage/transactions.py
@@ -22,6 +22,7 @@ from canonicaljson import encode_canonical_json, json from twisted.internet import defer +from synapse.metrics.background_process_metrics import run_as_background_process from synapse.util.caches.descriptors import cached from ._base import SQLBaseStore @@ -57,7 +58,7 @@ class TransactionStore(SQLBaseStore): def __init__(self, db_conn, hs): super(TransactionStore, self).__init__(db_conn, hs) - self._clock.looping_call(self._cleanup_transactions, 30 * 60 * 1000) + self._clock.looping_call(self._start_cleanup_transactions, 30 * 60 * 1000) def get_received_txn_response(self, transaction_id, origin): """For an incoming transaction from a given origin, check if we have @@ -271,6 +272,11 @@ class TransactionStore(SQLBaseStore): txn.execute(query, (self._clock.time_msec(),)) return self.cursor_to_dict(txn) + def _start_cleanup_transactions(self): + return run_as_background_process( + "cleanup_transactions", self._cleanup_transactions, + ) + def _cleanup_transactions(self): now = self._clock.time_msec() month_ago = now - 30 * 24 * 60 * 60 * 1000 diff --git a/synapse/types.py b/synapse/types.py
index 08f058f714..41afb27a74 100644 --- a/synapse/types.py +++ b/synapse/types.py
@@ -137,7 +137,7 @@ class DomainSpecificString( @classmethod def from_string(cls, s): """Parse the string given by 's' into a structure object.""" - if len(s) < 1 or s[0] != cls.SIGIL: + if len(s) < 1 or s[0:1] != cls.SIGIL: raise SynapseError(400, "Expected %s string to start with '%s'" % ( cls.__name__, cls.SIGIL, )) diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index f8a07df6b8..861c24809c 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py
@@ -473,105 +473,101 @@ class CacheListDescriptor(_CacheDescriptorBase): @functools.wraps(self.orig) def wrapped(*args, **kwargs): - # If we're passed a cache_context then we'll want to call its invalidate() - # whenever we are invalidated + # If we're passed a cache_context then we'll want to call its + # invalidate() whenever we are invalidated invalidate_callback = kwargs.pop("on_invalidate", None) arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names] list_args = arg_dict[self.list_name] - # cached is a dict arg -> deferred, where deferred results in a - # 2-tuple (`arg`, `result`) results = {} - cached_defers = {} - missing = [] + + def update_results_dict(res, arg): + results[arg] = res + + # list of deferreds to wait for + cached_defers = [] + + missing = set() # If the cache takes a single arg then that is used as the key, # otherwise a tuple is used. if num_args == 1: - def cache_get(arg): - return cache.get(arg, callback=invalidate_callback) + def arg_to_cache_key(arg): + return arg else: - key = list(keyargs) + keylist = list(keyargs) - def cache_get(arg): - key[self.list_pos] = arg - return cache.get(tuple(key), callback=invalidate_callback) + def arg_to_cache_key(arg): + keylist[self.list_pos] = arg + return tuple(keylist) for arg in list_args: try: - res = cache_get(arg) - + res = cache.get(arg_to_cache_key(arg), + callback=invalidate_callback) if not isinstance(res, ObservableDeferred): results[arg] = res elif not res.has_succeeded(): res = res.observe() - res.addCallback(lambda r, arg: (arg, r), arg) - cached_defers[arg] = res + res.addCallback(update_results_dict, arg) + cached_defers.append(res) else: results[arg] = res.get_result() except KeyError: - missing.append(arg) + missing.add(arg) if missing: + # we need an observable deferred for each entry in the list, + # which we put in the cache. Each deferred resolves with the + # relevant result for that key. + deferreds_map = {} + for arg in missing: + deferred = defer.Deferred() + deferreds_map[arg] = deferred + key = arg_to_cache_key(arg) + observable = ObservableDeferred(deferred) + cache.set(key, observable, callback=invalidate_callback) + + def complete_all(res): + # the wrapped function has completed. It returns a + # a dict. We can now resolve the observable deferreds in + # the cache and update our own result map. + for e in missing: + val = res.get(e, None) + deferreds_map[e].callback(val) + results[e] = val + + def errback(f): + # the wrapped function has failed. Invalidate any cache + # entries we're supposed to be populating, and fail + # their deferreds. + for e in missing: + key = arg_to_cache_key(e) + cache.invalidate(key) + deferreds_map[e].errback(f) + + # return the failure, to propagate to our caller. + return f + args_to_call = dict(arg_dict) - args_to_call[self.list_name] = missing + args_to_call[self.list_name] = list(missing) - ret_d = defer.maybeDeferred( + cached_defers.append(defer.maybeDeferred( logcontext.preserve_fn(self.function_to_call), **args_to_call - ) - - ret_d = ObservableDeferred(ret_d) - - # We need to create deferreds for each arg in the list so that - # we can insert the new deferred into the cache. - for arg in missing: - observer = ret_d.observe() - observer.addCallback(lambda r, arg: r.get(arg, None), arg) - - observer = ObservableDeferred(observer) - - if num_args == 1: - cache.set( - arg, observer, - callback=invalidate_callback - ) - - def invalidate(f, key): - cache.invalidate(key) - return f - observer.addErrback(invalidate, arg) - else: - key = list(keyargs) - key[self.list_pos] = arg - cache.set( - tuple(key), observer, - callback=invalidate_callback - ) - - def invalidate(f, key): - cache.invalidate(key) - return f - observer.addErrback(invalidate, tuple(key)) - - res = observer.observe() - res.addCallback(lambda r, arg: (arg, r), arg) - - cached_defers[arg] = res + ).addCallbacks(complete_all, errback)) if cached_defers: - def update_results_dict(res): - results.update(res) - return results - - return logcontext.make_deferred_yieldable(defer.gatherResults( - list(cached_defers.values()), + d = defer.gatherResults( + cached_defers, consumeErrors=True, - ).addCallback(update_results_dict).addErrback( + ).addCallbacks( + lambda _: results, unwrapFirstError - )) + ) + return logcontext.make_deferred_yieldable(d) else: return results @@ -625,7 +621,8 @@ def cachedList(cached_method_name, list_name, num_args=None, inlineCallbacks=Fal cache. Args: - cache (Cache): The underlying cache to use. + cached_method_name (str): The name of the single-item lookup method. + This is only used to find the cache to use. list_name (str): The name of the argument that is the list to use to do batch lookups in the cache. num_args (int): Number of arguments to use as the key in the cache diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py
index 465adc54a8..ce85b2ae11 100644 --- a/synapse/util/caches/expiringcache.py +++ b/synapse/util/caches/expiringcache.py
@@ -64,7 +64,7 @@ class ExpiringCache(object): return def f(): - run_as_background_process( + return run_as_background_process( "prune_cache_%s" % self._cache_name, self._prune_cache, ) diff --git a/synapse/util/frozenutils.py b/synapse/util/frozenutils.py
index 581c6052ac..014edea971 100644 --- a/synapse/util/frozenutils.py +++ b/synapse/util/frozenutils.py
@@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from six import string_types +from six import binary_type, text_type from canonicaljson import json from frozendict import frozendict @@ -26,7 +26,7 @@ def freeze(o): if isinstance(o, frozendict): return o - if isinstance(o, string_types): + if isinstance(o, (binary_type, text_type)): return o try: @@ -41,7 +41,7 @@ def unfreeze(o): if isinstance(o, (dict, frozendict)): return dict({k: unfreeze(v) for k, v in o.items()}) - if isinstance(o, string_types): + if isinstance(o, (binary_type, text_type)): return o try: