summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--README.rst4
-rw-r--r--docs/postgres.rst6
-rw-r--r--docs/turn-howto.rst38
-rwxr-xr-xscripts-dev/nuke-room-from-db.sh43
-rwxr-xr-xscripts/synapse_port_db4
-rwxr-xr-xsynapse/app/synctl.py3
-rw-r--r--synapse/appservice/__init__.py38
-rw-r--r--synapse/config/voip.py8
-rw-r--r--synapse/push/push_tools.py7
-rw-r--r--synapse/rest/client/v1/voip.py5
-rw-r--r--synapse/rest/client/v2_alpha/thirdparty.py8
-rw-r--r--synapse/types.py4
-rw-r--r--synapse/util/async.py7
-rw-r--r--synapse/util/caches/descriptors.py42
-rw-r--r--synapse/visibility.py3
-rw-r--r--tests/appservice/test_appservice.py4
-rw-r--r--tests/storage/test__base.py2
-rw-r--r--tests/util/caches/test_descriptors.py38
-rw-r--r--tests/util/test_snapshot_cache.py4
19 files changed, 202 insertions, 66 deletions
diff --git a/README.rst b/README.rst
index b9c854ad48..759197f5ff 100644
--- a/README.rst
+++ b/README.rst
@@ -108,10 +108,10 @@ Installing prerequisites on ArchLinux::
     sudo pacman -S base-devel python2 python-pip \
                    python-setuptools python-virtualenv sqlite3
 
-Installing prerequisites on CentOS 7::
+Installing prerequisites on CentOS 7 or Fedora 25::
 
     sudo yum install libtiff-devel libjpeg-devel libzip-devel freetype-devel \
-                     lcms2-devel libwebp-devel tcl-devel tk-devel \
+                     lcms2-devel libwebp-devel tcl-devel tk-devel redhat-rpm-config \
                      python-virtualenv libffi-devel openssl-devel
     sudo yum groupinstall "Development Tools"
 
diff --git a/docs/postgres.rst b/docs/postgres.rst
index 402ff9a4de..b592801e93 100644
--- a/docs/postgres.rst
+++ b/docs/postgres.rst
@@ -112,9 +112,9 @@ script one last time, e.g. if the SQLite database is at  ``homeserver.db``
 run::
 
     synapse_port_db --sqlite-database homeserver.db \
-        --postgres-config database_config.yaml
+        --postgres-config homeserver-postgres.yaml
 
 Once that has completed, change the synapse config to point at the PostgreSQL
-database configuration file using the ``database_config`` parameter (see
-`Synapse Config`_) and restart synapse. Synapse should now be running against
+database configuration file ``homeserver-postgres.yaml`` (i.e. rename it to 
+``homeserver.yaml``) and restart synapse. Synapse should now be running against
 PostgreSQL.
diff --git a/docs/turn-howto.rst b/docs/turn-howto.rst
index 04c0100715..e48628ce6e 100644
--- a/docs/turn-howto.rst
+++ b/docs/turn-howto.rst
@@ -50,14 +50,37 @@ You may be able to setup coturn via your package manager,  or set it up manually
 
        pwgen -s 64 1
 
- 5. Ensure youe firewall allows traffic into the TURN server on
+ 5. Consider your security settings.  TURN lets users request a relay
+    which will connect to arbitrary IP addresses and ports.  At the least
+    we recommend:
+
+       # VoIP traffic is all UDP. There is no reason to let users connect to arbitrary TCP endpoints via the relay.
+       no-tcp-relay
+
+       # don't let the relay ever try to connect to private IP address ranges within your network (if any)
+       # given the turn server is likely behind your firewall, remember to include any privileged public IPs too.
+       denied-peer-ip=10.0.0.0-10.255.255.255
+       denied-peer-ip=192.168.0.0-192.168.255.255
+       denied-peer-ip=172.16.0.0-172.31.255.255
+
+       # special case the turn server itself so that client->TURN->TURN->client flows work
+       allowed-peer-ip=10.0.0.1
+
+       # consider whether you want to limit the quota of relayed streams per user (or total) to avoid risk of DoS.
+       user-quota=12 # 4 streams per video call, so 12 streams = 3 simultaneous relayed calls per user.
+       total-quota=1200
+
+    Ideally coturn should refuse to relay traffic which isn't SRTP;
+    see https://github.com/matrix-org/synapse/issues/2009
+
+ 6. Ensure your firewall allows traffic into the TURN server on
     the ports you've configured it to listen on (remember to allow
-    both TCP and UDP if you've enabled both).
+    both TCP and UDP TURN traffic)
 
- 6. If you've configured coturn to support TLS/DTLS, generate or
+ 7. If you've configured coturn to support TLS/DTLS, generate or
     import your private key and certificate.
 
- 7. Start the turn server::
+ 8. Start the turn server::
  
        bin/turnserver -o
 
@@ -83,12 +106,19 @@ Your home server configuration file needs the following extra keys:
     to refresh credentials. The TURN REST API specification recommends
     one day (86400000).
 
+  4. "turn_allow_guests": Whether to allow guest users to use the TURN
+    server.  This is enabled by default, as otherwise VoIP will not
+    work reliably for guests.  However, it does introduce a security risk
+    as it lets guests connect to arbitrary endpoints without having gone
+    through a CAPTCHA or similar to register a real account.
+
 As an example, here is the relevant section of the config file for
 matrix.org::
 
     turn_uris: [ "turn:turn.matrix.org:3478?transport=udp", "turn:turn.matrix.org:3478?transport=tcp" ]
     turn_shared_secret: n0t4ctuAllymatr1Xd0TorgSshar3d5ecret4obvIousreAsons
     turn_user_lifetime: 86400000
+    turn_allow_guests: True
 
 Now, restart synapse::
 
diff --git a/scripts-dev/nuke-room-from-db.sh b/scripts-dev/nuke-room-from-db.sh
index 58c036c896..1201d176c2 100755
--- a/scripts-dev/nuke-room-from-db.sh
+++ b/scripts-dev/nuke-room-from-db.sh
@@ -9,16 +9,39 @@
 ROOMID="$1"
 
 sqlite3 homeserver.db <<EOF
-DELETE FROM context_depth WHERE context = '$ROOMID';
-DELETE FROM current_state WHERE context = '$ROOMID';
-DELETE FROM feedback WHERE room_id = '$ROOMID';
-DELETE FROM messages WHERE room_id = '$ROOMID';
-DELETE FROM pdu_backward_extremities WHERE context = '$ROOMID';
-DELETE FROM pdu_edges WHERE context = '$ROOMID';
-DELETE FROM pdu_forward_extremities WHERE context = '$ROOMID';
-DELETE FROM pdus WHERE context = '$ROOMID';
-DELETE FROM room_data WHERE room_id = '$ROOMID';
+DELETE FROM event_forward_extremities WHERE room_id = '$ROOMID';
+DELETE FROM event_backward_extremities WHERE room_id = '$ROOMID';
+DELETE FROM event_edges WHERE room_id = '$ROOMID';
+DELETE FROM room_depth WHERE room_id = '$ROOMID';
+DELETE FROM state_forward_extremities WHERE room_id = '$ROOMID';
+DELETE FROM events WHERE room_id = '$ROOMID';
+DELETE FROM event_json WHERE room_id = '$ROOMID';
+DELETE FROM state_events WHERE room_id = '$ROOMID';
+DELETE FROM current_state_events WHERE room_id = '$ROOMID';
 DELETE FROM room_memberships WHERE room_id = '$ROOMID';
+DELETE FROM feedback WHERE room_id = '$ROOMID';
+DELETE FROM topics WHERE room_id = '$ROOMID';
+DELETE FROM room_names WHERE room_id = '$ROOMID';
 DELETE FROM rooms WHERE room_id = '$ROOMID';
-DELETE FROM state_pdus WHERE context = '$ROOMID';
+DELETE FROM room_hosts WHERE room_id = '$ROOMID';
+DELETE FROM room_aliases WHERE room_id = '$ROOMID';
+DELETE FROM state_groups WHERE room_id = '$ROOMID';
+DELETE FROM state_groups_state WHERE room_id = '$ROOMID';
+DELETE FROM receipts_graph WHERE room_id = '$ROOMID';
+DELETE FROM receipts_linearized WHERE room_id = '$ROOMID';
+DELETE FROM event_search_content WHERE c1room_id = '$ROOMID';
+DELETE FROM guest_access WHERE room_id = '$ROOMID';
+DELETE FROM history_visibility WHERE room_id = '$ROOMID';
+DELETE FROM room_tags WHERE room_id = '$ROOMID';
+DELETE FROM room_tags_revisions WHERE room_id = '$ROOMID';
+DELETE FROM room_account_data WHERE room_id = '$ROOMID';
+DELETE FROM event_push_actions WHERE room_id = '$ROOMID';
+DELETE FROM local_invites WHERE room_id = '$ROOMID';
+DELETE FROM pusher_throttle WHERE room_id = '$ROOMID';
+DELETE FROM event_reports WHERE room_id = '$ROOMID';
+DELETE FROM public_room_list_stream WHERE room_id = '$ROOMID';
+DELETE FROM stream_ordering_to_exterm WHERE room_id = '$ROOMID';
+DELETE FROM event_auth WHERE room_id = '$ROOMID';
+DELETE FROM appservice_room_list WHERE room_id = '$ROOMID';
+VACUUM;
 EOF
diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db
index ea367a1281..2e5d666707 100755
--- a/scripts/synapse_port_db
+++ b/scripts/synapse_port_db
@@ -447,9 +447,7 @@ class Porter(object):
 
             postgres_tables = yield self.postgres_store._simple_select_onecol(
                 table="information_schema.tables",
-                keyvalues={
-                    "table_schema": "public",
-                },
+                keyvalues={},
                 retcol="distinct table_name",
             )
 
diff --git a/synapse/app/synctl.py b/synapse/app/synctl.py
index 23eb6a1ec4..e8218d01ad 100755
--- a/synapse/app/synctl.py
+++ b/synapse/app/synctl.py
@@ -202,7 +202,8 @@ def main():
         worker_app = worker_config["worker_app"]
         worker_pidfile = worker_config["worker_pid_file"]
         worker_daemonize = worker_config["worker_daemonize"]
-        assert worker_daemonize  # TODO print something more user friendly
+        assert worker_daemonize, "In config %r: expected '%s' to be True" % (
+            worker_configfile, "worker_daemonize")
         worker_cache_factor = worker_config.get("synctl_cache_factor")
         workers.append(Worker(
             worker_app, worker_configfile, worker_pidfile, worker_cache_factor,
diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py
index b0106a3597..7346206bb1 100644
--- a/synapse/appservice/__init__.py
+++ b/synapse/appservice/__init__.py
@@ -13,6 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 from synapse.api.constants import EventTypes
+from synapse.util.caches.descriptors import cachedInlineCallbacks
 
 from twisted.internet import defer
 
@@ -124,29 +125,23 @@ class ApplicationService(object):
                     raise ValueError(
                         "Expected bool for 'exclusive' in ns '%s'" % ns
                     )
-                if not isinstance(regex_obj.get("regex"), basestring):
+                regex = regex_obj.get("regex")
+                if isinstance(regex, basestring):
+                    regex_obj["regex"] = re.compile(regex)  # Pre-compile regex
+                else:
                     raise ValueError(
                         "Expected string for 'regex' in ns '%s'" % ns
                     )
         return namespaces
 
-    def _matches_regex(self, test_string, namespace_key, return_obj=False):
-        if not isinstance(test_string, basestring):
-            logger.error(
-                "Expected a string to test regex against, but got %s",
-                test_string
-            )
-            return False
-
+    def _matches_regex(self, test_string, namespace_key):
         for regex_obj in self.namespaces[namespace_key]:
-            if re.match(regex_obj["regex"], test_string):
-                if return_obj:
-                    return regex_obj
-                return True
-        return False
+            if regex_obj["regex"].match(test_string):
+                return regex_obj
+        return None
 
     def _is_exclusive(self, ns_key, test_string):
-        regex_obj = self._matches_regex(test_string, ns_key, return_obj=True)
+        regex_obj = self._matches_regex(test_string, ns_key)
         if regex_obj:
             return regex_obj["exclusive"]
         return False
@@ -166,7 +161,14 @@ class ApplicationService(object):
         if not store:
             defer.returnValue(False)
 
-        member_list = yield store.get_users_in_room(event.room_id)
+        does_match = yield self._matches_user_in_member_list(event.room_id, store)
+        defer.returnValue(does_match)
+
+    @cachedInlineCallbacks(num_args=1, cache_context=True)
+    def _matches_user_in_member_list(self, room_id, store, cache_context):
+        member_list = yield store.get_users_in_room(
+            room_id, on_invalidate=cache_context.invalidate
+        )
 
         # check joined member events
         for user_id in member_list:
@@ -219,10 +221,10 @@ class ApplicationService(object):
         )
 
     def is_interested_in_alias(self, alias):
-        return self._matches_regex(alias, ApplicationService.NS_ALIASES)
+        return bool(self._matches_regex(alias, ApplicationService.NS_ALIASES))
 
     def is_interested_in_room(self, room_id):
-        return self._matches_regex(room_id, ApplicationService.NS_ROOMS)
+        return bool(self._matches_regex(room_id, ApplicationService.NS_ROOMS))
 
     def is_exclusive_user(self, user_id):
         return (
diff --git a/synapse/config/voip.py b/synapse/config/voip.py
index eeb693027b..3a4e16fa96 100644
--- a/synapse/config/voip.py
+++ b/synapse/config/voip.py
@@ -23,6 +23,7 @@ class VoipConfig(Config):
         self.turn_username = config.get("turn_username")
         self.turn_password = config.get("turn_password")
         self.turn_user_lifetime = self.parse_duration(config["turn_user_lifetime"])
+        self.turn_allow_guests = config.get("turn_allow_guests", True)
 
     def default_config(self, **kwargs):
         return """\
@@ -41,4 +42,11 @@ class VoipConfig(Config):
 
         # How long generated TURN credentials last
         turn_user_lifetime: "1h"
+
+        # Whether guests should be allowed to use the TURN server.
+        # This defaults to True, otherwise VoIP will be unreliable for guests.
+        # However, it does introduce a slight security risk as it allows users to
+        # connect to arbitrary endpoints without having first signed up for a
+        # valid account (e.g. by passing a CAPTCHA).
+        turn_allow_guests: True
         """
diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py
index 287df94b4f..6835f54e97 100644
--- a/synapse/push/push_tools.py
+++ b/synapse/push/push_tools.py
@@ -17,15 +17,12 @@ from twisted.internet import defer
 from synapse.push.presentable_names import (
     calculate_room_name, name_from_member_event
 )
-from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
 
 
 @defer.inlineCallbacks
 def get_badge_count(store, user_id):
-    invites, joins = yield preserve_context_over_deferred(defer.gatherResults([
-        preserve_fn(store.get_invited_rooms_for_user)(user_id),
-        preserve_fn(store.get_rooms_for_user)(user_id),
-    ], consumeErrors=True))
+    invites = yield store.get_invited_rooms_for_user(user_id)
+    joins = yield store.get_rooms_for_user(user_id)
 
     my_receipts_by_room = yield store.get_receipts_for_user(
         user_id, "m.read",
diff --git a/synapse/rest/client/v1/voip.py b/synapse/rest/client/v1/voip.py
index 03141c623c..c43b30b73a 100644
--- a/synapse/rest/client/v1/voip.py
+++ b/synapse/rest/client/v1/voip.py
@@ -28,7 +28,10 @@ class VoipRestServlet(ClientV1RestServlet):
 
     @defer.inlineCallbacks
     def on_GET(self, request):
-        requester = yield self.auth.get_user_by_req(request)
+        requester = yield self.auth.get_user_by_req(
+            request,
+            self.hs.config.turn_allow_guests
+        )
 
         turnUris = self.hs.config.turn_uris
         turnSecret = self.hs.config.turn_shared_secret
diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/v2_alpha/thirdparty.py
index 31f94bc6e9..6fceb23e26 100644
--- a/synapse/rest/client/v2_alpha/thirdparty.py
+++ b/synapse/rest/client/v2_alpha/thirdparty.py
@@ -36,7 +36,7 @@ class ThirdPartyProtocolsServlet(RestServlet):
 
     @defer.inlineCallbacks
     def on_GET(self, request):
-        yield self.auth.get_user_by_req(request)
+        yield self.auth.get_user_by_req(request, allow_guest=True)
 
         protocols = yield self.appservice_handler.get_3pe_protocols()
         defer.returnValue((200, protocols))
@@ -54,7 +54,7 @@ class ThirdPartyProtocolServlet(RestServlet):
 
     @defer.inlineCallbacks
     def on_GET(self, request, protocol):
-        yield self.auth.get_user_by_req(request)
+        yield self.auth.get_user_by_req(request, allow_guest=True)
 
         protocols = yield self.appservice_handler.get_3pe_protocols(
             only_protocol=protocol,
@@ -77,7 +77,7 @@ class ThirdPartyUserServlet(RestServlet):
 
     @defer.inlineCallbacks
     def on_GET(self, request, protocol):
-        yield self.auth.get_user_by_req(request)
+        yield self.auth.get_user_by_req(request, allow_guest=True)
 
         fields = request.args
         fields.pop("access_token", None)
@@ -101,7 +101,7 @@ class ThirdPartyLocationServlet(RestServlet):
 
     @defer.inlineCallbacks
     def on_GET(self, request, protocol):
-        yield self.auth.get_user_by_req(request)
+        yield self.auth.get_user_by_req(request, allow_guest=True)
 
         fields = request.args
         fields.pop("access_token", None)
diff --git a/synapse/types.py b/synapse/types.py
index 9666f9d73f..c87ed813b9 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -216,9 +216,7 @@ class StreamToken(
             return self
 
     def copy_and_replace(self, key, new_value):
-        d = self._asdict()
-        d[key] = new_value
-        return StreamToken(**d)
+        return self._replace(**{key: new_value})
 
 
 StreamToken.START = StreamToken(
diff --git a/synapse/util/async.py b/synapse/util/async.py
index 35380bf8ed..1453faf0ef 100644
--- a/synapse/util/async.py
+++ b/synapse/util/async.py
@@ -89,6 +89,11 @@ class ObservableDeferred(object):
         deferred.addCallbacks(callback, errback)
 
     def observe(self):
+        """Observe the underlying deferred.
+
+        Can return either a deferred if the underlying deferred is still pending
+        (or has failed), or the actual value. Callers may need to use maybeDeferred.
+        """
         if not self._result:
             d = defer.Deferred()
 
@@ -101,7 +106,7 @@ class ObservableDeferred(object):
             return d
         else:
             success, res = self._result
-            return defer.succeed(res) if success else defer.fail(res)
+            return res if success else defer.fail(res)
 
     def observers(self):
         return self._observers
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 5c30ed235d..9d0d0be1f9 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -224,8 +224,20 @@ class _CacheDescriptorBase(object):
             )
 
         self.num_args = num_args
+
+        # list of the names of the args used as the cache key
         self.arg_names = all_args[1:num_args + 1]
 
+        # self.arg_defaults is a map of arg name to its default value for each
+        # argument that has a default value
+        if arg_spec.defaults:
+            self.arg_defaults = dict(zip(
+                all_args[-len(arg_spec.defaults):],
+                arg_spec.defaults
+            ))
+        else:
+            self.arg_defaults = {}
+
         if "cache_context" in self.arg_names:
             raise Exception(
                 "cache_context arg cannot be included among the cache keys"
@@ -289,18 +301,31 @@ class CacheDescriptor(_CacheDescriptorBase):
             iterable=self.iterable,
         )
 
+        def get_cache_key(args, kwargs):
+            """Given some args/kwargs return a generator that resolves into
+            the cache_key.
+
+            We loop through each arg name, looking up if its in the `kwargs`,
+            otherwise using the next argument in `args`. If there are no more
+            args then we try looking the arg name up in the defaults
+            """
+            pos = 0
+            for nm in self.arg_names:
+                if nm in kwargs:
+                    yield kwargs[nm]
+                elif pos < len(args):
+                    yield args[pos]
+                    pos += 1
+                else:
+                    yield self.arg_defaults[nm]
+
         @functools.wraps(self.orig)
         def wrapped(*args, **kwargs):
             # If we're passed a cache_context then we'll want to call its invalidate()
             # whenever we are invalidated
             invalidate_callback = kwargs.pop("on_invalidate", None)
 
-            # Add temp cache_context so inspect.getcallargs doesn't explode
-            if self.add_cache_context:
-                kwargs["cache_context"] = None
-
-            arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
-            cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
+            cache_key = tuple(get_cache_key(args, kwargs))
 
             # Add our own `cache_context` to argument list if the wrapped function
             # has asked for one
@@ -341,7 +366,10 @@ class CacheDescriptor(_CacheDescriptorBase):
                 cache.set(cache_key, result_d, callback=invalidate_callback)
                 observer = result_d.observe()
 
-            return logcontext.make_deferred_yieldable(observer)
+            if isinstance(observer, defer.Deferred):
+                return logcontext.make_deferred_yieldable(observer)
+            else:
+                return observer
 
         wrapped.invalidate = cache.invalidate
         wrapped.invalidate_all = cache.invalidate_all
diff --git a/synapse/visibility.py b/synapse/visibility.py
index 31659156ae..c4dd9ae2c7 100644
--- a/synapse/visibility.py
+++ b/synapse/visibility.py
@@ -56,7 +56,8 @@ def filter_events_for_clients(store, user_tuples, events, event_id_to_state):
         events ([synapse.events.EventBase]): list of events to filter
     """
     forgotten = yield preserve_context_over_deferred(defer.gatherResults([
-        preserve_fn(store.who_forgot_in_room)(
+        defer.maybeDeferred(
+            preserve_fn(store.who_forgot_in_room),
             room_id,
         )
         for room_id in frozenset(e.room_id for e in events)
diff --git a/tests/appservice/test_appservice.py b/tests/appservice/test_appservice.py
index aa8cc50550..7586ea9053 100644
--- a/tests/appservice/test_appservice.py
+++ b/tests/appservice/test_appservice.py
@@ -19,10 +19,12 @@ from twisted.internet import defer
 from mock import Mock
 from tests import unittest
 
+import re
+
 
 def _regex(regex, exclusive=True):
     return {
-        "regex": regex,
+        "regex": re.compile(regex),
         "exclusive": exclusive
     }
 
diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index 8361dd8cee..281eb16254 100644
--- a/tests/storage/test__base.py
+++ b/tests/storage/test__base.py
@@ -199,7 +199,7 @@ class CacheDecoratorTestCase(unittest.TestCase):
 
         a.func.prefill(("foo",), ObservableDeferred(d))
 
-        self.assertEquals(a.func("foo").result, d.result)
+        self.assertEquals(a.func("foo"), d.result)
         self.assertEquals(callcount[0], 0)
 
     @defer.inlineCallbacks
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 4414e86771..3f14ab503f 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -175,3 +175,41 @@ class DescriptorTestCase(unittest.TestCase):
                          logcontext.LoggingContext.sentinel)
 
         return d1
+
+    @defer.inlineCallbacks
+    def test_cache_default_args(self):
+        class Cls(object):
+            def __init__(self):
+                self.mock = mock.Mock()
+
+            @descriptors.cached()
+            def fn(self, arg1, arg2=2, arg3=3):
+                return self.mock(arg1, arg2, arg3)
+
+        obj = Cls()
+
+        obj.mock.return_value = 'fish'
+        r = yield obj.fn(1, 2, 3)
+        self.assertEqual(r, 'fish')
+        obj.mock.assert_called_once_with(1, 2, 3)
+        obj.mock.reset_mock()
+
+        # a call with same params shouldn't call the mock again
+        r = yield obj.fn(1, 2)
+        self.assertEqual(r, 'fish')
+        obj.mock.assert_not_called()
+        obj.mock.reset_mock()
+
+        # a call with different params should call the mock again
+        obj.mock.return_value = 'chips'
+        r = yield obj.fn(2, 3)
+        self.assertEqual(r, 'chips')
+        obj.mock.assert_called_once_with(2, 3, 3)
+        obj.mock.reset_mock()
+
+        # the two values should now be cached
+        r = yield obj.fn(1, 2)
+        self.assertEqual(r, 'fish')
+        r = yield obj.fn(2, 3)
+        self.assertEqual(r, 'chips')
+        obj.mock.assert_not_called()
diff --git a/tests/util/test_snapshot_cache.py b/tests/util/test_snapshot_cache.py
index 7e289715ba..d3a8630c2f 100644
--- a/tests/util/test_snapshot_cache.py
+++ b/tests/util/test_snapshot_cache.py
@@ -53,7 +53,9 @@ class SnapshotCacheTestCase(unittest.TestCase):
         # before the cache expires returns a resolved deferred.
         get_result_at_11 = self.cache.get(11, "key")
         self.assertIsNotNone(get_result_at_11)
-        self.assertTrue(get_result_at_11.called)
+        if isinstance(get_result_at_11, Deferred):
+            # The cache may return the actual result rather than a deferred
+            self.assertTrue(get_result_at_11.called)
 
         # Check that getting the key after the deferred has resolved
         # after the cache expires returns None