From a4d62ba36afc54d4e60f1371fe9b31e8b8e6834c Mon Sep 17 00:00:00 2001 From: Kegan Dougal Date: Tue, 28 Jul 2015 17:34:12 +0100 Subject: Fix v2_alpha registration. Add unit tests. V2 Registration forced everyone (including ASes) to create a password for a user, when ASes should be able to omit passwords. Also unbreak AS registration in general which checked too early if the given username was claimed by an AS; it was checked before knowing if the AS was the one doing the registration! Add unit tests for AS reg, user reg and disabled_registration flag. --- tests/rest/client/v2_alpha/test_register.py | 132 ++++++++++++++++++++++++++++ 1 file changed, 132 insertions(+) create mode 100644 tests/rest/client/v2_alpha/test_register.py (limited to 'tests') diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py new file mode 100644 index 0000000000..3edc2ec2e9 --- /dev/null +++ b/tests/rest/client/v2_alpha/test_register.py @@ -0,0 +1,132 @@ +from synapse.rest.client.v2_alpha.register import RegisterRestServlet +from synapse.api.errors import SynapseError +from twisted.internet import defer +from mock import Mock, MagicMock +from tests import unittest +import json + + +class RegisterRestServletTestCase(unittest.TestCase): + + def setUp(self): + # do the dance to hook up request data to self.request_data + self.request_data = "" + self.request = Mock( + content=Mock(read=Mock(side_effect=lambda: self.request_data)), + ) + self.request.args = {} + + self.appservice = None + self.auth = Mock(get_appservice_by_req=Mock( + side_effect=lambda x: defer.succeed(self.appservice)) + ) + + self.auth_result = (False, None, None) + self.auth_handler = Mock( + check_auth=Mock(side_effect=lambda x,y,z: self.auth_result) + ) + self.registration_handler = Mock() + self.identity_handler = Mock() + self.login_handler = Mock() + + # do the dance to hook it up to the hs global + self.handlers = Mock( + auth_handler=self.auth_handler, + registration_handler=self.registration_handler, + identity_handler=self.identity_handler, + login_handler=self.login_handler + ) + self.hs = Mock() + self.hs.hostname = "superbig~testing~thing.com" + self.hs.get_auth = Mock(return_value=self.auth) + self.hs.get_handlers = Mock(return_value=self.handlers) + self.hs.config.disable_registration = False + + # init the thing we're testing + self.servlet = RegisterRestServlet(self.hs) + + @defer.inlineCallbacks + def test_POST_appservice_registration_valid(self): + user_id = "@kermit:muppet" + token = "kermits_access_token" + self.request.args = { + "access_token": "i_am_an_app_service" + } + self.request_data = json.dumps({ + "username": "kermit" + }) + self.appservice = { + "id": "1234" + } + self.registration_handler.register = Mock(return_value=(user_id, token)) + result = yield self.servlet.on_POST(self.request) + self.assertEquals(result, (200, { + "user_id": user_id, + "access_token": token, + "home_server": self.hs.hostname + })) + + @defer.inlineCallbacks + def test_POST_appservice_registration_invalid(self): + self.request.args = { + "access_token": "i_am_an_app_service" + } + self.request_data = json.dumps({ + "username": "kermit" + }) + self.appservice = None # no application service exists + result = yield self.servlet.on_POST(self.request) + self.assertEquals(result, (401, None)) + + def test_POST_bad_password(self): + self.request_data = json.dumps({ + "username": "kermit", + "password": 666 + }) + d = self.servlet.on_POST(self.request) + return self.assertFailure(d, SynapseError) + + def test_POST_bad_username(self): + self.request_data = json.dumps({ + "username": 777, + "password": "monkey" + }) + d = self.servlet.on_POST(self.request) + return self.assertFailure(d, SynapseError) + + @defer.inlineCallbacks + def test_POST_user_valid(self): + user_id = "@kermit:muppet" + token = "kermits_access_token" + self.request_data = json.dumps({ + "username": "kermit", + "password": "monkey" + }) + self.registration_handler.check_username = Mock(return_value=True) + self.auth_result = (True, None, { + "username": "kermit", + "password": "monkey" + }) + self.registration_handler.register = Mock(return_value=(user_id, token)) + + result = yield self.servlet.on_POST(self.request) + self.assertEquals(result, (200, { + "user_id": user_id, + "access_token": token, + "home_server": self.hs.hostname + })) + + def test_POST_disabled_registration(self): + self.hs.config.disable_registration = True + self.request_data = json.dumps({ + "username": "kermit", + "password": "monkey" + }) + self.registration_handler.check_username = Mock(return_value=True) + self.auth_result = (True, None, { + "username": "kermit", + "password": "monkey" + }) + self.registration_handler.register = Mock(return_value=("@user:id", "t")) + d = self.servlet.on_POST(self.request) + return self.assertFailure(d, SynapseError) \ No newline at end of file -- cgit 1.5.1 From 11b0a3407485e98082bf06d771e5ae2f68106ca7 Mon Sep 17 00:00:00 2001 From: Kegan Dougal Date: Wed, 29 Jul 2015 10:00:54 +0100 Subject: Use the same reg paths as register v1 for ASes. Namely this means using registration_handler.appservice_register. --- synapse/rest/client/v2_alpha/register.py | 10 ++++++---- tests/rest/client/v2_alpha/test_register.py | 4 +++- 2 files changed, 9 insertions(+), 5 deletions(-) (limited to 'tests') diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index e1c42dd51e..cf54e1dacf 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -82,7 +82,9 @@ class RegisterRestServlet(RestServlet): # == Application Service Registration == if appservice: - result = yield self._do_appservice_registration(desired_username) + result = yield self._do_appservice_registration( + desired_username, request.args["access_token"][0] + ) defer.returnValue((200, result)) # we throw for non 200 responses return @@ -166,9 +168,9 @@ class RegisterRestServlet(RestServlet): return 200, {} @defer.inlineCallbacks - def _do_appservice_registration(self, username): - (user_id, token) = yield self.registration_handler.register( - localpart=username + def _do_appservice_registration(self, username, as_token): + (user_id, token) = yield self.registration_handler.appservice_register( + username, as_token ) defer.returnValue(self._create_registration_details(user_id, token)) diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index 3edc2ec2e9..66fd25964d 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -58,7 +58,9 @@ class RegisterRestServletTestCase(unittest.TestCase): self.appservice = { "id": "1234" } - self.registration_handler.register = Mock(return_value=(user_id, token)) + self.registration_handler.appservice_register = Mock( + return_value=(user_id, token) + ) result = yield self.servlet.on_POST(self.request) self.assertEquals(result, (200, { "user_id": user_id, -- cgit 1.5.1 From 7eea3e356ff58168f3525879a8eb684f0681ee68 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 6 Aug 2015 13:33:34 +0100 Subject: Make @cached cache deferreds rather than the deferreds' values --- synapse/storage/_base.py | 21 ++++++++------------- synapse/util/async.py | 9 +++++++-- tests/storage/test__base.py | 11 +++++++---- 3 files changed, 22 insertions(+), 19 deletions(-) (limited to 'tests') diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index f1265541ba..8604d38c3e 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -15,6 +15,7 @@ import logging from synapse.api.errors import StoreError +from synapse.util.async import ObservableDeferred from synapse.util.logutils import log_function from synapse.util.logcontext import preserve_context_over_fn, LoggingContext from synapse.util.lrucache import LruCache @@ -173,33 +174,27 @@ class CacheDescriptor(object): ) @functools.wraps(self.orig) - @defer.inlineCallbacks def wrapped(*args, **kwargs): arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names] try: cached_result = cache.get(*keyargs) - if DEBUG_CACHES: - actual_result = yield self.function_to_call(obj, *args, **kwargs) - if actual_result != cached_result: - logger.error( - "Stale cache entry %s%r: cached: %r, actual %r", - self.orig.__name__, keyargs, - cached_result, actual_result, - ) - raise ValueError("Stale cache entry") - defer.returnValue(cached_result) + return cached_result.observe() except KeyError: # Get the sequence number of the cache before reading from the # database so that we can tell if the cache is invalidated # while the SELECT is executing (SYN-369) sequence = cache.sequence - ret = yield self.function_to_call(obj, *args, **kwargs) + ret = defer.maybeDeferred( + self.function_to_call, + obj, *args, **kwargs + ) + ret = ObservableDeferred(ret, consumeErrors=False) cache.update(sequence, *(keyargs + [ret])) - defer.returnValue(ret) + return ret.observe() wrapped.invalidate = cache.invalidate wrapped.invalidate_all = cache.invalidate_all diff --git a/synapse/util/async.py b/synapse/util/async.py index 5a1d545c96..7bf2d38bb8 100644 --- a/synapse/util/async.py +++ b/synapse/util/async.py @@ -51,7 +51,7 @@ class ObservableDeferred(object): object.__setattr__(self, "_observers", set()) def callback(r): - self._result = (True, r) + object.__setattr__(self, "_result", (True, r)) while self._observers: try: self._observers.pop().callback(r) @@ -60,7 +60,7 @@ class ObservableDeferred(object): return r def errback(f): - self._result = (False, f) + object.__setattr__(self, "_result", (False, f)) while self._observers: try: self._observers.pop().errback(f) @@ -97,3 +97,8 @@ class ObservableDeferred(object): def __setattr__(self, name, value): setattr(self._deferred, name, value) + + def __repr__(self): + return "" % ( + id(self), self._result, self._deferred, + ) diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index 8c3d2952bd..8fa305d18a 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -17,6 +17,8 @@ from tests import unittest from twisted.internet import defer +from synapse.util.async import ObservableDeferred + from synapse.storage._base import Cache, cached @@ -178,19 +180,20 @@ class CacheDecoratorTestCase(unittest.TestCase): self.assertTrue(callcount[0] >= 14, msg="Expected callcount >= 14, got %d" % (callcount[0])) - @defer.inlineCallbacks def test_prefill(self): callcount = [0] + d = defer.succeed(123) + class A(object): @cached() def func(self, key): callcount[0] += 1 - return key + return d a = A() - a.func.prefill("foo", 123) + a.func.prefill("foo", ObservableDeferred(d)) - self.assertEquals((yield a.func("foo")), 123) + self.assertEquals(a.func("foo").result, d.result) self.assertEquals(callcount[0], 0) -- cgit 1.5.1 From 20addfa358e4f72b9f0c25d48c1e3ecfc08a68b2 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 7 Aug 2015 11:52:21 +0100 Subject: Change Cache to not use *args in its interface --- synapse/storage/__init__.py | 4 +-- synapse/storage/_base.py | 61 +++++++++++++++++++------------------ synapse/storage/directory.py | 4 +-- synapse/storage/event_federation.py | 4 +-- synapse/storage/events.py | 21 +++++++------ synapse/storage/keys.py | 2 +- synapse/storage/presence.py | 4 +-- synapse/storage/push_rule.py | 16 +++++----- synapse/storage/registration.py | 2 +- synapse/storage/roommember.py | 6 ++-- tests/storage/test__base.py | 12 ++++---- 11 files changed, 69 insertions(+), 67 deletions(-) (limited to 'tests') diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 71d5d92500..1a6a8a3762 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -99,7 +99,7 @@ class DataStore(RoomMemberStore, RoomStore, key = (user.to_string(), access_token, device_id, ip) try: - last_seen = self.client_ip_last_seen.get(*key) + last_seen = self.client_ip_last_seen.get(key) except KeyError: last_seen = None @@ -107,7 +107,7 @@ class DataStore(RoomMemberStore, RoomStore, if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY: defer.returnValue(None) - self.client_ip_last_seen.prefill(*key + (now,)) + self.client_ip_last_seen.prefill(key, now) # It's safe not to lock here: a) no unique constraint, # b) LAST_SEEN_GRANULARITY makes concurrent updates incredibly unlikely diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 0872a438f1..e07cf3b58a 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -57,6 +57,9 @@ cache_counter = metrics.register_cache( ) +_CacheSentinel = object() + + class Cache(object): def __init__(self, name, max_entries=1000, keylen=1, lru=True): @@ -83,41 +86,38 @@ class Cache(object): "Cache objects can only be accessed from the main thread" ) - def get(self, *keyargs): - if len(keyargs) != self.keylen: - raise ValueError("Expected a key to have %d items", self.keylen) - - if keyargs in self.cache: + def get(self, keyargs, default=_CacheSentinel): + val = self.cache.get(keyargs, _CacheSentinel) + if val is not _CacheSentinel: cache_counter.inc_hits(self.name) - return self.cache[keyargs] + return val cache_counter.inc_misses(self.name) - raise KeyError() - def update(self, sequence, *args): + if default is _CacheSentinel: + raise KeyError() + else: + return default + + def update(self, sequence, keyargs, value): self.check_thread() if self.sequence == sequence: # Only update the cache if the caches sequence number matches the # number that the cache had before the SELECT was started (SYN-369) - self.prefill(*args) - - def prefill(self, *args): # because I can't *keyargs, value - keyargs = args[:-1] - value = args[-1] - - if len(keyargs) != self.keylen: - raise ValueError("Expected a key to have %d items", self.keylen) + self.prefill(keyargs, value) + def prefill(self, keyargs, value): if self.max_entries is not None: while len(self.cache) >= self.max_entries: self.cache.popitem(last=False) self.cache[keyargs] = value - def invalidate(self, *keyargs): + def invalidate(self, keyargs): self.check_thread() - if len(keyargs) != self.keylen: - raise ValueError("Expected a key to have %d items", self.keylen) + if not isinstance(keyargs, tuple): + raise ValueError("keyargs must be a tuple.") + # Increment the sequence number so that any SELECT statements that # raced with the INSERT don't update the cache (SYN-369) self.sequence += 1 @@ -168,20 +168,21 @@ class CacheDescriptor(object): % (orig.__name__,) ) - def __get__(self, obj, objtype=None): - cache = Cache( + self.cache = Cache( name=self.orig.__name__, max_entries=self.max_entries, keylen=self.num_args, lru=self.lru, ) + def __get__(self, obj, objtype=None): + @functools.wraps(self.orig) def wrapped(*args, **kwargs): arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) - keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names] + keyargs = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names) try: - cached_result_d = cache.get(*keyargs) + cached_result_d = self.cache.get(keyargs) if DEBUG_CACHES: @@ -202,7 +203,7 @@ class CacheDescriptor(object): # Get the sequence number of the cache before reading from the # database so that we can tell if the cache is invalidated # while the SELECT is executing (SYN-369) - sequence = cache.sequence + sequence = self.cache.sequence ret = defer.maybeDeferred( self.function_to_call, @@ -210,19 +211,19 @@ class CacheDescriptor(object): ) def onErr(f): - cache.invalidate(*keyargs) + self.cache.invalidate(keyargs) return f ret.addErrback(onErr) - ret = ObservableDeferred(ret, consumeErrors=False) - cache.update(sequence, *(keyargs + [ret])) + ret = ObservableDeferred(ret, consumeErrors=True) + self.cache.update(sequence, keyargs, ret) return ret.observe() - wrapped.invalidate = cache.invalidate - wrapped.invalidate_all = cache.invalidate_all - wrapped.prefill = cache.prefill + wrapped.invalidate = self.cache.invalidate + wrapped.invalidate_all = self.cache.invalidate_all + wrapped.prefill = self.cache.prefill obj.__dict__[self.orig.__name__] = wrapped diff --git a/synapse/storage/directory.py b/synapse/storage/directory.py index 2b2bdf8615..f3947bbe89 100644 --- a/synapse/storage/directory.py +++ b/synapse/storage/directory.py @@ -104,7 +104,7 @@ class DirectoryStore(SQLBaseStore): }, desc="create_room_alias_association", ) - self.get_aliases_for_room.invalidate(room_id) + self.get_aliases_for_room.invalidate((room_id,)) @defer.inlineCallbacks def delete_room_alias(self, room_alias): @@ -114,7 +114,7 @@ class DirectoryStore(SQLBaseStore): room_alias, ) - self.get_aliases_for_room.invalidate(room_id) + self.get_aliases_for_room.invalidate((room_id,)) defer.returnValue(room_id) def _delete_room_alias_txn(self, txn, room_alias): diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py index 45b86c94e8..910b6598a7 100644 --- a/synapse/storage/event_federation.py +++ b/synapse/storage/event_federation.py @@ -362,7 +362,7 @@ class EventFederationStore(SQLBaseStore): for room_id in events_by_room: txn.call_after( - self.get_latest_event_ids_in_room.invalidate, room_id + self.get_latest_event_ids_in_room.invalidate, (room_id,) ) def get_backfill_events(self, room_id, event_list, limit): @@ -505,4 +505,4 @@ class EventFederationStore(SQLBaseStore): query = "DELETE FROM event_forward_extremities WHERE room_id = ?" txn.execute(query, (room_id,)) - txn.call_after(self.get_latest_event_ids_in_room.invalidate, room_id) + txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,)) diff --git a/synapse/storage/events.py b/synapse/storage/events.py index ed7ea38804..5b64918024 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -162,8 +162,8 @@ class EventsStore(SQLBaseStore): if current_state: txn.call_after(self.get_current_state_for_key.invalidate_all) txn.call_after(self.get_rooms_for_user.invalidate_all) - txn.call_after(self.get_users_in_room.invalidate, event.room_id) - txn.call_after(self.get_joined_hosts_for_room.invalidate, event.room_id) + txn.call_after(self.get_users_in_room.invalidate, (event.room_id,)) + txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,)) txn.call_after(self.get_room_name_and_aliases, event.room_id) self._simple_delete_txn( @@ -430,13 +430,13 @@ class EventsStore(SQLBaseStore): if not context.rejected: txn.call_after( self.get_current_state_for_key.invalidate, - event.room_id, event.type, event.state_key - ) + (event.room_id, event.type, event.state_key,) + ) if event.type in [EventTypes.Name, EventTypes.Aliases]: txn.call_after( self.get_room_name_and_aliases.invalidate, - event.room_id + (event.room_id,) ) self._simple_upsert_txn( @@ -567,8 +567,9 @@ class EventsStore(SQLBaseStore): def _invalidate_get_event_cache(self, event_id): for check_redacted in (False, True): for get_prev_content in (False, True): - self._get_event_cache.invalidate(event_id, check_redacted, - get_prev_content) + self._get_event_cache.invalidate( + (event_id, check_redacted, get_prev_content) + ) def _get_event_txn(self, txn, event_id, check_redacted=True, get_prev_content=False, allow_rejected=False): @@ -589,7 +590,7 @@ class EventsStore(SQLBaseStore): for event_id in events: try: ret = self._get_event_cache.get( - event_id, check_redacted, get_prev_content + (event_id, check_redacted, get_prev_content,) ) if allow_rejected or not ret.rejected_reason: @@ -822,7 +823,7 @@ class EventsStore(SQLBaseStore): ev.unsigned["prev_content"] = prev.get_dict()["content"] self._get_event_cache.prefill( - ev.event_id, check_redacted, get_prev_content, ev + (ev.event_id, check_redacted, get_prev_content), ev ) defer.returnValue(ev) @@ -879,7 +880,7 @@ class EventsStore(SQLBaseStore): ev.unsigned["prev_content"] = prev.get_dict()["content"] self._get_event_cache.prefill( - ev.event_id, check_redacted, get_prev_content, ev + (ev.event_id, check_redacted, get_prev_content), ev ) return ev diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py index e3f98f0cde..49b8e37cfd 100644 --- a/synapse/storage/keys.py +++ b/synapse/storage/keys.py @@ -131,7 +131,7 @@ class KeyStore(SQLBaseStore): desc="store_server_verify_key", ) - self.get_all_server_verify_keys.invalidate(server_name) + self.get_all_server_verify_keys.invalidate((server_name,)) def store_server_keys_json(self, server_name, key_id, from_server, ts_now_ms, ts_expires_ms, key_json_bytes): diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py index fefcf6bce0..576cf670cc 100644 --- a/synapse/storage/presence.py +++ b/synapse/storage/presence.py @@ -98,7 +98,7 @@ class PresenceStore(SQLBaseStore): updatevalues={"accepted": True}, desc="set_presence_list_accepted", ) - self.get_presence_list_accepted.invalidate(observer_localpart) + self.get_presence_list_accepted.invalidate((observer_localpart,)) defer.returnValue(result) def get_presence_list(self, observer_localpart, accepted=None): @@ -133,4 +133,4 @@ class PresenceStore(SQLBaseStore): "observed_user_id": observed_userid}, desc="del_presence_list", ) - self.get_presence_list_accepted.invalidate(observer_localpart) + self.get_presence_list_accepted.invalidate((observer_localpart,)) diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index a220f3632e..9b88ca7b39 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -151,11 +151,11 @@ class PushRuleStore(SQLBaseStore): txn.execute(sql, (user_name, priority_class, new_rule_priority)) txn.call_after( - self.get_push_rules_for_user.invalidate, user_name + self.get_push_rules_for_user.invalidate, (user_name,) ) txn.call_after( - self.get_push_rules_enabled_for_user.invalidate, user_name + self.get_push_rules_enabled_for_user.invalidate, (user_name,) ) self._simple_insert_txn( @@ -187,10 +187,10 @@ class PushRuleStore(SQLBaseStore): new_rule['priority'] = new_prio txn.call_after( - self.get_push_rules_for_user.invalidate, user_name + self.get_push_rules_for_user.invalidate, (user_name,) ) txn.call_after( - self.get_push_rules_enabled_for_user.invalidate, user_name + self.get_push_rules_enabled_for_user.invalidate, (user_name,) ) self._simple_insert_txn( @@ -216,8 +216,8 @@ class PushRuleStore(SQLBaseStore): desc="delete_push_rule", ) - self.get_push_rules_for_user.invalidate(user_name) - self.get_push_rules_enabled_for_user.invalidate(user_name) + self.get_push_rules_for_user.invalidate((user_name,)) + self.get_push_rules_enabled_for_user.invalidate((user_name,)) @defer.inlineCallbacks def set_push_rule_enabled(self, user_name, rule_id, enabled): @@ -238,10 +238,10 @@ class PushRuleStore(SQLBaseStore): {'id': new_id}, ) txn.call_after( - self.get_push_rules_for_user.invalidate, user_name + self.get_push_rules_for_user.invalidate, (user_name,) ) txn.call_after( - self.get_push_rules_enabled_for_user.invalidate, user_name + self.get_push_rules_enabled_for_user.invalidate, (user_name,) ) diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 90e2606be2..4eaa088b36 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -131,7 +131,7 @@ class RegistrationStore(SQLBaseStore): user_id ) for r in rows: - self.get_user_by_token.invalidate(r) + self.get_user_by_token.invalidate((r,)) @cached() def get_user_by_token(self, token): diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 55dd3f6cfb..9f14f38f24 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -54,9 +54,9 @@ class RoomMemberStore(SQLBaseStore): ) for event in events: - txn.call_after(self.get_rooms_for_user.invalidate, event.state_key) - txn.call_after(self.get_joined_hosts_for_room.invalidate, event.room_id) - txn.call_after(self.get_users_in_room.invalidate, event.room_id) + txn.call_after(self.get_rooms_for_user.invalidate, (event.state_key,)) + txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,)) + txn.call_after(self.get_users_in_room.invalidate, (event.room_id,)) def get_room_member(self, user_id, room_id): """Retrieve the current state of a room member. diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index 8fa305d18a..abee2f631d 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -42,12 +42,12 @@ class CacheTestCase(unittest.TestCase): self.assertEquals(self.cache.get("foo"), 123) def test_invalidate(self): - self.cache.prefill("foo", 123) - self.cache.invalidate("foo") + self.cache.prefill(("foo",), 123) + self.cache.invalidate(("foo",)) failed = False try: - self.cache.get("foo") + self.cache.get(("foo",)) except KeyError: failed = True @@ -141,7 +141,7 @@ class CacheDecoratorTestCase(unittest.TestCase): self.assertEquals(callcount[0], 1) - a.func.invalidate("foo") + a.func.invalidate(("foo",)) yield a.func("foo") @@ -153,7 +153,7 @@ class CacheDecoratorTestCase(unittest.TestCase): def func(self, key): return key - A().func.invalidate("what") + A().func.invalidate(("what",)) @defer.inlineCallbacks def test_max_entries(self): @@ -193,7 +193,7 @@ class CacheDecoratorTestCase(unittest.TestCase): a = A() - a.func.prefill("foo", ObservableDeferred(d)) + a.func.prefill(("foo",), ObservableDeferred(d)) self.assertEquals(a.func("foo").result, d.result) self.assertEquals(callcount[0], 0) -- cgit 1.5.1 From 4ff0228c254e146e37ffd2b1c1df22d1b67353cb Mon Sep 17 00:00:00 2001 From: Daniel Wagner-Hall Date: Tue, 11 Aug 2015 16:54:06 +0100 Subject: Remove call to recently removed function in mock --- tests/test_distributor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'tests') diff --git a/tests/test_distributor.py b/tests/test_distributor.py index 6a0095d850..8ed48cfb70 100644 --- a/tests/test_distributor.py +++ b/tests/test_distributor.py @@ -73,8 +73,8 @@ class DistributorTestCase(unittest.TestCase): yield d self.assertTrue(d.called) - observers[0].assert_called_once("Go") - observers[1].assert_called_once("Go") + observers[0].assert_called_once_with("Go") + observers[1].assert_called_once_with("Go") self.assertEquals(mock_logger.warning.call_count, 1) self.assertIsInstance(mock_logger.warning.call_args[0][0], -- cgit 1.5.1