diff options
-rw-r--r-- | synapse/events/snapshot.py | 1 | ||||
-rw-r--r-- | synapse/federation/replication.py | 291 | ||||
-rw-r--r-- | synapse/federation/transaction_queue.py | 314 | ||||
-rw-r--r-- | synapse/state.py | 127 | ||||
-rw-r--r-- | synapse/storage/__init__.py | 40 | ||||
-rw-r--r-- | synapse/storage/_base.py | 21 | ||||
-rw-r--r-- | synapse/storage/rejections.py | 33 | ||||
-rw-r--r-- | synapse/storage/schema/delta/v12.sql | 21 | ||||
-rw-r--r-- | synapse/storage/schema/im.sql | 7 | ||||
-rw-r--r-- | tests/test_state.py | 428 |
10 files changed, 866 insertions, 417 deletions
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index 6bbba8d6ba..7e98bdef28 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -20,3 +20,4 @@ class EventContext(object): self.current_state = current_state self.auth_events = auth_events self.state_group = None + self.rejected = False diff --git a/synapse/federation/replication.py b/synapse/federation/replication.py index 6620532a60..accf95e406 100644 --- a/synapse/federation/replication.py +++ b/synapse/federation/replication.py @@ -20,8 +20,8 @@ a given transport. from twisted.internet import defer from .units import Transaction, Edu - from .persistence import TransactionActions +from .transaction_queue import TransactionQueue from synapse.util.logutils import log_function from synapse.util.logcontext import PreserveLoggingContext @@ -62,7 +62,7 @@ class ReplicationLayer(object): # self.pdu_actions = PduActions(self.store) self.transaction_actions = TransactionActions(self.store) - self._transaction_queue = _TransactionQueue( + self._transaction_queue = TransactionQueue( hs, self.transaction_actions, transport_layer ) @@ -662,290 +662,3 @@ class ReplicationLayer(object): event.internal_metadata.outlier = outlier return event - - -class _TransactionQueue(object): - """This class makes sure we only have one transaction in flight at - a time for a given destination. - - It batches pending PDUs into single transactions. - """ - - def __init__(self, hs, transaction_actions, transport_layer): - self.server_name = hs.hostname - self.transaction_actions = transaction_actions - self.transport_layer = transport_layer - - self._clock = hs.get_clock() - self.store = hs.get_datastore() - - # Is a mapping from destinations -> deferreds. Used to keep track - # of which destinations have transactions in flight and when they are - # done - self.pending_transactions = {} - - # Is a mapping from destination -> list of - # tuple(pending pdus, deferred, order) - self.pending_pdus_by_dest = {} - # destination -> list of tuple(edu, deferred) - self.pending_edus_by_dest = {} - - # destination -> list of tuple(failure, deferred) - self.pending_failures_by_dest = {} - - # HACK to get unique tx id - self._next_txn_id = int(self._clock.time_msec()) - - @defer.inlineCallbacks - @log_function - def enqueue_pdu(self, pdu, destinations, order): - # We loop through all destinations to see whether we already have - # a transaction in progress. If we do, stick it in the pending_pdus - # table and we'll get back to it later. - - destinations = set(destinations) - destinations.discard(self.server_name) - destinations.discard("localhost") - - logger.debug("Sending to: %s", str(destinations)) - - if not destinations: - return - - deferreds = [] - - for destination in destinations: - deferred = defer.Deferred() - self.pending_pdus_by_dest.setdefault(destination, []).append( - (pdu, deferred, order) - ) - - def eb(failure): - if not deferred.called: - deferred.errback(failure) - else: - logger.warn("Failed to send pdu", failure) - - with PreserveLoggingContext(): - self._attempt_new_transaction(destination).addErrback(eb) - - deferreds.append(deferred) - - yield defer.DeferredList(deferreds) - - # NO inlineCallbacks - def enqueue_edu(self, edu): - destination = edu.destination - - if destination == self.server_name: - return - - deferred = defer.Deferred() - self.pending_edus_by_dest.setdefault(destination, []).append( - (edu, deferred) - ) - - def eb(failure): - if not deferred.called: - deferred.errback(failure) - else: - logger.warn("Failed to send edu", failure) - - with PreserveLoggingContext(): - self._attempt_new_transaction(destination).addErrback(eb) - - return deferred - - @defer.inlineCallbacks - def enqueue_failure(self, failure, destination): - deferred = defer.Deferred() - - self.pending_failures_by_dest.setdefault( - destination, [] - ).append( - (failure, deferred) - ) - - yield deferred - - @defer.inlineCallbacks - @log_function - def _attempt_new_transaction(self, destination): - - (retry_last_ts, retry_interval) = (0, 0) - retry_timings = yield self.store.get_destination_retry_timings( - destination - ) - if retry_timings: - (retry_last_ts, retry_interval) = ( - retry_timings.retry_last_ts, retry_timings.retry_interval - ) - if retry_last_ts + retry_interval > int(self._clock.time_msec()): - logger.info( - "TX [%s] not ready for retry yet - " - "dropping transaction for now", - destination, - ) - return - else: - logger.info("TX [%s] is ready for retry", destination) - - logger.info("TX [%s] _attempt_new_transaction", destination) - - if destination in self.pending_transactions: - # XXX: pending_transactions can get stuck on by a never-ending - # request at which point pending_pdus_by_dest just keeps growing. - # we need application-layer timeouts of some flavour of these - # requests - return - - # list of (pending_pdu, deferred, order) - pending_pdus = self.pending_pdus_by_dest.pop(destination, []) - pending_edus = self.pending_edus_by_dest.pop(destination, []) - pending_failures = self.pending_failures_by_dest.pop(destination, []) - - if pending_pdus: - logger.info("TX [%s] len(pending_pdus_by_dest[dest]) = %d", - destination, len(pending_pdus)) - - if not pending_pdus and not pending_edus and not pending_failures: - return - - logger.debug( - "TX [%s] Attempting new transaction" - " (pdus: %d, edus: %d, failures: %d)", - destination, - len(pending_pdus), - len(pending_edus), - len(pending_failures) - ) - - # Sort based on the order field - pending_pdus.sort(key=lambda t: t[2]) - - pdus = [x[0] for x in pending_pdus] - edus = [x[0] for x in pending_edus] - failures = [x[0].get_dict() for x in pending_failures] - deferreds = [ - x[1] - for x in pending_pdus + pending_edus + pending_failures - ] - - try: - self.pending_transactions[destination] = 1 - - logger.debug("TX [%s] Persisting transaction...", destination) - - transaction = Transaction.create_new( - origin_server_ts=int(self._clock.time_msec()), - transaction_id=str(self._next_txn_id), - origin=self.server_name, - destination=destination, - pdus=pdus, - edus=edus, - pdu_failures=failures, - ) - - self._next_txn_id += 1 - - yield self.transaction_actions.prepare_to_send(transaction) - - logger.debug("TX [%s] Persisted transaction", destination) - logger.info( - "TX [%s] Sending transaction [%s]", - destination, - transaction.transaction_id, - ) - - # Actually send the transaction - - # FIXME (erikj): This is a bit of a hack to make the Pdu age - # keys work - def json_data_cb(): - data = transaction.get_dict() - now = int(self._clock.time_msec()) - if "pdus" in data: - for p in data["pdus"]: - if "age_ts" in p: - unsigned = p.setdefault("unsigned", {}) - unsigned["age"] = now - int(p["age_ts"]) - del p["age_ts"] - return data - - code, response = yield self.transport_layer.send_transaction( - transaction, json_data_cb - ) - - logger.info("TX [%s] got %d response", destination, code) - - logger.debug("TX [%s] Sent transaction", destination) - logger.debug("TX [%s] Marking as delivered...", destination) - - yield self.transaction_actions.delivered( - transaction, code, response - ) - - logger.debug("TX [%s] Marked as delivered", destination) - logger.debug("TX [%s] Yielding to callbacks...", destination) - - for deferred in deferreds: - if code == 200: - if retry_last_ts: - # this host is alive! reset retry schedule - yield self.store.set_destination_retry_timings( - destination, 0, 0 - ) - deferred.callback(None) - else: - self.set_retrying(destination, retry_interval) - deferred.errback(RuntimeError("Got status %d" % code)) - - # Ensures we don't continue until all callbacks on that - # deferred have fired - try: - yield deferred - except: - pass - - logger.debug("TX [%s] Yielded to callbacks", destination) - - except Exception as e: - # We capture this here as there as nothing actually listens - # for this finishing functions deferred. - logger.warn( - "TX [%s] Problem in _attempt_transaction: %s", - destination, - e, - ) - - self.set_retrying(destination, retry_interval) - - for deferred in deferreds: - if not deferred.called: - deferred.errback(e) - - finally: - # We want to be *very* sure we delete this after we stop processing - self.pending_transactions.pop(destination, None) - - # Check to see if there is anything else to send. - self._attempt_new_transaction(destination) - - @defer.inlineCallbacks - def set_retrying(self, destination, retry_interval): - # track that this destination is having problems and we should - # give it a chance to recover before trying it again - - if retry_interval: - retry_interval *= 2 - # plateau at hourly retries for now - if retry_interval >= 60 * 60 * 1000: - retry_interval = 60 * 60 * 1000 - else: - retry_interval = 2000 # try again at first after 2 seconds - - yield self.store.set_destination_retry_timings( - destination, - int(self._clock.time_msec()), - retry_interval - ) diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py new file mode 100644 index 0000000000..c2cb4a1c49 --- /dev/null +++ b/synapse/federation/transaction_queue.py @@ -0,0 +1,314 @@ +# -*- coding: utf-8 -*- +# Copyright 2014, 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from twisted.internet import defer + +from .units import Transaction + +from synapse.util.logutils import log_function +from synapse.util.logcontext import PreserveLoggingContext + +import logging + + +logger = logging.getLogger(__name__) + + +class TransactionQueue(object): + """This class makes sure we only have one transaction in flight at + a time for a given destination. + + It batches pending PDUs into single transactions. + """ + + def __init__(self, hs, transaction_actions, transport_layer): + self.server_name = hs.hostname + self.transaction_actions = transaction_actions + self.transport_layer = transport_layer + + self._clock = hs.get_clock() + self.store = hs.get_datastore() + + # Is a mapping from destinations -> deferreds. Used to keep track + # of which destinations have transactions in flight and when they are + # done + self.pending_transactions = {} + + # Is a mapping from destination -> list of + # tuple(pending pdus, deferred, order) + self.pending_pdus_by_dest = {} + # destination -> list of tuple(edu, deferred) + self.pending_edus_by_dest = {} + + # destination -> list of tuple(failure, deferred) + self.pending_failures_by_dest = {} + + # HACK to get unique tx id + self._next_txn_id = int(self._clock.time_msec()) + + @defer.inlineCallbacks + @log_function + def enqueue_pdu(self, pdu, destinations, order): + # We loop through all destinations to see whether we already have + # a transaction in progress. If we do, stick it in the pending_pdus + # table and we'll get back to it later. + + destinations = set(destinations) + destinations.discard(self.server_name) + destinations.discard("localhost") + + logger.debug("Sending to: %s", str(destinations)) + + if not destinations: + return + + deferreds = [] + + for destination in destinations: + deferred = defer.Deferred() + self.pending_pdus_by_dest.setdefault(destination, []).append( + (pdu, deferred, order) + ) + + def eb(failure): + if not deferred.called: + deferred.errback(failure) + else: + logger.warn("Failed to send pdu", failure) + + with PreserveLoggingContext(): + self._attempt_new_transaction(destination).addErrback(eb) + + deferreds.append(deferred) + + yield defer.DeferredList(deferreds) + + # NO inlineCallbacks + def enqueue_edu(self, edu): + destination = edu.destination + + if destination == self.server_name: + return + + deferred = defer.Deferred() + self.pending_edus_by_dest.setdefault(destination, []).append( + (edu, deferred) + ) + + def eb(failure): + if not deferred.called: + deferred.errback(failure) + else: + logger.warn("Failed to send edu", failure) + + with PreserveLoggingContext(): + self._attempt_new_transaction(destination).addErrback(eb) + + return deferred + + @defer.inlineCallbacks + def enqueue_failure(self, failure, destination): + deferred = defer.Deferred() + + self.pending_failures_by_dest.setdefault( + destination, [] + ).append( + (failure, deferred) + ) + + yield deferred + + @defer.inlineCallbacks + @log_function + def _attempt_new_transaction(self, destination): + + (retry_last_ts, retry_interval) = (0, 0) + retry_timings = yield self.store.get_destination_retry_timings( + destination + ) + if retry_timings: + (retry_last_ts, retry_interval) = ( + retry_timings.retry_last_ts, retry_timings.retry_interval + ) + if retry_last_ts + retry_interval > int(self._clock.time_msec()): + logger.info( + "TX [%s] not ready for retry yet - " + "dropping transaction for now", + destination, + ) + return + else: + logger.info("TX [%s] is ready for retry", destination) + + logger.info("TX [%s] _attempt_new_transaction", destination) + + if destination in self.pending_transactions: + # XXX: pending_transactions can get stuck on by a never-ending + # request at which point pending_pdus_by_dest just keeps growing. + # we need application-layer timeouts of some flavour of these + # requests + return + + # list of (pending_pdu, deferred, order) + pending_pdus = self.pending_pdus_by_dest.pop(destination, []) + pending_edus = self.pending_edus_by_dest.pop(destination, []) + pending_failures = self.pending_failures_by_dest.pop(destination, []) + + if pending_pdus: + logger.info("TX [%s] len(pending_pdus_by_dest[dest]) = %d", + destination, len(pending_pdus)) + + if not pending_pdus and not pending_edus and not pending_failures: + return + + logger.debug( + "TX [%s] Attempting new transaction" + " (pdus: %d, edus: %d, failures: %d)", + destination, + len(pending_pdus), + len(pending_edus), + len(pending_failures) + ) + + # Sort based on the order field + pending_pdus.sort(key=lambda t: t[2]) + + pdus = [x[0] for x in pending_pdus] + edus = [x[0] for x in pending_edus] + failures = [x[0].get_dict() for x in pending_failures] + deferreds = [ + x[1] + for x in pending_pdus + pending_edus + pending_failures + ] + + try: + self.pending_transactions[destination] = 1 + + logger.debug("TX [%s] Persisting transaction...", destination) + + transaction = Transaction.create_new( + origin_server_ts=int(self._clock.time_msec()), + transaction_id=str(self._next_txn_id), + origin=self.server_name, + destination=destination, + pdus=pdus, + edus=edus, + pdu_failures=failures, + ) + + self._next_txn_id += 1 + + yield self.transaction_actions.prepare_to_send(transaction) + + logger.debug("TX [%s] Persisted transaction", destination) + logger.info( + "TX [%s] Sending transaction [%s]", + destination, + transaction.transaction_id, + ) + + # Actually send the transaction + + # FIXME (erikj): This is a bit of a hack to make the Pdu age + # keys work + def json_data_cb(): + data = transaction.get_dict() + now = int(self._clock.time_msec()) + if "pdus" in data: + for p in data["pdus"]: + if "age_ts" in p: + unsigned = p.setdefault("unsigned", {}) + unsigned["age"] = now - int(p["age_ts"]) + del p["age_ts"] + return data + + code, response = yield self.transport_layer.send_transaction( + transaction, json_data_cb + ) + + logger.info("TX [%s] got %d response", destination, code) + + logger.debug("TX [%s] Sent transaction", destination) + logger.debug("TX [%s] Marking as delivered...", destination) + + yield self.transaction_actions.delivered( + transaction, code, response + ) + + logger.debug("TX [%s] Marked as delivered", destination) + logger.debug("TX [%s] Yielding to callbacks...", destination) + + for deferred in deferreds: + if code == 200: + if retry_last_ts: + # this host is alive! reset retry schedule + yield self.store.set_destination_retry_timings( + destination, 0, 0 + ) + deferred.callback(None) + else: + self.set_retrying(destination, retry_interval) + deferred.errback(RuntimeError("Got status %d" % code)) + + # Ensures we don't continue until all callbacks on that + # deferred have fired + try: + yield deferred + except: + pass + + logger.debug("TX [%s] Yielded to callbacks", destination) + + except Exception as e: + # We capture this here as there as nothing actually listens + # for this finishing functions deferred. + logger.warn( + "TX [%s] Problem in _attempt_transaction: %s", + destination, + e, + ) + + self.set_retrying(destination, retry_interval) + + for deferred in deferreds: + if not deferred.called: + deferred.errback(e) + + finally: + # We want to be *very* sure we delete this after we stop processing + self.pending_transactions.pop(destination, None) + + # Check to see if there is anything else to send. + self._attempt_new_transaction(destination) + + @defer.inlineCallbacks + def set_retrying(self, destination, retry_interval): + # track that this destination is having problems and we should + # give it a chance to recover before trying it again + + if retry_interval: + retry_interval *= 2 + # plateau at hourly retries for now + if retry_interval >= 60 * 60 * 1000: + retry_interval = 60 * 60 * 1000 + else: + retry_interval = 2000 # try again at first after 2 seconds + + yield self.store.set_destination_retry_timings( + destination, + int(self._clock.time_msec()), + retry_interval + ) diff --git a/synapse/state.py b/synapse/state.py index 8144fa02b4..d9fdfb34be 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -19,6 +19,7 @@ from twisted.internet import defer from synapse.util.logutils import log_function from synapse.util.async import run_on_reactor from synapse.api.constants import EventTypes +from synapse.api.errors import AuthError from synapse.events.snapshot import EventContext from collections import namedtuple @@ -42,6 +43,8 @@ class StateHandler(object): def __init__(self, hs): self.store = hs.get_datastore() + # self.auth = hs.get_auth() + self.hs = hs @defer.inlineCallbacks def get_current_state(self, room_id, event_type=None, state_key=""): @@ -210,64 +213,96 @@ class StateHandler(object): else: prev_states = [] + auth_events = { + k: e for k, e in unconflicted_state.items() + if k[0] in (EventTypes.Create, EventTypes.Member, EventTypes.PowerLevels,) + } + try: - new_state = {} - new_state.update(unconflicted_state) - for key, events in conflicted_state.items(): - new_state[key] = self._resolve_state_events(events) + resolved_state = self._resolve_state_events( + conflicted_state, auth_events + ) except: logger.exception("Failed to resolve state") raise + new_state = unconflicted_state + new_state.update(resolved_state) + defer.returnValue((None, new_state, prev_states)) - def _get_power_level_from_event_state(self, event, user_id): - if hasattr(event, "old_state_events") and event.old_state_events: - key = (EventTypes.PowerLevels, "", ) - power_level_event = event.old_state_events.get(key) - level = None - if power_level_event: - level = power_level_event.content.get("users", {}).get( - user_id + @log_function + def _resolve_state_events(self, conflicted_state, auth_events): + """ This is where we actually decide which of the conflicted state to + use. + + We resolve conflicts in the following order: + 1. power levels + 2. memberships + 3. other events. + + :param conflicted_state: + :param auth_events: + :return: + """ + resolved_state = {} + power_key = (EventTypes.PowerLevels, "") + if power_key in conflicted_state.items(): + power_levels = conflicted_state[power_key] + resolved_state[power_key] = self._resolve_auth_events(power_levels) + + auth_events.update(resolved_state) + + for key, events in conflicted_state.items(): + if key[0] == EventTypes.Member: + resolved_state[key] = self._resolve_auth_events( + events, + auth_events ) - if not level: - level = power_level_event.content.get("users_default", 0) - return level - else: - return 0 + auth_events.update(resolved_state) - @log_function - def _resolve_state_events(self, events): - curr_events = events + for key, events in conflicted_state.items(): + if key not in resolved_state: + resolved_state[key] = self._resolve_normal_events( + events, auth_events + ) - new_powers = [ - self._get_power_level_from_event_state(e, e.user_id) - for e in curr_events - ] + return resolved_state - new_powers = [ - int(p) if p else 0 for p in new_powers - ] + def _resolve_auth_events(self, events, auth_events): + reverse = [i for i in reversed(self._ordered_events(events))] - max_power = max(new_powers) + auth_events = dict(auth_events) - curr_events = [ - z[0] for z in zip(curr_events, new_powers) - if z[1] == max_power - ] + prev_event = reverse[0] + for event in reverse[1:]: + auth_events[(prev_event.type, prev_event.state_key)] = prev_event + try: + # FIXME: hs.get_auth() is bad style, but we need to do it to + # get around circular deps. + self.hs.get_auth().check(event, auth_events) + prev_event = event + except AuthError: + return prev_event - if not curr_events: - raise RuntimeError("Max didn't get a max?") - elif len(curr_events) == 1: - return curr_events[0] - - # TODO: For now, just choose the one with the largest event_id. - return ( - sorted( - curr_events, - key=lambda e: hashlib.sha1( - e.event_id + e.user_id + e.room_id + e.type - ).hexdigest() - )[0] - ) + return event + + def _resolve_normal_events(self, events, auth_events): + for event in self._ordered_events(events): + try: + # FIXME: hs.get_auth() is bad style, but we need to do it to + # get around circular deps. + self.hs.get_auth().check(event, auth_events) + return event + except AuthError: + pass + + # Oh dear. + return event + + def _ordered_events(self, events): + def key_func(e): + return -int(e.depth), hashlib.sha1(e.event_id).hexdigest() + + return sorted(events, key=key_func) \ No newline at end of file diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 4beb951b9f..4f09909607 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -30,6 +30,7 @@ from .transactions import TransactionStore from .keys import KeyStore from .event_federation import EventFederationStore from .media_repository import MediaRepositoryStore +from .rejections import RejectionsStore from .state import StateStore from .signatures import SignatureStore @@ -66,7 +67,7 @@ SCHEMAS = [ # Remember to update this number every time an incompatible change is made to # database schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 11 +SCHEMA_VERSION = 12 class _RollbackButIsFineException(Exception): @@ -82,6 +83,7 @@ class DataStore(RoomMemberStore, RoomStore, DirectoryStore, KeyStore, StateStore, SignatureStore, EventFederationStore, MediaRepositoryStore, + RejectionsStore, ): def __init__(self, hs): @@ -224,6 +226,9 @@ class DataStore(RoomMemberStore, RoomStore, if not outlier: self._store_state_groups_txn(txn, event, context) + if context.rejected: + self._store_rejections_txn(txn, event.event_id, context.rejected) + if current_state: txn.execute( "DELETE FROM current_state_events WHERE room_id = ?", @@ -262,7 +267,7 @@ class DataStore(RoomMemberStore, RoomStore, or_replace=True, ) - if is_new_state: + if is_new_state and not context.rejected: self._simple_insert_txn( txn, "current_state_events", @@ -288,7 +293,7 @@ class DataStore(RoomMemberStore, RoomStore, or_ignore=True, ) - if not backfilled: + if not backfilled and not context.rejected: self._simple_insert_txn( txn, table="state_forward_extremities", @@ -417,6 +422,35 @@ class DataStore(RoomMemberStore, RoomStore, ], ) + def have_events(self, event_ids): + """Given a list of event ids, check if we have already processed them. + + Returns: + dict: Has an entry for each event id we already have seen. Maps to + the rejected reason string if we rejected the event, else maps to + None. + """ + def f(txn): + sql = ( + "SELECT e.event_id, reason FROM events as e " + "LEFT JOIN rejections as r ON e.event_id = r.event_id " + "WHERE event_id = ?" + ) + + res = {} + for event_id in event_ids: + txn.execute(sql, (event_id,)) + row = txn.fetchone() + if row: + _, rejected = row + res[event_id] = rejected + + return res + + return self.runInteraction( + "have_events", f, + ) + def schema_path(schema): """ Get a filesystem path for the named database schema diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index f660fc6eaf..2075a018b2 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -458,10 +458,12 @@ class SQLBaseStore(object): return [e for e in events if e] def _get_event_txn(self, txn, event_id, check_redacted=True, - get_prev_content=False): + get_prev_content=False, allow_rejected=False): sql = ( - "SELECT internal_metadata, json, r.event_id FROM event_json as e " + "SELECT internal_metadata, json, r.event_id, reason " + "FROM event_json as e " "LEFT JOIN redactions as r ON e.event_id = r.redacts " + "LEFT JOIN rejections as rej on rej.event_id = e.event_id " "WHERE e.event_id = ? " "LIMIT 1 " ) @@ -473,13 +475,16 @@ class SQLBaseStore(object): if not res: return None - internal_metadata, js, redacted = res + internal_metadata, js, redacted, rejected_reason = res - return self._get_event_from_row_txn( - txn, internal_metadata, js, redacted, - check_redacted=check_redacted, - get_prev_content=get_prev_content, - ) + if allow_rejected or not rejected_reason: + return self._get_event_from_row_txn( + txn, internal_metadata, js, redacted, + check_redacted=check_redacted, + get_prev_content=get_prev_content, + ) + else: + return None def _get_event_from_row_txn(self, txn, internal_metadata, js, redacted, check_redacted=True, get_prev_content=False): diff --git a/synapse/storage/rejections.py b/synapse/storage/rejections.py new file mode 100644 index 0000000000..7d38b31f44 --- /dev/null +++ b/synapse/storage/rejections.py @@ -0,0 +1,33 @@ +# -*- coding: utf-8 -*- +# Copyright 2014, 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import SQLBaseStore + +import logging + +logger = logging.getLogger(__name__) + + +class RejectionsStore(SQLBaseStore): + def _store_rejections_txn(self, txn, event_id, reason): + self._simple_insert_txn( + txn, + table="rejections", + values={ + "event_id": event_id, + "reason": reason, + "last_failure": self._clock.time_msec(), + } + ) diff --git a/synapse/storage/schema/delta/v12.sql b/synapse/storage/schema/delta/v12.sql new file mode 100644 index 0000000000..bd2a8b1bb5 --- /dev/null +++ b/synapse/storage/schema/delta/v12.sql @@ -0,0 +1,21 @@ +/* Copyright 2015 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS rejections( + event_id TEXT NOT NULL, + reason TEXT NOT NULL, + last_check TEXT NOT NULL, + CONSTRAINT ev_id UNIQUE (event_id) ON CONFLICT REPLACE +); diff --git a/synapse/storage/schema/im.sql b/synapse/storage/schema/im.sql index dd00c1cd2f..bc7c6b6ed5 100644 --- a/synapse/storage/schema/im.sql +++ b/synapse/storage/schema/im.sql @@ -123,3 +123,10 @@ CREATE TABLE IF NOT EXISTS room_hosts( ); CREATE INDEX IF NOT EXISTS room_hosts_room_id ON room_hosts (room_id); + +CREATE TABLE IF NOT EXISTS rejections( + event_id TEXT NOT NULL, + reason TEXT NOT NULL, + last_check TEXT NOT NULL, + CONSTRAINT ev_id UNIQUE (event_id) ON CONFLICT REPLACE +); diff --git a/tests/test_state.py b/tests/test_state.py index 98ad9e54cd..019e794aa2 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -16,11 +16,120 @@ from tests import unittest from twisted.internet import defer +from synapse.events import FrozenEvent +from synapse.api.auth import Auth +from synapse.api.constants import EventTypes, Membership from synapse.state import StateHandler from mock import Mock +_next_event_id = 1000 + + +def create_event(name=None, type=None, state_key=None, depth=2, event_id=None, + prev_events=[], **kwargs): + global _next_event_id + + if not event_id: + _next_event_id += 1 + event_id = str(_next_event_id) + + if not name: + if state_key is not None: + name = "<%s-%s, %s>" % (type, state_key, event_id,) + else: + name = "<%s, %s>" % (type, event_id,) + + d = { + "event_id": event_id, + "type": type, + "sender": "@user_id:example.com", + "room_id": "!room_id:example.com", + "depth": depth, + "prev_events": prev_events, + } + + if state_key is not None: + d["state_key"] = state_key + + d.update(kwargs) + + event = FrozenEvent(d) + + return event + + +class StateGroupStore(object): + def __init__(self): + self._event_to_state_group = {} + self._group_to_state = {} + + self._next_group = 1 + + def get_state_groups(self, event_ids): + groups = {} + for event_id in event_ids: + group = self._event_to_state_group.get(event_id) + if group: + groups[group] = self._group_to_state[group] + + return defer.succeed(groups) + + def store_state_groups(self, event, context): + if context.current_state is None: + return + + state_events = context.current_state + + if event.is_state(): + state_events[(event.type, event.state_key)] = event + + state_group = context.state_group + if not state_group: + state_group = self._next_group + self._next_group += 1 + + self._group_to_state[state_group] = state_events.values() + + self._event_to_state_group[event.event_id] = state_group + + +class DictObj(dict): + def __init__(self, **kwargs): + super(DictObj, self).__init__(kwargs) + self.__dict__ = self + + +class Graph(object): + def __init__(self, nodes, edges): + events = {} + clobbered = set(events.keys()) + + for event_id, fields in nodes.items(): + refs = edges.get(event_id) + if refs: + clobbered.difference_update(refs) + prev_events = [(r, {}) for r in refs] + else: + prev_events = [] + + events[event_id] = create_event( + event_id=event_id, + prev_events=prev_events, + **fields + ) + + self._leaves = clobbered + self._events = sorted(events.values(), key=lambda e: e.depth) + + def walk(self): + return iter(self._events) + + def get_leaves(self): + return (self._events[i] for i in self._leaves) + + class StateTestCase(unittest.TestCase): def setUp(self): self.store = Mock( @@ -29,20 +138,188 @@ class StateTestCase(unittest.TestCase): "add_event_hashes", ] ) - hs = Mock(spec=["get_datastore"]) + hs = Mock(spec=["get_datastore", "get_auth", "get_state_handler"]) hs.get_datastore.return_value = self.store + hs.get_state_handler.return_value = None + hs.get_auth.return_value = Auth(hs) self.state = StateHandler(hs) self.event_id = 0 @defer.inlineCallbacks + def test_branch_no_conflict(self): + graph = Graph( + nodes={ + "START": DictObj( + type=EventTypes.Create, + state_key="", + depth=1, + ), + "A": DictObj( + type=EventTypes.Message, + depth=2, + ), + "B": DictObj( + type=EventTypes.Message, + depth=3, + ), + "C": DictObj( + type=EventTypes.Name, + state_key="", + depth=3, + ), + "D": DictObj( + type=EventTypes.Message, + depth=4, + ), + }, + edges={ + "A": ["START"], + "B": ["A"], + "C": ["A"], + "D": ["B", "C"] + } + ) + + store = StateGroupStore() + self.store.get_state_groups.side_effect = store.get_state_groups + + context_store = {} + + for event in graph.walk(): + context = yield self.state.compute_event_context(event) + store.store_state_groups(event, context) + context_store[event.event_id] = context + + self.assertEqual(2, len(context_store["D"].current_state)) + + @defer.inlineCallbacks + def test_branch_basic_conflict(self): + graph = Graph( + nodes={ + "START": DictObj( + type=EventTypes.Create, + state_key="creator", + content={"membership": "@user_id:example.com"}, + depth=1, + ), + "A": DictObj( + type=EventTypes.Member, + state_key="@user_id:example.com", + content={"membership": Membership.JOIN}, + membership=Membership.JOIN, + depth=2, + ), + "B": DictObj( + type=EventTypes.Name, + state_key="", + depth=3, + ), + "C": DictObj( + type=EventTypes.Name, + state_key="", + depth=4, + ), + "D": DictObj( + type=EventTypes.Message, + depth=5, + ), + }, + edges={ + "A": ["START"], + "B": ["A"], + "C": ["A"], + "D": ["B", "C"] + } + ) + + store = StateGroupStore() + self.store.get_state_groups.side_effect = store.get_state_groups + + context_store = {} + + for event in graph.walk(): + context = yield self.state.compute_event_context(event) + store.store_state_groups(event, context) + context_store[event.event_id] = context + + self.assertSetEqual( + {"START", "A", "C"}, + {e.event_id for e in context_store["D"].current_state.values()} + ) + + @defer.inlineCallbacks + def test_branch_have_banned_conflict(self): + graph = Graph( + nodes={ + "START": DictObj( + type=EventTypes.Create, + state_key="creator", + content={"membership": "@user_id:example.com"}, + depth=1, + ), + "A": DictObj( + type=EventTypes.Member, + state_key="@user_id:example.com", + content={"membership": Membership.JOIN}, + membership=Membership.JOIN, + depth=2, + ), + "B": DictObj( + type=EventTypes.Name, + state_key="", + depth=3, + ), + "C": DictObj( + type=EventTypes.Member, + state_key="@user_id_2:example.com", + content={"membership": Membership.BAN}, + membership=Membership.BAN, + depth=4, + ), + "D": DictObj( + type=EventTypes.Name, + state_key="", + depth=4, + sender="@user_id_2:example.com", + ), + "E": DictObj( + type=EventTypes.Message, + depth=5, + ), + }, + edges={ + "A": ["START"], + "B": ["A"], + "C": ["B"], + "D": ["B"], + "E": ["C", "D"] + } + ) + + store = StateGroupStore() + self.store.get_state_groups.side_effect = store.get_state_groups + + context_store = {} + + for event in graph.walk(): + context = yield self.state.compute_event_context(event) + store.store_state_groups(event, context) + context_store[event.event_id] = context + + self.assertSetEqual( + {"START", "A", "B", "C"}, + {e.event_id for e in context_store["E"].current_state.values()} + ) + + @defer.inlineCallbacks def test_annotate_with_old_message(self): - event = self.create_event(type="test_message", name="event") + event = create_event(type="test_message", name="event") old_state = [ - self.create_event(type="test1", state_key="1"), - self.create_event(type="test1", state_key="2"), - self.create_event(type="test2", state_key=""), + create_event(type="test1", state_key="1"), + create_event(type="test1", state_key="2"), + create_event(type="test2", state_key=""), ] context = yield self.state.compute_event_context( @@ -62,12 +339,12 @@ class StateTestCase(unittest.TestCase): @defer.inlineCallbacks def test_annotate_with_old_state(self): - event = self.create_event(type="state", state_key="", name="event") + event = create_event(type="state", state_key="", name="event") old_state = [ - self.create_event(type="test1", state_key="1"), - self.create_event(type="test1", state_key="2"), - self.create_event(type="test2", state_key=""), + create_event(type="test1", state_key="1"), + create_event(type="test1", state_key="2"), + create_event(type="test2", state_key=""), ] context = yield self.state.compute_event_context( @@ -88,13 +365,12 @@ class StateTestCase(unittest.TestCase): @defer.inlineCallbacks def test_trivial_annotate_message(self): - event = self.create_event(type="test_message", name="event") - event.prev_events = [] + event = create_event(type="test_message", name="event") old_state = [ - self.create_event(type="test1", state_key="1"), - self.create_event(type="test1", state_key="2"), - self.create_event(type="test2", state_key=""), + create_event(type="test1", state_key="1"), + create_event(type="test1", state_key="2"), + create_event(type="test2", state_key=""), ] group_name = "group_name_1" @@ -119,13 +395,12 @@ class StateTestCase(unittest.TestCase): @defer.inlineCallbacks def test_trivial_annotate_state(self): - event = self.create_event(type="state", state_key="", name="event") - event.prev_events = [] + event = create_event(type="state", state_key="", name="event") old_state = [ - self.create_event(type="test1", state_key="1"), - self.create_event(type="test1", state_key="2"), - self.create_event(type="test2", state_key=""), + create_event(type="test1", state_key="1"), + create_event(type="test1", state_key="2"), + create_event(type="test2", state_key=""), ] group_name = "group_name_1" @@ -150,30 +425,21 @@ class StateTestCase(unittest.TestCase): @defer.inlineCallbacks def test_resolve_message_conflict(self): - event = self.create_event(type="test_message", name="event") - event.prev_events = [] + event = create_event(type="test_message", name="event") old_state_1 = [ - self.create_event(type="test1", state_key="1"), - self.create_event(type="test1", state_key="2"), - self.create_event(type="test2", state_key=""), + create_event(type="test1", state_key="1"), + create_event(type="test1", state_key="2"), + create_event(type="test2", state_key=""), ] old_state_2 = [ - self.create_event(type="test1", state_key="1"), - self.create_event(type="test3", state_key="2"), - self.create_event(type="test4", state_key=""), + create_event(type="test1", state_key="1"), + create_event(type="test3", state_key="2"), + create_event(type="test4", state_key=""), ] - group_name_1 = "group_name_1" - group_name_2 = "group_name_2" - - self.store.get_state_groups.return_value = { - group_name_1: old_state_1, - group_name_2: old_state_2, - } - - context = yield self.state.compute_event_context(event) + context = yield self._get_context(event, old_state_1, old_state_2) self.assertEqual(len(context.current_state), 5) @@ -181,56 +447,76 @@ class StateTestCase(unittest.TestCase): @defer.inlineCallbacks def test_resolve_state_conflict(self): - event = self.create_event(type="test4", state_key="", name="event") - event.prev_events = [] + event = create_event(type="test4", state_key="", name="event") old_state_1 = [ - self.create_event(type="test1", state_key="1"), - self.create_event(type="test1", state_key="2"), - self.create_event(type="test2", state_key=""), + create_event(type="test1", state_key="1"), + create_event(type="test1", state_key="2"), + create_event(type="test2", state_key=""), ] old_state_2 = [ - self.create_event(type="test1", state_key="1"), - self.create_event(type="test3", state_key="2"), - self.create_event(type="test4", state_key=""), + create_event(type="test1", state_key="1"), + create_event(type="test3", state_key="2"), + create_event(type="test4", state_key=""), ] - group_name_1 = "group_name_1" - group_name_2 = "group_name_2" - - self.store.get_state_groups.return_value = { - group_name_1: old_state_1, - group_name_2: old_state_2, - } - - context = yield self.state.compute_event_context(event) + context = yield self._get_context(event, old_state_1, old_state_2) self.assertEqual(len(context.current_state), 5) self.assertIsNone(context.state_group) - def create_event(self, name=None, type=None, state_key=None): - self.event_id += 1 - event_id = str(self.event_id) + @defer.inlineCallbacks + def test_standard_depth_conflict(self): + event = create_event(type="test4", name="event") + + member_event = create_event( + type=EventTypes.Member, + state_key="@user_id:example.com", + content={ + "membership": Membership.JOIN, + } + ) - if not name: - if state_key is not None: - name = "<%s-%s>" % (type, state_key) - else: - name = "<%s>" % (type, ) + old_state_1 = [ + member_event, + create_event(type="test1", state_key="1", depth=1), + ] + + old_state_2 = [ + member_event, + create_event(type="test1", state_key="1", depth=2), + ] - event = Mock(name=name, spec=[]) - event.type = type + context = yield self._get_context(event, old_state_1, old_state_2) - if state_key is not None: - event.state_key = state_key - event.event_id = event_id + self.assertEqual(old_state_2[1], context.current_state[("test1", "1")]) + + # Reverse the depth to make sure we are actually using the depths + # during state resolution. + + old_state_1 = [ + member_event, + create_event(type="test1", state_key="1", depth=2), + ] + + old_state_2 = [ + member_event, + create_event(type="test1", state_key="1", depth=1), + ] + + context = yield self._get_context(event, old_state_1, old_state_2) + + self.assertEqual(old_state_1[1], context.current_state[("test1", "1")]) - event.is_state = lambda: (state_key is not None) - event.unsigned = {} + def _get_context(self, event, old_state_1, old_state_2): + group_name_1 = "group_name_1" + group_name_2 = "group_name_2" - event.user_id = "@user_id:example.com" - event.room_id = "!room_id:example.com" + self.store.get_state_groups.return_value = { + group_name_1: old_state_1, + group_name_2: old_state_2, + } - return event + return self.state.compute_event_context(event) |