diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index 7e7b0b4b1d..2cf262bb46 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -20,7 +20,7 @@ from mock import Mock
from synapse.api.auth import Auth
from synapse.api.errors import AuthError
from synapse.types import UserID
-from tests.utils import setup_test_homeserver
+from tests.utils import setup_test_homeserver, mock_getRawHeaders
import pymacaroons
@@ -45,12 +45,13 @@ class AuthTestCase(unittest.TestCase):
user_info = {
"name": self.test_user,
"token_id": "ditto",
+ "device_id": "device",
}
self.store.get_user_by_access_token = Mock(return_value=user_info)
request = Mock(args={})
request.args["access_token"] = [self.test_token]
- request.requestHeaders.getRawHeaders = Mock(return_value=[""])
+ request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = yield self.auth.get_user_by_req(request)
self.assertEquals(requester.user.to_string(), self.test_user)
@@ -60,7 +61,7 @@ class AuthTestCase(unittest.TestCase):
request = Mock(args={})
request.args["access_token"] = [self.test_token]
- request.requestHeaders.getRawHeaders = Mock(return_value=[""])
+ request.requestHeaders.getRawHeaders = mock_getRawHeaders()
d = self.auth.get_user_by_req(request)
self.failureResultOf(d, AuthError)
@@ -73,7 +74,7 @@ class AuthTestCase(unittest.TestCase):
self.store.get_user_by_access_token = Mock(return_value=user_info)
request = Mock(args={})
- request.requestHeaders.getRawHeaders = Mock(return_value=[""])
+ request.requestHeaders.getRawHeaders = mock_getRawHeaders()
d = self.auth.get_user_by_req(request)
self.failureResultOf(d, AuthError)
@@ -85,7 +86,7 @@ class AuthTestCase(unittest.TestCase):
request = Mock(args={})
request.args["access_token"] = [self.test_token]
- request.requestHeaders.getRawHeaders = Mock(return_value=[""])
+ request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = yield self.auth.get_user_by_req(request)
self.assertEquals(requester.user.to_string(), self.test_user)
@@ -95,7 +96,7 @@ class AuthTestCase(unittest.TestCase):
request = Mock(args={})
request.args["access_token"] = [self.test_token]
- request.requestHeaders.getRawHeaders = Mock(return_value=[""])
+ request.requestHeaders.getRawHeaders = mock_getRawHeaders()
d = self.auth.get_user_by_req(request)
self.failureResultOf(d, AuthError)
@@ -105,7 +106,7 @@ class AuthTestCase(unittest.TestCase):
self.store.get_user_by_access_token = Mock(return_value=None)
request = Mock(args={})
- request.requestHeaders.getRawHeaders = Mock(return_value=[""])
+ request.requestHeaders.getRawHeaders = mock_getRawHeaders()
d = self.auth.get_user_by_req(request)
self.failureResultOf(d, AuthError)
@@ -120,7 +121,7 @@ class AuthTestCase(unittest.TestCase):
request = Mock(args={})
request.args["access_token"] = [self.test_token]
request.args["user_id"] = [masquerading_user_id]
- request.requestHeaders.getRawHeaders = Mock(return_value=[""])
+ request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = yield self.auth.get_user_by_req(request)
self.assertEquals(requester.user.to_string(), masquerading_user_id)
@@ -134,7 +135,7 @@ class AuthTestCase(unittest.TestCase):
request = Mock(args={})
request.args["access_token"] = [self.test_token]
request.args["user_id"] = [masquerading_user_id]
- request.requestHeaders.getRawHeaders = Mock(return_value=[""])
+ request.requestHeaders.getRawHeaders = mock_getRawHeaders()
d = self.auth.get_user_by_req(request)
self.failureResultOf(d, AuthError)
@@ -143,7 +144,10 @@ class AuthTestCase(unittest.TestCase):
# TODO(danielwh): Remove this mock when we remove the
# get_user_by_access_token fallback.
self.store.get_user_by_access_token = Mock(
- return_value={"name": "@baldrick:matrix.org"}
+ return_value={
+ "name": "@baldrick:matrix.org",
+ "device_id": "device",
+ }
)
user_id = "@baldrick:matrix.org"
@@ -158,6 +162,10 @@ class AuthTestCase(unittest.TestCase):
user = user_info["user"]
self.assertEqual(UserID.from_string(user_id), user)
+ # TODO: device_id should come from the macaroon, but currently comes
+ # from the db.
+ self.assertEqual(user_info["device_id"], "device")
+
@defer.inlineCallbacks
def test_get_guest_user_from_macaroon(self):
user_id = "@baldrick:matrix.org"
@@ -281,15 +289,44 @@ class AuthTestCase(unittest.TestCase):
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("type = access")
macaroon.add_first_party_caveat("user_id = %s" % (user,))
- macaroon.add_first_party_caveat("time < 1") # ms
+ macaroon.add_first_party_caveat("time < -2000") # ms
self.hs.clock.now = 5000 # seconds
-
- yield self.auth.get_user_from_macaroon(macaroon.serialize())
+ self.hs.config.expire_access_token = True
+ # yield self.auth.get_user_from_macaroon(macaroon.serialize())
# TODO(daniel): Turn on the check that we validate expiration, when we
# validate expiration (and remove the above line, which will start
# throwing).
- # with self.assertRaises(AuthError) as cm:
- # yield self.auth.get_user_from_macaroon(macaroon.serialize())
- # self.assertEqual(401, cm.exception.code)
- # self.assertIn("Invalid macaroon", cm.exception.msg)
+ with self.assertRaises(AuthError) as cm:
+ yield self.auth.get_user_from_macaroon(macaroon.serialize())
+ self.assertEqual(401, cm.exception.code)
+ self.assertIn("Invalid macaroon", cm.exception.msg)
+
+ @defer.inlineCallbacks
+ def test_get_user_from_macaroon_with_valid_duration(self):
+ # TODO(danielwh): Remove this mock when we remove the
+ # get_user_by_access_token fallback.
+ self.store.get_user_by_access_token = Mock(
+ return_value={"name": "@baldrick:matrix.org"}
+ )
+
+ self.store.get_user_by_access_token = Mock(
+ return_value={"name": "@baldrick:matrix.org"}
+ )
+
+ user_id = "@baldrick:matrix.org"
+ macaroon = pymacaroons.Macaroon(
+ location=self.hs.config.server_name,
+ identifier="key",
+ key=self.hs.config.macaroon_secret_key)
+ macaroon.add_first_party_caveat("gen = 1")
+ macaroon.add_first_party_caveat("type = access")
+ macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
+ macaroon.add_first_party_caveat("time < 900000000") # ms
+
+ self.hs.clock.now = 5000 # seconds
+ self.hs.config.expire_access_token = True
+
+ user_info = yield self.auth.get_user_from_macaroon(macaroon.serialize())
+ user = user_info["user"]
+ self.assertEqual(UserID.from_string(user_id), user)
diff --git a/tests/appservice/test_appservice.py b/tests/appservice/test_appservice.py
index d6cc1881e9..aa8cc50550 100644
--- a/tests/appservice/test_appservice.py
+++ b/tests/appservice/test_appservice.py
@@ -14,6 +14,8 @@
# limitations under the License.
from synapse.appservice import ApplicationService
+from twisted.internet import defer
+
from mock import Mock
from tests import unittest
@@ -42,20 +44,25 @@ class ApplicationServiceTestCase(unittest.TestCase):
type="m.something", room_id="!foo:bar", sender="@someone:somewhere"
)
+ self.store = Mock()
+
+ @defer.inlineCallbacks
def test_regex_user_id_prefix_match(self):
self.service.namespaces[ApplicationService.NS_USERS].append(
_regex("@irc_.*")
)
self.event.sender = "@irc_foobar:matrix.org"
- self.assertTrue(self.service.is_interested(self.event))
+ self.assertTrue((yield self.service.is_interested(self.event)))
+ @defer.inlineCallbacks
def test_regex_user_id_prefix_no_match(self):
self.service.namespaces[ApplicationService.NS_USERS].append(
_regex("@irc_.*")
)
self.event.sender = "@someone_else:matrix.org"
- self.assertFalse(self.service.is_interested(self.event))
+ self.assertFalse((yield self.service.is_interested(self.event)))
+ @defer.inlineCallbacks
def test_regex_room_member_is_checked(self):
self.service.namespaces[ApplicationService.NS_USERS].append(
_regex("@irc_.*")
@@ -63,30 +70,36 @@ class ApplicationServiceTestCase(unittest.TestCase):
self.event.sender = "@someone_else:matrix.org"
self.event.type = "m.room.member"
self.event.state_key = "@irc_foobar:matrix.org"
- self.assertTrue(self.service.is_interested(self.event))
+ self.assertTrue((yield self.service.is_interested(self.event)))
+ @defer.inlineCallbacks
def test_regex_room_id_match(self):
self.service.namespaces[ApplicationService.NS_ROOMS].append(
_regex("!some_prefix.*some_suffix:matrix.org")
)
self.event.room_id = "!some_prefixs0m3th1nGsome_suffix:matrix.org"
- self.assertTrue(self.service.is_interested(self.event))
+ self.assertTrue((yield self.service.is_interested(self.event)))
+ @defer.inlineCallbacks
def test_regex_room_id_no_match(self):
self.service.namespaces[ApplicationService.NS_ROOMS].append(
_regex("!some_prefix.*some_suffix:matrix.org")
)
self.event.room_id = "!XqBunHwQIXUiqCaoxq:matrix.org"
- self.assertFalse(self.service.is_interested(self.event))
+ self.assertFalse((yield self.service.is_interested(self.event)))
+ @defer.inlineCallbacks
def test_regex_alias_match(self):
self.service.namespaces[ApplicationService.NS_ALIASES].append(
_regex("#irc_.*:matrix.org")
)
- self.assertTrue(self.service.is_interested(
- self.event,
- aliases_for_event=["#irc_foobar:matrix.org", "#athing:matrix.org"]
- ))
+ self.store.get_aliases_for_room.return_value = [
+ "#irc_foobar:matrix.org", "#athing:matrix.org"
+ ]
+ self.store.get_users_in_room.return_value = []
+ self.assertTrue((yield self.service.is_interested(
+ self.event, self.store
+ )))
def test_non_exclusive_alias(self):
self.service.namespaces[ApplicationService.NS_ALIASES].append(
@@ -136,15 +149,20 @@ class ApplicationServiceTestCase(unittest.TestCase):
"!irc_foobar:matrix.org"
))
+ @defer.inlineCallbacks
def test_regex_alias_no_match(self):
self.service.namespaces[ApplicationService.NS_ALIASES].append(
_regex("#irc_.*:matrix.org")
)
- self.assertFalse(self.service.is_interested(
- self.event,
- aliases_for_event=["#xmpp_foobar:matrix.org", "#athing:matrix.org"]
- ))
+ self.store.get_aliases_for_room.return_value = [
+ "#xmpp_foobar:matrix.org", "#athing:matrix.org"
+ ]
+ self.store.get_users_in_room.return_value = []
+ self.assertFalse((yield self.service.is_interested(
+ self.event, self.store
+ )))
+ @defer.inlineCallbacks
def test_regex_multiple_matches(self):
self.service.namespaces[ApplicationService.NS_ALIASES].append(
_regex("#irc_.*:matrix.org")
@@ -153,53 +171,13 @@ class ApplicationServiceTestCase(unittest.TestCase):
_regex("@irc_.*")
)
self.event.sender = "@irc_foobar:matrix.org"
- self.assertTrue(self.service.is_interested(
- self.event,
- aliases_for_event=["#irc_barfoo:matrix.org"]
- ))
-
- def test_restrict_to_rooms(self):
- self.service.namespaces[ApplicationService.NS_ROOMS].append(
- _regex("!flibble_.*:matrix.org")
- )
- self.service.namespaces[ApplicationService.NS_USERS].append(
- _regex("@irc_.*")
- )
- self.event.sender = "@irc_foobar:matrix.org"
- self.event.room_id = "!wibblewoo:matrix.org"
- self.assertFalse(self.service.is_interested(
- self.event,
- restrict_to=ApplicationService.NS_ROOMS
- ))
-
- def test_restrict_to_aliases(self):
- self.service.namespaces[ApplicationService.NS_ALIASES].append(
- _regex("#xmpp_.*:matrix.org")
- )
- self.service.namespaces[ApplicationService.NS_USERS].append(
- _regex("@irc_.*")
- )
- self.event.sender = "@irc_foobar:matrix.org"
- self.assertFalse(self.service.is_interested(
- self.event,
- restrict_to=ApplicationService.NS_ALIASES,
- aliases_for_event=["#irc_barfoo:matrix.org"]
- ))
-
- def test_restrict_to_senders(self):
- self.service.namespaces[ApplicationService.NS_ALIASES].append(
- _regex("#xmpp_.*:matrix.org")
- )
- self.service.namespaces[ApplicationService.NS_USERS].append(
- _regex("@irc_.*")
- )
- self.event.sender = "@xmpp_foobar:matrix.org"
- self.assertFalse(self.service.is_interested(
- self.event,
- restrict_to=ApplicationService.NS_USERS,
- aliases_for_event=["#xmpp_barfoo:matrix.org"]
- ))
+ self.store.get_aliases_for_room.return_value = ["#irc_barfoo:matrix.org"]
+ self.store.get_users_in_room.return_value = []
+ self.assertTrue((yield self.service.is_interested(
+ self.event, self.store
+ )))
+ @defer.inlineCallbacks
def test_interested_in_self(self):
# make sure invites get through
self.service.sender = "@appservice:name"
@@ -211,20 +189,21 @@ class ApplicationServiceTestCase(unittest.TestCase):
"membership": "invite"
}
self.event.state_key = self.service.sender
- self.assertTrue(self.service.is_interested(self.event))
+ self.assertTrue((yield self.service.is_interested(self.event)))
+ @defer.inlineCallbacks
def test_member_list_match(self):
self.service.namespaces[ApplicationService.NS_USERS].append(
_regex("@irc_.*")
)
- join_list = [
+ self.store.get_users_in_room.return_value = [
"@alice:here",
"@irc_fo:here", # AS user
"@bob:here",
]
+ self.store.get_aliases_for_room.return_value = []
self.event.sender = "@xmpp_foobar:matrix.org"
- self.assertTrue(self.service.is_interested(
- event=self.event,
- member_list=join_list
- ))
+ self.assertTrue((yield self.service.is_interested(
+ event=self.event, store=self.store
+ )))
diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py
index 631a229332..e5a902f734 100644
--- a/tests/appservice/test_scheduler.py
+++ b/tests/appservice/test_scheduler.py
@@ -193,7 +193,7 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase):
def setUp(self):
self.txn_ctrl = Mock()
- self.queuer = _ServiceQueuer(self.txn_ctrl)
+ self.queuer = _ServiceQueuer(self.txn_ctrl, MockClock())
def test_send_single_event_no_queue(self):
# Expect the event to be sent immediately.
diff --git a/tests/config/test_generate.py b/tests/config/test_generate.py
index 4329d73974..8f57fbeb23 100644
--- a/tests/config/test_generate.py
+++ b/tests/config/test_generate.py
@@ -30,7 +30,7 @@ class ConfigGenerationTestCase(unittest.TestCase):
shutil.rmtree(self.dir)
def test_generate_config_generates_files(self):
- HomeServerConfig.load_config("", [
+ HomeServerConfig.load_or_generate_config("", [
"--generate-config",
"-c", self.file,
"--report-stats=yes",
diff --git a/tests/config/test_load.py b/tests/config/test_load.py
index bf46233c5c..161a87d7e3 100644
--- a/tests/config/test_load.py
+++ b/tests/config/test_load.py
@@ -34,6 +34,8 @@ class ConfigLoadingTestCase(unittest.TestCase):
self.generate_config_and_remove_lines_containing("server_name")
with self.assertRaises(Exception):
HomeServerConfig.load_config("", ["-c", self.file])
+ with self.assertRaises(Exception):
+ HomeServerConfig.load_or_generate_config("", ["-c", self.file])
def test_generates_and_loads_macaroon_secret_key(self):
self.generate_config()
@@ -54,11 +56,24 @@ class ConfigLoadingTestCase(unittest.TestCase):
"was: %r" % (config.macaroon_secret_key,)
)
+ config = HomeServerConfig.load_or_generate_config("", ["-c", self.file])
+ self.assertTrue(
+ hasattr(config, "macaroon_secret_key"),
+ "Want config to have attr macaroon_secret_key"
+ )
+ if len(config.macaroon_secret_key) < 5:
+ self.fail(
+ "Want macaroon secret key to be string of at least length 5,"
+ "was: %r" % (config.macaroon_secret_key,)
+ )
+
def test_load_succeeds_if_macaroon_secret_key_missing(self):
self.generate_config_and_remove_lines_containing("macaroon")
config1 = HomeServerConfig.load_config("", ["-c", self.file])
config2 = HomeServerConfig.load_config("", ["-c", self.file])
+ config3 = HomeServerConfig.load_or_generate_config("", ["-c", self.file])
self.assertEqual(config1.macaroon_secret_key, config2.macaroon_secret_key)
+ self.assertEqual(config1.macaroon_secret_key, config3.macaroon_secret_key)
def test_disable_registration(self):
self.generate_config()
@@ -70,14 +85,17 @@ class ConfigLoadingTestCase(unittest.TestCase):
config = HomeServerConfig.load_config("", ["-c", self.file])
self.assertFalse(config.enable_registration)
+ config = HomeServerConfig.load_or_generate_config("", ["-c", self.file])
+ self.assertFalse(config.enable_registration)
+
# Check that either config value is clobbered by the command line.
- config = HomeServerConfig.load_config("", [
+ config = HomeServerConfig.load_or_generate_config("", [
"-c", self.file, "--enable-registration"
])
self.assertTrue(config.enable_registration)
def generate_config(self):
- HomeServerConfig.load_config("", [
+ HomeServerConfig.load_or_generate_config("", [
"--generate-config",
"-c", self.file,
"--report-stats=yes",
diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py
index fb0953c4ec..29f068d1f1 100644
--- a/tests/events/test_utils.py
+++ b/tests/events/test_utils.py
@@ -17,7 +17,11 @@
from .. import unittest
from synapse.events import FrozenEvent
-from synapse.events.utils import prune_event
+from synapse.events.utils import prune_event, serialize_event
+
+
+def MockEvent(**kwargs):
+ return FrozenEvent(kwargs)
class PruneEventTestCase(unittest.TestCase):
@@ -114,3 +118,167 @@ class PruneEventTestCase(unittest.TestCase):
'unsigned': {},
}
)
+
+
+class SerializeEventTestCase(unittest.TestCase):
+
+ def serialize(self, ev, fields):
+ return serialize_event(ev, 1479807801915, only_event_fields=fields)
+
+ def test_event_fields_works_with_keys(self):
+ self.assertEquals(
+ self.serialize(
+ MockEvent(
+ sender="@alice:localhost",
+ room_id="!foo:bar"
+ ),
+ ["room_id"]
+ ),
+ {
+ "room_id": "!foo:bar",
+ }
+ )
+
+ def test_event_fields_works_with_nested_keys(self):
+ self.assertEquals(
+ self.serialize(
+ MockEvent(
+ sender="@alice:localhost",
+ room_id="!foo:bar",
+ content={
+ "body": "A message",
+ },
+ ),
+ ["content.body"]
+ ),
+ {
+ "content": {
+ "body": "A message",
+ }
+ }
+ )
+
+ def test_event_fields_works_with_dot_keys(self):
+ self.assertEquals(
+ self.serialize(
+ MockEvent(
+ sender="@alice:localhost",
+ room_id="!foo:bar",
+ content={
+ "key.with.dots": {},
+ },
+ ),
+ ["content.key\.with\.dots"]
+ ),
+ {
+ "content": {
+ "key.with.dots": {},
+ }
+ }
+ )
+
+ def test_event_fields_works_with_nested_dot_keys(self):
+ self.assertEquals(
+ self.serialize(
+ MockEvent(
+ sender="@alice:localhost",
+ room_id="!foo:bar",
+ content={
+ "not_me": 1,
+ "nested.dot.key": {
+ "leaf.key": 42,
+ "not_me_either": 1,
+ },
+ },
+ ),
+ ["content.nested\.dot\.key.leaf\.key"]
+ ),
+ {
+ "content": {
+ "nested.dot.key": {
+ "leaf.key": 42,
+ },
+ }
+ }
+ )
+
+ def test_event_fields_nops_with_unknown_keys(self):
+ self.assertEquals(
+ self.serialize(
+ MockEvent(
+ sender="@alice:localhost",
+ room_id="!foo:bar",
+ content={
+ "foo": "bar",
+ },
+ ),
+ ["content.foo", "content.notexists"]
+ ),
+ {
+ "content": {
+ "foo": "bar",
+ }
+ }
+ )
+
+ def test_event_fields_nops_with_non_dict_keys(self):
+ self.assertEquals(
+ self.serialize(
+ MockEvent(
+ sender="@alice:localhost",
+ room_id="!foo:bar",
+ content={
+ "foo": ["I", "am", "an", "array"],
+ },
+ ),
+ ["content.foo.am"]
+ ),
+ {}
+ )
+
+ def test_event_fields_nops_with_array_keys(self):
+ self.assertEquals(
+ self.serialize(
+ MockEvent(
+ sender="@alice:localhost",
+ room_id="!foo:bar",
+ content={
+ "foo": ["I", "am", "an", "array"],
+ },
+ ),
+ ["content.foo.1"]
+ ),
+ {}
+ )
+
+ def test_event_fields_all_fields_if_empty(self):
+ self.assertEquals(
+ self.serialize(
+ MockEvent(
+ room_id="!foo:bar",
+ content={
+ "foo": "bar",
+ },
+ ),
+ []
+ ),
+ {
+ "room_id": "!foo:bar",
+ "content": {
+ "foo": "bar",
+ },
+ "unsigned": {}
+ }
+ )
+
+ def test_event_fields_fail_if_fields_not_str(self):
+ with self.assertRaises(TypeError):
+ self.serialize(
+ MockEvent(
+ room_id="!foo:bar",
+ content={
+ "foo": "bar",
+ },
+ ),
+ ["room_id", 4]
+ )
diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index 7ddbbb9b4a..7fe88172c0 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -15,6 +15,7 @@
from twisted.internet import defer
from .. import unittest
+from tests.utils import MockClock
from synapse.handlers.appservice import ApplicationServicesHandler
@@ -30,9 +31,10 @@ class AppServiceHandlerTestCase(unittest.TestCase):
self.mock_scheduler = Mock()
hs = Mock()
hs.get_datastore = Mock(return_value=self.mock_store)
- self.handler = ApplicationServicesHandler(
- hs, self.mock_as_api, self.mock_scheduler
- )
+ hs.get_application_service_api = Mock(return_value=self.mock_as_api)
+ hs.get_application_service_scheduler = Mock(return_value=self.mock_scheduler)
+ hs.get_clock.return_value = MockClock()
+ self.handler = ApplicationServicesHandler(hs)
@defer.inlineCallbacks
def test_notify_interested_services(self):
@@ -51,8 +53,9 @@ class AppServiceHandlerTestCase(unittest.TestCase):
type="m.room.message",
room_id="!foo:bar"
)
+ self.mock_store.get_new_events_for_appservice.return_value = (0, [event])
self.mock_as_api.push = Mock()
- yield self.handler.notify_interested_services(event)
+ yield self.handler.notify_interested_services(0)
self.mock_scheduler.submit_event_for_as.assert_called_once_with(
interested_service, event
)
@@ -72,7 +75,8 @@ class AppServiceHandlerTestCase(unittest.TestCase):
)
self.mock_as_api.push = Mock()
self.mock_as_api.query_user = Mock()
- yield self.handler.notify_interested_services(event)
+ self.mock_store.get_new_events_for_appservice.return_value = (0, [event])
+ yield self.handler.notify_interested_services(0)
self.mock_as_api.query_user.assert_called_once_with(
services[0], user_id
)
@@ -94,7 +98,8 @@ class AppServiceHandlerTestCase(unittest.TestCase):
)
self.mock_as_api.push = Mock()
self.mock_as_api.query_user = Mock()
- yield self.handler.notify_interested_services(event)
+ self.mock_store.get_new_events_for_appservice.return_value = (0, [event])
+ yield self.handler.notify_interested_services(0)
self.assertFalse(
self.mock_as_api.query_user.called,
"query_user called when it shouldn't have been."
@@ -108,11 +113,11 @@ class AppServiceHandlerTestCase(unittest.TestCase):
room_id = "!alpha:bet"
servers = ["aperture"]
- interested_service = self._mkservice(is_interested=True)
+ interested_service = self._mkservice_alias(is_interested_in_alias=True)
services = [
- self._mkservice(is_interested=False),
+ self._mkservice_alias(is_interested_in_alias=False),
interested_service,
- self._mkservice(is_interested=False)
+ self._mkservice_alias(is_interested_in_alias=False)
]
self.mock_store.get_app_services = Mock(return_value=services)
@@ -135,3 +140,10 @@ class AppServiceHandlerTestCase(unittest.TestCase):
service.token = "mock_service_token"
service.url = "mock_service_url"
return service
+
+ def _mkservice_alias(self, is_interested_in_alias):
+ service = Mock()
+ service.is_interested_in_alias = Mock(return_value=is_interested_in_alias)
+ service.token = "mock_service_token"
+ service.url = "mock_service_url"
+ return service
diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py
index 21077cbe9a..9d013e5ca7 100644
--- a/tests/handlers/test_auth.py
+++ b/tests/handlers/test_auth.py
@@ -14,11 +14,13 @@
# limitations under the License.
import pymacaroons
+from twisted.internet import defer
+import synapse
+import synapse.api.errors
from synapse.handlers.auth import AuthHandler
from tests import unittest
from tests.utils import setup_test_homeserver
-from twisted.internet import defer
class AuthHandlers(object):
@@ -31,11 +33,12 @@ class AuthTestCase(unittest.TestCase):
def setUp(self):
self.hs = yield setup_test_homeserver(handlers=None)
self.hs.handlers = AuthHandlers(self.hs)
+ self.auth_handler = self.hs.handlers.auth_handler
def test_token_is_a_macaroon(self):
self.hs.config.macaroon_secret_key = "this key is a huge secret"
- token = self.hs.handlers.auth_handler.generate_access_token("some_user")
+ token = self.auth_handler.generate_access_token("some_user")
# Check that we can parse the thing with pymacaroons
macaroon = pymacaroons.Macaroon.deserialize(token)
# The most basic of sanity checks
@@ -46,7 +49,7 @@ class AuthTestCase(unittest.TestCase):
self.hs.config.macaroon_secret_key = "this key is a massive secret"
self.hs.clock.now = 5000
- token = self.hs.handlers.auth_handler.generate_access_token("a_user")
+ token = self.auth_handler.generate_access_token("a_user")
macaroon = pymacaroons.Macaroon.deserialize(token)
def verify_gen(caveat):
@@ -58,12 +61,55 @@ class AuthTestCase(unittest.TestCase):
def verify_type(caveat):
return caveat == "type = access"
- def verify_expiry(caveat):
- return caveat == "time < 8600000"
+ def verify_nonce(caveat):
+ return caveat.startswith("nonce =")
v = pymacaroons.Verifier()
v.satisfy_general(verify_gen)
v.satisfy_general(verify_user)
v.satisfy_general(verify_type)
- v.satisfy_general(verify_expiry)
+ v.satisfy_general(verify_nonce)
v.verify(macaroon, self.hs.config.macaroon_secret_key)
+
+ def test_short_term_login_token_gives_user_id(self):
+ self.hs.clock.now = 1000
+
+ token = self.auth_handler.generate_short_term_login_token(
+ "a_user", 5000
+ )
+
+ self.assertEqual(
+ "a_user",
+ self.auth_handler.validate_short_term_login_token_and_get_user_id(
+ token
+ )
+ )
+
+ # when we advance the clock, the token should be rejected
+ self.hs.clock.now = 6000
+ with self.assertRaises(synapse.api.errors.AuthError):
+ self.auth_handler.validate_short_term_login_token_and_get_user_id(
+ token
+ )
+
+ def test_short_term_login_token_cannot_replace_user_id(self):
+ token = self.auth_handler.generate_short_term_login_token(
+ "a_user", 5000
+ )
+ macaroon = pymacaroons.Macaroon.deserialize(token)
+
+ self.assertEqual(
+ "a_user",
+ self.auth_handler.validate_short_term_login_token_and_get_user_id(
+ macaroon.serialize()
+ )
+ )
+
+ # add another "user_id" caveat, which might allow us to override the
+ # user_id.
+ macaroon.add_first_party_caveat("user_id = b_user")
+
+ with self.assertRaises(synapse.api.errors.AuthError):
+ self.auth_handler.validate_short_term_login_token_and_get_user_id(
+ macaroon.serialize()
+ )
diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
new file mode 100644
index 0000000000..85a970a6c9
--- /dev/null
+++ b/tests/handlers/test_device.py
@@ -0,0 +1,184 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+import synapse.api.errors
+import synapse.handlers.device
+
+import synapse.storage
+from synapse import types
+from tests import unittest, utils
+
+user1 = "@boris:aaa"
+user2 = "@theresa:bbb"
+
+
+class DeviceTestCase(unittest.TestCase):
+ def __init__(self, *args, **kwargs):
+ super(DeviceTestCase, self).__init__(*args, **kwargs)
+ self.store = None # type: synapse.storage.DataStore
+ self.handler = None # type: synapse.handlers.device.DeviceHandler
+ self.clock = None # type: utils.MockClock
+
+ @defer.inlineCallbacks
+ def setUp(self):
+ hs = yield utils.setup_test_homeserver(handlers=None)
+ self.handler = synapse.handlers.device.DeviceHandler(hs)
+ self.store = hs.get_datastore()
+ self.clock = hs.get_clock()
+
+ @defer.inlineCallbacks
+ def test_device_is_created_if_doesnt_exist(self):
+ res = yield self.handler.check_device_registered(
+ user_id="boris",
+ device_id="fco",
+ initial_device_display_name="display name"
+ )
+ self.assertEqual(res, "fco")
+
+ dev = yield self.handler.store.get_device("boris", "fco")
+ self.assertEqual(dev["display_name"], "display name")
+
+ @defer.inlineCallbacks
+ def test_device_is_preserved_if_exists(self):
+ res1 = yield self.handler.check_device_registered(
+ user_id="boris",
+ device_id="fco",
+ initial_device_display_name="display name"
+ )
+ self.assertEqual(res1, "fco")
+
+ res2 = yield self.handler.check_device_registered(
+ user_id="boris",
+ device_id="fco",
+ initial_device_display_name="new display name"
+ )
+ self.assertEqual(res2, "fco")
+
+ dev = yield self.handler.store.get_device("boris", "fco")
+ self.assertEqual(dev["display_name"], "display name")
+
+ @defer.inlineCallbacks
+ def test_device_id_is_made_up_if_unspecified(self):
+ device_id = yield self.handler.check_device_registered(
+ user_id="theresa",
+ device_id=None,
+ initial_device_display_name="display"
+ )
+
+ dev = yield self.handler.store.get_device("theresa", device_id)
+ self.assertEqual(dev["display_name"], "display")
+
+ @defer.inlineCallbacks
+ def test_get_devices_by_user(self):
+ yield self._record_users()
+
+ res = yield self.handler.get_devices_by_user(user1)
+ self.assertEqual(3, len(res))
+ device_map = {
+ d["device_id"]: d for d in res
+ }
+ self.assertDictContainsSubset({
+ "user_id": user1,
+ "device_id": "xyz",
+ "display_name": "display 0",
+ "last_seen_ip": None,
+ "last_seen_ts": None,
+ }, device_map["xyz"])
+ self.assertDictContainsSubset({
+ "user_id": user1,
+ "device_id": "fco",
+ "display_name": "display 1",
+ "last_seen_ip": "ip1",
+ "last_seen_ts": 1000000,
+ }, device_map["fco"])
+ self.assertDictContainsSubset({
+ "user_id": user1,
+ "device_id": "abc",
+ "display_name": "display 2",
+ "last_seen_ip": "ip3",
+ "last_seen_ts": 3000000,
+ }, device_map["abc"])
+
+ @defer.inlineCallbacks
+ def test_get_device(self):
+ yield self._record_users()
+
+ res = yield self.handler.get_device(user1, "abc")
+ self.assertDictContainsSubset({
+ "user_id": user1,
+ "device_id": "abc",
+ "display_name": "display 2",
+ "last_seen_ip": "ip3",
+ "last_seen_ts": 3000000,
+ }, res)
+
+ @defer.inlineCallbacks
+ def test_delete_device(self):
+ yield self._record_users()
+
+ # delete the device
+ yield self.handler.delete_device(user1, "abc")
+
+ # check the device was deleted
+ with self.assertRaises(synapse.api.errors.NotFoundError):
+ yield self.handler.get_device(user1, "abc")
+
+ # we'd like to check the access token was invalidated, but that's a
+ # bit of a PITA.
+
+ @defer.inlineCallbacks
+ def test_update_device(self):
+ yield self._record_users()
+
+ update = {"display_name": "new display"}
+ yield self.handler.update_device(user1, "abc", update)
+
+ res = yield self.handler.get_device(user1, "abc")
+ self.assertEqual(res["display_name"], "new display")
+
+ @defer.inlineCallbacks
+ def test_update_unknown_device(self):
+ update = {"display_name": "new_display"}
+ with self.assertRaises(synapse.api.errors.NotFoundError):
+ yield self.handler.update_device("user_id", "unknown_device_id",
+ update)
+
+ @defer.inlineCallbacks
+ def _record_users(self):
+ # check this works for both devices which have a recorded client_ip,
+ # and those which don't.
+ yield self._record_user(user1, "xyz", "display 0")
+ yield self._record_user(user1, "fco", "display 1", "token1", "ip1")
+ yield self._record_user(user1, "abc", "display 2", "token2", "ip2")
+ yield self._record_user(user1, "abc", "display 2", "token3", "ip3")
+
+ yield self._record_user(user2, "def", "dispkay", "token4", "ip4")
+
+ @defer.inlineCallbacks
+ def _record_user(self, user_id, device_id, display_name,
+ access_token=None, ip=None):
+ device_id = yield self.handler.check_device_registered(
+ user_id=user_id,
+ device_id=device_id,
+ initial_device_display_name=display_name
+ )
+
+ if ip is not None:
+ yield self.store.insert_client_ip(
+ types.UserID.from_string(user_id),
+ access_token, ip, "user_agent", device_id)
+ self.clock.advance_time(1000)
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
new file mode 100644
index 0000000000..878a54dc34
--- /dev/null
+++ b/tests/handlers/test_e2e_keys.py
@@ -0,0 +1,46 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import mock
+from twisted.internet import defer
+
+import synapse.api.errors
+import synapse.handlers.e2e_keys
+
+import synapse.storage
+from tests import unittest, utils
+
+
+class E2eKeysHandlerTestCase(unittest.TestCase):
+ def __init__(self, *args, **kwargs):
+ super(E2eKeysHandlerTestCase, self).__init__(*args, **kwargs)
+ self.hs = None # type: synapse.server.HomeServer
+ self.handler = None # type: synapse.handlers.e2e_keys.E2eKeysHandler
+
+ @defer.inlineCallbacks
+ def setUp(self):
+ self.hs = yield utils.setup_test_homeserver(
+ handlers=None,
+ replication_layer=mock.Mock(),
+ )
+ self.handler = synapse.handlers.e2e_keys.E2eKeysHandler(self.hs)
+
+ @defer.inlineCallbacks
+ def test_query_local_devices_no_devices(self):
+ """If the user has no devices, we expect an empty list.
+ """
+ local_user = "@boris:" + self.hs.hostname
+ res = yield self.handler.query_local_devices({local_user: None})
+ self.assertDictEqual(res, {local_user: {}})
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index 87c795fcfa..d9e8f634ae 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -115,6 +115,53 @@ class PresenceUpdateTestCase(unittest.TestCase):
),
], any_order=True)
+ def test_online_to_online_last_active_noop(self):
+ wheel_timer = Mock()
+ user_id = "@foo:bar"
+ now = 5000000
+
+ prev_state = UserPresenceState.default(user_id)
+ prev_state = prev_state.copy_and_replace(
+ state=PresenceState.ONLINE,
+ last_active_ts=now - LAST_ACTIVE_GRANULARITY - 10,
+ currently_active=True,
+ )
+
+ new_state = prev_state.copy_and_replace(
+ state=PresenceState.ONLINE,
+ last_active_ts=now,
+ )
+
+ state, persist_and_notify, federation_ping = handle_update(
+ prev_state, new_state, is_mine=True, wheel_timer=wheel_timer, now=now
+ )
+
+ self.assertFalse(persist_and_notify)
+ self.assertTrue(federation_ping)
+ self.assertTrue(state.currently_active)
+ self.assertEquals(new_state.state, state.state)
+ self.assertEquals(new_state.status_msg, state.status_msg)
+ self.assertEquals(state.last_federation_update_ts, now)
+
+ self.assertEquals(wheel_timer.insert.call_count, 3)
+ wheel_timer.insert.assert_has_calls([
+ call(
+ now=now,
+ obj=user_id,
+ then=new_state.last_active_ts + IDLE_TIMER
+ ),
+ call(
+ now=now,
+ obj=user_id,
+ then=new_state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT
+ ),
+ call(
+ now=now,
+ obj=user_id,
+ then=new_state.last_active_ts + LAST_ACTIVE_GRANULARITY
+ ),
+ ], any_order=True)
+
def test_online_to_online_last_active(self):
wheel_timer = Mock()
user_id = "@foo:bar"
@@ -264,7 +311,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
)
new_state = handle_timeout(
- state, is_mine=True, user_to_num_current_syncs={}, now=now
+ state, is_mine=True, syncing_user_ids=set(), now=now
)
self.assertIsNotNone(new_state)
@@ -282,7 +329,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
)
new_state = handle_timeout(
- state, is_mine=True, user_to_num_current_syncs={}, now=now
+ state, is_mine=True, syncing_user_ids=set(), now=now
)
self.assertIsNotNone(new_state)
@@ -300,9 +347,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
)
new_state = handle_timeout(
- state, is_mine=True, user_to_num_current_syncs={
- user_id: 1,
- }, now=now
+ state, is_mine=True, syncing_user_ids=set([user_id]), now=now
)
self.assertIsNotNone(new_state)
@@ -321,7 +366,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
)
new_state = handle_timeout(
- state, is_mine=True, user_to_num_current_syncs={}, now=now
+ state, is_mine=True, syncing_user_ids=set(), now=now
)
self.assertIsNotNone(new_state)
@@ -340,7 +385,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
)
new_state = handle_timeout(
- state, is_mine=True, user_to_num_current_syncs={}, now=now
+ state, is_mine=True, syncing_user_ids=set(), now=now
)
self.assertIsNone(new_state)
@@ -358,7 +403,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
)
new_state = handle_timeout(
- state, is_mine=False, user_to_num_current_syncs={}, now=now
+ state, is_mine=False, syncing_user_ids=set(), now=now
)
self.assertIsNotNone(new_state)
@@ -377,7 +422,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
)
new_state = handle_timeout(
- state, is_mine=True, user_to_num_current_syncs={}, now=now
+ state, is_mine=True, syncing_user_ids=set(), now=now
)
self.assertIsNotNone(new_state)
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index 4f2c14e4ff..f1f664275f 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -19,11 +19,12 @@ from twisted.internet import defer
from mock import Mock, NonCallableMock
+import synapse.types
from synapse.api.errors import AuthError
from synapse.handlers.profile import ProfileHandler
from synapse.types import UserID
-from tests.utils import setup_test_homeserver, requester_for_user
+from tests.utils import setup_test_homeserver
class ProfileHandlers(object):
@@ -86,7 +87,7 @@ class ProfileTestCase(unittest.TestCase):
def test_set_my_name(self):
yield self.handler.set_displayname(
self.frank,
- requester_for_user(self.frank),
+ synapse.types.create_requester(self.frank),
"Frank Jr."
)
@@ -99,7 +100,7 @@ class ProfileTestCase(unittest.TestCase):
def test_set_my_name_noauth(self):
d = self.handler.set_displayname(
self.frank,
- requester_for_user(self.bob),
+ synapse.types.create_requester(self.bob),
"Frank Jr."
)
@@ -144,7 +145,8 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_set_my_avatar(self):
yield self.handler.set_avatar_url(
- self.frank, requester_for_user(self.frank), "http://my.server/pic.gif"
+ self.frank, synapse.types.create_requester(self.frank),
+ "http://my.server/pic.gif"
)
self.assertEquals(
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
new file mode 100644
index 0000000000..a4380c48b4
--- /dev/null
+++ b/tests/handlers/test_register.py
@@ -0,0 +1,80 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015, 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+from .. import unittest
+
+from synapse.handlers.register import RegistrationHandler
+from synapse.types import UserID, create_requester
+
+from tests.utils import setup_test_homeserver
+
+from mock import Mock
+
+
+class RegistrationHandlers(object):
+ def __init__(self, hs):
+ self.registration_handler = RegistrationHandler(hs)
+
+
+class RegistrationTestCase(unittest.TestCase):
+ """ Tests the RegistrationHandler. """
+
+ @defer.inlineCallbacks
+ def setUp(self):
+ self.mock_distributor = Mock()
+ self.mock_distributor.declare("registered_user")
+ self.mock_captcha_client = Mock()
+ self.hs = yield setup_test_homeserver(
+ handlers=None,
+ http_client=None,
+ expire_access_token=True)
+ self.auth_handler = Mock(
+ generate_access_token=Mock(return_value='secret'))
+ self.hs.handlers = RegistrationHandlers(self.hs)
+ self.handler = self.hs.get_handlers().registration_handler
+ self.hs.get_handlers().profile_handler = Mock()
+ self.mock_handler = Mock(spec=[
+ "generate_access_token",
+ ])
+ self.hs.get_auth_handler = Mock(return_value=self.auth_handler)
+
+ @defer.inlineCallbacks
+ def test_user_is_created_and_logged_in_if_doesnt_exist(self):
+ local_part = "someone"
+ display_name = "someone"
+ user_id = "@someone:test"
+ requester = create_requester("@as:test")
+ result_user_id, result_token = yield self.handler.get_or_create_user(
+ requester, local_part, display_name)
+ self.assertEquals(result_user_id, user_id)
+ self.assertEquals(result_token, 'secret')
+
+ @defer.inlineCallbacks
+ def test_if_user_exists(self):
+ store = self.hs.get_datastore()
+ frank = UserID.from_string("@frank:test")
+ yield store.register(
+ user_id=frank.to_string(),
+ token="jkv;g498752-43gj['eamb!-5",
+ password_hash=None)
+ local_part = "frank"
+ display_name = "Frank"
+ user_id = "@frank:test"
+ requester = create_requester("@as:test")
+ result_user_id, result_token = yield self.handler.get_or_create_user(
+ requester, local_part, display_name)
+ self.assertEquals(result_user_id, user_id)
+ self.assertEquals(result_token, 'secret')
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 3955e7f5b1..c718d1f98f 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -25,8 +25,6 @@ from ..utils import (
)
from synapse.api.errors import AuthError
-from synapse.handlers.typing import TypingNotificationHandler
-
from synapse.types import UserID
@@ -49,11 +47,6 @@ def _make_edu_json(origin, edu_type, content):
return json.dumps(_expect_edu("test", edu_type, content, origin=origin))
-class JustTypingNotificationHandlers(object):
- def __init__(self, hs):
- self.typing_notification_handler = TypingNotificationHandler(hs)
-
-
class TypingNotificationsTestCase(unittest.TestCase):
"""Tests typing notifications to rooms."""
@defer.inlineCallbacks
@@ -69,8 +62,10 @@ class TypingNotificationsTestCase(unittest.TestCase):
self.on_new_event = mock_notifier.on_new_event
self.auth = Mock(spec=[])
+ self.state_handler = Mock()
hs = yield setup_test_homeserver(
+ "test",
auth=self.auth,
clock=self.clock,
datastore=Mock(spec=[
@@ -81,6 +76,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
"set_received_txn_response",
"get_destination_retry_timings",
]),
+ state_handler=self.state_handler,
handlers=None,
notifier=mock_notifier,
resource_for_client=Mock(),
@@ -88,9 +84,8 @@ class TypingNotificationsTestCase(unittest.TestCase):
http_client=self.mock_http_client,
keyring=Mock(),
)
- hs.handlers = JustTypingNotificationHandlers(hs)
- self.handler = hs.get_handlers().typing_notification_handler
+ self.handler = hs.get_typing_handler()
self.event_source = hs.get_event_sources().sources["typing"]
@@ -110,58 +105,30 @@ class TypingNotificationsTestCase(unittest.TestCase):
self.room_id = "a-room"
- # Mock the RoomMemberHandler
- hs.handlers.room_member_handler = Mock(spec=[])
- self.room_member_handler = hs.handlers.room_member_handler
-
self.room_members = []
- def get_rooms_for_user(user):
- if user in self.room_members:
- return defer.succeed([self.room_id])
- else:
- return defer.succeed([])
- self.room_member_handler.get_rooms_for_user = get_rooms_for_user
-
- def get_room_members(room_id):
- if room_id == self.room_id:
- return defer.succeed(self.room_members)
- else:
- return defer.succeed([])
- self.room_member_handler.get_room_members = get_room_members
-
- def get_joined_rooms_for_user(user):
- if user in self.room_members:
- return defer.succeed([self.room_id])
- else:
- return defer.succeed([])
- self.room_member_handler.get_joined_rooms_for_user = get_joined_rooms_for_user
-
- @defer.inlineCallbacks
- def fetch_room_distributions_into(
- room_id, localusers=None, remotedomains=None, ignore_user=None
- ):
- members = yield get_room_members(room_id)
- for member in members:
- if ignore_user is not None and member == ignore_user:
- continue
-
- if hs.is_mine(member):
- if localusers is not None:
- localusers.add(member)
- else:
- if remotedomains is not None:
- remotedomains.add(member.domain)
- self.room_member_handler.fetch_room_distributions_into = (
- fetch_room_distributions_into
- )
-
def check_joined_room(room_id, user_id):
if user_id not in [u.to_string() for u in self.room_members]:
raise AuthError(401, "User is not in the room")
+ def get_joined_hosts_for_room(room_id):
+ return set(member.domain for member in self.room_members)
+ self.datastore.get_joined_hosts_for_room = get_joined_hosts_for_room
+
+ def get_current_user_in_room(room_id):
+ return set(str(u) for u in self.room_members)
+ self.state_handler.get_current_user_in_room = get_current_user_in_room
+
self.auth.check_joined_room = check_joined_room
+ self.datastore.get_to_device_stream_token = lambda: 0
+ self.datastore.get_new_device_msgs_for_remote = (
+ lambda *args, **kargs: ([], 0)
+ )
+ self.datastore.delete_device_msgs_for_remote = (
+ lambda *args, **kargs: None
+ )
+
# Some local users to test with
self.u_apple = UserID.from_string("@apple:test")
self.u_banana = UserID.from_string("@banana:test")
@@ -252,7 +219,8 @@ class TypingNotificationsTestCase(unittest.TestCase):
"user_id": self.u_onion.to_string(),
"typing": True,
}
- )
+ ),
+ federation_auth=True,
)
self.on_new_event.assert_has_calls([
@@ -298,12 +266,9 @@ class TypingNotificationsTestCase(unittest.TestCase):
# Gut-wrenching
from synapse.handlers.typing import RoomMember
- member = RoomMember(self.room_id, self.u_apple)
+ member = RoomMember(self.room_id, self.u_apple.to_string())
self.handler._member_typing_until[member] = 1002000
- self.handler._member_typing_timer[member] = (
- self.clock.call_later(1002, lambda: 0)
- )
- self.handler._room_typing[self.room_id] = set((self.u_apple,))
+ self.handler._room_typing[self.room_id] = set([self.u_apple.to_string()])
self.assertEquals(self.event_source.get_current_key(), 0)
@@ -363,7 +328,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
},
}])
- self.clock.advance_time(11)
+ self.clock.advance_time(16)
self.on_new_event.assert_has_calls([
call('typing_key', 2, rooms=[self.room_id]),
diff --git a/tests/metrics/test_metric.py b/tests/metrics/test_metric.py
index f3c1927ce1..f85455a5af 100644
--- a/tests/metrics/test_metric.py
+++ b/tests/metrics/test_metric.py
@@ -61,9 +61,6 @@ class CounterMetricTestCase(unittest.TestCase):
'vector{method="PUT"} 1',
])
- # Check that passing too few values errors
- self.assertRaises(ValueError, counter.inc)
-
class CallbackMetricTestCase(unittest.TestCase):
@@ -138,27 +135,27 @@ class CacheMetricTestCase(unittest.TestCase):
def test_cache(self):
d = dict()
- metric = CacheMetric("cache", lambda: len(d))
+ metric = CacheMetric("cache", lambda: len(d), "cache_name")
self.assertEquals(metric.render(), [
- 'cache:hits 0',
- 'cache:total 0',
- 'cache:size 0',
+ 'cache:hits{name="cache_name"} 0',
+ 'cache:total{name="cache_name"} 0',
+ 'cache:size{name="cache_name"} 0',
])
metric.inc_misses()
d["key"] = "value"
self.assertEquals(metric.render(), [
- 'cache:hits 0',
- 'cache:total 1',
- 'cache:size 1',
+ 'cache:hits{name="cache_name"} 0',
+ 'cache:total{name="cache_name"} 1',
+ 'cache:size{name="cache_name"} 1',
])
metric.inc_hits()
self.assertEquals(metric.render(), [
- 'cache:hits 1',
- 'cache:total 2',
- 'cache:size 1',
+ 'cache:hits{name="cache_name"} 1',
+ 'cache:total{name="cache_name"} 2',
+ 'cache:size{name="cache_name"} 1',
])
diff --git a/tests/replication/slave/__init__.py b/tests/replication/slave/__init__.py
new file mode 100644
index 0000000000..b7df13c9ee
--- /dev/null
+++ b/tests/replication/slave/__init__.py
@@ -0,0 +1,14 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/tests/replication/slave/storage/__init__.py b/tests/replication/slave/storage/__init__.py
new file mode 100644
index 0000000000..b7df13c9ee
--- /dev/null
+++ b/tests/replication/slave/storage/__init__.py
@@ -0,0 +1,14 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py
new file mode 100644
index 0000000000..b82868054d
--- /dev/null
+++ b/tests/replication/slave/storage/_base.py
@@ -0,0 +1,56 @@
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+from tests import unittest
+
+from mock import Mock, NonCallableMock
+from tests.utils import setup_test_homeserver
+from synapse.replication.resource import ReplicationResource
+
+
+class BaseSlavedStoreTestCase(unittest.TestCase):
+ @defer.inlineCallbacks
+ def setUp(self):
+ self.hs = yield setup_test_homeserver(
+ "blue",
+ http_client=None,
+ replication_layer=Mock(),
+ ratelimiter=NonCallableMock(spec_set=[
+ "send_message",
+ ]),
+ )
+ self.hs.get_ratelimiter().send_message.return_value = (True, 0)
+
+ self.replication = ReplicationResource(self.hs)
+
+ self.master_store = self.hs.get_datastore()
+ self.slaved_store = self.STORE_TYPE(self.hs.get_db_conn(), self.hs)
+ self.event_id = 0
+
+ @defer.inlineCallbacks
+ def replicate(self):
+ streams = self.slaved_store.stream_positions()
+ writer = yield self.replication.replicate(streams, 100)
+ result = writer.finish()
+ yield self.slaved_store.process_replication(result)
+
+ @defer.inlineCallbacks
+ def check(self, method, args, expected_result=None):
+ master_result = yield getattr(self.master_store, method)(*args)
+ slaved_result = yield getattr(self.slaved_store, method)(*args)
+ if expected_result is not None:
+ self.assertEqual(master_result, expected_result)
+ self.assertEqual(slaved_result, expected_result)
+ self.assertEqual(master_result, slaved_result)
diff --git a/tests/replication/slave/storage/test_account_data.py b/tests/replication/slave/storage/test_account_data.py
new file mode 100644
index 0000000000..da54d478ce
--- /dev/null
+++ b/tests/replication/slave/storage/test_account_data.py
@@ -0,0 +1,56 @@
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from ._base import BaseSlavedStoreTestCase
+
+from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
+
+from twisted.internet import defer
+
+USER_ID = "@feeling:blue"
+TYPE = "my.type"
+
+
+class SlavedAccountDataStoreTestCase(BaseSlavedStoreTestCase):
+
+ STORE_TYPE = SlavedAccountDataStore
+
+ @defer.inlineCallbacks
+ def test_user_account_data(self):
+ yield self.master_store.add_account_data_for_user(
+ USER_ID, TYPE, {"a": 1}
+ )
+ yield self.replicate()
+ yield self.check(
+ "get_global_account_data_by_type_for_user",
+ [TYPE, USER_ID], {"a": 1}
+ )
+ yield self.check(
+ "get_global_account_data_by_type_for_users",
+ [TYPE, [USER_ID]], {USER_ID: {"a": 1}}
+ )
+
+ yield self.master_store.add_account_data_for_user(
+ USER_ID, TYPE, {"a": 2}
+ )
+ yield self.replicate()
+ yield self.check(
+ "get_global_account_data_by_type_for_user",
+ [TYPE, USER_ID], {"a": 2}
+ )
+ yield self.check(
+ "get_global_account_data_by_type_for_users",
+ [TYPE, [USER_ID]], {USER_ID: {"a": 2}}
+ )
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
new file mode 100644
index 0000000000..44e859b5d1
--- /dev/null
+++ b/tests/replication/slave/storage/test_events.py
@@ -0,0 +1,333 @@
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ._base import BaseSlavedStoreTestCase
+
+from synapse.events import FrozenEvent, _EventInternalMetadata
+from synapse.events.snapshot import EventContext
+from synapse.replication.slave.storage.events import SlavedEventStore
+from synapse.storage.roommember import RoomsForUser
+
+from twisted.internet import defer
+
+
+USER_ID = "@feeling:blue"
+USER_ID_2 = "@bright:blue"
+OUTLIER = {"outlier": True}
+ROOM_ID = "!room:blue"
+
+
+def dict_equals(self, other):
+ return self.__dict__ == other.__dict__
+
+
+def patch__eq__(cls):
+ eq = getattr(cls, "__eq__", None)
+ cls.__eq__ = dict_equals
+
+ def unpatch():
+ if eq is not None:
+ cls.__eq__ = eq
+ return unpatch
+
+
+class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
+
+ STORE_TYPE = SlavedEventStore
+
+ def setUp(self):
+ # Patch up the equality operator for events so that we can check
+ # whether lists of events match using assertEquals
+ self.unpatches = [
+ patch__eq__(_EventInternalMetadata),
+ patch__eq__(FrozenEvent),
+ ]
+ return super(SlavedEventStoreTestCase, self).setUp()
+
+ def tearDown(self):
+ [unpatch() for unpatch in self.unpatches]
+
+ @defer.inlineCallbacks
+ def test_room_members(self):
+ create = yield self.persist(type="m.room.create", key="", creator=USER_ID)
+ yield self.replicate()
+ yield self.check("get_rooms_for_user", (USER_ID,), [])
+ yield self.check("get_users_in_room", (ROOM_ID,), [])
+
+ # Join the room.
+ join = yield self.persist(type="m.room.member", key=USER_ID, membership="join")
+ yield self.replicate()
+ yield self.check("get_rooms_for_user", (USER_ID,), [RoomsForUser(
+ room_id=ROOM_ID,
+ sender=USER_ID,
+ membership="join",
+ event_id=join.event_id,
+ stream_ordering=join.internal_metadata.stream_ordering,
+ )])
+ yield self.check("get_users_in_room", (ROOM_ID,), [USER_ID])
+
+ # Leave the room.
+ yield self.persist(type="m.room.member", key=USER_ID, membership="leave")
+ yield self.replicate()
+ yield self.check("get_rooms_for_user", (USER_ID,), [])
+ yield self.check("get_users_in_room", (ROOM_ID,), [])
+
+ # Add some other user to the room.
+ join = yield self.persist(type="m.room.member", key=USER_ID_2, membership="join")
+ yield self.replicate()
+ yield self.check("get_rooms_for_user", (USER_ID_2,), [RoomsForUser(
+ room_id=ROOM_ID,
+ sender=USER_ID,
+ membership="join",
+ event_id=join.event_id,
+ stream_ordering=join.internal_metadata.stream_ordering,
+ )])
+ yield self.check("get_users_in_room", (ROOM_ID,), [USER_ID_2])
+
+ # Join the room clobbering the state.
+ # This should remove any evidence of the other user being in the room.
+ yield self.persist(
+ type="m.room.member", key=USER_ID, membership="join",
+ reset_state=[create]
+ )
+ yield self.replicate()
+ yield self.check("get_users_in_room", (ROOM_ID,), [USER_ID])
+ yield self.check("get_rooms_for_user", (USER_ID_2,), [])
+
+ @defer.inlineCallbacks
+ def test_get_latest_event_ids_in_room(self):
+ create = yield self.persist(type="m.room.create", key="", creator=USER_ID)
+ yield self.replicate()
+ yield self.check(
+ "get_latest_event_ids_in_room", (ROOM_ID,), [create.event_id]
+ )
+
+ join = yield self.persist(
+ type="m.room.member", key=USER_ID, membership="join",
+ prev_events=[(create.event_id, {})],
+ )
+ yield self.replicate()
+ yield self.check(
+ "get_latest_event_ids_in_room", (ROOM_ID,), [join.event_id]
+ )
+
+ @defer.inlineCallbacks
+ def test_get_current_state(self):
+ # Create the room.
+ create = yield self.persist(type="m.room.create", key="", creator=USER_ID)
+ yield self.replicate()
+ yield self.check(
+ "get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID), []
+ )
+
+ # Join the room.
+ join1 = yield self.persist(
+ type="m.room.member", key=USER_ID, membership="join",
+ )
+ yield self.replicate()
+ yield self.check(
+ "get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID),
+ [join1]
+ )
+
+ # Add some other user to the room.
+ join2 = yield self.persist(
+ type="m.room.member", key=USER_ID_2, membership="join",
+ )
+ yield self.replicate()
+ yield self.check(
+ "get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID_2),
+ [join2]
+ )
+
+ # Leave the room, then rejoin the room clobbering state.
+ yield self.persist(type="m.room.member", key=USER_ID, membership="leave")
+ join3 = yield self.persist(
+ type="m.room.member", key=USER_ID, membership="join",
+ reset_state=[create]
+ )
+ yield self.replicate()
+ yield self.check(
+ "get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID_2),
+ []
+ )
+ yield self.check(
+ "get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID),
+ [join3]
+ )
+
+ @defer.inlineCallbacks
+ def test_redactions(self):
+ yield self.persist(type="m.room.create", key="", creator=USER_ID)
+ yield self.persist(type="m.room.member", key=USER_ID, membership="join")
+
+ msg = yield self.persist(
+ type="m.room.message", msgtype="m.text", body="Hello"
+ )
+ yield self.replicate()
+ yield self.check("get_event", [msg.event_id], msg)
+
+ redaction = yield self.persist(
+ type="m.room.redaction", redacts=msg.event_id
+ )
+ yield self.replicate()
+
+ msg_dict = msg.get_dict()
+ msg_dict["content"] = {}
+ msg_dict["unsigned"]["redacted_by"] = redaction.event_id
+ msg_dict["unsigned"]["redacted_because"] = redaction
+ redacted = FrozenEvent(msg_dict, msg.internal_metadata.get_dict())
+ yield self.check("get_event", [msg.event_id], redacted)
+
+ @defer.inlineCallbacks
+ def test_backfilled_redactions(self):
+ yield self.persist(type="m.room.create", key="", creator=USER_ID)
+ yield self.persist(type="m.room.member", key=USER_ID, membership="join")
+
+ msg = yield self.persist(
+ type="m.room.message", msgtype="m.text", body="Hello"
+ )
+ yield self.replicate()
+ yield self.check("get_event", [msg.event_id], msg)
+
+ redaction = yield self.persist(
+ type="m.room.redaction", redacts=msg.event_id, backfill=True
+ )
+ yield self.replicate()
+
+ msg_dict = msg.get_dict()
+ msg_dict["content"] = {}
+ msg_dict["unsigned"]["redacted_by"] = redaction.event_id
+ msg_dict["unsigned"]["redacted_because"] = redaction
+ redacted = FrozenEvent(msg_dict, msg.internal_metadata.get_dict())
+ yield self.check("get_event", [msg.event_id], redacted)
+
+ @defer.inlineCallbacks
+ def test_invites(self):
+ yield self.check("get_invited_rooms_for_user", [USER_ID_2], [])
+ event = yield self.persist(
+ type="m.room.member", key=USER_ID_2, membership="invite"
+ )
+ yield self.replicate()
+ yield self.check("get_invited_rooms_for_user", [USER_ID_2], [RoomsForUser(
+ ROOM_ID, USER_ID, "invite", event.event_id,
+ event.internal_metadata.stream_ordering
+ )])
+
+ @defer.inlineCallbacks
+ def test_push_actions_for_user(self):
+ yield self.persist(type="m.room.create", creator=USER_ID)
+ yield self.persist(type="m.room.join", key=USER_ID, membership="join")
+ yield self.persist(
+ type="m.room.join", sender=USER_ID, key=USER_ID_2, membership="join"
+ )
+ event1 = yield self.persist(
+ type="m.room.message", msgtype="m.text", body="hello"
+ )
+ yield self.replicate()
+ yield self.check(
+ "get_unread_event_push_actions_by_room_for_user",
+ [ROOM_ID, USER_ID_2, event1.event_id],
+ {"highlight_count": 0, "notify_count": 0}
+ )
+
+ yield self.persist(
+ type="m.room.message", msgtype="m.text", body="world",
+ push_actions=[(USER_ID_2, ["notify"])],
+ )
+ yield self.replicate()
+ yield self.check(
+ "get_unread_event_push_actions_by_room_for_user",
+ [ROOM_ID, USER_ID_2, event1.event_id],
+ {"highlight_count": 0, "notify_count": 1}
+ )
+
+ yield self.persist(
+ type="m.room.message", msgtype="m.text", body="world",
+ push_actions=[(USER_ID_2, [
+ "notify", {"set_tweak": "highlight", "value": True}
+ ])],
+ )
+ yield self.replicate()
+ yield self.check(
+ "get_unread_event_push_actions_by_room_for_user",
+ [ROOM_ID, USER_ID_2, event1.event_id],
+ {"highlight_count": 1, "notify_count": 2}
+ )
+
+ event_id = 0
+
+ @defer.inlineCallbacks
+ def persist(
+ self, sender=USER_ID, room_id=ROOM_ID, type={}, key=None, internal={},
+ state=None, reset_state=False, backfill=False,
+ depth=None, prev_events=[], auth_events=[], prev_state=[], redacts=None,
+ push_actions=[],
+ **content
+ ):
+ """
+ Returns:
+ synapse.events.FrozenEvent: The event that was persisted.
+ """
+ if depth is None:
+ depth = self.event_id
+
+ event_dict = {
+ "sender": sender,
+ "type": type,
+ "content": content,
+ "event_id": "$%d:blue" % (self.event_id,),
+ "room_id": room_id,
+ "depth": depth,
+ "origin_server_ts": self.event_id,
+ "prev_events": prev_events,
+ "auth_events": auth_events,
+ }
+ if key is not None:
+ event_dict["state_key"] = key
+ event_dict["prev_state"] = prev_state
+
+ if redacts is not None:
+ event_dict["redacts"] = redacts
+
+ event = FrozenEvent(event_dict, internal_metadata_dict=internal)
+
+ self.event_id += 1
+
+ if state is not None:
+ state_ids = {
+ key: e.event_id for key, e in state.items()
+ }
+ else:
+ state_ids = None
+
+ context = EventContext()
+ context.current_state_ids = state_ids
+ context.prev_state_ids = state_ids
+ context.push_actions = push_actions
+
+ ordering = None
+ if backfill:
+ yield self.master_store.persist_events(
+ [(event, context)], backfilled=True
+ )
+ else:
+ ordering, _ = yield self.master_store.persist_event(
+ event, context, current_state=reset_state
+ )
+
+ if ordering:
+ event.internal_metadata.stream_ordering = ordering
+
+ defer.returnValue(event)
diff --git a/tests/replication/slave/storage/test_receipts.py b/tests/replication/slave/storage/test_receipts.py
new file mode 100644
index 0000000000..6624fe4eea
--- /dev/null
+++ b/tests/replication/slave/storage/test_receipts.py
@@ -0,0 +1,39 @@
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ._base import BaseSlavedStoreTestCase
+
+from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
+
+from twisted.internet import defer
+
+USER_ID = "@feeling:blue"
+ROOM_ID = "!room:blue"
+EVENT_ID = "$event:blue"
+
+
+class SlavedReceiptTestCase(BaseSlavedStoreTestCase):
+
+ STORE_TYPE = SlavedReceiptsStore
+
+ @defer.inlineCallbacks
+ def test_receipt(self):
+ yield self.check("get_receipts_for_user", [USER_ID, "m.read"], {})
+ yield self.master_store.insert_receipt(
+ ROOM_ID, "m.read", USER_ID, [EVENT_ID], {}
+ )
+ yield self.replicate()
+ yield self.check("get_receipts_for_user", [USER_ID, "m.read"], {
+ ROOM_ID: EVENT_ID
+ })
diff --git a/tests/replication/test_resource.py b/tests/replication/test_resource.py
index f4b5fb3328..93b9fad012 100644
--- a/tests/replication/test_resource.py
+++ b/tests/replication/test_resource.py
@@ -13,15 +13,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.replication.resource import ReplicationResource
-from synapse.types import Requester, UserID
+import contextlib
+import json
+from mock import Mock, NonCallableMock
from twisted.internet import defer
+
+import synapse.types
+from synapse.replication.resource import ReplicationResource
+from synapse.types import UserID
from tests import unittest
-from tests.utils import setup_test_homeserver, requester_for_user
-from mock import Mock, NonCallableMock
-import json
-import contextlib
+from tests.utils import setup_test_homeserver
class ReplicationResourceCase(unittest.TestCase):
@@ -61,18 +63,18 @@ class ReplicationResourceCase(unittest.TestCase):
def test_events(self):
get = self.get(events="-1", timeout="0")
yield self.hs.get_handlers().room_creation_handler.create_room(
- Requester(self.user, "", False), {}
+ synapse.types.create_requester(self.user), {}
)
code, body = yield get
self.assertEquals(code, 200)
self.assertEquals(body["events"]["field_names"], [
- "position", "internal", "json"
+ "position", "internal", "json", "state_group"
])
@defer.inlineCallbacks
def test_presence(self):
get = self.get(presence="-1")
- yield self.hs.get_handlers().presence_handler.set_state(
+ yield self.hs.get_presence_handler().set_state(
self.user, {"presence": "online"}
)
code, body = yield get
@@ -87,7 +89,7 @@ class ReplicationResourceCase(unittest.TestCase):
def test_typing(self):
room_id = yield self.create_room()
get = self.get(typing="-1")
- yield self.hs.get_handlers().typing_notification_handler.started_typing(
+ yield self.hs.get_typing_handler().started_typing(
self.user, self.user, room_id, timeout=2
)
code, body = yield get
@@ -101,7 +103,7 @@ class ReplicationResourceCase(unittest.TestCase):
room_id = yield self.create_room()
event_id = yield self.send_text_message(room_id, "Hello, World")
get = self.get(receipts="-1")
- yield self.hs.get_handlers().receipts_handler.received_client_receipt(
+ yield self.hs.get_receipts_handler().received_client_receipt(
room_id, "m.read", self.user_id, event_id
)
code, body = yield get
@@ -118,7 +120,7 @@ class ReplicationResourceCase(unittest.TestCase):
self.hs.clock.advance_time_msec(1)
code, body = yield get
self.assertEquals(code, 200)
- self.assertEquals(body, {})
+ self.assertEquals(body.get("rows", []), [])
test_timeout.__name__ = "test_timeout_%s" % (stream)
return test_timeout
@@ -132,12 +134,13 @@ class ReplicationResourceCase(unittest.TestCase):
test_timeout_backfill = _test_timeout("backfill")
test_timeout_push_rules = _test_timeout("push_rules")
test_timeout_pushers = _test_timeout("pushers")
+ test_timeout_state = _test_timeout("state")
@defer.inlineCallbacks
def send_text_message(self, room_id, message):
handler = self.hs.get_handlers().message_handler
event = yield handler.create_and_send_nonmember_event(
- requester_for_user(self.user),
+ synapse.types.create_requester(self.user),
{
"type": "m.room.message",
"content": {"body": "message", "msgtype": "m.text"},
@@ -150,7 +153,7 @@ class ReplicationResourceCase(unittest.TestCase):
@defer.inlineCallbacks
def create_room(self):
result = yield self.hs.get_handlers().room_creation_handler.create_room(
- Requester(self.user, "", False), {}
+ synapse.types.create_requester(self.user), {}
)
defer.returnValue(result["room_id"])
@@ -182,4 +185,20 @@ class ReplicationResourceCase(unittest.TestCase):
)
response_body = json.loads(response_json)
+ if response_code == 200:
+ self.check_response(response_body)
+
defer.returnValue((response_code, response_body))
+
+ def check_response(self, response_body):
+ for name, stream in response_body.items():
+ self.assertIn("field_names", stream)
+ field_names = stream["field_names"]
+ self.assertIn("rows", stream)
+ for row in stream["rows"]:
+ self.assertEquals(
+ len(row), len(field_names),
+ "%s: len(row = %r) == len(field_names = %r)" % (
+ name, row, field_names
+ )
+ )
diff --git a/tests/rest/client/test_transactions.py b/tests/rest/client/test_transactions.py
new file mode 100644
index 0000000000..d7cea30260
--- /dev/null
+++ b/tests/rest/client/test_transactions.py
@@ -0,0 +1,69 @@
+from synapse.rest.client.transactions import HttpTransactionCache
+from synapse.rest.client.transactions import CLEANUP_PERIOD_MS
+from twisted.internet import defer
+from mock import Mock, call
+from tests import unittest
+from tests.utils import MockClock
+
+
+class HttpTransactionCacheTestCase(unittest.TestCase):
+
+ def setUp(self):
+ self.clock = MockClock()
+ self.cache = HttpTransactionCache(self.clock)
+
+ self.mock_http_response = (200, "GOOD JOB!")
+ self.mock_key = "foo"
+
+ @defer.inlineCallbacks
+ def test_executes_given_function(self):
+ cb = Mock(
+ return_value=defer.succeed(self.mock_http_response)
+ )
+ res = yield self.cache.fetch_or_execute(
+ self.mock_key, cb, "some_arg", keyword="arg"
+ )
+ cb.assert_called_once_with("some_arg", keyword="arg")
+ self.assertEqual(res, self.mock_http_response)
+
+ @defer.inlineCallbacks
+ def test_deduplicates_based_on_key(self):
+ cb = Mock(
+ return_value=defer.succeed(self.mock_http_response)
+ )
+ for i in range(3): # invoke multiple times
+ res = yield self.cache.fetch_or_execute(
+ self.mock_key, cb, "some_arg", keyword="arg", changing_args=i
+ )
+ self.assertEqual(res, self.mock_http_response)
+ # expect only a single call to do the work
+ cb.assert_called_once_with("some_arg", keyword="arg", changing_args=0)
+
+ @defer.inlineCallbacks
+ def test_cleans_up(self):
+ cb = Mock(
+ return_value=defer.succeed(self.mock_http_response)
+ )
+ yield self.cache.fetch_or_execute(
+ self.mock_key, cb, "an arg"
+ )
+ # should NOT have cleaned up yet
+ self.clock.advance_time_msec(CLEANUP_PERIOD_MS / 2)
+
+ yield self.cache.fetch_or_execute(
+ self.mock_key, cb, "an arg"
+ )
+ # still using cache
+ cb.assert_called_once_with("an arg")
+
+ self.clock.advance_time_msec(CLEANUP_PERIOD_MS)
+
+ yield self.cache.fetch_or_execute(
+ self.mock_key, cb, "an arg"
+ )
+ # no longer using cache
+ self.assertEqual(cb.call_count, 2)
+ self.assertEqual(
+ cb.call_args_list,
+ [call("an arg",), call("an arg",)]
+ )
diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py
index af02fce8fb..1e95e97538 100644
--- a/tests/rest/client/v1/test_profile.py
+++ b/tests/rest/client/v1/test_profile.py
@@ -14,17 +14,14 @@
# limitations under the License.
"""Tests REST events for /profile paths."""
-from tests import unittest
-from twisted.internet import defer
-
from mock import Mock
+from twisted.internet import defer
-from ....utils import MockHttpResource, setup_test_homeserver
-
+import synapse.types
from synapse.api.errors import SynapseError, AuthError
-from synapse.types import Requester, UserID
-
from synapse.rest.client.v1 import profile
+from tests import unittest
+from ....utils import MockHttpResource, setup_test_homeserver
myid = "@1234ABCD:test"
PATH_PREFIX = "/_matrix/client/api/v1"
@@ -52,7 +49,7 @@ class ProfileTestCase(unittest.TestCase):
)
def _get_user_by_req(request=None, allow_guest=False):
- return Requester(UserID.from_string(myid), "", False)
+ return synapse.types.create_requester(myid)
hs.get_v1auth().get_user_by_req = _get_user_by_req
diff --git a/tests/rest/client/v1/test_register.py b/tests/rest/client/v1/test_register.py
new file mode 100644
index 0000000000..a6a4e2ffe0
--- /dev/null
+++ b/tests/rest/client/v1/test_register.py
@@ -0,0 +1,78 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015, 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.rest.client.v1.register import CreateUserRestServlet
+from twisted.internet import defer
+from mock import Mock
+from tests import unittest
+from tests.utils import mock_getRawHeaders
+import json
+
+
+class CreateUserServletTestCase(unittest.TestCase):
+
+ def setUp(self):
+ # do the dance to hook up request data to self.request_data
+ self.request_data = ""
+ self.request = Mock(
+ content=Mock(read=Mock(side_effect=lambda: self.request_data)),
+ path='/_matrix/client/api/v1/createUser'
+ )
+ self.request.args = {}
+ self.request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+
+ self.registration_handler = Mock()
+
+ self.appservice = Mock(sender="@as:test")
+ self.datastore = Mock(
+ get_app_service_by_token=Mock(return_value=self.appservice)
+ )
+
+ # do the dance to hook things up to the hs global
+ handlers = Mock(
+ registration_handler=self.registration_handler,
+ )
+ self.hs = Mock()
+ self.hs.hostname = "superbig~testing~thing.com"
+ self.hs.get_datastore = Mock(return_value=self.datastore)
+ self.hs.get_handlers = Mock(return_value=handlers)
+ self.servlet = CreateUserRestServlet(self.hs)
+
+ @defer.inlineCallbacks
+ def test_POST_createuser_with_valid_user(self):
+ user_id = "@someone:interesting"
+ token = "my token"
+ self.request.args = {
+ "access_token": "i_am_an_app_service"
+ }
+ self.request_data = json.dumps({
+ "localpart": "someone",
+ "displayname": "someone interesting",
+ "duration_seconds": 200
+ })
+
+ self.registration_handler.get_or_create_user = Mock(
+ return_value=(user_id, token)
+ )
+
+ (code, result) = yield self.servlet.on_POST(self.request)
+ self.assertEquals(code, 200)
+
+ det_data = {
+ "user_id": user_id,
+ "access_token": token,
+ "home_server": self.hs.hostname
+ }
+ self.assertDictContainsSubset(det_data, result)
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index 4ab8b35e6b..4fe99ebc0b 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -259,8 +259,8 @@ class RoomPermissionsTestCase(RestTestCase):
# set [invite/join/left] of self, set [invite/join/left] of other,
# expect all 404s because room doesn't exist on any server
for usr in [self.user_id, self.rmcreator_id]:
- yield self.join(room=room, user=usr, expect_code=404)
- yield self.leave(room=room, user=usr, expect_code=404)
+ yield self.join(room=room, user=usr, expect_code=403)
+ yield self.leave(room=room, user=usr, expect_code=403)
@defer.inlineCallbacks
def test_membership_private_room_perms(self):
@@ -1032,7 +1032,7 @@ class RoomMessageListTestCase(RestTestCase):
@defer.inlineCallbacks
def test_topo_token_is_accepted(self):
- token = "t1-0_0_0_0_0_0"
+ token = "t1-0_0_0_0_0_0_0"
(code, response) = yield self.mock_resource.trigger_get(
"/rooms/%s/messages?access_token=x&from=%s" %
(self.room_id, token))
@@ -1044,7 +1044,7 @@ class RoomMessageListTestCase(RestTestCase):
@defer.inlineCallbacks
def test_stream_token_is_accepted_for_fwd_pagianation(self):
- token = "s0_0_0_0_0_0"
+ token = "s0_0_0_0_0_0_0"
(code, response) = yield self.mock_resource.trigger_get(
"/rooms/%s/messages?access_token=x&from=%s" %
(self.room_id, token))
diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py
index d0037a53ef..a269e6f56e 100644
--- a/tests/rest/client/v1/test_typing.py
+++ b/tests/rest/client/v1/test_typing.py
@@ -105,9 +105,6 @@ class RoomTypingTestCase(RestTestCase):
# Need another user to make notifications actually work
yield self.join(self.room_id, user="@jim:red")
- def tearDown(self):
- self.hs.get_handlers().typing_notification_handler.tearDown()
-
@defer.inlineCallbacks
def test_set_typing(self):
(code, _) = yield self.mock_resource.trigger(
@@ -147,7 +144,7 @@ class RoomTypingTestCase(RestTestCase):
self.assertEquals(self.event_source.get_current_key(), 1)
- self.clock.advance_time(31)
+ self.clock.advance_time(36)
self.assertEquals(self.event_source.get_current_key(), 2)
diff --git a/tests/rest/client/v2_alpha/test_filter.py b/tests/rest/client/v2_alpha/test_filter.py
index d1442aafac..3d27d03cbf 100644
--- a/tests/rest/client/v2_alpha/test_filter.py
+++ b/tests/rest/client/v2_alpha/test_filter.py
@@ -15,78 +15,125 @@
from twisted.internet import defer
-from . import V2AlphaRestTestCase
+from tests import unittest
from synapse.rest.client.v2_alpha import filter
-from synapse.api.errors import StoreError
+from synapse.api.errors import Codes
+import synapse.types
+
+from synapse.types import UserID
+
+from ....utils import MockHttpResource, setup_test_homeserver
+
+PATH_PREFIX = "/_matrix/client/v2_alpha"
+
+
+class FilterTestCase(unittest.TestCase):
-class FilterTestCase(V2AlphaRestTestCase):
USER_ID = "@apple:test"
+ EXAMPLE_FILTER = {"type": ["m.*"]}
+ EXAMPLE_FILTER_JSON = '{"type": ["m.*"]}'
TO_REGISTER = [filter]
- def make_datastore_mock(self):
- datastore = super(FilterTestCase, self).make_datastore_mock()
+ @defer.inlineCallbacks
+ def setUp(self):
+ self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
- self._user_filters = {}
+ self.hs = yield setup_test_homeserver(
+ http_client=None,
+ resource_for_client=self.mock_resource,
+ resource_for_federation=self.mock_resource,
+ )
- def add_user_filter(user_localpart, definition):
- filters = self._user_filters.setdefault(user_localpart, [])
- filter_id = len(filters)
- filters.append(definition)
- return defer.succeed(filter_id)
- datastore.add_user_filter = add_user_filter
+ self.auth = self.hs.get_auth()
- def get_user_filter(user_localpart, filter_id):
- if user_localpart not in self._user_filters:
- raise StoreError(404, "No user")
- filters = self._user_filters[user_localpart]
- if filter_id >= len(filters):
- raise StoreError(404, "No filter")
- return defer.succeed(filters[filter_id])
- datastore.get_user_filter = get_user_filter
+ def get_user_by_access_token(token=None, allow_guest=False):
+ return {
+ "user": UserID.from_string(self.USER_ID),
+ "token_id": 1,
+ "is_guest": False,
+ }
- return datastore
+ def get_user_by_req(request, allow_guest=False, rights="access"):
+ return synapse.types.create_requester(
+ UserID.from_string(self.USER_ID), 1, False, None)
+
+ self.auth.get_user_by_access_token = get_user_by_access_token
+ self.auth.get_user_by_req = get_user_by_req
+
+ self.store = self.hs.get_datastore()
+ self.filtering = self.hs.get_filtering()
+
+ for r in self.TO_REGISTER:
+ r.register_servlets(self.hs, self.mock_resource)
@defer.inlineCallbacks
def test_add_filter(self):
(code, response) = yield self.mock_resource.trigger(
- "POST", "/user/%s/filter" % (self.USER_ID), '{"type": ["m.*"]}'
+ "POST", "/user/%s/filter" % (self.USER_ID), self.EXAMPLE_FILTER_JSON
)
self.assertEquals(200, code)
self.assertEquals({"filter_id": "0"}, response)
+ filter = yield self.store.get_user_filter(
+ user_localpart='apple',
+ filter_id=0,
+ )
+ self.assertEquals(filter, self.EXAMPLE_FILTER)
- self.assertIn("apple", self._user_filters)
- self.assertEquals(len(self._user_filters["apple"]), 1)
- self.assertEquals({"type": ["m.*"]}, self._user_filters["apple"][0])
+ @defer.inlineCallbacks
+ def test_add_filter_for_other_user(self):
+ (code, response) = yield self.mock_resource.trigger(
+ "POST", "/user/%s/filter" % ('@watermelon:test'), self.EXAMPLE_FILTER_JSON
+ )
+ self.assertEquals(403, code)
+ self.assertEquals(response['errcode'], Codes.FORBIDDEN)
@defer.inlineCallbacks
- def test_get_filter(self):
- self._user_filters["apple"] = [
- {"type": ["m.*"]}
- ]
+ def test_add_filter_non_local_user(self):
+ _is_mine = self.hs.is_mine
+ self.hs.is_mine = lambda target_user: False
+ (code, response) = yield self.mock_resource.trigger(
+ "POST", "/user/%s/filter" % (self.USER_ID), self.EXAMPLE_FILTER_JSON
+ )
+ self.hs.is_mine = _is_mine
+ self.assertEquals(403, code)
+ self.assertEquals(response['errcode'], Codes.FORBIDDEN)
+ @defer.inlineCallbacks
+ def test_get_filter(self):
+ filter_id = yield self.filtering.add_user_filter(
+ user_localpart='apple',
+ user_filter=self.EXAMPLE_FILTER
+ )
(code, response) = yield self.mock_resource.trigger_get(
- "/user/%s/filter/0" % (self.USER_ID)
+ "/user/%s/filter/%s" % (self.USER_ID, filter_id)
)
self.assertEquals(200, code)
- self.assertEquals({"type": ["m.*"]}, response)
+ self.assertEquals(self.EXAMPLE_FILTER, response)
@defer.inlineCallbacks
- def test_get_filter_no_id(self):
- self._user_filters["apple"] = [
- {"type": ["m.*"]}
- ]
+ def test_get_filter_non_existant(self):
+ (code, response) = yield self.mock_resource.trigger_get(
+ "/user/%s/filter/12382148321" % (self.USER_ID)
+ )
+ self.assertEquals(400, code)
+ self.assertEquals(response['errcode'], Codes.NOT_FOUND)
+ # Currently invalid params do not have an appropriate errcode
+ # in errors.py
+ @defer.inlineCallbacks
+ def test_get_filter_invalid_id(self):
(code, response) = yield self.mock_resource.trigger_get(
- "/user/%s/filter/2" % (self.USER_ID)
+ "/user/%s/filter/foobar" % (self.USER_ID)
)
- self.assertEquals(404, code)
+ self.assertEquals(400, code)
+ # No ID also returns an invalid_id error
@defer.inlineCallbacks
- def test_get_filter_no_user(self):
+ def test_get_filter_no_id(self):
(code, response) = yield self.mock_resource.trigger_get(
- "/user/%s/filter/0" % (self.USER_ID)
+ "/user/%s/filter/" % (self.USER_ID)
)
- self.assertEquals(404, code)
+ self.assertEquals(400, code)
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index affd42c015..b6173ab2ee 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -3,6 +3,7 @@ from synapse.api.errors import SynapseError
from twisted.internet import defer
from mock import Mock
from tests import unittest
+from tests.utils import mock_getRawHeaders
import json
@@ -16,10 +17,11 @@ class RegisterRestServletTestCase(unittest.TestCase):
path='/_matrix/api/v2_alpha/register'
)
self.request.args = {}
+ self.request.requestHeaders.getRawHeaders = mock_getRawHeaders()
self.appservice = None
self.auth = Mock(get_appservice_by_req=Mock(
- side_effect=lambda x: defer.succeed(self.appservice))
+ side_effect=lambda x: self.appservice)
)
self.auth_result = (False, None, None, None)
@@ -30,10 +32,10 @@ class RegisterRestServletTestCase(unittest.TestCase):
self.registration_handler = Mock()
self.identity_handler = Mock()
self.login_handler = Mock()
+ self.device_handler = Mock()
# do the dance to hook it up to the hs global
self.handlers = Mock(
- auth_handler=self.auth_handler,
registration_handler=self.registration_handler,
identity_handler=self.identity_handler,
login_handler=self.login_handler
@@ -42,6 +44,8 @@ class RegisterRestServletTestCase(unittest.TestCase):
self.hs.hostname = "superbig~testing~thing.com"
self.hs.get_auth = Mock(return_value=self.auth)
self.hs.get_handlers = Mock(return_value=self.handlers)
+ self.hs.get_auth_handler = Mock(return_value=self.auth_handler)
+ self.hs.get_device_handler = Mock(return_value=self.device_handler)
self.hs.config.enable_registration = True
# init the thing we're testing
@@ -61,8 +65,12 @@ class RegisterRestServletTestCase(unittest.TestCase):
"id": "1234"
}
self.registration_handler.appservice_register = Mock(
- return_value=(user_id, token)
+ return_value=user_id
)
+ self.auth_handler.get_access_token_for_user_id = Mock(
+ return_value=token
+ )
+
(code, result) = yield self.servlet.on_POST(self.request)
self.assertEquals(code, 200)
det_data = {
@@ -71,7 +79,6 @@ class RegisterRestServletTestCase(unittest.TestCase):
"home_server": self.hs.hostname
}
self.assertDictContainsSubset(det_data, result)
- self.assertIn("refresh_token", result)
@defer.inlineCallbacks
def test_POST_appservice_registration_invalid(self):
@@ -105,26 +112,35 @@ class RegisterRestServletTestCase(unittest.TestCase):
def test_POST_user_valid(self):
user_id = "@kermit:muppet"
token = "kermits_access_token"
+ device_id = "frogfone"
self.request_data = json.dumps({
"username": "kermit",
- "password": "monkey"
+ "password": "monkey",
+ "device_id": device_id,
})
self.registration_handler.check_username = Mock(return_value=True)
self.auth_result = (True, None, {
"username": "kermit",
"password": "monkey"
}, None)
- self.registration_handler.register = Mock(return_value=(user_id, token))
+ self.registration_handler.register = Mock(return_value=(user_id, None))
+ self.auth_handler.get_access_token_for_user_id = Mock(
+ return_value=token
+ )
+ self.device_handler.check_device_registered = \
+ Mock(return_value=device_id)
(code, result) = yield self.servlet.on_POST(self.request)
self.assertEquals(code, 200)
det_data = {
"user_id": user_id,
"access_token": token,
- "home_server": self.hs.hostname
+ "home_server": self.hs.hostname,
+ "device_id": device_id,
}
self.assertDictContainsSubset(det_data, result)
- self.assertIn("refresh_token", result)
+ self.auth_handler.get_login_tuple_for_user_id(
+ user_id, device_id=device_id, initial_device_display_name=None)
def test_POST_disabled_registration(self):
self.hs.config.enable_registration = False
diff --git a/tests/storage/event_injector.py b/tests/storage/event_injector.py
index f22ba8db89..38556da9a7 100644
--- a/tests/storage/event_injector.py
+++ b/tests/storage/event_injector.py
@@ -30,6 +30,7 @@ class EventInjector:
def create_room(self, room):
builder = self.event_builder_factory.new({
"type": EventTypes.Create,
+ "sender": "",
"room_id": room.to_string(),
"content": {},
})
diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index 96b7dba5fe..ab6095564a 100644
--- a/tests/storage/test__base.py
+++ b/tests/storage/test__base.py
@@ -17,6 +17,8 @@
from tests import unittest
from twisted.internet import defer
+from mock import Mock
+
from synapse.util.async import ObservableDeferred
from synapse.util.caches.descriptors import Cache, cached
@@ -72,7 +74,7 @@ class CacheTestCase(unittest.TestCase):
cache.get(3)
def test_eviction_lru(self):
- cache = Cache("test", max_entries=2, lru=True)
+ cache = Cache("test", max_entries=2)
cache.prefill(1, "one")
cache.prefill(2, "two")
@@ -199,3 +201,115 @@ class CacheDecoratorTestCase(unittest.TestCase):
self.assertEquals(a.func("foo").result, d.result)
self.assertEquals(callcount[0], 0)
+
+ @defer.inlineCallbacks
+ def test_invalidate_context(self):
+ callcount = [0]
+ callcount2 = [0]
+
+ class A(object):
+ @cached()
+ def func(self, key):
+ callcount[0] += 1
+ return key
+
+ @cached(cache_context=True)
+ def func2(self, key, cache_context):
+ callcount2[0] += 1
+ return self.func(key, on_invalidate=cache_context.invalidate)
+
+ a = A()
+ yield a.func2("foo")
+
+ self.assertEquals(callcount[0], 1)
+ self.assertEquals(callcount2[0], 1)
+
+ a.func.invalidate(("foo",))
+ yield a.func("foo")
+
+ self.assertEquals(callcount[0], 2)
+ self.assertEquals(callcount2[0], 1)
+
+ yield a.func2("foo")
+
+ self.assertEquals(callcount[0], 2)
+ self.assertEquals(callcount2[0], 2)
+
+ @defer.inlineCallbacks
+ def test_eviction_context(self):
+ callcount = [0]
+ callcount2 = [0]
+
+ class A(object):
+ @cached(max_entries=2)
+ def func(self, key):
+ callcount[0] += 1
+ return key
+
+ @cached(cache_context=True)
+ def func2(self, key, cache_context):
+ callcount2[0] += 1
+ return self.func(key, on_invalidate=cache_context.invalidate)
+
+ a = A()
+ yield a.func2("foo")
+ yield a.func2("foo2")
+
+ self.assertEquals(callcount[0], 2)
+ self.assertEquals(callcount2[0], 2)
+
+ yield a.func("foo3")
+
+ self.assertEquals(callcount[0], 3)
+ self.assertEquals(callcount2[0], 2)
+
+ yield a.func2("foo")
+
+ self.assertEquals(callcount[0], 4)
+ self.assertEquals(callcount2[0], 3)
+
+ @defer.inlineCallbacks
+ def test_double_get(self):
+ callcount = [0]
+ callcount2 = [0]
+
+ class A(object):
+ @cached()
+ def func(self, key):
+ callcount[0] += 1
+ return key
+
+ @cached(cache_context=True)
+ def func2(self, key, cache_context):
+ callcount2[0] += 1
+ return self.func(key, on_invalidate=cache_context.invalidate)
+
+ a = A()
+ a.func2.cache.cache = Mock(wraps=a.func2.cache.cache)
+
+ yield a.func2("foo")
+
+ self.assertEquals(callcount[0], 1)
+ self.assertEquals(callcount2[0], 1)
+
+ a.func2.invalidate(("foo",))
+ self.assertEquals(a.func2.cache.cache.pop.call_count, 1)
+
+ yield a.func2("foo")
+ a.func2.invalidate(("foo",))
+ self.assertEquals(a.func2.cache.cache.pop.call_count, 2)
+
+ self.assertEquals(callcount[0], 1)
+ self.assertEquals(callcount2[0], 2)
+
+ a.func.invalidate(("foo",))
+ self.assertEquals(a.func2.cache.cache.pop.call_count, 3)
+ yield a.func("foo")
+
+ self.assertEquals(callcount[0], 2)
+ self.assertEquals(callcount2[0], 2)
+
+ yield a.func2("foo")
+
+ self.assertEquals(callcount[0], 2)
+ self.assertEquals(callcount2[0], 3)
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index 5734198121..9ff1abcd80 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -37,8 +37,9 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
config = Mock(
app_service_config_files=self.as_yaml_files,
event_cache_size=1,
+ password_providers=[],
)
- hs = yield setup_test_homeserver(config=config)
+ hs = yield setup_test_homeserver(config=config, federation_sender=Mock())
self.as_token = "token1"
self.as_url = "some_url"
@@ -71,14 +72,12 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
outfile.write(yaml.dump(as_yaml))
self.as_yaml_files.append(as_token)
- @defer.inlineCallbacks
def test_retrieve_unknown_service_token(self):
- service = yield self.store.get_app_service_by_token("invalid_token")
+ service = self.store.get_app_service_by_token("invalid_token")
self.assertEquals(service, None)
- @defer.inlineCallbacks
def test_retrieval_of_service(self):
- stored_service = yield self.store.get_app_service_by_token(
+ stored_service = self.store.get_app_service_by_token(
self.as_token
)
self.assertEquals(stored_service.token, self.as_token)
@@ -97,9 +96,8 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
[]
)
- @defer.inlineCallbacks
def test_retrieval_of_all_services(self):
- services = yield self.store.get_app_services()
+ services = self.store.get_app_services()
self.assertEquals(len(services), 3)
@@ -112,8 +110,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
config = Mock(
app_service_config_files=self.as_yaml_files,
event_cache_size=1,
+ password_providers=[],
)
- hs = yield setup_test_homeserver(config=config)
+ hs = yield setup_test_homeserver(config=config, federation_sender=Mock())
self.db_pool = hs.get_db_pool()
self.as_list = [
@@ -357,7 +356,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
other_events = [Mock(event_id="e5"), Mock(event_id="e6")]
# we aren't testing store._base stuff here, so mock this out
- self.store._get_events_txn = Mock(return_value=events)
+ self.store._get_events = Mock(return_value=events)
yield self._insert_txn(self.as_list[1]["id"], 9, other_events)
yield self._insert_txn(service.id, 10, events)
@@ -440,8 +439,15 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
f1 = self._write_config(suffix="1")
f2 = self._write_config(suffix="2")
- config = Mock(app_service_config_files=[f1, f2], event_cache_size=1)
- hs = yield setup_test_homeserver(config=config, datastore=Mock())
+ config = Mock(
+ app_service_config_files=[f1, f2], event_cache_size=1,
+ password_providers=[]
+ )
+ hs = yield setup_test_homeserver(
+ config=config,
+ datastore=Mock(),
+ federation_sender=Mock()
+ )
ApplicationServiceStore(hs)
@@ -450,8 +456,15 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
f1 = self._write_config(id="id", suffix="1")
f2 = self._write_config(id="id", suffix="2")
- config = Mock(app_service_config_files=[f1, f2], event_cache_size=1)
- hs = yield setup_test_homeserver(config=config, datastore=Mock())
+ config = Mock(
+ app_service_config_files=[f1, f2], event_cache_size=1,
+ password_providers=[]
+ )
+ hs = yield setup_test_homeserver(
+ config=config,
+ datastore=Mock(),
+ federation_sender=Mock()
+ )
with self.assertRaises(ConfigError) as cm:
ApplicationServiceStore(hs)
@@ -466,8 +479,15 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
f1 = self._write_config(as_token="as_token", suffix="1")
f2 = self._write_config(as_token="as_token", suffix="2")
- config = Mock(app_service_config_files=[f1, f2], event_cache_size=1)
- hs = yield setup_test_homeserver(config=config, datastore=Mock())
+ config = Mock(
+ app_service_config_files=[f1, f2], event_cache_size=1,
+ password_providers=[]
+ )
+ hs = yield setup_test_homeserver(
+ config=config,
+ datastore=Mock(),
+ federation_sender=Mock()
+ )
with self.assertRaises(ConfigError) as cm:
ApplicationServiceStore(hs)
diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py
index 6e4d9b1373..1286b4ce2d 100644
--- a/tests/storage/test_background_update.py
+++ b/tests/storage/test_background_update.py
@@ -10,7 +10,7 @@ class BackgroundUpdateTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
- hs = yield setup_test_homeserver()
+ hs = yield setup_test_homeserver() # type: synapse.server.HomeServer
self.store = hs.get_datastore()
self.clock = hs.get_clock()
@@ -20,11 +20,20 @@ class BackgroundUpdateTestCase(unittest.TestCase):
"test_update", self.update_handler
)
+ # run the real background updates, to get them out the way
+ # (perhaps we should run them as part of the test HS setup, since we
+ # run all of the other schema setup stuff there?)
+ while True:
+ res = yield self.store.do_next_background_update(1000)
+ if res is None:
+ break
+
@defer.inlineCallbacks
def test_do_background_update(self):
desired_count = 1000
duration_ms = 42
+ # first step: make a bit of progress
@defer.inlineCallbacks
def update(progress, count):
self.clock.advance_time_msec(count * duration_ms)
@@ -42,7 +51,7 @@ class BackgroundUpdateTestCase(unittest.TestCase):
yield self.store.start_background_update("test_update", {"my_key": 1})
self.update_handler.reset_mock()
- result = yield self.store.do_background_update(
+ result = yield self.store.do_next_background_update(
duration_ms * desired_count
)
self.assertIsNotNone(result)
@@ -50,15 +59,15 @@ class BackgroundUpdateTestCase(unittest.TestCase):
{"my_key": 1}, self.store.DEFAULT_BACKGROUND_BATCH_SIZE
)
+ # second step: complete the update
@defer.inlineCallbacks
def update(progress, count):
yield self.store._end_background_update("test_update")
defer.returnValue(count)
self.update_handler.side_effect = update
-
self.update_handler.reset_mock()
- result = yield self.store.do_background_update(
+ result = yield self.store.do_next_background_update(
duration_ms * desired_count
)
self.assertIsNotNone(result)
@@ -66,8 +75,9 @@ class BackgroundUpdateTestCase(unittest.TestCase):
{"my_key": 2}, desired_count
)
+ # third step: we don't expect to be called any more
self.update_handler.reset_mock()
- result = yield self.store.do_background_update(
+ result = yield self.store.do_next_background_update(
duration_ms * desired_count
)
self.assertIsNone(result)
diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index 2e33beb07c..afbefb2e2d 100644
--- a/tests/storage/test_base.py
+++ b/tests/storage/test_base.py
@@ -53,7 +53,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
"test",
db_pool=self.db_pool,
config=config,
- database_engine=create_engine(config),
+ database_engine=create_engine(config.database_config),
)
self.datastore = SQLBaseStore(hs)
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
new file mode 100644
index 0000000000..1f0c0e7c37
--- /dev/null
+++ b/tests/storage/test_client_ips.py
@@ -0,0 +1,62 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+import synapse.server
+import synapse.storage
+import synapse.types
+import tests.unittest
+import tests.utils
+
+
+class ClientIpStoreTestCase(tests.unittest.TestCase):
+ def __init__(self, *args, **kwargs):
+ super(ClientIpStoreTestCase, self).__init__(*args, **kwargs)
+ self.store = None # type: synapse.storage.DataStore
+ self.clock = None # type: tests.utils.MockClock
+
+ @defer.inlineCallbacks
+ def setUp(self):
+ hs = yield tests.utils.setup_test_homeserver()
+ self.store = hs.get_datastore()
+ self.clock = hs.get_clock()
+
+ @defer.inlineCallbacks
+ def test_insert_new_client_ip(self):
+ self.clock.now = 12345678
+ user_id = "@user:id"
+ yield self.store.insert_client_ip(
+ synapse.types.UserID.from_string(user_id),
+ "access_token", "ip", "user_agent", "device_id",
+ )
+
+ # deliberately use an iterable here to make sure that the lookup
+ # method doesn't iterate it twice
+ device_list = iter(((user_id, "device_id"),))
+ result = yield self.store.get_last_client_ip_by_device(device_list)
+
+ r = result[(user_id, "device_id")]
+ self.assertDictContainsSubset(
+ {
+ "user_id": user_id,
+ "device_id": "device_id",
+ "access_token": "access_token",
+ "ip": "ip",
+ "user_agent": "user_agent",
+ "last_seen": 12345678000,
+ },
+ r
+ )
diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
new file mode 100644
index 0000000000..f8725acea0
--- /dev/null
+++ b/tests/storage/test_devices.py
@@ -0,0 +1,105 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+import synapse.api.errors
+import tests.unittest
+import tests.utils
+
+
+class DeviceStoreTestCase(tests.unittest.TestCase):
+ def __init__(self, *args, **kwargs):
+ super(DeviceStoreTestCase, self).__init__(*args, **kwargs)
+ self.store = None # type: synapse.storage.DataStore
+
+ @defer.inlineCallbacks
+ def setUp(self):
+ hs = yield tests.utils.setup_test_homeserver()
+
+ self.store = hs.get_datastore()
+
+ @defer.inlineCallbacks
+ def test_store_new_device(self):
+ yield self.store.store_device(
+ "user_id", "device_id", "display_name"
+ )
+
+ res = yield self.store.get_device("user_id", "device_id")
+ self.assertDictContainsSubset({
+ "user_id": "user_id",
+ "device_id": "device_id",
+ "display_name": "display_name",
+ }, res)
+
+ @defer.inlineCallbacks
+ def test_get_devices_by_user(self):
+ yield self.store.store_device(
+ "user_id", "device1", "display_name 1"
+ )
+ yield self.store.store_device(
+ "user_id", "device2", "display_name 2"
+ )
+ yield self.store.store_device(
+ "user_id2", "device3", "display_name 3"
+ )
+
+ res = yield self.store.get_devices_by_user("user_id")
+ self.assertEqual(2, len(res.keys()))
+ self.assertDictContainsSubset({
+ "user_id": "user_id",
+ "device_id": "device1",
+ "display_name": "display_name 1",
+ }, res["device1"])
+ self.assertDictContainsSubset({
+ "user_id": "user_id",
+ "device_id": "device2",
+ "display_name": "display_name 2",
+ }, res["device2"])
+
+ @defer.inlineCallbacks
+ def test_update_device(self):
+ yield self.store.store_device(
+ "user_id", "device_id", "display_name 1"
+ )
+
+ res = yield self.store.get_device("user_id", "device_id")
+ self.assertEqual("display_name 1", res["display_name"])
+
+ # do a no-op first
+ yield self.store.update_device(
+ "user_id", "device_id",
+ )
+ res = yield self.store.get_device("user_id", "device_id")
+ self.assertEqual("display_name 1", res["display_name"])
+
+ # do the update
+ yield self.store.update_device(
+ "user_id", "device_id",
+ new_display_name="display_name 2",
+ )
+
+ # check it worked
+ res = yield self.store.get_device("user_id", "device_id")
+ self.assertEqual("display_name 2", res["display_name"])
+
+ @defer.inlineCallbacks
+ def test_update_unknown_device(self):
+ with self.assertRaises(synapse.api.errors.StoreError) as cm:
+ yield self.store.update_device(
+ "user_id", "unknown_device_id",
+ new_display_name="display_name 2",
+ )
+ self.assertEqual(404, cm.exception.code)
diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py
new file mode 100644
index 0000000000..453bc61438
--- /dev/null
+++ b/tests/storage/test_end_to_end_keys.py
@@ -0,0 +1,90 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+import tests.unittest
+import tests.utils
+
+
+class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
+ def __init__(self, *args, **kwargs):
+ super(EndToEndKeyStoreTestCase, self).__init__(*args, **kwargs)
+ self.store = None # type: synapse.storage.DataStore
+
+ @defer.inlineCallbacks
+ def setUp(self):
+ hs = yield tests.utils.setup_test_homeserver()
+
+ self.store = hs.get_datastore()
+
+ @defer.inlineCallbacks
+ def test_key_without_device_name(self):
+ now = 1470174257070
+ json = '{ "key": "value" }'
+
+ yield self.store.set_e2e_device_keys(
+ "user", "device", now, json)
+
+ res = yield self.store.get_e2e_device_keys((("user", "device"),))
+ self.assertIn("user", res)
+ self.assertIn("device", res["user"])
+ dev = res["user"]["device"]
+ self.assertDictContainsSubset({
+ "key_json": json,
+ "device_display_name": None,
+ }, dev)
+
+ @defer.inlineCallbacks
+ def test_get_key_with_device_name(self):
+ now = 1470174257070
+ json = '{ "key": "value" }'
+
+ yield self.store.set_e2e_device_keys(
+ "user", "device", now, json)
+ yield self.store.store_device(
+ "user", "device", "display_name"
+ )
+
+ res = yield self.store.get_e2e_device_keys((("user", "device"),))
+ self.assertIn("user", res)
+ self.assertIn("device", res["user"])
+ dev = res["user"]["device"]
+ self.assertDictContainsSubset({
+ "key_json": json,
+ "device_display_name": "display_name",
+ }, dev)
+
+ @defer.inlineCallbacks
+ def test_multiple_devices(self):
+ now = 1470174257070
+
+ yield self.store.set_e2e_device_keys(
+ "user1", "device1", now, 'json11')
+ yield self.store.set_e2e_device_keys(
+ "user1", "device2", now, 'json12')
+ yield self.store.set_e2e_device_keys(
+ "user2", "device1", now, 'json21')
+ yield self.store.set_e2e_device_keys(
+ "user2", "device2", now, 'json22')
+
+ res = yield self.store.get_e2e_device_keys((("user1", "device1"),
+ ("user2", "device2")))
+ self.assertIn("user1", res)
+ self.assertIn("device1", res["user1"])
+ self.assertNotIn("device2", res["user1"])
+ self.assertIn("user2", res)
+ self.assertNotIn("device1", res["user2"])
+ self.assertIn("device2", res["user2"])
diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
new file mode 100644
index 0000000000..e9044afa2e
--- /dev/null
+++ b/tests/storage/test_event_push_actions.py
@@ -0,0 +1,41 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+import tests.unittest
+import tests.utils
+
+USER_ID = "@user:example.com"
+
+
+class EventPushActionsStoreTestCase(tests.unittest.TestCase):
+
+ @defer.inlineCallbacks
+ def setUp(self):
+ hs = yield tests.utils.setup_test_homeserver()
+ self.store = hs.get_datastore()
+
+ @defer.inlineCallbacks
+ def test_get_unread_push_actions_for_user_in_range_for_http(self):
+ yield self.store.get_unread_push_actions_for_user_in_range_for_http(
+ USER_ID, 0, 1000, 20
+ )
+
+ @defer.inlineCallbacks
+ def test_get_unread_push_actions_for_user_in_range_for_email(self):
+ yield self.store.get_unread_push_actions_for_user_in_range_for_email(
+ USER_ID, 0, 1000, 20
+ )
diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py
index 18a6cff0c7..3762b38e37 100644
--- a/tests/storage/test_events.py
+++ b/tests/storage/test_events.py
@@ -37,7 +37,7 @@ class EventsStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_count_daily_messages(self):
- self.db_pool.runQuery("DELETE FROM stats_reporting")
+ yield self.db_pool.runQuery("DELETE FROM stats_reporting")
self.hs.clock.now = 100
@@ -60,7 +60,7 @@ class EventsStoreTestCase(unittest.TestCase):
# it isn't old enough.
count = yield self.store.count_daily_messages()
self.assertIsNone(count)
- self._assert_stats_reporting(1, self.hs.clock.now)
+ yield self._assert_stats_reporting(1, self.hs.clock.now)
# Already reported yesterday, two new events from today.
yield self.event_injector.inject_message(room, user, "Yeah they are!")
@@ -68,21 +68,21 @@ class EventsStoreTestCase(unittest.TestCase):
self.hs.clock.now += 60 * 60 * 24
count = yield self.store.count_daily_messages()
self.assertEqual(2, count) # 2 since yesterday
- self._assert_stats_reporting(3, self.hs.clock.now) # 3 ever
+ yield self._assert_stats_reporting(3, self.hs.clock.now) # 3 ever
# Last reported too recently.
yield self.event_injector.inject_message(room, user, "Who could disagree?")
self.hs.clock.now += 60 * 60 * 22
count = yield self.store.count_daily_messages()
self.assertIsNone(count)
- self._assert_stats_reporting(4, self.hs.clock.now)
+ yield self._assert_stats_reporting(4, self.hs.clock.now)
# Last reported too long ago
yield self.event_injector.inject_message(room, user, "No one.")
self.hs.clock.now += 60 * 60 * 26
count = yield self.store.count_daily_messages()
self.assertIsNone(count)
- self._assert_stats_reporting(5, self.hs.clock.now)
+ yield self._assert_stats_reporting(5, self.hs.clock.now)
# And now let's actually report something
yield self.event_injector.inject_message(room, user, "Indeed.")
@@ -92,7 +92,7 @@ class EventsStoreTestCase(unittest.TestCase):
self.hs.clock.now += (60 * 60 * 24) + 50
count = yield self.store.count_daily_messages()
self.assertEqual(3, count)
- self._assert_stats_reporting(8, self.hs.clock.now)
+ yield self._assert_stats_reporting(8, self.hs.clock.now)
@defer.inlineCallbacks
def _get_last_stream_token(self):
diff --git a/tests/storage/test_presence.py b/tests/storage/test_presence.py
index ec78f007ca..63203cea35 100644
--- a/tests/storage/test_presence.py
+++ b/tests/storage/test_presence.py
@@ -35,33 +35,6 @@ class PresenceStoreTestCase(unittest.TestCase):
self.u_banana = UserID.from_string("@banana:test")
@defer.inlineCallbacks
- def test_visibility(self):
- self.assertFalse((yield self.store.is_presence_visible(
- observed_localpart=self.u_apple.localpart,
- observer_userid=self.u_banana.to_string(),
- )))
-
- yield self.store.allow_presence_visible(
- observed_localpart=self.u_apple.localpart,
- observer_userid=self.u_banana.to_string(),
- )
-
- self.assertTrue((yield self.store.is_presence_visible(
- observed_localpart=self.u_apple.localpart,
- observer_userid=self.u_banana.to_string(),
- )))
-
- yield self.store.disallow_presence_visible(
- observed_localpart=self.u_apple.localpart,
- observer_userid=self.u_banana.to_string(),
- )
-
- self.assertFalse((yield self.store.is_presence_visible(
- observed_localpart=self.u_apple.localpart,
- observer_userid=self.u_banana.to_string(),
- )))
-
- @defer.inlineCallbacks
def test_presence_list(self):
self.assertEquals(
[],
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index 5880409867..6afaca3a61 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -110,22 +110,10 @@ class RedactionTestCase(unittest.TestCase):
self.room1, self.u_alice, Membership.JOIN
)
- start = yield self.store.get_room_events_max_id()
-
msg_event = yield self.inject_message(self.room1, self.u_alice, u"t")
- end = yield self.store.get_room_events_max_id()
-
- results, _ = yield self.store.get_room_events_stream(
- self.u_alice.to_string(),
- start,
- end,
- )
-
- self.assertEqual(1, len(results))
-
# Check event has not been redacted:
- event = results[0]
+ event = yield self.store.get_event(msg_event.event_id)
self.assertObjectHasAttributes(
{
@@ -144,17 +132,7 @@ class RedactionTestCase(unittest.TestCase):
self.room1, msg_event.event_id, self.u_alice, reason
)
- results, _ = yield self.store.get_room_events_stream(
- self.u_alice.to_string(),
- start,
- end,
- )
-
- self.assertEqual(1, len(results))
-
- # Check redaction
-
- event = results[0]
+ event = yield self.store.get_event(msg_event.event_id)
self.assertEqual(msg_event.event_id, event.event_id)
@@ -184,25 +162,12 @@ class RedactionTestCase(unittest.TestCase):
self.room1, self.u_alice, Membership.JOIN
)
- start = yield self.store.get_room_events_max_id()
-
msg_event = yield self.inject_room_member(
self.room1, self.u_bob, Membership.JOIN,
extra_content={"blue": "red"},
)
- end = yield self.store.get_room_events_max_id()
-
- results, _ = yield self.store.get_room_events_stream(
- self.u_alice.to_string(),
- start,
- end,
- )
-
- self.assertEqual(1, len(results))
-
- # Check event has not been redacted:
- event = results[0]
+ event = yield self.store.get_event(msg_event.event_id)
self.assertObjectHasAttributes(
{
@@ -221,17 +186,9 @@ class RedactionTestCase(unittest.TestCase):
self.room1, msg_event.event_id, self.u_alice, reason
)
- results, _ = yield self.store.get_room_events_stream(
- self.u_alice.to_string(),
- start,
- end,
- )
-
- self.assertEqual(1, len(results))
-
# Check redaction
- event = results[0]
+ event = yield self.store.get_event(msg_event.event_id)
self.assertTrue("redacted_because" in event.unsigned)
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index b8384c98d8..316ecdb32d 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -17,9 +17,6 @@
from tests import unittest
from twisted.internet import defer
-from synapse.api.errors import StoreError
-from synapse.util import stringutils
-
from tests.utils import setup_test_homeserver
@@ -38,6 +35,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
"BcDeFgHiJkLmNoPqRsTuVwXyZa"
]
self.pwhash = "{xx1}123456789"
+ self.device_id = "akgjhdjklgshg"
@defer.inlineCallbacks
def test_register(self):
@@ -64,13 +62,15 @@ class RegistrationStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_add_tokens(self):
yield self.store.register(self.user_id, self.tokens[0], self.pwhash)
- yield self.store.add_access_token_to_user(self.user_id, self.tokens[1])
+ yield self.store.add_access_token_to_user(self.user_id, self.tokens[1],
+ self.device_id)
result = yield self.store.get_user_by_access_token(self.tokens[1])
self.assertDictContainsSubset(
{
"name": self.user_id,
+ "device_id": self.device_id,
},
result
)
@@ -78,48 +78,31 @@ class RegistrationStoreTestCase(unittest.TestCase):
self.assertTrue("token_id" in result)
@defer.inlineCallbacks
- def test_exchange_refresh_token_valid(self):
- uid = stringutils.random_string(32)
- generator = TokenGenerator()
- last_token = generator.generate(uid)
-
- self.db_pool.runQuery(
- "INSERT INTO refresh_tokens(user_id, token) VALUES(?,?)",
- (uid, last_token,))
-
- (found_user_id, refresh_token) = yield self.store.exchange_refresh_token(
- last_token, generator.generate)
- self.assertEqual(uid, found_user_id)
-
- rows = yield self.db_pool.runQuery(
- "SELECT token FROM refresh_tokens WHERE user_id = ?", (uid, ))
- self.assertEqual([(refresh_token,)], rows)
- # We issued token 1, then exchanged it for token 2
- expected_refresh_token = u"%s-%d" % (uid, 2,)
- self.assertEqual(expected_refresh_token, refresh_token)
+ def test_user_delete_access_tokens(self):
+ # add some tokens
+ yield self.store.register(self.user_id, self.tokens[0], self.pwhash)
+ yield self.store.add_access_token_to_user(self.user_id, self.tokens[1],
+ self.device_id)
- @defer.inlineCallbacks
- def test_exchange_refresh_token_none(self):
- uid = stringutils.random_string(32)
- generator = TokenGenerator()
- last_token = generator.generate(uid)
+ # now delete some
+ yield self.store.user_delete_access_tokens(
+ self.user_id, device_id=self.device_id, delete_refresh_tokens=True)
- with self.assertRaises(StoreError):
- yield self.store.exchange_refresh_token(last_token, generator.generate)
+ # check they were deleted
+ user = yield self.store.get_user_by_access_token(self.tokens[1])
+ self.assertIsNone(user, "access token was not deleted by device_id")
- @defer.inlineCallbacks
- def test_exchange_refresh_token_invalid(self):
- uid = stringutils.random_string(32)
- generator = TokenGenerator()
- last_token = generator.generate(uid)
- wrong_token = "%s-wrong" % (last_token,)
-
- self.db_pool.runQuery(
- "INSERT INTO refresh_tokens(user_id, token) VALUES(?,?)",
- (uid, wrong_token,))
-
- with self.assertRaises(StoreError):
- yield self.store.exchange_refresh_token(last_token, generator.generate)
+ # check the one not associated with the device was not deleted
+ user = yield self.store.get_user_by_access_token(self.tokens[0])
+ self.assertEqual(self.user_id, user["name"])
+
+ # now delete the rest
+ yield self.store.user_delete_access_tokens(
+ self.user_id, delete_refresh_tokens=True)
+
+ user = yield self.store.get_user_by_access_token(self.tokens[0])
+ self.assertIsNone(user,
+ "access token was not deleted without device_id")
class TokenGenerator:
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index b029ff0584..1be7d932f6 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -71,19 +71,6 @@ class RoomMemberStoreTestCase(unittest.TestCase):
yield self.inject_room_member(self.room, self.u_alice, Membership.JOIN)
self.assertEquals(
- Membership.JOIN,
- (yield self.store.get_room_member(
- user_id=self.u_alice.to_string(),
- room_id=self.room.to_string(),
- )).membership
- )
- self.assertEquals(
- [self.u_alice.to_string()],
- [m.user_id for m in (
- yield self.store.get_room_members(self.room.to_string())
- )]
- )
- self.assertEquals(
[self.room.to_string()],
[m.room_id for m in (
yield self.store.get_rooms_for_user_where_membership_is(
@@ -91,56 +78,3 @@ class RoomMemberStoreTestCase(unittest.TestCase):
)
)]
)
-
- @defer.inlineCallbacks
- def test_two_members(self):
- yield self.inject_room_member(self.room, self.u_alice, Membership.JOIN)
- yield self.inject_room_member(self.room, self.u_bob, Membership.JOIN)
-
- self.assertEquals(
- {self.u_alice.to_string(), self.u_bob.to_string()},
- {m.user_id for m in (
- yield self.store.get_room_members(self.room.to_string())
- )}
- )
-
- @defer.inlineCallbacks
- def test_room_hosts(self):
- yield self.inject_room_member(self.room, self.u_alice, Membership.JOIN)
-
- self.assertEquals(
- {"test"},
- (yield self.store.get_joined_hosts_for_room(self.room.to_string()))
- )
-
- # Should still have just one host after second join from it
- yield self.inject_room_member(self.room, self.u_bob, Membership.JOIN)
-
- self.assertEquals(
- {"test"},
- (yield self.store.get_joined_hosts_for_room(self.room.to_string()))
- )
-
- # Should now have two hosts after join from other host
- yield self.inject_room_member(self.room, self.u_charlie, Membership.JOIN)
-
- self.assertEquals(
- {"test", "elsewhere"},
- (yield self.store.get_joined_hosts_for_room(self.room.to_string()))
- )
-
- # Should still have both hosts
- yield self.inject_room_member(self.room, self.u_alice, Membership.LEAVE)
-
- self.assertEquals(
- {"test", "elsewhere"},
- (yield self.store.get_joined_hosts_for_room(self.room.to_string()))
- )
-
- # Should have only one host after other leaves
- yield self.inject_room_member(self.room, self.u_charlie, Membership.LEAVE)
-
- self.assertEquals(
- {"test"},
- (yield self.store.get_joined_hosts_for_room(self.room.to_string()))
- )
diff --git a/tests/storage/test_stream.py b/tests/storage/test_stream.py
deleted file mode 100644
index da322152c7..0000000000
--- a/tests/storage/test_stream.py
+++ /dev/null
@@ -1,185 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2014-2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-from tests import unittest
-from twisted.internet import defer
-
-from synapse.api.constants import EventTypes, Membership
-from synapse.types import UserID, RoomID
-from tests.storage.event_injector import EventInjector
-
-from tests.utils import setup_test_homeserver
-
-from mock import Mock
-
-
-class StreamStoreTestCase(unittest.TestCase):
-
- @defer.inlineCallbacks
- def setUp(self):
- hs = yield setup_test_homeserver(
- resource_for_federation=Mock(),
- http_client=None,
- )
-
- self.store = hs.get_datastore()
- self.event_builder_factory = hs.get_event_builder_factory()
- self.event_injector = EventInjector(hs)
- self.handlers = hs.get_handlers()
- self.message_handler = self.handlers.message_handler
-
- self.u_alice = UserID.from_string("@alice:test")
- self.u_bob = UserID.from_string("@bob:test")
-
- self.room1 = RoomID.from_string("!abc123:test")
- self.room2 = RoomID.from_string("!xyx987:test")
-
- @defer.inlineCallbacks
- def test_event_stream_get_other(self):
- # Both bob and alice joins the room
- yield self.event_injector.inject_room_member(
- self.room1, self.u_alice, Membership.JOIN
- )
- yield self.event_injector.inject_room_member(
- self.room1, self.u_bob, Membership.JOIN
- )
-
- # Initial stream key:
- start = yield self.store.get_room_events_max_id()
-
- yield self.event_injector.inject_message(self.room1, self.u_alice, u"test")
-
- end = yield self.store.get_room_events_max_id()
-
- results, _ = yield self.store.get_room_events_stream(
- self.u_bob.to_string(),
- start,
- end,
- )
-
- self.assertEqual(1, len(results))
-
- event = results[0]
-
- self.assertObjectHasAttributes(
- {
- "type": EventTypes.Message,
- "user_id": self.u_alice.to_string(),
- "content": {"body": "test", "msgtype": "message"},
- },
- event,
- )
-
- @defer.inlineCallbacks
- def test_event_stream_get_own(self):
- # Both bob and alice joins the room
- yield self.event_injector.inject_room_member(
- self.room1, self.u_alice, Membership.JOIN
- )
- yield self.event_injector.inject_room_member(
- self.room1, self.u_bob, Membership.JOIN
- )
-
- # Initial stream key:
- start = yield self.store.get_room_events_max_id()
-
- yield self.event_injector.inject_message(self.room1, self.u_alice, u"test")
-
- end = yield self.store.get_room_events_max_id()
-
- results, _ = yield self.store.get_room_events_stream(
- self.u_alice.to_string(),
- start,
- end,
- )
-
- self.assertEqual(1, len(results))
-
- event = results[0]
-
- self.assertObjectHasAttributes(
- {
- "type": EventTypes.Message,
- "user_id": self.u_alice.to_string(),
- "content": {"body": "test", "msgtype": "message"},
- },
- event,
- )
-
- @defer.inlineCallbacks
- def test_event_stream_join_leave(self):
- # Both bob and alice joins the room
- yield self.event_injector.inject_room_member(
- self.room1, self.u_alice, Membership.JOIN
- )
- yield self.event_injector.inject_room_member(
- self.room1, self.u_bob, Membership.JOIN
- )
-
- # Then bob leaves again.
- yield self.event_injector.inject_room_member(
- self.room1, self.u_bob, Membership.LEAVE
- )
-
- # Initial stream key:
- start = yield self.store.get_room_events_max_id()
-
- yield self.event_injector.inject_message(self.room1, self.u_alice, u"test")
-
- end = yield self.store.get_room_events_max_id()
-
- results, _ = yield self.store.get_room_events_stream(
- self.u_bob.to_string(),
- start,
- end,
- )
-
- # We should not get the message, as it happened *after* bob left.
- self.assertEqual(0, len(results))
-
- @defer.inlineCallbacks
- def test_event_stream_prev_content(self):
- yield self.event_injector.inject_room_member(
- self.room1, self.u_bob, Membership.JOIN
- )
-
- yield self.event_injector.inject_room_member(
- self.room1, self.u_alice, Membership.JOIN
- )
-
- start = yield self.store.get_room_events_max_id()
-
- yield self.event_injector.inject_room_member(
- self.room1, self.u_alice, Membership.JOIN,
- )
-
- end = yield self.store.get_room_events_max_id()
-
- results, _ = yield self.store.get_room_events_stream(
- self.u_bob.to_string(),
- start,
- end,
- )
-
- # We should not get the message, as it happened *after* bob left.
- self.assertEqual(1, len(results))
-
- event = results[0]
-
- self.assertTrue(
- "prev_content" in event.unsigned,
- msg="No prev_content key"
- )
diff --git a/tests/test_dns.py b/tests/test_dns.py
index 637b1606f8..c394c57ee7 100644
--- a/tests/test_dns.py
+++ b/tests/test_dns.py
@@ -21,6 +21,8 @@ from mock import Mock
from synapse.http.endpoint import resolve_service
+from tests.utils import MockClock
+
class DnsTestCase(unittest.TestCase):
@@ -63,14 +65,17 @@ class DnsTestCase(unittest.TestCase):
self.assertEquals(servers[0].host, ip_address)
@defer.inlineCallbacks
- def test_from_cache(self):
+ def test_from_cache_expired_and_dns_fail(self):
dns_client_mock = Mock()
dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError())
service_name = "test_service.examle.com"
+ entry = Mock(spec_set=["expires"])
+ entry.expires = 0
+
cache = {
- service_name: [object()]
+ service_name: [entry]
}
servers = yield resolve_service(
@@ -83,6 +88,31 @@ class DnsTestCase(unittest.TestCase):
self.assertEquals(servers, cache[service_name])
@defer.inlineCallbacks
+ def test_from_cache(self):
+ clock = MockClock()
+
+ dns_client_mock = Mock(spec_set=['lookupService'])
+ dns_client_mock.lookupService = Mock(spec_set=[])
+
+ service_name = "test_service.examle.com"
+
+ entry = Mock(spec_set=["expires"])
+ entry.expires = 999999999
+
+ cache = {
+ service_name: [entry]
+ }
+
+ servers = yield resolve_service(
+ service_name, dns_client=dns_client_mock, cache=cache, clock=clock,
+ )
+
+ self.assertFalse(dns_client_mock.lookupService.called)
+
+ self.assertEquals(len(servers), 1)
+ self.assertEquals(servers, cache[service_name])
+
+ @defer.inlineCallbacks
def test_empty_cache(self):
dns_client_mock = Mock()
diff --git a/tests/test_preview.py b/tests/test_preview.py
new file mode 100644
index 0000000000..ffa52e5dd4
--- /dev/null
+++ b/tests/test_preview.py
@@ -0,0 +1,217 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014-2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from . import unittest
+
+from synapse.rest.media.v1.preview_url_resource import (
+ summarize_paragraphs, decode_and_calc_og
+)
+
+
+class PreviewTestCase(unittest.TestCase):
+
+ def test_long_summarize(self):
+ example_paras = [
+ u"""Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:
+ Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in
+ Troms county, Norway. The administrative centre of the municipality is
+ the city of Tromsø. Outside of Norway, Tromso and Tromsö are
+ alternative spellings of the city.Tromsø is considered the northernmost
+ city in the world with a population above 50,000. The most populous town
+ north of it is Alta, Norway, with a population of 14,272 (2013).""",
+
+ u"""Tromsø lies in Northern Norway. The municipality has a population of
+ (2015) 72,066, but with an annual influx of students it has over 75,000
+ most of the year. It is the largest urban area in Northern Norway and the
+ third largest north of the Arctic Circle (following Murmansk and Norilsk).
+ Most of Tromsø, including the city centre, is located on the island of
+ Tromsøya, 350 kilometres (217 mi) north of the Arctic Circle. In 2012,
+ Tromsøya had a population of 36,088. Substantial parts of the urban area
+ are also situated on the mainland to the east, and on parts of Kvaløya—a
+ large island to the west. Tromsøya is connected to the mainland by the Tromsø
+ Bridge and the Tromsøysund Tunnel, and to the island of Kvaløya by the
+ Sandnessund Bridge. Tromsø Airport connects the city to many destinations
+ in Europe. The city is warmer than most other places located on the same
+ latitude, due to the warming effect of the Gulf Stream.""",
+
+ u"""The city centre of Tromsø contains the highest number of old wooden
+ houses in Northern Norway, the oldest house dating from 1789. The Arctic
+ Cathedral, a modern church from 1965, is probably the most famous landmark
+ in Tromsø. The city is a cultural centre for its region, with several
+ festivals taking place in the summer. Some of Norway's best-known
+ musicians, Torbjørn Brundtland and Svein Berge of the electronica duo
+ Röyksopp and Lene Marlin grew up and started their careers in Tromsø.
+ Noted electronic musician Geir Jenssen also hails from Tromsø.""",
+ ]
+
+ desc = summarize_paragraphs(example_paras, min_size=200, max_size=500)
+
+ self.assertEquals(
+ desc,
+ u"Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:"
+ u" Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in"
+ u" Troms county, Norway. The administrative centre of the municipality is"
+ u" the city of Tromsø. Outside of Norway, Tromso and Tromsö are"
+ u" alternative spellings of the city.Tromsø is considered the northernmost"
+ u" city in the world with a population above 50,000. The most populous town"
+ u" north of it is Alta, Norway, with a population of 14,272 (2013)."
+ )
+
+ desc = summarize_paragraphs(example_paras[1:], min_size=200, max_size=500)
+
+ self.assertEquals(
+ desc,
+ u"Tromsø lies in Northern Norway. The municipality has a population of"
+ u" (2015) 72,066, but with an annual influx of students it has over 75,000"
+ u" most of the year. It is the largest urban area in Northern Norway and the"
+ u" third largest north of the Arctic Circle (following Murmansk and Norilsk)."
+ u" Most of Tromsø, including the city centre, is located on the island of"
+ u" Tromsøya, 350 kilometres (217 mi) north of the Arctic Circle. In 2012,"
+ u" Tromsøya had a population of 36,088. Substantial parts of the urban…"
+ )
+
+ def test_short_summarize(self):
+ example_paras = [
+ u"Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:"
+ u" Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in"
+ u" Troms county, Norway.",
+
+ u"Tromsø lies in Northern Norway. The municipality has a population of"
+ u" (2015) 72,066, but with an annual influx of students it has over 75,000"
+ u" most of the year.",
+
+ u"The city centre of Tromsø contains the highest number of old wooden"
+ u" houses in Northern Norway, the oldest house dating from 1789. The Arctic"
+ u" Cathedral, a modern church from 1965, is probably the most famous landmark"
+ u" in Tromsø.",
+ ]
+
+ desc = summarize_paragraphs(example_paras, min_size=200, max_size=500)
+
+ self.assertEquals(
+ desc,
+ u"Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:"
+ u" Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in"
+ u" Troms county, Norway.\n"
+ u"\n"
+ u"Tromsø lies in Northern Norway. The municipality has a population of"
+ u" (2015) 72,066, but with an annual influx of students it has over 75,000"
+ u" most of the year."
+ )
+
+ def test_small_then_large_summarize(self):
+ example_paras = [
+ u"Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:"
+ u" Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in"
+ u" Troms county, Norway.",
+
+ u"Tromsø lies in Northern Norway. The municipality has a population of"
+ u" (2015) 72,066, but with an annual influx of students it has over 75,000"
+ u" most of the year."
+ u" The city centre of Tromsø contains the highest number of old wooden"
+ u" houses in Northern Norway, the oldest house dating from 1789. The Arctic"
+ u" Cathedral, a modern church from 1965, is probably the most famous landmark"
+ u" in Tromsø.",
+ ]
+
+ desc = summarize_paragraphs(example_paras, min_size=200, max_size=500)
+ self.assertEquals(
+ desc,
+ u"Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:"
+ u" Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in"
+ u" Troms county, Norway.\n"
+ u"\n"
+ u"Tromsø lies in Northern Norway. The municipality has a population of"
+ u" (2015) 72,066, but with an annual influx of students it has over 75,000"
+ u" most of the year. The city centre of Tromsø contains the highest number"
+ u" of old wooden houses in Northern Norway, the oldest house dating from"
+ u" 1789. The Arctic Cathedral, a modern church from…"
+ )
+
+
+class PreviewUrlTestCase(unittest.TestCase):
+ def test_simple(self):
+ html = u"""
+ <html>
+ <head><title>Foo</title></head>
+ <body>
+ Some text.
+ </body>
+ </html>
+ """
+
+ og = decode_and_calc_og(html, "http://example.com/test.html")
+
+ self.assertEquals(og, {
+ u"og:title": u"Foo",
+ u"og:description": u"Some text."
+ })
+
+ def test_comment(self):
+ html = u"""
+ <html>
+ <head><title>Foo</title></head>
+ <body>
+ <!-- HTML comment -->
+ Some text.
+ </body>
+ </html>
+ """
+
+ og = decode_and_calc_og(html, "http://example.com/test.html")
+
+ self.assertEquals(og, {
+ u"og:title": u"Foo",
+ u"og:description": u"Some text."
+ })
+
+ def test_comment2(self):
+ html = u"""
+ <html>
+ <head><title>Foo</title></head>
+ <body>
+ Some text.
+ <!-- HTML comment -->
+ Some more text.
+ <p>Text</p>
+ More text
+ </body>
+ </html>
+ """
+
+ og = decode_and_calc_og(html, "http://example.com/test.html")
+
+ self.assertEquals(og, {
+ u"og:title": u"Foo",
+ u"og:description": u"Some text.\n\nSome more text.\n\nText\n\nMore text"
+ })
+
+ def test_script(self):
+ html = u"""
+ <html>
+ <head><title>Foo</title></head>
+ <body>
+ <script> (function() {})() </script>
+ Some text.
+ </body>
+ </html>
+ """
+
+ og = decode_and_calc_og(html, "http://example.com/test.html")
+
+ self.assertEquals(og, {
+ u"og:title": u"Foo",
+ u"og:description": u"Some text."
+ })
diff --git a/tests/test_state.py b/tests/test_state.py
index a1ea7ef672..6454f994e3 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -67,9 +67,11 @@ class StateGroupStore(object):
self._event_to_state_group = {}
self._group_to_state = {}
+ self._event_id_to_event = {}
+
self._next_group = 1
- def get_state_groups(self, room_id, event_ids):
+ def get_state_groups_ids(self, room_id, event_ids):
groups = {}
for event_id in event_ids:
group = self._event_to_state_group.get(event_id)
@@ -79,22 +81,23 @@ class StateGroupStore(object):
return defer.succeed(groups)
def store_state_groups(self, event, context):
- if context.current_state is None:
+ if context.current_state_ids is None:
return
- state_events = context.current_state
-
- if event.is_state():
- state_events[(event.type, event.state_key)] = event
+ state_events = dict(context.current_state_ids)
- state_group = context.state_group
- if not state_group:
- state_group = self._next_group
- self._next_group += 1
+ self._group_to_state[context.state_group] = state_events
+ self._event_to_state_group[event.event_id] = context.state_group
- self._group_to_state[state_group] = state_events.values()
+ def get_events(self, event_ids, **kwargs):
+ return {
+ e_id: self._event_id_to_event[e_id] for e_id in event_ids
+ if e_id in self._event_id_to_event
+ }
- self._event_to_state_group[event.event_id] = state_group
+ def register_events(self, events):
+ for e in events:
+ self._event_id_to_event[e.event_id] = e
class DictObj(dict):
@@ -136,17 +139,21 @@ class StateTestCase(unittest.TestCase):
def setUp(self):
self.store = Mock(
spec_set=[
- "get_state_groups",
+ "get_state_groups_ids",
"add_event_hashes",
+ "get_events",
+ "get_next_state_group",
]
)
- hs = Mock(spec=[
+ hs = Mock(spec_set=[
"get_datastore", "get_auth", "get_state_handler", "get_clock",
])
hs.get_datastore.return_value = self.store
hs.get_state_handler.return_value = None
- hs.get_auth.return_value = Auth(hs)
hs.get_clock.return_value = MockClock()
+ hs.get_auth.return_value = Auth(hs)
+
+ self.store.get_next_state_group.side_effect = Mock
self.state = StateHandler(hs)
self.event_id = 0
@@ -187,7 +194,7 @@ class StateTestCase(unittest.TestCase):
)
store = StateGroupStore()
- self.store.get_state_groups.side_effect = store.get_state_groups
+ self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
context_store = {}
@@ -196,7 +203,7 @@ class StateTestCase(unittest.TestCase):
store.store_state_groups(event, context)
context_store[event.event_id] = context
- self.assertEqual(2, len(context_store["D"].current_state))
+ self.assertEqual(2, len(context_store["D"].prev_state_ids))
@defer.inlineCallbacks
def test_branch_basic_conflict(self):
@@ -239,7 +246,9 @@ class StateTestCase(unittest.TestCase):
)
store = StateGroupStore()
- self.store.get_state_groups.side_effect = store.get_state_groups
+ self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
+ self.store.get_events = store.get_events
+ store.register_events(graph.walk())
context_store = {}
@@ -250,7 +259,7 @@ class StateTestCase(unittest.TestCase):
self.assertSetEqual(
{"START", "A", "C"},
- {e.event_id for e in context_store["D"].current_state.values()}
+ {e_id for e_id in context_store["D"].prev_state_ids.values()}
)
@defer.inlineCallbacks
@@ -303,7 +312,9 @@ class StateTestCase(unittest.TestCase):
)
store = StateGroupStore()
- self.store.get_state_groups.side_effect = store.get_state_groups
+ self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
+ self.store.get_events = store.get_events
+ store.register_events(graph.walk())
context_store = {}
@@ -314,7 +325,7 @@ class StateTestCase(unittest.TestCase):
self.assertSetEqual(
{"START", "A", "B", "C"},
- {e.event_id for e in context_store["E"].current_state.values()}
+ {e for e in context_store["E"].prev_state_ids.values()}
)
@defer.inlineCallbacks
@@ -384,7 +395,9 @@ class StateTestCase(unittest.TestCase):
graph = Graph(nodes, edges)
store = StateGroupStore()
- self.store.get_state_groups.side_effect = store.get_state_groups
+ self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
+ self.store.get_events = store.get_events
+ store.register_events(graph.walk())
context_store = {}
@@ -395,7 +408,7 @@ class StateTestCase(unittest.TestCase):
self.assertSetEqual(
{"A1", "A2", "A3", "A5", "B"},
- {e.event_id for e in context_store["D"].current_state.values()}
+ {e for e in context_store["D"].prev_state_ids.values()}
)
def _add_depths(self, nodes, edges):
@@ -424,16 +437,11 @@ class StateTestCase(unittest.TestCase):
event, old_state=old_state
)
- for k, v in context.current_state.items():
- type, state_key = k
- self.assertEqual(type, v.type)
- self.assertEqual(state_key, v.state_key)
-
self.assertEqual(
- set(old_state), set(context.current_state.values())
+ set(e.event_id for e in old_state), set(context.current_state_ids.values())
)
- self.assertIsNone(context.state_group)
+ self.assertIsNotNone(context.state_group)
@defer.inlineCallbacks
def test_annotate_with_old_state(self):
@@ -449,18 +457,10 @@ class StateTestCase(unittest.TestCase):
event, old_state=old_state
)
- for k, v in context.current_state.items():
- type, state_key = k
- self.assertEqual(type, v.type)
- self.assertEqual(state_key, v.state_key)
-
self.assertEqual(
- set(old_state),
- set(context.current_state.values())
+ set(e.event_id for e in old_state), set(context.prev_state_ids.values())
)
- self.assertIsNone(context.state_group)
-
@defer.inlineCallbacks
def test_trivial_annotate_message(self):
event = create_event(type="test_message", name="event")
@@ -473,20 +473,15 @@ class StateTestCase(unittest.TestCase):
group_name = "group_name_1"
- self.store.get_state_groups.return_value = {
- group_name: old_state,
+ self.store.get_state_groups_ids.return_value = {
+ group_name: {(e.type, e.state_key): e.event_id for e in old_state},
}
context = yield self.state.compute_event_context(event)
- for k, v in context.current_state.items():
- type, state_key = k
- self.assertEqual(type, v.type)
- self.assertEqual(state_key, v.state_key)
-
self.assertEqual(
set([e.event_id for e in old_state]),
- set([e.event_id for e in context.current_state.values()])
+ set(context.current_state_ids.values())
)
self.assertEqual(group_name, context.state_group)
@@ -503,23 +498,18 @@ class StateTestCase(unittest.TestCase):
group_name = "group_name_1"
- self.store.get_state_groups.return_value = {
- group_name: old_state,
+ self.store.get_state_groups_ids.return_value = {
+ group_name: {(e.type, e.state_key): e.event_id for e in old_state},
}
context = yield self.state.compute_event_context(event)
- for k, v in context.current_state.items():
- type, state_key = k
- self.assertEqual(type, v.type)
- self.assertEqual(state_key, v.state_key)
-
self.assertEqual(
set([e.event_id for e in old_state]),
- set([e.event_id for e in context.current_state.values()])
+ set(context.prev_state_ids.values())
)
- self.assertIsNone(context.state_group)
+ self.assertIsNotNone(context.state_group)
@defer.inlineCallbacks
def test_resolve_message_conflict(self):
@@ -543,11 +533,16 @@ class StateTestCase(unittest.TestCase):
create_event(type="test4", state_key=""),
]
+ store = StateGroupStore()
+ store.register_events(old_state_1)
+ store.register_events(old_state_2)
+ self.store.get_events = store.get_events
+
context = yield self._get_context(event, old_state_1, old_state_2)
- self.assertEqual(len(context.current_state), 6)
+ self.assertEqual(len(context.current_state_ids), 6)
- self.assertIsNone(context.state_group)
+ self.assertIsNotNone(context.state_group)
@defer.inlineCallbacks
def test_resolve_state_conflict(self):
@@ -571,11 +566,16 @@ class StateTestCase(unittest.TestCase):
create_event(type="test4", state_key=""),
]
+ store = StateGroupStore()
+ store.register_events(old_state_1)
+ store.register_events(old_state_2)
+ self.store.get_events = store.get_events
+
context = yield self._get_context(event, old_state_1, old_state_2)
- self.assertEqual(len(context.current_state), 6)
+ self.assertEqual(len(context.current_state_ids), 6)
- self.assertIsNone(context.state_group)
+ self.assertIsNotNone(context.state_group)
@defer.inlineCallbacks
def test_standard_depth_conflict(self):
@@ -606,9 +606,16 @@ class StateTestCase(unittest.TestCase):
create_event(type="test1", state_key="1", depth=2),
]
+ store = StateGroupStore()
+ store.register_events(old_state_1)
+ store.register_events(old_state_2)
+ self.store.get_events = store.get_events
+
context = yield self._get_context(event, old_state_1, old_state_2)
- self.assertEqual(old_state_2[2], context.current_state[("test1", "1")])
+ self.assertEqual(
+ old_state_2[2].event_id, context.current_state_ids[("test1", "1")]
+ )
# Reverse the depth to make sure we are actually using the depths
# during state resolution.
@@ -625,17 +632,22 @@ class StateTestCase(unittest.TestCase):
create_event(type="test1", state_key="1", depth=1),
]
+ store.register_events(old_state_1)
+ store.register_events(old_state_2)
+
context = yield self._get_context(event, old_state_1, old_state_2)
- self.assertEqual(old_state_1[2], context.current_state[("test1", "1")])
+ self.assertEqual(
+ old_state_1[2].event_id, context.current_state_ids[("test1", "1")]
+ )
def _get_context(self, event, old_state_1, old_state_2):
group_name_1 = "group_name_1"
group_name_2 = "group_name_2"
- self.store.get_state_groups.return_value = {
- group_name_1: old_state_1,
- group_name_2: old_state_2,
+ self.store.get_state_groups_ids.return_value = {
+ group_name_1: {(e.type, e.state_key): e.event_id for e in old_state_1},
+ group_name_2: {(e.type, e.state_key): e.event_id for e in old_state_2},
}
return self.state.compute_event_context(event)
diff --git a/tests/unittest.py b/tests/unittest.py
index 5b22abfe74..38715972dd 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -17,13 +17,18 @@ from twisted.trial import unittest
import logging
-
# logging doesn't have a "don't log anything at all EVARRRR setting,
# but since the highest value is 50, 1000000 should do ;)
NEVER = 1000000
-logging.getLogger().addHandler(logging.StreamHandler())
+handler = logging.StreamHandler()
+handler.setFormatter(logging.Formatter(
+ "%(levelname)s:%(name)s:%(message)s [%(pathname)s:%(lineno)d]"
+))
+logging.getLogger().addHandler(handler)
logging.getLogger().setLevel(NEVER)
+logging.getLogger("synapse.storage.SQL").setLevel(NEVER)
+logging.getLogger("synapse.storage.txn").setLevel(NEVER)
def around(target):
@@ -70,8 +75,6 @@ class TestCase(unittest.TestCase):
return ret
logging.getLogger().setLevel(level)
- # Don't set SQL logging
- logging.getLogger("synapse.storage").setLevel(old_level)
return orig()
def assertObjectHasAttributes(self, attrs, obj):
diff --git a/tests/util/test_linearizer.py b/tests/util/test_linearizer.py
new file mode 100644
index 0000000000..afcba482f9
--- /dev/null
+++ b/tests/util/test_linearizer.py
@@ -0,0 +1,44 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from tests import unittest
+
+from twisted.internet import defer
+
+from synapse.util.async import Linearizer
+
+
+class LinearizerTestCase(unittest.TestCase):
+
+ @defer.inlineCallbacks
+ def test_linearizer(self):
+ linearizer = Linearizer()
+
+ key = object()
+
+ d1 = linearizer.queue(key)
+ cm1 = yield d1
+
+ d2 = linearizer.queue(key)
+ self.assertFalse(d2.called)
+
+ with cm1:
+ self.assertFalse(d2.called)
+
+ self.assertTrue(d2.called)
+
+ with (yield d2):
+ pass
diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py
index bab366fb7f..1eba5b535e 100644
--- a/tests/util/test_lrucache.py
+++ b/tests/util/test_lrucache.py
@@ -19,6 +19,8 @@ from .. import unittest
from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.treecache import TreeCache
+from mock import Mock
+
class LruCacheTestCase(unittest.TestCase):
@@ -48,6 +50,8 @@ class LruCacheTestCase(unittest.TestCase):
self.assertEquals(cache.get("key"), 1)
self.assertEquals(cache.setdefault("key", 2), 1)
self.assertEquals(cache.get("key"), 1)
+ cache["key"] = 2 # Make sure overriding works.
+ self.assertEquals(cache.get("key"), 2)
def test_pop(self):
cache = LruCache(1)
@@ -79,3 +83,152 @@ class LruCacheTestCase(unittest.TestCase):
cache["key"] = 1
cache.clear()
self.assertEquals(len(cache), 0)
+
+
+class LruCacheCallbacksTestCase(unittest.TestCase):
+ def test_get(self):
+ m = Mock()
+ cache = LruCache(1)
+
+ cache.set("key", "value")
+ self.assertFalse(m.called)
+
+ cache.get("key", callback=m)
+ self.assertFalse(m.called)
+
+ cache.get("key", "value")
+ self.assertFalse(m.called)
+
+ cache.set("key", "value2")
+ self.assertEquals(m.call_count, 1)
+
+ cache.set("key", "value")
+ self.assertEquals(m.call_count, 1)
+
+ def test_multi_get(self):
+ m = Mock()
+ cache = LruCache(1)
+
+ cache.set("key", "value")
+ self.assertFalse(m.called)
+
+ cache.get("key", callback=m)
+ self.assertFalse(m.called)
+
+ cache.get("key", callback=m)
+ self.assertFalse(m.called)
+
+ cache.set("key", "value2")
+ self.assertEquals(m.call_count, 1)
+
+ cache.set("key", "value")
+ self.assertEquals(m.call_count, 1)
+
+ def test_set(self):
+ m = Mock()
+ cache = LruCache(1)
+
+ cache.set("key", "value", m)
+ self.assertFalse(m.called)
+
+ cache.set("key", "value")
+ self.assertFalse(m.called)
+
+ cache.set("key", "value2")
+ self.assertEquals(m.call_count, 1)
+
+ cache.set("key", "value")
+ self.assertEquals(m.call_count, 1)
+
+ def test_pop(self):
+ m = Mock()
+ cache = LruCache(1)
+
+ cache.set("key", "value", m)
+ self.assertFalse(m.called)
+
+ cache.pop("key")
+ self.assertEquals(m.call_count, 1)
+
+ cache.set("key", "value")
+ self.assertEquals(m.call_count, 1)
+
+ cache.pop("key")
+ self.assertEquals(m.call_count, 1)
+
+ def test_del_multi(self):
+ m1 = Mock()
+ m2 = Mock()
+ m3 = Mock()
+ m4 = Mock()
+ cache = LruCache(4, 2, cache_type=TreeCache)
+
+ cache.set(("a", "1"), "value", m1)
+ cache.set(("a", "2"), "value", m2)
+ cache.set(("b", "1"), "value", m3)
+ cache.set(("b", "2"), "value", m4)
+
+ self.assertEquals(m1.call_count, 0)
+ self.assertEquals(m2.call_count, 0)
+ self.assertEquals(m3.call_count, 0)
+ self.assertEquals(m4.call_count, 0)
+
+ cache.del_multi(("a",))
+
+ self.assertEquals(m1.call_count, 1)
+ self.assertEquals(m2.call_count, 1)
+ self.assertEquals(m3.call_count, 0)
+ self.assertEquals(m4.call_count, 0)
+
+ def test_clear(self):
+ m1 = Mock()
+ m2 = Mock()
+ cache = LruCache(5)
+
+ cache.set("key1", "value", m1)
+ cache.set("key2", "value", m2)
+
+ self.assertEquals(m1.call_count, 0)
+ self.assertEquals(m2.call_count, 0)
+
+ cache.clear()
+
+ self.assertEquals(m1.call_count, 1)
+ self.assertEquals(m2.call_count, 1)
+
+ def test_eviction(self):
+ m1 = Mock(name="m1")
+ m2 = Mock(name="m2")
+ m3 = Mock(name="m3")
+ cache = LruCache(2)
+
+ cache.set("key1", "value", m1)
+ cache.set("key2", "value", m2)
+
+ self.assertEquals(m1.call_count, 0)
+ self.assertEquals(m2.call_count, 0)
+ self.assertEquals(m3.call_count, 0)
+
+ cache.set("key3", "value", m3)
+
+ self.assertEquals(m1.call_count, 1)
+ self.assertEquals(m2.call_count, 0)
+ self.assertEquals(m3.call_count, 0)
+
+ cache.set("key3", "value")
+
+ self.assertEquals(m1.call_count, 1)
+ self.assertEquals(m2.call_count, 0)
+ self.assertEquals(m3.call_count, 0)
+
+ cache.get("key2")
+
+ self.assertEquals(m1.call_count, 1)
+ self.assertEquals(m2.call_count, 0)
+ self.assertEquals(m3.call_count, 0)
+
+ cache.set("key1", "value", m1)
+
+ self.assertEquals(m1.call_count, 1)
+ self.assertEquals(m2.call_count, 0)
+ self.assertEquals(m3.call_count, 1)
diff --git a/tests/util/test_rwlock.py b/tests/util/test_rwlock.py
new file mode 100644
index 0000000000..1d745ae1a7
--- /dev/null
+++ b/tests/util/test_rwlock.py
@@ -0,0 +1,85 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from tests import unittest
+
+from synapse.util.async import ReadWriteLock
+
+
+class ReadWriteLockTestCase(unittest.TestCase):
+
+ def _assert_called_before_not_after(self, lst, first_false):
+ for i, d in enumerate(lst[:first_false]):
+ self.assertTrue(d.called, msg="%d was unexpectedly false" % i)
+
+ for i, d in enumerate(lst[first_false:]):
+ self.assertFalse(
+ d.called, msg="%d was unexpectedly true" % (i + first_false)
+ )
+
+ def test_rwlock(self):
+ rwlock = ReadWriteLock()
+
+ key = object()
+
+ ds = [
+ rwlock.read(key), # 0
+ rwlock.read(key), # 1
+ rwlock.write(key), # 2
+ rwlock.write(key), # 3
+ rwlock.read(key), # 4
+ rwlock.read(key), # 5
+ rwlock.write(key), # 6
+ ]
+
+ self._assert_called_before_not_after(ds, 2)
+
+ with ds[0].result:
+ self._assert_called_before_not_after(ds, 2)
+ self._assert_called_before_not_after(ds, 2)
+
+ with ds[1].result:
+ self._assert_called_before_not_after(ds, 2)
+ self._assert_called_before_not_after(ds, 3)
+
+ with ds[2].result:
+ self._assert_called_before_not_after(ds, 3)
+ self._assert_called_before_not_after(ds, 4)
+
+ with ds[3].result:
+ self._assert_called_before_not_after(ds, 4)
+ self._assert_called_before_not_after(ds, 6)
+
+ with ds[5].result:
+ self._assert_called_before_not_after(ds, 6)
+ self._assert_called_before_not_after(ds, 6)
+
+ with ds[4].result:
+ self._assert_called_before_not_after(ds, 6)
+ self._assert_called_before_not_after(ds, 7)
+
+ with ds[6].result:
+ pass
+
+ d = rwlock.write(key)
+ self.assertTrue(d.called)
+ with d.result:
+ pass
+
+ d = rwlock.read(key)
+ self.assertTrue(d.called)
+ with d.result:
+ pass
diff --git a/tests/utils.py b/tests/utils.py
index 52405502e9..2d0bd205fd 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -20,7 +20,6 @@ from synapse.storage.prepare_database import prepare_database
from synapse.storage.engines import create_engine
from synapse.server import HomeServer
from synapse.federation.transport import server
-from synapse.types import Requester
from synapse.util.ratelimitutils import FederationRateLimiter
from synapse.util.logcontext import LoggingContext
@@ -49,11 +48,17 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
config.event_cache_size = 1
config.enable_registration = True
config.macaroon_secret_key = "not even a little secret"
- config.server_name = "server.under.test"
+ config.expire_access_token = False
+ config.server_name = name
config.trusted_third_party_id_servers = []
config.room_invite_state_types = []
+ config.password_providers = []
+ config.worker_replication_url = ""
+ config.worker_app = None
+ config.use_frozen_dicts = True
config.database_config = {"name": "sqlite3"}
+ config.ldap_enabled = False
if "clock" not in kargs:
kargs["clock"] = MockClock()
@@ -64,8 +69,10 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
hs = HomeServer(
name, db_pool=db_pool, config=config,
version_string="Synapse/tests",
- database_engine=create_engine(config),
+ database_engine=create_engine(config.database_config),
get_db_conn=db_pool.get_db_conn,
+ room_list_handler=object(),
+ tls_server_context_factory=Mock(),
**kargs
)
hs.setup()
@@ -73,21 +80,18 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
hs = HomeServer(
name, db_pool=None, datastore=datastore, config=config,
version_string="Synapse/tests",
- database_engine=create_engine(config),
+ database_engine=create_engine(config.database_config),
+ room_list_handler=object(),
+ tls_server_context_factory=Mock(),
**kargs
)
# bcrypt is far too slow to be doing in unit tests
- def swap_out_hash_for_testing(old_build_handlers):
- def build_handlers():
- handlers = old_build_handlers()
- auth_handler = handlers.auth_handler
- auth_handler.hash = lambda p: hashlib.md5(p).hexdigest()
- auth_handler.validate_hash = lambda p, h: hashlib.md5(p).hexdigest() == h
- return handlers
- return build_handlers
-
- hs.build_handlers = swap_out_hash_for_testing(hs.build_handlers)
+ # Need to let the HS build an auth handler and then mess with it
+ # because AuthHandler's constructor requires the HS, so we can't make one
+ # beforehand and pass it in to the HS's constructor (chicken / egg)
+ hs.get_auth_handler().hash = lambda p: hashlib.md5(p).hexdigest()
+ hs.get_auth_handler().validate_hash = lambda p, h: hashlib.md5(p).hexdigest() == h
fed = kargs.get("resource_for_federation", None)
if fed:
@@ -116,6 +120,15 @@ def get_mock_call_args(pattern_func, mock_func):
return getcallargs(pattern_func, *invoked_args, **invoked_kargs)
+def mock_getRawHeaders(headers=None):
+ headers = headers if headers is not None else {}
+
+ def getRawHeaders(name, default=None):
+ return headers.get(name, default)
+
+ return getRawHeaders
+
+
# This is a mock /resource/ not an entire server
class MockHttpResource(HttpServer):
@@ -128,7 +141,7 @@ class MockHttpResource(HttpServer):
@patch('twisted.web.http.Request')
@defer.inlineCallbacks
- def trigger(self, http_method, path, content, mock_request):
+ def trigger(self, http_method, path, content, mock_request, federation_auth=False):
""" Fire an HTTP event.
Args:
@@ -156,9 +169,10 @@ class MockHttpResource(HttpServer):
mock_request.getClientIP.return_value = "-"
- mock_request.requestHeaders.getRawHeaders.return_value = [
- "X-Matrix origin=test,key=,sig="
- ]
+ headers = {}
+ if federation_auth:
+ headers["Authorization"] = ["X-Matrix origin=test,key=,sig="]
+ mock_request.requestHeaders.getRawHeaders = mock_getRawHeaders(headers)
# return the right path if the event requires it
mock_request.path = path
@@ -189,7 +203,7 @@ class MockHttpResource(HttpServer):
)
defer.returnValue((code, response))
except CodeMessageException as e:
- defer.returnValue((e.code, cs_error(e.msg)))
+ defer.returnValue((e.code, cs_error(e.msg, code=e.errcode)))
raise KeyError("No event can handle %s" % path)
@@ -221,6 +235,7 @@ class MockClock(object):
# list of lists of [absolute_time, callback, expired] in no particular
# order
self.timers = []
+ self.loopers = []
def time(self):
return self.now
@@ -241,7 +256,7 @@ class MockClock(object):
return t
def looping_call(self, function, interval):
- pass
+ self.loopers.append([function, interval / 1000., self.now])
def cancel_call_later(self, timer, ignore_errs=False):
if timer[2]:
@@ -270,6 +285,12 @@ class MockClock(object):
else:
self.timers.append(t)
+ for looped in self.loopers:
+ func, interval, last = looped
+ if last + interval < self.now:
+ func()
+ looped[2] = self.now
+
def advance_time_msec(self, ms):
self.advance_time(ms / 1000.)
@@ -298,7 +319,7 @@ class SQLiteMemoryDbPool(ConnectionPool, object):
return conn
def create_engine(self):
- return create_engine(self.config)
+ return create_engine(self.config.database_config)
class MemoryDataStore(object):
@@ -512,7 +533,3 @@ class DeferredMockCallable(object):
"call(%s)" % _format_call(c[0], c[1]) for c in calls
])
)
-
-
-def requester_for_user(user):
- return Requester(user, None, False)
|