diff --git a/tests/__init__.py b/tests/__init__.py
index aab20e8e02..24006c949e 100644
--- a/tests/__init__.py
+++ b/tests/__init__.py
@@ -14,4 +14,5 @@
# limitations under the License.
from twisted.trial import util
+
util.DEFAULT_TIMEOUT_DURATION = 10
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index 4575dd9834..5f158ec4b9 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -13,16 +13,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import pymacaroons
from mock import Mock
+
+import pymacaroons
+
from twisted.internet import defer
import synapse.handlers.auth
from synapse.api.auth import Auth
from synapse.api.errors import AuthError
from synapse.types import UserID
+
from tests import unittest
-from tests.utils import setup_test_homeserver, mock_getRawHeaders
+from tests.utils import mock_getRawHeaders, setup_test_homeserver
class TestHandlers(object):
@@ -86,16 +89,53 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_get_user_by_req_appservice_valid_token(self):
- app_service = Mock(token="foobar", url="a_url", sender=self.test_user)
+ app_service = Mock(
+ token="foobar", url="a_url", sender=self.test_user,
+ ip_range_whitelist=None,
+ )
+ self.store.get_app_service_by_token = Mock(return_value=app_service)
+ self.store.get_user_by_access_token = Mock(return_value=None)
+
+ request = Mock(args={})
+ request.getClientIP.return_value = "127.0.0.1"
+ request.args["access_token"] = [self.test_token]
+ request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+ requester = yield self.auth.get_user_by_req(request)
+ self.assertEquals(requester.user.to_string(), self.test_user)
+
+ @defer.inlineCallbacks
+ def test_get_user_by_req_appservice_valid_token_good_ip(self):
+ from netaddr import IPSet
+ app_service = Mock(
+ token="foobar", url="a_url", sender=self.test_user,
+ ip_range_whitelist=IPSet(["192.168/16"]),
+ )
self.store.get_app_service_by_token = Mock(return_value=app_service)
self.store.get_user_by_access_token = Mock(return_value=None)
request = Mock(args={})
+ request.getClientIP.return_value = "192.168.10.10"
request.args["access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = yield self.auth.get_user_by_req(request)
self.assertEquals(requester.user.to_string(), self.test_user)
+ def test_get_user_by_req_appservice_valid_token_bad_ip(self):
+ from netaddr import IPSet
+ app_service = Mock(
+ token="foobar", url="a_url", sender=self.test_user,
+ ip_range_whitelist=IPSet(["192.168/16"]),
+ )
+ self.store.get_app_service_by_token = Mock(return_value=app_service)
+ self.store.get_user_by_access_token = Mock(return_value=None)
+
+ request = Mock(args={})
+ request.getClientIP.return_value = "131.111.8.42"
+ request.args["access_token"] = [self.test_token]
+ request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+ d = self.auth.get_user_by_req(request)
+ self.failureResultOf(d, AuthError)
+
def test_get_user_by_req_appservice_bad_token(self):
self.store.get_app_service_by_token = Mock(return_value=None)
self.store.get_user_by_access_token = Mock(return_value=None)
@@ -119,12 +159,16 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_get_user_by_req_appservice_valid_token_valid_user_id(self):
masquerading_user_id = "@doppelganger:matrix.org"
- app_service = Mock(token="foobar", url="a_url", sender=self.test_user)
+ app_service = Mock(
+ token="foobar", url="a_url", sender=self.test_user,
+ ip_range_whitelist=None,
+ )
app_service.is_interested_in_user = Mock(return_value=True)
self.store.get_app_service_by_token = Mock(return_value=app_service)
self.store.get_user_by_access_token = Mock(return_value=None)
request = Mock(args={})
+ request.getClientIP.return_value = "127.0.0.1"
request.args["access_token"] = [self.test_token]
request.args["user_id"] = [masquerading_user_id]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
@@ -133,12 +177,16 @@ class AuthTestCase(unittest.TestCase):
def test_get_user_by_req_appservice_valid_token_bad_user_id(self):
masquerading_user_id = "@doppelganger:matrix.org"
- app_service = Mock(token="foobar", url="a_url", sender=self.test_user)
+ app_service = Mock(
+ token="foobar", url="a_url", sender=self.test_user,
+ ip_range_whitelist=None,
+ )
app_service.is_interested_in_user = Mock(return_value=False)
self.store.get_app_service_by_token = Mock(return_value=app_service)
self.store.get_user_by_access_token = Mock(return_value=None)
request = Mock(args={})
+ request.getClientIP.return_value = "127.0.0.1"
request.args["access_token"] = [self.test_token]
request.args["user_id"] = [masquerading_user_id]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py
index dcceca7f3e..836a23fb54 100644
--- a/tests/api/test_filtering.py
+++ b/tests/api/test_filtering.py
@@ -13,19 +13,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from tests import unittest
-from twisted.internet import defer
-
from mock import Mock
-from tests.utils import (
- MockHttpResource, DeferredMockCallable, setup_test_homeserver
-)
+import jsonschema
+
+from twisted.internet import defer
+
+from synapse.api.errors import SynapseError
from synapse.api.filtering import Filter
from synapse.events import FrozenEvent
-from synapse.api.errors import SynapseError
-import jsonschema
+from tests import unittest
+from tests.utils import DeferredMockCallable, MockHttpResource, setup_test_homeserver
user_localpart = "test_user"
diff --git a/tests/appservice/test_appservice.py b/tests/appservice/test_appservice.py
index 5b2b95860a..891e0cc973 100644
--- a/tests/appservice/test_appservice.py
+++ b/tests/appservice/test_appservice.py
@@ -12,14 +12,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.appservice import ApplicationService
+import re
+
+from mock import Mock
from twisted.internet import defer
-from mock import Mock
-from tests import unittest
+from synapse.appservice import ApplicationService
-import re
+from tests import unittest
def _regex(regex, exclusive=True):
diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py
index 9181692771..b9f4863e9a 100644
--- a/tests/appservice/test_scheduler.py
+++ b/tests/appservice/test_scheduler.py
@@ -12,17 +12,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from mock import Mock
+
+from twisted.internet import defer
+
from synapse.appservice import ApplicationServiceState
from synapse.appservice.scheduler import (
- _ServiceQueuer, _TransactionController, _Recoverer
+ _Recoverer,
+ _ServiceQueuer,
+ _TransactionController,
)
-from twisted.internet import defer
-
from synapse.util.logcontext import make_deferred_yieldable
-from ..utils import MockClock
-from mock import Mock
+
from tests import unittest
+from ..utils import MockClock
+
class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
diff --git a/tests/config/test_generate.py b/tests/config/test_generate.py
index 879159ccea..eb7f0ab12a 100644
--- a/tests/config/test_generate.py
+++ b/tests/config/test_generate.py
@@ -19,6 +19,7 @@ import shutil
import tempfile
from synapse.config.homeserver import HomeServerConfig
+
from tests import unittest
diff --git a/tests/config/test_load.py b/tests/config/test_load.py
index 772afd2cf9..5c422eff38 100644
--- a/tests/config/test_load.py
+++ b/tests/config/test_load.py
@@ -15,8 +15,11 @@
import os.path
import shutil
import tempfile
+
import yaml
+
from synapse.config.homeserver import HomeServerConfig
+
from tests import unittest
diff --git a/tests/crypto/test_event_signing.py b/tests/crypto/test_event_signing.py
index 47cb328a01..cd11871b80 100644
--- a/tests/crypto/test_event_signing.py
+++ b/tests/crypto/test_event_signing.py
@@ -14,15 +14,13 @@
# limitations under the License.
-from tests import unittest
-
-from synapse.events.builder import EventBuilder
-from synapse.crypto.event_signing import add_hashes_and_signatures
-
+import nacl.signing
from unpaddedbase64 import decode_base64
-import nacl.signing
+from synapse.crypto.event_signing import add_hashes_and_signatures
+from synapse.events.builder import EventBuilder
+from tests import unittest
# Perform these tests using given secret key so we get entirely deterministic
# signatures output that we can test against.
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 149e443022..a9d37fe084 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -14,15 +14,19 @@
# limitations under the License.
import time
+from mock import Mock
+
import signedjson.key
import signedjson.sign
-from mock import Mock
+
+from twisted.internet import defer, reactor
+
from synapse.api.errors import SynapseError
from synapse.crypto import keyring
-from synapse.util import async, logcontext
+from synapse.util import Clock, logcontext
from synapse.util.logcontext import LoggingContext
+
from tests import unittest, utils
-from twisted.internet import defer
class MockPerspectiveServer(object):
@@ -118,6 +122,7 @@ class KeyringTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_verify_json_objects_for_server_awaits_previous_requests(self):
+ clock = Clock(reactor)
key1 = signedjson.key.generate_signing_key(1)
kr = keyring.Keyring(self.hs)
@@ -167,7 +172,7 @@ class KeyringTestCase(unittest.TestCase):
# wait a tick for it to send the request to the perspectives server
# (it first tries the datastore)
- yield async.sleep(1) # XXX find out why this takes so long!
+ yield clock.sleep(1) # XXX find out why this takes so long!
self.http_client.post_json.assert_called_once()
self.assertIs(LoggingContext.current_context(), context_11)
@@ -183,7 +188,7 @@ class KeyringTestCase(unittest.TestCase):
res_deferreds_2 = kr.verify_json_objects_for_server(
[("server10", json1)],
)
- yield async.sleep(1)
+ yield clock.sleep(1)
self.http_client.post_json.assert_not_called()
res_deferreds_2[0].addBoth(self.check_context, None)
diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py
index dfc870066e..f51d99419e 100644
--- a/tests/events/test_utils.py
+++ b/tests/events/test_utils.py
@@ -14,11 +14,11 @@
# limitations under the License.
-from .. import unittest
-
from synapse.events import FrozenEvent
from synapse.events.utils import prune_event, serialize_event
+from .. import unittest
+
def MockEvent(**kwargs):
if "event_id" not in kwargs:
diff --git a/tests/federation/__init__.py b/tests/federation/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/tests/federation/__init__.py
diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py
new file mode 100644
index 0000000000..c91e25f54f
--- /dev/null
+++ b/tests/federation/test_federation_server.py
@@ -0,0 +1,58 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from synapse.events import FrozenEvent
+from synapse.federation.federation_server import server_matches_acl_event
+
+from tests import unittest
+
+
+@unittest.DEBUG
+class ServerACLsTestCase(unittest.TestCase):
+ def test_blacklisted_server(self):
+ e = _create_acl_event({
+ "allow": ["*"],
+ "deny": ["evil.com"],
+ })
+ logging.info("ACL event: %s", e.content)
+
+ self.assertFalse(server_matches_acl_event("evil.com", e))
+ self.assertFalse(server_matches_acl_event("EVIL.COM", e))
+
+ self.assertTrue(server_matches_acl_event("evil.com.au", e))
+ self.assertTrue(server_matches_acl_event("honestly.not.evil.com", e))
+
+ def test_block_ip_literals(self):
+ e = _create_acl_event({
+ "allow_ip_literals": False,
+ "allow": ["*"],
+ })
+ logging.info("ACL event: %s", e.content)
+
+ self.assertFalse(server_matches_acl_event("1.2.3.4", e))
+ self.assertTrue(server_matches_acl_event("1a.2.3.4", e))
+ self.assertFalse(server_matches_acl_event("[1:2::]", e))
+ self.assertTrue(server_matches_acl_event("1:2:3:4", e))
+
+
+def _create_acl_event(content):
+ return FrozenEvent({
+ "room_id": "!a:b",
+ "event_id": "$a:b",
+ "type": "m.room.server_acls",
+ "sender": "@a:b",
+ "content": content
+ })
diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index b753455943..57c0771cf3 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -13,13 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from mock import Mock
+
from twisted.internet import defer
-from .. import unittest
-from tests.utils import MockClock
from synapse.handlers.appservice import ApplicationServicesHandler
-from mock import Mock
+from tests.utils import MockClock
+
+from .. import unittest
class AppServiceHandlerTestCase(unittest.TestCase):
diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py
index 1822dcf1e0..2e5e8e4dec 100644
--- a/tests/handlers/test_auth.py
+++ b/tests/handlers/test_auth.py
@@ -14,11 +14,13 @@
# limitations under the License.
import pymacaroons
+
from twisted.internet import defer
import synapse
import synapse.api.errors
from synapse.handlers.auth import AuthHandler
+
from tests import unittest
from tests.utils import setup_test_homeserver
diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index 778ff2f6e9..633a0b7f36 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -17,8 +17,8 @@ from twisted.internet import defer
import synapse.api.errors
import synapse.handlers.device
-
import synapse.storage
+
from tests import unittest, utils
user1 = "@boris:aaa"
diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index 7e5332e272..a353070316 100644
--- a/tests/handlers/test_directory.py
+++ b/tests/handlers/test_directory.py
@@ -14,14 +14,14 @@
# limitations under the License.
-from tests import unittest
-from twisted.internet import defer
-
from mock import Mock
+from twisted.internet import defer
+
from synapse.handlers.directory import DirectoryHandler
from synapse.types import RoomAlias
+from tests import unittest
from tests.utils import setup_test_homeserver
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index d1bd87b898..ca1542236d 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -14,13 +14,14 @@
# limitations under the License.
import mock
-from synapse.api import errors
+
from twisted.internet import defer
import synapse.api.errors
import synapse.handlers.e2e_keys
-
import synapse.storage
+from synapse.api import errors
+
from tests import unittest, utils
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index de06a6ad30..121ce78634 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -14,18 +14,22 @@
# limitations under the License.
-from tests import unittest
-
from mock import Mock, call
from synapse.api.constants import PresenceState
from synapse.handlers.presence import (
- handle_update, handle_timeout,
- IDLE_TIMER, SYNC_ONLINE_TIMEOUT, LAST_ACTIVE_GRANULARITY, FEDERATION_TIMEOUT,
FEDERATION_PING_INTERVAL,
+ FEDERATION_TIMEOUT,
+ IDLE_TIMER,
+ LAST_ACTIVE_GRANULARITY,
+ SYNC_ONLINE_TIMEOUT,
+ handle_timeout,
+ handle_update,
)
from synapse.storage.presence import UserPresenceState
+from tests import unittest
+
class PresenceUpdateTestCase(unittest.TestCase):
def test_offline_to_online(self):
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index 458296ee4c..dc17918a3d 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -14,16 +14,16 @@
# limitations under the License.
-from tests import unittest
-from twisted.internet import defer
-
from mock import Mock, NonCallableMock
+from twisted.internet import defer
+
import synapse.types
from synapse.api.errors import AuthError
from synapse.handlers.profile import ProfileHandler
from synapse.types import UserID
+from tests import unittest
from tests.utils import setup_test_homeserver
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index e990e45220..025fa1be81 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -13,15 +13,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from mock import Mock
+
from twisted.internet import defer
-from .. import unittest
from synapse.handlers.register import RegistrationHandler
from synapse.types import UserID, create_requester
from tests.utils import setup_test_homeserver
-from mock import Mock
+from .. import unittest
class RegistrationHandlers(object):
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index a433bbfa8a..b08856f763 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -14,19 +14,24 @@
# limitations under the License.
-from tests import unittest
-from twisted.internet import defer
-
-from mock import Mock, call, ANY
import json
-from ..utils import (
- MockHttpResource, MockClock, DeferredMockCallable, setup_test_homeserver
-)
+from mock import ANY, Mock, call
+
+from twisted.internet import defer
from synapse.api.errors import AuthError
from synapse.types import UserID
+from tests import unittest
+
+from ..utils import (
+ DeferredMockCallable,
+ MockClock,
+ MockHttpResource,
+ setup_test_homeserver,
+)
+
def _expect_edu(destination, edu_type, content, origin="test"):
return {
diff --git a/tests/http/__init__.py b/tests/http/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/tests/http/__init__.py
diff --git a/tests/http/test_endpoint.py b/tests/http/test_endpoint.py
new file mode 100644
index 0000000000..60e6a75953
--- /dev/null
+++ b/tests/http/test_endpoint.py
@@ -0,0 +1,53 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from synapse.http.endpoint import parse_and_validate_server_name, parse_server_name
+
+from tests import unittest
+
+
+class ServerNameTestCase(unittest.TestCase):
+ def test_parse_server_name(self):
+ test_data = {
+ 'localhost': ('localhost', None),
+ 'my-example.com:1234': ('my-example.com', 1234),
+ '1.2.3.4': ('1.2.3.4', None),
+ '[0abc:1def::1234]': ('[0abc:1def::1234]', None),
+ '1.2.3.4:1': ('1.2.3.4', 1),
+ '[0abc:1def::1234]:8080': ('[0abc:1def::1234]', 8080),
+ }
+
+ for i, o in test_data.items():
+ self.assertEqual(parse_server_name(i), o)
+
+ def test_validate_bad_server_names(self):
+ test_data = [
+ "", # empty
+ "localhost:http", # non-numeric port
+ "1234]", # smells like ipv6 literal but isn't
+ "[1234",
+ "underscore_.com",
+ "percent%65.com",
+ "1234:5678:80", # too many colons
+ ]
+ for i in test_data:
+ try:
+ parse_and_validate_server_name(i)
+ self.fail(
+ "Expected parse_and_validate_server_name('%s') to throw" % (
+ i,
+ ),
+ )
+ except ValueError:
+ pass
diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py
index 64e07a8c93..8708c8a196 100644
--- a/tests/replication/slave/storage/_base.py
+++ b/tests/replication/slave/storage/_base.py
@@ -12,17 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer, reactor
-from tests import unittest
-
import tempfile
from mock import Mock, NonCallableMock
-from tests.utils import setup_test_homeserver
-from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
+
+from twisted.internet import defer, reactor
+
from synapse.replication.tcp.client import (
- ReplicationClientHandler, ReplicationClientFactory,
+ ReplicationClientFactory,
+ ReplicationClientHandler,
)
+from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
+
+from tests import unittest
+from tests.utils import setup_test_homeserver
class BaseSlavedStoreTestCase(unittest.TestCase):
diff --git a/tests/replication/slave/storage/test_account_data.py b/tests/replication/slave/storage/test_account_data.py
index da54d478ce..adf226404e 100644
--- a/tests/replication/slave/storage/test_account_data.py
+++ b/tests/replication/slave/storage/test_account_data.py
@@ -13,11 +13,11 @@
# limitations under the License.
-from ._base import BaseSlavedStoreTestCase
+from twisted.internet import defer
from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
-from twisted.internet import defer
+from ._base import BaseSlavedStoreTestCase
USER_ID = "@feeling:blue"
TYPE = "my.type"
@@ -37,10 +37,6 @@ class SlavedAccountDataStoreTestCase(BaseSlavedStoreTestCase):
"get_global_account_data_by_type_for_user",
[TYPE, USER_ID], {"a": 1}
)
- yield self.check(
- "get_global_account_data_by_type_for_users",
- [TYPE, [USER_ID]], {USER_ID: {"a": 1}}
- )
yield self.master_store.add_account_data_for_user(
USER_ID, TYPE, {"a": 2}
@@ -50,7 +46,3 @@ class SlavedAccountDataStoreTestCase(BaseSlavedStoreTestCase):
"get_global_account_data_by_type_for_user",
[TYPE, USER_ID], {"a": 2}
)
- yield self.check(
- "get_global_account_data_by_type_for_users",
- [TYPE, [USER_ID]], {USER_ID: {"a": 2}}
- )
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index cb058d3142..cea01d93eb 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -12,15 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import BaseSlavedStoreTestCase
+from twisted.internet import defer
from synapse.events import FrozenEvent, _EventInternalMetadata
from synapse.events.snapshot import EventContext
from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.storage.roommember import RoomsForUser
-from twisted.internet import defer
-
+from ._base import BaseSlavedStoreTestCase
USER_ID = "@feeling:blue"
USER_ID_2 = "@bright:blue"
diff --git a/tests/replication/slave/storage/test_receipts.py b/tests/replication/slave/storage/test_receipts.py
index 6624fe4eea..e6d670cc1f 100644
--- a/tests/replication/slave/storage/test_receipts.py
+++ b/tests/replication/slave/storage/test_receipts.py
@@ -12,11 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import BaseSlavedStoreTestCase
+from twisted.internet import defer
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
-from twisted.internet import defer
+from ._base import BaseSlavedStoreTestCase
USER_ID = "@feeling:blue"
ROOM_ID = "!room:blue"
diff --git a/tests/rest/client/test_transactions.py b/tests/rest/client/test_transactions.py
index b5bc2fa255..34e68ae82f 100644
--- a/tests/rest/client/test_transactions.py
+++ b/tests/rest/client/test_transactions.py
@@ -1,10 +1,11 @@
-from synapse.rest.client.transactions import HttpTransactionCache
-from synapse.rest.client.transactions import CLEANUP_PERIOD_MS
-from twisted.internet import defer
from mock import Mock, call
-from synapse.util import async
+from twisted.internet import defer, reactor
+
+from synapse.rest.client.transactions import CLEANUP_PERIOD_MS, HttpTransactionCache
+from synapse.util import Clock
from synapse.util.logcontext import LoggingContext
+
from tests import unittest
from tests.utils import MockClock
@@ -13,7 +14,10 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
def setUp(self):
self.clock = MockClock()
- self.cache = HttpTransactionCache(self.clock)
+ self.hs = Mock()
+ self.hs.get_clock = Mock(return_value=self.clock)
+ self.hs.get_auth = Mock()
+ self.cache = HttpTransactionCache(self.hs)
self.mock_http_response = (200, "GOOD JOB!")
self.mock_key = "foo"
@@ -46,7 +50,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
def test_logcontexts_with_async_result(self):
@defer.inlineCallbacks
def cb():
- yield async.sleep(0)
+ yield Clock(reactor).sleep(0)
defer.returnValue("yay")
@defer.inlineCallbacks
diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py
index f5a7258e68..50418153fa 100644
--- a/tests/rest/client/v1/test_events.py
+++ b/tests/rest/client/v1/test_events.py
@@ -14,102 +14,30 @@
# limitations under the License.
""" Tests REST events for /events paths."""
-from tests import unittest
-# twisted imports
-from twisted.internet import defer
-
-import synapse.rest.client.v1.events
-import synapse.rest.client.v1.register
-import synapse.rest.client.v1.room
+from mock import Mock, NonCallableMock
+from six import PY3
+from twisted.internet import defer
from ....utils import MockHttpResource, setup_test_homeserver
from .utils import RestTestCase
-from mock import Mock, NonCallableMock
-
-
PATH_PREFIX = "/_matrix/client/api/v1"
-class EventStreamPaginationApiTestCase(unittest.TestCase):
- """ Tests event streaming query parameters and start/end keys used in the
- Pagination stream API. """
- user_id = "sid1"
-
- def setUp(self):
- # configure stream and inject items
- pass
-
- def tearDown(self):
- pass
-
- def TODO_test_long_poll(self):
- # stream from 'end' key, send (self+other) message, expect message.
-
- # stream from 'END', send (self+other) message, expect message.
-
- # stream from 'end' key, send (self+other) topic, expect topic.
-
- # stream from 'END', send (self+other) topic, expect topic.
-
- # stream from 'end' key, send (self+other) invite, expect invite.
-
- # stream from 'END', send (self+other) invite, expect invite.
-
- pass
-
- def TODO_test_stream_forward(self):
- # stream from START, expect injected items
-
- # stream from 'start' key, expect same content
-
- # stream from 'end' key, expect nothing
-
- # stream from 'END', expect nothing
-
- # The following is needed for cases where content is removed e.g. you
- # left a room, so the token you're streaming from is > the one that
- # would be returned naturally from START>END.
- # stream from very new token (higher than end key), expect same token
- # returned as end key
- pass
-
- def TODO_test_limits(self):
- # stream from a key, expect limit_num items
-
- # stream from START, expect limit_num items
-
- pass
-
- def TODO_test_range(self):
- # stream from key to key, expect X items
-
- # stream from key to END, expect X items
-
- # stream from START to key, expect X items
-
- # stream from START to END, expect all items
- pass
-
- def TODO_test_direction(self):
- # stream from END to START and fwds, expect newest first
-
- # stream from END to START and bwds, expect oldest first
-
- # stream from START to END and fwds, expect oldest first
-
- # stream from START to END and bwds, expect newest first
-
- pass
-
-
class EventStreamPermissionsTestCase(RestTestCase):
""" Tests event streaming (GET /events). """
+ if PY3:
+ skip = "Skip on Py3 until ported to use not V1 only register."
+
@defer.inlineCallbacks
def setUp(self):
+ import synapse.rest.client.v1.events
+ import synapse.rest.client.v1_only.register
+ import synapse.rest.client.v1.room
+
self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
hs = yield setup_test_homeserver(
@@ -127,7 +55,7 @@ class EventStreamPermissionsTestCase(RestTestCase):
hs.get_handlers().federation_handler = Mock()
- synapse.rest.client.v1.register.register_servlets(hs, self.mock_resource)
+ synapse.rest.client.v1_only.register.register_servlets(hs, self.mock_resource)
synapse.rest.client.v1.events.register_servlets(hs, self.mock_resource)
synapse.rest.client.v1.room.register_servlets(hs, self.mock_resource)
diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py
index dc94b8bd19..d71cc8e0db 100644
--- a/tests/rest/client/v1/test_profile.py
+++ b/tests/rest/client/v1/test_profile.py
@@ -15,12 +15,15 @@
"""Tests REST events for /profile paths."""
from mock import Mock
+
from twisted.internet import defer
import synapse.types
-from synapse.api.errors import SynapseError, AuthError
+from synapse.api.errors import AuthError, SynapseError
from synapse.rest.client.v1 import profile
+
from tests import unittest
+
from ....utils import MockHttpResource, setup_test_homeserver
myid = "@1234ABCD:test"
diff --git a/tests/rest/client/v1/test_register.py b/tests/rest/client/v1/test_register.py
index a6a4e2ffe0..83a23cd8fe 100644
--- a/tests/rest/client/v1/test_register.py
+++ b/tests/rest/client/v1/test_register.py
@@ -13,26 +13,29 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.rest.client.v1.register import CreateUserRestServlet
-from twisted.internet import defer
+import json
+
from mock import Mock
+from six import PY3
+
+from twisted.test.proto_helpers import MemoryReactorClock
+
+from synapse.http.server import JsonResource
+from synapse.rest.client.v1_only.register import register_servlets
+from synapse.util import Clock
+
from tests import unittest
-from tests.utils import mock_getRawHeaders
-import json
+from tests.server import make_request, setup_test_homeserver
class CreateUserServletTestCase(unittest.TestCase):
+ """
+ Tests for CreateUserRestServlet.
+ """
+ if PY3:
+ skip = "Not ported to Python 3."
def setUp(self):
- # do the dance to hook up request data to self.request_data
- self.request_data = ""
- self.request = Mock(
- content=Mock(read=Mock(side_effect=lambda: self.request_data)),
- path='/_matrix/client/api/v1/createUser'
- )
- self.request.args = {}
- self.request.requestHeaders.getRawHeaders = mock_getRawHeaders()
-
self.registration_handler = Mock()
self.appservice = Mock(sender="@as:test")
@@ -40,39 +43,49 @@ class CreateUserServletTestCase(unittest.TestCase):
get_app_service_by_token=Mock(return_value=self.appservice)
)
- # do the dance to hook things up to the hs global
- handlers = Mock(
- registration_handler=self.registration_handler,
+ handlers = Mock(registration_handler=self.registration_handler)
+ self.clock = MemoryReactorClock()
+ self.hs_clock = Clock(self.clock)
+
+ self.hs = self.hs = setup_test_homeserver(
+ http_client=None, clock=self.hs_clock, reactor=self.clock
)
- self.hs = Mock()
- self.hs.hostname = "superbig~testing~thing.com"
self.hs.get_datastore = Mock(return_value=self.datastore)
self.hs.get_handlers = Mock(return_value=handlers)
- self.servlet = CreateUserRestServlet(self.hs)
- @defer.inlineCallbacks
def test_POST_createuser_with_valid_user(self):
+
+ res = JsonResource(self.hs)
+ register_servlets(self.hs, res)
+
+ request_data = json.dumps(
+ {
+ "localpart": "someone",
+ "displayname": "someone interesting",
+ "duration_seconds": 200,
+ }
+ )
+
+ url = b'/_matrix/client/api/v1/createUser?access_token=i_am_an_app_service'
+
user_id = "@someone:interesting"
token = "my token"
- self.request.args = {
- "access_token": "i_am_an_app_service"
- }
- self.request_data = json.dumps({
- "localpart": "someone",
- "displayname": "someone interesting",
- "duration_seconds": 200
- })
self.registration_handler.get_or_create_user = Mock(
return_value=(user_id, token)
)
- (code, result) = yield self.servlet.on_POST(self.request)
- self.assertEquals(code, 200)
+ request, channel = make_request(b"POST", url, request_data)
+ request.render(res)
+
+ # Advance the clock because it waits
+ self.clock.advance(1)
+
+ self.assertEquals(channel.result["code"], b"200")
det_data = {
"user_id": user_id,
"access_token": token,
- "home_server": self.hs.hostname
+ "home_server": self.hs.hostname,
}
- self.assertDictContainsSubset(det_data, result)
+ self.assertDictContainsSubset(det_data, json.loads(channel.result["body"]))
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index 61d737725b..00fc796787 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -15,960 +15,782 @@
"""Tests REST events for /rooms paths."""
-# twisted imports
+import json
+
+from mock import Mock, NonCallableMock
+from six.moves.urllib import parse as urlparse
+
from twisted.internet import defer
import synapse.rest.client.v1.room
from synapse.api.constants import Membership
-
+from synapse.http.server import JsonResource
from synapse.types import UserID
+from synapse.util import Clock
-import json
-from six.moves.urllib import parse as urlparse
-
-from ....utils import MockHttpResource, setup_test_homeserver
-from .utils import RestTestCase
+from tests import unittest
+from tests.server import (
+ ThreadedMemoryReactorClock,
+ make_request,
+ render,
+ setup_test_homeserver,
+)
-from mock import Mock, NonCallableMock
+from .utils import RestHelper
-PATH_PREFIX = "/_matrix/client/api/v1"
+PATH_PREFIX = b"/_matrix/client/api/v1"
-class RoomPermissionsTestCase(RestTestCase):
- """ Tests room permissions. """
- user_id = "@sid1:red"
- rmcreator_id = "@notme:red"
+class RoomBase(unittest.TestCase):
+ rmcreator_id = None
- @defer.inlineCallbacks
def setUp(self):
- self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
- hs = yield setup_test_homeserver(
+ self.clock = ThreadedMemoryReactorClock()
+ self.hs_clock = Clock(self.clock)
+
+ self.hs = setup_test_homeserver(
"red",
http_client=None,
+ clock=self.hs_clock,
+ reactor=self.clock,
federation_client=Mock(),
ratelimiter=NonCallableMock(spec_set=["send_message"]),
)
- self.ratelimiter = hs.get_ratelimiter()
+ self.ratelimiter = self.hs.get_ratelimiter()
self.ratelimiter.send_message.return_value = (True, 0)
- hs.get_handlers().federation_handler = Mock()
+ self.hs.get_federation_handler = Mock(return_value=Mock())
def get_user_by_access_token(token=None, allow_guest=False):
return {
- "user": UserID.from_string(self.auth_user_id),
+ "user": UserID.from_string(self.helper.auth_user_id),
"token_id": 1,
"is_guest": False,
}
- hs.get_auth().get_user_by_access_token = get_user_by_access_token
+
+ def get_user_by_req(request, allow_guest=False, rights="access"):
+ return synapse.types.create_requester(
+ UserID.from_string(self.helper.auth_user_id), 1, False, None
+ )
+
+ self.hs.get_auth().get_user_by_req = get_user_by_req
+ self.hs.get_auth().get_user_by_access_token = get_user_by_access_token
+ self.hs.get_auth().get_access_token_from_request = Mock(return_value=b"1234")
def _insert_client_ip(*args, **kwargs):
return defer.succeed(None)
- hs.get_datastore().insert_client_ip = _insert_client_ip
- self.auth_user_id = self.rmcreator_id
+ self.hs.get_datastore().insert_client_ip = _insert_client_ip
- synapse.rest.client.v1.room.register_servlets(hs, self.mock_resource)
+ self.resource = JsonResource(self.hs)
+ synapse.rest.client.v1.room.register_servlets(self.hs, self.resource)
+ synapse.rest.client.v1.room.register_deprecated_servlets(self.hs, self.resource)
+ self.helper = RestHelper(self.hs, self.resource, self.user_id)
- self.auth = hs.get_auth()
- # create some rooms under the name rmcreator_id
- self.uncreated_rmid = "!aa:test"
+class RoomPermissionsTestCase(RoomBase):
+ """ Tests room permissions. """
+
+ user_id = b"@sid1:red"
+ rmcreator_id = b"@notme:red"
- self.created_rmid = yield self.create_room_as(self.rmcreator_id,
- is_public=False)
+ def setUp(self):
+
+ super(RoomPermissionsTestCase, self).setUp()
- self.created_public_rmid = yield self.create_room_as(self.rmcreator_id,
- is_public=True)
+ self.helper.auth_user_id = self.rmcreator_id
+ # create some rooms under the name rmcreator_id
+ self.uncreated_rmid = "!aa:test"
+ self.created_rmid = self.helper.create_room_as(
+ self.rmcreator_id, is_public=False
+ )
+ self.created_public_rmid = self.helper.create_room_as(
+ self.rmcreator_id, is_public=True
+ )
# send a message in one of the rooms
self.created_rmid_msg_path = (
- "/rooms/%s/send/m.room.message/a1" % (self.created_rmid)
- )
- (code, response) = yield self.mock_resource.trigger(
- "PUT",
+ "rooms/%s/send/m.room.message/a1" % (self.created_rmid)
+ ).encode('ascii')
+ request, channel = make_request(
+ b"PUT",
self.created_rmid_msg_path,
- '{"msgtype":"m.text","body":"test msg"}'
+ b'{"msgtype":"m.text","body":"test msg"}',
)
- self.assertEquals(200, code, msg=str(response))
+ render(request, self.resource, self.clock)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
# set topic for public room
- (code, response) = yield self.mock_resource.trigger(
- "PUT",
- "/rooms/%s/state/m.room.topic" % self.created_public_rmid,
- '{"topic":"Public Room Topic"}'
+ request, channel = make_request(
+ b"PUT",
+ ("rooms/%s/state/m.room.topic" % self.created_public_rmid).encode('ascii'),
+ b'{"topic":"Public Room Topic"}',
)
- self.assertEquals(200, code, msg=str(response))
+ render(request, self.resource, self.clock)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
# auth as user_id now
- self.auth_user_id = self.user_id
-
- def tearDown(self):
- pass
+ self.helper.auth_user_id = self.user_id
- @defer.inlineCallbacks
def test_send_message(self):
- msg_content = '{"msgtype":"m.text","body":"hello"}'
- send_msg_path = (
- "/rooms/%s/send/m.room.message/mid1" % (self.created_rmid,)
- )
+ msg_content = b'{"msgtype":"m.text","body":"hello"}'
+
+ seq = iter(range(100))
+
+ def send_msg_path():
+ return b"/rooms/%s/send/m.room.message/mid%s" % (
+ self.created_rmid,
+ str(next(seq)).encode('ascii'),
+ )
# send message in uncreated room, expect 403
- (code, response) = yield self.mock_resource.trigger(
- "PUT",
- "/rooms/%s/send/m.room.message/mid2" % (self.uncreated_rmid,),
- msg_content
+ request, channel = make_request(
+ b"PUT",
+ b"/rooms/%s/send/m.room.message/mid2" % (self.uncreated_rmid,),
+ msg_content,
)
- self.assertEquals(403, code, msg=str(response))
+ render(request, self.resource, self.clock)
+ self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"])
# send message in created room not joined (no state), expect 403
- (code, response) = yield self.mock_resource.trigger(
- "PUT",
- send_msg_path,
- msg_content
- )
- self.assertEquals(403, code, msg=str(response))
+ request, channel = make_request(b"PUT", send_msg_path(), msg_content)
+ render(request, self.resource, self.clock)
+ self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"])
# send message in created room and invited, expect 403
- yield self.invite(
- room=self.created_rmid,
- src=self.rmcreator_id,
- targ=self.user_id
- )
- (code, response) = yield self.mock_resource.trigger(
- "PUT",
- send_msg_path,
- msg_content
+ self.helper.invite(
+ room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id
)
- self.assertEquals(403, code, msg=str(response))
+ request, channel = make_request(b"PUT", send_msg_path(), msg_content)
+ render(request, self.resource, self.clock)
+ self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"])
# send message in created room and joined, expect 200
- yield self.join(room=self.created_rmid, user=self.user_id)
- (code, response) = yield self.mock_resource.trigger(
- "PUT",
- send_msg_path,
- msg_content
- )
- self.assertEquals(200, code, msg=str(response))
+ self.helper.join(room=self.created_rmid, user=self.user_id)
+ request, channel = make_request(b"PUT", send_msg_path(), msg_content)
+ render(request, self.resource, self.clock)
+ self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
# send message in created room and left, expect 403
- yield self.leave(room=self.created_rmid, user=self.user_id)
- (code, response) = yield self.mock_resource.trigger(
- "PUT",
- send_msg_path,
- msg_content
- )
- self.assertEquals(403, code, msg=str(response))
+ self.helper.leave(room=self.created_rmid, user=self.user_id)
+ request, channel = make_request(b"PUT", send_msg_path(), msg_content)
+ render(request, self.resource, self.clock)
+ self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"])
- @defer.inlineCallbacks
def test_topic_perms(self):
- topic_content = '{"topic":"My Topic Name"}'
- topic_path = "/rooms/%s/state/m.room.topic" % self.created_rmid
+ topic_content = b'{"topic":"My Topic Name"}'
+ topic_path = b"/rooms/%s/state/m.room.topic" % self.created_rmid
# set/get topic in uncreated room, expect 403
- (code, response) = yield self.mock_resource.trigger(
- "PUT", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid,
- topic_content
+ request, channel = make_request(
+ b"PUT", b"/rooms/%s/state/m.room.topic" % self.uncreated_rmid, topic_content
)
- self.assertEquals(403, code, msg=str(response))
- (code, response) = yield self.mock_resource.trigger_get(
- "/rooms/%s/state/m.room.topic" % self.uncreated_rmid
+ render(request, self.resource, self.clock)
+ self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"])
+ request, channel = make_request(
+ b"GET", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid
)
- self.assertEquals(403, code, msg=str(response))
+ render(request, self.resource, self.clock)
+ self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"])
# set/get topic in created PRIVATE room not joined, expect 403
- (code, response) = yield self.mock_resource.trigger(
- "PUT", topic_path, topic_content
- )
- self.assertEquals(403, code, msg=str(response))
- (code, response) = yield self.mock_resource.trigger_get(topic_path)
- self.assertEquals(403, code, msg=str(response))
+ request, channel = make_request(b"PUT", topic_path, topic_content)
+ render(request, self.resource, self.clock)
+ self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"])
+ request, channel = make_request(b"GET", topic_path)
+ render(request, self.resource, self.clock)
+ self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"])
# set topic in created PRIVATE room and invited, expect 403
- yield self.invite(
+ self.helper.invite(
room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id
)
- (code, response) = yield self.mock_resource.trigger(
- "PUT", topic_path, topic_content
- )
- self.assertEquals(403, code, msg=str(response))
+ request, channel = make_request(b"PUT", topic_path, topic_content)
+ render(request, self.resource, self.clock)
+ self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"])
# get topic in created PRIVATE room and invited, expect 403
- (code, response) = yield self.mock_resource.trigger_get(topic_path)
- self.assertEquals(403, code, msg=str(response))
+ request, channel = make_request(b"GET", topic_path)
+ render(request, self.resource, self.clock)
+ self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"])
# set/get topic in created PRIVATE room and joined, expect 200
- yield self.join(room=self.created_rmid, user=self.user_id)
+ self.helper.join(room=self.created_rmid, user=self.user_id)
# Only room ops can set topic by default
- self.auth_user_id = self.rmcreator_id
- (code, response) = yield self.mock_resource.trigger(
- "PUT", topic_path, topic_content
- )
- self.assertEquals(200, code, msg=str(response))
- self.auth_user_id = self.user_id
+ self.helper.auth_user_id = self.rmcreator_id
+ request, channel = make_request(b"PUT", topic_path, topic_content)
+ render(request, self.resource, self.clock)
+ self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.helper.auth_user_id = self.user_id
- (code, response) = yield self.mock_resource.trigger_get(topic_path)
- self.assertEquals(200, code, msg=str(response))
- self.assert_dict(json.loads(topic_content), response)
+ request, channel = make_request(b"GET", topic_path)
+ render(request, self.resource, self.clock)
+ self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assert_dict(json.loads(topic_content), channel.json_body)
# set/get topic in created PRIVATE room and left, expect 403
- yield self.leave(room=self.created_rmid, user=self.user_id)
- (code, response) = yield self.mock_resource.trigger(
- "PUT", topic_path, topic_content
- )
- self.assertEquals(403, code, msg=str(response))
- (code, response) = yield self.mock_resource.trigger_get(topic_path)
- self.assertEquals(200, code, msg=str(response))
+ self.helper.leave(room=self.created_rmid, user=self.user_id)
+ request, channel = make_request(b"PUT", topic_path, topic_content)
+ render(request, self.resource, self.clock)
+ self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"])
+ request, channel = make_request(b"GET", topic_path)
+ render(request, self.resource, self.clock)
+ self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
# get topic in PUBLIC room, not joined, expect 403
- (code, response) = yield self.mock_resource.trigger_get(
- "/rooms/%s/state/m.room.topic" % self.created_public_rmid
+ request, channel = make_request(
+ b"GET", b"/rooms/%s/state/m.room.topic" % self.created_public_rmid
)
- self.assertEquals(403, code, msg=str(response))
+ render(request, self.resource, self.clock)
+ self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"])
# set topic in PUBLIC room, not joined, expect 403
- (code, response) = yield self.mock_resource.trigger(
- "PUT",
- "/rooms/%s/state/m.room.topic" % self.created_public_rmid,
- topic_content
+ request, channel = make_request(
+ b"PUT",
+ b"/rooms/%s/state/m.room.topic" % self.created_public_rmid,
+ topic_content,
)
- self.assertEquals(403, code, msg=str(response))
+ render(request, self.resource, self.clock)
+ self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"])
- @defer.inlineCallbacks
def _test_get_membership(self, room=None, members=[], expect_code=None):
for member in members:
- path = "/rooms/%s/state/m.room.member/%s" % (room, member)
- (code, response) = yield self.mock_resource.trigger_get(path)
- self.assertEquals(expect_code, code)
+ path = b"/rooms/%s/state/m.room.member/%s" % (room, member)
+ request, channel = make_request(b"GET", path)
+ render(request, self.resource, self.clock)
+ self.assertEquals(expect_code, int(channel.result["code"]))
- @defer.inlineCallbacks
def test_membership_basic_room_perms(self):
# === room does not exist ===
room = self.uncreated_rmid
# get membership of self, get membership of other, uncreated room
# expect all 403s
- yield self._test_get_membership(
- members=[self.user_id, self.rmcreator_id],
- room=room, expect_code=403)
+ self._test_get_membership(
+ members=[self.user_id, self.rmcreator_id], room=room, expect_code=403
+ )
# trying to invite people to this room should 403
- yield self.invite(room=room, src=self.user_id, targ=self.rmcreator_id,
- expect_code=403)
+ self.helper.invite(
+ room=room, src=self.user_id, targ=self.rmcreator_id, expect_code=403
+ )
# set [invite/join/left] of self, set [invite/join/left] of other,
# expect all 404s because room doesn't exist on any server
for usr in [self.user_id, self.rmcreator_id]:
- yield self.join(room=room, user=usr, expect_code=404)
- yield self.leave(room=room, user=usr, expect_code=404)
+ self.helper.join(room=room, user=usr, expect_code=404)
+ self.helper.leave(room=room, user=usr, expect_code=404)
- @defer.inlineCallbacks
def test_membership_private_room_perms(self):
room = self.created_rmid
# get membership of self, get membership of other, private room + invite
# expect all 403s
- yield self.invite(room=room, src=self.rmcreator_id,
- targ=self.user_id)
- yield self._test_get_membership(
- members=[self.user_id, self.rmcreator_id],
- room=room, expect_code=403)
+ self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id)
+ self._test_get_membership(
+ members=[self.user_id, self.rmcreator_id], room=room, expect_code=403
+ )
# get membership of self, get membership of other, private room + joined
# expect all 200s
- yield self.join(room=room, user=self.user_id)
- yield self._test_get_membership(
- members=[self.user_id, self.rmcreator_id],
- room=room, expect_code=200)
+ self.helper.join(room=room, user=self.user_id)
+ self._test_get_membership(
+ members=[self.user_id, self.rmcreator_id], room=room, expect_code=200
+ )
# get membership of self, get membership of other, private room + left
# expect all 200s
- yield self.leave(room=room, user=self.user_id)
- yield self._test_get_membership(
- members=[self.user_id, self.rmcreator_id],
- room=room, expect_code=200)
+ self.helper.leave(room=room, user=self.user_id)
+ self._test_get_membership(
+ members=[self.user_id, self.rmcreator_id], room=room, expect_code=200
+ )
- @defer.inlineCallbacks
def test_membership_public_room_perms(self):
room = self.created_public_rmid
# get membership of self, get membership of other, public room + invite
# expect 403
- yield self.invite(room=room, src=self.rmcreator_id,
- targ=self.user_id)
- yield self._test_get_membership(
- members=[self.user_id, self.rmcreator_id],
- room=room, expect_code=403)
+ self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id)
+ self._test_get_membership(
+ members=[self.user_id, self.rmcreator_id], room=room, expect_code=403
+ )
# get membership of self, get membership of other, public room + joined
# expect all 200s
- yield self.join(room=room, user=self.user_id)
- yield self._test_get_membership(
- members=[self.user_id, self.rmcreator_id],
- room=room, expect_code=200)
+ self.helper.join(room=room, user=self.user_id)
+ self._test_get_membership(
+ members=[self.user_id, self.rmcreator_id], room=room, expect_code=200
+ )
# get membership of self, get membership of other, public room + left
# expect 200.
- yield self.leave(room=room, user=self.user_id)
- yield self._test_get_membership(
- members=[self.user_id, self.rmcreator_id],
- room=room, expect_code=200)
+ self.helper.leave(room=room, user=self.user_id)
+ self._test_get_membership(
+ members=[self.user_id, self.rmcreator_id], room=room, expect_code=200
+ )
- @defer.inlineCallbacks
def test_invited_permissions(self):
room = self.created_rmid
- yield self.invite(room=room, src=self.rmcreator_id, targ=self.user_id)
+ self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id)
# set [invite/join/left] of other user, expect 403s
- yield self.invite(room=room, src=self.user_id, targ=self.rmcreator_id,
- expect_code=403)
- yield self.change_membership(room=room, src=self.user_id,
- targ=self.rmcreator_id,
- membership=Membership.JOIN,
- expect_code=403)
- yield self.change_membership(room=room, src=self.user_id,
- targ=self.rmcreator_id,
- membership=Membership.LEAVE,
- expect_code=403)
-
- @defer.inlineCallbacks
+ self.helper.invite(
+ room=room, src=self.user_id, targ=self.rmcreator_id, expect_code=403
+ )
+ self.helper.change_membership(
+ room=room,
+ src=self.user_id,
+ targ=self.rmcreator_id,
+ membership=Membership.JOIN,
+ expect_code=403,
+ )
+ self.helper.change_membership(
+ room=room,
+ src=self.user_id,
+ targ=self.rmcreator_id,
+ membership=Membership.LEAVE,
+ expect_code=403,
+ )
+
def test_joined_permissions(self):
room = self.created_rmid
- yield self.invite(room=room, src=self.rmcreator_id, targ=self.user_id)
- yield self.join(room=room, user=self.user_id)
+ self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id)
+ self.helper.join(room=room, user=self.user_id)
# set invited of self, expect 403
- yield self.invite(room=room, src=self.user_id, targ=self.user_id,
- expect_code=403)
+ self.helper.invite(
+ room=room, src=self.user_id, targ=self.user_id, expect_code=403
+ )
# set joined of self, expect 200 (NOOP)
- yield self.join(room=room, user=self.user_id)
+ self.helper.join(room=room, user=self.user_id)
other = "@burgundy:red"
# set invited of other, expect 200
- yield self.invite(room=room, src=self.user_id, targ=other,
- expect_code=200)
+ self.helper.invite(room=room, src=self.user_id, targ=other, expect_code=200)
# set joined of other, expect 403
- yield self.change_membership(room=room, src=self.user_id,
- targ=other,
- membership=Membership.JOIN,
- expect_code=403)
+ self.helper.change_membership(
+ room=room,
+ src=self.user_id,
+ targ=other,
+ membership=Membership.JOIN,
+ expect_code=403,
+ )
# set left of other, expect 403
- yield self.change_membership(room=room, src=self.user_id,
- targ=other,
- membership=Membership.LEAVE,
- expect_code=403)
+ self.helper.change_membership(
+ room=room,
+ src=self.user_id,
+ targ=other,
+ membership=Membership.LEAVE,
+ expect_code=403,
+ )
# set left of self, expect 200
- yield self.leave(room=room, user=self.user_id)
+ self.helper.leave(room=room, user=self.user_id)
- @defer.inlineCallbacks
def test_leave_permissions(self):
room = self.created_rmid
- yield self.invite(room=room, src=self.rmcreator_id, targ=self.user_id)
- yield self.join(room=room, user=self.user_id)
- yield self.leave(room=room, user=self.user_id)
+ self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id)
+ self.helper.join(room=room, user=self.user_id)
+ self.helper.leave(room=room, user=self.user_id)
# set [invite/join/left] of self, set [invite/join/left] of other,
# expect all 403s
for usr in [self.user_id, self.rmcreator_id]:
- yield self.change_membership(
+ self.helper.change_membership(
room=room,
src=self.user_id,
targ=usr,
membership=Membership.INVITE,
- expect_code=403
+ expect_code=403,
)
- yield self.change_membership(
+ self.helper.change_membership(
room=room,
src=self.user_id,
targ=usr,
membership=Membership.JOIN,
- expect_code=403
+ expect_code=403,
)
# It is always valid to LEAVE if you've already left (currently.)
- yield self.change_membership(
+ self.helper.change_membership(
room=room,
src=self.user_id,
targ=self.rmcreator_id,
membership=Membership.LEAVE,
- expect_code=403
+ expect_code=403,
)
-class RoomsMemberListTestCase(RestTestCase):
+class RoomsMemberListTestCase(RoomBase):
""" Tests /rooms/$room_id/members/list REST events."""
- user_id = "@sid1:red"
- @defer.inlineCallbacks
- def setUp(self):
- self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
-
- hs = yield setup_test_homeserver(
- "red",
- http_client=None,
- federation_client=Mock(),
- ratelimiter=NonCallableMock(spec_set=["send_message"]),
- )
- self.ratelimiter = hs.get_ratelimiter()
- self.ratelimiter.send_message.return_value = (True, 0)
-
- hs.get_handlers().federation_handler = Mock()
-
- self.auth_user_id = self.user_id
-
- def get_user_by_access_token(token=None, allow_guest=False):
- return {
- "user": UserID.from_string(self.auth_user_id),
- "token_id": 1,
- "is_guest": False,
- }
- hs.get_auth().get_user_by_access_token = get_user_by_access_token
-
- def _insert_client_ip(*args, **kwargs):
- return defer.succeed(None)
- hs.get_datastore().insert_client_ip = _insert_client_ip
-
- synapse.rest.client.v1.room.register_servlets(hs, self.mock_resource)
-
- def tearDown(self):
- pass
+ user_id = b"@sid1:red"
- @defer.inlineCallbacks
def test_get_member_list(self):
- room_id = yield self.create_room_as(self.user_id)
- (code, response) = yield self.mock_resource.trigger_get(
- "/rooms/%s/members" % room_id
- )
- self.assertEquals(200, code, msg=str(response))
+ room_id = self.helper.create_room_as(self.user_id)
+ request, channel = make_request(b"GET", b"/rooms/%s/members" % room_id)
+ render(request, self.resource, self.clock)
+ self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
- @defer.inlineCallbacks
def test_get_member_list_no_room(self):
- (code, response) = yield self.mock_resource.trigger_get(
- "/rooms/roomdoesnotexist/members"
- )
- self.assertEquals(403, code, msg=str(response))
+ request, channel = make_request(b"GET", b"/rooms/roomdoesnotexist/members")
+ render(request, self.resource, self.clock)
+ self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"])
- @defer.inlineCallbacks
def test_get_member_list_no_permission(self):
- room_id = yield self.create_room_as("@some_other_guy:red")
- (code, response) = yield self.mock_resource.trigger_get(
- "/rooms/%s/members" % room_id
- )
- self.assertEquals(403, code, msg=str(response))
+ room_id = self.helper.create_room_as(b"@some_other_guy:red")
+ request, channel = make_request(b"GET", b"/rooms/%s/members" % room_id)
+ render(request, self.resource, self.clock)
+ self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"])
- @defer.inlineCallbacks
def test_get_member_list_mixed_memberships(self):
- room_creator = "@some_other_guy:red"
- room_id = yield self.create_room_as(room_creator)
- room_path = "/rooms/%s/members" % room_id
- yield self.invite(room=room_id, src=room_creator,
- targ=self.user_id)
+ room_creator = b"@some_other_guy:red"
+ room_id = self.helper.create_room_as(room_creator)
+ room_path = b"/rooms/%s/members" % room_id
+ self.helper.invite(room=room_id, src=room_creator, targ=self.user_id)
# can't see list if you're just invited.
- (code, response) = yield self.mock_resource.trigger_get(room_path)
- self.assertEquals(403, code, msg=str(response))
+ request, channel = make_request(b"GET", room_path)
+ render(request, self.resource, self.clock)
+ self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"])
- yield self.join(room=room_id, user=self.user_id)
+ self.helper.join(room=room_id, user=self.user_id)
# can see list now joined
- (code, response) = yield self.mock_resource.trigger_get(room_path)
- self.assertEquals(200, code, msg=str(response))
+ request, channel = make_request(b"GET", room_path)
+ render(request, self.resource, self.clock)
+ self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
- yield self.leave(room=room_id, user=self.user_id)
+ self.helper.leave(room=room_id, user=self.user_id)
# can see old list once left
- (code, response) = yield self.mock_resource.trigger_get(room_path)
- self.assertEquals(200, code, msg=str(response))
+ request, channel = make_request(b"GET", room_path)
+ render(request, self.resource, self.clock)
+ self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
-class RoomsCreateTestCase(RestTestCase):
+class RoomsCreateTestCase(RoomBase):
""" Tests /rooms and /rooms/$room_id REST events. """
- user_id = "@sid1:red"
-
- @defer.inlineCallbacks
- def setUp(self):
- self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
- self.auth_user_id = self.user_id
-
- hs = yield setup_test_homeserver(
- "red",
- http_client=None,
- federation_client=Mock(),
- ratelimiter=NonCallableMock(spec_set=["send_message"]),
- )
- self.ratelimiter = hs.get_ratelimiter()
- self.ratelimiter.send_message.return_value = (True, 0)
- hs.get_handlers().federation_handler = Mock()
+ user_id = b"@sid1:red"
- def get_user_by_access_token(token=None, allow_guest=False):
- return {
- "user": UserID.from_string(self.auth_user_id),
- "token_id": 1,
- "is_guest": False,
- }
- hs.get_auth().get_user_by_access_token = get_user_by_access_token
-
- def _insert_client_ip(*args, **kwargs):
- return defer.succeed(None)
- hs.get_datastore().insert_client_ip = _insert_client_ip
-
- synapse.rest.client.v1.room.register_servlets(hs, self.mock_resource)
-
- @defer.inlineCallbacks
def test_post_room_no_keys(self):
# POST with no config keys, expect new room id
- (code, response) = yield self.mock_resource.trigger("POST",
- "/createRoom",
- "{}")
- self.assertEquals(200, code, response)
- self.assertTrue("room_id" in response)
+ request, channel = make_request(b"POST", b"/createRoom", b"{}")
+
+ render(request, self.resource, self.clock)
+ self.assertEquals(200, int(channel.result["code"]), channel.result)
+ self.assertTrue("room_id" in channel.json_body)
- @defer.inlineCallbacks
def test_post_room_visibility_key(self):
# POST with visibility config key, expect new room id
- (code, response) = yield self.mock_resource.trigger(
- "POST",
- "/createRoom",
- '{"visibility":"private"}')
- self.assertEquals(200, code)
- self.assertTrue("room_id" in response)
-
- @defer.inlineCallbacks
+ request, channel = make_request(
+ b"POST", b"/createRoom", b'{"visibility":"private"}'
+ )
+ render(request, self.resource, self.clock)
+ self.assertEquals(200, int(channel.result["code"]))
+ self.assertTrue("room_id" in channel.json_body)
+
def test_post_room_custom_key(self):
# POST with custom config keys, expect new room id
- (code, response) = yield self.mock_resource.trigger(
- "POST",
- "/createRoom",
- '{"custom":"stuff"}')
- self.assertEquals(200, code)
- self.assertTrue("room_id" in response)
-
- @defer.inlineCallbacks
+ request, channel = make_request(b"POST", b"/createRoom", b'{"custom":"stuff"}')
+ render(request, self.resource, self.clock)
+ self.assertEquals(200, int(channel.result["code"]))
+ self.assertTrue("room_id" in channel.json_body)
+
def test_post_room_known_and_unknown_keys(self):
# POST with custom + known config keys, expect new room id
- (code, response) = yield self.mock_resource.trigger(
- "POST",
- "/createRoom",
- '{"visibility":"private","custom":"things"}')
- self.assertEquals(200, code)
- self.assertTrue("room_id" in response)
-
- @defer.inlineCallbacks
+ request, channel = make_request(
+ b"POST", b"/createRoom", b'{"visibility":"private","custom":"things"}'
+ )
+ render(request, self.resource, self.clock)
+ self.assertEquals(200, int(channel.result["code"]))
+ self.assertTrue("room_id" in channel.json_body)
+
def test_post_room_invalid_content(self):
# POST with invalid content / paths, expect 400
- (code, response) = yield self.mock_resource.trigger(
- "POST",
- "/createRoom",
- '{"visibili')
- self.assertEquals(400, code)
+ request, channel = make_request(b"POST", b"/createRoom", b'{"visibili')
+ render(request, self.resource, self.clock)
+ self.assertEquals(400, int(channel.result["code"]))
- (code, response) = yield self.mock_resource.trigger(
- "POST",
- "/createRoom",
- '["hello"]')
- self.assertEquals(400, code)
+ request, channel = make_request(b"POST", b"/createRoom", b'["hello"]')
+ render(request, self.resource, self.clock)
+ self.assertEquals(400, int(channel.result["code"]))
-class RoomTopicTestCase(RestTestCase):
+class RoomTopicTestCase(RoomBase):
""" Tests /rooms/$room_id/topic REST events. """
- user_id = "@sid1:red"
-
- @defer.inlineCallbacks
- def setUp(self):
- self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
- self.auth_user_id = self.user_id
-
- hs = yield setup_test_homeserver(
- "red",
- http_client=None,
- federation_client=Mock(),
- ratelimiter=NonCallableMock(spec_set=["send_message"]),
- )
- self.ratelimiter = hs.get_ratelimiter()
- self.ratelimiter.send_message.return_value = (True, 0)
- hs.get_handlers().federation_handler = Mock()
+ user_id = b"@sid1:red"
- def get_user_by_access_token(token=None, allow_guest=False):
- return {
- "user": UserID.from_string(self.auth_user_id),
- "token_id": 1,
- "is_guest": False,
- }
-
- hs.get_auth().get_user_by_access_token = get_user_by_access_token
-
- def _insert_client_ip(*args, **kwargs):
- return defer.succeed(None)
- hs.get_datastore().insert_client_ip = _insert_client_ip
+ def setUp(self):
- synapse.rest.client.v1.room.register_servlets(hs, self.mock_resource)
+ super(RoomTopicTestCase, self).setUp()
# create the room
- self.room_id = yield self.create_room_as(self.user_id)
- self.path = "/rooms/%s/state/m.room.topic" % (self.room_id,)
-
- def tearDown(self):
- pass
+ self.room_id = self.helper.create_room_as(self.user_id)
+ self.path = b"/rooms/%s/state/m.room.topic" % (self.room_id,)
- @defer.inlineCallbacks
def test_invalid_puts(self):
# missing keys or invalid json
- (code, response) = yield self.mock_resource.trigger(
- "PUT", self.path, '{}'
- )
- self.assertEquals(400, code, msg=str(response))
+ request, channel = make_request(b"PUT", self.path, '{}')
+ render(request, self.resource, self.clock)
+ self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"])
- (code, response) = yield self.mock_resource.trigger(
- "PUT", self.path, '{"_name":"bob"}'
- )
- self.assertEquals(400, code, msg=str(response))
+ request, channel = make_request(b"PUT", self.path, '{"_name":"bob"}')
+ render(request, self.resource, self.clock)
+ self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"])
- (code, response) = yield self.mock_resource.trigger(
- "PUT", self.path, '{"nao'
- )
- self.assertEquals(400, code, msg=str(response))
+ request, channel = make_request(b"PUT", self.path, '{"nao')
+ render(request, self.resource, self.clock)
+ self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"])
- (code, response) = yield self.mock_resource.trigger(
- "PUT", self.path, '[{"_name":"bob"},{"_name":"jill"}]'
+ request, channel = make_request(
+ b"PUT", self.path, '[{"_name":"bob"},{"_name":"jill"}]'
)
- self.assertEquals(400, code, msg=str(response))
+ render(request, self.resource, self.clock)
+ self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"])
- (code, response) = yield self.mock_resource.trigger(
- "PUT", self.path, 'text only'
- )
- self.assertEquals(400, code, msg=str(response))
+ request, channel = make_request(b"PUT", self.path, 'text only')
+ render(request, self.resource, self.clock)
+ self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"])
- (code, response) = yield self.mock_resource.trigger(
- "PUT", self.path, ''
- )
- self.assertEquals(400, code, msg=str(response))
+ request, channel = make_request(b"PUT", self.path, '')
+ render(request, self.resource, self.clock)
+ self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"])
# valid key, wrong type
content = '{"topic":["Topic name"]}'
- (code, response) = yield self.mock_resource.trigger(
- "PUT", self.path, content
- )
- self.assertEquals(400, code, msg=str(response))
+ request, channel = make_request(b"PUT", self.path, content)
+ render(request, self.resource, self.clock)
+ self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"])
- @defer.inlineCallbacks
def test_rooms_topic(self):
# nothing should be there
- (code, response) = yield self.mock_resource.trigger_get(self.path)
- self.assertEquals(404, code, msg=str(response))
+ request, channel = make_request(b"GET", self.path)
+ render(request, self.resource, self.clock)
+ self.assertEquals(404, int(channel.result["code"]), msg=channel.result["body"])
# valid put
content = '{"topic":"Topic name"}'
- (code, response) = yield self.mock_resource.trigger(
- "PUT", self.path, content
- )
- self.assertEquals(200, code, msg=str(response))
+ request, channel = make_request(b"PUT", self.path, content)
+ render(request, self.resource, self.clock)
+ self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
# valid get
- (code, response) = yield self.mock_resource.trigger_get(self.path)
- self.assertEquals(200, code, msg=str(response))
- self.assert_dict(json.loads(content), response)
+ request, channel = make_request(b"GET", self.path)
+ render(request, self.resource, self.clock)
+ self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assert_dict(json.loads(content), channel.json_body)
- @defer.inlineCallbacks
def test_rooms_topic_with_extra_keys(self):
# valid put with extra keys
content = '{"topic":"Seasons","subtopic":"Summer"}'
- (code, response) = yield self.mock_resource.trigger(
- "PUT", self.path, content
- )
- self.assertEquals(200, code, msg=str(response))
+ request, channel = make_request(b"PUT", self.path, content)
+ render(request, self.resource, self.clock)
+ self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
# valid get
- (code, response) = yield self.mock_resource.trigger_get(self.path)
- self.assertEquals(200, code, msg=str(response))
- self.assert_dict(json.loads(content), response)
+ request, channel = make_request(b"GET", self.path)
+ render(request, self.resource, self.clock)
+ self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assert_dict(json.loads(content), channel.json_body)
-class RoomMemberStateTestCase(RestTestCase):
+class RoomMemberStateTestCase(RoomBase):
""" Tests /rooms/$room_id/members/$user_id/state REST events. """
- user_id = "@sid1:red"
- @defer.inlineCallbacks
- def setUp(self):
- self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
- self.auth_user_id = self.user_id
+ user_id = b"@sid1:red"
- hs = yield setup_test_homeserver(
- "red",
- http_client=None,
- federation_client=Mock(),
- ratelimiter=NonCallableMock(spec_set=["send_message"]),
- )
- self.ratelimiter = hs.get_ratelimiter()
- self.ratelimiter.send_message.return_value = (True, 0)
-
- hs.get_handlers().federation_handler = Mock()
-
- def get_user_by_access_token(token=None, allow_guest=False):
- return {
- "user": UserID.from_string(self.auth_user_id),
- "token_id": 1,
- "is_guest": False,
- }
- hs.get_auth().get_user_by_access_token = get_user_by_access_token
-
- def _insert_client_ip(*args, **kwargs):
- return defer.succeed(None)
- hs.get_datastore().insert_client_ip = _insert_client_ip
-
- synapse.rest.client.v1.room.register_servlets(hs, self.mock_resource)
+ def setUp(self):
- self.room_id = yield self.create_room_as(self.user_id)
+ super(RoomMemberStateTestCase, self).setUp()
+ self.room_id = self.helper.create_room_as(self.user_id)
def tearDown(self):
pass
- @defer.inlineCallbacks
def test_invalid_puts(self):
path = "/rooms/%s/state/m.room.member/%s" % (self.room_id, self.user_id)
# missing keys or invalid json
- (code, response) = yield self.mock_resource.trigger("PUT", path, '{}')
- self.assertEquals(400, code, msg=str(response))
+ request, channel = make_request(b"PUT", path, '{}')
+ render(request, self.resource, self.clock)
+ self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"])
- (code, response) = yield self.mock_resource.trigger(
- "PUT", path, '{"_name":"bob"}'
- )
- self.assertEquals(400, code, msg=str(response))
+ request, channel = make_request(b"PUT", path, '{"_name":"bob"}')
+ render(request, self.resource, self.clock)
+ self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"])
- (code, response) = yield self.mock_resource.trigger(
- "PUT", path, '{"nao'
- )
- self.assertEquals(400, code, msg=str(response))
+ request, channel = make_request(b"PUT", path, '{"nao')
+ render(request, self.resource, self.clock)
+ self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"])
- (code, response) = yield self.mock_resource.trigger(
- "PUT", path, '[{"_name":"bob"},{"_name":"jill"}]'
+ request, channel = make_request(
+ b"PUT", path, b'[{"_name":"bob"},{"_name":"jill"}]'
)
- self.assertEquals(400, code, msg=str(response))
+ render(request, self.resource, self.clock)
+ self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"])
- (code, response) = yield self.mock_resource.trigger(
- "PUT", path, 'text only'
- )
- self.assertEquals(400, code, msg=str(response))
+ request, channel = make_request(b"PUT", path, 'text only')
+ render(request, self.resource, self.clock)
+ self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"])
- (code, response) = yield self.mock_resource.trigger(
- "PUT", path, ''
- )
- self.assertEquals(400, code, msg=str(response))
+ request, channel = make_request(b"PUT", path, '')
+ render(request, self.resource, self.clock)
+ self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"])
# valid keys, wrong types
- content = ('{"membership":["%s","%s","%s"]}' % (
- Membership.INVITE, Membership.JOIN, Membership.LEAVE
- ))
- (code, response) = yield self.mock_resource.trigger("PUT", path, content)
- self.assertEquals(400, code, msg=str(response))
+ content = '{"membership":["%s","%s","%s"]}' % (
+ Membership.INVITE,
+ Membership.JOIN,
+ Membership.LEAVE,
+ )
+ request, channel = make_request(b"PUT", path, content.encode('ascii'))
+ render(request, self.resource, self.clock)
+ self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"])
- @defer.inlineCallbacks
def test_rooms_members_self(self):
path = "/rooms/%s/state/m.room.member/%s" % (
- urlparse.quote(self.room_id), self.user_id
+ urlparse.quote(self.room_id),
+ self.user_id,
)
# valid join message (NOOP since we made the room)
content = '{"membership":"%s"}' % Membership.JOIN
- (code, response) = yield self.mock_resource.trigger("PUT", path, content)
- self.assertEquals(200, code, msg=str(response))
+ request, channel = make_request(b"PUT", path, content.encode('ascii'))
+ render(request, self.resource, self.clock)
+ self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
- (code, response) = yield self.mock_resource.trigger("GET", path, None)
- self.assertEquals(200, code, msg=str(response))
+ request, channel = make_request(b"GET", path, None)
+ render(request, self.resource, self.clock)
+ self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
- expected_response = {
- "membership": Membership.JOIN,
- }
- self.assertEquals(expected_response, response)
+ expected_response = {"membership": Membership.JOIN}
+ self.assertEquals(expected_response, channel.json_body)
- @defer.inlineCallbacks
def test_rooms_members_other(self):
self.other_id = "@zzsid1:red"
path = "/rooms/%s/state/m.room.member/%s" % (
- urlparse.quote(self.room_id), self.other_id
+ urlparse.quote(self.room_id),
+ self.other_id,
)
# valid invite message
content = '{"membership":"%s"}' % Membership.INVITE
- (code, response) = yield self.mock_resource.trigger("PUT", path, content)
- self.assertEquals(200, code, msg=str(response))
+ request, channel = make_request(b"PUT", path, content)
+ render(request, self.resource, self.clock)
+ self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
- (code, response) = yield self.mock_resource.trigger("GET", path, None)
- self.assertEquals(200, code, msg=str(response))
- self.assertEquals(json.loads(content), response)
+ request, channel = make_request(b"GET", path, None)
+ render(request, self.resource, self.clock)
+ self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEquals(json.loads(content), channel.json_body)
- @defer.inlineCallbacks
def test_rooms_members_other_custom_keys(self):
self.other_id = "@zzsid1:red"
path = "/rooms/%s/state/m.room.member/%s" % (
- urlparse.quote(self.room_id), self.other_id
+ urlparse.quote(self.room_id),
+ self.other_id,
)
# valid invite message with custom key
- content = ('{"membership":"%s","invite_text":"%s"}' % (
- Membership.INVITE, "Join us!"
- ))
- (code, response) = yield self.mock_resource.trigger("PUT", path, content)
- self.assertEquals(200, code, msg=str(response))
+ content = '{"membership":"%s","invite_text":"%s"}' % (
+ Membership.INVITE,
+ "Join us!",
+ )
+ request, channel = make_request(b"PUT", path, content)
+ render(request, self.resource, self.clock)
+ self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
- (code, response) = yield self.mock_resource.trigger("GET", path, None)
- self.assertEquals(200, code, msg=str(response))
- self.assertEquals(json.loads(content), response)
+ request, channel = make_request(b"GET", path, None)
+ render(request, self.resource, self.clock)
+ self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEquals(json.loads(content), channel.json_body)
-class RoomMessagesTestCase(RestTestCase):
+class RoomMessagesTestCase(RoomBase):
""" Tests /rooms/$room_id/messages/$user_id/$msg_id REST events. """
+
user_id = "@sid1:red"
- @defer.inlineCallbacks
def setUp(self):
- self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
- self.auth_user_id = self.user_id
+ super(RoomMessagesTestCase, self).setUp()
- hs = yield setup_test_homeserver(
- "red",
- http_client=None,
- federation_client=Mock(),
- ratelimiter=NonCallableMock(spec_set=["send_message"]),
- )
- self.ratelimiter = hs.get_ratelimiter()
- self.ratelimiter.send_message.return_value = (True, 0)
+ self.room_id = self.helper.create_room_as(self.user_id)
- hs.get_handlers().federation_handler = Mock()
-
- def get_user_by_access_token(token=None, allow_guest=False):
- return {
- "user": UserID.from_string(self.auth_user_id),
- "token_id": 1,
- "is_guest": False,
- }
- hs.get_auth().get_user_by_access_token = get_user_by_access_token
-
- def _insert_client_ip(*args, **kwargs):
- return defer.succeed(None)
- hs.get_datastore().insert_client_ip = _insert_client_ip
-
- synapse.rest.client.v1.room.register_servlets(hs, self.mock_resource)
-
- self.room_id = yield self.create_room_as(self.user_id)
-
- def tearDown(self):
- pass
-
- @defer.inlineCallbacks
def test_invalid_puts(self):
- path = "/rooms/%s/send/m.room.message/mid1" % (
- urlparse.quote(self.room_id))
+ path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id))
# missing keys or invalid json
- (code, response) = yield self.mock_resource.trigger(
- "PUT", path, '{}'
- )
- self.assertEquals(400, code, msg=str(response))
+ request, channel = make_request(b"PUT", path, '{}')
+ render(request, self.resource, self.clock)
+ self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"])
- (code, response) = yield self.mock_resource.trigger(
- "PUT", path, '{"_name":"bob"}'
- )
- self.assertEquals(400, code, msg=str(response))
+ request, channel = make_request(b"PUT", path, '{"_name":"bob"}')
+ render(request, self.resource, self.clock)
+ self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"])
- (code, response) = yield self.mock_resource.trigger(
- "PUT", path, '{"nao'
- )
- self.assertEquals(400, code, msg=str(response))
+ request, channel = make_request(b"PUT", path, '{"nao')
+ render(request, self.resource, self.clock)
+ self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"])
- (code, response) = yield self.mock_resource.trigger(
- "PUT", path, '[{"_name":"bob"},{"_name":"jill"}]'
+ request, channel = make_request(
+ b"PUT", path, '[{"_name":"bob"},{"_name":"jill"}]'
)
- self.assertEquals(400, code, msg=str(response))
+ render(request, self.resource, self.clock)
+ self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"])
- (code, response) = yield self.mock_resource.trigger(
- "PUT", path, 'text only'
- )
- self.assertEquals(400, code, msg=str(response))
+ request, channel = make_request(b"PUT", path, 'text only')
+ render(request, self.resource, self.clock)
+ self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"])
- (code, response) = yield self.mock_resource.trigger(
- "PUT", path, ''
- )
- self.assertEquals(400, code, msg=str(response))
+ request, channel = make_request(b"PUT", path, '')
+ render(request, self.resource, self.clock)
+ self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"])
- @defer.inlineCallbacks
def test_rooms_messages_sent(self):
- path = "/rooms/%s/send/m.room.message/mid1" % (
- urlparse.quote(self.room_id))
+ path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id))
content = '{"body":"test","msgtype":{"type":"a"}}'
- (code, response) = yield self.mock_resource.trigger("PUT", path, content)
- self.assertEquals(400, code, msg=str(response))
+ request, channel = make_request(b"PUT", path, content)
+ render(request, self.resource, self.clock)
+ self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"])
# custom message types
content = '{"body":"test","msgtype":"test.custom.text"}'
- (code, response) = yield self.mock_resource.trigger("PUT", path, content)
- self.assertEquals(200, code, msg=str(response))
-
-# (code, response) = yield self.mock_resource.trigger("GET", path, None)
-# self.assertEquals(200, code, msg=str(response))
-# self.assert_dict(json.loads(content), response)
+ request, channel = make_request(b"PUT", path, content)
+ render(request, self.resource, self.clock)
+ self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
# m.text message type
- path = "/rooms/%s/send/m.room.message/mid2" % (
- urlparse.quote(self.room_id))
+ path = "/rooms/%s/send/m.room.message/mid2" % (urlparse.quote(self.room_id))
content = '{"body":"test2","msgtype":"m.text"}'
- (code, response) = yield self.mock_resource.trigger("PUT", path, content)
- self.assertEquals(200, code, msg=str(response))
+ request, channel = make_request(b"PUT", path, content)
+ render(request, self.resource, self.clock)
+ self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
-class RoomInitialSyncTestCase(RestTestCase):
+class RoomInitialSyncTestCase(RoomBase):
""" Tests /rooms/$room_id/initialSync. """
+
user_id = "@sid1:red"
- @defer.inlineCallbacks
def setUp(self):
- self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
- self.auth_user_id = self.user_id
-
- hs = yield setup_test_homeserver(
- "red",
- http_client=None,
- federation_client=Mock(),
- ratelimiter=NonCallableMock(spec_set=[
- "send_message",
- ]),
- )
- self.ratelimiter = hs.get_ratelimiter()
- self.ratelimiter.send_message.return_value = (True, 0)
-
- hs.get_handlers().federation_handler = Mock()
-
- def get_user_by_access_token(token=None, allow_guest=False):
- return {
- "user": UserID.from_string(self.auth_user_id),
- "token_id": 1,
- "is_guest": False,
- }
- hs.get_auth().get_user_by_access_token = get_user_by_access_token
-
- def _insert_client_ip(*args, **kwargs):
- return defer.succeed(None)
- hs.get_datastore().insert_client_ip = _insert_client_ip
-
- synapse.rest.client.v1.room.register_servlets(hs, self.mock_resource)
+ super(RoomInitialSyncTestCase, self).setUp()
# create the room
- self.room_id = yield self.create_room_as(self.user_id)
+ self.room_id = self.helper.create_room_as(self.user_id)
- @defer.inlineCallbacks
def test_initial_sync(self):
- (code, response) = yield self.mock_resource.trigger_get(
- "/rooms/%s/initialSync" % self.room_id
- )
- self.assertEquals(200, code)
+ request, channel = make_request(b"GET", "/rooms/%s/initialSync" % self.room_id)
+ render(request, self.resource, self.clock)
+ self.assertEquals(200, int(channel.result["code"]))
- self.assertEquals(self.room_id, response["room_id"])
- self.assertEquals("join", response["membership"])
+ self.assertEquals(self.room_id, channel.json_body["room_id"])
+ self.assertEquals("join", channel.json_body["membership"])
# Room state is easier to assert on if we unpack it into a dict
state = {}
- for event in response["state"]:
+ for event in channel.json_body["state"]:
if "state_key" not in event:
continue
t = event["type"]
@@ -978,75 +800,48 @@ class RoomInitialSyncTestCase(RestTestCase):
self.assertTrue("m.room.create" in state)
- self.assertTrue("messages" in response)
- self.assertTrue("chunk" in response["messages"])
- self.assertTrue("end" in response["messages"])
+ self.assertTrue("messages" in channel.json_body)
+ self.assertTrue("chunk" in channel.json_body["messages"])
+ self.assertTrue("end" in channel.json_body["messages"])
- self.assertTrue("presence" in response)
+ self.assertTrue("presence" in channel.json_body)
presence_by_user = {
- e["content"]["user_id"]: e for e in response["presence"]
+ e["content"]["user_id"]: e for e in channel.json_body["presence"]
}
self.assertTrue(self.user_id in presence_by_user)
self.assertEquals("m.presence", presence_by_user[self.user_id]["type"])
-class RoomMessageListTestCase(RestTestCase):
+class RoomMessageListTestCase(RoomBase):
""" Tests /rooms/$room_id/messages REST events. """
+
user_id = "@sid1:red"
- @defer.inlineCallbacks
def setUp(self):
- self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
- self.auth_user_id = self.user_id
+ super(RoomMessageListTestCase, self).setUp()
+ self.room_id = self.helper.create_room_as(self.user_id)
- hs = yield setup_test_homeserver(
- "red",
- http_client=None,
- federation_client=Mock(),
- ratelimiter=NonCallableMock(spec_set=["send_message"]),
- )
- self.ratelimiter = hs.get_ratelimiter()
- self.ratelimiter.send_message.return_value = (True, 0)
-
- hs.get_handlers().federation_handler = Mock()
-
- def get_user_by_access_token(token=None, allow_guest=False):
- return {
- "user": UserID.from_string(self.auth_user_id),
- "token_id": 1,
- "is_guest": False,
- }
- hs.get_auth().get_user_by_access_token = get_user_by_access_token
-
- def _insert_client_ip(*args, **kwargs):
- return defer.succeed(None)
- hs.get_datastore().insert_client_ip = _insert_client_ip
-
- synapse.rest.client.v1.room.register_servlets(hs, self.mock_resource)
-
- self.room_id = yield self.create_room_as(self.user_id)
-
- @defer.inlineCallbacks
def test_topo_token_is_accepted(self):
token = "t1-0_0_0_0_0_0_0_0_0"
- (code, response) = yield self.mock_resource.trigger_get(
- "/rooms/%s/messages?access_token=x&from=%s" %
- (self.room_id, token))
- self.assertEquals(200, code)
- self.assertTrue("start" in response)
- self.assertEquals(token, response['start'])
- self.assertTrue("chunk" in response)
- self.assertTrue("end" in response)
-
- @defer.inlineCallbacks
+ request, channel = make_request(
+ b"GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token)
+ )
+ render(request, self.resource, self.clock)
+ self.assertEquals(200, int(channel.result["code"]))
+ self.assertTrue("start" in channel.json_body)
+ self.assertEquals(token, channel.json_body['start'])
+ self.assertTrue("chunk" in channel.json_body)
+ self.assertTrue("end" in channel.json_body)
+
def test_stream_token_is_accepted_for_fwd_pagianation(self):
token = "s0_0_0_0_0_0_0_0_0"
- (code, response) = yield self.mock_resource.trigger_get(
- "/rooms/%s/messages?access_token=x&from=%s" %
- (self.room_id, token))
- self.assertEquals(200, code)
- self.assertTrue("start" in response)
- self.assertEquals(token, response['start'])
- self.assertTrue("chunk" in response)
- self.assertTrue("end" in response)
+ request, channel = make_request(
+ b"GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token)
+ )
+ render(request, self.resource, self.clock)
+ self.assertEquals(200, int(channel.result["code"]))
+ self.assertTrue("start" in channel.json_body)
+ self.assertEquals(token, channel.json_body['start'])
+ self.assertTrue("chunk" in channel.json_body)
+ self.assertTrue("end" in channel.json_body)
diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py
index fe161ee5cb..bddb3302e4 100644
--- a/tests/rest/client/v1/test_typing.py
+++ b/tests/rest/client/v1/test_typing.py
@@ -15,18 +15,17 @@
"""Tests REST events for /rooms paths."""
+from mock import Mock, NonCallableMock
+
# twisted imports
from twisted.internet import defer
import synapse.rest.client.v1.room
from synapse.types import UserID
-from ....utils import MockHttpResource, MockClock, setup_test_homeserver
+from ....utils import MockClock, MockHttpResource, setup_test_homeserver
from .utils import RestTestCase
-from mock import Mock, NonCallableMock
-
-
PATH_PREFIX = "/_matrix/client/api/v1"
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index 3bb1dd003a..41de8e0762 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -13,16 +13,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-# twisted imports
-from twisted.internet import defer
+import json
+import time
-# trial imports
-from tests import unittest
+import attr
+
+from twisted.internet import defer
from synapse.api.constants import Membership
-import json
-import time
+from tests import unittest
+from tests.server import make_request, wait_until_result
class RestTestCase(unittest.TestCase):
@@ -133,3 +134,113 @@ class RestTestCase(unittest.TestCase):
for key in required:
self.assertEquals(required[key], actual[key],
msg="%s mismatch. %s" % (key, actual))
+
+
+@attr.s
+class RestHelper(object):
+ """Contains extra helper functions to quickly and clearly perform a given
+ REST action, which isn't the focus of the test.
+ """
+
+ hs = attr.ib()
+ resource = attr.ib()
+ auth_user_id = attr.ib()
+
+ def create_room_as(self, room_creator, is_public=True, tok=None):
+ temp_id = self.auth_user_id
+ self.auth_user_id = room_creator
+ path = b"/_matrix/client/r0/createRoom"
+ content = {}
+ if not is_public:
+ content["visibility"] = "private"
+ if tok:
+ path = path + b"?access_token=%s" % tok.encode('ascii')
+
+ request, channel = make_request(b"POST", path, json.dumps(content).encode('utf8'))
+ request.render(self.resource)
+ wait_until_result(self.hs.get_reactor(), channel)
+
+ assert channel.result["code"] == b"200", channel.result
+ self.auth_user_id = temp_id
+ return channel.json_body["room_id"]
+
+ def invite(self, room=None, src=None, targ=None, expect_code=200, tok=None):
+ self.change_membership(
+ room=room,
+ src=src,
+ targ=targ,
+ tok=tok,
+ membership=Membership.INVITE,
+ expect_code=expect_code,
+ )
+
+ def join(self, room=None, user=None, expect_code=200, tok=None):
+ self.change_membership(
+ room=room,
+ src=user,
+ targ=user,
+ tok=tok,
+ membership=Membership.JOIN,
+ expect_code=expect_code,
+ )
+
+ def leave(self, room=None, user=None, expect_code=200, tok=None):
+ self.change_membership(
+ room=room,
+ src=user,
+ targ=user,
+ tok=tok,
+ membership=Membership.LEAVE,
+ expect_code=expect_code,
+ )
+
+ def change_membership(self, room, src, targ, membership, tok=None, expect_code=200):
+ temp_id = self.auth_user_id
+ self.auth_user_id = src
+
+ path = "/_matrix/client/r0/rooms/%s/state/m.room.member/%s" % (room, targ)
+ if tok:
+ path = path + "?access_token=%s" % tok
+
+ data = {"membership": membership}
+
+ request, channel = make_request(
+ b"PUT", path.encode('ascii'), json.dumps(data).encode('utf8')
+ )
+
+ request.render(self.resource)
+ wait_until_result(self.hs.get_reactor(), channel)
+
+ assert int(channel.result["code"]) == expect_code, (
+ "Expected: %d, got: %d, resp: %r"
+ % (expect_code, int(channel.result["code"]), channel.result["body"])
+ )
+
+ self.auth_user_id = temp_id
+
+ @defer.inlineCallbacks
+ def register(self, user_id):
+ (code, response) = yield self.mock_resource.trigger(
+ "POST",
+ "/_matrix/client/r0/register",
+ json.dumps(
+ {"user": user_id, "password": "test", "type": "m.login.password"}
+ ),
+ )
+ self.assertEquals(200, code)
+ defer.returnValue(response)
+
+ @defer.inlineCallbacks
+ def send(self, room_id, body=None, txn_id=None, tok=None, expect_code=200):
+ if txn_id is None:
+ txn_id = "m%s" % (str(time.time()))
+ if body is None:
+ body = "body_text_here"
+
+ path = "/_matrix/client/r0/rooms/%s/send/m.room.message/%s" % (room_id, txn_id)
+ content = '{"msgtype":"m.text","body":"%s"}' % body
+ if tok:
+ path = path + "?access_token=%s" % tok
+
+ (code, response) = yield self.mock_resource.trigger("PUT", path, content)
+ self.assertEquals(expect_code, code, msg=str(response))
diff --git a/tests/rest/client/v2_alpha/__init__.py b/tests/rest/client/v2_alpha/__init__.py
index 5170217d9e..e69de29bb2 100644
--- a/tests/rest/client/v2_alpha/__init__.py
+++ b/tests/rest/client/v2_alpha/__init__.py
@@ -1,62 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2015, 2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from tests import unittest
-
-from mock import Mock
-
-from ....utils import MockHttpResource, setup_test_homeserver
-
-from synapse.types import UserID
-
-from twisted.internet import defer
-
-
-PATH_PREFIX = "/_matrix/client/v2_alpha"
-
-
-class V2AlphaRestTestCase(unittest.TestCase):
- # Consumer must define
- # USER_ID = <some string>
- # TO_REGISTER = [<list of REST servlets to register>]
-
- @defer.inlineCallbacks
- def setUp(self):
- self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
-
- hs = yield setup_test_homeserver(
- datastore=self.make_datastore_mock(),
- http_client=None,
- resource_for_client=self.mock_resource,
- resource_for_federation=self.mock_resource,
- )
-
- def get_user_by_access_token(token=None, allow_guest=False):
- return {
- "user": UserID.from_string(self.USER_ID),
- "token_id": 1,
- "is_guest": False,
- }
- hs.get_auth().get_user_by_access_token = get_user_by_access_token
-
- for r in self.TO_REGISTER:
- r.register_servlets(hs, self.mock_resource)
-
- def make_datastore_mock(self):
- store = Mock(spec=[
- "insert_client_ip",
- ])
- store.get_app_service_by_token = Mock(return_value=None)
- return store
diff --git a/tests/rest/client/v2_alpha/test_filter.py b/tests/rest/client/v2_alpha/test_filter.py
index 76b833e119..e890f0feac 100644
--- a/tests/rest/client/v2_alpha/test_filter.py
+++ b/tests/rest/client/v2_alpha/test_filter.py
@@ -13,38 +13,37 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
-
-from tests import unittest
-
-from synapse.rest.client.v2_alpha import filter
-
-from synapse.api.errors import Codes
-
import synapse.types
-
+from synapse.api.errors import Codes
+from synapse.http.server import JsonResource
+from synapse.rest.client.v2_alpha import filter
from synapse.types import UserID
+from synapse.util import Clock
-from ....utils import MockHttpResource, setup_test_homeserver
+from tests import unittest
+from tests.server import (
+ ThreadedMemoryReactorClock as MemoryReactorClock,
+ make_request,
+ setup_test_homeserver,
+ wait_until_result,
+)
PATH_PREFIX = "/_matrix/client/v2_alpha"
class FilterTestCase(unittest.TestCase):
- USER_ID = "@apple:test"
+ USER_ID = b"@apple:test"
EXAMPLE_FILTER = {"room": {"timeline": {"types": ["m.room.message"]}}}
- EXAMPLE_FILTER_JSON = '{"room": {"timeline": {"types": ["m.room.message"]}}}'
+ EXAMPLE_FILTER_JSON = b'{"room": {"timeline": {"types": ["m.room.message"]}}}'
TO_REGISTER = [filter]
- @defer.inlineCallbacks
def setUp(self):
- self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
+ self.clock = MemoryReactorClock()
+ self.hs_clock = Clock(self.clock)
- self.hs = yield setup_test_homeserver(
- http_client=None,
- resource_for_client=self.mock_resource,
- resource_for_federation=self.mock_resource,
+ self.hs = setup_test_homeserver(
+ http_client=None, clock=self.hs_clock, reactor=self.clock
)
self.auth = self.hs.get_auth()
@@ -58,82 +57,103 @@ class FilterTestCase(unittest.TestCase):
def get_user_by_req(request, allow_guest=False, rights="access"):
return synapse.types.create_requester(
- UserID.from_string(self.USER_ID), 1, False, None)
+ UserID.from_string(self.USER_ID), 1, False, None
+ )
self.auth.get_user_by_access_token = get_user_by_access_token
self.auth.get_user_by_req = get_user_by_req
self.store = self.hs.get_datastore()
self.filtering = self.hs.get_filtering()
+ self.resource = JsonResource(self.hs)
for r in self.TO_REGISTER:
- r.register_servlets(self.hs, self.mock_resource)
+ r.register_servlets(self.hs, self.resource)
- @defer.inlineCallbacks
def test_add_filter(self):
- (code, response) = yield self.mock_resource.trigger(
- "POST", "/user/%s/filter" % (self.USER_ID), self.EXAMPLE_FILTER_JSON
- )
- self.assertEquals(200, code)
- self.assertEquals({"filter_id": "0"}, response)
- filter = yield self.store.get_user_filter(
- user_localpart='apple',
- filter_id=0,
+ request, channel = make_request(
+ b"POST",
+ b"/_matrix/client/r0/user/%s/filter" % (self.USER_ID),
+ self.EXAMPLE_FILTER_JSON,
)
- self.assertEquals(filter, self.EXAMPLE_FILTER)
+ request.render(self.resource)
+ wait_until_result(self.clock, channel)
+
+ self.assertEqual(channel.result["code"], b"200")
+ self.assertEqual(channel.json_body, {"filter_id": "0"})
+ filter = self.store.get_user_filter(user_localpart="apple", filter_id=0)
+ self.clock.advance(0)
+ self.assertEquals(filter.result, self.EXAMPLE_FILTER)
- @defer.inlineCallbacks
def test_add_filter_for_other_user(self):
- (code, response) = yield self.mock_resource.trigger(
- "POST", "/user/%s/filter" % ('@watermelon:test'), self.EXAMPLE_FILTER_JSON
+ request, channel = make_request(
+ b"POST",
+ b"/_matrix/client/r0/user/%s/filter" % (b"@watermelon:test"),
+ self.EXAMPLE_FILTER_JSON,
)
- self.assertEquals(403, code)
- self.assertEquals(response['errcode'], Codes.FORBIDDEN)
+ request.render(self.resource)
+ wait_until_result(self.clock, channel)
+
+ self.assertEqual(channel.result["code"], b"403")
+ self.assertEquals(channel.json_body["errcode"], Codes.FORBIDDEN)
- @defer.inlineCallbacks
def test_add_filter_non_local_user(self):
_is_mine = self.hs.is_mine
self.hs.is_mine = lambda target_user: False
- (code, response) = yield self.mock_resource.trigger(
- "POST", "/user/%s/filter" % (self.USER_ID), self.EXAMPLE_FILTER_JSON
+ request, channel = make_request(
+ b"POST",
+ b"/_matrix/client/r0/user/%s/filter" % (self.USER_ID),
+ self.EXAMPLE_FILTER_JSON,
)
+ request.render(self.resource)
+ wait_until_result(self.clock, channel)
+
self.hs.is_mine = _is_mine
- self.assertEquals(403, code)
- self.assertEquals(response['errcode'], Codes.FORBIDDEN)
+ self.assertEqual(channel.result["code"], b"403")
+ self.assertEquals(channel.json_body["errcode"], Codes.FORBIDDEN)
- @defer.inlineCallbacks
def test_get_filter(self):
- filter_id = yield self.filtering.add_user_filter(
- user_localpart='apple',
- user_filter=self.EXAMPLE_FILTER
+ filter_id = self.filtering.add_user_filter(
+ user_localpart="apple", user_filter=self.EXAMPLE_FILTER
)
- (code, response) = yield self.mock_resource.trigger_get(
- "/user/%s/filter/%s" % (self.USER_ID, filter_id)
+ self.clock.advance(1)
+ filter_id = filter_id.result
+ request, channel = make_request(
+ b"GET", b"/_matrix/client/r0/user/%s/filter/%s" % (self.USER_ID, filter_id)
)
- self.assertEquals(200, code)
- self.assertEquals(self.EXAMPLE_FILTER, response)
+ request.render(self.resource)
+ wait_until_result(self.clock, channel)
+
+ self.assertEqual(channel.result["code"], b"200")
+ self.assertEquals(channel.json_body, self.EXAMPLE_FILTER)
- @defer.inlineCallbacks
def test_get_filter_non_existant(self):
- (code, response) = yield self.mock_resource.trigger_get(
- "/user/%s/filter/12382148321" % (self.USER_ID)
+ request, channel = make_request(
+ b"GET", "/_matrix/client/r0/user/%s/filter/12382148321" % (self.USER_ID)
)
- self.assertEquals(400, code)
- self.assertEquals(response['errcode'], Codes.NOT_FOUND)
+ request.render(self.resource)
+ wait_until_result(self.clock, channel)
+
+ self.assertEqual(channel.result["code"], b"400")
+ self.assertEquals(channel.json_body["errcode"], Codes.NOT_FOUND)
# Currently invalid params do not have an appropriate errcode
# in errors.py
- @defer.inlineCallbacks
def test_get_filter_invalid_id(self):
- (code, response) = yield self.mock_resource.trigger_get(
- "/user/%s/filter/foobar" % (self.USER_ID)
+ request, channel = make_request(
+ b"GET", "/_matrix/client/r0/user/%s/filter/foobar" % (self.USER_ID)
)
- self.assertEquals(400, code)
+ request.render(self.resource)
+ wait_until_result(self.clock, channel)
+
+ self.assertEqual(channel.result["code"], b"400")
# No ID also returns an invalid_id error
- @defer.inlineCallbacks
def test_get_filter_no_id(self):
- (code, response) = yield self.mock_resource.trigger_get(
- "/user/%s/filter/" % (self.USER_ID)
+ request, channel = make_request(
+ b"GET", "/_matrix/client/r0/user/%s/filter/" % (self.USER_ID)
)
- self.assertEquals(400, code)
+ request.render(self.resource)
+ wait_until_result(self.clock, channel)
+
+ self.assertEqual(channel.result["code"], b"400")
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index 8aba456510..e004d8fc73 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -1,163 +1,193 @@
-from twisted.python import failure
+import json
-from synapse.rest.client.v2_alpha.register import RegisterRestServlet
-from synapse.api.errors import SynapseError, InteractiveAuthIncompleteError
-from twisted.internet import defer
from mock import Mock
+
+from twisted.python import failure
+from twisted.test.proto_helpers import MemoryReactorClock
+
+from synapse.api.errors import InteractiveAuthIncompleteError
+from synapse.http.server import JsonResource
+from synapse.rest.client.v2_alpha.register import register_servlets
+from synapse.util import Clock
+
from tests import unittest
-from tests.utils import mock_getRawHeaders
-import json
+from tests.server import make_request, setup_test_homeserver, wait_until_result
class RegisterRestServletTestCase(unittest.TestCase):
-
def setUp(self):
- # do the dance to hook up request data to self.request_data
- self.request_data = ""
- self.request = Mock(
- content=Mock(read=Mock(side_effect=lambda: self.request_data)),
- path='/_matrix/api/v2_alpha/register'
- )
- self.request.args = {}
- self.request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+
+ self.clock = MemoryReactorClock()
+ self.hs_clock = Clock(self.clock)
+ self.url = b"/_matrix/client/r0/register"
self.appservice = None
- self.auth = Mock(get_appservice_by_req=Mock(
- side_effect=lambda x: self.appservice)
+ self.auth = Mock(
+ get_appservice_by_req=Mock(side_effect=lambda x: self.appservice)
)
self.auth_result = failure.Failure(InteractiveAuthIncompleteError(None))
self.auth_handler = Mock(
check_auth=Mock(side_effect=lambda x, y, z: self.auth_result),
- get_session_data=Mock(return_value=None)
+ get_session_data=Mock(return_value=None),
)
self.registration_handler = Mock()
self.identity_handler = Mock()
self.login_handler = Mock()
self.device_handler = Mock()
+ self.device_handler.check_device_registered = Mock(return_value="FAKE")
+
+ self.datastore = Mock(return_value=Mock())
+ self.datastore.get_current_state_deltas = Mock(return_value=[])
# do the dance to hook it up to the hs global
self.handlers = Mock(
registration_handler=self.registration_handler,
identity_handler=self.identity_handler,
- login_handler=self.login_handler
+ login_handler=self.login_handler,
+ )
+ self.hs = setup_test_homeserver(
+ http_client=None, clock=self.hs_clock, reactor=self.clock
)
- self.hs = Mock()
- self.hs.hostname = "superbig~testing~thing.com"
self.hs.get_auth = Mock(return_value=self.auth)
self.hs.get_handlers = Mock(return_value=self.handlers)
self.hs.get_auth_handler = Mock(return_value=self.auth_handler)
self.hs.get_device_handler = Mock(return_value=self.device_handler)
+ self.hs.get_datastore = Mock(return_value=self.datastore)
self.hs.config.enable_registration = True
self.hs.config.registrations_require_3pid = []
self.hs.config.auto_join_rooms = []
- # init the thing we're testing
- self.servlet = RegisterRestServlet(self.hs)
+ self.resource = JsonResource(self.hs)
+ register_servlets(self.hs, self.resource)
- @defer.inlineCallbacks
def test_POST_appservice_registration_valid(self):
user_id = "@kermit:muppet"
token = "kermits_access_token"
- self.request.args = {
- "access_token": "i_am_an_app_service"
- }
- self.request_data = json.dumps({
- "username": "kermit"
- })
- self.appservice = {
- "id": "1234"
- }
- self.registration_handler.appservice_register = Mock(
- return_value=user_id
- )
- self.auth_handler.get_access_token_for_user_id = Mock(
- return_value=token
+ self.appservice = {"id": "1234"}
+ self.registration_handler.appservice_register = Mock(return_value=user_id)
+ self.auth_handler.get_access_token_for_user_id = Mock(return_value=token)
+ request_data = json.dumps({"username": "kermit"})
+
+ request, channel = make_request(
+ b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
)
+ request.render(self.resource)
+ wait_until_result(self.clock, channel)
- (code, result) = yield self.servlet.on_POST(self.request)
- self.assertEquals(code, 200)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
det_data = {
"user_id": user_id,
"access_token": token,
- "home_server": self.hs.hostname
+ "home_server": self.hs.hostname,
}
- self.assertDictContainsSubset(det_data, result)
+ self.assertDictContainsSubset(det_data, json.loads(channel.result["body"]))
- @defer.inlineCallbacks
def test_POST_appservice_registration_invalid(self):
- self.request.args = {
- "access_token": "i_am_an_app_service"
- }
-
- self.request_data = json.dumps({
- "username": "kermit"
- })
self.appservice = None # no application service exists
- result = yield self.servlet.on_POST(self.request)
- self.assertEquals(result, (401, None))
+ request_data = json.dumps({"username": "kermit"})
+ request, channel = make_request(
+ b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
+ )
+ request.render(self.resource)
+ wait_until_result(self.clock, channel)
+
+ self.assertEquals(channel.result["code"], b"401", channel.result)
def test_POST_bad_password(self):
- self.request_data = json.dumps({
- "username": "kermit",
- "password": 666
- })
- d = self.servlet.on_POST(self.request)
- return self.assertFailure(d, SynapseError)
+ request_data = json.dumps({"username": "kermit", "password": 666})
+ request, channel = make_request(b"POST", self.url, request_data)
+ request.render(self.resource)
+ wait_until_result(self.clock, channel)
+
+ self.assertEquals(channel.result["code"], b"400", channel.result)
+ self.assertEquals(
+ json.loads(channel.result["body"])["error"], "Invalid password"
+ )
def test_POST_bad_username(self):
- self.request_data = json.dumps({
- "username": 777,
- "password": "monkey"
- })
- d = self.servlet.on_POST(self.request)
- return self.assertFailure(d, SynapseError)
-
- @defer.inlineCallbacks
+ request_data = json.dumps({"username": 777, "password": "monkey"})
+ request, channel = make_request(b"POST", self.url, request_data)
+ request.render(self.resource)
+ wait_until_result(self.clock, channel)
+
+ self.assertEquals(channel.result["code"], b"400", channel.result)
+ self.assertEquals(
+ json.loads(channel.result["body"])["error"], "Invalid username"
+ )
+
def test_POST_user_valid(self):
user_id = "@kermit:muppet"
token = "kermits_access_token"
device_id = "frogfone"
- self.request_data = json.dumps({
- "username": "kermit",
- "password": "monkey",
- "device_id": device_id,
- })
+ request_data = json.dumps(
+ {"username": "kermit", "password": "monkey", "device_id": device_id}
+ )
self.registration_handler.check_username = Mock(return_value=True)
- self.auth_result = (None, {
- "username": "kermit",
- "password": "monkey"
- }, None)
+ self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None)
self.registration_handler.register = Mock(return_value=(user_id, None))
- self.auth_handler.get_access_token_for_user_id = Mock(
- return_value=token
- )
- self.device_handler.check_device_registered = \
- Mock(return_value=device_id)
+ self.auth_handler.get_access_token_for_user_id = Mock(return_value=token)
+ self.device_handler.check_device_registered = Mock(return_value=device_id)
+
+ request, channel = make_request(b"POST", self.url, request_data)
+ request.render(self.resource)
+ wait_until_result(self.clock, channel)
- (code, result) = yield self.servlet.on_POST(self.request)
- self.assertEquals(code, 200)
det_data = {
"user_id": user_id,
"access_token": token,
"home_server": self.hs.hostname,
"device_id": device_id,
}
- self.assertDictContainsSubset(det_data, result)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+ self.assertDictContainsSubset(det_data, json.loads(channel.result["body"]))
self.auth_handler.get_login_tuple_for_user_id(
- user_id, device_id=device_id, initial_device_display_name=None)
+ user_id, device_id=device_id, initial_device_display_name=None
+ )
def test_POST_disabled_registration(self):
self.hs.config.enable_registration = False
- self.request_data = json.dumps({
- "username": "kermit",
- "password": "monkey"
- })
+ request_data = json.dumps({"username": "kermit", "password": "monkey"})
self.registration_handler.check_username = Mock(return_value=True)
- self.auth_result = (None, {
- "username": "kermit",
- "password": "monkey"
- }, None)
+ self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None)
self.registration_handler.register = Mock(return_value=("@user:id", "t"))
- d = self.servlet.on_POST(self.request)
- return self.assertFailure(d, SynapseError)
+
+ request, channel = make_request(b"POST", self.url, request_data)
+ request.render(self.resource)
+ wait_until_result(self.clock, channel)
+
+ self.assertEquals(channel.result["code"], b"403", channel.result)
+ self.assertEquals(
+ json.loads(channel.result["body"])["error"],
+ "Registration has been disabled",
+ )
+
+ def test_POST_guest_registration(self):
+ user_id = "a@b"
+ self.hs.config.macaroon_secret_key = "test"
+ self.hs.config.allow_guest_access = True
+ self.registration_handler.register = Mock(return_value=(user_id, None))
+
+ request, channel = make_request(b"POST", self.url + b"?kind=guest", b"{}")
+ request.render(self.resource)
+ wait_until_result(self.clock, channel)
+
+ det_data = {
+ "user_id": user_id,
+ "home_server": self.hs.hostname,
+ "device_id": "guest_device",
+ }
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+ self.assertDictContainsSubset(det_data, json.loads(channel.result["body"]))
+
+ def test_POST_disabled_guest_registration(self):
+ self.hs.config.allow_guest_access = False
+
+ request, channel = make_request(b"POST", self.url + b"?kind=guest", b"{}")
+ request.render(self.resource)
+ wait_until_result(self.clock, channel)
+
+ self.assertEquals(channel.result["code"], b"403", channel.result)
+ self.assertEquals(
+ json.loads(channel.result["body"])["error"], "Guest access is disabled"
+ )
diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py
new file mode 100644
index 0000000000..03ec3993b2
--- /dev/null
+++ b/tests/rest/client/v2_alpha/test_sync.py
@@ -0,0 +1,87 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import synapse.types
+from synapse.http.server import JsonResource
+from synapse.rest.client.v2_alpha import sync
+from synapse.types import UserID
+from synapse.util import Clock
+
+from tests import unittest
+from tests.server import (
+ ThreadedMemoryReactorClock as MemoryReactorClock,
+ make_request,
+ setup_test_homeserver,
+ wait_until_result,
+)
+
+PATH_PREFIX = "/_matrix/client/v2_alpha"
+
+
+class FilterTestCase(unittest.TestCase):
+
+ USER_ID = b"@apple:test"
+ TO_REGISTER = [sync]
+
+ def setUp(self):
+ self.clock = MemoryReactorClock()
+ self.hs_clock = Clock(self.clock)
+
+ self.hs = setup_test_homeserver(
+ http_client=None, clock=self.hs_clock, reactor=self.clock
+ )
+
+ self.auth = self.hs.get_auth()
+
+ def get_user_by_access_token(token=None, allow_guest=False):
+ return {
+ "user": UserID.from_string(self.USER_ID),
+ "token_id": 1,
+ "is_guest": False,
+ }
+
+ def get_user_by_req(request, allow_guest=False, rights="access"):
+ return synapse.types.create_requester(
+ UserID.from_string(self.USER_ID), 1, False, None
+ )
+
+ self.auth.get_user_by_access_token = get_user_by_access_token
+ self.auth.get_user_by_req = get_user_by_req
+
+ self.store = self.hs.get_datastore()
+ self.filtering = self.hs.get_filtering()
+ self.resource = JsonResource(self.hs)
+
+ for r in self.TO_REGISTER:
+ r.register_servlets(self.hs, self.resource)
+
+ def test_sync_argless(self):
+ request, channel = make_request(b"GET", b"/_matrix/client/r0/sync")
+ request.render(self.resource)
+ wait_until_result(self.clock, channel)
+
+ self.assertEqual(channel.result["code"], b"200")
+ self.assertTrue(
+ set(
+ [
+ "next_batch",
+ "rooms",
+ "presence",
+ "account_data",
+ "to_device",
+ "device_lists",
+ ]
+ ).issubset(set(channel.json_body.keys()))
+ )
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index eef38b6781..bf254a260d 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -14,21 +14,21 @@
# limitations under the License.
-from twisted.internet import defer
+import os
+import shutil
+import tempfile
+
+from mock import Mock
+
+from twisted.internet import defer, reactor
from synapse.rest.media.v1._base import FileInfo
-from synapse.rest.media.v1.media_storage import MediaStorage
from synapse.rest.media.v1.filepath import MediaFilePaths
+from synapse.rest.media.v1.media_storage import MediaStorage
from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend
-from mock import Mock
-
from tests import unittest
-import os
-import shutil
-import tempfile
-
class MediaStorageTests(unittest.TestCase):
def setUp(self):
@@ -38,6 +38,7 @@ class MediaStorageTests(unittest.TestCase):
self.secondary_base_path = os.path.join(self.test_dir, "secondary")
hs = Mock()
+ hs.get_reactor = Mock(return_value=reactor)
hs.config.media_store_path = self.primary_base_path
storage_providers = [FileStorageProviderBackend(
@@ -46,7 +47,7 @@ class MediaStorageTests(unittest.TestCase):
self.filepaths = MediaFilePaths(self.primary_base_path)
self.media_storage = MediaStorage(
- self.primary_base_path, self.filepaths, storage_providers,
+ hs, self.primary_base_path, self.filepaths, storage_providers,
)
def tearDown(self):
diff --git a/tests/server.py b/tests/server.py
new file mode 100644
index 0000000000..c611dd6059
--- /dev/null
+++ b/tests/server.py
@@ -0,0 +1,193 @@
+import json
+from io import BytesIO
+
+from six import text_type
+
+import attr
+
+from twisted.internet import threads
+from twisted.internet.defer import Deferred
+from twisted.python.failure import Failure
+from twisted.test.proto_helpers import MemoryReactorClock
+
+from synapse.http.site import SynapseRequest
+
+from tests.utils import setup_test_homeserver as _sth
+
+
+@attr.s
+class FakeChannel(object):
+ """
+ A fake Twisted Web Channel (the part that interfaces with the
+ wire).
+ """
+
+ result = attr.ib(default=attr.Factory(dict))
+
+ @property
+ def json_body(self):
+ if not self.result:
+ raise Exception("No result yet.")
+ return json.loads(self.result["body"])
+
+ def writeHeaders(self, version, code, reason, headers):
+ self.result["version"] = version
+ self.result["code"] = code
+ self.result["reason"] = reason
+ self.result["headers"] = headers
+
+ def write(self, content):
+ if "body" not in self.result:
+ self.result["body"] = b""
+
+ self.result["body"] += content
+
+ def requestDone(self, _self):
+ self.result["done"] = True
+
+ def getPeer(self):
+ return None
+
+ def getHost(self):
+ return None
+
+ @property
+ def transport(self):
+ return self
+
+
+class FakeSite:
+ """
+ A fake Twisted Web Site, with mocks of the extra things that
+ Synapse adds.
+ """
+
+ server_version_string = b"1"
+ site_tag = "test"
+
+ @property
+ def access_logger(self):
+ class FakeLogger:
+ def info(self, *args, **kwargs):
+ pass
+
+ return FakeLogger()
+
+
+def make_request(method, path, content=b""):
+ """
+ Make a web request using the given method and path, feed it the
+ content, and return the Request and the Channel underneath.
+ """
+
+ # Decorate it to be the full path
+ if not path.startswith(b"/_matrix"):
+ path = b"/_matrix/client/r0/" + path
+ path = path.replace("//", "/")
+
+ if isinstance(content, text_type):
+ content = content.encode('utf8')
+
+ site = FakeSite()
+ channel = FakeChannel()
+
+ req = SynapseRequest(site, channel)
+ req.process = lambda: b""
+ req.content = BytesIO(content)
+ req.requestReceived(method, path, b"1.1")
+
+ return req, channel
+
+
+def wait_until_result(clock, channel, timeout=100):
+ """
+ Wait until the channel has a result.
+ """
+ clock.run()
+ x = 0
+
+ while not channel.result:
+ x += 1
+
+ if x > timeout:
+ raise Exception("Timed out waiting for request to finish.")
+
+ clock.advance(0.1)
+
+
+def render(request, resource, clock):
+ request.render(resource)
+ wait_until_result(clock, request._channel)
+
+
+class ThreadedMemoryReactorClock(MemoryReactorClock):
+ """
+ A MemoryReactorClock that supports callFromThread.
+ """
+ def callFromThread(self, callback, *args, **kwargs):
+ """
+ Make the callback fire in the next reactor iteration.
+ """
+ d = Deferred()
+ d.addCallback(lambda x: callback(*args, **kwargs))
+ self.callLater(0, d.callback, True)
+ return d
+
+
+def setup_test_homeserver(*args, **kwargs):
+ """
+ Set up a synchronous test server, driven by the reactor used by
+ the homeserver.
+ """
+ d = _sth(*args, **kwargs).result
+
+ # Make the thread pool synchronous.
+ clock = d.get_clock()
+ pool = d.get_db_pool()
+
+ def runWithConnection(func, *args, **kwargs):
+ return threads.deferToThreadPool(
+ pool._reactor,
+ pool.threadpool,
+ pool._runWithConnection,
+ func,
+ *args,
+ **kwargs
+ )
+
+ def runInteraction(interaction, *args, **kwargs):
+ return threads.deferToThreadPool(
+ pool._reactor,
+ pool.threadpool,
+ pool._runInteraction,
+ interaction,
+ *args,
+ **kwargs
+ )
+
+ pool.runWithConnection = runWithConnection
+ pool.runInteraction = runInteraction
+
+ class ThreadPool:
+ """
+ Threadless thread pool.
+ """
+ def start(self):
+ pass
+
+ def callInThreadWithCallback(self, onResult, function, *args, **kwargs):
+ def _(res):
+ if isinstance(res, Failure):
+ onResult(False, res)
+ else:
+ onResult(True, res)
+
+ d = Deferred()
+ d.addCallback(lambda x: function(*args, **kwargs))
+ d.addBoth(_)
+ clock._reactor.callLater(0, d.callback, True)
+ return d
+
+ clock.threadpool = ThreadPool()
+ pool.threadpool = ThreadPool()
+ return d
diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index 3cfa21c9f8..6d6f00c5c5 100644
--- a/tests/storage/test__base.py
+++ b/tests/storage/test__base.py
@@ -14,15 +14,15 @@
# limitations under the License.
-from tests import unittest
-from twisted.internet import defer
-
from mock import Mock
-from synapse.util.async import ObservableDeferred
+from twisted.internet import defer
+from synapse.util.async import ObservableDeferred
from synapse.util.caches.descriptors import Cache, cached
+from tests import unittest
+
class CacheTestCase(unittest.TestCase):
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index 00825498b1..099861b27c 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -12,21 +12,25 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import json
+import os
import tempfile
-from synapse.config._base import ConfigError
-from tests import unittest
+
+from mock import Mock
+
+import yaml
+
from twisted.internet import defer
-from tests.utils import setup_test_homeserver
from synapse.appservice import ApplicationService, ApplicationServiceState
+from synapse.config._base import ConfigError
from synapse.storage.appservice import (
- ApplicationServiceStore, ApplicationServiceTransactionStore
+ ApplicationServiceStore,
+ ApplicationServiceTransactionStore,
)
-import json
-import os
-import yaml
-from mock import Mock
+from tests import unittest
+from tests.utils import setup_test_homeserver
class ApplicationServiceStoreTestCase(unittest.TestCase):
diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py
index 1286b4ce2d..ab1f310572 100644
--- a/tests/storage/test_background_update.py
+++ b/tests/storage/test_background_update.py
@@ -1,10 +1,10 @@
-from tests import unittest
+from mock import Mock
+
from twisted.internet import defer
+from tests import unittest
from tests.utils import setup_test_homeserver
-from mock import Mock
-
class BackgroundUpdateTestCase(unittest.TestCase):
diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index 0ac910e76f..1d1234ee39 100644
--- a/tests/storage/test_base.py
+++ b/tests/storage/test_base.py
@@ -14,18 +14,18 @@
# limitations under the License.
-from tests import unittest
-from twisted.internet import defer
+from collections import OrderedDict
from mock import Mock
-from collections import OrderedDict
+from twisted.internet import defer
from synapse.server import HomeServer
-
from synapse.storage._base import SQLBaseStore
from synapse.storage.engines import create_engine
+from tests import unittest
+
class SQLBaseStoreTestCase(unittest.TestCase):
""" Test the "simple" SQL generating methods in SQLBaseStore. """
diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index f8725acea0..a54cc6bc32 100644
--- a/tests/storage/test_devices.py
+++ b/tests/storage/test_devices.py
@@ -16,6 +16,7 @@
from twisted.internet import defer
import synapse.api.errors
+
import tests.unittest
import tests.utils
diff --git a/tests/storage/test_directory.py b/tests/storage/test_directory.py
index 95709cd50a..129ebaf343 100644
--- a/tests/storage/test_directory.py
+++ b/tests/storage/test_directory.py
@@ -14,12 +14,12 @@
# limitations under the License.
-from tests import unittest
from twisted.internet import defer
from synapse.storage.directory import DirectoryStore
-from synapse.types import RoomID, RoomAlias
+from synapse.types import RoomAlias, RoomID
+from tests import unittest
from tests.utils import setup_test_homeserver
diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index 3cbf9a78b1..8430fc7ba6 100644
--- a/tests/storage/test_event_push_actions.py
+++ b/tests/storage/test_event_push_actions.py
@@ -13,11 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from mock import Mock
+
from twisted.internet import defer
import tests.unittest
import tests.utils
-from mock import Mock
USER_ID = "@user:example.com"
diff --git a/tests/storage/test_keys.py b/tests/storage/test_keys.py
index 0be790d8f8..3a3d002782 100644
--- a/tests/storage/test_keys.py
+++ b/tests/storage/test_keys.py
@@ -14,6 +14,7 @@
# limitations under the License.
import signedjson.key
+
from twisted.internet import defer
import tests.unittest
diff --git a/tests/storage/test_presence.py b/tests/storage/test_presence.py
index f5fcb611d4..3276b39504 100644
--- a/tests/storage/test_presence.py
+++ b/tests/storage/test_presence.py
@@ -14,13 +14,13 @@
# limitations under the License.
-from tests import unittest
from twisted.internet import defer
from synapse.storage.presence import PresenceStore
from synapse.types import UserID
-from tests.utils import setup_test_homeserver, MockClock
+from tests import unittest
+from tests.utils import MockClock, setup_test_homeserver
class PresenceStoreTestCase(unittest.TestCase):
diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py
index 423710c9c1..2c95e5e95a 100644
--- a/tests/storage/test_profile.py
+++ b/tests/storage/test_profile.py
@@ -14,12 +14,12 @@
# limitations under the License.
-from tests import unittest
from twisted.internet import defer
from synapse.storage.profile import ProfileStore
from synapse.types import UserID
+from tests import unittest
from tests.utils import setup_test_homeserver
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index 888ddfaddd..475ec900c4 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -14,16 +14,16 @@
# limitations under the License.
-from tests import unittest
+from mock import Mock
+
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
-from synapse.types import UserID, RoomID
+from synapse.types import RoomID, UserID
+from tests import unittest
from tests.utils import setup_test_homeserver
-from mock import Mock
-
class RedactionTestCase(unittest.TestCase):
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index f863b75846..7821ea3fa3 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -14,9 +14,9 @@
# limitations under the License.
-from tests import unittest
from twisted.internet import defer
+from tests import unittest
from tests.utils import setup_test_homeserver
diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py
index ef8a4d234f..ae8ae94b6d 100644
--- a/tests/storage/test_room.py
+++ b/tests/storage/test_room.py
@@ -14,12 +14,12 @@
# limitations under the License.
-from tests import unittest
from twisted.internet import defer
from synapse.api.constants import EventTypes
-from synapse.types import UserID, RoomID, RoomAlias
+from synapse.types import RoomAlias, RoomID, UserID
+from tests import unittest
from tests.utils import setup_test_homeserver
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index 657b279e5d..c5fd54f67e 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -14,16 +14,16 @@
# limitations under the License.
-from tests import unittest
+from mock import Mock
+
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
-from synapse.types import UserID, RoomID
+from synapse.types import RoomID, UserID
+from tests import unittest
from tests.utils import setup_test_homeserver
-from mock import Mock
-
class RoomMemberStoreTestCase(unittest.TestCase):
diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index 0891308f25..23fad12bca 100644
--- a/tests/storage/test_user_directory.py
+++ b/tests/storage/test_user_directory.py
@@ -17,6 +17,7 @@ from twisted.internet import defer
from synapse.storage import UserDirectoryStore
from synapse.storage.roommember import ProfileInfo
+
from tests import unittest
from tests.utils import setup_test_homeserver
diff --git a/tests/test_distributor.py b/tests/test_distributor.py
index 010aeaee7e..71d11cda77 100644
--- a/tests/test_distributor.py
+++ b/tests/test_distributor.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,13 +14,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from . import unittest
-from twisted.internet import defer
-
from mock import Mock, patch
from synapse.util.distributor import Distributor
-from synapse.util.async import run_on_reactor
+
+from . import unittest
class DistributorTestCase(unittest.TestCase):
@@ -27,38 +26,15 @@ class DistributorTestCase(unittest.TestCase):
def setUp(self):
self.dist = Distributor()
- @defer.inlineCallbacks
def test_signal_dispatch(self):
self.dist.declare("alert")
observer = Mock()
self.dist.observe("alert", observer)
- d = self.dist.fire("alert", 1, 2, 3)
- yield d
- self.assertTrue(d.called)
+ self.dist.fire("alert", 1, 2, 3)
observer.assert_called_with(1, 2, 3)
- @defer.inlineCallbacks
- def test_signal_dispatch_deferred(self):
- self.dist.declare("whine")
-
- d_inner = defer.Deferred()
-
- def observer():
- return d_inner
-
- self.dist.observe("whine", observer)
-
- d_outer = self.dist.fire("whine")
-
- self.assertFalse(d_outer.called)
-
- d_inner.callback(None)
- yield d_outer
- self.assertTrue(d_outer.called)
-
- @defer.inlineCallbacks
def test_signal_catch(self):
self.dist.declare("alarm")
@@ -71,9 +47,7 @@ class DistributorTestCase(unittest.TestCase):
with patch(
"synapse.util.distributor.logger", spec=["warning"]
) as mock_logger:
- d = self.dist.fire("alarm", "Go")
- yield d
- self.assertTrue(d.called)
+ self.dist.fire("alarm", "Go")
observers[0].assert_called_once_with("Go")
observers[1].assert_called_once_with("Go")
@@ -83,35 +57,12 @@ class DistributorTestCase(unittest.TestCase):
mock_logger.warning.call_args[0][0], str
)
- @defer.inlineCallbacks
- def test_signal_catch_no_suppress(self):
- # Gut-wrenching
- self.dist.suppress_failures = False
-
- self.dist.declare("whail")
-
- class MyException(Exception):
- pass
-
- @defer.inlineCallbacks
- def observer():
- yield run_on_reactor()
- raise MyException("Oopsie")
-
- self.dist.observe("whail", observer)
-
- d = self.dist.fire("whail")
-
- yield self.assertFailure(d, MyException)
- self.dist.suppress_failures = True
-
- @defer.inlineCallbacks
def test_signal_prereg(self):
observer = Mock()
self.dist.observe("flare", observer)
self.dist.declare("flare")
- yield self.dist.fire("flare", 4, 5)
+ self.dist.fire("flare", 4, 5)
observer.assert_called_with(4, 5)
diff --git a/tests/test_dns.py b/tests/test_dns.py
index 3b360a0fc7..b647d92697 100644
--- a/tests/test_dns.py
+++ b/tests/test_dns.py
@@ -13,16 +13,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from . import unittest
+from mock import Mock
+
from twisted.internet import defer
from twisted.names import dns, error
-from mock import Mock
-
from synapse.http.endpoint import resolve_service
from tests.utils import MockClock
+from . import unittest
+
@unittest.DEBUG
class DnsTestCase(unittest.TestCase):
diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py
new file mode 100644
index 0000000000..06112430e5
--- /dev/null
+++ b/tests/test_event_auth.py
@@ -0,0 +1,152 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+from synapse import event_auth
+from synapse.api.errors import AuthError
+from synapse.events import FrozenEvent
+
+
+class EventAuthTestCase(unittest.TestCase):
+ def test_random_users_cannot_send_state_before_first_pl(self):
+ """
+ Check that, before the first PL lands, the creator is the only user
+ that can send a state event.
+ """
+ creator = "@creator:example.com"
+ joiner = "@joiner:example.com"
+ auth_events = {
+ ("m.room.create", ""): _create_event(creator),
+ ("m.room.member", creator): _join_event(creator),
+ ("m.room.member", joiner): _join_event(joiner),
+ }
+
+ # creator should be able to send state
+ event_auth.check(
+ _random_state_event(creator), auth_events,
+ do_sig_check=False,
+ )
+
+ # joiner should not be able to send state
+ self.assertRaises(
+ AuthError,
+ event_auth.check,
+ _random_state_event(joiner),
+ auth_events,
+ do_sig_check=False,
+ ),
+
+ def test_state_default_level(self):
+ """
+ Check that users above the state_default level can send state and
+ those below cannot
+ """
+ creator = "@creator:example.com"
+ pleb = "@joiner:example.com"
+ king = "@joiner2:example.com"
+
+ auth_events = {
+ ("m.room.create", ""): _create_event(creator),
+ ("m.room.member", creator): _join_event(creator),
+ ("m.room.power_levels", ""): _power_levels_event(creator, {
+ "state_default": "30",
+ "users": {
+ pleb: "29",
+ king: "30",
+ },
+ }),
+ ("m.room.member", pleb): _join_event(pleb),
+ ("m.room.member", king): _join_event(king),
+ }
+
+ # pleb should not be able to send state
+ self.assertRaises(
+ AuthError,
+ event_auth.check,
+ _random_state_event(pleb),
+ auth_events,
+ do_sig_check=False,
+ ),
+
+ # king should be able to send state
+ event_auth.check(
+ _random_state_event(king), auth_events,
+ do_sig_check=False,
+ )
+
+
+# helpers for making events
+
+TEST_ROOM_ID = "!test:room"
+
+
+def _create_event(user_id):
+ return FrozenEvent({
+ "room_id": TEST_ROOM_ID,
+ "event_id": _get_event_id(),
+ "type": "m.room.create",
+ "sender": user_id,
+ "content": {
+ "creator": user_id,
+ },
+ })
+
+
+def _join_event(user_id):
+ return FrozenEvent({
+ "room_id": TEST_ROOM_ID,
+ "event_id": _get_event_id(),
+ "type": "m.room.member",
+ "sender": user_id,
+ "state_key": user_id,
+ "content": {
+ "membership": "join",
+ },
+ })
+
+
+def _power_levels_event(sender, content):
+ return FrozenEvent({
+ "room_id": TEST_ROOM_ID,
+ "event_id": _get_event_id(),
+ "type": "m.room.power_levels",
+ "sender": sender,
+ "state_key": "",
+ "content": content,
+ })
+
+
+def _random_state_event(sender):
+ return FrozenEvent({
+ "room_id": TEST_ROOM_ID,
+ "event_id": _get_event_id(),
+ "type": "test.state",
+ "sender": sender,
+ "state_key": "",
+ "content": {
+ "membership": "join",
+ },
+ })
+
+
+event_count = 0
+
+
+def _get_event_id():
+ global event_count
+ c = event_count
+ event_count += 1
+ return "!%i:example.com" % (c, )
diff --git a/tests/test_federation.py b/tests/test_federation.py
new file mode 100644
index 0000000000..159a136971
--- /dev/null
+++ b/tests/test_federation.py
@@ -0,0 +1,243 @@
+
+from mock import Mock
+
+from twisted.internet.defer import maybeDeferred, succeed
+
+from synapse.events import FrozenEvent
+from synapse.types import Requester, UserID
+from synapse.util import Clock
+
+from tests import unittest
+from tests.server import ThreadedMemoryReactorClock, setup_test_homeserver
+
+
+class MessageAcceptTests(unittest.TestCase):
+ def setUp(self):
+
+ self.http_client = Mock()
+ self.reactor = ThreadedMemoryReactorClock()
+ self.hs_clock = Clock(self.reactor)
+ self.homeserver = setup_test_homeserver(
+ http_client=self.http_client, clock=self.hs_clock, reactor=self.reactor
+ )
+
+ user_id = UserID("us", "test")
+ our_user = Requester(user_id, None, False, None, None)
+ room_creator = self.homeserver.get_room_creation_handler()
+ room = room_creator.create_room(
+ our_user, room_creator.PRESETS_DICT["public_chat"], ratelimit=False
+ )
+ self.reactor.advance(0.1)
+ self.room_id = self.successResultOf(room)["room_id"]
+
+ # Figure out what the most recent event is
+ most_recent = self.successResultOf(
+ maybeDeferred(
+ self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id
+ )
+ )[0]
+
+ join_event = FrozenEvent(
+ {
+ "room_id": self.room_id,
+ "sender": "@baduser:test.serv",
+ "state_key": "@baduser:test.serv",
+ "event_id": "$join:test.serv",
+ "depth": 1000,
+ "origin_server_ts": 1,
+ "type": "m.room.member",
+ "origin": "test.servx",
+ "content": {"membership": "join"},
+ "auth_events": [],
+ "prev_state": [(most_recent, {})],
+ "prev_events": [(most_recent, {})],
+ }
+ )
+
+ self.handler = self.homeserver.get_handlers().federation_handler
+ self.handler.do_auth = lambda *a, **b: succeed(True)
+ self.client = self.homeserver.get_federation_client()
+ self.client._check_sigs_and_hash_and_fetch = lambda dest, pdus, **k: succeed(
+ pdus
+ )
+
+ # Send the join, it should return None (which is not an error)
+ d = self.handler.on_receive_pdu(
+ "test.serv", join_event, sent_to_us_directly=True
+ )
+ self.reactor.advance(1)
+ self.assertEqual(self.successResultOf(d), None)
+
+ # Make sure we actually joined the room
+ self.assertEqual(
+ self.successResultOf(
+ maybeDeferred(
+ self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id
+ )
+ )[0],
+ "$join:test.serv",
+ )
+
+ def test_cant_hide_direct_ancestors(self):
+ """
+ If you send a message, you must be able to provide the direct
+ prev_events that said event references.
+ """
+
+ def post_json(destination, path, data, headers=None, timeout=0):
+ # If it asks us for new missing events, give them NOTHING
+ if path.startswith("/_matrix/federation/v1/get_missing_events/"):
+ return {"events": []}
+
+ self.http_client.post_json = post_json
+
+ # Figure out what the most recent event is
+ most_recent = self.successResultOf(
+ maybeDeferred(
+ self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id
+ )
+ )[0]
+
+ # Now lie about an event
+ lying_event = FrozenEvent(
+ {
+ "room_id": self.room_id,
+ "sender": "@baduser:test.serv",
+ "event_id": "one:test.serv",
+ "depth": 1000,
+ "origin_server_ts": 1,
+ "type": "m.room.message",
+ "origin": "test.serv",
+ "content": "hewwo?",
+ "auth_events": [],
+ "prev_events": [("two:test.serv", {}), (most_recent, {})],
+ }
+ )
+
+ d = self.handler.on_receive_pdu(
+ "test.serv", lying_event, sent_to_us_directly=True
+ )
+
+ # Step the reactor, so the database fetches come back
+ self.reactor.advance(1)
+
+ # on_receive_pdu should throw an error
+ failure = self.failureResultOf(d)
+ self.assertEqual(
+ failure.value.args[0],
+ (
+ "ERROR 403: Your server isn't divulging details about prev_events "
+ "referenced in this event."
+ ),
+ )
+
+ # Make sure the invalid event isn't there
+ extrem = maybeDeferred(
+ self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id
+ )
+ self.assertEqual(self.successResultOf(extrem)[0], "$join:test.serv")
+
+ @unittest.DEBUG
+ def test_cant_hide_past_history(self):
+ """
+ If you send a message, you must be able to provide the direct
+ prev_events that said event references.
+ """
+
+ def post_json(destination, path, data, headers=None, timeout=0):
+ if path.startswith("/_matrix/federation/v1/get_missing_events/"):
+ return {
+ "events": [
+ {
+ "room_id": self.room_id,
+ "sender": "@baduser:test.serv",
+ "event_id": "three:test.serv",
+ "depth": 1000,
+ "origin_server_ts": 1,
+ "type": "m.room.message",
+ "origin": "test.serv",
+ "content": "hewwo?",
+ "auth_events": [],
+ "prev_events": [("four:test.serv", {})],
+ }
+ ]
+ }
+
+ self.http_client.post_json = post_json
+
+ def get_json(destination, path, args, headers=None):
+ if path.startswith("/_matrix/federation/v1/state_ids/"):
+ d = self.successResultOf(
+ self.homeserver.datastore.get_state_ids_for_event("one:test.serv")
+ )
+
+ return succeed(
+ {
+ "pdu_ids": [
+ y
+ for x, y in d.items()
+ if x == ("m.room.member", "@us:test")
+ ],
+ "auth_chain_ids": d.values(),
+ }
+ )
+
+ self.http_client.get_json = get_json
+
+ # Figure out what the most recent event is
+ most_recent = self.successResultOf(
+ maybeDeferred(
+ self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id
+ )
+ )[0]
+
+ # Make a good event
+ good_event = FrozenEvent(
+ {
+ "room_id": self.room_id,
+ "sender": "@baduser:test.serv",
+ "event_id": "one:test.serv",
+ "depth": 1000,
+ "origin_server_ts": 1,
+ "type": "m.room.message",
+ "origin": "test.serv",
+ "content": "hewwo?",
+ "auth_events": [],
+ "prev_events": [(most_recent, {})],
+ }
+ )
+
+ d = self.handler.on_receive_pdu(
+ "test.serv", good_event, sent_to_us_directly=True
+ )
+ self.reactor.advance(1)
+ self.assertEqual(self.successResultOf(d), None)
+
+ bad_event = FrozenEvent(
+ {
+ "room_id": self.room_id,
+ "sender": "@baduser:test.serv",
+ "event_id": "two:test.serv",
+ "depth": 1000,
+ "origin_server_ts": 1,
+ "type": "m.room.message",
+ "origin": "test.serv",
+ "content": "hewwo?",
+ "auth_events": [],
+ "prev_events": [("one:test.serv", {}), ("three:test.serv", {})],
+ }
+ )
+
+ d = self.handler.on_receive_pdu(
+ "test.serv", bad_event, sent_to_us_directly=True
+ )
+ self.reactor.advance(1)
+
+ extrem = maybeDeferred(
+ self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id
+ )
+ self.assertEqual(self.successResultOf(extrem)[0], "two:test.serv")
+
+ state = self.homeserver.get_state_handler().get_current_state_ids(self.room_id)
+ self.reactor.advance(1)
+ self.assertIn(("m.room.member", "@us:test"), self.successResultOf(state).keys())
diff --git a/tests/test_preview.py b/tests/test_preview.py
index 5bd36c74aa..446843367e 100644
--- a/tests/test_preview.py
+++ b/tests/test_preview.py
@@ -13,12 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from . import unittest
-
from synapse.rest.media.v1.preview_url_resource import (
- summarize_paragraphs, decode_and_calc_og
+ decode_and_calc_og,
+ summarize_paragraphs,
)
+from . import unittest
+
class PreviewTestCase(unittest.TestCase):
diff --git a/tests/test_server.py b/tests/test_server.py
new file mode 100644
index 0000000000..7e063c0290
--- /dev/null
+++ b/tests/test_server.py
@@ -0,0 +1,131 @@
+import json
+import re
+
+from twisted.internet.defer import Deferred
+from twisted.test.proto_helpers import MemoryReactorClock
+
+from synapse.api.errors import Codes, SynapseError
+from synapse.http.server import JsonResource
+from synapse.util import Clock
+
+from tests import unittest
+from tests.server import make_request, setup_test_homeserver
+
+
+class JsonResourceTests(unittest.TestCase):
+ def setUp(self):
+ self.reactor = MemoryReactorClock()
+ self.hs_clock = Clock(self.reactor)
+ self.homeserver = setup_test_homeserver(
+ http_client=None, clock=self.hs_clock, reactor=self.reactor
+ )
+
+ def test_handler_for_request(self):
+ """
+ JsonResource.handler_for_request gives correctly decoded URL args to
+ the callback, while Twisted will give the raw bytes of URL query
+ arguments.
+ """
+ got_kwargs = {}
+
+ def _callback(request, **kwargs):
+ got_kwargs.update(kwargs)
+ return (200, kwargs)
+
+ res = JsonResource(self.homeserver)
+ res.register_paths(
+ "GET", [re.compile("^/_matrix/foo/(?P<room_id>[^/]*)$")], _callback
+ )
+
+ request, channel = make_request(b"GET", b"/_matrix/foo/%E2%98%83?a=%E2%98%83")
+ request.render(res)
+
+ self.assertEqual(request.args, {b'a': [u"\N{SNOWMAN}".encode('utf8')]})
+ self.assertEqual(got_kwargs, {u"room_id": u"\N{SNOWMAN}"})
+
+ def test_callback_direct_exception(self):
+ """
+ If the web callback raises an uncaught exception, it will be translated
+ into a 500.
+ """
+
+ def _callback(request, **kwargs):
+ raise Exception("boo")
+
+ res = JsonResource(self.homeserver)
+ res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback)
+
+ request, channel = make_request(b"GET", b"/_matrix/foo")
+ request.render(res)
+
+ self.assertEqual(channel.result["code"], b'500')
+
+ def test_callback_indirect_exception(self):
+ """
+ If the web callback raises an uncaught exception in a Deferred, it will
+ be translated into a 500.
+ """
+
+ def _throw(*args):
+ raise Exception("boo")
+
+ def _callback(request, **kwargs):
+ d = Deferred()
+ d.addCallback(_throw)
+ self.reactor.callLater(1, d.callback, True)
+ return d
+
+ res = JsonResource(self.homeserver)
+ res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback)
+
+ request, channel = make_request(b"GET", b"/_matrix/foo")
+ request.render(res)
+
+ # No error has been raised yet
+ self.assertTrue("code" not in channel.result)
+
+ # Advance time, now there's an error
+ self.reactor.advance(1)
+ self.assertEqual(channel.result["code"], b'500')
+
+ def test_callback_synapseerror(self):
+ """
+ If the web callback raises a SynapseError, it returns the appropriate
+ status code and message set in it.
+ """
+
+ def _callback(request, **kwargs):
+ raise SynapseError(403, "Forbidden!!one!", Codes.FORBIDDEN)
+
+ res = JsonResource(self.homeserver)
+ res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback)
+
+ request, channel = make_request(b"GET", b"/_matrix/foo")
+ request.render(res)
+
+ self.assertEqual(channel.result["code"], b'403')
+ reply_body = json.loads(channel.result["body"])
+ self.assertEqual(reply_body["error"], "Forbidden!!one!")
+ self.assertEqual(reply_body["errcode"], "M_FORBIDDEN")
+
+ def test_no_handler(self):
+ """
+ If there is no handler to process the request, Synapse will return 400.
+ """
+
+ def _callback(request, **kwargs):
+ """
+ Not ever actually called!
+ """
+ self.fail("shouldn't ever get here")
+
+ res = JsonResource(self.homeserver)
+ res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback)
+
+ request, channel = make_request(b"GET", b"/_matrix/foobar")
+ request.render(res)
+
+ self.assertEqual(channel.result["code"], b'400')
+ reply_body = json.loads(channel.result["body"])
+ self.assertEqual(reply_body["error"], "Unrecognized request")
+ self.assertEqual(reply_body["errcode"], "M_UNRECOGNIZED")
diff --git a/tests/test_state.py b/tests/test_state.py
index a5c5e55951..c0f2d1152d 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -13,18 +13,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from tests import unittest
+from mock import Mock
+
from twisted.internet import defer
-from synapse.events import FrozenEvent
from synapse.api.auth import Auth
from synapse.api.constants import EventTypes, Membership
+from synapse.events import FrozenEvent
from synapse.state import StateHandler, StateResolutionHandler
-from .utils import MockClock
-
-from mock import Mock
+from tests import unittest
+from .utils import MockClock
_next_event_id = 1000
@@ -606,6 +606,14 @@ class StateTestCase(unittest.TestCase):
}
)
+ power_levels = create_event(
+ type=EventTypes.PowerLevels, state_key="",
+ content={"users": {
+ "@foo:bar": "100",
+ "@user_id:example.com": "100",
+ }}
+ )
+
creation = create_event(
type=EventTypes.Create, state_key="",
content={"creator": "@foo:bar"}
@@ -613,12 +621,14 @@ class StateTestCase(unittest.TestCase):
old_state_1 = [
creation,
+ power_levels,
member_event,
create_event(type="test1", state_key="1", depth=1),
]
old_state_2 = [
creation,
+ power_levels,
member_event,
create_event(type="test1", state_key="1", depth=2),
]
@@ -633,7 +643,7 @@ class StateTestCase(unittest.TestCase):
)
self.assertEqual(
- old_state_2[2].event_id, context.current_state_ids[("test1", "1")]
+ old_state_2[3].event_id, context.current_state_ids[("test1", "1")]
)
# Reverse the depth to make sure we are actually using the depths
@@ -641,12 +651,14 @@ class StateTestCase(unittest.TestCase):
old_state_1 = [
creation,
+ power_levels,
member_event,
create_event(type="test1", state_key="1", depth=2),
]
old_state_2 = [
creation,
+ power_levels,
member_event,
create_event(type="test1", state_key="1", depth=1),
]
@@ -659,7 +671,7 @@ class StateTestCase(unittest.TestCase):
)
self.assertEqual(
- old_state_1[2].event_id, context.current_state_ids[("test1", "1")]
+ old_state_1[3].event_id, context.current_state_ids[("test1", "1")]
)
def _get_context(self, event, prev_event_id_1, old_state_1, prev_event_id_2,
diff --git a/tests/test_test_utils.py b/tests/test_test_utils.py
index d28bb726bb..bc97c12245 100644
--- a/tests/test_test_utils.py
+++ b/tests/test_test_utils.py
@@ -14,7 +14,6 @@
# limitations under the License.
from tests import unittest
-
from tests.utils import MockClock
diff --git a/tests/test_types.py b/tests/test_types.py
index 115def2287..729bd676c1 100644
--- a/tests/test_types.py
+++ b/tests/test_types.py
@@ -13,11 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from tests import unittest
-
from synapse.api.errors import SynapseError
from synapse.server import HomeServer
-from synapse.types import UserID, RoomAlias, GroupID
+from synapse.types import GroupID, RoomAlias, UserID
+
+from tests import unittest
mock_homeserver = HomeServer(hostname="my.domain")
diff --git a/tests/test_visibility.py b/tests/test_visibility.py
new file mode 100644
index 0000000000..0dc1a924d3
--- /dev/null
+++ b/tests/test_visibility.py
@@ -0,0 +1,324 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from twisted.internet import defer
+from twisted.internet.defer import succeed
+
+from synapse.events import FrozenEvent
+from synapse.visibility import filter_events_for_server
+
+import tests.unittest
+from tests.utils import setup_test_homeserver
+
+logger = logging.getLogger(__name__)
+
+TEST_ROOM_ID = "!TEST:ROOM"
+
+
+class FilterEventsForServerTestCase(tests.unittest.TestCase):
+ @defer.inlineCallbacks
+ def setUp(self):
+ self.hs = yield setup_test_homeserver()
+ self.event_creation_handler = self.hs.get_event_creation_handler()
+ self.event_builder_factory = self.hs.get_event_builder_factory()
+ self.store = self.hs.get_datastore()
+
+ @defer.inlineCallbacks
+ def test_filtering(self):
+ #
+ # The events to be filtered consist of 10 membership events (it doesn't
+ # really matter if they are joins or leaves, so let's make them joins).
+ # One of those membership events is going to be for a user on the
+ # server we are filtering for (so we can check the filtering is doing
+ # the right thing).
+ #
+
+ # before we do that, we persist some other events to act as state.
+ self.inject_visibility("@admin:hs", "joined")
+ for i in range(0, 10):
+ yield self.inject_room_member("@resident%i:hs" % i)
+
+ events_to_filter = []
+
+ for i in range(0, 10):
+ user = "@user%i:%s" % (
+ i, "test_server" if i == 5 else "other_server"
+ )
+ evt = yield self.inject_room_member(user, extra_content={"a": "b"})
+ events_to_filter.append(evt)
+
+ filtered = yield filter_events_for_server(
+ self.store, "test_server", events_to_filter,
+ )
+
+ # the result should be 5 redacted events, and 5 unredacted events.
+ for i in range(0, 5):
+ self.assertEqual(events_to_filter[i].event_id, filtered[i].event_id)
+ self.assertNotIn("a", filtered[i].content)
+
+ for i in range(5, 10):
+ self.assertEqual(events_to_filter[i].event_id, filtered[i].event_id)
+ self.assertEqual(filtered[i].content["a"], "b")
+
+ @tests.unittest.DEBUG
+ @defer.inlineCallbacks
+ def test_erased_user(self):
+ # 4 message events, from erased and unerased users, with a membership
+ # change in the middle of them.
+ events_to_filter = []
+
+ evt = yield self.inject_message("@unerased:local_hs")
+ events_to_filter.append(evt)
+
+ evt = yield self.inject_message("@erased:local_hs")
+ events_to_filter.append(evt)
+
+ evt = yield self.inject_room_member("@joiner:remote_hs")
+ events_to_filter.append(evt)
+
+ evt = yield self.inject_message("@unerased:local_hs")
+ events_to_filter.append(evt)
+
+ evt = yield self.inject_message("@erased:local_hs")
+ events_to_filter.append(evt)
+
+ # the erasey user gets erased
+ self.hs.get_datastore().mark_user_erased("@erased:local_hs")
+
+ # ... and the filtering happens.
+ filtered = yield filter_events_for_server(
+ self.store, "test_server", events_to_filter,
+ )
+
+ for i in range(0, len(events_to_filter)):
+ self.assertEqual(
+ events_to_filter[i].event_id, filtered[i].event_id,
+ "Unexpected event at result position %i" % (i, )
+ )
+
+ for i in (0, 3):
+ self.assertEqual(
+ events_to_filter[i].content["body"], filtered[i].content["body"],
+ "Unexpected event content at result position %i" % (i,)
+ )
+
+ for i in (1, 4):
+ self.assertNotIn("body", filtered[i].content)
+
+ @defer.inlineCallbacks
+ def inject_visibility(self, user_id, visibility):
+ content = {"history_visibility": visibility}
+ builder = self.event_builder_factory.new({
+ "type": "m.room.history_visibility",
+ "sender": user_id,
+ "state_key": "",
+ "room_id": TEST_ROOM_ID,
+ "content": content,
+ })
+
+ event, context = yield self.event_creation_handler.create_new_client_event(
+ builder
+ )
+ yield self.hs.get_datastore().persist_event(event, context)
+ defer.returnValue(event)
+
+ @defer.inlineCallbacks
+ def inject_room_member(self, user_id, membership="join", extra_content={}):
+ content = {"membership": membership}
+ content.update(extra_content)
+ builder = self.event_builder_factory.new({
+ "type": "m.room.member",
+ "sender": user_id,
+ "state_key": user_id,
+ "room_id": TEST_ROOM_ID,
+ "content": content,
+ })
+
+ event, context = yield self.event_creation_handler.create_new_client_event(
+ builder
+ )
+
+ yield self.hs.get_datastore().persist_event(event, context)
+ defer.returnValue(event)
+
+ @defer.inlineCallbacks
+ def inject_message(self, user_id, content=None):
+ if content is None:
+ content = {"body": "testytest"}
+ builder = self.event_builder_factory.new({
+ "type": "m.room.message",
+ "sender": user_id,
+ "room_id": TEST_ROOM_ID,
+ "content": content,
+ })
+
+ event, context = yield self.event_creation_handler.create_new_client_event(
+ builder
+ )
+
+ yield self.hs.get_datastore().persist_event(event, context)
+ defer.returnValue(event)
+
+ @defer.inlineCallbacks
+ def test_large_room(self):
+ # see what happens when we have a large room with hundreds of thousands
+ # of membership events
+
+ # As above, the events to be filtered consist of 10 membership events,
+ # where one of them is for a user on the server we are filtering for.
+
+ import cProfile
+ import pstats
+ import time
+
+ # we stub out the store, because building up all that state the normal
+ # way is very slow.
+ test_store = _TestStore()
+
+ # our initial state is 100000 membership events and one
+ # history_visibility event.
+ room_state = []
+
+ history_visibility_evt = FrozenEvent({
+ "event_id": "$history_vis",
+ "type": "m.room.history_visibility",
+ "sender": "@resident_user_0:test.com",
+ "state_key": "",
+ "room_id": TEST_ROOM_ID,
+ "content": {"history_visibility": "joined"},
+ })
+ room_state.append(history_visibility_evt)
+ test_store.add_event(history_visibility_evt)
+
+ for i in range(0, 100000):
+ user = "@resident_user_%i:test.com" % (i, )
+ evt = FrozenEvent({
+ "event_id": "$res_event_%i" % (i, ),
+ "type": "m.room.member",
+ "state_key": user,
+ "sender": user,
+ "room_id": TEST_ROOM_ID,
+ "content": {
+ "membership": "join",
+ "extra": "zzz,"
+ },
+ })
+ room_state.append(evt)
+ test_store.add_event(evt)
+
+ events_to_filter = []
+ for i in range(0, 10):
+ user = "@user%i:%s" % (
+ i, "test_server" if i == 5 else "other_server"
+ )
+ evt = FrozenEvent({
+ "event_id": "$evt%i" % (i, ),
+ "type": "m.room.member",
+ "state_key": user,
+ "sender": user,
+ "room_id": TEST_ROOM_ID,
+ "content": {
+ "membership": "join",
+ "extra": "zzz",
+ },
+ })
+ events_to_filter.append(evt)
+ room_state.append(evt)
+
+ test_store.add_event(evt)
+ test_store.set_state_ids_for_event(evt, {
+ (e.type, e.state_key): e.event_id for e in room_state
+ })
+
+ pr = cProfile.Profile()
+ pr.enable()
+
+ logger.info("Starting filtering")
+ start = time.time()
+ filtered = yield filter_events_for_server(
+ test_store, "test_server", events_to_filter,
+ )
+ logger.info("Filtering took %f seconds", time.time() - start)
+
+ pr.disable()
+ with open("filter_events_for_server.profile", "w+") as f:
+ ps = pstats.Stats(pr, stream=f).sort_stats('cumulative')
+ ps.print_stats()
+
+ # the result should be 5 redacted events, and 5 unredacted events.
+ for i in range(0, 5):
+ self.assertEqual(events_to_filter[i].event_id, filtered[i].event_id)
+ self.assertNotIn("extra", filtered[i].content)
+
+ for i in range(5, 10):
+ self.assertEqual(events_to_filter[i].event_id, filtered[i].event_id)
+ self.assertEqual(filtered[i].content["extra"], "zzz")
+
+ test_large_room.skip = "Disabled by default because it's slow"
+
+
+class _TestStore(object):
+ """Implements a few methods of the DataStore, so that we can test
+ filter_events_for_server
+
+ """
+ def __init__(self):
+ # data for get_events: a map from event_id to event
+ self.events = {}
+
+ # data for get_state_ids_for_events mock: a map from event_id to
+ # a map from (type_state_key) -> event_id for the state at that
+ # event
+ self.state_ids_for_events = {}
+
+ def add_event(self, event):
+ self.events[event.event_id] = event
+
+ def set_state_ids_for_event(self, event, state):
+ self.state_ids_for_events[event.event_id] = state
+
+ def get_state_ids_for_events(self, events, types):
+ res = {}
+ include_memberships = False
+ for (type, state_key) in types:
+ if type == "m.room.history_visibility":
+ continue
+ if type != "m.room.member" or state_key is not None:
+ raise RuntimeError(
+ "Unimplemented: get_state_ids with type (%s, %s)" %
+ (type, state_key),
+ )
+ include_memberships = True
+
+ if include_memberships:
+ for event_id in events:
+ res[event_id] = self.state_ids_for_events[event_id]
+
+ else:
+ k = ("m.room.history_visibility", "")
+ for event_id in events:
+ hve = self.state_ids_for_events[event_id][k]
+ res[event_id] = {k: hve}
+
+ return succeed(res)
+
+ def get_events(self, events):
+ return succeed({
+ event_id: self.events[event_id] for event_id in events
+ })
+
+ def are_users_erased(self, users):
+ return succeed({u: False for u in users})
diff --git a/tests/unittest.py b/tests/unittest.py
index 184fe880f3..b15b06726b 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -35,7 +35,10 @@ class ToTwistedHandler(logging.Handler):
def emit(self, record):
log_entry = self.format(record)
log_level = record.levelname.lower().replace('warning', 'warn')
- self.tx_log.emit(twisted.logger.LogLevel.levelWithName(log_level), log_entry)
+ self.tx_log.emit(
+ twisted.logger.LogLevel.levelWithName(log_level),
+ log_entry.replace("{", r"(").replace("}", r")"),
+ )
handler = ToTwistedHandler()
@@ -106,6 +109,17 @@ class TestCase(unittest.TestCase):
except AssertionError as e:
raise (type(e))(e.message + " for '.%s'" % key)
+ def assert_dict(self, required, actual):
+ """Does a partial assert of a dict.
+
+ Args:
+ required (dict): The keys and value which MUST be in 'actual'.
+ actual (dict): The test result. Extra keys will not be checked.
+ """
+ for key in required:
+ self.assertEquals(required[key], actual[key],
+ msg="%s mismatch. %s" % (key, actual))
+
def DEBUG(target):
"""A decorator to set the .loglevel attribute to logging.DEBUG.
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 2516fe40f4..8176a7dabd 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -13,20 +13,28 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from functools import partial
import logging
+from functools import partial
import mock
+
+from twisted.internet import defer, reactor
+
from synapse.api.errors import SynapseError
-from synapse.util import async
from synapse.util import logcontext
-from twisted.internet import defer
from synapse.util.caches import descriptors
+
from tests import unittest
logger = logging.getLogger(__name__)
+def run_on_reactor():
+ d = defer.Deferred()
+ reactor.callLater(0, d.callback, 0)
+ return logcontext.make_deferred_yieldable(d)
+
+
class CacheTestCase(unittest.TestCase):
def test_invalidate_all(self):
cache = descriptors.Cache("testcache")
@@ -195,7 +203,8 @@ class DescriptorTestCase(unittest.TestCase):
def fn(self, arg1):
@defer.inlineCallbacks
def inner_fn():
- yield async.run_on_reactor()
+ # we want this to behave like an asynchronous function
+ yield run_on_reactor()
raise SynapseError(400, "blah")
return inner_fn()
@@ -205,7 +214,12 @@ class DescriptorTestCase(unittest.TestCase):
with logcontext.LoggingContext() as c1:
c1.name = "c1"
try:
- yield obj.fn(1)
+ d = obj.fn(1)
+ self.assertEqual(
+ logcontext.LoggingContext.current_context(),
+ logcontext.LoggingContext.sentinel,
+ )
+ yield d
self.fail("No exception thrown")
except SynapseError:
pass
diff --git a/tests/util/test_dict_cache.py b/tests/util/test_dict_cache.py
index bc92f85fa6..26f2fa5800 100644
--- a/tests/util/test_dict_cache.py
+++ b/tests/util/test_dict_cache.py
@@ -14,10 +14,10 @@
# limitations under the License.
-from tests import unittest
-
from synapse.util.caches.dictionary_cache import DictionaryCache
+from tests import unittest
+
class DictCacheTestCase(unittest.TestCase):
@@ -32,7 +32,7 @@ class DictCacheTestCase(unittest.TestCase):
seq = self.cache.sequence
test_value = {"test": "test_simple_cache_hit_full"}
- self.cache.update(seq, key, test_value, full=True)
+ self.cache.update(seq, key, test_value)
c = self.cache.get(key)
self.assertEqual(test_value, c.value)
@@ -44,7 +44,7 @@ class DictCacheTestCase(unittest.TestCase):
test_value = {
"test": "test_simple_cache_hit_partial"
}
- self.cache.update(seq, key, test_value, full=True)
+ self.cache.update(seq, key, test_value)
c = self.cache.get(key, ["test"])
self.assertEqual(test_value, c.value)
@@ -56,7 +56,7 @@ class DictCacheTestCase(unittest.TestCase):
test_value = {
"test": "test_simple_cache_miss_partial"
}
- self.cache.update(seq, key, test_value, full=True)
+ self.cache.update(seq, key, test_value)
c = self.cache.get(key, ["test2"])
self.assertEqual({}, c.value)
@@ -70,7 +70,7 @@ class DictCacheTestCase(unittest.TestCase):
"test2": "test_simple_cache_hit_miss_partial2",
"test3": "test_simple_cache_hit_miss_partial3",
}
- self.cache.update(seq, key, test_value, full=True)
+ self.cache.update(seq, key, test_value)
c = self.cache.get(key, ["test2"])
self.assertEqual({"test2": "test_simple_cache_hit_miss_partial2"}, c.value)
@@ -82,13 +82,13 @@ class DictCacheTestCase(unittest.TestCase):
test_value_1 = {
"test": "test_simple_cache_hit_miss_partial",
}
- self.cache.update(seq, key, test_value_1, full=False)
+ self.cache.update(seq, key, test_value_1, fetched_keys=set("test"))
seq = self.cache.sequence
test_value_2 = {
"test2": "test_simple_cache_hit_miss_partial2",
}
- self.cache.update(seq, key, test_value_2, full=False)
+ self.cache.update(seq, key, test_value_2, fetched_keys=set("test2"))
c = self.cache.get(key)
self.assertEqual(
diff --git a/tests/util/test_expiring_cache.py b/tests/util/test_expiring_cache.py
index 31d24adb8b..d12b5e838b 100644
--- a/tests/util/test_expiring_cache.py
+++ b/tests/util/test_expiring_cache.py
@@ -14,12 +14,12 @@
# limitations under the License.
-from .. import unittest
-
from synapse.util.caches.expiringcache import ExpiringCache
from tests.utils import MockClock
+from .. import unittest
+
class ExpiringCacheTestCase(unittest.TestCase):
diff --git a/tests/util/test_file_consumer.py b/tests/util/test_file_consumer.py
index d6e1082779..7ce5f8c258 100644
--- a/tests/util/test_file_consumer.py
+++ b/tests/util/test_file_consumer.py
@@ -14,15 +14,16 @@
# limitations under the License.
-from twisted.internet import defer, reactor
+import threading
+
from mock import NonCallableMock
+from six import StringIO
+
+from twisted.internet import defer, reactor
from synapse.util.file_consumer import BackgroundFileConsumer
from tests import unittest
-from six import StringIO
-
-import threading
class FileConsumerTests(unittest.TestCase):
@@ -30,7 +31,7 @@ class FileConsumerTests(unittest.TestCase):
@defer.inlineCallbacks
def test_pull_consumer(self):
string_file = StringIO()
- consumer = BackgroundFileConsumer(string_file)
+ consumer = BackgroundFileConsumer(string_file, reactor=reactor)
try:
producer = DummyPullProducer()
@@ -54,7 +55,7 @@ class FileConsumerTests(unittest.TestCase):
@defer.inlineCallbacks
def test_push_consumer(self):
string_file = BlockingStringWrite()
- consumer = BackgroundFileConsumer(string_file)
+ consumer = BackgroundFileConsumer(string_file, reactor=reactor)
try:
producer = NonCallableMock(spec_set=[])
@@ -80,7 +81,7 @@ class FileConsumerTests(unittest.TestCase):
@defer.inlineCallbacks
def test_push_producer_feedback(self):
string_file = BlockingStringWrite()
- consumer = BackgroundFileConsumer(string_file)
+ consumer = BackgroundFileConsumer(string_file, reactor=reactor)
try:
producer = NonCallableMock(spec_set=["pauseProducing", "resumeProducing"])
diff --git a/tests/util/test_limiter.py b/tests/util/test_limiter.py
index 9c795d9fdb..a5a767b1ff 100644
--- a/tests/util/test_limiter.py
+++ b/tests/util/test_limiter.py
@@ -14,12 +14,12 @@
# limitations under the License.
-from tests import unittest
-
from twisted.internet import defer
from synapse.util.async import Limiter
+from tests import unittest
+
class LimiterTestCase(unittest.TestCase):
diff --git a/tests/util/test_linearizer.py b/tests/util/test_linearizer.py
index 4865eb4bc6..c95907b32c 100644
--- a/tests/util/test_linearizer.py
+++ b/tests/util/test_linearizer.py
@@ -12,13 +12,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.util import async, logcontext
-from tests import unittest
-from twisted.internet import defer
+from six.moves import range
+
+from twisted.internet import defer, reactor
+from synapse.util import Clock, logcontext
from synapse.util.async import Linearizer
-from six.moves import range
+
+from tests import unittest
class LinearizerTestCase(unittest.TestCase):
@@ -53,7 +55,7 @@ class LinearizerTestCase(unittest.TestCase):
self.assertEqual(
logcontext.LoggingContext.current_context(), lc)
if sleep:
- yield async.sleep(0)
+ yield Clock(reactor).sleep(0)
self.assertEqual(
logcontext.LoggingContext.current_context(), lc)
diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py
index ad78d884e0..c54001f7a4 100644
--- a/tests/util/test_logcontext.py
+++ b/tests/util/test_logcontext.py
@@ -1,12 +1,11 @@
import twisted.python.failure
-from twisted.internet import defer
-from twisted.internet import reactor
-from .. import unittest
+from twisted.internet import defer, reactor
-from synapse.util.async import sleep
-from synapse.util import logcontext
+from synapse.util import Clock, logcontext
from synapse.util.logcontext import LoggingContext
+from .. import unittest
+
class LoggingContextTestCase(unittest.TestCase):
@@ -22,18 +21,20 @@ class LoggingContextTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_sleep(self):
+ clock = Clock(reactor)
+
@defer.inlineCallbacks
def competing_callback():
with LoggingContext() as competing_context:
competing_context.request = "competing"
- yield sleep(0)
+ yield clock.sleep(0)
self._check_test_key("competing")
reactor.callLater(0, competing_callback)
with LoggingContext() as context_one:
context_one.request = "one"
- yield sleep(0)
+ yield clock.sleep(0)
self._check_test_key("one")
def _test_run_in_background(self, function):
@@ -87,7 +88,7 @@ class LoggingContextTestCase(unittest.TestCase):
def test_run_in_background_with_blocking_fn(self):
@defer.inlineCallbacks
def blocking_function():
- yield sleep(0)
+ yield Clock(reactor).sleep(0)
return self._test_run_in_background(blocking_function)
diff --git a/tests/util/test_logformatter.py b/tests/util/test_logformatter.py
index 1a1a8412f2..297aebbfbe 100644
--- a/tests/util/test_logformatter.py
+++ b/tests/util/test_logformatter.py
@@ -15,6 +15,7 @@
import sys
from synapse.util.logformatter import LogFormatter
+
from tests import unittest
diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py
index dfb78cb8bd..9b36ef4482 100644
--- a/tests/util/test_lrucache.py
+++ b/tests/util/test_lrucache.py
@@ -14,12 +14,12 @@
# limitations under the License.
-from .. import unittest
+from mock import Mock
from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.treecache import TreeCache
-from mock import Mock
+from .. import unittest
class LruCacheTestCase(unittest.TestCase):
diff --git a/tests/util/test_rwlock.py b/tests/util/test_rwlock.py
index 1d745ae1a7..24194e3b25 100644
--- a/tests/util/test_rwlock.py
+++ b/tests/util/test_rwlock.py
@@ -14,10 +14,10 @@
# limitations under the License.
-from tests import unittest
-
from synapse.util.async import ReadWriteLock
+from tests import unittest
+
class ReadWriteLockTestCase(unittest.TestCase):
diff --git a/tests/util/test_snapshot_cache.py b/tests/util/test_snapshot_cache.py
index d3a8630c2f..0f5b32fcc0 100644
--- a/tests/util/test_snapshot_cache.py
+++ b/tests/util/test_snapshot_cache.py
@@ -14,10 +14,11 @@
# limitations under the License.
-from .. import unittest
+from twisted.internet.defer import Deferred
from synapse.util.caches.snapshot_cache import SnapshotCache
-from twisted.internet.defer import Deferred
+
+from .. import unittest
class SnapshotCacheTestCase(unittest.TestCase):
diff --git a/tests/util/test_stream_change_cache.py b/tests/util/test_stream_change_cache.py
index 67ece166c7..65b0f2e6fb 100644
--- a/tests/util/test_stream_change_cache.py
+++ b/tests/util/test_stream_change_cache.py
@@ -1,8 +1,9 @@
-from tests import unittest
from mock import patch
from synapse.util.caches.stream_change_cache import StreamChangeCache
+from tests import unittest
+
class StreamChangeCacheTests(unittest.TestCase):
"""
@@ -140,8 +141,8 @@ class StreamChangeCacheTests(unittest.TestCase):
)
# Query all the entries mid-way through the stream, but include one
- # that doesn't exist in it. We should get back the one that doesn't
- # exist, too.
+ # that doesn't exist in it. We shouldn't get back the one that doesn't
+ # exist.
self.assertEqual(
cache.get_entities_changed(
[
@@ -152,7 +153,7 @@ class StreamChangeCacheTests(unittest.TestCase):
],
stream_pos=2,
),
- set(["bar@baz.net", "user@elsewhere.org", "not@here.website"]),
+ set(["bar@baz.net", "user@elsewhere.org"]),
)
# Query all the entries, but before the first known point. We will get
@@ -177,6 +178,22 @@ class StreamChangeCacheTests(unittest.TestCase):
),
)
+ # Query a subset of the entries mid-way through the stream. We should
+ # only get back the subset.
+ self.assertEqual(
+ cache.get_entities_changed(
+ [
+ "bar@baz.net",
+ ],
+ stream_pos=2,
+ ),
+ set(
+ [
+ "bar@baz.net",
+ ]
+ ),
+ )
+
def test_max_pos(self):
"""
StreamChangeCache.get_max_pos_of_last_change will return the most
diff --git a/tests/util/test_treecache.py b/tests/util/test_treecache.py
index 7ab578a185..a5f2261208 100644
--- a/tests/util/test_treecache.py
+++ b/tests/util/test_treecache.py
@@ -14,10 +14,10 @@
# limitations under the License.
-from .. import unittest
-
from synapse.util.caches.treecache import TreeCache
+from .. import unittest
+
class TreeCacheTestCase(unittest.TestCase):
def test_get_set_onelevel(self):
diff --git a/tests/util/test_wheel_timer.py b/tests/util/test_wheel_timer.py
index fdb24a48b0..03201a4d9b 100644
--- a/tests/util/test_wheel_timer.py
+++ b/tests/util/test_wheel_timer.py
@@ -13,10 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from .. import unittest
-
from synapse.util.wheel_timer import WheelTimer
+from .. import unittest
+
class WheelTimerTestCase(unittest.TestCase):
def test_single_insert_fetch(self):
diff --git a/tests/utils.py b/tests/utils.py
index 262c4a5714..e488238bb3 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -15,9 +15,10 @@
import hashlib
from inspect import getcallargs
-from six.moves.urllib import parse as urlparse
from mock import Mock, patch
+from six.moves.urllib import parse as urlparse
+
from twisted.internet import defer, reactor
from synapse.api.errors import CodeMessageException, cs_error
@@ -37,11 +38,15 @@ USE_POSTGRES_FOR_TESTS = False
@defer.inlineCallbacks
-def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
+def setup_test_homeserver(name="test", datastore=None, config=None, reactor=None,
+ **kargs):
"""Setup a homeserver suitable for running tests against. Keyword arguments
are passed to the Homeserver constructor. If no datastore is supplied a
datastore backed by an in-memory sqlite db will be given to the HS.
"""
+ if reactor is None:
+ from twisted.internet import reactor
+
if config is None:
config = Mock()
config.signing_key = [MockKey()]
@@ -60,6 +65,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
config.federation_domain_whitelist = None
config.federation_rc_reject_limit = 10
config.federation_rc_sleep_limit = 10
+ config.federation_rc_sleep_delay = 100
config.federation_rc_concurrent = 10
config.filter_timeline_limit = 5000
config.user_directory_search_all_users = False
@@ -110,6 +116,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
database_engine=db_engine,
room_list_handler=object(),
tls_server_context_factory=Mock(),
+ reactor=reactor,
**kargs
)
db_conn = hs.get_db_conn()
|