diff --git a/CHANGES.rst b/CHANGES.rst
index 1ca2407a73..ac398728f8 100644
--- a/CHANGES.rst
+++ b/CHANGES.rst
@@ -1,3 +1,26 @@
+Changes in synapse v0.9.2 (2015-06-12)
+======================================
+
+General:
+
+* Use ultrajson for json (de)serialisation when a canonical encoding is not
+ required. Ultrajson is significantly faster than simplejson in certain
+ circumstances.
+* Use connection pools for outgoing HTTP connections.
+* Process thumbnails on separate threads.
+
+Configuration:
+
+* Add option, ``gzip_responses``, to disable HTTP response compression.
+
+Federation:
+
+* Improve resilience of backfill by ensuring we fetch any missing auth events.
+* Improve performance of backfill and joining remote rooms by removing
+ unnecessary computations. This included handling events we'd previously
+ handled as well as attempting to compute the current state for outliers.
+
+
Changes in synapse v0.9.1 (2015-05-26)
======================================
diff --git a/contrib/systemd/log_config.yaml b/contrib/systemd/log_config.yaml
index e16fb5456a..d85bdd1208 100644
--- a/contrib/systemd/log_config.yaml
+++ b/contrib/systemd/log_config.yaml
@@ -21,3 +21,5 @@ handlers:
root:
level: INFO
handlers: [journal]
+
+disable_existing_loggers: False
diff --git a/synapse/__init__.py b/synapse/__init__.py
index 4720d99848..e6088dc6cc 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server.
"""
-__version__ = "0.9.1"
+__version__ = "0.9.2"
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index f3513abb55..65a5dfa84e 100755
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -54,6 +54,8 @@ from synapse.rest.client.v1 import ClientV1RestResource
from synapse.rest.client.v2_alpha import ClientV2AlphaRestResource
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
+from synapse import events
+
from daemonize import Daemonize
import twisted.manhole.telnet
@@ -85,10 +87,16 @@ class SynapseHomeServer(HomeServer):
return MatrixFederationHttpClient(self)
def build_resource_for_client(self):
- return gz_wrap(ClientV1RestResource(self))
+ res = ClientV1RestResource(self)
+ if self.config.gzip_responses:
+ res = gz_wrap(res)
+ return res
def build_resource_for_client_v2_alpha(self):
- return gz_wrap(ClientV2AlphaRestResource(self))
+ res = ClientV2AlphaRestResource(self)
+ if self.config.gzip_responses:
+ res = gz_wrap(res)
+ return res
def build_resource_for_federation(self):
return JsonResource(self)
@@ -415,6 +423,8 @@ def setup(config_options):
logger.info("Server hostname: %s", config.server_name)
logger.info("Server version: %s", version_string)
+ events.USE_FROZEN_DICTS = config.use_frozen_dicts
+
if re.search(":[0-9]+$", config.server_name):
domain_with_port = config.server_name
else:
diff --git a/synapse/config/captcha.py b/synapse/config/captcha.py
index d8fe577e34..ba221121cb 100644
--- a/synapse/config/captcha.py
+++ b/synapse/config/captcha.py
@@ -26,6 +26,7 @@ class CaptchaConfig(Config):
config["captcha_ip_origin_is_x_forwarded"]
)
self.captcha_bypass_secret = config.get("captcha_bypass_secret")
+ self.recaptcha_siteverify_api = config["recaptcha_siteverify_api"]
def default_config(self, config_dir_path, server_name):
return """\
@@ -48,4 +49,7 @@ class CaptchaConfig(Config):
# A secret key used to bypass the captcha test entirely.
#captcha_bypass_secret: "YOUR_SECRET_HERE"
+
+ # The API endpoint to use for verifying m.login.recaptcha responses.
+ recaptcha_siteverify_api: "https://www.google.com/recaptcha/api/siteverify"
"""
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index b39989a87f..67e780864e 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -39,7 +39,7 @@ class RegistrationConfig(Config):
## Registration ##
# Enable registration for new users.
- enable_registration: True
+ enable_registration: False
# If set, allows registration by anyone who also has the shared
# secret, even if registration is otherwise disabled.
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 78195b3a4f..d0c8fb8f3c 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -28,6 +28,8 @@ class ServerConfig(Config):
self.web_client = config["web_client"]
self.soft_file_limit = config["soft_file_limit"]
self.daemonize = config.get("daemonize")
+ self.use_frozen_dicts = config.get("use_frozen_dicts", True)
+ self.gzip_responses = config["gzip_responses"]
# Attempt to guess the content_addr for the v0 content repostitory
content_addr = config.get("content_addr")
@@ -85,6 +87,11 @@ class ServerConfig(Config):
# Turn on the twisted telnet manhole service on localhost on the given
# port.
#manhole: 9000
+
+ # Should synapse compress HTTP responses to clients that support it?
+ # This should be disabled if running synapse behind a load balancer
+ # that can do automatic compression.
+ gzip_responses: True
""" % locals()
def read_arguments(self, args):
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index e4495ccf12..39ce4f7c42 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -16,6 +16,12 @@
from synapse.util.frozenutils import freeze
+# Whether we should use frozen_dict in FrozenEvent. Using frozen_dicts prevents
+# bugs where we accidentally share e.g. signature dicts. However, converting
+# a dict to frozen_dicts is expensive.
+USE_FROZEN_DICTS = True
+
+
class _EventInternalMetadata(object):
def __init__(self, internal_metadata_dict):
self.__dict__ = dict(internal_metadata_dict)
@@ -122,7 +128,10 @@ class FrozenEvent(EventBase):
unsigned = dict(event_dict.pop("unsigned", {}))
- frozen_dict = freeze(event_dict)
+ if USE_FROZEN_DICTS:
+ frozen_dict = freeze(event_dict)
+ else:
+ frozen_dict = event_dict
super(FrozenEvent, self).__init__(
frozen_dict,
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index f0430b2cb1..299493af91 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -18,8 +18,6 @@ from twisted.internet import defer
from synapse.events.utils import prune_event
-from syutil.jsonutil import encode_canonical_json
-
from synapse.crypto.event_signing import check_event_content_hash
from synapse.api.errors import SynapseError
@@ -120,16 +118,15 @@ class FederationBase(object):
)
except SynapseError:
logger.warn(
- "Signature check failed for %s redacted to %s",
- encode_canonical_json(pdu.get_pdu_json()),
- encode_canonical_json(redacted_pdu_json),
+ "Signature check failed for %s",
+ pdu.event_id,
)
raise
if not check_event_content_hash(pdu):
logger.warn(
- "Event content has been tampered, redacting %s, %s",
- pdu.event_id, encode_canonical_json(pdu.get_dict())
+ "Event content has been tampered, redacting.",
+ pdu.event_id,
)
defer.returnValue(redacted_event)
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index af87805f34..31190e700a 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -93,6 +93,8 @@ class TransportLayerServer(object):
yield self.keyring.verify_json_for_server(origin, json_request)
+ logger.info("Request from %s", origin)
+
defer.returnValue((origin, content))
@log_function
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index 833ff41377..d6c064b398 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -78,7 +78,9 @@ class BaseHandler(object):
context = yield state_handler.compute_event_context(builder)
if builder.is_state():
- builder.prev_state = context.prev_state_events
+ builder.prev_state = yield self.store.add_event_hashes(
+ context.prev_state_events
+ )
yield self.auth.add_auth_events(builder, context)
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 4e2e50345e..63071653a3 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -187,8 +187,8 @@ class AuthHandler(BaseHandler):
# each request
try:
client = SimpleHttpClient(self.hs)
- data = yield client.post_urlencoded_get_json(
- "https://www.google.com/recaptcha/api/siteverify",
+ resp_body = yield client.post_urlencoded_get_json(
+ self.hs.config.recaptcha_siteverify_api,
args={
'secret': self.hs.config.recaptcha_private_key,
'response': user_response,
@@ -198,7 +198,8 @@ class AuthHandler(BaseHandler):
except PartialDownloadError as pde:
# Twisted is silly
data = pde.response
- resp_body = simplejson.loads(data)
+ resp_body = simplejson.loads(data)
+
if 'success' in resp_body and resp_body['success']:
defer.returnValue(True)
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 46ce3699d7..b5d882fd65 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -247,9 +247,15 @@ class FederationHandler(BaseHandler):
if set(e_id for e_id, _ in ev.prev_events) - event_ids
]
+ logger.info(
+ "backfill: Got %d events with %d edges",
+ len(events), len(edges),
+ )
+
# For each edge get the current state.
auth_events = {}
+ state_events = {}
events_to_state = {}
for e_id in edges:
state, auth = yield self.replication_layer.get_state_for_room(
@@ -258,12 +264,46 @@ class FederationHandler(BaseHandler):
event_id=e_id
)
auth_events.update({a.event_id: a for a in auth})
+ auth_events.update({s.event_id: s for s in state})
+ state_events.update({s.event_id: s for s in state})
events_to_state[e_id] = state
+ seen_events = yield self.store.have_events(
+ set(auth_events.keys()) | set(state_events.keys())
+ )
+
+ all_events = events + state_events.values() + auth_events.values()
+ required_auth = set(
+ a_id for event in all_events for a_id, _ in event.auth_events
+ )
+
+ missing_auth = required_auth - set(auth_events)
+ results = yield defer.gatherResults(
+ [
+ self.replication_layer.get_pdu(
+ [dest],
+ event_id,
+ outlier=True,
+ timeout=10000,
+ )
+ for event_id in missing_auth
+ ],
+ consumeErrors=True
+ ).addErrback(unwrapFirstError)
+ auth_events.update({a.event_id: a for a in results})
+
yield defer.gatherResults(
[
- self._handle_new_event(dest, a)
+ self._handle_new_event(
+ dest, a,
+ auth_events={
+ (auth_events[a_id].type, auth_events[a_id].state_key):
+ auth_events[a_id]
+ for a_id, _ in a.auth_events
+ },
+ )
for a in auth_events.values()
+ if a.event_id not in seen_events
],
consumeErrors=True,
).addErrback(unwrapFirstError)
@@ -274,6 +314,11 @@ class FederationHandler(BaseHandler):
dest, event_map[e_id],
state=events_to_state[e_id],
backfilled=True,
+ auth_events={
+ (auth_events[a_id].type, auth_events[a_id].state_key):
+ auth_events[a_id]
+ for a_id, _ in event_map[e_id].auth_events
+ },
)
for e_id in events_to_state
],
@@ -900,8 +945,10 @@ class FederationHandler(BaseHandler):
event.event_id, event.signatures,
)
+ outlier = event.internal_metadata.is_outlier()
+
context = yield self.state_handler.compute_event_context(
- event, old_state=state
+ event, old_state=state, outlier=outlier,
)
if not auth_events:
@@ -912,7 +959,7 @@ class FederationHandler(BaseHandler):
event.event_id, auth_events,
)
- is_new_state = not event.internal_metadata.is_outlier()
+ is_new_state = not outlier
# This is a hack to fix some old rooms where the initial join event
# didn't reference the create event in its auth events.
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 5b3cefb2dc..e746f2416e 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -20,7 +20,8 @@ import synapse.metrics
from twisted.internet import defer, reactor
from twisted.web.client import (
- Agent, readBody, FileBodyProducer, PartialDownloadError
+ Agent, readBody, FileBodyProducer, PartialDownloadError,
+ HTTPConnectionPool,
)
from twisted.web.http_headers import Headers
@@ -55,7 +56,9 @@ class SimpleHttpClient(object):
# The default context factory in Twisted 14.0.0 (which we require) is
# BrowserLikePolicyForHTTPS which will do regular cert validation
# 'like a browser'
- self.agent = Agent(reactor)
+ pool = HTTPConnectionPool(reactor)
+ pool.maxPersistentPerHost = 10
+ self.agent = Agent(reactor, pool=pool)
self.version_string = hs.version_string
def request(self, method, *args, **kwargs):
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 6f976d5ce8..7f3d8fc884 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -16,7 +16,7 @@
from twisted.internet import defer, reactor, protocol
from twisted.internet.error import DNSLookupError
-from twisted.web.client import readBody, _AgentBase, _URI
+from twisted.web.client import readBody, _AgentBase, _URI, HTTPConnectionPool
from twisted.web.http_headers import Headers
from twisted.web._newclient import ResponseDone
@@ -103,7 +103,9 @@ class MatrixFederationHttpClient(object):
self.hs = hs
self.signing_key = hs.config.signing_key[0]
self.server_name = hs.hostname
- self.agent = MatrixFederationHttpAgent(reactor)
+ pool = HTTPConnectionPool(reactor)
+ pool.maxPersistentPerHost = 10
+ self.agent = MatrixFederationHttpAgent(reactor, pool=pool)
self.clock = hs.get_clock()
self.version_string = hs.version_string
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 73efbff4f2..ae8f3b3972 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -19,9 +19,10 @@ from synapse.api.errors import (
)
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
import synapse.metrics
+import synapse.events
from syutil.jsonutil import (
- encode_canonical_json, encode_pretty_printed_json
+ encode_canonical_json, encode_pretty_printed_json, encode_json
)
from twisted.internet import defer
@@ -168,9 +169,10 @@ class JsonResource(HttpServer, resource.Resource):
_PathEntry = collections.namedtuple("_PathEntry", ["pattern", "callback"])
- def __init__(self, hs):
+ def __init__(self, hs, canonical_json=True):
resource.Resource.__init__(self)
+ self.canonical_json = canonical_json
self.clock = hs.get_clock()
self.path_regexs = {}
self.version_string = hs.version_string
@@ -256,6 +258,7 @@ class JsonResource(HttpServer, resource.Resource):
response_code_message=response_code_message,
pretty_print=_request_user_agent_is_curl(request),
version_string=self.version_string,
+ canonical_json=self.canonical_json,
)
@@ -277,11 +280,16 @@ class RootRedirect(resource.Resource):
def respond_with_json(request, code, json_object, send_cors=False,
response_code_message=None, pretty_print=False,
- version_string=""):
+ version_string="", canonical_json=True):
if pretty_print:
json_bytes = encode_pretty_printed_json(json_object) + "\n"
else:
- json_bytes = encode_canonical_json(json_object)
+ if canonical_json:
+ json_bytes = encode_canonical_json(json_object)
+ else:
+ json_bytes = encode_json(
+ json_object, using_frozen_dicts=synapse.events.USE_FROZEN_DICTS
+ )
return respond_with_json_bytes(
request, code, json_bytes,
diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py
index 8059fff1b2..36f450c31d 100644
--- a/synapse/push/__init__.py
+++ b/synapse/push/__init__.py
@@ -24,6 +24,7 @@ import baserules
import logging
import simplejson as json
import re
+import random
logger = logging.getLogger(__name__)
@@ -256,134 +257,154 @@ class Pusher(object):
logger.info("Pusher %s for user %s starting from token %s",
self.pushkey, self.user_name, self.last_token)
+ wait = 0
while self.alive:
- from_tok = StreamToken.from_string(self.last_token)
- config = PaginationConfig(from_token=from_tok, limit='1')
- chunk = yield self.evStreamHandler.get_stream(
- self.user_name, config,
- timeout=100*365*24*60*60*1000, affect_presence=False
- )
+ try:
+ if wait > 0:
+ yield synapse.util.async.sleep(wait)
+ yield self.get_and_dispatch()
+ wait = 0
+ except:
+ if wait == 0:
+ wait = 1
+ else:
+ wait = min(wait * 2, 1800)
+ logger.exception(
+ "Exception in pusher loop for pushkey %s. Pausing for %ds",
+ self.pushkey, wait
+ )
- # limiting to 1 may get 1 event plus 1 presence event, so
- # pick out the actual event
- single_event = None
- for c in chunk['chunk']:
- if 'event_id' in c: # Hmmm...
- single_event = c
- break
- if not single_event:
- self.last_token = chunk['end']
- continue
+ @defer.inlineCallbacks
+ def get_and_dispatch(self):
+ from_tok = StreamToken.from_string(self.last_token)
+ config = PaginationConfig(from_token=from_tok, limit='1')
+ timeout = (300 + random.randint(-60, 60)) * 1000
+ chunk = yield self.evStreamHandler.get_stream(
+ self.user_name, config,
+ timeout=timeout, affect_presence=False
+ )
- if not self.alive:
- continue
+ # limiting to 1 may get 1 event plus 1 presence event, so
+ # pick out the actual event
+ single_event = None
+ for c in chunk['chunk']:
+ if 'event_id' in c: # Hmmm...
+ single_event = c
+ break
+ if not single_event:
+ self.last_token = chunk['end']
+ logger.debug("Event stream timeout for pushkey %s", self.pushkey)
+ return
- processed = False
- actions = yield self._actions_for_event(single_event)
- tweaks = _tweaks_for_actions(actions)
+ if not self.alive:
+ return
- if len(actions) == 0:
- logger.warn("Empty actions! Using default action.")
- actions = Pusher.DEFAULT_ACTIONS
+ processed = False
+ actions = yield self._actions_for_event(single_event)
+ tweaks = _tweaks_for_actions(actions)
- if 'notify' not in actions and 'dont_notify' not in actions:
- logger.warn("Neither notify nor dont_notify in actions: adding default")
- actions.extend(Pusher.DEFAULT_ACTIONS)
+ if len(actions) == 0:
+ logger.warn("Empty actions! Using default action.")
+ actions = Pusher.DEFAULT_ACTIONS
- if 'dont_notify' in actions:
- logger.debug(
- "%s for %s: dont_notify",
- single_event['event_id'], self.user_name
- )
+ if 'notify' not in actions and 'dont_notify' not in actions:
+ logger.warn("Neither notify nor dont_notify in actions: adding default")
+ actions.extend(Pusher.DEFAULT_ACTIONS)
+
+ if 'dont_notify' in actions:
+ logger.debug(
+ "%s for %s: dont_notify",
+ single_event['event_id'], self.user_name
+ )
+ processed = True
+ else:
+ rejected = yield self.dispatch_push(single_event, tweaks)
+ self.has_unread = True
+ if isinstance(rejected, list) or isinstance(rejected, tuple):
processed = True
- else:
- rejected = yield self.dispatch_push(single_event, tweaks)
- self.has_unread = True
- if isinstance(rejected, list) or isinstance(rejected, tuple):
- processed = True
- for pk in rejected:
- if pk != self.pushkey:
- # for sanity, we only remove the pushkey if it
- # was the one we actually sent...
- logger.warn(
- ("Ignoring rejected pushkey %s because we"
- " didn't send it"), pk
- )
- else:
- logger.info(
- "Pushkey %s was rejected: removing",
- pk
- )
- yield self.hs.get_pusherpool().remove_pusher(
- self.app_id, pk, self.user_name
- )
-
- if not self.alive:
- continue
+ for pk in rejected:
+ if pk != self.pushkey:
+ # for sanity, we only remove the pushkey if it
+ # was the one we actually sent...
+ logger.warn(
+ ("Ignoring rejected pushkey %s because we"
+ " didn't send it"), pk
+ )
+ else:
+ logger.info(
+ "Pushkey %s was rejected: removing",
+ pk
+ )
+ yield self.hs.get_pusherpool().remove_pusher(
+ self.app_id, pk, self.user_name
+ )
+
+ if not self.alive:
+ return
+
+ if processed:
+ self.backoff_delay = Pusher.INITIAL_BACKOFF
+ self.last_token = chunk['end']
+ self.store.update_pusher_last_token_and_success(
+ self.app_id,
+ self.pushkey,
+ self.user_name,
+ self.last_token,
+ self.clock.time_msec()
+ )
+ if self.failing_since:
+ self.failing_since = None
+ self.store.update_pusher_failing_since(
+ self.app_id,
+ self.pushkey,
+ self.user_name,
+ self.failing_since)
+ else:
+ if not self.failing_since:
+ self.failing_since = self.clock.time_msec()
+ self.store.update_pusher_failing_since(
+ self.app_id,
+ self.pushkey,
+ self.user_name,
+ self.failing_since
+ )
- if processed:
+ if (self.failing_since and
+ self.failing_since <
+ self.clock.time_msec() - Pusher.GIVE_UP_AFTER):
+ # we really only give up so that if the URL gets
+ # fixed, we don't suddenly deliver a load
+ # of old notifications.
+ logger.warn("Giving up on a notification to user %s, "
+ "pushkey %s",
+ self.user_name, self.pushkey)
self.backoff_delay = Pusher.INITIAL_BACKOFF
self.last_token = chunk['end']
- self.store.update_pusher_last_token_and_success(
+ self.store.update_pusher_last_token(
+ self.app_id,
+ self.pushkey,
+ self.user_name,
+ self.last_token
+ )
+
+ self.failing_since = None
+ self.store.update_pusher_failing_since(
self.app_id,
self.pushkey,
self.user_name,
- self.last_token,
- self.clock.time_msec()
+ self.failing_since
)
- if self.failing_since:
- self.failing_since = None
- self.store.update_pusher_failing_since(
- self.app_id,
- self.pushkey,
- self.user_name,
- self.failing_since)
else:
- if not self.failing_since:
- self.failing_since = self.clock.time_msec()
- self.store.update_pusher_failing_since(
- self.app_id,
- self.pushkey,
- self.user_name,
- self.failing_since
- )
-
- if (self.failing_since and
- self.failing_since <
- self.clock.time_msec() - Pusher.GIVE_UP_AFTER):
- # we really only give up so that if the URL gets
- # fixed, we don't suddenly deliver a load
- # of old notifications.
- logger.warn("Giving up on a notification to user %s, "
- "pushkey %s",
- self.user_name, self.pushkey)
- self.backoff_delay = Pusher.INITIAL_BACKOFF
- self.last_token = chunk['end']
- self.store.update_pusher_last_token(
- self.app_id,
- self.pushkey,
- self.user_name,
- self.last_token
- )
-
- self.failing_since = None
- self.store.update_pusher_failing_since(
- self.app_id,
- self.pushkey,
- self.user_name,
- self.failing_since
- )
- else:
- logger.warn("Failed to dispatch push for user %s "
- "(failing for %dms)."
- "Trying again in %dms",
- self.user_name,
- self.clock.time_msec() - self.failing_since,
- self.backoff_delay)
- yield synapse.util.async.sleep(self.backoff_delay / 1000.0)
- self.backoff_delay *= 2
- if self.backoff_delay > Pusher.MAX_BACKOFF:
- self.backoff_delay = Pusher.MAX_BACKOFF
+ logger.warn("Failed to dispatch push for user %s "
+ "(failing for %dms)."
+ "Trying again in %dms",
+ self.user_name,
+ self.clock.time_msec() - self.failing_since,
+ self.backoff_delay)
+ yield synapse.util.async.sleep(self.backoff_delay / 1000.0)
+ self.backoff_delay *= 2
+ if self.backoff_delay > Pusher.MAX_BACKOFF:
+ self.backoff_delay = Pusher.MAX_BACKOFF
def stop(self):
self.alive = False
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index a45dd3c93d..f9e59dd917 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -18,7 +18,7 @@ from distutils.version import LooseVersion
logger = logging.getLogger(__name__)
REQUIREMENTS = {
- "syutil>=0.0.6": ["syutil>=0.0.6"],
+ "syutil>=0.0.7": ["syutil>=0.0.7"],
"Twisted==14.0.2": ["twisted==14.0.2"],
"service_identity>=1.0.0": ["service_identity>=1.0.0"],
"pyopenssl>=0.14": ["OpenSSL>=0.14"],
@@ -30,6 +30,7 @@ REQUIREMENTS = {
"frozendict>=0.4": ["frozendict"],
"pillow": ["PIL"],
"pydenticon": ["pydenticon"],
+ "ujson": ["ujson"],
}
CONDITIONAL_REQUIREMENTS = {
"web_client": {
@@ -52,8 +53,8 @@ def github_link(project, version, egg):
DEPENDENCY_LINKS = [
github_link(
project="matrix-org/syutil",
- version="v0.0.6",
- egg="syutil-0.0.6",
+ version="v0.0.7",
+ egg="syutil-0.0.7",
),
github_link(
project="matrix-org/matrix-angular-sdk",
diff --git a/synapse/rest/client/v1/__init__.py b/synapse/rest/client/v1/__init__.py
index 21876b3487..cc9b49d539 100644
--- a/synapse/rest/client/v1/__init__.py
+++ b/synapse/rest/client/v1/__init__.py
@@ -25,7 +25,7 @@ class ClientV1RestResource(JsonResource):
"""A resource for version 1 of the matrix client API."""
def __init__(self, hs):
- JsonResource.__init__(self, hs)
+ JsonResource.__init__(self, hs, canonical_json=False)
self.register_servlets(self, hs)
@staticmethod
diff --git a/synapse/rest/client/v2_alpha/__init__.py b/synapse/rest/client/v2_alpha/__init__.py
index 28d95b2729..7d1aff4307 100644
--- a/synapse/rest/client/v2_alpha/__init__.py
+++ b/synapse/rest/client/v2_alpha/__init__.py
@@ -28,7 +28,7 @@ class ClientV2AlphaRestResource(JsonResource):
"""A resource for version 2 alpha of the matrix client API."""
def __init__(self, hs):
- JsonResource.__init__(self, hs)
+ JsonResource.__init__(self, hs, canonical_json=False)
self.register_servlets(self, hs)
@staticmethod
diff --git a/synapse/rest/media/v1/base_resource.py b/synapse/rest/media/v1/base_resource.py
index 4af5f73878..6c83a9478c 100644
--- a/synapse/rest/media/v1/base_resource.py
+++ b/synapse/rest/media/v1/base_resource.py
@@ -15,13 +15,14 @@
from .thumbnailer import Thumbnailer
+from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.http.server import respond_with_json
from synapse.util.stringutils import random_string
from synapse.api.errors import (
cs_error, Codes, SynapseError
)
-from twisted.internet import defer
+from twisted.internet import defer, threads
from twisted.web.resource import Resource
from twisted.protocols.basic import FileSender
@@ -52,7 +53,7 @@ class BaseMediaResource(Resource):
def __init__(self, hs, filepaths):
Resource.__init__(self)
self.auth = hs.get_auth()
- self.client = hs.get_http_client()
+ self.client = MatrixFederationHttpClient(hs)
self.clock = hs.get_clock()
self.server_name = hs.hostname
self.store = hs.get_datastore()
@@ -273,57 +274,65 @@ class BaseMediaResource(Resource):
if not requirements:
return
+ remote_thumbnails = []
+
input_path = self.filepaths.remote_media_filepath(server_name, file_id)
thumbnailer = Thumbnailer(input_path)
m_width = thumbnailer.width
m_height = thumbnailer.height
- if m_width * m_height >= self.max_image_pixels:
- logger.info(
- "Image too large to thumbnail %r x %r > %r",
- m_width, m_height, self.max_image_pixels
- )
- return
-
- scales = set()
- crops = set()
- for r_width, r_height, r_method, r_type in requirements:
- if r_method == "scale":
- t_width, t_height = thumbnailer.aspect(r_width, r_height)
- scales.add((
- min(m_width, t_width), min(m_height, t_height), r_type,
- ))
- elif r_method == "crop":
- crops.add((r_width, r_height, r_type))
+ def generate_thumbnails():
+ if m_width * m_height >= self.max_image_pixels:
+ logger.info(
+ "Image too large to thumbnail %r x %r > %r",
+ m_width, m_height, self.max_image_pixels
+ )
+ return
+
+ scales = set()
+ crops = set()
+ for r_width, r_height, r_method, r_type in requirements:
+ if r_method == "scale":
+ t_width, t_height = thumbnailer.aspect(r_width, r_height)
+ scales.add((
+ min(m_width, t_width), min(m_height, t_height), r_type,
+ ))
+ elif r_method == "crop":
+ crops.add((r_width, r_height, r_type))
+
+ for t_width, t_height, t_type in scales:
+ t_method = "scale"
+ t_path = self.filepaths.remote_media_thumbnail(
+ server_name, file_id, t_width, t_height, t_type, t_method
+ )
+ self._makedirs(t_path)
+ t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
+ remote_thumbnails.append([
+ server_name, media_id, file_id,
+ t_width, t_height, t_type, t_method, t_len
+ ])
+
+ for t_width, t_height, t_type in crops:
+ if (t_width, t_height, t_type) in scales:
+ # If the aspect ratio of the cropped thumbnail matches a purely
+ # scaled one then there is no point in calculating a separate
+ # thumbnail.
+ continue
+ t_method = "crop"
+ t_path = self.filepaths.remote_media_thumbnail(
+ server_name, file_id, t_width, t_height, t_type, t_method
+ )
+ self._makedirs(t_path)
+ t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
+ remote_thumbnails.append([
+ server_name, media_id, file_id,
+ t_width, t_height, t_type, t_method, t_len
+ ])
- for t_width, t_height, t_type in scales:
- t_method = "scale"
- t_path = self.filepaths.remote_media_thumbnail(
- server_name, file_id, t_width, t_height, t_type, t_method
- )
- self._makedirs(t_path)
- t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
- yield self.store.store_remote_media_thumbnail(
- server_name, media_id, file_id,
- t_width, t_height, t_type, t_method, t_len
- )
+ yield threads.deferToThread(generate_thumbnails)
- for t_width, t_height, t_type in crops:
- if (t_width, t_height, t_type) in scales:
- # If the aspect ratio of the cropped thumbnail matches a purely
- # scaled one then there is no point in calculating a separate
- # thumbnail.
- continue
- t_method = "crop"
- t_path = self.filepaths.remote_media_thumbnail(
- server_name, file_id, t_width, t_height, t_type, t_method
- )
- self._makedirs(t_path)
- t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
- yield self.store.store_remote_media_thumbnail(
- server_name, media_id, file_id,
- t_width, t_height, t_type, t_method, t_len
- )
+ for r in remote_thumbnails:
+ yield self.store.store_remote_media_thumbnail(*r)
defer.returnValue({
"width": m_width,
diff --git a/synapse/state.py b/synapse/state.py
index 9dddb77d5b..80da90a72c 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -106,7 +106,7 @@ class StateHandler(object):
defer.returnValue(state)
@defer.inlineCallbacks
- def compute_event_context(self, event, old_state=None):
+ def compute_event_context(self, event, old_state=None, outlier=False):
""" Fills out the context with the `current state` of the graph. The
`current state` here is defined to be the state of the event graph
just before the event - i.e. it never includes `event`
@@ -119,9 +119,23 @@ class StateHandler(object):
Returns:
an EventContext
"""
+ yield run_on_reactor()
+
context = EventContext()
- yield run_on_reactor()
+ if outlier:
+ # If this is an outlier, then we know it shouldn't have any current
+ # state. Certainly store.get_current_state won't return any, and
+ # persisting the event won't store the state group.
+ if old_state:
+ context.current_state = {
+ (s.type, s.state_key): s for s in old_state
+ }
+ else:
+ context.current_state = {}
+ context.prev_state_events = []
+ context.state_group = None
+ defer.returnValue(context)
if old_state:
context.current_state = {
@@ -155,10 +169,6 @@ class StateHandler(object):
context.current_state = curr_state
context.state_group = group if not event.is_state() else None
- prev_state = yield self.store.add_event_hashes(
- prev_state
- )
-
if event.is_state():
key = (event.type, event.state_key)
if key in context.current_state:
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 75af44d787..c137f47820 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -51,7 +51,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts.
-SCHEMA_VERSION = 19
+SCHEMA_VERSION = 20
dir_path = os.path.abspath(os.path.dirname(__file__))
@@ -348,7 +348,7 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
module_name, absolute_path, python_file
)
logger.debug("Running script %s", relative_path)
- module.run_upgrade(cur)
+ module.run_upgrade(cur, database_engine)
elif ext == ".sql":
# A plain old .sql file, just read and execute it
logger.debug("Applying schema %s", relative_path)
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 39884c2afe..8d33def6c6 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -127,7 +127,7 @@ class Cache(object):
self.cache.clear()
-def cached(max_entries=1000, num_args=1, lru=False):
+class CacheDescriptor(object):
""" A method decorator that applies a memoizing cache around the function.
The function is presumed to take zero or more arguments, which are used in
@@ -141,25 +141,32 @@ def cached(max_entries=1000, num_args=1, lru=False):
which can be used to insert values into the cache specifically, without
calling the calculation function.
"""
- def wrap(orig):
+ def __init__(self, orig, max_entries=1000, num_args=1, lru=False):
+ self.orig = orig
+
+ self.max_entries = max_entries
+ self.num_args = num_args
+ self.lru = lru
+
+ def __get__(self, obj, objtype=None):
cache = Cache(
- name=orig.__name__,
- max_entries=max_entries,
- keylen=num_args,
- lru=lru,
+ name=self.orig.__name__,
+ max_entries=self.max_entries,
+ keylen=self.num_args,
+ lru=self.lru,
)
- @functools.wraps(orig)
+ @functools.wraps(self.orig)
@defer.inlineCallbacks
- def wrapped(self, *keyargs):
+ def wrapped(*keyargs):
try:
- cached_result = cache.get(*keyargs)
+ cached_result = cache.get(*keyargs[:self.num_args])
if DEBUG_CACHES:
- actual_result = yield orig(self, *keyargs)
+ actual_result = yield self.orig(obj, *keyargs)
if actual_result != cached_result:
logger.error(
"Stale cache entry %s%r: cached: %r, actual %r",
- orig.__name__, keyargs,
+ self.orig.__name__, keyargs,
cached_result, actual_result,
)
raise ValueError("Stale cache entry")
@@ -170,18 +177,28 @@ def cached(max_entries=1000, num_args=1, lru=False):
# while the SELECT is executing (SYN-369)
sequence = cache.sequence
- ret = yield orig(self, *keyargs)
+ ret = yield self.orig(obj, *keyargs)
- cache.update(sequence, *keyargs + (ret,))
+ cache.update(sequence, *keyargs[:self.num_args] + (ret,))
defer.returnValue(ret)
wrapped.invalidate = cache.invalidate
wrapped.invalidate_all = cache.invalidate_all
wrapped.prefill = cache.prefill
+
+ obj.__dict__[self.orig.__name__] = wrapped
+
return wrapped
- return wrap
+
+def cached(max_entries=1000, num_args=1, lru=False):
+ return lambda orig: CacheDescriptor(
+ orig,
+ max_entries=max_entries,
+ num_args=num_args,
+ lru=lru
+ )
class LoggingTransaction(object):
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index d2a010bd88..2caf0aae80 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -17,7 +17,7 @@ from _base import SQLBaseStore, _RollbackButIsFineException
from twisted.internet import defer, reactor
-from synapse.events import FrozenEvent
+from synapse.events import FrozenEvent, USE_FROZEN_DICTS
from synapse.events.utils import prune_event
from synapse.util.logcontext import preserve_context_over_deferred
@@ -26,11 +26,11 @@ from synapse.api.constants import EventTypes
from synapse.crypto.event_signing import compute_event_reference_hash
from syutil.base64util import decode_base64
-from syutil.jsonutil import encode_canonical_json
+from syutil.jsonutil import encode_json
from contextlib import contextmanager
import logging
-import simplejson as json
+import ujson as json
logger = logging.getLogger(__name__)
@@ -166,8 +166,9 @@ class EventsStore(SQLBaseStore):
allow_none=True,
)
- metadata_json = encode_canonical_json(
- event.internal_metadata.get_dict()
+ metadata_json = encode_json(
+ event.internal_metadata.get_dict(),
+ using_frozen_dicts=USE_FROZEN_DICTS
).decode("UTF-8")
# If we have already persisted this event, we don't need to do any
@@ -235,12 +236,14 @@ class EventsStore(SQLBaseStore):
"event_id": event.event_id,
"room_id": event.room_id,
"internal_metadata": metadata_json,
- "json": encode_canonical_json(event_dict).decode("UTF-8"),
+ "json": encode_json(
+ event_dict, using_frozen_dicts=USE_FROZEN_DICTS
+ ).decode("UTF-8"),
},
)
- content = encode_canonical_json(
- event.content
+ content = encode_json(
+ event.content, using_frozen_dicts=USE_FROZEN_DICTS
).decode("UTF-8")
vals = {
@@ -266,8 +269,8 @@ class EventsStore(SQLBaseStore):
]
}
- vals["unrecognized_keys"] = encode_canonical_json(
- unrec
+ vals["unrecognized_keys"] = encode_json(
+ unrec, using_frozen_dicts=USE_FROZEN_DICTS
).decode("UTF-8")
sql = (
@@ -733,7 +736,8 @@ class EventsStore(SQLBaseStore):
because = yield self.get_event(
redaction_id,
- check_redacted=False
+ check_redacted=False,
+ allow_none=True,
)
if because:
@@ -743,6 +747,7 @@ class EventsStore(SQLBaseStore):
prev = yield self.get_event(
ev.unsigned["replaces_state"],
get_prev_content=False,
+ allow_none=True,
)
if prev:
ev.unsigned["prev_content"] = prev.get_dict()["content"]
diff --git a/synapse/storage/schema/delta/14/upgrade_appservice_db.py b/synapse/storage/schema/delta/14/upgrade_appservice_db.py
index 9f3a4dd4c5..61232f9757 100644
--- a/synapse/storage/schema/delta/14/upgrade_appservice_db.py
+++ b/synapse/storage/schema/delta/14/upgrade_appservice_db.py
@@ -18,7 +18,7 @@ import logging
logger = logging.getLogger(__name__)
-def run_upgrade(cur):
+def run_upgrade(cur, *args, **kwargs):
cur.execute("SELECT id, regex FROM application_services_regex")
for row in cur.fetchall():
try:
diff --git a/synapse/storage/schema/delta/20/pushers.py b/synapse/storage/schema/delta/20/pushers.py
new file mode 100644
index 0000000000..543e57bbe2
--- /dev/null
+++ b/synapse/storage/schema/delta/20/pushers.py
@@ -0,0 +1,76 @@
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+Main purpose of this upgrade is to change the unique key on the
+pushers table again (it was missed when the v16 full schema was
+made) but this also changes the pushkey and data columns to text.
+When selecting a bytea column into a text column, postgres inserts
+the hex encoded data, and there's no portable way of getting the
+UTF-8 bytes, so we have to do it in Python.
+"""
+
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+def run_upgrade(cur, database_engine, *args, **kwargs):
+ logger.info("Porting pushers table...")
+ cur.execute("""
+ CREATE TABLE IF NOT EXISTS pushers2 (
+ id BIGINT PRIMARY KEY,
+ user_name TEXT NOT NULL,
+ access_token BIGINT DEFAULT NULL,
+ profile_tag VARCHAR(32) NOT NULL,
+ kind VARCHAR(8) NOT NULL,
+ app_id VARCHAR(64) NOT NULL,
+ app_display_name VARCHAR(64) NOT NULL,
+ device_display_name VARCHAR(128) NOT NULL,
+ pushkey TEXT NOT NULL,
+ ts BIGINT NOT NULL,
+ lang VARCHAR(8),
+ data TEXT,
+ last_token TEXT,
+ last_success BIGINT,
+ failing_since BIGINT,
+ UNIQUE (app_id, pushkey, user_name)
+ )
+ """)
+ cur.execute("""SELECT
+ id, user_name, access_token, profile_tag, kind,
+ app_id, app_display_name, device_display_name,
+ pushkey, ts, lang, data, last_token, last_success,
+ failing_since
+ FROM pushers
+ """)
+ count = 0
+ for row in cur.fetchall():
+ row = list(row)
+ row[8] = bytes(row[8]).decode("utf-8")
+ row[11] = bytes(row[11]).decode("utf-8")
+ cur.execute(database_engine.convert_param_style("""
+ INSERT into pushers2 (
+ id, user_name, access_token, profile_tag, kind,
+ app_id, app_display_name, device_display_name,
+ pushkey, ts, lang, data, last_token, last_success,
+ failing_since
+ ) values (%s)""" % (','.join(['?' for _ in range(len(row))]))),
+ row
+ )
+ count += 1
+ cur.execute("DROP TABLE pushers")
+ cur.execute("ALTER TABLE pushers2 RENAME TO pushers")
+ logger.info("Moved %d pushers to new table", count)
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index b24de34f23..f2b17f29ea 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -81,19 +81,23 @@ class StateStore(SQLBaseStore):
f,
)
- @defer.inlineCallbacks
- def c(vals):
- vals[:] = yield self._get_events(vals, get_prev_content=False)
-
- yield defer.gatherResults(
+ state_list = yield defer.gatherResults(
[
- c(vals)
- for vals in states.values()
+ self._fetch_events_for_group(group, vals)
+ for group, vals in states.items()
],
consumeErrors=True,
)
- defer.returnValue(states)
+ defer.returnValue(dict(state_list))
+
+ @cached(num_args=1)
+ def _fetch_events_for_group(self, state_group, events):
+ return self._get_events(
+ events, get_prev_content=False
+ ).addCallback(
+ lambda evs: (state_group, evs)
+ )
def _store_state_groups_txn(self, txn, event, context):
if context.current_state is None:
diff --git a/synapse/util/jsonobject.py b/synapse/util/jsonobject.py
index 0765f7d217..00f86ed220 100644
--- a/synapse/util/jsonobject.py
+++ b/synapse/util/jsonobject.py
@@ -13,8 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import copy
-
class JsonEncodedObject(object):
""" A common base class for defining protocol units that are represented
@@ -76,15 +74,7 @@ class JsonEncodedObject(object):
if k in self.valid_keys and k not in self.internal_keys
}
d.update(self.unrecognized_keys)
- return copy.deepcopy(d)
-
- def get_full_dict(self):
- d = {
- k: _encode(v) for (k, v) in self.__dict__.items()
- if k in self.valid_keys or k in self.internal_keys
- }
- d.update(self.unrecognized_keys)
- return copy.deepcopy(d)
+ return d
def __str__(self):
return "(%s, %s)" % (self.__class__.__name__, repr(self.__dict__))
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index f3821242bc..d392c23015 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -100,7 +100,7 @@ class FederationTestCase(unittest.TestCase):
return defer.succeed({})
self.datastore.have_events.side_effect = have_events
- def annotate(ev, old_state=None):
+ def annotate(ev, old_state=None, outlier=False):
context = Mock()
context.current_state = {}
context.auth_events = {}
@@ -120,7 +120,7 @@ class FederationTestCase(unittest.TestCase):
)
self.state_handler.compute_event_context.assert_called_once_with(
- ANY, old_state=None,
+ ANY, old_state=None, outlier=False
)
self.auth.check.assert_called_once_with(ANY, auth_events={})
diff --git a/tests/handlers/test_room.py b/tests/handlers/test_room.py
index a2d7635995..2a7553f982 100644
--- a/tests/handlers/test_room.py
+++ b/tests/handlers/test_room.py
@@ -42,6 +42,7 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
"get_room",
"store_room",
"get_latest_events_in_room",
+ "add_event_hashes",
]),
resource_for_federation=NonCallableMock(),
http_client=NonCallableMock(spec_set=[]),
@@ -88,6 +89,7 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
self.ratelimiter.send_message.return_value = (True, 0)
self.datastore.persist_event.return_value = (1,1)
+ self.datastore.add_event_hashes.return_value = []
@defer.inlineCallbacks
def test_invite(self):
diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index 96caf8c4c1..8c3d2952bd 100644
--- a/tests/storage/test__base.py
+++ b/tests/storage/test__base.py
@@ -96,73 +96,84 @@ class CacheDecoratorTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_passthrough(self):
- @cached()
- def func(self, key):
- return key
+ class A(object):
+ @cached()
+ def func(self, key):
+ return key
- self.assertEquals((yield func(self, "foo")), "foo")
- self.assertEquals((yield func(self, "bar")), "bar")
+ a = A()
+
+ self.assertEquals((yield a.func("foo")), "foo")
+ self.assertEquals((yield a.func("bar")), "bar")
@defer.inlineCallbacks
def test_hit(self):
callcount = [0]
- @cached()
- def func(self, key):
- callcount[0] += 1
- return key
+ class A(object):
+ @cached()
+ def func(self, key):
+ callcount[0] += 1
+ return key
- yield func(self, "foo")
+ a = A()
+ yield a.func("foo")
self.assertEquals(callcount[0], 1)
- self.assertEquals((yield func(self, "foo")), "foo")
+ self.assertEquals((yield a.func("foo")), "foo")
self.assertEquals(callcount[0], 1)
@defer.inlineCallbacks
def test_invalidate(self):
callcount = [0]
- @cached()
- def func(self, key):
- callcount[0] += 1
- return key
+ class A(object):
+ @cached()
+ def func(self, key):
+ callcount[0] += 1
+ return key
- yield func(self, "foo")
+ a = A()
+ yield a.func("foo")
self.assertEquals(callcount[0], 1)
- func.invalidate("foo")
+ a.func.invalidate("foo")
- yield func(self, "foo")
+ yield a.func("foo")
self.assertEquals(callcount[0], 2)
def test_invalidate_missing(self):
- @cached()
- def func(self, key):
- return key
+ class A(object):
+ @cached()
+ def func(self, key):
+ return key
- func.invalidate("what")
+ A().func.invalidate("what")
@defer.inlineCallbacks
def test_max_entries(self):
callcount = [0]
- @cached(max_entries=10)
- def func(self, key):
- callcount[0] += 1
- return key
+ class A(object):
+ @cached(max_entries=10)
+ def func(self, key):
+ callcount[0] += 1
+ return key
- for k in range(0,12):
- yield func(self, k)
+ a = A()
+
+ for k in range(0, 12):
+ yield a.func(k)
self.assertEquals(callcount[0], 12)
# There must have been at least 2 evictions, meaning if we calculate
# all 12 values again, we must get called at least 2 more times
for k in range(0,12):
- yield func(self, k)
+ yield a.func(k)
self.assertTrue(callcount[0] >= 14,
msg="Expected callcount >= 14, got %d" % (callcount[0]))
@@ -171,12 +182,15 @@ class CacheDecoratorTestCase(unittest.TestCase):
def test_prefill(self):
callcount = [0]
- @cached()
- def func(self, key):
- callcount[0] += 1
- return key
+ class A(object):
+ @cached()
+ def func(self, key):
+ callcount[0] += 1
+ return key
+
+ a = A()
- func.prefill("foo", 123)
+ a.func.prefill("foo", 123)
- self.assertEquals((yield func(self, "foo")), 123)
+ self.assertEquals((yield a.func("foo")), 123)
self.assertEquals(callcount[0], 0)
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index 78f6004204..2702291178 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -46,7 +46,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
(yield self.store.get_user_by_id(self.user_id))
)
- result = yield self.store.get_user_by_token(self.tokens[1])
+ result = yield self.store.get_user_by_token(self.tokens[0])
self.assertDictContainsSubset(
{
|