diff options
-rwxr-xr-x | synapse/app/homeserver.py | 10 | ||||
-rw-r--r-- | synapse/config/server.py | 6 | ||||
-rw-r--r-- | synapse/handlers/federation.py | 6 | ||||
-rw-r--r-- | synapse/rest/media/v1/base_resource.py | 3 | ||||
-rw-r--r-- | synapse/state.py | 18 | ||||
-rw-r--r-- | synapse/storage/_base.py | 45 | ||||
-rw-r--r-- | synapse/storage/state.py | 20 | ||||
-rw-r--r-- | tests/handlers/test_federation.py | 4 | ||||
-rw-r--r-- | tests/storage/test__base.py | 84 | ||||
-rw-r--r-- | tests/storage/test_registration.py | 2 |
10 files changed, 131 insertions, 67 deletions
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index d93afdc1c2..65a5dfa84e 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -87,10 +87,16 @@ class SynapseHomeServer(HomeServer): return MatrixFederationHttpClient(self) def build_resource_for_client(self): - return gz_wrap(ClientV1RestResource(self)) + res = ClientV1RestResource(self) + if self.config.gzip_responses: + res = gz_wrap(res) + return res def build_resource_for_client_v2_alpha(self): - return gz_wrap(ClientV2AlphaRestResource(self)) + res = ClientV2AlphaRestResource(self) + if self.config.gzip_responses: + res = gz_wrap(res) + return res def build_resource_for_federation(self): return JsonResource(self) diff --git a/synapse/config/server.py b/synapse/config/server.py index 48a26c65d9..d0c8fb8f3c 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -29,6 +29,7 @@ class ServerConfig(Config): self.soft_file_limit = config["soft_file_limit"] self.daemonize = config.get("daemonize") self.use_frozen_dicts = config.get("use_frozen_dicts", True) + self.gzip_responses = config["gzip_responses"] # Attempt to guess the content_addr for the v0 content repostitory content_addr = config.get("content_addr") @@ -86,6 +87,11 @@ class ServerConfig(Config): # Turn on the twisted telnet manhole service on localhost on the given # port. #manhole: 9000 + + # Should synapse compress HTTP responses to clients that support it? + # This should be disabled if running synapse behind a load balancer + # that can do automatic compression. + gzip_responses: True """ % locals() def read_arguments(self, args): diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 46ce3699d7..5503d9ae86 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -900,8 +900,10 @@ class FederationHandler(BaseHandler): event.event_id, event.signatures, ) + outlier = event.internal_metadata.is_outlier() + context = yield self.state_handler.compute_event_context( - event, old_state=state + event, old_state=state, outlier=outlier, ) if not auth_events: @@ -912,7 +914,7 @@ class FederationHandler(BaseHandler): event.event_id, auth_events, ) - is_new_state = not event.internal_metadata.is_outlier() + is_new_state = not outlier # This is a hack to fix some old rooms where the initial join event # didn't reference the create event in its auth events. diff --git a/synapse/rest/media/v1/base_resource.py b/synapse/rest/media/v1/base_resource.py index 4af5f73878..9a60b777ca 100644 --- a/synapse/rest/media/v1/base_resource.py +++ b/synapse/rest/media/v1/base_resource.py @@ -15,6 +15,7 @@ from .thumbnailer import Thumbnailer +from synapse.http.matrixfederationclient import MatrixFederationHttpClient from synapse.http.server import respond_with_json from synapse.util.stringutils import random_string from synapse.api.errors import ( @@ -52,7 +53,7 @@ class BaseMediaResource(Resource): def __init__(self, hs, filepaths): Resource.__init__(self) self.auth = hs.get_auth() - self.client = hs.get_http_client() + self.client = MatrixFederationHttpClient(hs) self.clock = hs.get_clock() self.server_name = hs.hostname self.store = hs.get_datastore() diff --git a/synapse/state.py b/synapse/state.py index d4d8930001..80da90a72c 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -106,7 +106,7 @@ class StateHandler(object): defer.returnValue(state) @defer.inlineCallbacks - def compute_event_context(self, event, old_state=None): + def compute_event_context(self, event, old_state=None, outlier=False): """ Fills out the context with the `current state` of the graph. The `current state` here is defined to be the state of the event graph just before the event - i.e. it never includes `event` @@ -119,9 +119,23 @@ class StateHandler(object): Returns: an EventContext """ + yield run_on_reactor() + context = EventContext() - yield run_on_reactor() + if outlier: + # If this is an outlier, then we know it shouldn't have any current + # state. Certainly store.get_current_state won't return any, and + # persisting the event won't store the state group. + if old_state: + context.current_state = { + (s.type, s.state_key): s for s in old_state + } + else: + context.current_state = {} + context.prev_state_events = [] + context.state_group = None + defer.returnValue(context) if old_state: context.current_state = { diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 39884c2afe..8d33def6c6 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -127,7 +127,7 @@ class Cache(object): self.cache.clear() -def cached(max_entries=1000, num_args=1, lru=False): +class CacheDescriptor(object): """ A method decorator that applies a memoizing cache around the function. The function is presumed to take zero or more arguments, which are used in @@ -141,25 +141,32 @@ def cached(max_entries=1000, num_args=1, lru=False): which can be used to insert values into the cache specifically, without calling the calculation function. """ - def wrap(orig): + def __init__(self, orig, max_entries=1000, num_args=1, lru=False): + self.orig = orig + + self.max_entries = max_entries + self.num_args = num_args + self.lru = lru + + def __get__(self, obj, objtype=None): cache = Cache( - name=orig.__name__, - max_entries=max_entries, - keylen=num_args, - lru=lru, + name=self.orig.__name__, + max_entries=self.max_entries, + keylen=self.num_args, + lru=self.lru, ) - @functools.wraps(orig) + @functools.wraps(self.orig) @defer.inlineCallbacks - def wrapped(self, *keyargs): + def wrapped(*keyargs): try: - cached_result = cache.get(*keyargs) + cached_result = cache.get(*keyargs[:self.num_args]) if DEBUG_CACHES: - actual_result = yield orig(self, *keyargs) + actual_result = yield self.orig(obj, *keyargs) if actual_result != cached_result: logger.error( "Stale cache entry %s%r: cached: %r, actual %r", - orig.__name__, keyargs, + self.orig.__name__, keyargs, cached_result, actual_result, ) raise ValueError("Stale cache entry") @@ -170,18 +177,28 @@ def cached(max_entries=1000, num_args=1, lru=False): # while the SELECT is executing (SYN-369) sequence = cache.sequence - ret = yield orig(self, *keyargs) + ret = yield self.orig(obj, *keyargs) - cache.update(sequence, *keyargs + (ret,)) + cache.update(sequence, *keyargs[:self.num_args] + (ret,)) defer.returnValue(ret) wrapped.invalidate = cache.invalidate wrapped.invalidate_all = cache.invalidate_all wrapped.prefill = cache.prefill + + obj.__dict__[self.orig.__name__] = wrapped + return wrapped - return wrap + +def cached(max_entries=1000, num_args=1, lru=False): + return lambda orig: CacheDescriptor( + orig, + max_entries=max_entries, + num_args=num_args, + lru=lru + ) class LoggingTransaction(object): diff --git a/synapse/storage/state.py b/synapse/storage/state.py index b24de34f23..f2b17f29ea 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -81,19 +81,23 @@ class StateStore(SQLBaseStore): f, ) - @defer.inlineCallbacks - def c(vals): - vals[:] = yield self._get_events(vals, get_prev_content=False) - - yield defer.gatherResults( + state_list = yield defer.gatherResults( [ - c(vals) - for vals in states.values() + self._fetch_events_for_group(group, vals) + for group, vals in states.items() ], consumeErrors=True, ) - defer.returnValue(states) + defer.returnValue(dict(state_list)) + + @cached(num_args=1) + def _fetch_events_for_group(self, state_group, events): + return self._get_events( + events, get_prev_content=False + ).addCallback( + lambda evs: (state_group, evs) + ) def _store_state_groups_txn(self, txn, event, context): if context.current_state is None: diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index f3821242bc..d392c23015 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -100,7 +100,7 @@ class FederationTestCase(unittest.TestCase): return defer.succeed({}) self.datastore.have_events.side_effect = have_events - def annotate(ev, old_state=None): + def annotate(ev, old_state=None, outlier=False): context = Mock() context.current_state = {} context.auth_events = {} @@ -120,7 +120,7 @@ class FederationTestCase(unittest.TestCase): ) self.state_handler.compute_event_context.assert_called_once_with( - ANY, old_state=None, + ANY, old_state=None, outlier=False ) self.auth.check.assert_called_once_with(ANY, auth_events={}) diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index 96caf8c4c1..8c3d2952bd 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -96,73 +96,84 @@ class CacheDecoratorTestCase(unittest.TestCase): @defer.inlineCallbacks def test_passthrough(self): - @cached() - def func(self, key): - return key + class A(object): + @cached() + def func(self, key): + return key - self.assertEquals((yield func(self, "foo")), "foo") - self.assertEquals((yield func(self, "bar")), "bar") + a = A() + + self.assertEquals((yield a.func("foo")), "foo") + self.assertEquals((yield a.func("bar")), "bar") @defer.inlineCallbacks def test_hit(self): callcount = [0] - @cached() - def func(self, key): - callcount[0] += 1 - return key + class A(object): + @cached() + def func(self, key): + callcount[0] += 1 + return key - yield func(self, "foo") + a = A() + yield a.func("foo") self.assertEquals(callcount[0], 1) - self.assertEquals((yield func(self, "foo")), "foo") + self.assertEquals((yield a.func("foo")), "foo") self.assertEquals(callcount[0], 1) @defer.inlineCallbacks def test_invalidate(self): callcount = [0] - @cached() - def func(self, key): - callcount[0] += 1 - return key + class A(object): + @cached() + def func(self, key): + callcount[0] += 1 + return key - yield func(self, "foo") + a = A() + yield a.func("foo") self.assertEquals(callcount[0], 1) - func.invalidate("foo") + a.func.invalidate("foo") - yield func(self, "foo") + yield a.func("foo") self.assertEquals(callcount[0], 2) def test_invalidate_missing(self): - @cached() - def func(self, key): - return key + class A(object): + @cached() + def func(self, key): + return key - func.invalidate("what") + A().func.invalidate("what") @defer.inlineCallbacks def test_max_entries(self): callcount = [0] - @cached(max_entries=10) - def func(self, key): - callcount[0] += 1 - return key + class A(object): + @cached(max_entries=10) + def func(self, key): + callcount[0] += 1 + return key - for k in range(0,12): - yield func(self, k) + a = A() + + for k in range(0, 12): + yield a.func(k) self.assertEquals(callcount[0], 12) # There must have been at least 2 evictions, meaning if we calculate # all 12 values again, we must get called at least 2 more times for k in range(0,12): - yield func(self, k) + yield a.func(k) self.assertTrue(callcount[0] >= 14, msg="Expected callcount >= 14, got %d" % (callcount[0])) @@ -171,12 +182,15 @@ class CacheDecoratorTestCase(unittest.TestCase): def test_prefill(self): callcount = [0] - @cached() - def func(self, key): - callcount[0] += 1 - return key + class A(object): + @cached() + def func(self, key): + callcount[0] += 1 + return key + + a = A() - func.prefill("foo", 123) + a.func.prefill("foo", 123) - self.assertEquals((yield func(self, "foo")), 123) + self.assertEquals((yield a.func("foo")), 123) self.assertEquals(callcount[0], 0) diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index 78f6004204..2702291178 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -46,7 +46,7 @@ class RegistrationStoreTestCase(unittest.TestCase): (yield self.store.get_user_by_id(self.user_id)) ) - result = yield self.store.get_user_by_token(self.tokens[1]) + result = yield self.store.get_user_by_token(self.tokens[0]) self.assertDictContainsSubset( { |