summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
authorRichard van der Hoff <richard@matrix.org>2016-03-03 19:05:54 +0000
committerRichard van der Hoff <richard@matrix.org>2016-03-03 19:05:54 +0000
commita85179aff3bf2bc1b132e9918cd8222a61a8bcc2 (patch)
treeaac1ff9e5bdca7ddc0fdc9d2cae5a5dcf13e8733 /tests
parentEmpty commit (diff)
parentMerge pull request #621 from matrix-org/daniel/ratelimiting (diff)
downloadsynapse-a85179aff3bf2bc1b132e9918cd8222a61a8bcc2.tar.xz
Merge remote-tracking branch 'origin/develop' into rav/SYN-642
Diffstat (limited to 'tests')
-rw-r--r--tests/handlers/test_profile.py16
-rw-r--r--tests/replication/test_resource.py17
-rw-r--r--tests/rest/client/v1/test_profile.py4
-rw-r--r--tests/storage/test_base.py3
-rw-r--r--tests/utils.py25
5 files changed, 45 insertions, 20 deletions
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index a87703bbfd..4f2c14e4ff 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -23,7 +23,7 @@ from synapse.api.errors import AuthError
 from synapse.handlers.profile import ProfileHandler
 from synapse.types import UserID
 
-from tests.utils import setup_test_homeserver
+from tests.utils import setup_test_homeserver, requester_for_user
 
 
 class ProfileHandlers(object):
@@ -84,7 +84,11 @@ class ProfileTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_set_my_name(self):
-        yield self.handler.set_displayname(self.frank, self.frank, "Frank Jr.")
+        yield self.handler.set_displayname(
+            self.frank,
+            requester_for_user(self.frank),
+            "Frank Jr."
+        )
 
         self.assertEquals(
             (yield self.store.get_profile_displayname(self.frank.localpart)),
@@ -93,7 +97,11 @@ class ProfileTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_set_my_name_noauth(self):
-        d = self.handler.set_displayname(self.frank, self.bob, "Frank Jr.")
+        d = self.handler.set_displayname(
+            self.frank,
+            requester_for_user(self.bob),
+            "Frank Jr."
+        )
 
         yield self.assertFailure(d, AuthError)
 
@@ -136,7 +144,7 @@ class ProfileTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def test_set_my_avatar(self):
         yield self.handler.set_avatar_url(
-            self.frank, self.frank, "http://my.server/pic.gif"
+            self.frank, requester_for_user(self.frank), "http://my.server/pic.gif"
         )
 
         self.assertEquals(
diff --git a/tests/replication/test_resource.py b/tests/replication/test_resource.py
index 38daaf87e2..daabc563b4 100644
--- a/tests/replication/test_resource.py
+++ b/tests/replication/test_resource.py
@@ -18,7 +18,7 @@ from synapse.types import Requester, UserID
 
 from twisted.internet import defer
 from tests import unittest
-from tests.utils import setup_test_homeserver
+from tests.utils import setup_test_homeserver, requester_for_user
 from mock import Mock, NonCallableMock
 import json
 import contextlib
@@ -133,12 +133,15 @@ class ReplicationResourceCase(unittest.TestCase):
     @defer.inlineCallbacks
     def send_text_message(self, room_id, message):
         handler = self.hs.get_handlers().message_handler
-        event = yield handler.create_and_send_nonmember_event({
-            "type": "m.room.message",
-            "content": {"body": "message", "msgtype": "m.text"},
-            "room_id": room_id,
-            "sender": self.user.to_string(),
-        })
+        event = yield handler.create_and_send_nonmember_event(
+            requester_for_user(self.user),
+            {
+                "type": "m.room.message",
+                "content": {"body": "message", "msgtype": "m.text"},
+                "room_id": room_id,
+                "sender": self.user.to_string(),
+            }
+        )
         defer.returnValue(event.event_id)
 
     @defer.inlineCallbacks
diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py
index 0785965de2..1d210f9bf8 100644
--- a/tests/rest/client/v1/test_profile.py
+++ b/tests/rest/client/v1/test_profile.py
@@ -86,7 +86,7 @@ class ProfileTestCase(unittest.TestCase):
 
         self.assertEquals(200, code)
         self.assertEquals(mocked_set.call_args[0][0].localpart, "1234ABCD")
-        self.assertEquals(mocked_set.call_args[0][1].localpart, "1234ABCD")
+        self.assertEquals(mocked_set.call_args[0][1].user.localpart, "1234ABCD")
         self.assertEquals(mocked_set.call_args[0][2], "Frank Jr.")
 
     @defer.inlineCallbacks
@@ -155,5 +155,5 @@ class ProfileTestCase(unittest.TestCase):
 
         self.assertEquals(200, code)
         self.assertEquals(mocked_set.call_args[0][0].localpart, "1234ABCD")
-        self.assertEquals(mocked_set.call_args[0][1].localpart, "1234ABCD")
+        self.assertEquals(mocked_set.call_args[0][1].user.localpart, "1234ABCD")
         self.assertEquals(mocked_set.call_args[0][2], "http://my.server/pic.gif")
diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index c76545be65..2e33beb07c 100644
--- a/tests/storage/test_base.py
+++ b/tests/storage/test_base.py
@@ -48,11 +48,12 @@ class SQLBaseStoreTestCase(unittest.TestCase):
 
         config = Mock()
         config.event_cache_size = 1
+        config.database_config = {"name": "sqlite3"}
         hs = HomeServer(
             "test",
             db_pool=self.db_pool,
             config=config,
-            database_engine=create_engine("sqlite3"),
+            database_engine=create_engine(config),
         )
 
         self.datastore = SQLBaseStore(hs)
diff --git a/tests/utils.py b/tests/utils.py
index dfbee5c23a..291b549053 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -20,6 +20,7 @@ from synapse.storage.prepare_database import prepare_database
 from synapse.storage.engines import create_engine
 from synapse.server import HomeServer
 from synapse.federation.transport import server
+from synapse.types import Requester
 from synapse.util.ratelimitutils import FederationRateLimiter
 
 from synapse.util.logcontext import LoggingContext
@@ -51,6 +52,8 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
         config.server_name = "server.under.test"
         config.trusted_third_party_id_servers = []
 
+    config.database_config = {"name": "sqlite3"}
+
     if "clock" not in kargs:
         kargs["clock"] = MockClock()
 
@@ -60,7 +63,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
         hs = HomeServer(
             name, db_pool=db_pool, config=config,
             version_string="Synapse/tests",
-            database_engine=create_engine("sqlite3"),
+            database_engine=create_engine(config),
             get_db_conn=db_pool.get_db_conn,
             **kargs
         )
@@ -69,7 +72,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
         hs = HomeServer(
             name, db_pool=None, datastore=datastore, config=config,
             version_string="Synapse/tests",
-            database_engine=create_engine("sqlite3"),
+            database_engine=create_engine(config),
             **kargs
         )
 
@@ -278,18 +281,24 @@ class SQLiteMemoryDbPool(ConnectionPool, object):
             cp_max=1,
         )
 
+        self.config = Mock()
+        self.config.database_config = {"name": "sqlite3"}
+
     def prepare(self):
-        engine = create_engine("sqlite3")
+        engine = self.create_engine()
         return self.runWithConnection(
-            lambda conn: prepare_database(conn, engine)
+            lambda conn: prepare_database(conn, engine, self.config)
         )
 
     def get_db_conn(self):
         conn = self.connect()
-        engine = create_engine("sqlite3")
-        prepare_database(conn, engine)
+        engine = self.create_engine()
+        prepare_database(conn, engine, self.config)
         return conn
 
+    def create_engine(self):
+        return create_engine(self.config)
+
 
 class MemoryDataStore(object):
 
@@ -502,3 +511,7 @@ class DeferredMockCallable(object):
                     "call(%s)" % _format_call(c[0], c[1]) for c in calls
                 ])
             )
+
+
+def requester_for_user(user):
+    return Requester(user, None, False)