diff --git a/tests/__init__.py b/tests/__init__.py
index bfebb0f644..aab20e8e02 100644
--- a/tests/__init__.py
+++ b/tests/__init__.py
@@ -12,3 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
+from twisted.trial import util
+util.DEFAULT_TIMEOUT_DURATION = 10
diff --git a/tests/metrics/__init__.py b/tests/metrics/__init__.py
deleted file mode 100644
index e69de29bb2..0000000000
--- a/tests/metrics/__init__.py
+++ /dev/null
diff --git a/tests/metrics/test_metric.py b/tests/metrics/test_metric.py
deleted file mode 100644
index 069c0be762..0000000000
--- a/tests/metrics/test_metric.py
+++ /dev/null
@@ -1,192 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2015, 2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from tests import unittest
-
-from synapse.metrics.metric import (
- CounterMetric, CallbackMetric, DistributionMetric, CacheMetric,
- _escape_label_value,
-)
-
-
-class CounterMetricTestCase(unittest.TestCase):
-
- def test_scalar(self):
- counter = CounterMetric("scalar")
-
- self.assertEquals(counter.render(), [
- 'scalar 0',
- ])
-
- counter.inc()
-
- self.assertEquals(counter.render(), [
- 'scalar 1',
- ])
-
- counter.inc_by(2)
-
- self.assertEquals(counter.render(), [
- 'scalar 3'
- ])
-
- def test_vector(self):
- counter = CounterMetric("vector", labels=["method"])
-
- # Empty counter doesn't yet know what values it has
- self.assertEquals(counter.render(), [])
-
- counter.inc("GET")
-
- self.assertEquals(counter.render(), [
- 'vector{method="GET"} 1',
- ])
-
- counter.inc("GET")
- counter.inc("PUT")
-
- self.assertEquals(counter.render(), [
- 'vector{method="GET"} 2',
- 'vector{method="PUT"} 1',
- ])
-
-
-class CallbackMetricTestCase(unittest.TestCase):
-
- def test_scalar(self):
- d = dict()
-
- metric = CallbackMetric("size", lambda: len(d))
-
- self.assertEquals(metric.render(), [
- 'size 0',
- ])
-
- d["key"] = "value"
-
- self.assertEquals(metric.render(), [
- 'size 1',
- ])
-
- def test_vector(self):
- vals = dict()
-
- metric = CallbackMetric("values", lambda: vals, labels=["type"])
-
- self.assertEquals(metric.render(), [])
-
- # Keys have to be tuples, even if they're 1-element
- vals[("foo",)] = 1
- vals[("bar",)] = 2
-
- self.assertEquals(metric.render(), [
- 'values{type="bar"} 2',
- 'values{type="foo"} 1',
- ])
-
-
-class DistributionMetricTestCase(unittest.TestCase):
-
- def test_scalar(self):
- metric = DistributionMetric("thing")
-
- self.assertEquals(metric.render(), [
- 'thing:count 0',
- 'thing:total 0',
- ])
-
- metric.inc_by(500)
-
- self.assertEquals(metric.render(), [
- 'thing:count 1',
- 'thing:total 500',
- ])
-
- def test_vector(self):
- metric = DistributionMetric("queries", labels=["verb"])
-
- self.assertEquals(metric.render(), [])
-
- metric.inc_by(300, "SELECT")
- metric.inc_by(200, "SELECT")
- metric.inc_by(800, "INSERT")
-
- self.assertEquals(metric.render(), [
- 'queries:count{verb="INSERT"} 1',
- 'queries:count{verb="SELECT"} 2',
- 'queries:total{verb="INSERT"} 800',
- 'queries:total{verb="SELECT"} 500',
- ])
-
-
-class CacheMetricTestCase(unittest.TestCase):
-
- def test_cache(self):
- d = dict()
-
- metric = CacheMetric("cache", lambda: len(d), "cache_name")
-
- self.assertEquals(metric.render(), [
- 'cache:hits{name="cache_name"} 0',
- 'cache:total{name="cache_name"} 0',
- 'cache:size{name="cache_name"} 0',
- 'cache:evicted_size{name="cache_name"} 0',
- ])
-
- metric.inc_misses()
- d["key"] = "value"
-
- self.assertEquals(metric.render(), [
- 'cache:hits{name="cache_name"} 0',
- 'cache:total{name="cache_name"} 1',
- 'cache:size{name="cache_name"} 1',
- 'cache:evicted_size{name="cache_name"} 0',
- ])
-
- metric.inc_hits()
-
- self.assertEquals(metric.render(), [
- 'cache:hits{name="cache_name"} 1',
- 'cache:total{name="cache_name"} 2',
- 'cache:size{name="cache_name"} 1',
- 'cache:evicted_size{name="cache_name"} 0',
- ])
-
- metric.inc_evictions(2)
-
- self.assertEquals(metric.render(), [
- 'cache:hits{name="cache_name"} 1',
- 'cache:total{name="cache_name"} 2',
- 'cache:size{name="cache_name"} 1',
- 'cache:evicted_size{name="cache_name"} 2',
- ])
-
-
-class LabelValueEscapeTestCase(unittest.TestCase):
- def test_simple(self):
- string = "safjhsdlifhyskljfksdfh"
- self.assertEqual(string, _escape_label_value(string))
-
- def test_escape(self):
- self.assertEqual(
- "abc\\\"def\\nghi\\\\",
- _escape_label_value("abc\"def\nghi\\"),
- )
-
- def test_sequence_of_escapes(self):
- self.assertEqual(
- "abc\\\"def\\nghi\\\\\\n",
- _escape_label_value("abc\"def\nghi\\\n"),
- )
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index cc637dda1c..f863b75846 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -49,6 +49,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
"is_guest": 0,
"consent_version": None,
"consent_server_notice_sent": None,
+ "appservice_id": None,
},
(yield self.store.get_user_by_id(self.user_id))
)
diff --git a/tests/test_dns.py b/tests/test_dns.py
index af607d626f..3b360a0fc7 100644
--- a/tests/test_dns.py
+++ b/tests/test_dns.py
@@ -62,7 +62,7 @@ class DnsTestCase(unittest.TestCase):
dns_client_mock = Mock()
dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError())
- service_name = "test_service.examle.com"
+ service_name = "test_service.example.com"
entry = Mock(spec_set=["expires"])
entry.expires = 0
@@ -87,7 +87,7 @@ class DnsTestCase(unittest.TestCase):
dns_client_mock = Mock(spec_set=['lookupService'])
dns_client_mock.lookupService = Mock(spec_set=[])
- service_name = "test_service.examle.com"
+ service_name = "test_service.example.com"
entry = Mock(spec_set=["expires"])
entry.expires = 999999999
@@ -111,7 +111,7 @@ class DnsTestCase(unittest.TestCase):
dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError())
- service_name = "test_service.examle.com"
+ service_name = "test_service.example.com"
cache = {}
@@ -126,7 +126,7 @@ class DnsTestCase(unittest.TestCase):
dns_client_mock.lookupService.return_value = defer.fail(error.DNSNameError())
- service_name = "test_service.examle.com"
+ service_name = "test_service.example.com"
cache = {}
diff --git a/tests/unittest.py b/tests/unittest.py
index 7b478c4294..184fe880f3 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -12,23 +12,37 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
+import logging
+
import twisted
+import twisted.logger
from twisted.trial import unittest
-import logging
+from synapse.util.logcontext import LoggingContextFilter
+
+# Set up putting Synapse's logs into Trial's.
+rootLogger = logging.getLogger()
+
+log_format = (
+ "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s - %(message)s"
+)
+
+
+class ToTwistedHandler(logging.Handler):
+ tx_log = twisted.logger.Logger()
+
+ def emit(self, record):
+ log_entry = self.format(record)
+ log_level = record.levelname.lower().replace('warning', 'warn')
+ self.tx_log.emit(twisted.logger.LogLevel.levelWithName(log_level), log_entry)
-# logging doesn't have a "don't log anything at all EVARRRR setting,
-# but since the highest value is 50, 1000000 should do ;)
-NEVER = 1000000
-handler = logging.StreamHandler()
-handler.setFormatter(logging.Formatter(
- "%(levelname)s:%(name)s:%(message)s [%(pathname)s:%(lineno)d]"
-))
-logging.getLogger().addHandler(handler)
-logging.getLogger().setLevel(NEVER)
-logging.getLogger("synapse.storage.SQL").setLevel(NEVER)
-logging.getLogger("synapse.storage.txn").setLevel(NEVER)
+handler = ToTwistedHandler()
+formatter = logging.Formatter(log_format)
+handler.setFormatter(formatter)
+handler.addFilter(LoggingContextFilter(request=""))
+rootLogger.addHandler(handler)
def around(target):
@@ -61,7 +75,7 @@ class TestCase(unittest.TestCase):
method = getattr(self, methodName)
- level = getattr(method, "loglevel", getattr(self, "loglevel", NEVER))
+ level = getattr(method, "loglevel", getattr(self, "loglevel", logging.ERROR))
@around(self)
def setUp(orig):
diff --git a/tests/util/test_stream_change_cache.py b/tests/util/test_stream_change_cache.py
new file mode 100644
index 0000000000..67ece166c7
--- /dev/null
+++ b/tests/util/test_stream_change_cache.py
@@ -0,0 +1,198 @@
+from tests import unittest
+from mock import patch
+
+from synapse.util.caches.stream_change_cache import StreamChangeCache
+
+
+class StreamChangeCacheTests(unittest.TestCase):
+ """
+ Tests for StreamChangeCache.
+ """
+
+ def test_prefilled_cache(self):
+ """
+ Providing a prefilled cache to StreamChangeCache will result in a cache
+ with the prefilled-cache entered in.
+ """
+ cache = StreamChangeCache("#test", 1, prefilled_cache={"user@foo.com": 2})
+ self.assertTrue(cache.has_entity_changed("user@foo.com", 1))
+
+ def test_has_entity_changed(self):
+ """
+ StreamChangeCache.entity_has_changed will mark entities as changed, and
+ has_entity_changed will observe the changed entities.
+ """
+ cache = StreamChangeCache("#test", 3)
+
+ cache.entity_has_changed("user@foo.com", 6)
+ cache.entity_has_changed("bar@baz.net", 7)
+
+ # If it's been changed after that stream position, return True
+ self.assertTrue(cache.has_entity_changed("user@foo.com", 4))
+ self.assertTrue(cache.has_entity_changed("bar@baz.net", 4))
+
+ # If it's been changed at that stream position, return False
+ self.assertFalse(cache.has_entity_changed("user@foo.com", 6))
+
+ # If there's no changes after that stream position, return False
+ self.assertFalse(cache.has_entity_changed("user@foo.com", 7))
+
+ # If the entity does not exist, return False.
+ self.assertFalse(cache.has_entity_changed("not@here.website", 7))
+
+ # If we request before the stream cache's earliest known position,
+ # return True, whether it's a known entity or not.
+ self.assertTrue(cache.has_entity_changed("user@foo.com", 0))
+ self.assertTrue(cache.has_entity_changed("not@here.website", 0))
+
+ @patch("synapse.util.caches.CACHE_SIZE_FACTOR", 1.0)
+ def test_has_entity_changed_pops_off_start(self):
+ """
+ StreamChangeCache.entity_has_changed will respect the max size and
+ purge the oldest items upon reaching that max size.
+ """
+ cache = StreamChangeCache("#test", 1, max_size=2)
+
+ cache.entity_has_changed("user@foo.com", 2)
+ cache.entity_has_changed("bar@baz.net", 3)
+ cache.entity_has_changed("user@elsewhere.org", 4)
+
+ # The cache is at the max size, 2
+ self.assertEqual(len(cache._cache), 2)
+
+ # The oldest item has been popped off
+ self.assertTrue("user@foo.com" not in cache._entity_to_key)
+
+ # If we update an existing entity, it keeps the two existing entities
+ cache.entity_has_changed("bar@baz.net", 5)
+ self.assertEqual(
+ set(["bar@baz.net", "user@elsewhere.org"]), set(cache._entity_to_key)
+ )
+
+ def test_get_all_entities_changed(self):
+ """
+ StreamChangeCache.get_all_entities_changed will return all changed
+ entities since the given position. If the position is before the start
+ of the known stream, it returns None instead.
+ """
+ cache = StreamChangeCache("#test", 1)
+
+ cache.entity_has_changed("user@foo.com", 2)
+ cache.entity_has_changed("bar@baz.net", 3)
+ cache.entity_has_changed("user@elsewhere.org", 4)
+
+ self.assertEqual(
+ cache.get_all_entities_changed(1),
+ ["user@foo.com", "bar@baz.net", "user@elsewhere.org"],
+ )
+ self.assertEqual(
+ cache.get_all_entities_changed(2), ["bar@baz.net", "user@elsewhere.org"]
+ )
+ self.assertEqual(cache.get_all_entities_changed(3), ["user@elsewhere.org"])
+ self.assertEqual(cache.get_all_entities_changed(0), None)
+
+ def test_has_any_entity_changed(self):
+ """
+ StreamChangeCache.has_any_entity_changed will return True if any
+ entities have been changed since the provided stream position, and
+ False if they have not. If the cache has entries and the provided
+ stream position is before it, it will return True, otherwise False if
+ the cache has no entries.
+ """
+ cache = StreamChangeCache("#test", 1)
+
+ # With no entities, it returns False for the past, present, and future.
+ self.assertFalse(cache.has_any_entity_changed(0))
+ self.assertFalse(cache.has_any_entity_changed(1))
+ self.assertFalse(cache.has_any_entity_changed(2))
+
+ # We add an entity
+ cache.entity_has_changed("user@foo.com", 2)
+
+ # With an entity, it returns True for the past, the stream start
+ # position, and False for the stream position the entity was changed
+ # on and ones after it.
+ self.assertTrue(cache.has_any_entity_changed(0))
+ self.assertTrue(cache.has_any_entity_changed(1))
+ self.assertFalse(cache.has_any_entity_changed(2))
+ self.assertFalse(cache.has_any_entity_changed(3))
+
+ def test_get_entities_changed(self):
+ """
+ StreamChangeCache.get_entities_changed will return the entities in the
+ given list that have changed since the provided stream ID. If the
+ stream position is earlier than the earliest known position, it will
+ return all of the entities queried for.
+ """
+ cache = StreamChangeCache("#test", 1)
+
+ cache.entity_has_changed("user@foo.com", 2)
+ cache.entity_has_changed("bar@baz.net", 3)
+ cache.entity_has_changed("user@elsewhere.org", 4)
+
+ # Query all the entries, but mid-way through the stream. We should only
+ # get the ones after that point.
+ self.assertEqual(
+ cache.get_entities_changed(
+ ["user@foo.com", "bar@baz.net", "user@elsewhere.org"], stream_pos=2
+ ),
+ set(["bar@baz.net", "user@elsewhere.org"]),
+ )
+
+ # Query all the entries mid-way through the stream, but include one
+ # that doesn't exist in it. We should get back the one that doesn't
+ # exist, too.
+ self.assertEqual(
+ cache.get_entities_changed(
+ [
+ "user@foo.com",
+ "bar@baz.net",
+ "user@elsewhere.org",
+ "not@here.website",
+ ],
+ stream_pos=2,
+ ),
+ set(["bar@baz.net", "user@elsewhere.org", "not@here.website"]),
+ )
+
+ # Query all the entries, but before the first known point. We will get
+ # all the entries we queried for, including ones that don't exist.
+ self.assertEqual(
+ cache.get_entities_changed(
+ [
+ "user@foo.com",
+ "bar@baz.net",
+ "user@elsewhere.org",
+ "not@here.website",
+ ],
+ stream_pos=0,
+ ),
+ set(
+ [
+ "user@foo.com",
+ "bar@baz.net",
+ "user@elsewhere.org",
+ "not@here.website",
+ ]
+ ),
+ )
+
+ def test_max_pos(self):
+ """
+ StreamChangeCache.get_max_pos_of_last_change will return the most
+ recent point where the entity could have changed. If the entity is not
+ known, the stream start is provided instead.
+ """
+ cache = StreamChangeCache("#test", 1)
+
+ cache.entity_has_changed("user@foo.com", 2)
+ cache.entity_has_changed("bar@baz.net", 3)
+ cache.entity_has_changed("user@elsewhere.org", 4)
+
+ # Known entities will return the point where they were changed.
+ self.assertEqual(cache.get_max_pos_of_last_change("user@foo.com"), 2)
+ self.assertEqual(cache.get_max_pos_of_last_change("bar@baz.net"), 3)
+ self.assertEqual(cache.get_max_pos_of_last_change("user@elsewhere.org"), 4)
+
+ # Unknown entities will return the stream start position.
+ self.assertEqual(cache.get_max_pos_of_last_change("not@here.website"), 1)
diff --git a/tests/util/test_wheel_timer.py b/tests/util/test_wheel_timer.py
index c44567e52e..fdb24a48b0 100644
--- a/tests/util/test_wheel_timer.py
+++ b/tests/util/test_wheel_timer.py
@@ -33,7 +33,7 @@ class WheelTimerTestCase(unittest.TestCase):
self.assertListEqual(wheel.fetch(156), [obj])
self.assertListEqual(wheel.fetch(170), [])
- def test_mutli_insert(self):
+ def test_multi_insert(self):
wheel = WheelTimer(bucket_size=5)
obj1 = object()
@@ -58,7 +58,7 @@ class WheelTimerTestCase(unittest.TestCase):
wheel.insert(100, obj, 50)
self.assertListEqual(wheel.fetch(120), [obj])
- def test_insert_past_mutli(self):
+ def test_insert_past_multi(self):
wheel = WheelTimer(bucket_size=5)
obj1 = object()
|