From 83ce57302dab6a825f3afde11926b5404ce1c9ff Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 8 Sep 2014 19:50:46 +0100 Subject: Fix bug in state handling where we incorrectly identified a missing pdu. Update tests to catch this case. --- tests/test_state.py | 233 ++++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 218 insertions(+), 15 deletions(-) (limited to 'tests') diff --git a/tests/test_state.py b/tests/test_state.py index b01496c40f..4512475ebd 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -24,6 +24,8 @@ from collections import namedtuple from mock import Mock +import mock + ReturnType = namedtuple( "StateReturnType", ["new_branch", "current_branch"] @@ -54,7 +56,7 @@ class StateTestCase(unittest.TestCase): new_pdu = new_fake_pdu_entry("A", "test", "mem", "x", None, 10) self.persistence.get_unresolved_state_tree.return_value = ( - ReturnType([new_pdu], []) + (ReturnType([new_pdu], []), None) ) is_new = yield self.state.handle_new_state(new_pdu) @@ -78,7 +80,7 @@ class StateTestCase(unittest.TestCase): new_pdu = new_fake_pdu_entry("B", "test", "mem", "x", "A", 5) self.persistence.get_unresolved_state_tree.return_value = ( - ReturnType([new_pdu, old_pdu], [old_pdu]) + (ReturnType([new_pdu, old_pdu], [old_pdu]), None) ) is_new = yield self.state.handle_new_state(new_pdu) @@ -103,7 +105,7 @@ class StateTestCase(unittest.TestCase): new_pdu = new_fake_pdu_entry("C", "test", "mem", "x", "A", 5) self.persistence.get_unresolved_state_tree.return_value = ( - ReturnType([new_pdu, old_pdu_1], [old_pdu_2, old_pdu_1]) + (ReturnType([new_pdu, old_pdu_1], [old_pdu_2, old_pdu_1]), None) ) is_new = yield self.state.handle_new_state(new_pdu) @@ -128,7 +130,7 @@ class StateTestCase(unittest.TestCase): new_pdu = new_fake_pdu_entry("C", "test", "mem", "x", "A", 15) self.persistence.get_unresolved_state_tree.return_value = ( - ReturnType([new_pdu, old_pdu_1], [old_pdu_2, old_pdu_1]) + (ReturnType([new_pdu, old_pdu_1], [old_pdu_2, old_pdu_1]), None) ) is_new = yield self.state.handle_new_state(new_pdu) @@ -153,7 +155,7 @@ class StateTestCase(unittest.TestCase): new_pdu = new_fake_pdu_entry("C", "test", "mem", "x", "A", 10) self.persistence.get_unresolved_state_tree.return_value = ( - ReturnType([new_pdu, old_pdu_1], [old_pdu_2, old_pdu_1]) + (ReturnType([new_pdu, old_pdu_1], [old_pdu_2, old_pdu_1]), None) ) is_new = yield self.state.handle_new_state(new_pdu) @@ -179,7 +181,13 @@ class StateTestCase(unittest.TestCase): new_pdu = new_fake_pdu_entry("D", "test", "mem", "x", "C", 10) self.persistence.get_unresolved_state_tree.return_value = ( - ReturnType([new_pdu, old_pdu_3, old_pdu_1], [old_pdu_2, old_pdu_1]) + ( + ReturnType( + [new_pdu, old_pdu_3, old_pdu_1], + [old_pdu_2, old_pdu_1] + ), + None + ) ) is_new = yield self.state.handle_new_state(new_pdu) @@ -200,22 +208,32 @@ class StateTestCase(unittest.TestCase): # triggering a get_pdu request # The pdu we haven't seen - old_pdu_1 = new_fake_pdu_entry("A", "test", "mem", "x", None, 10) + old_pdu_1 = new_fake_pdu_entry( + "A", "test", "mem", "x", None, 10, depth=0 + ) - old_pdu_2 = new_fake_pdu_entry("B", "test", "mem", "x", None, 10) - new_pdu = new_fake_pdu_entry("C", "test", "mem", "x", "A", 20) + old_pdu_2 = new_fake_pdu_entry( + "B", "test", "mem", "x", "A", 10, depth=1 + ) + new_pdu = new_fake_pdu_entry( + "C", "test", "mem", "x", "A", 20, depth=2 + ) # The return_value of `get_unresolved_state_tree`, which changes after # the call to get_pdu - tree_to_return = [ReturnType([new_pdu], [old_pdu_2])] + tree_to_return = [(ReturnType([new_pdu], [old_pdu_2]), 0)] def return_tree(p): return tree_to_return[0] - def set_return_tree(*args, **kwargs): - tree_to_return[0] = ReturnType( - [new_pdu, old_pdu_1], [old_pdu_2, old_pdu_1] + def set_return_tree(destination, pdu_origin, pdu_id, outlier=False): + tree_to_return[0] = ( + ReturnType( + [new_pdu, old_pdu_1], [old_pdu_2, old_pdu_1] + ), + None ) + return defer.succeed(None) self.persistence.get_unresolved_state_tree.side_effect = return_tree @@ -227,6 +245,13 @@ class StateTestCase(unittest.TestCase): self.assertTrue(is_new) + self.replication.get_pdu.assert_called_with( + destination=new_pdu.origin, + pdu_origin=old_pdu_1.origin, + pdu_id=old_pdu_1.pdu_id, + outlier=True + ) + self.persistence.get_unresolved_state_tree.assert_called_with( new_pdu ) @@ -237,6 +262,184 @@ class StateTestCase(unittest.TestCase): self.assertEqual(1, self.persistence.update_current_state.call_count) + @defer.inlineCallbacks + def test_missing_pdu_depth_1(self): + # We try to update state against a PDU we haven't yet seen, + # triggering a get_pdu request + + # The pdu we haven't seen + old_pdu_1 = new_fake_pdu_entry( + "A", "test", "mem", "x", None, 10, depth=0 + ) + + old_pdu_2 = new_fake_pdu_entry( + "B", "test", "mem", "x", "A", 10, depth=2 + ) + old_pdu_3 = new_fake_pdu_entry( + "C", "test", "mem", "x", "B", 10, depth=3 + ) + new_pdu = new_fake_pdu_entry( + "D", "test", "mem", "x", "A", 20, depth=4 + ) + + # The return_value of `get_unresolved_state_tree`, which changes after + # the call to get_pdu + tree_to_return = [ + ( + ReturnType([new_pdu], [old_pdu_3]), + 0 + ), + ( + ReturnType( + [new_pdu, old_pdu_1], [old_pdu_3] + ), + 1 + ), + ( + ReturnType( + [new_pdu, old_pdu_1], [old_pdu_3, old_pdu_2, old_pdu_1] + ), + None + ), + ] + + to_return = [0] + + def return_tree(p): + return tree_to_return[to_return[0]] + + def set_return_tree(destination, pdu_origin, pdu_id, outlier=False): + to_return[0] += 1 + return defer.succeed(None) + + self.persistence.get_unresolved_state_tree.side_effect = return_tree + + self.replication.get_pdu.side_effect = set_return_tree + + self.persistence.get_pdu.return_value = None + + is_new = yield self.state.handle_new_state(new_pdu) + + self.assertTrue(is_new) + + self.assertEqual(2, self.replication.get_pdu.call_count) + + self.replication.get_pdu.assert_has_calls( + [ + mock.call( + destination=new_pdu.origin, + pdu_origin=old_pdu_1.origin, + pdu_id=old_pdu_1.pdu_id, + outlier=True + ), + mock.call( + destination=old_pdu_3.origin, + pdu_origin=old_pdu_2.origin, + pdu_id=old_pdu_2.pdu_id, + outlier=True + ), + ] + ) + + self.persistence.get_unresolved_state_tree.assert_called_with( + new_pdu + ) + + self.assertEquals( + 3, self.persistence.get_unresolved_state_tree.call_count + ) + + self.assertEqual(1, self.persistence.update_current_state.call_count) + + @defer.inlineCallbacks + def test_missing_pdu_depth_2(self): + # We try to update state against a PDU we haven't yet seen, + # triggering a get_pdu request + + # The pdu we haven't seen + old_pdu_1 = new_fake_pdu_entry( + "A", "test", "mem", "x", None, 10, depth=0 + ) + + old_pdu_2 = new_fake_pdu_entry( + "B", "test", "mem", "x", "A", 10, depth=2 + ) + old_pdu_3 = new_fake_pdu_entry( + "C", "test", "mem", "x", "B", 10, depth=3 + ) + new_pdu = new_fake_pdu_entry( + "D", "test", "mem", "x", "A", 20, depth=1 + ) + + # The return_value of `get_unresolved_state_tree`, which changes after + # the call to get_pdu + tree_to_return = [ + ( + ReturnType([new_pdu], [old_pdu_3]), + 1, + ), + ( + ReturnType( + [new_pdu], [old_pdu_3, old_pdu_2] + ), + 0, + ), + ( + ReturnType( + [new_pdu, old_pdu_1], [old_pdu_3, old_pdu_2, old_pdu_1] + ), + None + ), + ] + + to_return = [0] + + def return_tree(p): + return tree_to_return[to_return[0]] + + def set_return_tree(destination, pdu_origin, pdu_id, outlier=False): + to_return[0] += 1 + return defer.succeed(None) + + self.persistence.get_unresolved_state_tree.side_effect = return_tree + + self.replication.get_pdu.side_effect = set_return_tree + + self.persistence.get_pdu.return_value = None + + is_new = yield self.state.handle_new_state(new_pdu) + + self.assertTrue(is_new) + + self.assertEqual(2, self.replication.get_pdu.call_count) + + self.replication.get_pdu.assert_has_calls( + [ + mock.call( + destination=old_pdu_3.origin, + pdu_origin=old_pdu_2.origin, + pdu_id=old_pdu_2.pdu_id, + outlier=True + ), + mock.call( + destination=new_pdu.origin, + pdu_origin=old_pdu_1.origin, + pdu_id=old_pdu_1.pdu_id, + outlier=True + ), + ] + ) + + self.persistence.get_unresolved_state_tree.assert_called_with( + new_pdu + ) + + self.assertEquals( + 3, self.persistence.get_unresolved_state_tree.call_count + ) + + self.assertEqual(1, self.persistence.update_current_state.call_count) + @defer.inlineCallbacks def test_new_event(self): event = Mock() @@ -270,7 +473,7 @@ class StateTestCase(unittest.TestCase): def new_fake_pdu_entry(pdu_id, context, pdu_type, state_key, prev_state_id, - power_level): + power_level, depth=0): new_pdu = PduEntry( pdu_id=pdu_id, pdu_type=pdu_type, @@ -280,7 +483,7 @@ def new_fake_pdu_entry(pdu_id, context, pdu_type, state_key, prev_state_id, origin="example.com", context="context", ts=1405353060021, - depth=0, + depth=depth, content_json="{}", unrecognized_keys="{}", outlier=True, -- cgit 1.5.1 From 942d8412c49a1d481f0bedd189eb1598629b103c Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 8 Sep 2014 20:13:27 +0100 Subject: Handle the case where we don't have a common ancestor --- synapse/state.py | 27 ++++++++++++++++++--------- tests/test_state.py | 24 ++++++++++++++++++++++++ 2 files changed, 42 insertions(+), 9 deletions(-) (limited to 'tests') diff --git a/synapse/state.py b/synapse/state.py index e69282860a..0cc1344d51 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -174,7 +174,9 @@ class StateHandler(object): n = new_branch[-1] c = current_branch[-1] - if n.pdu_id == c.pdu_id and n.origin == c.origin: + common_ancestor = n.pdu_id == c.pdu_id and n.origin == c.origin + + if common_ancestor: # We found a common ancestor! if len(current_branch) == 1: @@ -185,10 +187,12 @@ class StateHandler(object): # We didn't find a common ancestor. This is probably fine. pass - result = self._do_conflict_res(new_branch, current_branch) + result = self._do_conflict_res( + new_branch, current_branch, common_ancestor + ) defer.returnValue(result) - def _do_conflict_res(self, new_branch, current_branch): + def _do_conflict_res(self, new_branch, current_branch, common_ancestor): conflict_res = [ self._do_power_level_conflict_res, self._do_chain_length_conflict_res, @@ -196,7 +200,9 @@ class StateHandler(object): ] for algo in conflict_res: - new_res, curr_res = algo(new_branch, current_branch) + new_res, curr_res = algo( + new_branch, current_branch, common_ancestor + ) if new_res < curr_res: defer.returnValue(False) @@ -205,23 +211,26 @@ class StateHandler(object): raise Exception("Conflict resolution failed.") - def _do_power_level_conflict_res(self, new_branch, current_branch): + def _do_power_level_conflict_res(self, new_branch, current_branch, + common_ancestor): max_power_new = max( - new_branch[:-1], + new_branch[:-1] if common_ancestor else new_branch, key=lambda t: t.power_level ).power_level max_power_current = max( - current_branch[:-1], + current_branch[:-1] if common_ancestor else current_branch, key=lambda t: t.power_level ).power_level return (max_power_new, max_power_current) - def _do_chain_length_conflict_res(self, new_branch, current_branch): + def _do_chain_length_conflict_res(self, new_branch, current_branch, + common_ancestor): return (len(new_branch), len(current_branch)) - def _do_hash_conflict_res(self, new_branch, current_branch): + def _do_hash_conflict_res(self, new_branch, current_branch, + common_ancestor): new_str = "".join([p.pdu_id + p.origin for p in new_branch]) c_str = "".join([p.pdu_id + p.origin for p in current_branch]) diff --git a/tests/test_state.py b/tests/test_state.py index 4512475ebd..a9fc3fb85c 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -440,6 +440,30 @@ class StateTestCase(unittest.TestCase): self.assertEqual(1, self.persistence.update_current_state.call_count) + @defer.inlineCallbacks + def test_no_common_ancestor(self): + # We do a direct overwriting of the old state, i.e., the new state + # points to the old state. + + old_pdu = new_fake_pdu_entry("A", "test", "mem", "x", None, 5) + new_pdu = new_fake_pdu_entry("B", "test", "mem", "x", None, 10) + + self.persistence.get_unresolved_state_tree.return_value = ( + (ReturnType([new_pdu], [old_pdu]), None) + ) + + is_new = yield self.state.handle_new_state(new_pdu) + + self.assertTrue(is_new) + + self.persistence.get_unresolved_state_tree.assert_called_once_with( + new_pdu + ) + + self.assertEqual(1, self.persistence.update_current_state.call_count) + + self.assertFalse(self.replication.get_pdu.called) + @defer.inlineCallbacks def test_new_event(self): event = Mock() -- cgit 1.5.1 From 39e3fc69e5a190371aa6936bfea57e9f8bd5255b Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 12 Sep 2014 17:11:00 +0100 Subject: Make the state resolution use actual power levels rather than taking them from a Pdu key. --- synapse/federation/units.py | 1 + synapse/state.py | 46 ++++++++--- synapse/storage/_base.py | 8 ++ synapse/storage/pdu.py | 81 +++---------------- tests/test_state.py | 185 +++++++++++++++++++++++++++++++++----------- 5 files changed, 194 insertions(+), 127 deletions(-) (limited to 'tests') diff --git a/synapse/federation/units.py b/synapse/federation/units.py index 9740431279..622fe66a8f 100644 --- a/synapse/federation/units.py +++ b/synapse/federation/units.py @@ -69,6 +69,7 @@ class Pdu(JsonEncodedObject): "prev_state_id", "prev_state_origin", "required_power_level", + "user_id", ] internal_keys = [ diff --git a/synapse/state.py b/synapse/state.py index 0cc1344d51..9db84c9b5c 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -115,6 +115,8 @@ class StateHandler(object): is_new = yield self._handle_new_state(new_pdu) + logger.debug("is_new: %s %s %s", is_new, new_pdu.pdu_id, new_pdu.origin) + if is_new: yield self.store.update_current_state( pdu_id=new_pdu.pdu_id, @@ -187,11 +189,12 @@ class StateHandler(object): # We didn't find a common ancestor. This is probably fine. pass - result = self._do_conflict_res( + result = yield self._do_conflict_res( new_branch, current_branch, common_ancestor ) defer.returnValue(result) + @defer.inlineCallbacks def _do_conflict_res(self, new_branch, current_branch, common_ancestor): conflict_res = [ self._do_power_level_conflict_res, @@ -200,7 +203,8 @@ class StateHandler(object): ] for algo in conflict_res: - new_res, curr_res = algo( + new_res, curr_res = yield defer.maybeDeferred( + algo, new_branch, current_branch, common_ancestor ) @@ -211,19 +215,39 @@ class StateHandler(object): raise Exception("Conflict resolution failed.") + @defer.inlineCallbacks def _do_power_level_conflict_res(self, new_branch, current_branch, common_ancestor): - max_power_new = max( - new_branch[:-1] if common_ancestor else new_branch, - key=lambda t: t.power_level - ).power_level + new_powers_deferreds = [] + for e in new_branch[:-1] if common_ancestor else new_branch: + if hasattr(e, "user_id"): + new_powers_deferreds.append( + self.store.get_power_level(e.context, e.user_id) + ) + + current_powers_deferreds = [] + for e in current_branch[:-1] if common_ancestor else current_branch: + if hasattr(e, "user_id"): + current_powers_deferreds.append( + self.store.get_power_level(e.context, e.user_id) + ) + + new_powers = yield defer.gatherResults( + new_powers_deferreds, + consumeErrors=True + ) - max_power_current = max( - current_branch[:-1] if common_ancestor else current_branch, - key=lambda t: t.power_level - ).power_level + current_powers = yield defer.gatherResults( + current_powers_deferreds, + consumeErrors=True + ) + + max_power_new = max(new_powers) + max_power_current = max(current_powers) - return (max_power_new, max_power_current) + defer.returnValue( + (max_power_new, max_power_current) + ) def _do_chain_length_conflict_res(self, new_branch, current_branch, common_ancestor): diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 8037225079..8deaaf93bd 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -17,6 +17,7 @@ import logging from twisted.internet import defer from synapse.api.errors import StoreError +from synapse.util.logutils import log_function import collections import copy @@ -91,6 +92,7 @@ class SQLBaseStore(object): self._simple_insert_txn, table, values, or_replace=or_replace ) + @log_function def _simple_insert_txn(self, txn, table, values, or_replace=False): sql = "%s INTO %s (%s) VALUES(%s)" % ( ("INSERT OR REPLACE" if or_replace else "INSERT"), @@ -98,6 +100,12 @@ class SQLBaseStore(object): ", ".join(k for k in values), ", ".join("?" for k in values) ) + + logger.debug( + "[SQL] %s Args=%s Func=%s", + sql, values.values(), + ) + txn.execute(sql, values.values()) return txn.lastrowid diff --git a/synapse/storage/pdu.py b/synapse/storage/pdu.py index f780111b3b..3c859fdeac 100644 --- a/synapse/storage/pdu.py +++ b/synapse/storage/pdu.py @@ -17,6 +17,7 @@ from twisted.internet import defer from ._base import SQLBaseStore, Table, JoinHelper +from synapse.federation.units import Pdu from synapse.util.logutils import log_function from collections import namedtuple @@ -625,53 +626,6 @@ class StatePduStore(SQLBaseStore): return result - def get_next_missing_pdu(self, new_pdu): - """When we get a new state pdu we need to check whether we need to do - any conflict resolution, if we do then we need to check if we need - to go back and request some more state pdus that we haven't seen yet. - - Args: - txn - new_pdu - - Returns: - PduIdTuple: A pdu that we are missing, or None if we have all the - pdus required to do the conflict resolution. - """ - return self._db_pool.runInteraction( - self._get_next_missing_pdu, new_pdu - ) - - def _get_next_missing_pdu(self, txn, new_pdu): - logger.debug( - "get_next_missing_pdu %s %s", - new_pdu.pdu_id, new_pdu.origin - ) - - current = self._get_current_interaction( - txn, - new_pdu.context, new_pdu.pdu_type, new_pdu.state_key - ) - - if (not current or not current.prev_state_id - or not current.prev_state_origin): - return None - - # Oh look, it's a straight clobber, so wooooo almost no-op. - if (new_pdu.prev_state_id == current.pdu_id - and new_pdu.prev_state_origin == current.origin): - return None - - enum_branches = self._enumerate_state_branches(txn, new_pdu, current) - for branch, prev_state, state in enum_branches: - if not state: - return PduIdTuple( - prev_state.prev_state_id, - prev_state.prev_state_origin - ) - - return None - def handle_new_state(self, new_pdu): """Actually perform conflict resolution on the new_pdu on the assumption we have all the pdus required to perform it. @@ -755,24 +709,11 @@ class StatePduStore(SQLBaseStore): return is_current - @classmethod @log_function - def _enumerate_state_branches(cls, txn, pdu_a, pdu_b): + def _enumerate_state_branches(self, txn, pdu_a, pdu_b): branch_a = pdu_a branch_b = pdu_b - get_query = ( - "SELECT %(fields)s FROM %(pdus)s as p " - "LEFT JOIN %(state)s as s " - "ON p.pdu_id = s.pdu_id AND p.origin = s.origin " - "WHERE p.pdu_id = ? AND p.origin = ? " - ) % { - "fields": _pdu_state_joiner.get_fields( - PdusTable="p", StatePdusTable="s"), - "pdus": PdusTable.table_name, - "state": StatePdusTable.table_name, - } - while True: if (branch_a.pdu_id == branch_b.pdu_id and branch_a.origin == branch_b.origin): @@ -804,13 +745,12 @@ class StatePduStore(SQLBaseStore): branch_a.prev_state_origin ) - logger.debug("getting branch_a prev %s", pdu_tuple) - txn.execute(get_query, pdu_tuple) - prev_branch = branch_a - res = txn.fetchone() - branch_a = PduEntry(*res) if res else None + logger.debug("getting branch_a prev %s", pdu_tuple) + branch_a = self._get_pdu_tuple(txn, *pdu_tuple) + if branch_a: + branch_a = Pdu.from_pdu_tuple(branch_a) logger.debug("branch_a=%s", branch_a) @@ -823,14 +763,13 @@ class StatePduStore(SQLBaseStore): branch_b.prev_state_id, branch_b.prev_state_origin ) - txn.execute(get_query, pdu_tuple) - - logger.debug("getting branch_b prev %s", pdu_tuple) prev_branch = branch_b - res = txn.fetchone() - branch_b = PduEntry(*res) if res else None + logger.debug("getting branch_b prev %s", pdu_tuple) + branch_b = self._get_pdu_tuple(txn, *pdu_tuple) + if branch_b: + branch_b = Pdu.from_pdu_tuple(branch_b) logger.debug("branch_b=%s", branch_b) diff --git a/tests/test_state.py b/tests/test_state.py index a9fc3fb85c..16af95b7bc 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -15,15 +15,18 @@ from twisted.internet import defer from twisted.trial import unittest +from twisted.python.log import PythonLoggingObserver from synapse.state import StateHandler from synapse.storage.pdu import PduEntry from synapse.federation.pdu_codec import encode_event_id +from synapse.federation.units import Pdu from collections import namedtuple from mock import Mock +import logging import mock @@ -32,6 +35,11 @@ ReturnType = namedtuple( ) +def _gen_get_power_level(power_level_list): + def get_power_level(room_id, user_id): + return defer.succeed(power_level_list.get(user_id, None)) + return get_power_level + class StateTestCase(unittest.TestCase): def setUp(self): self.persistence = Mock(spec=[ @@ -40,6 +48,7 @@ class StateTestCase(unittest.TestCase): "get_latest_pdus_in_context", "get_current_state_pdu", "get_pdu", + "get_power_level", ]) self.replication = Mock(spec=["get_pdu"]) @@ -53,7 +62,9 @@ class StateTestCase(unittest.TestCase): @defer.inlineCallbacks def test_new_state_key(self): # We've never seen anything for this state before - new_pdu = new_fake_pdu_entry("A", "test", "mem", "x", None, 10) + new_pdu = new_fake_pdu("A", "test", "mem", "x", None, "u") + + self.persistence.get_power_level.side_effect = _gen_get_power_level({}) self.persistence.get_unresolved_state_tree.return_value = ( (ReturnType([new_pdu], []), None) @@ -76,8 +87,13 @@ class StateTestCase(unittest.TestCase): # We do a direct overwriting of the old state, i.e., the new state # points to the old state. - old_pdu = new_fake_pdu_entry("A", "test", "mem", "x", None, 10) - new_pdu = new_fake_pdu_entry("B", "test", "mem", "x", "A", 5) + old_pdu = new_fake_pdu("A", "test", "mem", "x", None, "u1") + new_pdu = new_fake_pdu("B", "test", "mem", "x", "A", "u2") + + self.persistence.get_power_level.side_effect = _gen_get_power_level({ + "u1": 10, + "u2": 5, + }) self.persistence.get_unresolved_state_tree.return_value = ( (ReturnType([new_pdu, old_pdu], [old_pdu]), None) @@ -95,14 +111,48 @@ class StateTestCase(unittest.TestCase): self.assertFalse(self.replication.get_pdu.called) + @defer.inlineCallbacks + def test_overwrite(self): + old_pdu_1 = new_fake_pdu("A", "test", "mem", "x", None, "u1") + old_pdu_2 = new_fake_pdu("B", "test", "mem", "x", "A", "u2") + new_pdu = new_fake_pdu("C", "test", "mem", "x", "B", "u3") + + self.persistence.get_power_level.side_effect = _gen_get_power_level({ + "u1": 10, + "u2": 5, + "u3": 0, + }) + + self.persistence.get_unresolved_state_tree.return_value = ( + (ReturnType([new_pdu, old_pdu_2, old_pdu_1], [old_pdu_1]), None) + ) + + is_new = yield self.state.handle_new_state(new_pdu) + + self.assertTrue(is_new) + + self.persistence.get_unresolved_state_tree.assert_called_once_with( + new_pdu + ) + + self.assertEqual(1, self.persistence.update_current_state.call_count) + + self.assertFalse(self.replication.get_pdu.called) + @defer.inlineCallbacks def test_power_level_fail(self): # We try to update the state based on an outdated state, and have a # too low power level. - old_pdu_1 = new_fake_pdu_entry("A", "test", "mem", "x", None, 10) - old_pdu_2 = new_fake_pdu_entry("B", "test", "mem", "x", None, 10) - new_pdu = new_fake_pdu_entry("C", "test", "mem", "x", "A", 5) + old_pdu_1 = new_fake_pdu("A", "test", "mem", "x", None, "u1") + old_pdu_2 = new_fake_pdu("B", "test", "mem", "x", None, "u2") + new_pdu = new_fake_pdu("C", "test", "mem", "x", "A", "u3") + + self.persistence.get_power_level.side_effect = _gen_get_power_level({ + "u1": 10, + "u2": 10, + "u3": 5, + }) self.persistence.get_unresolved_state_tree.return_value = ( (ReturnType([new_pdu, old_pdu_1], [old_pdu_2, old_pdu_1]), None) @@ -125,9 +175,15 @@ class StateTestCase(unittest.TestCase): # We try to update the state based on an outdated state, but have # sufficient power level to force the update. - old_pdu_1 = new_fake_pdu_entry("A", "test", "mem", "x", None, 10) - old_pdu_2 = new_fake_pdu_entry("B", "test", "mem", "x", None, 10) - new_pdu = new_fake_pdu_entry("C", "test", "mem", "x", "A", 15) + old_pdu_1 = new_fake_pdu("A", "test", "mem", "x", None, "u1") + old_pdu_2 = new_fake_pdu("B", "test", "mem", "x", None, "u2") + new_pdu = new_fake_pdu("C", "test", "mem", "x", "A", "u3") + + self.persistence.get_power_level.side_effect = _gen_get_power_level({ + "u1": 10, + "u2": 10, + "u3": 15, + }) self.persistence.get_unresolved_state_tree.return_value = ( (ReturnType([new_pdu, old_pdu_1], [old_pdu_2, old_pdu_1]), None) @@ -150,9 +206,15 @@ class StateTestCase(unittest.TestCase): # We try to update the state based on an outdated state, the power # levels are the same and so are the branch lengths - old_pdu_1 = new_fake_pdu_entry("A", "test", "mem", "x", None, 10) - old_pdu_2 = new_fake_pdu_entry("B", "test", "mem", "x", None, 10) - new_pdu = new_fake_pdu_entry("C", "test", "mem", "x", "A", 10) + old_pdu_1 = new_fake_pdu("A", "test", "mem", "x", None, "u1") + old_pdu_2 = new_fake_pdu("B", "test", "mem", "x", None, "u2") + new_pdu = new_fake_pdu("C", "test", "mem", "x", "A", "u3") + + self.persistence.get_power_level.side_effect = _gen_get_power_level({ + "u1": 10, + "u2": 10, + "u3": 10, + }) self.persistence.get_unresolved_state_tree.return_value = ( (ReturnType([new_pdu, old_pdu_1], [old_pdu_2, old_pdu_1]), None) @@ -175,10 +237,17 @@ class StateTestCase(unittest.TestCase): # We try to update the state based on an outdated state, the power # levels are the same but the branch length of the new one is longer. - old_pdu_1 = new_fake_pdu_entry("A", "test", "mem", "x", None, 10) - old_pdu_2 = new_fake_pdu_entry("B", "test", "mem", "x", None, 10) - old_pdu_3 = new_fake_pdu_entry("C", "test", "mem", "x", "A", 10) - new_pdu = new_fake_pdu_entry("D", "test", "mem", "x", "C", 10) + old_pdu_1 = new_fake_pdu("A", "test", "mem", "x", None, "u1") + old_pdu_2 = new_fake_pdu("B", "test", "mem", "x", None, "u2") + old_pdu_3 = new_fake_pdu("C", "test", "mem", "x", "A", "u3") + new_pdu = new_fake_pdu("D", "test", "mem", "x", "C", "u4") + + self.persistence.get_power_level.side_effect = _gen_get_power_level({ + "u1": 10, + "u2": 10, + "u3": 10, + "u4": 10, + }) self.persistence.get_unresolved_state_tree.return_value = ( ( @@ -208,17 +277,23 @@ class StateTestCase(unittest.TestCase): # triggering a get_pdu request # The pdu we haven't seen - old_pdu_1 = new_fake_pdu_entry( - "A", "test", "mem", "x", None, 10, depth=0 + old_pdu_1 = new_fake_pdu( + "A", "test", "mem", "x", None, "u1", depth=0 ) - old_pdu_2 = new_fake_pdu_entry( - "B", "test", "mem", "x", "A", 10, depth=1 + old_pdu_2 = new_fake_pdu( + "B", "test", "mem", "x", "A", "u2", depth=1 ) - new_pdu = new_fake_pdu_entry( - "C", "test", "mem", "x", "A", 20, depth=2 + new_pdu = new_fake_pdu( + "C", "test", "mem", "x", "A", "u3", depth=2 ) + self.persistence.get_power_level.side_effect = _gen_get_power_level({ + "u1": 10, + "u2": 10, + "u3": 20, + }) + # The return_value of `get_unresolved_state_tree`, which changes after # the call to get_pdu tree_to_return = [(ReturnType([new_pdu], [old_pdu_2]), 0)] @@ -268,20 +343,27 @@ class StateTestCase(unittest.TestCase): # triggering a get_pdu request # The pdu we haven't seen - old_pdu_1 = new_fake_pdu_entry( - "A", "test", "mem", "x", None, 10, depth=0 + old_pdu_1 = new_fake_pdu( + "A", "test", "mem", "x", None, "u1", depth=0 ) - old_pdu_2 = new_fake_pdu_entry( - "B", "test", "mem", "x", "A", 10, depth=2 + old_pdu_2 = new_fake_pdu( + "B", "test", "mem", "x", "A", "u2", depth=2 ) - old_pdu_3 = new_fake_pdu_entry( - "C", "test", "mem", "x", "B", 10, depth=3 + old_pdu_3 = new_fake_pdu( + "C", "test", "mem", "x", "B", "u3", depth=3 ) - new_pdu = new_fake_pdu_entry( - "D", "test", "mem", "x", "A", 20, depth=4 + new_pdu = new_fake_pdu( + "D", "test", "mem", "x", "A", "u4", depth=4 ) + self.persistence.get_power_level.side_effect = _gen_get_power_level({ + "u1": 10, + "u2": 10, + "u3": 10, + "u4": 20, + }) + # The return_value of `get_unresolved_state_tree`, which changes after # the call to get_pdu tree_to_return = [ @@ -357,20 +439,27 @@ class StateTestCase(unittest.TestCase): # triggering a get_pdu request # The pdu we haven't seen - old_pdu_1 = new_fake_pdu_entry( - "A", "test", "mem", "x", None, 10, depth=0 + old_pdu_1 = new_fake_pdu( + "A", "test", "mem", "x", None, "u1", depth=0 ) - old_pdu_2 = new_fake_pdu_entry( - "B", "test", "mem", "x", "A", 10, depth=2 + old_pdu_2 = new_fake_pdu( + "B", "test", "mem", "x", "A", "u2", depth=2 ) - old_pdu_3 = new_fake_pdu_entry( - "C", "test", "mem", "x", "B", 10, depth=3 + old_pdu_3 = new_fake_pdu( + "C", "test", "mem", "x", "B", "u3", depth=3 ) - new_pdu = new_fake_pdu_entry( - "D", "test", "mem", "x", "A", 20, depth=1 + new_pdu = new_fake_pdu( + "D", "test", "mem", "x", "A", "u4", depth=1 ) + self.persistence.get_power_level.side_effect = _gen_get_power_level({ + "u1": 10, + "u2": 10, + "u3": 10, + "u4": 20, + }) + # The return_value of `get_unresolved_state_tree`, which changes after # the call to get_pdu tree_to_return = [ @@ -445,8 +534,13 @@ class StateTestCase(unittest.TestCase): # We do a direct overwriting of the old state, i.e., the new state # points to the old state. - old_pdu = new_fake_pdu_entry("A", "test", "mem", "x", None, 5) - new_pdu = new_fake_pdu_entry("B", "test", "mem", "x", None, 10) + old_pdu = new_fake_pdu("A", "test", "mem", "x", None, "u1") + new_pdu = new_fake_pdu("B", "test", "mem", "x", None, "u2") + + self.persistence.get_power_level.side_effect = _gen_get_power_level({ + "u1": 5, + "u2": 10, + }) self.persistence.get_unresolved_state_tree.return_value = ( (ReturnType([new_pdu], [old_pdu]), None) @@ -469,7 +563,7 @@ class StateTestCase(unittest.TestCase): event = Mock() event.event_id = "12123123@test" - state_pdu = new_fake_pdu_entry("C", "test", "mem", "x", "A", 20) + state_pdu = new_fake_pdu("C", "test", "mem", "x", "A", 20) snapshot = Mock() snapshot.prev_state_pdu = state_pdu @@ -496,13 +590,13 @@ class StateTestCase(unittest.TestCase): ) -def new_fake_pdu_entry(pdu_id, context, pdu_type, state_key, prev_state_id, - power_level, depth=0): - new_pdu = PduEntry( +def new_fake_pdu(pdu_id, context, pdu_type, state_key, prev_state_id, + user_id, depth=0): + new_pdu = Pdu( pdu_id=pdu_id, pdu_type=pdu_type, state_key=state_key, - power_level=power_level, + user_id=user_id, prev_state_id=prev_state_id, origin="example.com", context="context", @@ -514,6 +608,7 @@ def new_fake_pdu_entry(pdu_id, context, pdu_type, state_key, prev_state_id, is_state=True, prev_state_origin="example.com", have_processed=True, + content={}, ) return new_pdu -- cgit 1.5.1 From cd62ee3f29456d96d336f4c67cbd37a0a95f7b4a Mon Sep 17 00:00:00 2001 From: "Paul \"LeoNerd\" Evans" Date: Fri, 12 Sep 2014 18:24:53 +0100 Subject: Have all unit tests import from our own subclass of trial's unittest TestCase; set up logging in ONE PLACE ONLY --- tests/api/test_ratelimiting.py | 2 +- tests/events/test_events.py | 2 +- tests/federation/test_federation.py | 6 +----- tests/federation/test_pdu_codec.py | 2 +- tests/handlers/test_directory.py | 6 +----- tests/handlers/test_federation.py | 6 +----- tests/handlers/test_presence.py | 7 +------ tests/handlers/test_presencelike.py | 6 +----- tests/handlers/test_profile.py | 6 +----- tests/handlers/test_room.py | 6 +----- tests/handlers/test_typing.py | 6 +----- tests/rest/test_events.py | 4 +--- tests/rest/test_presence.py | 6 +----- tests/rest/test_profile.py | 3 ++- tests/rest/utils.py | 2 +- tests/storage/test_base.py | 2 +- tests/test_distributor.py | 2 +- tests/test_state.py | 3 +-- tests/test_types.py | 2 +- tests/unittest.py | 30 ++++++++++++++++++++++++++++++ tests/util/test_lock.py | 4 ++-- 21 files changed, 52 insertions(+), 61 deletions(-) create mode 100644 tests/unittest.py (limited to 'tests') diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py index dc2f83c7eb..dd0bc19ecf 100644 --- a/tests/api/test_ratelimiting.py +++ b/tests/api/test_ratelimiting.py @@ -1,6 +1,6 @@ from synapse.api.ratelimiting import Ratelimiter -import unittest +from tests import unittest class TestRatelimiter(unittest.TestCase): diff --git a/tests/events/test_events.py b/tests/events/test_events.py index 93d5c15c6f..a4b6cb3afd 100644 --- a/tests/events/test_events.py +++ b/tests/events/test_events.py @@ -15,7 +15,7 @@ from synapse.api.events import SynapseEvent -import unittest +from tests import unittest class SynapseTemplateCheckTestCase(unittest.TestCase): diff --git a/tests/federation/test_federation.py b/tests/federation/test_federation.py index 0b105fe723..954ccac2a4 100644 --- a/tests/federation/test_federation.py +++ b/tests/federation/test_federation.py @@ -14,11 +14,10 @@ # trial imports from twisted.internet import defer -from twisted.trial import unittest +from tests import unittest # python imports from mock import Mock -import logging from ..utils import MockHttpResource, MockClock @@ -28,9 +27,6 @@ from synapse.federation.units import Pdu from synapse.storage.pdu import PduTuple, PduEntry -logging.getLogger().addHandler(logging.NullHandler()) - - def make_pdu(prev_pdus=[], **kwargs): """Provide some default fields for making a PduTuple.""" pdu_fields = { diff --git a/tests/federation/test_pdu_codec.py b/tests/federation/test_pdu_codec.py index 9f74ba119f..344e1baf60 100644 --- a/tests/federation/test_pdu_codec.py +++ b/tests/federation/test_pdu_codec.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.trial import unittest +from tests import unittest from synapse.federation.pdu_codec import ( PduCodec, encode_event_id, decode_event_id diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index 72a2b1443a..54d6e51f97 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -14,11 +14,10 @@ # limitations under the License. -from twisted.trial import unittest +from tests import unittest from twisted.internet import defer from mock import Mock -import logging from synapse.server import HomeServer from synapse.http.client import HttpClient @@ -26,9 +25,6 @@ from synapse.handlers.directory import DirectoryHandler from synapse.storage.directory import RoomAliasMapping -logging.getLogger().addHandler(logging.NullHandler()) - - class DirectoryHandlers(object): def __init__(self, hs): self.directory_handler = DirectoryHandler(hs) diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index 6fc3d8f7fd..f0308a29d3 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -14,7 +14,7 @@ from twisted.internet import defer -from twisted.trial import unittest +from tests import unittest from synapse.api.events.room import ( InviteJoinEvent, MessageEvent, RoomMemberEvent @@ -26,12 +26,8 @@ from synapse.federation.units import Pdu from mock import NonCallableMock, ANY -import logging - from ..utils import get_mock_call_args -logging.getLogger().addHandler(logging.NullHandler()) - class FederationTestCase(unittest.TestCase): diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index 9eb8b6909f..06f5f9c2ba 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -14,11 +14,10 @@ # limitations under the License. -from twisted.trial import unittest +from tests import unittest from twisted.internet import defer, reactor from mock import Mock, call, ANY -import logging import json from ..utils import MockHttpResource, MockClock, DeferredMockCallable @@ -34,9 +33,6 @@ UNAVAILABLE = PresenceState.UNAVAILABLE ONLINE = PresenceState.ONLINE -logging.getLogger().addHandler(logging.NullHandler()) - - def _expect_edu(destination, edu_type, content, origin="test"): return { "origin": origin, @@ -92,7 +88,6 @@ class PresenceStateTestCase(unittest.TestCase): # Mock the RoomMemberHandler room_member_handler = Mock(spec=[]) hs.handlers.room_member_handler = room_member_handler - logging.getLogger().debug("Mocking room_member_handler=%r", room_member_handler) # Some local users to test with self.u_apple = hs.parse_userid("@apple:test") diff --git a/tests/handlers/test_presencelike.py b/tests/handlers/test_presencelike.py index b35980d948..72c55b3667 100644 --- a/tests/handlers/test_presencelike.py +++ b/tests/handlers/test_presencelike.py @@ -16,11 +16,10 @@ """This file contains tests of the "presence-like" data that is shared between presence and profiles; namely, the displayname and avatar_url.""" -from twisted.trial import unittest +from tests import unittest from twisted.internet import defer from mock import Mock, call, ANY -import logging from ..utils import MockClock @@ -35,9 +34,6 @@ UNAVAILABLE = PresenceState.UNAVAILABLE ONLINE = PresenceState.ONLINE -logging.getLogger().addHandler(logging.NullHandler()) - - class MockReplication(object): def __init__(self): self.edu_handlers = {} diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index 8e7a89b479..0a5cebb4cc 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -14,20 +14,16 @@ # limitations under the License. -from twisted.trial import unittest +from tests import unittest from twisted.internet import defer from mock import Mock -import logging from synapse.api.errors import AuthError from synapse.server import HomeServer from synapse.handlers.profile import ProfileHandler -logging.getLogger().addHandler(logging.NullHandler()) - - class ProfileHandlers(object): def __init__(self, hs): self.profile_handler = ProfileHandler(hs) diff --git a/tests/handlers/test_room.py b/tests/handlers/test_room.py index 5687bbea0b..a1a2e80492 100644 --- a/tests/handlers/test_room.py +++ b/tests/handlers/test_room.py @@ -15,7 +15,7 @@ from twisted.internet import defer -from twisted.trial import unittest +from tests import unittest from synapse.api.events.room import ( InviteJoinEvent, RoomMemberEvent, RoomConfigEvent @@ -27,10 +27,6 @@ from synapse.server import HomeServer from mock import Mock, NonCallableMock -import logging - -logging.getLogger().addHandler(logging.NullHandler()) - class RoomMemberHandlerTestCase(unittest.TestCase): diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 6532ac94a3..ab908cdfc1 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -14,12 +14,11 @@ # limitations under the License. -from twisted.trial import unittest +from tests import unittest from twisted.internet import defer from mock import Mock, call, ANY import json -import logging from ..utils import MockHttpResource, MockClock, DeferredMockCallable @@ -27,9 +26,6 @@ from synapse.server import HomeServer from synapse.handlers.typing import TypingNotificationHandler -logging.getLogger().addHandler(logging.NullHandler()) - - def _expect_edu(destination, edu_type, content, origin="test"): return { "origin": origin, diff --git a/tests/rest/test_events.py b/tests/rest/test_events.py index fd2224f55f..79b371c04d 100644 --- a/tests/rest/test_events.py +++ b/tests/rest/test_events.py @@ -14,7 +14,7 @@ # limitations under the License. """ Tests REST events for /events paths.""" -from twisted.trial import unittest +from tests import unittest # twisted imports from twisted.internet import defer @@ -27,14 +27,12 @@ from synapse.server import HomeServer # python imports import json -import logging from ..utils import MockHttpResource, MemoryDataStore from .utils import RestTestCase from mock import Mock, NonCallableMock -logging.getLogger().addHandler(logging.NullHandler()) PATH_PREFIX = "/_matrix/client/api/v1" diff --git a/tests/rest/test_presence.py b/tests/rest/test_presence.py index a1db0fbcf3..ea3478ac5d 100644 --- a/tests/rest/test_presence.py +++ b/tests/rest/test_presence.py @@ -15,11 +15,10 @@ """Tests REST events for /presence paths.""" -from twisted.trial import unittest +from tests import unittest from twisted.internet import defer from mock import Mock -import logging from ..utils import MockHttpResource @@ -28,9 +27,6 @@ from synapse.handlers.presence import PresenceHandler from synapse.server import HomeServer -logging.getLogger().addHandler(logging.NullHandler()) - - OFFLINE = PresenceState.OFFLINE UNAVAILABLE = PresenceState.UNAVAILABLE ONLINE = PresenceState.ONLINE diff --git a/tests/rest/test_profile.py b/tests/rest/test_profile.py index f41810df1f..e6e51f6dd0 100644 --- a/tests/rest/test_profile.py +++ b/tests/rest/test_profile.py @@ -15,7 +15,7 @@ """Tests REST events for /profile paths.""" -from twisted.trial import unittest +from tests import unittest from twisted.internet import defer from mock import Mock @@ -28,6 +28,7 @@ from synapse.server import HomeServer myid = "@1234ABCD:test" PATH_PREFIX = "/_matrix/client/api/v1" + class ProfileTestCase(unittest.TestCase): """ Tests profile management. """ diff --git a/tests/rest/utils.py b/tests/rest/utils.py index 77f5ecf0df..ce2e8fd98a 100644 --- a/tests/rest/utils.py +++ b/tests/rest/utils.py @@ -17,7 +17,7 @@ from twisted.internet import defer # trial imports -from twisted.trial import unittest +from tests import unittest from synapse.api.constants import Membership diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index 330311448d..3ad9a4b0c0 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -14,7 +14,7 @@ # limitations under the License. -from twisted.trial import unittest +from tests import unittest from twisted.internet import defer from mock import Mock, call diff --git a/tests/test_distributor.py b/tests/test_distributor.py index 04933f0ecf..39c5b8dff2 100644 --- a/tests/test_distributor.py +++ b/tests/test_distributor.py @@ -13,8 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from tests import unittest from twisted.internet import defer -from twisted.trial import unittest from mock import Mock, patch diff --git a/tests/test_state.py b/tests/test_state.py index 16af95b7bc..b1624f0b25 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -13,8 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from tests import unittest from twisted.internet import defer -from twisted.trial import unittest from twisted.python.log import PythonLoggingObserver from synapse.state import StateHandler @@ -26,7 +26,6 @@ from collections import namedtuple from mock import Mock -import logging import mock diff --git a/tests/test_types.py b/tests/test_types.py index 571938356c..276ecc91fd 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest +from tests import unittest from synapse.server import BaseHomeServer from synapse.types import UserID, RoomAlias diff --git a/tests/unittest.py b/tests/unittest.py new file mode 100644 index 0000000000..00c3c532eb --- /dev/null +++ b/tests/unittest.py @@ -0,0 +1,30 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.trial import unittest + +import logging + + +# logging doesn't have a "don't log anything at all EVARRRR setting, +# but since the highest value is 50, 1000000 should do ;) +NEVER = 1000000 + +logging.getLogger().addHandler(logging.StreamHandler()) +logging.getLogger().setLevel(NEVER) + + +class TestCase(unittest.TestCase): + pass diff --git a/tests/util/test_lock.py b/tests/util/test_lock.py index 5623d78423..6a1e521b1e 100644 --- a/tests/util/test_lock.py +++ b/tests/util/test_lock.py @@ -15,7 +15,7 @@ from twisted.internet import defer -from twisted.trial import unittest +from tests import unittest from synapse.util.lockutils import LockManager @@ -105,4 +105,4 @@ class LockManagerTestCase(unittest.TestCase): pass with (yield self.lock_manager.lock(key)): - pass \ No newline at end of file + pass -- cgit 1.5.1 From ca8349a897c233d72ea74128dabdd1311f00c13c Mon Sep 17 00:00:00 2001 From: "Paul \"LeoNerd\" Evans" Date: Fri, 12 Sep 2014 18:29:07 +0100 Subject: Allow a TestCase to set a 'loglevel' attribute, which overrides the logging level while that testcase runs --- tests/unittest.py | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/unittest.py b/tests/unittest.py index 00c3c532eb..19be03b96a 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -27,4 +27,25 @@ logging.getLogger().setLevel(NEVER) class TestCase(unittest.TestCase): - pass + def __init__(self, *args, **kwargs): + super(TestCase, self).__init__(*args, **kwargs) + + level = getattr(self, "loglevel", NEVER) + + orig_setUp = self.setUp + + def setUp(): + old_level = logging.getLogger().level + + if old_level != level: + orig_tearDown = self.tearDown + + def tearDown(): + ret = orig_tearDown() + logging.getLogger().setLevel(old_level) + return ret + self.tearDown = tearDown + + logging.getLogger().setLevel(level) + return orig_setUp() + self.setUp = setUp -- cgit 1.5.1 From 33c4dd4c2ddcd81854855e84a838db9603bbe338 Mon Sep 17 00:00:00 2001 From: "Paul \"LeoNerd\" Evans" Date: Fri, 12 Sep 2014 18:38:11 +0100 Subject: Define a (class) decorator for easily setting a DEBUG logging level on a TestCase --- tests/unittest.py | 5 +++++ 1 file changed, 5 insertions(+) (limited to 'tests') diff --git a/tests/unittest.py b/tests/unittest.py index 19be03b96a..c66a3b8407 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -49,3 +49,8 @@ class TestCase(unittest.TestCase): logging.getLogger().setLevel(level) return orig_setUp() self.setUp = setUp + + +def DEBUG(target): + target.loglevel = logging.DEBUG + return target -- cgit 1.5.1 From d9f3f322c5f17a4c3d3ac000462a7bbd0a407711 Mon Sep 17 00:00:00 2001 From: "Paul \"LeoNerd\" Evans" Date: Fri, 12 Sep 2014 18:43:49 +0100 Subject: Additionally look first for a 'loglevel' attribute on the running test method, before the TestCase --- tests/unittest.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) (limited to 'tests') diff --git a/tests/unittest.py b/tests/unittest.py index c66a3b8407..8ae724c786 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -27,10 +27,14 @@ logging.getLogger().setLevel(NEVER) class TestCase(unittest.TestCase): - def __init__(self, *args, **kwargs): - super(TestCase, self).__init__(*args, **kwargs) + def __init__(self, methodName, *args, **kwargs): + super(TestCase, self).__init__(methodName, *args, **kwargs) - level = getattr(self, "loglevel", NEVER) + method = getattr(self, methodName) + + level = getattr(method, "loglevel", + getattr(self, "loglevel", + NEVER)) orig_setUp = self.setUp -- cgit 1.5.1 From aeb69c0f8cc6e723316aefbc6b71c82b4ed94aad Mon Sep 17 00:00:00 2001 From: "Paul \"LeoNerd\" Evans" Date: Fri, 12 Sep 2014 18:45:48 +0100 Subject: Add some docstrings --- tests/unittest.py | 6 ++++++ 1 file changed, 6 insertions(+) (limited to 'tests') diff --git a/tests/unittest.py b/tests/unittest.py index 8ae724c786..e437d3541a 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -27,6 +27,10 @@ logging.getLogger().setLevel(NEVER) class TestCase(unittest.TestCase): + """A subclass of twisted.trial's TestCase which looks for 'loglevel' + attributes on both itself and its individual test methods, to override the + root logger's logging level while that test (case|method) runs.""" + def __init__(self, methodName, *args, **kwargs): super(TestCase, self).__init__(methodName, *args, **kwargs) @@ -56,5 +60,7 @@ class TestCase(unittest.TestCase): def DEBUG(target): + """A decorator to set the .loglevel attribute to logging.DEBUG. + Can apply to either a TestCase or an individual test method.""" target.loglevel = logging.DEBUG return target -- cgit 1.5.1 From 7a77aabb4bbb997db9dadd46e49d855946c1ae2e Mon Sep 17 00:00:00 2001 From: "Paul \"LeoNerd\" Evans" Date: Fri, 12 Sep 2014 19:07:29 +0100 Subject: Define a CLOS-like 'around' modifier as a decorator, to neaten up the 'orig_*' noise of wrapping the setUp()/tearDown() methods --- tests/unittest.py | 33 +++++++++++++++++++++++---------- 1 file changed, 23 insertions(+), 10 deletions(-) (limited to 'tests') diff --git a/tests/unittest.py b/tests/unittest.py index e437d3541a..fb97fb1148 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -26,6 +26,23 @@ logging.getLogger().addHandler(logging.StreamHandler()) logging.getLogger().setLevel(NEVER) +def around(target): + """A CLOS-style 'around' modifier, which wraps the original method of the + given instance with another piece of code. + + @around(self) + def method_name(orig, *args, **kwargs): + return orig(*args, **kwargs) + """ + def _around(code): + name = code.__name__ + orig = getattr(target, name) + def new(*args, **kwargs): + return code(orig, *args, **kwargs) + setattr(target, name, new) + return _around + + class TestCase(unittest.TestCase): """A subclass of twisted.trial's TestCase which looks for 'loglevel' attributes on both itself and its individual test methods, to override the @@ -40,23 +57,19 @@ class TestCase(unittest.TestCase): getattr(self, "loglevel", NEVER)) - orig_setUp = self.setUp - - def setUp(): + @around(self) + def setUp(orig): old_level = logging.getLogger().level if old_level != level: - orig_tearDown = self.tearDown - - def tearDown(): - ret = orig_tearDown() + @around(self) + def tearDown(orig): + ret = orig() logging.getLogger().setLevel(old_level) return ret - self.tearDown = tearDown logging.getLogger().setLevel(level) - return orig_setUp() - self.setUp = setUp + return orig() def DEBUG(target): -- cgit 1.5.1