From c0577ea87a19c169d68ed83760582fa1fabe36e5 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 8 Sep 2014 18:34:18 +0100 Subject: Rollback if we try and insert duplicate events --- synapse/storage/__init__.py | 1 + 1 file changed, 1 insertion(+) (limited to 'synapse/storage') diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 81c3c94b2e..8ed80109a5 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -179,6 +179,7 @@ class DataStore(RoomMemberStore, RoomStore, "Failed to persist, probably duplicate: %s", event.event_id ) + txn.rollback() return if not backfilled and hasattr(event, "state_key"): -- cgit 1.5.1 From 83ce57302dab6a825f3afde11926b5404ce1c9ff Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 8 Sep 2014 19:50:46 +0100 Subject: Fix bug in state handling where we incorrectly identified a missing pdu. Update tests to catch this case. --- synapse/state.py | 92 +++++++++---------- synapse/storage/pdu.py | 9 +- tests/test_state.py | 233 +++++++++++++++++++++++++++++++++++++++++++++---- 3 files changed, 267 insertions(+), 67 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/state.py b/synapse/state.py index 5dcff27367..e69282860a 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -134,7 +134,9 @@ class StateHandler(object): @defer.inlineCallbacks @log_function def _handle_new_state(self, new_pdu): - tree = yield self.store.get_unresolved_state_tree(new_pdu) + tree, missing_branch = yield self.store.get_unresolved_state_tree( + new_pdu + ) new_branch, current_branch = tree logger.debug( @@ -142,6 +144,28 @@ class StateHandler(object): new_branch, current_branch ) + if missing_branch is not None: + # We're missing some PDUs. Fetch them. + # TODO (erikj): Limit this. + missing_prev = tree[missing_branch][-1] + + pdu_id = missing_prev.prev_state_id + origin = missing_prev.prev_state_origin + + is_missing = yield self.store.get_pdu(pdu_id, origin) is None + if not is_missing: + raise Exception("Conflict resolution failed") + + yield self._replication.get_pdu( + destination=missing_prev.origin, + pdu_origin=origin, + pdu_id=pdu_id, + outlier=True + ) + + updated_current = yield self._handle_new_state(new_pdu) + defer.returnValue(updated_current) + if not current_branch: # There is no current state defer.returnValue(True) @@ -151,65 +175,35 @@ class StateHandler(object): c = current_branch[-1] if n.pdu_id == c.pdu_id and n.origin == c.origin: - # We have all the PDUs we need, so we can just do the conflict - # resolution. + # We found a common ancestor! if len(current_branch) == 1: # This is a direct clobber so we can just... defer.returnValue(True) - conflict_res = [ - self._do_power_level_conflict_res, - self._do_chain_length_conflict_res, - self._do_hash_conflict_res, - ] - - for algo in conflict_res: - new_res, curr_res = algo(new_branch, current_branch) - - if new_res < curr_res: - defer.returnValue(False) - elif new_res > curr_res: - defer.returnValue(True) - - raise Exception("Conflict resolution failed.") - else: - # We need to ask for PDUs. - missing_prev = max( - new_branch[-1], current_branch[-1], - key=lambda x: x.depth - ) - - if not hasattr(missing_prev, "prev_state_id"): - # FIXME Hmm - # temporary fallback - for algo in conflict_res: - new_res, curr_res = algo(new_branch, current_branch) - - if new_res < curr_res: - defer.returnValue(False) - elif new_res > curr_res: - defer.returnValue(True) - return + # We didn't find a common ancestor. This is probably fine. + pass - pdu_id = missing_prev.prev_state_id - origin = missing_prev.prev_state_origin + result = self._do_conflict_res(new_branch, current_branch) + defer.returnValue(result) - is_missing = yield self.store.get_pdu(pdu_id, origin) is None + def _do_conflict_res(self, new_branch, current_branch): + conflict_res = [ + self._do_power_level_conflict_res, + self._do_chain_length_conflict_res, + self._do_hash_conflict_res, + ] - if not is_missing: - raise Exception("Conflict resolution failed.") + for algo in conflict_res: + new_res, curr_res = algo(new_branch, current_branch) - yield self._replication.get_pdu( - destination=missing_prev.origin, - pdu_origin=origin, - pdu_id=pdu_id, - outlier=True - ) + if new_res < curr_res: + defer.returnValue(False) + elif new_res > curr_res: + defer.returnValue(True) - updated_current = yield self._handle_new_state(new_pdu) - defer.returnValue(updated_current) + raise Exception("Conflict resolution failed.") def _do_power_level_conflict_res(self, new_branch, current_branch): max_power_new = max( diff --git a/synapse/storage/pdu.py b/synapse/storage/pdu.py index 0bf97e37ee..3cbce2d0a1 100644 --- a/synapse/storage/pdu.py +++ b/synapse/storage/pdu.py @@ -308,8 +308,8 @@ class PduStore(SQLBaseStore): @defer.inlineCallbacks def get_oldest_pdus_in_context(self, context): - """Get a list of Pdus that we haven't backfilled beyond yet (and haven't - seen). This list is used when we want to backfill backwards and is the + """Get a list of Pdus that we haven't backfilled beyond yet (and havent + seen). This list is used when we want to backfill backwards and is the list we send to the remote server. Args: @@ -524,13 +524,16 @@ class StatePduStore(SQLBaseStore): txn, new_pdu, current ) + missing_branch = None for branch, prev_state, state in enum_branches: if state: return_value[branch].append(state) else: + # We don't have prev_state :( + missing_branch = branch break - return return_value + return (return_value, missing_branch) def update_current_state(self, pdu_id, origin, context, pdu_type, state_key): diff --git a/tests/test_state.py b/tests/test_state.py index b01496c40f..4512475ebd 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -24,6 +24,8 @@ from collections import namedtuple from mock import Mock +import mock + ReturnType = namedtuple( "StateReturnType", ["new_branch", "current_branch"] @@ -54,7 +56,7 @@ class StateTestCase(unittest.TestCase): new_pdu = new_fake_pdu_entry("A", "test", "mem", "x", None, 10) self.persistence.get_unresolved_state_tree.return_value = ( - ReturnType([new_pdu], []) + (ReturnType([new_pdu], []), None) ) is_new = yield self.state.handle_new_state(new_pdu) @@ -78,7 +80,7 @@ class StateTestCase(unittest.TestCase): new_pdu = new_fake_pdu_entry("B", "test", "mem", "x", "A", 5) self.persistence.get_unresolved_state_tree.return_value = ( - ReturnType([new_pdu, old_pdu], [old_pdu]) + (ReturnType([new_pdu, old_pdu], [old_pdu]), None) ) is_new = yield self.state.handle_new_state(new_pdu) @@ -103,7 +105,7 @@ class StateTestCase(unittest.TestCase): new_pdu = new_fake_pdu_entry("C", "test", "mem", "x", "A", 5) self.persistence.get_unresolved_state_tree.return_value = ( - ReturnType([new_pdu, old_pdu_1], [old_pdu_2, old_pdu_1]) + (ReturnType([new_pdu, old_pdu_1], [old_pdu_2, old_pdu_1]), None) ) is_new = yield self.state.handle_new_state(new_pdu) @@ -128,7 +130,7 @@ class StateTestCase(unittest.TestCase): new_pdu = new_fake_pdu_entry("C", "test", "mem", "x", "A", 15) self.persistence.get_unresolved_state_tree.return_value = ( - ReturnType([new_pdu, old_pdu_1], [old_pdu_2, old_pdu_1]) + (ReturnType([new_pdu, old_pdu_1], [old_pdu_2, old_pdu_1]), None) ) is_new = yield self.state.handle_new_state(new_pdu) @@ -153,7 +155,7 @@ class StateTestCase(unittest.TestCase): new_pdu = new_fake_pdu_entry("C", "test", "mem", "x", "A", 10) self.persistence.get_unresolved_state_tree.return_value = ( - ReturnType([new_pdu, old_pdu_1], [old_pdu_2, old_pdu_1]) + (ReturnType([new_pdu, old_pdu_1], [old_pdu_2, old_pdu_1]), None) ) is_new = yield self.state.handle_new_state(new_pdu) @@ -179,7 +181,13 @@ class StateTestCase(unittest.TestCase): new_pdu = new_fake_pdu_entry("D", "test", "mem", "x", "C", 10) self.persistence.get_unresolved_state_tree.return_value = ( - ReturnType([new_pdu, old_pdu_3, old_pdu_1], [old_pdu_2, old_pdu_1]) + ( + ReturnType( + [new_pdu, old_pdu_3, old_pdu_1], + [old_pdu_2, old_pdu_1] + ), + None + ) ) is_new = yield self.state.handle_new_state(new_pdu) @@ -200,22 +208,32 @@ class StateTestCase(unittest.TestCase): # triggering a get_pdu request # The pdu we haven't seen - old_pdu_1 = new_fake_pdu_entry("A", "test", "mem", "x", None, 10) + old_pdu_1 = new_fake_pdu_entry( + "A", "test", "mem", "x", None, 10, depth=0 + ) - old_pdu_2 = new_fake_pdu_entry("B", "test", "mem", "x", None, 10) - new_pdu = new_fake_pdu_entry("C", "test", "mem", "x", "A", 20) + old_pdu_2 = new_fake_pdu_entry( + "B", "test", "mem", "x", "A", 10, depth=1 + ) + new_pdu = new_fake_pdu_entry( + "C", "test", "mem", "x", "A", 20, depth=2 + ) # The return_value of `get_unresolved_state_tree`, which changes after # the call to get_pdu - tree_to_return = [ReturnType([new_pdu], [old_pdu_2])] + tree_to_return = [(ReturnType([new_pdu], [old_pdu_2]), 0)] def return_tree(p): return tree_to_return[0] - def set_return_tree(*args, **kwargs): - tree_to_return[0] = ReturnType( - [new_pdu, old_pdu_1], [old_pdu_2, old_pdu_1] + def set_return_tree(destination, pdu_origin, pdu_id, outlier=False): + tree_to_return[0] = ( + ReturnType( + [new_pdu, old_pdu_1], [old_pdu_2, old_pdu_1] + ), + None ) + return defer.succeed(None) self.persistence.get_unresolved_state_tree.side_effect = return_tree @@ -227,6 +245,13 @@ class StateTestCase(unittest.TestCase): self.assertTrue(is_new) + self.replication.get_pdu.assert_called_with( + destination=new_pdu.origin, + pdu_origin=old_pdu_1.origin, + pdu_id=old_pdu_1.pdu_id, + outlier=True + ) + self.persistence.get_unresolved_state_tree.assert_called_with( new_pdu ) @@ -237,6 +262,184 @@ class StateTestCase(unittest.TestCase): self.assertEqual(1, self.persistence.update_current_state.call_count) + @defer.inlineCallbacks + def test_missing_pdu_depth_1(self): + # We try to update state against a PDU we haven't yet seen, + # triggering a get_pdu request + + # The pdu we haven't seen + old_pdu_1 = new_fake_pdu_entry( + "A", "test", "mem", "x", None, 10, depth=0 + ) + + old_pdu_2 = new_fake_pdu_entry( + "B", "test", "mem", "x", "A", 10, depth=2 + ) + old_pdu_3 = new_fake_pdu_entry( + "C", "test", "mem", "x", "B", 10, depth=3 + ) + new_pdu = new_fake_pdu_entry( + "D", "test", "mem", "x", "A", 20, depth=4 + ) + + # The return_value of `get_unresolved_state_tree`, which changes after + # the call to get_pdu + tree_to_return = [ + ( + ReturnType([new_pdu], [old_pdu_3]), + 0 + ), + ( + ReturnType( + [new_pdu, old_pdu_1], [old_pdu_3] + ), + 1 + ), + ( + ReturnType( + [new_pdu, old_pdu_1], [old_pdu_3, old_pdu_2, old_pdu_1] + ), + None + ), + ] + + to_return = [0] + + def return_tree(p): + return tree_to_return[to_return[0]] + + def set_return_tree(destination, pdu_origin, pdu_id, outlier=False): + to_return[0] += 1 + return defer.succeed(None) + + self.persistence.get_unresolved_state_tree.side_effect = return_tree + + self.replication.get_pdu.side_effect = set_return_tree + + self.persistence.get_pdu.return_value = None + + is_new = yield self.state.handle_new_state(new_pdu) + + self.assertTrue(is_new) + + self.assertEqual(2, self.replication.get_pdu.call_count) + + self.replication.get_pdu.assert_has_calls( + [ + mock.call( + destination=new_pdu.origin, + pdu_origin=old_pdu_1.origin, + pdu_id=old_pdu_1.pdu_id, + outlier=True + ), + mock.call( + destination=old_pdu_3.origin, + pdu_origin=old_pdu_2.origin, + pdu_id=old_pdu_2.pdu_id, + outlier=True + ), + ] + ) + + self.persistence.get_unresolved_state_tree.assert_called_with( + new_pdu + ) + + self.assertEquals( + 3, self.persistence.get_unresolved_state_tree.call_count + ) + + self.assertEqual(1, self.persistence.update_current_state.call_count) + + @defer.inlineCallbacks + def test_missing_pdu_depth_2(self): + # We try to update state against a PDU we haven't yet seen, + # triggering a get_pdu request + + # The pdu we haven't seen + old_pdu_1 = new_fake_pdu_entry( + "A", "test", "mem", "x", None, 10, depth=0 + ) + + old_pdu_2 = new_fake_pdu_entry( + "B", "test", "mem", "x", "A", 10, depth=2 + ) + old_pdu_3 = new_fake_pdu_entry( + "C", "test", "mem", "x", "B", 10, depth=3 + ) + new_pdu = new_fake_pdu_entry( + "D", "test", "mem", "x", "A", 20, depth=1 + ) + + # The return_value of `get_unresolved_state_tree`, which changes after + # the call to get_pdu + tree_to_return = [ + ( + ReturnType([new_pdu], [old_pdu_3]), + 1, + ), + ( + ReturnType( + [new_pdu], [old_pdu_3, old_pdu_2] + ), + 0, + ), + ( + ReturnType( + [new_pdu, old_pdu_1], [old_pdu_3, old_pdu_2, old_pdu_1] + ), + None + ), + ] + + to_return = [0] + + def return_tree(p): + return tree_to_return[to_return[0]] + + def set_return_tree(destination, pdu_origin, pdu_id, outlier=False): + to_return[0] += 1 + return defer.succeed(None) + + self.persistence.get_unresolved_state_tree.side_effect = return_tree + + self.replication.get_pdu.side_effect = set_return_tree + + self.persistence.get_pdu.return_value = None + + is_new = yield self.state.handle_new_state(new_pdu) + + self.assertTrue(is_new) + + self.assertEqual(2, self.replication.get_pdu.call_count) + + self.replication.get_pdu.assert_has_calls( + [ + mock.call( + destination=old_pdu_3.origin, + pdu_origin=old_pdu_2.origin, + pdu_id=old_pdu_2.pdu_id, + outlier=True + ), + mock.call( + destination=new_pdu.origin, + pdu_origin=old_pdu_1.origin, + pdu_id=old_pdu_1.pdu_id, + outlier=True + ), + ] + ) + + self.persistence.get_unresolved_state_tree.assert_called_with( + new_pdu + ) + + self.assertEquals( + 3, self.persistence.get_unresolved_state_tree.call_count + ) + + self.assertEqual(1, self.persistence.update_current_state.call_count) + @defer.inlineCallbacks def test_new_event(self): event = Mock() @@ -270,7 +473,7 @@ class StateTestCase(unittest.TestCase): def new_fake_pdu_entry(pdu_id, context, pdu_type, state_key, prev_state_id, - power_level): + power_level, depth=0): new_pdu = PduEntry( pdu_id=pdu_id, pdu_type=pdu_type, @@ -280,7 +483,7 @@ def new_fake_pdu_entry(pdu_id, context, pdu_type, state_key, prev_state_id, origin="example.com", context="context", ts=1405353060021, - depth=0, + depth=depth, content_json="{}", unrecognized_keys="{}", outlier=True, -- cgit 1.5.1 From e062f2dfa89eca20d409642b61bb240accb51bf1 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 8 Sep 2014 22:36:51 +0100 Subject: Apparently we can't do txn.rollback(), so raise and catch an exception instead. --- synapse/storage/__init__.py | 30 +++++++++++++++++++----------- 1 file changed, 19 insertions(+), 11 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 8ed80109a5..a2eec3b209 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -47,6 +47,11 @@ import os logger = logging.getLogger(__name__) +class _RollbackButIsFineException(Exception): + """ This exception is used to rollback a transaction without implying + something went wrong. + """ + pass class DataStore(RoomMemberStore, RoomStore, RegistrationStore, StreamStore, ProfileStore, FeedbackStore, @@ -71,13 +76,16 @@ class DataStore(RoomMemberStore, RoomStore, self.min_token -= 1 stream_ordering = self.min_token - latest = yield self._db_pool.runInteraction( - self._persist_pdu_event_txn, - pdu=pdu, - event=event, - backfilled=backfilled, - stream_ordering=stream_ordering, - ) + try: + latest = yield self._db_pool.runInteraction( + self._persist_pdu_event_txn, + pdu=pdu, + event=event, + backfilled=backfilled, + stream_ordering=stream_ordering, + ) + except _RollbackButIsFineException as e: + pass defer.returnValue(latest) @defer.inlineCallbacks @@ -175,12 +183,12 @@ class DataStore(RoomMemberStore, RoomStore, try: self._simple_insert_txn(txn, "events", vals) except: - logger.exception( + logger.warn( "Failed to persist, probably duplicate: %s", - event.event_id + event.event_id, + exc_info=True, ) - txn.rollback() - return + raise _RollbackButIsFineException("_persist_event") if not backfilled and hasattr(event, "state_key"): vals = { -- cgit 1.5.1 From a75f8686ba4c536db1a9e341786ac34bab3d25c7 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 9 Sep 2014 16:27:59 +0100 Subject: Fix bug where we used an unbound local variable if we ended up rolling back the persist_event transaction --- synapse/storage/__init__.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index a2eec3b209..ad2a484c16 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -77,7 +77,7 @@ class DataStore(RoomMemberStore, RoomStore, stream_ordering = self.min_token try: - latest = yield self._db_pool.runInteraction( + yield self._db_pool.runInteraction( self._persist_pdu_event_txn, pdu=pdu, event=event, @@ -86,7 +86,6 @@ class DataStore(RoomMemberStore, RoomStore, ) except _RollbackButIsFineException as e: pass - defer.returnValue(latest) @defer.inlineCallbacks def get_event(self, event_id, allow_none=False): @@ -214,8 +213,6 @@ class DataStore(RoomMemberStore, RoomStore, } ) - return self._get_room_events_max_id_txn(txn) - @defer.inlineCallbacks def get_current_state(self, room_id, event_type=None, state_key=""): sql = ( -- cgit 1.5.1 From ca1ae7cf9b4c75f677f453be1fb7d9a06c17194d Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 12 Sep 2014 13:54:13 +0100 Subject: Fix bug where we didn't return a tuple when expected. --- synapse/storage/pdu.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse/storage') diff --git a/synapse/storage/pdu.py b/synapse/storage/pdu.py index 3cbce2d0a1..f780111b3b 100644 --- a/synapse/storage/pdu.py +++ b/synapse/storage/pdu.py @@ -516,7 +516,7 @@ class StatePduStore(SQLBaseStore): if not current: logger.debug("get_unresolved_state_tree No current state.") - return return_value + return (return_value, None) return_value.current_branch.append(current) -- cgit 1.5.1 From b42fe05c516ffc8e049ab9b56451cceb813bdf64 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 12 Sep 2014 17:09:55 +0100 Subject: Fix bug where we incorrectly removed a remote host from the list of hosts in a room when any user from that host left that room even if they weren't the last user from that host in that room --- synapse/storage/roommember.py | 57 ++++++++++++++++++++++++++++++++++--------- 1 file changed, 45 insertions(+), 12 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 9a393e2568..20f22057a2 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -18,6 +18,7 @@ from twisted.internet import defer from ._base import SQLBaseStore from synapse.api.constants import Membership +from synapse.util.logutils import log_function import logging @@ -29,8 +30,18 @@ class RoomMemberStore(SQLBaseStore): def _store_room_member_txn(self, txn, event): """Store a room member in the database. """ - target_user_id = event.state_key - domain = self.hs.parse_userid(target_user_id).domain + try: + target_user_id = event.state_key + domain = self.hs.parse_userid(target_user_id).domain + except: + logger.exception("Failed to parse target_user_id=%s", target_user_id) + raise + + logger.debug( + "_store_room_member_txn: target_user_id=%s, membership=%s", + target_user_id, + event.membership, + ) self._simple_insert_txn( txn, @@ -51,12 +62,30 @@ class RoomMemberStore(SQLBaseStore): "VALUES (?, ?)" ) txn.execute(sql, (event.room_id, domain)) - else: - sql = ( - "DELETE FROM room_hosts WHERE room_id = ? AND host = ?" + elif event.membership != Membership.INVITE: + # Check if this was the last person to have left. + member_events = self._get_members_query_txn( + txn, + where_clause="c.room_id = ? AND m.membership = ?", + where_values=(event.room_id, Membership.JOIN,) ) - txn.execute(sql, (event.room_id, domain)) + joined_domains = set() + for e in member_events: + try: + joined_domains.add( + self.hs.parse_userid(e.state_key).domain + ) + except: + # FIXME: How do we deal with invalid user ids in the db? + logger.exception("Invalid user_id: %s", event.state_key) + + if domain not in joined_domains: + sql = ( + "DELETE FROM room_hosts WHERE room_id = ? AND host = ?" + ) + + txn.execute(sql, (event.room_id, domain)) @defer.inlineCallbacks def get_room_member(self, user_id, room_id): @@ -146,8 +175,13 @@ class RoomMemberStore(SQLBaseStore): vals = where_dict.values() return self._get_members_query(clause, vals) - @defer.inlineCallbacks def _get_members_query(self, where_clause, where_values): + return self._db_pool.runInteraction( + self._get_members_query_txn, + where_clause, where_values + ) + + def _get_members_query_txn(self, txn, where_clause, where_values): sql = ( "SELECT e.* FROM events as e " "INNER JOIN room_memberships as m " @@ -157,12 +191,11 @@ class RoomMemberStore(SQLBaseStore): "WHERE %s " ) % (where_clause,) - rows = yield self._execute_and_decode(sql, *where_values) - - # logger.debug("_get_members_query Got rows %s", rows) + txn.execute(sql, where_values) + rows = self.cursor_to_dict(txn) - results = yield self._parse_events(rows) - defer.returnValue(results) + results = self._parse_events_txn(txn, rows) + return results @defer.inlineCallbacks def user_rooms_intersect(self, user_list): -- cgit 1.5.1 From 39e3fc69e5a190371aa6936bfea57e9f8bd5255b Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 12 Sep 2014 17:11:00 +0100 Subject: Make the state resolution use actual power levels rather than taking them from a Pdu key. --- synapse/federation/units.py | 1 + synapse/state.py | 46 ++++++++--- synapse/storage/_base.py | 8 ++ synapse/storage/pdu.py | 81 +++---------------- tests/test_state.py | 185 +++++++++++++++++++++++++++++++++----------- 5 files changed, 194 insertions(+), 127 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/federation/units.py b/synapse/federation/units.py index 9740431279..622fe66a8f 100644 --- a/synapse/federation/units.py +++ b/synapse/federation/units.py @@ -69,6 +69,7 @@ class Pdu(JsonEncodedObject): "prev_state_id", "prev_state_origin", "required_power_level", + "user_id", ] internal_keys = [ diff --git a/synapse/state.py b/synapse/state.py index 0cc1344d51..9db84c9b5c 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -115,6 +115,8 @@ class StateHandler(object): is_new = yield self._handle_new_state(new_pdu) + logger.debug("is_new: %s %s %s", is_new, new_pdu.pdu_id, new_pdu.origin) + if is_new: yield self.store.update_current_state( pdu_id=new_pdu.pdu_id, @@ -187,11 +189,12 @@ class StateHandler(object): # We didn't find a common ancestor. This is probably fine. pass - result = self._do_conflict_res( + result = yield self._do_conflict_res( new_branch, current_branch, common_ancestor ) defer.returnValue(result) + @defer.inlineCallbacks def _do_conflict_res(self, new_branch, current_branch, common_ancestor): conflict_res = [ self._do_power_level_conflict_res, @@ -200,7 +203,8 @@ class StateHandler(object): ] for algo in conflict_res: - new_res, curr_res = algo( + new_res, curr_res = yield defer.maybeDeferred( + algo, new_branch, current_branch, common_ancestor ) @@ -211,19 +215,39 @@ class StateHandler(object): raise Exception("Conflict resolution failed.") + @defer.inlineCallbacks def _do_power_level_conflict_res(self, new_branch, current_branch, common_ancestor): - max_power_new = max( - new_branch[:-1] if common_ancestor else new_branch, - key=lambda t: t.power_level - ).power_level + new_powers_deferreds = [] + for e in new_branch[:-1] if common_ancestor else new_branch: + if hasattr(e, "user_id"): + new_powers_deferreds.append( + self.store.get_power_level(e.context, e.user_id) + ) + + current_powers_deferreds = [] + for e in current_branch[:-1] if common_ancestor else current_branch: + if hasattr(e, "user_id"): + current_powers_deferreds.append( + self.store.get_power_level(e.context, e.user_id) + ) + + new_powers = yield defer.gatherResults( + new_powers_deferreds, + consumeErrors=True + ) - max_power_current = max( - current_branch[:-1] if common_ancestor else current_branch, - key=lambda t: t.power_level - ).power_level + current_powers = yield defer.gatherResults( + current_powers_deferreds, + consumeErrors=True + ) + + max_power_new = max(new_powers) + max_power_current = max(current_powers) - return (max_power_new, max_power_current) + defer.returnValue( + (max_power_new, max_power_current) + ) def _do_chain_length_conflict_res(self, new_branch, current_branch, common_ancestor): diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 8037225079..8deaaf93bd 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -17,6 +17,7 @@ import logging from twisted.internet import defer from synapse.api.errors import StoreError +from synapse.util.logutils import log_function import collections import copy @@ -91,6 +92,7 @@ class SQLBaseStore(object): self._simple_insert_txn, table, values, or_replace=or_replace ) + @log_function def _simple_insert_txn(self, txn, table, values, or_replace=False): sql = "%s INTO %s (%s) VALUES(%s)" % ( ("INSERT OR REPLACE" if or_replace else "INSERT"), @@ -98,6 +100,12 @@ class SQLBaseStore(object): ", ".join(k for k in values), ", ".join("?" for k in values) ) + + logger.debug( + "[SQL] %s Args=%s Func=%s", + sql, values.values(), + ) + txn.execute(sql, values.values()) return txn.lastrowid diff --git a/synapse/storage/pdu.py b/synapse/storage/pdu.py index f780111b3b..3c859fdeac 100644 --- a/synapse/storage/pdu.py +++ b/synapse/storage/pdu.py @@ -17,6 +17,7 @@ from twisted.internet import defer from ._base import SQLBaseStore, Table, JoinHelper +from synapse.federation.units import Pdu from synapse.util.logutils import log_function from collections import namedtuple @@ -625,53 +626,6 @@ class StatePduStore(SQLBaseStore): return result - def get_next_missing_pdu(self, new_pdu): - """When we get a new state pdu we need to check whether we need to do - any conflict resolution, if we do then we need to check if we need - to go back and request some more state pdus that we haven't seen yet. - - Args: - txn - new_pdu - - Returns: - PduIdTuple: A pdu that we are missing, or None if we have all the - pdus required to do the conflict resolution. - """ - return self._db_pool.runInteraction( - self._get_next_missing_pdu, new_pdu - ) - - def _get_next_missing_pdu(self, txn, new_pdu): - logger.debug( - "get_next_missing_pdu %s %s", - new_pdu.pdu_id, new_pdu.origin - ) - - current = self._get_current_interaction( - txn, - new_pdu.context, new_pdu.pdu_type, new_pdu.state_key - ) - - if (not current or not current.prev_state_id - or not current.prev_state_origin): - return None - - # Oh look, it's a straight clobber, so wooooo almost no-op. - if (new_pdu.prev_state_id == current.pdu_id - and new_pdu.prev_state_origin == current.origin): - return None - - enum_branches = self._enumerate_state_branches(txn, new_pdu, current) - for branch, prev_state, state in enum_branches: - if not state: - return PduIdTuple( - prev_state.prev_state_id, - prev_state.prev_state_origin - ) - - return None - def handle_new_state(self, new_pdu): """Actually perform conflict resolution on the new_pdu on the assumption we have all the pdus required to perform it. @@ -755,24 +709,11 @@ class StatePduStore(SQLBaseStore): return is_current - @classmethod @log_function - def _enumerate_state_branches(cls, txn, pdu_a, pdu_b): + def _enumerate_state_branches(self, txn, pdu_a, pdu_b): branch_a = pdu_a branch_b = pdu_b - get_query = ( - "SELECT %(fields)s FROM %(pdus)s as p " - "LEFT JOIN %(state)s as s " - "ON p.pdu_id = s.pdu_id AND p.origin = s.origin " - "WHERE p.pdu_id = ? AND p.origin = ? " - ) % { - "fields": _pdu_state_joiner.get_fields( - PdusTable="p", StatePdusTable="s"), - "pdus": PdusTable.table_name, - "state": StatePdusTable.table_name, - } - while True: if (branch_a.pdu_id == branch_b.pdu_id and branch_a.origin == branch_b.origin): @@ -804,13 +745,12 @@ class StatePduStore(SQLBaseStore): branch_a.prev_state_origin ) - logger.debug("getting branch_a prev %s", pdu_tuple) - txn.execute(get_query, pdu_tuple) - prev_branch = branch_a - res = txn.fetchone() - branch_a = PduEntry(*res) if res else None + logger.debug("getting branch_a prev %s", pdu_tuple) + branch_a = self._get_pdu_tuple(txn, *pdu_tuple) + if branch_a: + branch_a = Pdu.from_pdu_tuple(branch_a) logger.debug("branch_a=%s", branch_a) @@ -823,14 +763,13 @@ class StatePduStore(SQLBaseStore): branch_b.prev_state_id, branch_b.prev_state_origin ) - txn.execute(get_query, pdu_tuple) - - logger.debug("getting branch_b prev %s", pdu_tuple) prev_branch = branch_b - res = txn.fetchone() - branch_b = PduEntry(*res) if res else None + logger.debug("getting branch_b prev %s", pdu_tuple) + branch_b = self._get_pdu_tuple(txn, *pdu_tuple) + if branch_b: + branch_b = Pdu.from_pdu_tuple(branch_b) logger.debug("branch_b=%s", branch_b) diff --git a/tests/test_state.py b/tests/test_state.py index a9fc3fb85c..16af95b7bc 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -15,15 +15,18 @@ from twisted.internet import defer from twisted.trial import unittest +from twisted.python.log import PythonLoggingObserver from synapse.state import StateHandler from synapse.storage.pdu import PduEntry from synapse.federation.pdu_codec import encode_event_id +from synapse.federation.units import Pdu from collections import namedtuple from mock import Mock +import logging import mock @@ -32,6 +35,11 @@ ReturnType = namedtuple( ) +def _gen_get_power_level(power_level_list): + def get_power_level(room_id, user_id): + return defer.succeed(power_level_list.get(user_id, None)) + return get_power_level + class StateTestCase(unittest.TestCase): def setUp(self): self.persistence = Mock(spec=[ @@ -40,6 +48,7 @@ class StateTestCase(unittest.TestCase): "get_latest_pdus_in_context", "get_current_state_pdu", "get_pdu", + "get_power_level", ]) self.replication = Mock(spec=["get_pdu"]) @@ -53,7 +62,9 @@ class StateTestCase(unittest.TestCase): @defer.inlineCallbacks def test_new_state_key(self): # We've never seen anything for this state before - new_pdu = new_fake_pdu_entry("A", "test", "mem", "x", None, 10) + new_pdu = new_fake_pdu("A", "test", "mem", "x", None, "u") + + self.persistence.get_power_level.side_effect = _gen_get_power_level({}) self.persistence.get_unresolved_state_tree.return_value = ( (ReturnType([new_pdu], []), None) @@ -76,8 +87,13 @@ class StateTestCase(unittest.TestCase): # We do a direct overwriting of the old state, i.e., the new state # points to the old state. - old_pdu = new_fake_pdu_entry("A", "test", "mem", "x", None, 10) - new_pdu = new_fake_pdu_entry("B", "test", "mem", "x", "A", 5) + old_pdu = new_fake_pdu("A", "test", "mem", "x", None, "u1") + new_pdu = new_fake_pdu("B", "test", "mem", "x", "A", "u2") + + self.persistence.get_power_level.side_effect = _gen_get_power_level({ + "u1": 10, + "u2": 5, + }) self.persistence.get_unresolved_state_tree.return_value = ( (ReturnType([new_pdu, old_pdu], [old_pdu]), None) @@ -95,14 +111,48 @@ class StateTestCase(unittest.TestCase): self.assertFalse(self.replication.get_pdu.called) + @defer.inlineCallbacks + def test_overwrite(self): + old_pdu_1 = new_fake_pdu("A", "test", "mem", "x", None, "u1") + old_pdu_2 = new_fake_pdu("B", "test", "mem", "x", "A", "u2") + new_pdu = new_fake_pdu("C", "test", "mem", "x", "B", "u3") + + self.persistence.get_power_level.side_effect = _gen_get_power_level({ + "u1": 10, + "u2": 5, + "u3": 0, + }) + + self.persistence.get_unresolved_state_tree.return_value = ( + (ReturnType([new_pdu, old_pdu_2, old_pdu_1], [old_pdu_1]), None) + ) + + is_new = yield self.state.handle_new_state(new_pdu) + + self.assertTrue(is_new) + + self.persistence.get_unresolved_state_tree.assert_called_once_with( + new_pdu + ) + + self.assertEqual(1, self.persistence.update_current_state.call_count) + + self.assertFalse(self.replication.get_pdu.called) + @defer.inlineCallbacks def test_power_level_fail(self): # We try to update the state based on an outdated state, and have a # too low power level. - old_pdu_1 = new_fake_pdu_entry("A", "test", "mem", "x", None, 10) - old_pdu_2 = new_fake_pdu_entry("B", "test", "mem", "x", None, 10) - new_pdu = new_fake_pdu_entry("C", "test", "mem", "x", "A", 5) + old_pdu_1 = new_fake_pdu("A", "test", "mem", "x", None, "u1") + old_pdu_2 = new_fake_pdu("B", "test", "mem", "x", None, "u2") + new_pdu = new_fake_pdu("C", "test", "mem", "x", "A", "u3") + + self.persistence.get_power_level.side_effect = _gen_get_power_level({ + "u1": 10, + "u2": 10, + "u3": 5, + }) self.persistence.get_unresolved_state_tree.return_value = ( (ReturnType([new_pdu, old_pdu_1], [old_pdu_2, old_pdu_1]), None) @@ -125,9 +175,15 @@ class StateTestCase(unittest.TestCase): # We try to update the state based on an outdated state, but have # sufficient power level to force the update. - old_pdu_1 = new_fake_pdu_entry("A", "test", "mem", "x", None, 10) - old_pdu_2 = new_fake_pdu_entry("B", "test", "mem", "x", None, 10) - new_pdu = new_fake_pdu_entry("C", "test", "mem", "x", "A", 15) + old_pdu_1 = new_fake_pdu("A", "test", "mem", "x", None, "u1") + old_pdu_2 = new_fake_pdu("B", "test", "mem", "x", None, "u2") + new_pdu = new_fake_pdu("C", "test", "mem", "x", "A", "u3") + + self.persistence.get_power_level.side_effect = _gen_get_power_level({ + "u1": 10, + "u2": 10, + "u3": 15, + }) self.persistence.get_unresolved_state_tree.return_value = ( (ReturnType([new_pdu, old_pdu_1], [old_pdu_2, old_pdu_1]), None) @@ -150,9 +206,15 @@ class StateTestCase(unittest.TestCase): # We try to update the state based on an outdated state, the power # levels are the same and so are the branch lengths - old_pdu_1 = new_fake_pdu_entry("A", "test", "mem", "x", None, 10) - old_pdu_2 = new_fake_pdu_entry("B", "test", "mem", "x", None, 10) - new_pdu = new_fake_pdu_entry("C", "test", "mem", "x", "A", 10) + old_pdu_1 = new_fake_pdu("A", "test", "mem", "x", None, "u1") + old_pdu_2 = new_fake_pdu("B", "test", "mem", "x", None, "u2") + new_pdu = new_fake_pdu("C", "test", "mem", "x", "A", "u3") + + self.persistence.get_power_level.side_effect = _gen_get_power_level({ + "u1": 10, + "u2": 10, + "u3": 10, + }) self.persistence.get_unresolved_state_tree.return_value = ( (ReturnType([new_pdu, old_pdu_1], [old_pdu_2, old_pdu_1]), None) @@ -175,10 +237,17 @@ class StateTestCase(unittest.TestCase): # We try to update the state based on an outdated state, the power # levels are the same but the branch length of the new one is longer. - old_pdu_1 = new_fake_pdu_entry("A", "test", "mem", "x", None, 10) - old_pdu_2 = new_fake_pdu_entry("B", "test", "mem", "x", None, 10) - old_pdu_3 = new_fake_pdu_entry("C", "test", "mem", "x", "A", 10) - new_pdu = new_fake_pdu_entry("D", "test", "mem", "x", "C", 10) + old_pdu_1 = new_fake_pdu("A", "test", "mem", "x", None, "u1") + old_pdu_2 = new_fake_pdu("B", "test", "mem", "x", None, "u2") + old_pdu_3 = new_fake_pdu("C", "test", "mem", "x", "A", "u3") + new_pdu = new_fake_pdu("D", "test", "mem", "x", "C", "u4") + + self.persistence.get_power_level.side_effect = _gen_get_power_level({ + "u1": 10, + "u2": 10, + "u3": 10, + "u4": 10, + }) self.persistence.get_unresolved_state_tree.return_value = ( ( @@ -208,17 +277,23 @@ class StateTestCase(unittest.TestCase): # triggering a get_pdu request # The pdu we haven't seen - old_pdu_1 = new_fake_pdu_entry( - "A", "test", "mem", "x", None, 10, depth=0 + old_pdu_1 = new_fake_pdu( + "A", "test", "mem", "x", None, "u1", depth=0 ) - old_pdu_2 = new_fake_pdu_entry( - "B", "test", "mem", "x", "A", 10, depth=1 + old_pdu_2 = new_fake_pdu( + "B", "test", "mem", "x", "A", "u2", depth=1 ) - new_pdu = new_fake_pdu_entry( - "C", "test", "mem", "x", "A", 20, depth=2 + new_pdu = new_fake_pdu( + "C", "test", "mem", "x", "A", "u3", depth=2 ) + self.persistence.get_power_level.side_effect = _gen_get_power_level({ + "u1": 10, + "u2": 10, + "u3": 20, + }) + # The return_value of `get_unresolved_state_tree`, which changes after # the call to get_pdu tree_to_return = [(ReturnType([new_pdu], [old_pdu_2]), 0)] @@ -268,20 +343,27 @@ class StateTestCase(unittest.TestCase): # triggering a get_pdu request # The pdu we haven't seen - old_pdu_1 = new_fake_pdu_entry( - "A", "test", "mem", "x", None, 10, depth=0 + old_pdu_1 = new_fake_pdu( + "A", "test", "mem", "x", None, "u1", depth=0 ) - old_pdu_2 = new_fake_pdu_entry( - "B", "test", "mem", "x", "A", 10, depth=2 + old_pdu_2 = new_fake_pdu( + "B", "test", "mem", "x", "A", "u2", depth=2 ) - old_pdu_3 = new_fake_pdu_entry( - "C", "test", "mem", "x", "B", 10, depth=3 + old_pdu_3 = new_fake_pdu( + "C", "test", "mem", "x", "B", "u3", depth=3 ) - new_pdu = new_fake_pdu_entry( - "D", "test", "mem", "x", "A", 20, depth=4 + new_pdu = new_fake_pdu( + "D", "test", "mem", "x", "A", "u4", depth=4 ) + self.persistence.get_power_level.side_effect = _gen_get_power_level({ + "u1": 10, + "u2": 10, + "u3": 10, + "u4": 20, + }) + # The return_value of `get_unresolved_state_tree`, which changes after # the call to get_pdu tree_to_return = [ @@ -357,20 +439,27 @@ class StateTestCase(unittest.TestCase): # triggering a get_pdu request # The pdu we haven't seen - old_pdu_1 = new_fake_pdu_entry( - "A", "test", "mem", "x", None, 10, depth=0 + old_pdu_1 = new_fake_pdu( + "A", "test", "mem", "x", None, "u1", depth=0 ) - old_pdu_2 = new_fake_pdu_entry( - "B", "test", "mem", "x", "A", 10, depth=2 + old_pdu_2 = new_fake_pdu( + "B", "test", "mem", "x", "A", "u2", depth=2 ) - old_pdu_3 = new_fake_pdu_entry( - "C", "test", "mem", "x", "B", 10, depth=3 + old_pdu_3 = new_fake_pdu( + "C", "test", "mem", "x", "B", "u3", depth=3 ) - new_pdu = new_fake_pdu_entry( - "D", "test", "mem", "x", "A", 20, depth=1 + new_pdu = new_fake_pdu( + "D", "test", "mem", "x", "A", "u4", depth=1 ) + self.persistence.get_power_level.side_effect = _gen_get_power_level({ + "u1": 10, + "u2": 10, + "u3": 10, + "u4": 20, + }) + # The return_value of `get_unresolved_state_tree`, which changes after # the call to get_pdu tree_to_return = [ @@ -445,8 +534,13 @@ class StateTestCase(unittest.TestCase): # We do a direct overwriting of the old state, i.e., the new state # points to the old state. - old_pdu = new_fake_pdu_entry("A", "test", "mem", "x", None, 5) - new_pdu = new_fake_pdu_entry("B", "test", "mem", "x", None, 10) + old_pdu = new_fake_pdu("A", "test", "mem", "x", None, "u1") + new_pdu = new_fake_pdu("B", "test", "mem", "x", None, "u2") + + self.persistence.get_power_level.side_effect = _gen_get_power_level({ + "u1": 5, + "u2": 10, + }) self.persistence.get_unresolved_state_tree.return_value = ( (ReturnType([new_pdu], [old_pdu]), None) @@ -469,7 +563,7 @@ class StateTestCase(unittest.TestCase): event = Mock() event.event_id = "12123123@test" - state_pdu = new_fake_pdu_entry("C", "test", "mem", "x", "A", 20) + state_pdu = new_fake_pdu("C", "test", "mem", "x", "A", 20) snapshot = Mock() snapshot.prev_state_pdu = state_pdu @@ -496,13 +590,13 @@ class StateTestCase(unittest.TestCase): ) -def new_fake_pdu_entry(pdu_id, context, pdu_type, state_key, prev_state_id, - power_level, depth=0): - new_pdu = PduEntry( +def new_fake_pdu(pdu_id, context, pdu_type, state_key, prev_state_id, + user_id, depth=0): + new_pdu = Pdu( pdu_id=pdu_id, pdu_type=pdu_type, state_key=state_key, - power_level=power_level, + user_id=user_id, prev_state_id=prev_state_id, origin="example.com", context="context", @@ -514,6 +608,7 @@ def new_fake_pdu_entry(pdu_id, context, pdu_type, state_key, prev_state_id, is_state=True, prev_state_origin="example.com", have_processed=True, + content={}, ) return new_pdu -- cgit 1.5.1 From 667e747ed11a418da317a03fc3c59a205c5c4af0 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 12 Sep 2014 17:56:21 +0100 Subject: Fix bug where we no longer stored user_id on Pdus --- synapse/storage/__init__.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) (limited to 'synapse/storage') diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index ad2a484c16..9201a377b6 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -36,7 +36,7 @@ from .registration import RegistrationStore from .room import RoomStore from .roommember import RoomMemberStore from .stream import StreamStore -from .pdu import StatePduStore, PduStore +from .pdu import StatePduStore, PduStore, PdusTable from .transactions import TransactionStore from .keys import KeyStore @@ -123,6 +123,12 @@ class DataStore(RoomMemberStore, RoomStore, del cols["content"] del cols["prev_pdus"] cols["content_json"] = json.dumps(pdu.content) + + unrec_keys.update({ + k: v for k, v in cols.items() + if k not in PdusTable.fields + }) + cols["unrecognized_keys"] = json.dumps(unrec_keys) logger.debug("Persisting: %s", repr(cols)) -- cgit 1.5.1 From 14975ce5bcf4dac2720cc4be290100a580334393 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 12 Sep 2014 17:57:02 +0100 Subject: Fix bug where we relied on the current_state_events being updated when we are handling type specific persistence --- synapse/storage/roommember.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 20f22057a2..676b2f2653 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -66,8 +66,8 @@ class RoomMemberStore(SQLBaseStore): # Check if this was the last person to have left. member_events = self._get_members_query_txn( txn, - where_clause="c.room_id = ? AND m.membership = ?", - where_values=(event.room_id, Membership.JOIN,) + where_clause="c.room_id = ? AND m.membership = ? AND m.user_id != ?", + where_values=(event.room_id, Membership.JOIN, target_user_id,) ) joined_domains = set() -- cgit 1.5.1