diff options
Diffstat (limited to 'synapse')
40 files changed, 1886 insertions, 979 deletions
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index c227265190..fa43211415 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -32,9 +32,9 @@ from synapse.server import HomeServer from twisted.internet import reactor from twisted.application import service from twisted.enterprise import adbapi -from twisted.web.resource import Resource +from twisted.web.resource import Resource, EncodingResourceWrapper from twisted.web.static import File -from twisted.web.server import Site +from twisted.web.server import Site, GzipEncoderFactory from twisted.web.http import proxiedLogFormatter, combinedLogFormatter from synapse.http.server import JsonResource, RootRedirect from synapse.rest.media.v0.content_repository import ContentRepoResource @@ -69,16 +69,26 @@ import subprocess logger = logging.getLogger("synapse.app.homeserver") +class GzipFile(File): + def getChild(self, path, request): + child = File.getChild(self, path, request) + return EncodingResourceWrapper(child, [GzipEncoderFactory()]) + + +def gz_wrap(r): + return EncodingResourceWrapper(r, [GzipEncoderFactory()]) + + class SynapseHomeServer(HomeServer): def build_http_client(self): return MatrixFederationHttpClient(self) def build_resource_for_client(self): - return ClientV1RestResource(self) + return gz_wrap(ClientV1RestResource(self)) def build_resource_for_client_v2_alpha(self): - return ClientV2AlphaRestResource(self) + return gz_wrap(ClientV2AlphaRestResource(self)) def build_resource_for_federation(self): return JsonResource(self) @@ -87,9 +97,16 @@ class SynapseHomeServer(HomeServer): import syweb syweb_path = os.path.dirname(syweb.__file__) webclient_path = os.path.join(syweb_path, "webclient") + # GZip is disabled here due to + # https://twistedmatrix.com/trac/ticket/7678 + # (It can stay enabled for the API resources: they call + # write() with the whole body and then finish() straight + # after and so do not trigger the bug. + # return GzipFile(webclient_path) # TODO configurable? return File(webclient_path) # TODO configurable? def build_resource_for_static_content(self): + # This is old and should go away: not going to bother adding gzip return File("static") def build_resource_for_content_repo(self): diff --git a/synapse/crypto/keyclient.py b/synapse/crypto/keyclient.py index 4911f0896b..24f15f3154 100644 --- a/synapse/crypto/keyclient.py +++ b/synapse/crypto/keyclient.py @@ -18,7 +18,9 @@ from twisted.web.http import HTTPClient from twisted.internet.protocol import Factory from twisted.internet import defer, reactor from synapse.http.endpoint import matrix_federation_endpoint -from synapse.util.logcontext import PreserveLoggingContext +from synapse.util.logcontext import ( + preserve_context_over_fn, preserve_context_over_deferred +) import simplejson as json import logging @@ -40,11 +42,14 @@ def fetch_server_key(server_name, ssl_context_factory, path=KEY_API_V1): for i in range(5): try: - with PreserveLoggingContext(): - protocol = yield endpoint.connect(factory) - server_response, server_certificate = yield protocol.remote_key - defer.returnValue((server_response, server_certificate)) - return + protocol = yield preserve_context_over_fn( + endpoint.connect, factory + ) + server_response, server_certificate = yield preserve_context_over_deferred( + protocol.remote_key + ) + defer.returnValue((server_response, server_certificate)) + return except SynapseKeyClientError as e: logger.exception("Error getting key for %r" % (server_name,)) if e.status.startswith("4"): diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 8709394b97..aff69c5f83 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -26,7 +26,7 @@ from synapse.api.errors import SynapseError, Codes from synapse.util.retryutils import get_retry_limiter -from synapse.util.async import create_observer +from synapse.util.async import ObservableDeferred from OpenSSL import crypto @@ -111,6 +111,10 @@ class Keyring(object): if download is None: download = self._get_server_verify_key_impl(server_name, key_ids) + download = ObservableDeferred( + download, + consumeErrors=True + ) self.key_downloads[server_name] = download @download.addBoth @@ -118,30 +122,31 @@ class Keyring(object): del self.key_downloads[server_name] return ret - r = yield create_observer(download) + r = yield download.observe() defer.returnValue(r) @defer.inlineCallbacks def _get_server_verify_key_impl(self, server_name, key_ids): keys = None - perspective_results = [] - for perspective_name, perspective_keys in self.perspective_servers.items(): - @defer.inlineCallbacks - def get_key(): - try: - result = yield self.get_server_verify_key_v2_indirect( - server_name, key_ids, perspective_name, perspective_keys - ) - defer.returnValue(result) - except: - logging.info( - "Unable to getting key %r for %r from %r", - key_ids, server_name, perspective_name, - ) - perspective_results.append(get_key()) + @defer.inlineCallbacks + def get_key(perspective_name, perspective_keys): + try: + result = yield self.get_server_verify_key_v2_indirect( + server_name, key_ids, perspective_name, perspective_keys + ) + defer.returnValue(result) + except Exception as e: + logging.info( + "Unable to getting key %r for %r from %r: %s %s", + key_ids, server_name, perspective_name, + type(e).__name__, str(e.message), + ) - perspective_results = yield defer.gatherResults(perspective_results) + perspective_results = yield defer.gatherResults([ + get_key(p_name, p_keys) + for p_name, p_keys in self.perspective_servers.items() + ]) for results in perspective_results: if results is not None: @@ -154,17 +159,22 @@ class Keyring(object): ) with limiter: - if keys is None: + if not keys: try: keys = yield self.get_server_verify_key_v2_direct( server_name, key_ids ) - except: - pass + except Exception as e: + logging.info( + "Unable to getting key %r for %r directly: %s %s", + key_ids, server_name, + type(e).__name__, str(e.message), + ) - keys = yield self.get_server_verify_key_v1_direct( - server_name, key_ids - ) + if not keys: + keys = yield self.get_server_verify_key_v1_direct( + server_name, key_ids + ) for key_id in key_ids: if key_id in keys: @@ -184,7 +194,7 @@ class Keyring(object): # TODO(mark): Set the minimum_valid_until_ts to that needed by # the events being validated or the current time if validating # an incoming request. - responses = yield self.client.post_json( + query_response = yield self.client.post_json( destination=perspective_name, path=b"/_matrix/key/v2/query", data={ @@ -200,6 +210,8 @@ class Keyring(object): keys = {} + responses = query_response["server_keys"] + for response in responses: if (u"signatures" not in response or perspective_name not in response[u"signatures"]): @@ -323,7 +335,7 @@ class Keyring(object): verify_key.time_added = time_now_ms old_verify_keys[key_id] = verify_key - for key_id in response_json["signatures"][server_name]: + for key_id in response_json["signatures"].get(server_name, {}): if key_id not in response_json["verify_keys"]: raise ValueError( "Key response must include verification keys for all" diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index 21a763214b..5217d91aab 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -24,6 +24,8 @@ from synapse.crypto.event_signing import check_event_content_hash from synapse.api.errors import SynapseError +from synapse.util import unwrapFirstError + import logging @@ -94,7 +96,7 @@ class FederationBase(object): yield defer.gatherResults( [do(pdu) for pdu in pdus], consumeErrors=True - ) + ).addErrback(unwrapFirstError) defer.returnValue(signed_pdus) diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 904c7c0945..3a7bc0c9a7 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -222,7 +222,7 @@ class FederationClient(FederationBase): for p in transaction_data["pdus"] ] - if pdu_list: + if pdu_list and pdu_list[0]: pdu = pdu_list[0] # Check signatures are correct. @@ -255,7 +255,7 @@ class FederationClient(FederationBase): ) continue - if self._get_pdu_cache is not None: + if self._get_pdu_cache is not None and pdu: self._get_pdu_cache[event_id] = pdu defer.returnValue(pdu) @@ -561,7 +561,7 @@ class FederationClient(FederationBase): res = yield defer.DeferredList(deferreds, consumeErrors=True) for (result, val), (e_id, _) in zip(res, ordered_missing): - if result: + if result and val: signed_events.append(val) else: failed_to_fetch.add(e_id) diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 2b46188c91..cd79e23f4b 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -20,7 +20,6 @@ from .federation_base import FederationBase from .units import Transaction, Edu from synapse.util.logutils import log_function -from synapse.util.logcontext import PreserveLoggingContext from synapse.events import FrozenEvent import synapse.metrics @@ -123,29 +122,28 @@ class FederationServer(FederationBase): logger.debug("[%s] Transaction is new", transaction.transaction_id) - with PreserveLoggingContext(): - results = [] - - for pdu in pdu_list: - d = self._handle_new_pdu(transaction.origin, pdu) - - try: - yield d - results.append({}) - except FederationError as e: - self.send_failure(e, transaction.origin) - results.append({"error": str(e)}) - except Exception as e: - results.append({"error": str(e)}) - logger.exception("Failed to handle PDU") - - if hasattr(transaction, "edus"): - for edu in [Edu(**x) for x in transaction.edus]: - self.received_edu( - transaction.origin, - edu.edu_type, - edu.content - ) + results = [] + + for pdu in pdu_list: + d = self._handle_new_pdu(transaction.origin, pdu) + + try: + yield d + results.append({}) + except FederationError as e: + self.send_failure(e, transaction.origin) + results.append({"error": str(e)}) + except Exception as e: + results.append({"error": str(e)}) + logger.exception("Failed to handle PDU") + + if hasattr(transaction, "edus"): + for edu in [Edu(**x) for x in transaction.edus]: + self.received_edu( + transaction.origin, + edu.edu_type, + edu.content + ) for failure in getattr(transaction, "pdu_failures", []): logger.info("Got failure %r", failure) diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index 4b3f4eadab..833ff41377 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -20,6 +20,8 @@ from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.api.constants import Membership, EventTypes from synapse.types import UserID +from synapse.util.logcontext import PreserveLoggingContext + import logging @@ -103,7 +105,9 @@ class BaseHandler(object): if not suppress_auth: self.auth.check(event, auth_events=context.current_state) - yield self.store.persist_event(event, context=context) + (event_stream_id, max_stream_id) = yield self.store.persist_event( + event, context=context + ) federation_handler = self.hs.get_handlers().federation_handler @@ -137,10 +141,12 @@ class BaseHandler(object): "Failed to get destination from event %s", s.event_id ) - # Don't block waiting on waking up all the listeners. - notify_d = self.notifier.on_new_room_event( - event, extra_users=extra_users - ) + with PreserveLoggingContext(): + # Don't block waiting on waking up all the listeners. + notify_d = self.notifier.on_new_room_event( + event, event_stream_id, max_stream_id, + extra_users=extra_users + ) def log_failure(f): logger.warn( diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index f76febee8f..e41a688836 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -22,6 +22,7 @@ from synapse.api.constants import EventTypes from synapse.types import RoomAlias import logging +import string logger = logging.getLogger(__name__) @@ -40,6 +41,10 @@ class DirectoryHandler(BaseHandler): def _create_association(self, room_alias, room_id, servers=None): # general association creation for both human users and app services + for wchar in string.whitespace: + if wchar in room_alias.localpart: + raise SynapseError(400, "Invalid characters in room alias") + if not self.hs.is_mine(room_alias): raise SynapseError(400, "Room alias must be local") # TODO(erikj): Change this. diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index f9f855213b..993d33ba47 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -15,7 +15,6 @@ from twisted.internet import defer -from synapse.util.logcontext import PreserveLoggingContext from synapse.util.logutils import log_function from synapse.types import UserID from synapse.events.utils import serialize_event @@ -81,10 +80,9 @@ class EventStreamHandler(BaseHandler): # thundering herds on restart. timeout = random.randint(int(timeout*0.9), int(timeout*1.1)) - with PreserveLoggingContext(): - events, tokens = yield self.notifier.get_events_for( - auth_user, room_ids, pagin_config, timeout - ) + events, tokens = yield self.notifier.get_events_for( + auth_user, room_ids, pagin_config, timeout + ) time_now = self.clock.time_msec() diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 85e2757227..d35d9f603c 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -18,9 +18,11 @@ from ._base import BaseHandler from synapse.api.errors import ( - AuthError, FederationError, StoreError, + AuthError, FederationError, StoreError, CodeMessageException, SynapseError, ) from synapse.api.constants import EventTypes, Membership, RejectedReason +from synapse.util import unwrapFirstError +from synapse.util.logcontext import PreserveLoggingContext from synapse.util.logutils import log_function from synapse.util.async import run_on_reactor from synapse.util.frozenutils import unfreeze @@ -29,6 +31,8 @@ from synapse.crypto.event_signing import ( ) from synapse.types import UserID +from synapse.util.retryutils import NotRetryingDestination + from twisted.internet import defer import itertools @@ -156,7 +160,7 @@ class FederationHandler(BaseHandler): ) try: - yield self._handle_new_event( + _, event_stream_id, max_stream_id = yield self._handle_new_event( origin, event, state=state, @@ -197,9 +201,11 @@ class FederationHandler(BaseHandler): target_user = UserID.from_string(target_user_id) extra_users.append(target_user) - d = self.notifier.on_new_room_event( - event, extra_users=extra_users - ) + with PreserveLoggingContext(): + d = self.notifier.on_new_room_event( + event, event_stream_id, max_stream_id, + extra_users=extra_users + ) def log_failure(f): logger.warn( @@ -218,10 +224,11 @@ class FederationHandler(BaseHandler): @log_function @defer.inlineCallbacks - def backfill(self, dest, room_id, limit): + def backfill(self, dest, room_id, limit, extremities=[]): """ Trigger a backfill request to `dest` for the given `room_id` """ - extremities = yield self.store.get_oldest_events_in_room(room_id) + if not extremities: + extremities = yield self.store.get_oldest_events_in_room(room_id) pdus = yield self.replication_layer.backfill( dest, @@ -249,6 +256,140 @@ class FederationHandler(BaseHandler): defer.returnValue(events) @defer.inlineCallbacks + def maybe_backfill(self, room_id, current_depth): + """Checks the database to see if we should backfill before paginating, + and if so do. + """ + extremities = yield self.store.get_oldest_events_with_depth_in_room( + room_id + ) + + if not extremities: + logger.debug("Not backfilling as no extremeties found.") + return + + # Check if we reached a point where we should start backfilling. + sorted_extremeties_tuple = sorted( + extremities.items(), + key=lambda e: -int(e[1]) + ) + max_depth = sorted_extremeties_tuple[0][1] + + if current_depth > max_depth: + logger.debug( + "Not backfilling as we don't need to. %d < %d", + max_depth, current_depth, + ) + return + + # Now we need to decide which hosts to hit first. + + # First we try hosts that are already in the room + # TODO: HEURISTIC ALERT. + + curr_state = yield self.state_handler.get_current_state(room_id) + + def get_domains_from_state(state): + joined_users = [ + (state_key, int(event.depth)) + for (e_type, state_key), event in state.items() + if e_type == EventTypes.Member + and event.membership == Membership.JOIN + ] + + joined_domains = {} + for u, d in joined_users: + try: + dom = UserID.from_string(u).domain + old_d = joined_domains.get(dom) + if old_d: + joined_domains[dom] = min(d, old_d) + else: + joined_domains[dom] = d + except: + pass + + return sorted(joined_domains.items(), key=lambda d: d[1]) + + curr_domains = get_domains_from_state(curr_state) + + likely_domains = [ + domain for domain, depth in curr_domains + if domain is not self.server_name + ] + + @defer.inlineCallbacks + def try_backfill(domains): + # TODO: Should we try multiple of these at a time? + for dom in domains: + try: + events = yield self.backfill( + dom, room_id, + limit=100, + extremities=[e for e in extremities.keys()] + ) + except SynapseError: + logger.info( + "Failed to backfill from %s because %s", + dom, e, + ) + continue + except CodeMessageException as e: + if 400 <= e.code < 500: + raise + + logger.info( + "Failed to backfill from %s because %s", + dom, e, + ) + continue + except NotRetryingDestination as e: + logger.info(e.message) + continue + except Exception as e: + logger.warn( + "Failed to backfill from %s because %s", + dom, e, + ) + continue + + if events: + defer.returnValue(True) + defer.returnValue(False) + + success = yield try_backfill(likely_domains) + if success: + defer.returnValue(True) + + # Huh, well *those* domains didn't work out. Lets try some domains + # from the time. + + tried_domains = set(likely_domains) + tried_domains.add(self.server_name) + + event_ids = list(extremities.keys()) + + states = yield defer.gatherResults([ + self.state_handler.resolve_state_groups([e]) + for e in event_ids + ]) + states = dict(zip(event_ids, [s[1] for s in states])) + + for e_id, _ in sorted_extremeties_tuple: + likely_domains = get_domains_from_state(states[e_id]) + + success = yield try_backfill([ + dom for dom in likely_domains + if dom not in tried_domains + ]) + if success: + defer.returnValue(True) + + tried_domains.update(likely_domains) + + defer.returnValue(False) + + @defer.inlineCallbacks def send_invite(self, target_host, event): """ Sends the invite to the remote server for signing. @@ -423,7 +564,7 @@ class FederationHandler(BaseHandler): if e.event_id in auth_ids } - yield self._handle_new_event( + _, event_stream_id, max_stream_id = yield self._handle_new_event( origin, new_event, state=state, @@ -431,9 +572,11 @@ class FederationHandler(BaseHandler): auth_events=auth_events, ) - d = self.notifier.on_new_room_event( - new_event, extra_users=[joinee] - ) + with PreserveLoggingContext(): + d = self.notifier.on_new_room_event( + new_event, event_stream_id, max_stream_id, + extra_users=[joinee] + ) def log_failure(f): logger.warn( @@ -498,7 +641,9 @@ class FederationHandler(BaseHandler): event.internal_metadata.outlier = False - context = yield self._handle_new_event(origin, event) + context, event_stream_id, max_stream_id = yield self._handle_new_event( + origin, event + ) logger.debug( "on_send_join_request: After _handle_new_event: %s, sigs: %s", @@ -512,9 +657,10 @@ class FederationHandler(BaseHandler): target_user = UserID.from_string(target_user_id) extra_users.append(target_user) - d = self.notifier.on_new_room_event( - event, extra_users=extra_users - ) + with PreserveLoggingContext(): + d = self.notifier.on_new_room_event( + event, event_stream_id, max_stream_id, extra_users=extra_users + ) def log_failure(f): logger.warn( @@ -587,16 +733,18 @@ class FederationHandler(BaseHandler): context = yield self.state_handler.compute_event_context(event) - yield self.store.persist_event( + event_stream_id, max_stream_id = yield self.store.persist_event( event, context=context, backfilled=False, ) target_user = UserID.from_string(event.state_key) - d = self.notifier.on_new_room_event( - event, extra_users=[target_user], - ) + with PreserveLoggingContext(): + d = self.notifier.on_new_room_event( + event, event_stream_id, max_stream_id, + extra_users=[target_user], + ) def log_failure(f): logger.warn( @@ -773,7 +921,7 @@ class FederationHandler(BaseHandler): ) raise - yield self.store.persist_event( + event_stream_id, max_stream_id = yield self.store.persist_event( event, context=context, backfilled=backfilled, @@ -781,7 +929,7 @@ class FederationHandler(BaseHandler): current_state=current_state, ) - defer.returnValue(context) + defer.returnValue((context, event_stream_id, max_stream_id)) @defer.inlineCallbacks def on_query_auth(self, origin, event_id, remote_auth_chain, rejects, @@ -921,7 +1069,7 @@ class FederationHandler(BaseHandler): if d in have_events and not have_events[d] ], consumeErrors=True - ) + ).addErrback(unwrapFirstError) if different_events: local_view = dict(auth_events) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 22e19af17f..867fdbefb0 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -20,8 +20,9 @@ from synapse.api.errors import RoomError, SynapseError from synapse.streams.config import PaginationConfig from synapse.events.utils import serialize_event from synapse.events.validator import EventValidator +from synapse.util import unwrapFirstError from synapse.util.logcontext import PreserveLoggingContext -from synapse.types import UserID +from synapse.types import UserID, RoomStreamToken from ._base import BaseHandler @@ -89,9 +90,19 @@ class MessageHandler(BaseHandler): if not pagin_config.from_token: pagin_config.from_token = ( - yield self.hs.get_event_sources().get_current_token() + yield self.hs.get_event_sources().get_current_token( + direction='b' + ) ) + room_token = RoomStreamToken.parse(pagin_config.from_token.room_key) + if room_token.topological is None: + raise SynapseError(400, "Invalid token") + + yield self.hs.get_handlers().federation_handler.maybe_backfill( + room_id, room_token.topological + ) + user = UserID.from_string(user_id) events, next_key = yield data_source.get_pagination_rows( @@ -303,7 +314,7 @@ class MessageHandler(BaseHandler): event.room_id ), ] - ) + ).addErrback(unwrapFirstError) start_token = now_token.copy_and_replace("room_key", token[0]) end_token = now_token.copy_and_replace("room_key", token[1]) @@ -328,7 +339,7 @@ class MessageHandler(BaseHandler): yield defer.gatherResults( [handle_room(e) for e in room_list], consumeErrors=True - ) + ).addErrback(unwrapFirstError) ret = { "rooms": rooms_ret, diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 9e15610401..40794187b1 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -18,8 +18,8 @@ from twisted.internet import defer from synapse.api.errors import SynapseError, AuthError from synapse.api.constants import PresenceState -from synapse.util.logutils import log_function from synapse.util.logcontext import PreserveLoggingContext +from synapse.util.logutils import log_function from synapse.types import UserID import synapse.metrics @@ -278,15 +278,14 @@ class PresenceHandler(BaseHandler): now_online = state["presence"] != PresenceState.OFFLINE was_polling = target_user in self._user_cachemap - with PreserveLoggingContext(): - if now_online and not was_polling: - self.start_polling_presence(target_user, state=state) - elif not now_online and was_polling: - self.stop_polling_presence(target_user) + if now_online and not was_polling: + self.start_polling_presence(target_user, state=state) + elif not now_online and was_polling: + self.stop_polling_presence(target_user) - # TODO(paul): perform a presence push as part of start/stop poll so - # we don't have to do this all the time - self.changed_presencelike_data(target_user, state) + # TODO(paul): perform a presence push as part of start/stop poll so + # we don't have to do this all the time + self.changed_presencelike_data(target_user, state) def bump_presence_active_time(self, user, now=None): if now is None: @@ -318,6 +317,14 @@ class PresenceHandler(BaseHandler): @defer.inlineCallbacks def user_joined_room(self, user, room_id): + """Called via the distributor whenever a user joins a room. + Notifies the new member of the presence of the current members. + Notifies the current members of the room of the new member's presence. + + Args: + user(UserID): The user who joined the room. + room_id(str): The room id the user joined. + """ if self.hs.is_mine(user): statuscache = self._get_or_make_usercache(user) @@ -337,6 +344,8 @@ class PresenceHandler(BaseHandler): curr_users = yield rm_handler.get_room_members(room_id) for local_user in [c for c in curr_users if self.hs.is_mine(c)]: + statuscache = self._get_or_offline_usercache(local_user) + statuscache.update({}, serial=self._user_cachemap_latest_serial) self.push_update_to_local_and_remote( observed_user=local_user, users_to_push=[user], @@ -345,6 +354,7 @@ class PresenceHandler(BaseHandler): @defer.inlineCallbacks def send_invite(self, observer_user, observed_user): + """Request the presence of a local or remote user for a local user""" if not self.hs.is_mine(observer_user): raise SynapseError(400, "User is not hosted on this Home Server") @@ -379,6 +389,15 @@ class PresenceHandler(BaseHandler): @defer.inlineCallbacks def invite_presence(self, observed_user, observer_user): + """Handles a m.presence_invite EDU. A remote or local user has + requested presence updates for a local user. If the invite is accepted + then allow the local or remote user to see the presence of the local + user. + + Args: + observed_user(UserID): The local user whose presence is requested. + observer_user(UserID): The remote or local user requesting presence. + """ accept = yield self._should_accept_invite(observed_user, observer_user) if accept: @@ -405,16 +424,34 @@ class PresenceHandler(BaseHandler): @defer.inlineCallbacks def accept_presence(self, observed_user, observer_user): + """Handles a m.presence_accept EDU. Mark a presence invite from a + local or remote user as accepted in a local user's presence list. + Starts polling for presence updates from the local or remote user. + + Args: + observed_user(UserID): The user to update in the presence list. + observer_user(UserID): The owner of the presence list to update. + """ yield self.store.set_presence_list_accepted( observer_user.localpart, observed_user.to_string() ) - with PreserveLoggingContext(): - self.start_polling_presence( - observer_user, target_user=observed_user - ) + + self.start_polling_presence( + observer_user, target_user=observed_user + ) @defer.inlineCallbacks def deny_presence(self, observed_user, observer_user): + """Handle a m.presence_deny EDU. Removes a local or remote user from a + local user's presence list. + + Args: + observed_user(UserID): The local or remote user to remove from the + list. + observer_user(UserID): The local owner of the presence list. + Returns: + A Deferred. + """ yield self.store.del_presence_list( observer_user.localpart, observed_user.to_string() ) @@ -423,6 +460,16 @@ class PresenceHandler(BaseHandler): @defer.inlineCallbacks def drop(self, observed_user, observer_user): + """Remove a local or remote user from a local user's presence list and + unsubscribe the local user from updates that user. + + Args: + observed_user(UserId): The local or remote user to remove from the + list. + observer_user(UserId): The local owner of the presence list. + Returns: + A Deferred. + """ if not self.hs.is_mine(observer_user): raise SynapseError(400, "User is not hosted on this Home Server") @@ -430,13 +477,22 @@ class PresenceHandler(BaseHandler): observer_user.localpart, observed_user.to_string() ) - with PreserveLoggingContext(): - self.stop_polling_presence( - observer_user, target_user=observed_user - ) + self.stop_polling_presence( + observer_user, target_user=observed_user + ) @defer.inlineCallbacks def get_presence_list(self, observer_user, accepted=None): + """Get the presence list for a local user. The retured list includes + the current presence state for each user listed. + + Args: + observer_user(UserID): The local user whose presence list to fetch. + accepted(bool or None): If not none then only include users who + have or have not accepted the presence invite request. + Returns: + A Deferred list of presence state events. + """ if not self.hs.is_mine(observer_user): raise SynapseError(400, "User is not hosted on this Home Server") @@ -458,6 +514,23 @@ class PresenceHandler(BaseHandler): @defer.inlineCallbacks @log_function def start_polling_presence(self, user, target_user=None, state=None): + """Subscribe a local user to presence updates from a local or remote + user. If no target_user is supplied then subscribe to all users stored + in the presence list for the local user. + + Additonally this pushes the current presence state of this user to all + target_users. That state can be provided directly or will be read from + the stored state for the local user. + + Also this attempts to notify the local user of the current state of + any local target users. + + Args: + user(UserID): The local user that whishes for presence updates. + target_user(UserID): The local or remote user whose updates are + wanted. + state(dict): Optional presence state for the local user. + """ logger.debug("Start polling for presence from %s", user) if target_user: @@ -498,9 +571,7 @@ class PresenceHandler(BaseHandler): # We want to tell the person that just came online # presence state of people they are interested in? self.push_update_to_clients( - observed_user=target_user, users_to_push=[user], - statuscache=self._get_or_offline_usercache(target_user), ) deferreds = [] @@ -517,6 +588,12 @@ class PresenceHandler(BaseHandler): yield defer.DeferredList(deferreds, consumeErrors=True) def _start_polling_local(self, user, target_user): + """Subscribe a local user to presence updates for a local user + + Args: + user(UserId): The local user that wishes for updates. + target_user(UserId): The local users whose updates are wanted. + """ target_localpart = target_user.localpart if target_localpart not in self._local_pushmap: @@ -525,6 +602,17 @@ class PresenceHandler(BaseHandler): self._local_pushmap[target_localpart].add(user) def _start_polling_remote(self, user, domain, remoteusers): + """Subscribe a local user to presence updates for remote users on a + given remote domain. + + Args: + user(UserID): The local user that wishes for updates. + domain(str): The remote server the local user wants updates from. + remoteusers(UserID): The remote users that local user wants to be + told about. + Returns: + A Deferred. + """ to_poll = set() for u in remoteusers: @@ -545,6 +633,17 @@ class PresenceHandler(BaseHandler): @log_function def stop_polling_presence(self, user, target_user=None): + """Unsubscribe a local user from presence updates from a local or + remote user. If no target user is supplied then unsubscribe the user + from all presence updates that the user had subscribed to. + + Args: + user(UserID): The local user that no longer wishes for updates. + target_user(UserID or None): The user whose updates are no longer + wanted. + Returns: + A Deferred. + """ logger.debug("Stop polling for presence from %s", user) if not target_user or self.hs.is_mine(target_user): @@ -573,6 +672,13 @@ class PresenceHandler(BaseHandler): return defer.DeferredList(deferreds, consumeErrors=True) def _stop_polling_local(self, user, target_user): + """Unsubscribe a local user from presence updates from a local user on + this server. + + Args: + user(UserID): The local user that no longer wishes for updates. + target_user(UserID): The user whose updates are no longer wanted. + """ for localpart in self._local_pushmap.keys(): if target_user and localpart != target_user.localpart: continue @@ -585,6 +691,17 @@ class PresenceHandler(BaseHandler): @log_function def _stop_polling_remote(self, user, domain, remoteusers): + """Unsubscribe a local user from presence updates from remote users on + a given domain. + + Args: + user(UserID): The local user that no longer wishes for updates. + domain(str): The remote server to unsubscribe from. + remoteusers([UserID]): The users on that remote server that the + local user no longer wishes to be updated about. + Returns: + A Deferred. + """ to_unpoll = set() for u in remoteusers: @@ -606,6 +723,19 @@ class PresenceHandler(BaseHandler): @defer.inlineCallbacks @log_function def push_presence(self, user, statuscache): + """ + Notify local and remote users of a change in presence of a local user. + Pushes the update to local clients and remote domains that are directly + subscribed to the presence of the local user. + Also pushes that update to any local user or remote domain that shares + a room with the local user. + + Args: + user(UserID): The local user whose presence was updated. + statuscache(UserPresenceCache): Cache of the user's presence state + Returns: + A Deferred. + """ assert(self.hs.is_mine(user)) logger.debug("Pushing presence update from %s", user) @@ -633,44 +763,23 @@ class PresenceHandler(BaseHandler): yield self.distributor.fire("user_presence_changed", user, statuscache) @defer.inlineCallbacks - def _push_presence_remote(self, user, destination, state=None): - if state is None: - state = yield self.store.get_presence_state(user.localpart) - del state["mtime"] - state["presence"] = state.pop("state") - - if user in self._user_cachemap: - state["last_active"] = ( - self._user_cachemap[user].get_state()["last_active"] - ) - - yield self.distributor.fire( - "collect_presencelike_data", user, state - ) - - if "last_active" in state: - state = dict(state) - state["last_active_ago"] = int( - self.clock.time_msec() - state.pop("last_active") - ) - - user_state = { - "user_id": user.to_string(), - } - user_state.update(**state) - - yield self.federation.send_edu( - destination=destination, - edu_type="m.presence", - content={ - "push": [ - user_state, - ], - } - ) - - @defer.inlineCallbacks def incoming_presence(self, origin, content): + """Handle an incoming m.presence EDU. + For each presence update in the "push" list update our local cache and + notify the appropriate local clients. Only clients that share a room + or are directly subscribed to the presence for a user should be + notified of the update. + For each subscription request in the "poll" list start pushing presence + updates to the remote server. + For unsubscribe request in the "unpoll" list stop pushing presence + updates to the remote server. + + Args: + orgin(str): The source of this m.presence EDU. + content(dict): The content of this m.presence EDU. + Returns: + A Deferred. + """ deferreds = [] for push in content.get("push", []): @@ -714,10 +823,7 @@ class PresenceHandler(BaseHandler): continue self.push_update_to_clients( - observed_user=user, - users_to_push=observers, - room_ids=room_ids, - statuscache=statuscache, + users_to_push=observers, room_ids=room_ids ) user_id = user.to_string() @@ -766,13 +872,29 @@ class PresenceHandler(BaseHandler): if not self._remote_sendmap[user]: del self._remote_sendmap[user] - with PreserveLoggingContext(): - yield defer.DeferredList(deferreds, consumeErrors=True) + yield defer.DeferredList(deferreds, consumeErrors=True) @defer.inlineCallbacks def push_update_to_local_and_remote(self, observed_user, statuscache, users_to_push=[], room_ids=[], remote_domains=[]): + """Notify local clients and remote servers of a change in the presence + of a user. + + Args: + observed_user(UserID): The user to push the presence state for. + statuscache(UserPresenceCache): The cache for the presence state to + push. + users_to_push([UserID]): A list of local and remote users to + notify. + room_ids([str]): Notify the local and remote occupants of these + rooms. + remote_domains([str]): A list of remote servers to notify in + addition to those implied by the users_to_push and the + room_ids. + Returns: + A Deferred. + """ localusers, remoteusers = partitionbool( users_to_push, @@ -782,10 +904,7 @@ class PresenceHandler(BaseHandler): localusers = set(localusers) self.push_update_to_clients( - observed_user=observed_user, - users_to_push=localusers, - room_ids=room_ids, - statuscache=statuscache, + users_to_push=localusers, room_ids=room_ids ) remote_domains = set(remote_domains) @@ -810,11 +929,65 @@ class PresenceHandler(BaseHandler): defer.returnValue((localusers, remote_domains)) - def push_update_to_clients(self, observed_user, users_to_push=[], - room_ids=[], statuscache=None): - self.notifier.on_new_user_event( - users_to_push, - room_ids, + def push_update_to_clients(self, users_to_push=[], room_ids=[]): + """Notify clients of a new presence event. + + Args: + users_to_push([UserID]): List of users to notify. + room_ids([str]): List of room_ids to notify. + """ + with PreserveLoggingContext(): + self.notifier.on_new_user_event( + "presence_key", + self._user_cachemap_latest_serial, + users_to_push, + room_ids, + ) + + @defer.inlineCallbacks + def _push_presence_remote(self, user, destination, state=None): + """Push a user's presence to a remote server. If a presence state event + that event is sent. Otherwise a new state event is constructed from the + stored presence state. + The last_active is replaced with last_active_ago in case the wallclock + time on the remote server is different to the time on this server. + Sends an EDU to the remote server with the current presence state. + + Args: + user(UserID): The user to push the presence state for. + destination(str): The remote server to send state to. + state(dict): The state to push, or None to use the current stored + state. + Returns: + A Deferred. + """ + if state is None: + state = yield self.store.get_presence_state(user.localpart) + del state["mtime"] + state["presence"] = state.pop("state") + + if user in self._user_cachemap: + state["last_active"] = ( + self._user_cachemap[user].get_state()["last_active"] + ) + + yield self.distributor.fire( + "collect_presencelike_data", user, state + ) + + if "last_active" in state: + state = dict(state) + state["last_active_ago"] = int( + self.clock.time_msec() - state.pop("last_active") + ) + + user_state = {"user_id": user.to_string(), } + user_state.update(state) + + yield self.federation.send_edu( + destination=destination, + edu_type="m.presence", + content={"push": [user_state, ], } ) diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index ee2732b848..799faffe53 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -17,8 +17,8 @@ from twisted.internet import defer from synapse.api.errors import SynapseError, AuthError, CodeMessageException from synapse.api.constants import EventTypes, Membership -from synapse.util.logcontext import PreserveLoggingContext from synapse.types import UserID +from synapse.util import unwrapFirstError from ._base import BaseHandler @@ -88,6 +88,9 @@ class ProfileHandler(BaseHandler): if target_user != auth_user: raise AuthError(400, "Cannot set another user's displayname") + if new_displayname == '': + new_displayname = None + yield self.store.set_profile_displayname( target_user.localpart, new_displayname ) @@ -154,14 +157,13 @@ class ProfileHandler(BaseHandler): if not self.hs.is_mine(user): defer.returnValue(None) - with PreserveLoggingContext(): - (displayname, avatar_url) = yield defer.gatherResults( - [ - self.store.get_profile_displayname(user.localpart), - self.store.get_profile_avatar_url(user.localpart), - ], - consumeErrors=True - ) + (displayname, avatar_url) = yield defer.gatherResults( + [ + self.store.get_profile_displayname(user.localpart), + self.store.get_profile_avatar_url(user.localpart), + ], + consumeErrors=True + ).addErrback(unwrapFirstError) state["displayname"] = displayname state["avatar_url"] = avatar_url diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 3da08c147e..4bd027d9bb 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -21,11 +21,12 @@ from ._base import BaseHandler from synapse.types import UserID, RoomAlias, RoomID from synapse.api.constants import EventTypes, Membership, JoinRules from synapse.api.errors import StoreError, SynapseError -from synapse.util import stringutils +from synapse.util import stringutils, unwrapFirstError from synapse.util.async import run_on_reactor from synapse.events.utils import serialize_event import logging +import string logger = logging.getLogger(__name__) @@ -50,6 +51,10 @@ class RoomCreationHandler(BaseHandler): self.ratelimit(user_id) if "room_alias_name" in config: + for wchar in string.whitespace: + if wchar in config["room_alias_name"]: + raise SynapseError(400, "Invalid characters in room alias") + room_alias = RoomAlias.create( config["room_alias_name"], self.hs.hostname, @@ -535,7 +540,7 @@ class RoomListHandler(BaseHandler): for room in chunk ], consumeErrors=True, - ) + ).addErrback(unwrapFirstError) for i, room in enumerate(chunk): room["num_joined_members"] = len(results[i]) @@ -575,8 +580,8 @@ class RoomEventSource(object): defer.returnValue((events, end_key)) - def get_current_key(self): - return self.store.get_room_events_max_id() + def get_current_key(self, direction='f'): + return self.store.get_room_events_max_id(direction) @defer.inlineCallbacks def get_pagination_rows(self, user, config, key): diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 35a62fda47..bd8c603681 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -92,7 +92,7 @@ class SyncHandler(BaseHandler): result = yield self.current_sync_for_user(sync_config, since_token) defer.returnValue(result) else: - def current_sync_callback(): + def current_sync_callback(before_token, after_token): return self.current_sync_for_user(sync_config, since_token) rm_handler = self.hs.get_handlers().room_member_handler diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index c0b2bd7db0..a9895292c2 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -18,6 +18,7 @@ from twisted.internet import defer from ._base import BaseHandler from synapse.api.errors import SynapseError, AuthError +from synapse.util.logcontext import PreserveLoggingContext from synapse.types import UserID import logging @@ -216,7 +217,10 @@ class TypingNotificationHandler(BaseHandler): self._latest_room_serial += 1 self._room_serials[room_id] = self._latest_room_serial - self.notifier.on_new_user_event(rooms=[room_id]) + with PreserveLoggingContext(): + self.notifier.on_new_user_event( + "typing_key", self._latest_room_serial, rooms=[room_id] + ) class TypingNotificationEventSource(object): diff --git a/synapse/http/client.py b/synapse/http/client.py index e8a5dedab4..5b3cefb2dc 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -14,6 +14,7 @@ # limitations under the License. from synapse.api.errors import CodeMessageException +from synapse.util.logcontext import preserve_context_over_fn from syutil.jsonutil import encode_canonical_json import synapse.metrics @@ -61,7 +62,10 @@ class SimpleHttpClient(object): # A small wrapper around self.agent.request() so we can easily attach # counters to it outgoing_requests_counter.inc(method) - d = self.agent.request(method, *args, **kwargs) + d = preserve_context_over_fn( + self.agent.request, + method, *args, **kwargs + ) def _cb(response): incoming_responses_counter.inc(method, response.code) diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 7fa295cad5..c99d237c73 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -22,7 +22,7 @@ from twisted.web._newclient import ResponseDone from synapse.http.endpoint import matrix_federation_endpoint from synapse.util.async import sleep -from synapse.util.logcontext import PreserveLoggingContext +from synapse.util.logcontext import preserve_context_over_fn import synapse.metrics from syutil.jsonutil import encode_canonical_json @@ -144,22 +144,22 @@ class MatrixFederationHttpClient(object): producer = body_callback(method, url_bytes, headers_dict) try: - with PreserveLoggingContext(): - request_deferred = self.agent.request( - destination, - endpoint, - method, - path_bytes, - param_bytes, - query_bytes, - Headers(headers_dict), - producer - ) + request_deferred = preserve_context_over_fn( + self.agent.request, + destination, + endpoint, + method, + path_bytes, + param_bytes, + query_bytes, + Headers(headers_dict), + producer + ) - response = yield self.clock.time_bound_deferred( - request_deferred, - time_out=60, - ) + response = yield self.clock.time_bound_deferred( + request_deferred, + time_out=60, + ) logger.debug("Got response to %s", method) break diff --git a/synapse/http/server.py b/synapse/http/server.py index 93ecbd7589..73efbff4f2 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -17,7 +17,7 @@ from synapse.api.errors import ( cs_exception, SynapseError, CodeMessageException, UnrecognizedRequestError ) -from synapse.util.logcontext import LoggingContext +from synapse.util.logcontext import LoggingContext, PreserveLoggingContext import synapse.metrics from syutil.jsonutil import ( @@ -85,7 +85,9 @@ def request_handler(request_handler): "Received request: %s %s", request.method, request.path ) - yield request_handler(self, request) + d = request_handler(self, request) + with PreserveLoggingContext(): + yield d code = request.code except CodeMessageException as e: code = e.code diff --git a/synapse/notifier.py b/synapse/notifier.py index 78eb28e4b2..4f47f88df8 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -16,7 +16,7 @@ from twisted.internet import defer from synapse.util.logutils import log_function -from synapse.util.logcontext import PreserveLoggingContext +from synapse.util.async import run_on_reactor from synapse.types import StreamToken import synapse.metrics @@ -42,63 +42,78 @@ def count(func, l): class _NotificationListener(object): """ This represents a single client connection to the events stream. - The events stream handler will have yielded to the deferred, so to notify the handler it is sufficient to resolve the deferred. - - This listener will also keep track of which rooms it is listening in - so that it can remove itself from the indexes in the Notifier class. """ - def __init__(self, user, rooms, from_token, limit, timeout, deferred, - appservice=None): - self.user = user - self.appservice = appservice - self.from_token = from_token - self.limit = limit - self.timeout = timeout + def __init__(self, deferred): self.deferred = deferred - self.rooms = rooms - self.timer = None def notified(self): return self.deferred.called - def notify(self, notifier, events, start_token, end_token): - """ Inform whoever is listening about the new events. This will - also remove this listener from all the indexes in the Notifier - it knows about. + def notify(self, token): + """ Inform whoever is listening about the new events. """ - - result = (events, (start_token, end_token)) - try: - self.deferred.callback(result) - notified_events_counter.inc_by(len(events)) + self.deferred.callback(token) except defer.AlreadyCalledError: pass - # Should the following be done be using intrusively linked lists? - # -- erikj + +class _NotifierUserStream(object): + """This represents a user connected to the event stream. + It tracks the most recent stream token for that user. + At a given point a user may have a number of streams listening for + events. + + This listener will also keep track of which rooms it is listening in + so that it can remove itself from the indexes in the Notifier class. + """ + + def __init__(self, user, rooms, current_token, time_now_ms, + appservice=None): + self.user = str(user) + self.appservice = appservice + self.listeners = set() + self.rooms = set(rooms) + self.current_token = current_token + self.last_notified_ms = time_now_ms + + def notify(self, stream_key, stream_id, time_now_ms): + """Notify any listeners for this user of a new event from an + event source. + Args: + stream_key(str): The stream the event came from. + stream_id(str): The new id for the stream the event came from. + time_now_ms(int): The current time in milliseconds. + """ + self.current_token = self.current_token.copy_and_advance( + stream_key, stream_id + ) + if self.listeners: + self.last_notified_ms = time_now_ms + listeners = self.listeners + self.listeners = set() + for listener in listeners: + listener.notify(self.current_token) + + def remove(self, notifier): + """ Remove this listener from all the indexes in the Notifier + it knows about. + """ for room in self.rooms: - lst = notifier.room_to_listeners.get(room, set()) + lst = notifier.room_to_user_streams.get(room, set()) lst.discard(self) - notifier.user_to_listeners.get(self.user, set()).discard(self) + notifier.user_to_user_stream.pop(self.user) if self.appservice: - notifier.appservice_to_listeners.get( + notifier.appservice_to_user_streams.get( self.appservice, set() ).discard(self) - # Cancel the timeout for this notifer if one exists. - if self.timer is not None: - try: - notifier.clock.cancel_call_later(self.timer) - except: - logger.warn("Failed to cancel notifier timer") - class Notifier(object): """ This class is responsible for notifying any listeners when there are @@ -107,14 +122,18 @@ class Notifier(object): Primarily used from the /events stream. """ + UNUSED_STREAM_EXPIRY_MS = 10 * 60 * 1000 + def __init__(self, hs): self.hs = hs - self.room_to_listeners = {} - self.user_to_listeners = {} - self.appservice_to_listeners = {} + self.user_to_user_stream = {} + self.room_to_user_streams = {} + self.appservice_to_user_streams = {} self.event_sources = hs.get_event_sources() + self.store = hs.get_datastore() + self.pending_new_room_events = [] self.clock = hs.get_clock() @@ -122,45 +141,80 @@ class Notifier(object): "user_joined_room", self._user_joined_room ) + self.clock.looping_call( + self.remove_expired_streams, self.UNUSED_STREAM_EXPIRY_MS + ) + # This is not a very cheap test to perform, but it's only executed # when rendering the metrics page, which is likely once per minute at # most when scraping it. def count_listeners(): - all_listeners = set() + all_user_streams = set() - for x in self.room_to_listeners.values(): - all_listeners |= x - for x in self.user_to_listeners.values(): - all_listeners |= x - for x in self.appservice_to_listeners.values(): - all_listeners |= x + for x in self.room_to_user_streams.values(): + all_user_streams |= x + for x in self.user_to_user_stream: + all_user_streams.add(x) + for x in self.appservice_to_user_streams.values(): + all_user_streams |= x - return len(all_listeners) + return sum(len(stream.listeners) for stream in all_user_streams) metrics.register_callback("listeners", count_listeners) metrics.register_callback( "rooms", - lambda: count(bool, self.room_to_listeners.values()), + lambda: count(bool, self.room_to_user_streams.values()), ) metrics.register_callback( "users", - lambda: count(bool, self.user_to_listeners.values()), + lambda: len(self.user_to_user_stream), ) metrics.register_callback( "appservices", - lambda: count(bool, self.appservice_to_listeners.values()), + lambda: count(bool, self.appservice_to_user_streams.values()), ) @log_function @defer.inlineCallbacks - def on_new_room_event(self, event, extra_users=[]): + def on_new_room_event(self, event, room_stream_id, max_room_stream_id, + extra_users=[]): """ Used by handlers to inform the notifier something has happened in the room, room event wise. This triggers the notifier to wake up any listeners that are listening to the room, and any listeners for the users in the `extra_users` param. + + The events can be peristed out of order. The notifier will wait + until all previous events have been persisted before notifying + the client streams. + """ + yield run_on_reactor() + + self.pending_new_room_events.append(( + room_stream_id, event, extra_users + )) + self._notify_pending_new_room_events(max_room_stream_id) + + def _notify_pending_new_room_events(self, max_room_stream_id): + """Notify for the room events that were queued waiting for a previous + event to be persisted. + Args: + max_room_stream_id(int): The highest stream_id below which all + events have been persisted. """ + pending = self.pending_new_room_events + self.pending_new_room_events = [] + for room_stream_id, event, extra_users in pending: + if room_stream_id > max_room_stream_id: + self.pending_new_room_events.append(( + room_stream_id, event, extra_users + )) + else: + self._on_new_room_event(event, room_stream_id, extra_users) + + def _on_new_room_event(self, event, room_stream_id, extra_users=[]): + """Notify any user streams that are interested in this room event""" # poke any interested application service. self.hs.get_handlers().appservice_handler.notify_interested_services( event @@ -168,192 +222,129 @@ class Notifier(object): room_id = event.room_id - room_source = self.event_sources.sources["room"] - - room_listeners = self.room_to_listeners.get(room_id, set()) - - _discard_if_notified(room_listeners) + room_user_streams = self.room_to_user_streams.get(room_id, set()) - listeners = room_listeners.copy() + user_streams = room_user_streams.copy() for user in extra_users: - user_listeners = self.user_to_listeners.get(user, set()) + user_stream = self.user_to_user_stream.get(str(user)) + if user_stream is not None: + user_streams.add(user_stream) - _discard_if_notified(user_listeners) - - listeners |= user_listeners - - for appservice in self.appservice_to_listeners: + for appservice in self.appservice_to_user_streams: # TODO (kegan): Redundant appservice listener checks? - # App services will already be in the room_to_listeners set, but + # App services will already be in the room_to_user_streams set, but # that isn't enough. They need to be checked here in order to # receive *invites* for users they are interested in. Does this - # make the room_to_listeners check somewhat obselete? + # make the room_to_user_streams check somewhat obselete? if appservice.is_interested(event): - app_listeners = self.appservice_to_listeners.get( + app_user_streams = self.appservice_to_user_streams.get( appservice, set() ) + user_streams |= app_user_streams - _discard_if_notified(app_listeners) - - listeners |= app_listeners - - logger.debug("on_new_room_event listeners %s", listeners) - - # TODO (erikj): Can we make this more efficient by hitting the - # db once? - - @defer.inlineCallbacks - def notify(listener): - events, end_key = yield room_source.get_new_events_for_user( - listener.user, - listener.from_token.room_key, - listener.limit, - ) - - if events: - end_token = listener.from_token.copy_and_replace( - "room_key", end_key - ) + logger.debug("on_new_room_event listeners %s", user_streams) - listener.notify( - self, events, listener.from_token, end_token + time_now_ms = self.clock.time_msec() + for user_stream in user_streams: + try: + user_stream.notify( + "room_key", "s%d" % (room_stream_id,), time_now_ms ) - - def eb(failure): - logger.exception("Failed to notify listener", failure) - - with PreserveLoggingContext(): - yield defer.DeferredList( - [notify(l).addErrback(eb) for l in listeners], - consumeErrors=True, - ) + except: + logger.exception("Failed to notify listener") @defer.inlineCallbacks @log_function - def on_new_user_event(self, users=[], rooms=[]): + def on_new_user_event(self, stream_key, new_token, users=[], rooms=[]): """ Used to inform listeners that something has happend presence/user event wise. Will wake up all listeners for the given users and rooms. """ - # TODO(paul): This is horrible, having to manually list every event - # source here individually - presence_source = self.event_sources.sources["presence"] - typing_source = self.event_sources.sources["typing"] - - listeners = set() + yield run_on_reactor() + user_streams = set() for user in users: - user_listeners = self.user_to_listeners.get(user, set()) - - _discard_if_notified(user_listeners) - - listeners |= user_listeners + user_stream = self.user_to_user_stream.get(str(user)) + if user_stream is not None: + user_streams.add(user_stream) for room in rooms: - room_listeners = self.room_to_listeners.get(room, set()) - - _discard_if_notified(room_listeners) - - listeners |= room_listeners - - @defer.inlineCallbacks - def notify(listener): - presence_events, presence_end_key = ( - yield presence_source.get_new_events_for_user( - listener.user, - listener.from_token.presence_key, - listener.limit, - ) - ) - typing_events, typing_end_key = ( - yield typing_source.get_new_events_for_user( - listener.user, - listener.from_token.typing_key, - listener.limit, - ) - ) - - if presence_events or typing_events: - end_token = listener.from_token.copy_and_replace( - "presence_key", presence_end_key - ).copy_and_replace( - "typing_key", typing_end_key - ) + user_streams |= self.room_to_user_streams.get(room, set()) - listener.notify( - self, - presence_events + typing_events, - listener.from_token, - end_token - ) - - def eb(failure): - logger.error( - "Failed to notify listener", - exc_info=( - failure.type, - failure.value, - failure.getTracebackObject()) - ) - - with PreserveLoggingContext(): - yield defer.DeferredList( - [notify(l).addErrback(eb) for l in listeners], - consumeErrors=True, - ) + time_now_ms = self.clock.time_msec() + for user_stream in user_streams: + try: + user_stream.notify(stream_key, new_token, time_now_ms) + except: + logger.exception("Failed to notify listener") @defer.inlineCallbacks - def wait_for_events(self, user, rooms, filter, timeout, callback): + def wait_for_events(self, user, rooms, timeout, callback, + from_token=StreamToken("s0", "0", "0")): """Wait until the callback returns a non empty response or the timeout fires. """ deferred = defer.Deferred() + time_now_ms = self.clock.time_msec() + + user = str(user) + user_stream = self.user_to_user_stream.get(user) + if user_stream is None: + appservice = yield self.store.get_app_service_by_user_id(user) + current_token = yield self.event_sources.get_current_token() + rooms = yield self.store.get_rooms_for_user(user) + rooms = [room.room_id for room in rooms] + user_stream = _NotifierUserStream( + user=user, + rooms=rooms, + appservice=appservice, + current_token=current_token, + time_now_ms=time_now_ms, + ) + self._register_with_keys(user_stream) + else: + current_token = user_stream.current_token - from_token = StreamToken("s0", "0", "0") + listener = [_NotificationListener(deferred)] - listener = [_NotificationListener( - user=user, - rooms=rooms, - from_token=from_token, - limit=1, - timeout=timeout, - deferred=deferred, - )] + if timeout and not current_token.is_after(from_token): + user_stream.listeners.add(listener[0]) - if timeout: - self._register_with_keys(listener[0]) + if current_token.is_after(from_token): + result = yield callback(from_token, current_token) + else: + result = None - result = yield callback() timer = [None] + if result: + user_stream.listeners.discard(listener[0]) + defer.returnValue(result) + return + if timeout: timed_out = [False] def _timeout_listener(): timed_out[0] = True timer[0] = None - listener[0].notify(self, [], from_token, from_token) + user_stream.listeners.discard(listener[0]) + listener[0].notify(current_token) # We create multiple notification listeners so we have to manage # canceling the timeout ourselves. timer[0] = self.clock.call_later(timeout/1000., _timeout_listener) while not result and not timed_out[0]: - yield deferred + new_token = yield deferred deferred = defer.Deferred() - listener[0] = _NotificationListener( - user=user, - rooms=rooms, - from_token=from_token, - limit=1, - timeout=timeout, - deferred=deferred, - ) - self._register_with_keys(listener[0]) - result = yield callback() + listener[0] = _NotificationListener(deferred) + user_stream.listeners.add(listener[0]) + result = yield callback(current_token, new_token) + current_token = new_token if timer[0] is not None: try: @@ -363,125 +354,79 @@ class Notifier(object): defer.returnValue(result) + @defer.inlineCallbacks def get_events_for(self, user, rooms, pagination_config, timeout): """ For the given user and rooms, return any new events for them. If there are no new events wait for up to `timeout` milliseconds for any new events to happen before returning. """ - deferred = defer.Deferred() - - self._get_events( - deferred, user, rooms, pagination_config.from_token, - pagination_config.limit, timeout - ).addErrback(deferred.errback) - - return deferred - - @defer.inlineCallbacks - def _get_events(self, deferred, user, rooms, from_token, limit, timeout): + from_token = pagination_config.from_token if not from_token: from_token = yield self.event_sources.get_current_token() - appservice = yield self.hs.get_datastore().get_app_service_by_user_id( - user.to_string() - ) + limit = pagination_config.limit - listener = _NotificationListener( - user, - rooms, - from_token, - limit, - timeout, - deferred, - appservice=appservice - ) + @defer.inlineCallbacks + def check_for_updates(before_token, after_token): + events = [] + end_token = from_token + for name, source in self.event_sources.sources.items(): + keyname = "%s_key" % name + before_id = getattr(before_token, keyname) + after_id = getattr(after_token, keyname) + if before_id == after_id: + continue + stuff, new_key = yield source.get_new_events_for_user( + user, getattr(from_token, keyname), limit, + ) + events.extend(stuff) + end_token = end_token.copy_and_replace(keyname, new_key) - def _timeout_listener(): - # TODO (erikj): We should probably set to_token to the current - # max rather than reusing from_token. - # Remove the timer from the listener so we don't try to cancel it. - listener.timer = None - listener.notify( - self, - [], - listener.from_token, - listener.from_token, - ) + if events: + defer.returnValue((events, (from_token, end_token))) + else: + defer.returnValue(None) - if timeout: - self._register_with_keys(listener) + result = yield self.wait_for_events( + user, rooms, timeout, check_for_updates, from_token=from_token + ) - yield self._check_for_updates(listener) + if result is None: + result = ([], (from_token, from_token)) - if not timeout: - _timeout_listener() - else: - # Only add the timer if the listener hasn't been notified - if not listener.notified(): - listener.timer = self.clock.call_later( - timeout/1000.0, _timeout_listener - ) - return + defer.returnValue(result) @log_function - def _register_with_keys(self, listener): - for room in listener.rooms: - s = self.room_to_listeners.setdefault(room, set()) - s.add(listener) + def remove_expired_streams(self): + time_now_ms = self.clock.time_msec() + expired_streams = [] + expire_before_ts = time_now_ms - self.UNUSED_STREAM_EXPIRY_MS + for stream in self.user_to_user_stream.values(): + if stream.listeners: + continue + if stream.last_notified_ms < expire_before_ts: + expired_streams.append(stream) + + for expired_stream in expired_streams: + expired_stream.remove(self) - self.user_to_listeners.setdefault(listener.user, set()).add(listener) - - if listener.appservice: - self.appservice_to_listeners.setdefault( - listener.appservice, set() - ).add(listener) - - @defer.inlineCallbacks @log_function - def _check_for_updates(self, listener): - # TODO (erikj): We need to think about limits across multiple sources - events = [] + def _register_with_keys(self, user_stream): + self.user_to_user_stream[user_stream.user] = user_stream - from_token = listener.from_token - limit = listener.limit + for room in user_stream.rooms: + s = self.room_to_user_streams.setdefault(room, set()) + s.add(user_stream) - # TODO (erikj): DeferredList? - for name, source in self.event_sources.sources.items(): - keyname = "%s_key" % name - - stuff, new_key = yield source.get_new_events_for_user( - listener.user, - getattr(from_token, keyname), - limit, - ) - - events.extend(stuff) - - from_token = from_token.copy_and_replace(keyname, new_key) - - end_token = from_token - - if events: - listener.notify(self, events, listener.from_token, end_token) - - defer.returnValue(listener) + if user_stream.appservice: + self.appservice_to_user_stream.setdefault( + user_stream.appservice, set() + ).add(user_stream) def _user_joined_room(self, user, room_id): - new_listeners = self.user_to_listeners.get(user, set()) - - listeners = self.room_to_listeners.setdefault(room_id, set()) - listeners |= new_listeners - - for l in new_listeners: - l.rooms.add(room_id) - - -def _discard_if_notified(listener_set): - """Remove any 'stale' listeners from the given set. - """ - to_discard = set() - for l in listener_set: - if l.notified(): - to_discard.add(l) - - listener_set -= to_discard + user = str(user) + new_user_stream = self.user_to_user_stream.get(user) + if new_user_stream is not None: + room_streams = self.room_to_user_streams.setdefault(room_id, set()) + room_streams.add(new_user_stream) + new_user_stream.rooms.add(room_id) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 3640fb4a29..72dfb876c5 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -82,8 +82,10 @@ class RegisterRestServlet(RestServlet): [LoginType.EMAIL_IDENTITY] ] + result = None if service: is_application_server = True + params = body elif 'mac' in body: # Check registration-specific shared secret auth if 'username' not in body: @@ -92,6 +94,7 @@ class RegisterRestServlet(RestServlet): body['username'], body['mac'] ) is_using_shared_secret = True + params = body else: authed, result, params = yield self.auth_handler.check_auth( flows, body, self.hs.get_ip_from_request(request) @@ -118,7 +121,7 @@ class RegisterRestServlet(RestServlet): password=new_password ) - if LoginType.EMAIL_IDENTITY in result: + if result and LoginType.EMAIL_IDENTITY in result: threepid = result[LoginType.EMAIL_IDENTITY] for reqd in ['medium', 'address', 'validated_at']: diff --git a/synapse/rest/media/v1/base_resource.py b/synapse/rest/media/v1/base_resource.py index 08c8d75af4..4af5f73878 100644 --- a/synapse/rest/media/v1/base_resource.py +++ b/synapse/rest/media/v1/base_resource.py @@ -25,7 +25,7 @@ from twisted.internet import defer from twisted.web.resource import Resource from twisted.protocols.basic import FileSender -from synapse.util.async import create_observer +from synapse.util.async import ObservableDeferred import os @@ -83,13 +83,17 @@ class BaseMediaResource(Resource): download = self.downloads.get(key) if download is None: download = self._get_remote_media_impl(server_name, media_id) + download = ObservableDeferred( + download, + consumeErrors=True + ) self.downloads[key] = download @download.addBoth def callback(media_info): del self.downloads[key] return media_info - return create_observer(download) + return download.observe() @defer.inlineCallbacks def _get_remote_media_impl(self, server_name, media_id): diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 7cb91a0be9..75af44d787 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -51,7 +51,7 @@ logger = logging.getLogger(__name__) # Remember to update this number every time a change is made to database # schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 18 +SCHEMA_VERSION = 19 dir_path = os.path.abspath(os.path.dirname(__file__)) diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index c8c76e58fe..39884c2afe 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -15,10 +15,8 @@ import logging from synapse.api.errors import StoreError -from synapse.events import FrozenEvent -from synapse.events.utils import prune_event from synapse.util.logutils import log_function -from synapse.util.logcontext import PreserveLoggingContext, LoggingContext +from synapse.util.logcontext import preserve_context_over_fn, LoggingContext from synapse.util.lrucache import LruCache import synapse.metrics @@ -27,8 +25,8 @@ from util.id_generators import IdGenerator, StreamIdGenerator from twisted.internet import defer from collections import namedtuple, OrderedDict + import functools -import simplejson as json import sys import time import threading @@ -48,7 +46,6 @@ sql_scheduling_timer = metrics.register_distribution("schedule_time") sql_query_timer = metrics.register_distribution("query_time", labels=["verb"]) sql_txn_timer = metrics.register_distribution("transaction_time", labels=["desc"]) -sql_getevents_timer = metrics.register_distribution("getEvents_time", labels=["desc"]) caches_by_name = {} cache_counter = metrics.register_cache( @@ -307,6 +304,12 @@ class SQLBaseStore(object): self._get_event_cache = Cache("*getEvent*", keylen=3, lru=True, max_entries=hs.config.event_cache_size) + self._event_fetch_lock = threading.Condition() + self._event_fetch_list = [] + self._event_fetch_ongoing = 0 + + self._pending_ds = [] + self.database_engine = hs.database_engine self._stream_id_gen = StreamIdGenerator() @@ -315,6 +318,7 @@ class SQLBaseStore(object): self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self) self._pushers_id_gen = IdGenerator("pushers", "id", self) self._push_rule_id_gen = IdGenerator("push_rules", "id", self) + self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self) def start_profiling(self): self._previous_loop_ts = self._clock.time_msec() @@ -345,6 +349,75 @@ class SQLBaseStore(object): self._clock.looping_call(loop, 10000) + def _new_transaction(self, conn, desc, after_callbacks, func, *args, **kwargs): + start = time.time() * 1000 + txn_id = self._TXN_ID + + # We don't really need these to be unique, so lets stop it from + # growing really large. + self._TXN_ID = (self._TXN_ID + 1) % (sys.maxint - 1) + + name = "%s-%x" % (desc, txn_id, ) + + transaction_logger.debug("[TXN START] {%s}", name) + + try: + i = 0 + N = 5 + while True: + try: + txn = conn.cursor() + txn = LoggingTransaction( + txn, name, self.database_engine, after_callbacks + ) + r = func(txn, *args, **kwargs) + conn.commit() + return r + except self.database_engine.module.OperationalError as e: + # This can happen if the database disappears mid + # transaction. + logger.warn( + "[TXN OPERROR] {%s} %s %d/%d", + name, e, i, N + ) + if i < N: + i += 1 + try: + conn.rollback() + except self.database_engine.module.Error as e1: + logger.warn( + "[TXN EROLL] {%s} %s", + name, e1, + ) + continue + raise + except self.database_engine.module.DatabaseError as e: + if self.database_engine.is_deadlock(e): + logger.warn("[TXN DEADLOCK] {%s} %d/%d", name, i, N) + if i < N: + i += 1 + try: + conn.rollback() + except self.database_engine.module.Error as e1: + logger.warn( + "[TXN EROLL] {%s} %s", + name, e1, + ) + continue + raise + except Exception as e: + logger.debug("[TXN FAIL] {%s} %s", name, e) + raise + finally: + end = time.time() * 1000 + duration = end - start + + transaction_logger.debug("[TXN END] {%s} %f", name, duration) + + self._current_txn_total_time += duration + self._txn_perf_counters.update(desc, start, end) + sql_txn_timer.inc_by(duration, desc) + @defer.inlineCallbacks def runInteraction(self, desc, func, *args, **kwargs): """Wraps the .runInteraction() method on the underlying db_pool.""" @@ -356,82 +429,50 @@ class SQLBaseStore(object): def inner_func(conn, *args, **kwargs): with LoggingContext("runInteraction") as context: + sql_scheduling_timer.inc_by(time.time() * 1000 - start_time) + if self.database_engine.is_connection_closed(conn): logger.debug("Reconnecting closed database connection") conn.reconnect() current_context.copy_to(context) - start = time.time() * 1000 - txn_id = self._TXN_ID + return self._new_transaction( + conn, desc, after_callbacks, func, *args, **kwargs + ) - # We don't really need these to be unique, so lets stop it from - # growing really large. - self._TXN_ID = (self._TXN_ID + 1) % (sys.maxint - 1) + result = yield preserve_context_over_fn( + self._db_pool.runWithConnection, + inner_func, *args, **kwargs + ) - name = "%s-%x" % (desc, txn_id, ) + for after_callback, after_args in after_callbacks: + after_callback(*after_args) + defer.returnValue(result) + + @defer.inlineCallbacks + def runWithConnection(self, func, *args, **kwargs): + """Wraps the .runInteraction() method on the underlying db_pool.""" + current_context = LoggingContext.current_context() + + start_time = time.time() * 1000 + def inner_func(conn, *args, **kwargs): + with LoggingContext("runWithConnection") as context: sql_scheduling_timer.inc_by(time.time() * 1000 - start_time) - transaction_logger.debug("[TXN START] {%s}", name) - try: - i = 0 - N = 5 - while True: - try: - txn = conn.cursor() - txn = LoggingTransaction( - txn, name, self.database_engine, after_callbacks - ) - return func(txn, *args, **kwargs) - except self.database_engine.module.OperationalError as e: - # This can happen if the database disappears mid - # transaction. - logger.warn( - "[TXN OPERROR] {%s} %s %d/%d", - name, e, i, N - ) - if i < N: - i += 1 - try: - conn.rollback() - except self.database_engine.module.Error as e1: - logger.warn( - "[TXN EROLL] {%s} %s", - name, e1, - ) - continue - except self.database_engine.module.DatabaseError as e: - if self.database_engine.is_deadlock(e): - logger.warn("[TXN DEADLOCK] {%s} %d/%d", name, i, N) - if i < N: - i += 1 - try: - conn.rollback() - except self.database_engine.module.Error as e1: - logger.warn( - "[TXN EROLL] {%s} %s", - name, e1, - ) - continue - raise - except Exception as e: - logger.debug("[TXN FAIL] {%s} %s", name, e) - raise - finally: - end = time.time() * 1000 - duration = end - start - transaction_logger.debug("[TXN END] {%s} %f", name, duration) + if self.database_engine.is_connection_closed(conn): + logger.debug("Reconnecting closed database connection") + conn.reconnect() - self._current_txn_total_time += duration - self._txn_perf_counters.update(desc, start, end) - sql_txn_timer.inc_by(duration, desc) + current_context.copy_to(context) + + return func(conn, *args, **kwargs) + + result = yield preserve_context_over_fn( + self._db_pool.runWithConnection, + inner_func, *args, **kwargs + ) - with PreserveLoggingContext(): - result = yield self._db_pool.runWithConnection( - inner_func, *args, **kwargs - ) - for after_callback, after_args in after_callbacks: - after_callback(*after_args) defer.returnValue(result) def cursor_to_dict(self, cursor): @@ -871,158 +912,6 @@ class SQLBaseStore(object): return self.runInteraction("_simple_max_id", func) - def _get_events(self, event_ids, check_redacted=True, - get_prev_content=False): - return self.runInteraction( - "_get_events", self._get_events_txn, event_ids, - check_redacted=check_redacted, get_prev_content=get_prev_content, - ) - - def _get_events_txn(self, txn, event_ids, check_redacted=True, - get_prev_content=False): - if not event_ids: - return [] - - events = [ - self._get_event_txn( - txn, event_id, - check_redacted=check_redacted, - get_prev_content=get_prev_content - ) - for event_id in event_ids - ] - - return [e for e in events if e] - - def _invalidate_get_event_cache(self, event_id): - for check_redacted in (False, True): - for get_prev_content in (False, True): - self._get_event_cache.invalidate(event_id, check_redacted, - get_prev_content) - - def _get_event_txn(self, txn, event_id, check_redacted=True, - get_prev_content=False, allow_rejected=False): - - start_time = time.time() * 1000 - - def update_counter(desc, last_time): - curr_time = self._get_event_counters.update(desc, last_time) - sql_getevents_timer.inc_by(curr_time - last_time, desc) - return curr_time - - try: - ret = self._get_event_cache.get(event_id, check_redacted, get_prev_content) - - if allow_rejected or not ret.rejected_reason: - return ret - else: - return None - except KeyError: - pass - finally: - start_time = update_counter("event_cache", start_time) - - sql = ( - "SELECT e.internal_metadata, e.json, r.event_id, rej.reason " - "FROM event_json as e " - "LEFT JOIN redactions as r ON e.event_id = r.redacts " - "LEFT JOIN rejections as rej on rej.event_id = e.event_id " - "WHERE e.event_id = ? " - "LIMIT 1 " - ) - - txn.execute(sql, (event_id,)) - - res = txn.fetchone() - - if not res: - return None - - internal_metadata, js, redacted, rejected_reason = res - - start_time = update_counter("select_event", start_time) - - result = self._get_event_from_row_txn( - txn, internal_metadata, js, redacted, - check_redacted=check_redacted, - get_prev_content=get_prev_content, - rejected_reason=rejected_reason, - ) - self._get_event_cache.prefill(event_id, check_redacted, get_prev_content, result) - - if allow_rejected or not rejected_reason: - return result - else: - return None - - def _get_event_from_row_txn(self, txn, internal_metadata, js, redacted, - check_redacted=True, get_prev_content=False, - rejected_reason=None): - - start_time = time.time() * 1000 - - def update_counter(desc, last_time): - curr_time = self._get_event_counters.update(desc, last_time) - sql_getevents_timer.inc_by(curr_time - last_time, desc) - return curr_time - - d = json.loads(js) - start_time = update_counter("decode_json", start_time) - - internal_metadata = json.loads(internal_metadata) - start_time = update_counter("decode_internal", start_time) - - ev = FrozenEvent( - d, - internal_metadata_dict=internal_metadata, - rejected_reason=rejected_reason, - ) - start_time = update_counter("build_frozen_event", start_time) - - if check_redacted and redacted: - ev = prune_event(ev) - - ev.unsigned["redacted_by"] = redacted - # Get the redaction event. - - because = self._get_event_txn( - txn, - redacted, - check_redacted=False - ) - - if because: - ev.unsigned["redacted_because"] = because - start_time = update_counter("redact_event", start_time) - - if get_prev_content and "replaces_state" in ev.unsigned: - prev = self._get_event_txn( - txn, - ev.unsigned["replaces_state"], - get_prev_content=False, - ) - if prev: - ev.unsigned["prev_content"] = prev.get_dict()["content"] - start_time = update_counter("get_prev_content", start_time) - - return ev - - def _parse_events(self, rows): - return self.runInteraction( - "_parse_events", self._parse_events_txn, rows - ) - - def _parse_events_txn(self, txn, rows): - event_ids = [r["event_id"] for r in rows] - - return self._get_events_txn(txn, event_ids) - - def _has_been_redacted_txn(self, txn, event): - sql = "SELECT event_id FROM redactions WHERE redacts = ?" - txn.execute(sql, (event.event_id,)) - result = txn.fetchone() - return result[0] if result else None - def get_next_stream_id(self): with self._next_stream_id_lock: i = self._next_stream_id diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py index a323028546..4a855ffd56 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py @@ -19,6 +19,8 @@ from ._base import IncorrectDatabaseSetup class PostgresEngine(object): + single_threaded = False + def __init__(self, database_module): self.module = database_module self.module.extensions.register_type(self.module.extensions.UNICODE) diff --git a/synapse/storage/engines/sqlite3.py b/synapse/storage/engines/sqlite3.py index ff13d8006a..d18e2808d1 100644 --- a/synapse/storage/engines/sqlite3.py +++ b/synapse/storage/engines/sqlite3.py @@ -17,6 +17,8 @@ from synapse.storage import prepare_database, prepare_sqlite3_database class Sqlite3Engine(object): + single_threaded = True + def __init__(self, database_module): self.module = database_module diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py index 74b4e23590..5d4b7843f3 100644 --- a/synapse/storage/event_federation.py +++ b/synapse/storage/event_federation.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from twisted.internet import defer + from ._base import SQLBaseStore, cached from syutil.base64util import encode_base64 @@ -33,16 +35,7 @@ class EventFederationStore(SQLBaseStore): """ def get_auth_chain(self, event_ids): - return self.runInteraction( - "get_auth_chain", - self._get_auth_chain_txn, - event_ids - ) - - def _get_auth_chain_txn(self, txn, event_ids): - results = self._get_auth_chain_ids_txn(txn, event_ids) - - return self._get_events_txn(txn, results) + return self.get_auth_chain_ids(event_ids).addCallback(self._get_events) def get_auth_chain_ids(self, event_ids): return self.runInteraction( @@ -79,6 +72,28 @@ class EventFederationStore(SQLBaseStore): room_id, ) + def get_oldest_events_with_depth_in_room(self, room_id): + return self.runInteraction( + "get_oldest_events_with_depth_in_room", + self.get_oldest_events_with_depth_in_room_txn, + room_id, + ) + + def get_oldest_events_with_depth_in_room_txn(self, txn, room_id): + sql = ( + "SELECT b.event_id, MAX(e.depth) FROM events as e" + " INNER JOIN event_edges as g" + " ON g.event_id = e.event_id AND g.room_id = e.room_id" + " INNER JOIN event_backward_extremities as b" + " ON g.prev_event_id = b.event_id AND g.room_id = b.room_id" + " WHERE b.room_id = ? AND g.is_state is ?" + " GROUP BY b.event_id" + ) + + txn.execute(sql, (room_id, False,)) + + return dict(txn.fetchall()) + def _get_oldest_events_in_room_txn(self, txn, room_id): return self._simple_select_onecol_txn( txn, @@ -247,11 +262,13 @@ class EventFederationStore(SQLBaseStore): do_insert = depth < min_depth if min_depth else True if do_insert: - self._simple_insert_txn( + self._simple_upsert_txn( txn, table="room_depth", - values={ + keyvalues={ "room_id": room_id, + }, + values={ "min_depth": depth, }, ) @@ -306,31 +323,27 @@ class EventFederationStore(SQLBaseStore): txn.execute(query, (event_id, room_id)) - # Insert all the prev_events as a backwards thing, they'll get - # deleted in a second if they're incorrect anyway. - self._simple_insert_many_txn( - txn, - table="event_backward_extremities", - values=[ - { - "event_id": e_id, - "room_id": room_id, - } - for e_id, _ in prev_events - ], + query = ( + "INSERT INTO event_backward_extremities (event_id, room_id)" + " SELECT ?, ? WHERE NOT EXISTS (" + " SELECT 1 FROM event_backward_extremities" + " WHERE event_id = ? AND room_id = ?" + " )" + " AND NOT EXISTS (" + " SELECT 1 FROM events WHERE event_id = ? AND room_id = ?" + " )" ) - # Also delete from the backwards extremities table all ones that - # reference events that we have already seen + txn.executemany(query, [ + (e_id, room_id, e_id, room_id, e_id, room_id, ) + for e_id, _ in prev_events + ]) + query = ( - "DELETE FROM event_backward_extremities WHERE EXISTS (" - "SELECT 1 FROM events " - "WHERE " - "event_backward_extremities.event_id = events.event_id " - "AND not events.outlier " - ")" + "DELETE FROM event_backward_extremities" + " WHERE event_id = ? AND room_id = ?" ) - txn.execute(query) + txn.execute(query, (event_id, room_id)) txn.call_after( self.get_latest_event_ids_in_room.invalidate, room_id @@ -349,7 +362,7 @@ class EventFederationStore(SQLBaseStore): return self.runInteraction( "get_backfill_events", self._get_backfill_events, room_id, event_list, limit - ) + ).addCallback(self._get_events) def _get_backfill_events(self, txn, room_id, event_list, limit): logger.debug( @@ -395,16 +408,26 @@ class EventFederationStore(SQLBaseStore): front = new_front event_results += new_front - return self._get_events_txn(txn, event_results) + return event_results + @defer.inlineCallbacks def get_missing_events(self, room_id, earliest_events, latest_events, limit, min_depth): - return self.runInteraction( + ids = yield self.runInteraction( "get_missing_events", self._get_missing_events, room_id, earliest_events, latest_events, limit, min_depth ) + events = yield self._get_events(ids) + + events = sorted( + [ev for ev in events if ev.depth >= min_depth], + key=lambda e: e.depth, + ) + + defer.returnValue(events[:limit]) + def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit, min_depth): @@ -436,14 +459,7 @@ class EventFederationStore(SQLBaseStore): front = new_front event_results |= new_front - events = self._get_events_txn(txn, event_results) - - events = sorted( - [ev for ev in events if ev.depth >= min_depth], - key=lambda e: e.depth, - ) - - return events[:limit] + return event_results def clean_room_for_join(self, room_id): return self.runInteraction( diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 1304219e86..d2a010bd88 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -15,20 +15,36 @@ from _base import SQLBaseStore, _RollbackButIsFineException -from twisted.internet import defer +from twisted.internet import defer, reactor +from synapse.events import FrozenEvent +from synapse.events.utils import prune_event + +from synapse.util.logcontext import preserve_context_over_deferred from synapse.util.logutils import log_function from synapse.api.constants import EventTypes from synapse.crypto.event_signing import compute_event_reference_hash from syutil.base64util import decode_base64 from syutil.jsonutil import encode_canonical_json +from contextlib import contextmanager import logging +import simplejson as json logger = logging.getLogger(__name__) +# These values are used in the `enqueus_event` and `_do_fetch` methods to +# control how we batch/bulk fetch events from the database. +# The values are plucked out of thing air to make initial sync run faster +# on jki.re +# TODO: Make these configurable. +EVENT_QUEUE_THREADS = 3 # Max number of threads that will fetch events +EVENT_QUEUE_ITERATIONS = 3 # No. times we block waiting for requests for events +EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events + + class EventsStore(SQLBaseStore): @defer.inlineCallbacks @log_function @@ -41,20 +57,32 @@ class EventsStore(SQLBaseStore): self.min_token -= 1 stream_ordering = self.min_token + if stream_ordering is None: + stream_ordering_manager = yield self._stream_id_gen.get_next(self) + else: + @contextmanager + def stream_ordering_manager(): + yield stream_ordering + stream_ordering_manager = stream_ordering_manager() + try: - yield self.runInteraction( - "persist_event", - self._persist_event_txn, - event=event, - context=context, - backfilled=backfilled, - stream_ordering=stream_ordering, - is_new_state=is_new_state, - current_state=current_state, - ) + with stream_ordering_manager as stream_ordering: + yield self.runInteraction( + "persist_event", + self._persist_event_txn, + event=event, + context=context, + backfilled=backfilled, + stream_ordering=stream_ordering, + is_new_state=is_new_state, + current_state=current_state, + ) except _RollbackButIsFineException: pass + max_persisted_id = yield self._stream_id_gen.get_max_token(self) + defer.returnValue((stream_ordering, max_persisted_id)) + @defer.inlineCallbacks def get_event(self, event_id, check_redacted=True, get_prev_content=False, allow_rejected=False, @@ -74,18 +102,17 @@ class EventsStore(SQLBaseStore): Returns: Deferred : A FrozenEvent. """ - event = yield self.runInteraction( - "get_event", self._get_event_txn, - event_id, + events = yield self._get_events( + [event_id], check_redacted=check_redacted, get_prev_content=get_prev_content, allow_rejected=allow_rejected, ) - if not event and not allow_none: + if not events and not allow_none: raise RuntimeError("Could not find event %s" % (event_id,)) - defer.returnValue(event) + defer.returnValue(events[0] if events else None) @log_function def _persist_event_txn(self, txn, event, context, backfilled, @@ -95,15 +122,6 @@ class EventsStore(SQLBaseStore): # Remove the any existing cache entries for the event_id txn.call_after(self._invalidate_get_event_cache, event.event_id) - if stream_ordering is None: - with self._stream_id_gen.get_next_txn(txn) as stream_ordering: - return self._persist_event_txn( - txn, event, context, backfilled, - stream_ordering=stream_ordering, - is_new_state=is_new_state, - current_state=current_state, - ) - # We purposefully do this first since if we include a `current_state` # key, we *want* to update the `current_state_events` table if current_state: @@ -134,19 +152,17 @@ class EventsStore(SQLBaseStore): outlier = event.internal_metadata.is_outlier() if not outlier: - self._store_state_groups_txn(txn, event, context) - self._update_min_depth_for_room_txn( txn, event.room_id, event.depth ) - have_persisted = self._simple_select_one_onecol_txn( + have_persisted = self._simple_select_one_txn( txn, - table="event_json", + table="events", keyvalues={"event_id": event.event_id}, - retcol="event_id", + retcols=["event_id", "outlier"], allow_none=True, ) @@ -161,7 +177,9 @@ class EventsStore(SQLBaseStore): # if we are persisting an event that we had persisted as an outlier, # but is no longer one. if have_persisted: - if not outlier: + if not outlier and have_persisted["outlier"]: + self._store_state_groups_txn(txn, event, context) + sql = ( "UPDATE event_json SET internal_metadata = ?" " WHERE event_id = ?" @@ -181,6 +199,9 @@ class EventsStore(SQLBaseStore): ) return + if not outlier: + self._store_state_groups_txn(txn, event, context) + self._handle_prev_events( txn, outlier=outlier, @@ -400,3 +421,407 @@ class EventsStore(SQLBaseStore): return self.runInteraction( "have_events", f, ) + + @defer.inlineCallbacks + def _get_events(self, event_ids, check_redacted=True, + get_prev_content=False, allow_rejected=False): + if not event_ids: + defer.returnValue([]) + + event_map = self._get_events_from_cache( + event_ids, + check_redacted=check_redacted, + get_prev_content=get_prev_content, + allow_rejected=allow_rejected, + ) + + missing_events_ids = [e for e in event_ids if e not in event_map] + + if not missing_events_ids: + defer.returnValue([ + event_map[e_id] for e_id in event_ids + if e_id in event_map and event_map[e_id] + ]) + + missing_events = yield self._enqueue_events( + missing_events_ids, + check_redacted=check_redacted, + get_prev_content=get_prev_content, + allow_rejected=allow_rejected, + ) + + event_map.update(missing_events) + + defer.returnValue([ + event_map[e_id] for e_id in event_ids + if e_id in event_map and event_map[e_id] + ]) + + def _get_events_txn(self, txn, event_ids, check_redacted=True, + get_prev_content=False, allow_rejected=False): + if not event_ids: + return [] + + event_map = self._get_events_from_cache( + event_ids, + check_redacted=check_redacted, + get_prev_content=get_prev_content, + allow_rejected=allow_rejected, + ) + + missing_events_ids = [e for e in event_ids if e not in event_map] + + if not missing_events_ids: + return [ + event_map[e_id] for e_id in event_ids + if e_id in event_map and event_map[e_id] + ] + + missing_events = self._fetch_events_txn( + txn, + missing_events_ids, + check_redacted=check_redacted, + get_prev_content=get_prev_content, + allow_rejected=allow_rejected, + ) + + event_map.update(missing_events) + + return [ + event_map[e_id] for e_id in event_ids + if e_id in event_map and event_map[e_id] + ] + + def _invalidate_get_event_cache(self, event_id): + for check_redacted in (False, True): + for get_prev_content in (False, True): + self._get_event_cache.invalidate(event_id, check_redacted, + get_prev_content) + + def _get_event_txn(self, txn, event_id, check_redacted=True, + get_prev_content=False, allow_rejected=False): + + events = self._get_events_txn( + txn, [event_id], + check_redacted=check_redacted, + get_prev_content=get_prev_content, + allow_rejected=allow_rejected, + ) + + return events[0] if events else None + + def _get_events_from_cache(self, events, check_redacted, get_prev_content, + allow_rejected): + event_map = {} + + for event_id in events: + try: + ret = self._get_event_cache.get( + event_id, check_redacted, get_prev_content + ) + + if allow_rejected or not ret.rejected_reason: + event_map[event_id] = ret + else: + event_map[event_id] = None + except KeyError: + pass + + return event_map + + def _do_fetch(self, conn): + """Takes a database connection and waits for requests for events from + the _event_fetch_list queue. + """ + event_list = [] + i = 0 + while True: + try: + with self._event_fetch_lock: + event_list = self._event_fetch_list + self._event_fetch_list = [] + + if not event_list: + single_threaded = self.database_engine.single_threaded + if single_threaded or i > EVENT_QUEUE_ITERATIONS: + self._event_fetch_ongoing -= 1 + return + else: + self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S) + i += 1 + continue + i = 0 + + event_id_lists = zip(*event_list)[0] + event_ids = [ + item for sublist in event_id_lists for item in sublist + ] + + rows = self._new_transaction( + conn, "do_fetch", [], self._fetch_event_rows, event_ids + ) + + row_dict = { + r["event_id"]: r + for r in rows + } + + # We only want to resolve deferreds from the main thread + def fire(lst, res): + for ids, d in lst: + if not d.called: + try: + d.callback([ + res[i] + for i in ids + if i in res + ]) + except: + logger.exception("Failed to callback") + reactor.callFromThread(fire, event_list, row_dict) + except Exception as e: + logger.exception("do_fetch") + + # We only want to resolve deferreds from the main thread + def fire(evs): + for _, d in evs: + if not d.called: + d.errback(e) + + if event_list: + reactor.callFromThread(fire, event_list) + + @defer.inlineCallbacks + def _enqueue_events(self, events, check_redacted=True, + get_prev_content=False, allow_rejected=False): + """Fetches events from the database using the _event_fetch_list. This + allows batch and bulk fetching of events - it allows us to fetch events + without having to create a new transaction for each request for events. + """ + if not events: + defer.returnValue({}) + + events_d = defer.Deferred() + with self._event_fetch_lock: + self._event_fetch_list.append( + (events, events_d) + ) + + self._event_fetch_lock.notify() + + if self._event_fetch_ongoing < EVENT_QUEUE_THREADS: + self._event_fetch_ongoing += 1 + should_start = True + else: + should_start = False + + if should_start: + self.runWithConnection( + self._do_fetch + ) + + rows = yield preserve_context_over_deferred(events_d) + + if not allow_rejected: + rows[:] = [r for r in rows if not r["rejects"]] + + res = yield defer.gatherResults( + [ + self._get_event_from_row( + row["internal_metadata"], row["json"], row["redacts"], + check_redacted=check_redacted, + get_prev_content=get_prev_content, + rejected_reason=row["rejects"], + ) + for row in rows + ], + consumeErrors=True + ) + + defer.returnValue({ + e.event_id: e + for e in res if e + }) + + def _fetch_event_rows(self, txn, events): + rows = [] + N = 200 + for i in range(1 + len(events) / N): + evs = events[i*N:(i + 1)*N] + if not evs: + break + + sql = ( + "SELECT " + " e.event_id as event_id, " + " e.internal_metadata," + " e.json," + " r.redacts as redacts," + " rej.event_id as rejects " + " FROM event_json as e" + " LEFT JOIN rejections as rej USING (event_id)" + " LEFT JOIN redactions as r ON e.event_id = r.redacts" + " WHERE e.event_id IN (%s)" + ) % (",".join(["?"]*len(evs)),) + + txn.execute(sql, evs) + rows.extend(self.cursor_to_dict(txn)) + + return rows + + def _fetch_events_txn(self, txn, events, check_redacted=True, + get_prev_content=False, allow_rejected=False): + if not events: + return {} + + rows = self._fetch_event_rows( + txn, events, + ) + + if not allow_rejected: + rows[:] = [r for r in rows if not r["rejects"]] + + res = [ + self._get_event_from_row_txn( + txn, + row["internal_metadata"], row["json"], row["redacts"], + check_redacted=check_redacted, + get_prev_content=get_prev_content, + rejected_reason=row["rejects"], + ) + for row in rows + ] + + return { + r.event_id: r + for r in res + } + + @defer.inlineCallbacks + def _get_event_from_row(self, internal_metadata, js, redacted, + check_redacted=True, get_prev_content=False, + rejected_reason=None): + d = json.loads(js) + internal_metadata = json.loads(internal_metadata) + + if rejected_reason: + rejected_reason = yield self._simple_select_one_onecol( + table="rejections", + keyvalues={"event_id": rejected_reason}, + retcol="reason", + desc="_get_event_from_row", + ) + + ev = FrozenEvent( + d, + internal_metadata_dict=internal_metadata, + rejected_reason=rejected_reason, + ) + + if check_redacted and redacted: + ev = prune_event(ev) + + redaction_id = yield self._simple_select_one_onecol( + table="redactions", + keyvalues={"redacts": ev.event_id}, + retcol="event_id", + desc="_get_event_from_row", + ) + + ev.unsigned["redacted_by"] = redaction_id + # Get the redaction event. + + because = yield self.get_event( + redaction_id, + check_redacted=False + ) + + if because: + ev.unsigned["redacted_because"] = because + + if get_prev_content and "replaces_state" in ev.unsigned: + prev = yield self.get_event( + ev.unsigned["replaces_state"], + get_prev_content=False, + ) + if prev: + ev.unsigned["prev_content"] = prev.get_dict()["content"] + + self._get_event_cache.prefill( + ev.event_id, check_redacted, get_prev_content, ev + ) + + defer.returnValue(ev) + + def _get_event_from_row_txn(self, txn, internal_metadata, js, redacted, + check_redacted=True, get_prev_content=False, + rejected_reason=None): + d = json.loads(js) + internal_metadata = json.loads(internal_metadata) + + if rejected_reason: + rejected_reason = self._simple_select_one_onecol_txn( + txn, + table="rejections", + keyvalues={"event_id": rejected_reason}, + retcol="reason", + ) + + ev = FrozenEvent( + d, + internal_metadata_dict=internal_metadata, + rejected_reason=rejected_reason, + ) + + if check_redacted and redacted: + ev = prune_event(ev) + + redaction_id = self._simple_select_one_onecol_txn( + txn, + table="redactions", + keyvalues={"redacts": ev.event_id}, + retcol="event_id", + ) + + ev.unsigned["redacted_by"] = redaction_id + # Get the redaction event. + + because = self._get_event_txn( + txn, + redaction_id, + check_redacted=False + ) + + if because: + ev.unsigned["redacted_because"] = because + + if get_prev_content and "replaces_state" in ev.unsigned: + prev = self._get_event_txn( + txn, + ev.unsigned["replaces_state"], + get_prev_content=False, + ) + if prev: + ev.unsigned["prev_content"] = prev.get_dict()["content"] + + self._get_event_cache.prefill( + ev.event_id, check_redacted, get_prev_content, ev + ) + + return ev + + def _parse_events(self, rows): + return self.runInteraction( + "_parse_events", self._parse_events_txn, rows + ) + + def _parse_events_txn(self, txn, rows): + event_ids = [r["event_id"] for r in rows] + + return self._get_events_txn(txn, event_ids) + + def _has_been_redacted_txn(self, txn, event): + sql = "SELECT event_id FROM redactions WHERE redacts = ?" + txn.execute(sql, (event.event_id,)) + result = txn.fetchone() + return result[0] if result else None diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index 88ee21b089..80d0ac4ea3 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -212,11 +212,21 @@ class PushRuleStore(SQLBaseStore): @defer.inlineCallbacks def set_push_rule_enabled(self, user_name, rule_id, enabled): - yield self._simple_upsert( + ret = yield self.runInteraction( + "_set_push_rule_enabled_txn", + self._set_push_rule_enabled_txn, + user_name, rule_id, enabled + ) + defer.returnValue(ret) + + def _set_push_rule_enabled_txn(self, txn, user_name, rule_id, enabled): + new_id = self._push_rules_enable_id_gen.get_next_txn(txn) + self._simple_upsert_txn( + txn, PushRuleEnableTable.table_name, {'user_name': user_name, 'rule_id': rule_id}, {'enabled': 1 if enabled else 0}, - desc="set_push_rule_enabled", + {'id': new_id}, ) self.get_push_rules_enabled_for_user.invalidate(user_name) diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 3691eade05..d36a6c18a8 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -77,16 +77,16 @@ class RoomMemberStore(SQLBaseStore): Returns: Deferred: Results in a MembershipEvent or None. """ - def f(txn): - events = self._get_members_events_txn( - txn, - room_id, - user_id=user_id, - ) - - return events[0] if events else None - - return self.runInteraction("get_room_member", f) + return self.runInteraction( + "get_room_member", + self._get_members_events_txn, + room_id, + user_id=user_id, + ).addCallback( + self._get_events + ).addCallback( + lambda events: events[0] if events else None + ) @cached() def get_users_in_room(self, room_id): @@ -112,15 +112,12 @@ class RoomMemberStore(SQLBaseStore): Returns: list of namedtuples representing the members in this room. """ - - def f(txn): - return self._get_members_events_txn( - txn, - room_id, - membership=membership, - ) - - return self.runInteraction("get_room_members", f) + return self.runInteraction( + "get_room_members", + self._get_members_events_txn, + room_id, + membership=membership, + ).addCallback(self._get_events) def get_rooms_for_user_where_membership_is(self, user_id, membership_list): """ Get all the rooms for this user where the membership for this user @@ -192,14 +189,14 @@ class RoomMemberStore(SQLBaseStore): return self.runInteraction( "get_members_query", self._get_members_events_txn, where_clause, where_values - ) + ).addCallbacks(self._get_events) def _get_members_events_txn(self, txn, room_id, membership=None, user_id=None): rows = self._get_members_rows_txn( txn, room_id, membership, user_id, ) - return self._get_events_txn(txn, [r["event_id"] for r in rows]) + return [r["event_id"] for r in rows] def _get_members_rows_txn(self, txn, room_id, membership=None, user_id=None): where_clause = "c.room_id = ?" diff --git a/synapse/storage/schema/delta/19/event_index.sql b/synapse/storage/schema/delta/19/event_index.sql new file mode 100644 index 0000000000..3881fc9897 --- /dev/null +++ b/synapse/storage/schema/delta/19/event_index.sql @@ -0,0 +1,19 @@ +/* Copyright 2015 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +CREATE INDEX events_order_topo_stream_room ON events( + topological_ordering, stream_ordering, room_id +); diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 6df7350552..b24de34f23 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -43,6 +43,7 @@ class StateStore(SQLBaseStore): * `state_groups_state`: Maps state group to state events. """ + @defer.inlineCallbacks def get_state_groups(self, event_ids): """ Get the state groups for the given list of event_ids @@ -71,17 +72,29 @@ class StateStore(SQLBaseStore): retcol="event_id", ) - state = self._get_events_txn(txn, state_ids) - - res[group] = state + res[group] = state_ids return res - return self.runInteraction( + states = yield self.runInteraction( "get_state_groups", f, ) + @defer.inlineCallbacks + def c(vals): + vals[:] = yield self._get_events(vals, get_prev_content=False) + + yield defer.gatherResults( + [ + c(vals) + for vals in states.values() + ], + consumeErrors=True, + ) + + defer.returnValue(states) + def _store_state_groups_txn(self, txn, event, context): if context.current_state is None: return @@ -152,11 +165,12 @@ class StateStore(SQLBaseStore): args = (room_id, ) txn.execute(sql, args) - results = self.cursor_to_dict(txn) + results = txn.fetchall() - return self._parse_events_txn(txn, results) + return [r[0] for r in results] - events = yield self.runInteraction("get_current_state", f) + event_ids = yield self.runInteraction("get_current_state", f) + events = yield self._get_events(event_ids, get_prev_content=False) defer.returnValue(events) @cached(num_args=3) diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index 280d4ad605..af45fc5619 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -37,11 +37,9 @@ from twisted.internet import defer from ._base import SQLBaseStore from synapse.api.constants import EventTypes -from synapse.api.errors import SynapseError +from synapse.types import RoomStreamToken from synapse.util.logutils import log_function -from collections import namedtuple - import logging @@ -55,76 +53,26 @@ _STREAM_TOKEN = "stream" _TOPOLOGICAL_TOKEN = "topological" -class _StreamToken(namedtuple("_StreamToken", "topological stream")): - """Tokens are positions between events. The token "s1" comes after event 1. - - s0 s1 - | | - [0] V [1] V [2] - - Tokens can either be a point in the live event stream or a cursor going - through historic events. - - When traversing the live event stream events are ordered by when they - arrived at the homeserver. - - When traversing historic events the events are ordered by their depth in - the event graph "topological_ordering" and then by when they arrived at the - homeserver "stream_ordering". - - Live tokens start with an "s" followed by the "stream_ordering" id of the - event it comes after. Historic tokens start with a "t" followed by the - "topological_ordering" id of the event it comes after, follewed by "-", - followed by the "stream_ordering" id of the event it comes after. - """ - __slots__ = [] - - @classmethod - def parse(cls, string): - try: - if string[0] == 's': - return cls(topological=None, stream=int(string[1:])) - if string[0] == 't': - parts = string[1:].split('-', 1) - return cls(topological=int(parts[0]), stream=int(parts[1])) - except: - pass - raise SynapseError(400, "Invalid token %r" % (string,)) - - @classmethod - def parse_stream_token(cls, string): - try: - if string[0] == 's': - return cls(topological=None, stream=int(string[1:])) - except: - pass - raise SynapseError(400, "Invalid token %r" % (string,)) - - def __str__(self): - if self.topological is not None: - return "t%d-%d" % (self.topological, self.stream) - else: - return "s%d" % (self.stream,) +def lower_bound(token): + if token.topological is None: + return "(%d < %s)" % (token.stream, "stream_ordering") + else: + return "(%d < %s OR (%d = %s AND %d < %s))" % ( + token.topological, "topological_ordering", + token.topological, "topological_ordering", + token.stream, "stream_ordering", + ) - def lower_bound(self): - if self.topological is None: - return "(%d < %s)" % (self.stream, "stream_ordering") - else: - return "(%d < %s OR (%d = %s AND %d < %s))" % ( - self.topological, "topological_ordering", - self.topological, "topological_ordering", - self.stream, "stream_ordering", - ) - def upper_bound(self): - if self.topological is None: - return "(%d >= %s)" % (self.stream, "stream_ordering") - else: - return "(%d > %s OR (%d = %s AND %d >= %s))" % ( - self.topological, "topological_ordering", - self.topological, "topological_ordering", - self.stream, "stream_ordering", - ) +def upper_bound(token): + if token.topological is None: + return "(%d >= %s)" % (token.stream, "stream_ordering") + else: + return "(%d > %s OR (%d = %s AND %d >= %s))" % ( + token.topological, "topological_ordering", + token.topological, "topological_ordering", + token.stream, "stream_ordering", + ) class StreamStore(SQLBaseStore): @@ -139,8 +87,8 @@ class StreamStore(SQLBaseStore): limit = MAX_STREAM_SIZE # From and to keys should be integers from ordering. - from_id = _StreamToken.parse_stream_token(from_key) - to_id = _StreamToken.parse_stream_token(to_key) + from_id = RoomStreamToken.parse_stream_token(from_key) + to_id = RoomStreamToken.parse_stream_token(to_key) if from_key == to_key: defer.returnValue(([], to_key)) @@ -234,8 +182,8 @@ class StreamStore(SQLBaseStore): limit = MAX_STREAM_SIZE # From and to keys should be integers from ordering. - from_id = _StreamToken.parse_stream_token(from_key) - to_id = _StreamToken.parse_stream_token(to_key) + from_id = RoomStreamToken.parse_stream_token(from_key) + to_id = RoomStreamToken.parse_stream_token(to_key) if from_key == to_key: return defer.succeed(([], to_key)) @@ -276,7 +224,7 @@ class StreamStore(SQLBaseStore): return self.runInteraction("get_room_events_stream", f) - @log_function + @defer.inlineCallbacks def paginate_room_events(self, room_id, from_key, to_key=None, direction='b', limit=-1, with_feedback=False): @@ -288,17 +236,17 @@ class StreamStore(SQLBaseStore): args = [False, room_id] if direction == 'b': order = "DESC" - bounds = _StreamToken.parse(from_key).upper_bound() + bounds = upper_bound(RoomStreamToken.parse(from_key)) if to_key: bounds = "%s AND %s" % ( - bounds, _StreamToken.parse(to_key).lower_bound() + bounds, lower_bound(RoomStreamToken.parse(to_key)) ) else: order = "ASC" - bounds = _StreamToken.parse(from_key).lower_bound() + bounds = lower_bound(RoomStreamToken.parse(from_key)) if to_key: bounds = "%s AND %s" % ( - bounds, _StreamToken.parse(to_key).upper_bound() + bounds, upper_bound(RoomStreamToken.parse(to_key)) ) if int(limit) > 0: @@ -333,28 +281,30 @@ class StreamStore(SQLBaseStore): # when we are going backwards so we subtract one from the # stream part. toke -= 1 - next_token = str(_StreamToken(topo, toke)) + next_token = str(RoomStreamToken(topo, toke)) else: # TODO (erikj): We should work out what to do here instead. next_token = to_key if to_key else from_key - events = self._get_events_txn( - txn, - [r["event_id"] for r in rows], - get_prev_content=True - ) + return rows, next_token, - self._set_before_and_after(events, rows) + rows, token = yield self.runInteraction("paginate_room_events", f) + + events = yield self._get_events( + [r["event_id"] for r in rows], + get_prev_content=True + ) - return events, next_token, + self._set_before_and_after(events, rows) - return self.runInteraction("paginate_room_events", f) + defer.returnValue((events, token)) + @defer.inlineCallbacks def get_recent_events_for_room(self, room_id, limit, end_token, with_feedback=False, from_token=None): # TODO (erikj): Handle compressed feedback - end_token = _StreamToken.parse_stream_token(end_token) + end_token = RoomStreamToken.parse_stream_token(end_token) if from_token is None: sql = ( @@ -365,7 +315,7 @@ class StreamStore(SQLBaseStore): " LIMIT ?" ) else: - from_token = _StreamToken.parse_stream_token(from_token) + from_token = RoomStreamToken.parse_stream_token(from_token) sql = ( "SELECT stream_ordering, topological_ordering, event_id" " FROM events" @@ -395,30 +345,49 @@ class StreamStore(SQLBaseStore): # stream part. topo = rows[0]["topological_ordering"] toke = rows[0]["stream_ordering"] - 1 - start_token = str(_StreamToken(topo, toke)) + start_token = str(RoomStreamToken(topo, toke)) token = (start_token, str(end_token)) else: token = (str(end_token), str(end_token)) - events = self._get_events_txn( - txn, - [r["event_id"] for r in rows], - get_prev_content=True - ) - - self._set_before_and_after(events, rows) + return rows, token - return events, token - - return self.runInteraction( + rows, token = yield self.runInteraction( "get_recent_events_for_room", get_recent_events_for_room_txn ) + logger.debug("stream before") + events = yield self._get_events( + [r["event_id"] for r in rows], + get_prev_content=True + ) + logger.debug("stream after") + + self._set_before_and_after(events, rows) + + defer.returnValue((events, token)) + @defer.inlineCallbacks - def get_room_events_max_id(self): + def get_room_events_max_id(self, direction='f'): token = yield self._stream_id_gen.get_max_token(self) - defer.returnValue("s%d" % (token,)) + if direction != 'b': + defer.returnValue("s%d" % (token,)) + else: + topo = yield self.runInteraction( + "_get_max_topological_txn", self._get_max_topological_txn + ) + defer.returnValue("t%d-%d" % (topo, token)) + + def _get_max_topological_txn(self, txn): + txn.execute( + "SELECT MAX(topological_ordering) FROM events" + " WHERE outlier = ?", + (False,) + ) + + rows = txn.fetchall() + return rows[0][0] if rows else 0 @defer.inlineCallbacks def _get_min_token(self): @@ -439,5 +408,5 @@ class StreamStore(SQLBaseStore): stream = row["stream_ordering"] topo = event.depth internal = event.internal_metadata - internal.before = str(_StreamToken(topo, stream - 1)) - internal.after = str(_StreamToken(topo, stream)) + internal.before = str(RoomStreamToken(topo, stream - 1)) + internal.after = str(RoomStreamToken(topo, stream)) diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index e40eb8a8c4..89d1643f10 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -78,14 +78,18 @@ class StreamIdGenerator(object): self._current_max = None self._unfinished_ids = deque() - def get_next_txn(self, txn): + @defer.inlineCallbacks + def get_next(self, store): """ Usage: - with stream_id_gen.get_next_txn(txn) as stream_id: + with yield stream_id_gen.get_next as stream_id: # ... persist event ... """ if not self._current_max: - self._get_or_compute_current_max(txn) + yield store.runInteraction( + "_compute_current_max", + self._get_or_compute_current_max, + ) with self._lock: self._current_max += 1 @@ -101,7 +105,7 @@ class StreamIdGenerator(object): with self._lock: self._unfinished_ids.remove(next_id) - return manager() + defer.returnValue(manager()) @defer.inlineCallbacks def get_max_token(self, store): diff --git a/synapse/streams/events.py b/synapse/streams/events.py index 5c8e54b78b..dff7970bea 100644 --- a/synapse/streams/events.py +++ b/synapse/streams/events.py @@ -31,7 +31,7 @@ class NullSource(object): def get_new_events_for_user(self, user, from_key, limit): return defer.succeed(([], from_key)) - def get_current_key(self): + def get_current_key(self, direction='f'): return defer.succeed(0) def get_pagination_rows(self, user, pagination_config, key): @@ -52,10 +52,10 @@ class EventSources(object): } @defer.inlineCallbacks - def get_current_token(self): + def get_current_token(self, direction='f'): token = StreamToken( room_key=( - yield self.sources["room"].get_current_key() + yield self.sources["room"].get_current_key(direction) ), presence_key=( yield self.sources["presence"].get_current_key() diff --git a/synapse/types.py b/synapse/types.py index f6a1b0bbcf..1b21160c57 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -70,6 +70,8 @@ class DomainSpecificString( """Return a string encoding the fields of the structure object.""" return "%s%s:%s" % (self.SIGIL, self.localpart, self.domain) + __str__ = to_string + @classmethod def create(cls, localpart, domain,): return cls(localpart=localpart, domain=domain) @@ -107,7 +109,6 @@ class StreamToken( def from_string(cls, string): try: keys = string.split(cls._SEPARATOR) - return cls(*keys) except: raise SynapseError(400, "Invalid Token") @@ -115,10 +116,95 @@ class StreamToken( def to_string(self): return self._SEPARATOR.join([str(k) for k in self]) + @property + def room_stream_id(self): + # TODO(markjh): Awful hack to work around hacks in the presence tests + # which assume that the keys are integers. + if type(self.room_key) is int: + return self.room_key + else: + return int(self.room_key[1:].split("-")[-1]) + + def is_after(self, other_token): + """Does this token contain events that the other doesn't?""" + return ( + (other_token.room_stream_id < self.room_stream_id) + or (int(other_token.presence_key) < int(self.presence_key)) + or (int(other_token.typing_key) < int(self.typing_key)) + ) + + def copy_and_advance(self, key, new_value): + """Advance the given key in the token to a new value if and only if the + new value is after the old value. + """ + new_token = self.copy_and_replace(key, new_value) + if key == "room_key": + new_id = new_token.room_stream_id + old_id = self.room_stream_id + else: + new_id = int(getattr(new_token, key)) + old_id = int(getattr(self, key)) + if old_id < new_id: + return new_token + else: + return self + def copy_and_replace(self, key, new_value): d = self._asdict() d[key] = new_value return StreamToken(**d) +class RoomStreamToken(namedtuple("_StreamToken", "topological stream")): + """Tokens are positions between events. The token "s1" comes after event 1. + + s0 s1 + | | + [0] V [1] V [2] + + Tokens can either be a point in the live event stream or a cursor going + through historic events. + + When traversing the live event stream events are ordered by when they + arrived at the homeserver. + + When traversing historic events the events are ordered by their depth in + the event graph "topological_ordering" and then by when they arrived at the + homeserver "stream_ordering". + + Live tokens start with an "s" followed by the "stream_ordering" id of the + event it comes after. Historic tokens start with a "t" followed by the + "topological_ordering" id of the event it comes after, follewed by "-", + followed by the "stream_ordering" id of the event it comes after. + """ + __slots__ = [] + + @classmethod + def parse(cls, string): + try: + if string[0] == 's': + return cls(topological=None, stream=int(string[1:])) + if string[0] == 't': + parts = string[1:].split('-', 1) + return cls(topological=int(parts[0]), stream=int(parts[1])) + except: + pass + raise SynapseError(400, "Invalid token %r" % (string,)) + + @classmethod + def parse_stream_token(cls, string): + try: + if string[0] == 's': + return cls(topological=None, stream=int(string[1:])) + except: + pass + raise SynapseError(400, "Invalid token %r" % (string,)) + + def __str__(self): + if self.topological is not None: + return "t%d-%d" % (self.topological, self.stream) + else: + return "s%d" % (self.stream,) + + ClientInfo = namedtuple("ClientInfo", ("device_id", "token_id")) diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index 79109d0b19..260714ccc2 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.util.logcontext import LoggingContext +from synapse.util.logcontext import LoggingContext, PreserveLoggingContext from twisted.internet import defer, reactor, task @@ -23,6 +23,40 @@ import logging logger = logging.getLogger(__name__) +def unwrapFirstError(failure): + # defer.gatherResults and DeferredLists wrap failures. + failure.trap(defer.FirstError) + return failure.value.subFailure + + +def unwrap_deferred(d): + """Given a deferred that we know has completed, return its value or raise + the failure as an exception + """ + if not d.called: + raise RuntimeError("deferred has not finished") + + res = [] + + def f(r): + res.append(r) + return r + d.addCallback(f) + + if res: + return res[0] + + def f(r): + res.append(r) + return r + d.addErrback(f) + + if res: + res[0].raiseException() + else: + raise RuntimeError("deferred did not call callbacks") + + class Clock(object): """A small utility that obtains current time-of-day so that time may be mocked during unit-tests. @@ -46,13 +80,16 @@ class Clock(object): def stop_looping_call(self, loop): loop.stop() - def call_later(self, delay, callback): + def call_later(self, delay, callback, *args, **kwargs): current_context = LoggingContext.current_context() - def wrapped_callback(): - LoggingContext.thread_local.current_context = current_context - callback() - return reactor.callLater(delay, wrapped_callback) + def wrapped_callback(*args, **kwargs): + with PreserveLoggingContext(): + LoggingContext.thread_local.current_context = current_context + callback(*args, **kwargs) + + with PreserveLoggingContext(): + return reactor.callLater(delay, wrapped_callback, *args, **kwargs) def cancel_call_later(self, timer): timer.cancel() diff --git a/synapse/util/async.py b/synapse/util/async.py index d8febdb90c..1c2044e5b4 100644 --- a/synapse/util/async.py +++ b/synapse/util/async.py @@ -16,15 +16,13 @@ from twisted.internet import defer, reactor -from .logcontext import PreserveLoggingContext +from .logcontext import preserve_context_over_deferred -@defer.inlineCallbacks def sleep(seconds): d = defer.Deferred() reactor.callLater(seconds, d.callback, seconds) - with PreserveLoggingContext(): - yield d + return preserve_context_over_deferred(d) def run_on_reactor(): @@ -34,20 +32,56 @@ def run_on_reactor(): return sleep(0) -def create_observer(deferred): - """Creates a deferred that observes the result or failure of the given - deferred *without* affecting the given deferred. +class ObservableDeferred(object): + """Wraps a deferred object so that we can add observer deferreds. These + observer deferreds do not affect the callback chain of the original + deferred. + + If consumeErrors is true errors will be captured from the origin deferred. """ - d = defer.Deferred() - def callback(r): - d.callback(r) - return r + __slots__ = ["_deferred", "_observers", "_result"] + + def __init__(self, deferred, consumeErrors=False): + object.__setattr__(self, "_deferred", deferred) + object.__setattr__(self, "_result", None) + object.__setattr__(self, "_observers", []) + + def callback(r): + self._result = (True, r) + while self._observers: + try: + self._observers.pop().callback(r) + except: + pass + return r + + def errback(f): + self._result = (False, f) + while self._observers: + try: + self._observers.pop().errback(f) + except: + pass + + if consumeErrors: + return None + else: + return f + + deferred.addCallbacks(callback, errback) - def errback(f): - d.errback(f) - return f + def observe(self): + if not self._result: + d = defer.Deferred() + self._observers.append(d) + return d + else: + success, res = self._result + return defer.succeed(res) if success else defer.fail(res) - deferred.addCallbacks(callback, errback) + def __getattr__(self, name): + return getattr(self._deferred, name) - return d + def __setattr__(self, name, value): + setattr(self._deferred, name, value) diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py index 9d9c350397..064c4a7a1e 100644 --- a/synapse/util/distributor.py +++ b/synapse/util/distributor.py @@ -13,10 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.util.logcontext import PreserveLoggingContext - from twisted.internet import defer +from synapse.util.logcontext import ( + PreserveLoggingContext, preserve_context_over_deferred, +) + +from synapse.util import unwrapFirstError + import logging @@ -93,7 +97,6 @@ class Signal(object): Each observer callable may return a Deferred.""" self.observers.append(observer) - @defer.inlineCallbacks def fire(self, *args, **kwargs): """Invokes every callable in the observer list, passing in the args and kwargs. Exceptions thrown by observers are logged but ignored. It is @@ -101,24 +104,28 @@ class Signal(object): Returns a Deferred that will complete when all the observers have completed.""" + + def do(observer): + def eb(failure): + logger.warning( + "%s signal observer %s failed: %r", + self.name, observer, failure, + exc_info=( + failure.type, + failure.value, + failure.getTracebackObject())) + if not self.suppress_failures: + return failure + return defer.maybeDeferred(observer, *args, **kwargs).addErrback(eb) + with PreserveLoggingContext(): - deferreds = [] - for observer in self.observers: - d = defer.maybeDeferred(observer, *args, **kwargs) - - def eb(failure): - logger.warning( - "%s signal observer %s failed: %r", - self.name, observer, failure, - exc_info=( - failure.type, - failure.value, - failure.getTracebackObject())) - if not self.suppress_failures: - failure.raiseException() - deferreds.append(d.addErrback(eb)) - results = [] - for deferred in deferreds: - result = yield deferred - results.append(result) - defer.returnValue(results) + deferreds = [ + do(observer) + for observer in self.observers + ] + + d = defer.gatherResults(deferreds, consumeErrors=True) + + d.addErrback(unwrapFirstError) + + return preserve_context_over_deferred(d) diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py index da7872e95d..a92d518b43 100644 --- a/synapse/util/logcontext.py +++ b/synapse/util/logcontext.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from twisted.internet import defer + import threading import logging @@ -129,3 +131,53 @@ class PreserveLoggingContext(object): def __exit__(self, type, value, traceback): """Restores the current logging context""" LoggingContext.thread_local.current_context = self.current_context + + if self.current_context is not LoggingContext.sentinel: + if self.current_context.parent_context is None: + logger.warn( + "Restoring dead context: %s", + self.current_context, + ) + + +def preserve_context_over_fn(fn, *args, **kwargs): + """Takes a function and invokes it with the given arguments, but removes + and restores the current logging context while doing so. + + If the result is a deferred, call preserve_context_over_deferred before + returning it. + """ + with PreserveLoggingContext(): + res = fn(*args, **kwargs) + + if isinstance(res, defer.Deferred): + return preserve_context_over_deferred(res) + else: + return res + + +def preserve_context_over_deferred(deferred): + """Given a deferred wrap it such that any callbacks added later to it will + be invoked with the current context. + """ + d = defer.Deferred() + + current_context = LoggingContext.current_context() + + def cb(res): + with PreserveLoggingContext(): + LoggingContext.thread_local.current_context = current_context + res = d.callback(res) + return res + + def eb(failure): + with PreserveLoggingContext(): + LoggingContext.thread_local.current_context = current_context + res = d.errback(failure) + return res + + if deferred.called: + return deferred + + deferred.addCallbacks(cb, eb) + return d |