diff --git a/synapse/api/events/__init__.py b/synapse/api/events/__init__.py
index a9991e9c94..0cee196851 100644
--- a/synapse/api/events/__init__.py
+++ b/synapse/api/events/__init__.py
@@ -156,7 +156,8 @@ class SynapseEvent(JsonEncodedObject):
return "Missing %s key" % key
if type(content[key]) != type(template[key]):
- return "Key %s is of the wrong type." % key
+ return "Key %s is of the wrong type (got %s, want %s)" % (
+ key, type(content[key]), type(template[key]))
if type(content[key]) == dict:
# we must go deeper
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index d675d8c8f9..2f1b954902 100755
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.storage import read_schema
+from synapse.storage import prepare_database
from synapse.server import HomeServer
@@ -36,30 +36,14 @@ from daemonize import Daemonize
import twisted.manhole.telnet
import logging
-import sqlite3
import os
import re
import sys
+import sqlite3
logger = logging.getLogger(__name__)
-SCHEMAS = [
- "transactions",
- "pdu",
- "users",
- "profiles",
- "presence",
- "im",
- "room_aliases",
-]
-
-
-# Remember to update this number every time an incompatible change is made to
-# database schema files, so the users will be informed on server restarts.
-SCHEMA_VERSION = 3
-
-
class SynapseHomeServer(HomeServer):
def build_http_client(self):
@@ -80,52 +64,12 @@ class SynapseHomeServer(HomeServer):
)
def build_db_pool(self):
- """ Set up all the dbs. Since all the *.sql have IF NOT EXISTS, so we
- don't have to worry about overwriting existing content.
- """
- logging.info("Preparing database: %s...", self.db_name)
-
- with sqlite3.connect(self.db_name) as db_conn:
- c = db_conn.cursor()
- c.execute("PRAGMA user_version")
- row = c.fetchone()
-
- if row and row[0]:
- user_version = row[0]
-
- if user_version > SCHEMA_VERSION:
- raise ValueError("Cannot use this database as it is too " +
- "new for the server to understand"
- )
- elif user_version < SCHEMA_VERSION:
- logging.info("Upgrading database from version %d",
- user_version
- )
-
- # Run every version since after the current version.
- for v in range(user_version + 1, SCHEMA_VERSION + 1):
- sql_script = read_schema("delta/v%d" % (v))
- c.executescript(sql_script)
-
- db_conn.commit()
-
- else:
- for sql_loc in SCHEMAS:
- sql_script = read_schema(sql_loc)
-
- c.executescript(sql_script)
- db_conn.commit()
- c.execute("PRAGMA user_version = %d" % SCHEMA_VERSION)
-
- c.close()
-
- logging.info("Database prepared in %s.", self.db_name)
-
- pool = adbapi.ConnectionPool(
- 'sqlite3', self.db_name, check_same_thread=False,
- cp_min=1, cp_max=1)
-
- return pool
+ return adbapi.ConnectionPool(
+ "sqlite3", self.get_db_name(),
+ check_same_thread=False,
+ cp_min=1,
+ cp_max=1
+ )
def create_resource_tree(self, web_client, redirect_root_to_web_client):
"""Create the resource tree for this Home Server.
@@ -230,10 +174,6 @@ class SynapseHomeServer(HomeServer):
logger.info("Synapse now listening on port %d", unsecure_port)
-def run():
- reactor.run()
-
-
def setup():
config = HomeServerConfig.load_config(
"Synapse Homeserver",
@@ -268,7 +208,15 @@ def setup():
web_client=config.webclient,
redirect_root_to_web_client=True,
)
- hs.start_listening(config.bind_port, config.unsecure_port)
+
+ db_name = hs.get_db_name()
+
+ logging.info("Preparing database: %s...", db_name)
+
+ with sqlite3.connect(db_name) as db_conn:
+ prepare_database(db_conn)
+
+ logging.info("Database prepared in %s.", db_name)
hs.get_db_pool()
@@ -279,12 +227,14 @@ def setup():
f.namespace['hs'] = hs
reactor.listenTCP(config.manhole, f, interface='127.0.0.1')
+ hs.start_listening(config.bind_port, config.unsecure_port)
+
if config.daemonize:
print config.pid_file
daemon = Daemonize(
app="synapse-homeserver",
pid=config.pid_file,
- action=run,
+ action=reactor.run,
auto_close_fds=False,
verbose=True,
logger=logger,
@@ -292,7 +242,7 @@ def setup():
daemon.start()
else:
- run()
+ reactor.run()
if __name__ == '__main__':
diff --git a/synapse/server.py b/synapse/server.py
index 7c185537aa..cdea49e6ab 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -58,6 +58,7 @@ class BaseHomeServer(object):
DEPENDENCIES = [
'clock',
'http_client',
+ 'db_name',
'db_pool',
'persistence_service',
'replication_layer',
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 1cede2809d..66658f6721 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -47,6 +47,23 @@ import os
logger = logging.getLogger(__name__)
+
+SCHEMAS = [
+ "transactions",
+ "pdu",
+ "users",
+ "profiles",
+ "presence",
+ "im",
+ "room_aliases",
+]
+
+
+# Remember to update this number every time an incompatible change is made to
+# database schema files, so the users will be informed on server restarts.
+SCHEMA_VERSION = 3
+
+
class _RollbackButIsFineException(Exception):
""" This exception is used to rollback a transaction without implying
something went wrong.
@@ -78,7 +95,7 @@ class DataStore(RoomMemberStore, RoomStore,
stream_ordering = self.min_token
try:
- yield self._db_pool.runInteraction(
+ yield self.runInteraction(
self._persist_pdu_event_txn,
pdu=pdu,
event=event,
@@ -291,7 +308,7 @@ class DataStore(RoomMemberStore, RoomStore,
prev_state_pdu=prev_state_pdu,
)
- return self._db_pool.runInteraction(_snapshot)
+ return self.runInteraction(_snapshot)
class Snapshot(object):
@@ -361,3 +378,42 @@ def read_schema(schema):
"""
with open(schema_path(schema)) as schema_file:
return schema_file.read()
+
+
+def prepare_database(db_conn):
+ """ Set up all the dbs. Since all the *.sql have IF NOT EXISTS, so we
+ don't have to worry about overwriting existing content.
+ """
+ c = db_conn.cursor()
+ c.execute("PRAGMA user_version")
+ row = c.fetchone()
+
+ if row and row[0]:
+ user_version = row[0]
+
+ if user_version > SCHEMA_VERSION:
+ raise ValueError("Cannot use this database as it is too " +
+ "new for the server to understand"
+ )
+ elif user_version < SCHEMA_VERSION:
+ logging.info("Upgrading database from version %d",
+ user_version
+ )
+
+ # Run every version since after the current version.
+ for v in range(user_version + 1, SCHEMA_VERSION + 1):
+ sql_script = read_schema("delta/v%d" % (v))
+ c.executescript(sql_script)
+
+ db_conn.commit()
+
+ else:
+ for sql_loc in SCHEMAS:
+ sql_script = read_schema(sql_loc)
+
+ c.executescript(sql_script)
+ db_conn.commit()
+ c.execute("PRAGMA user_version = %d" % SCHEMA_VERSION)
+
+ c.close()
+
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index cf88bfc22b..76ed7d06fb 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -26,6 +26,44 @@ import json
logger = logging.getLogger(__name__)
+sql_logger = logging.getLogger("synapse.storage.SQL")
+
+
+class LoggingTransaction(object):
+ """An object that almost-transparently proxies for the 'txn' object
+ passed to the constructor. Adds logging to the .execute() method."""
+ __slots__ = ["txn"]
+
+ def __init__(self, txn):
+ object.__setattr__(self, "txn", txn)
+
+ def __getattribute__(self, name):
+ if name == "execute":
+ return object.__getattribute__(self, "execute")
+
+ return getattr(object.__getattribute__(self, "txn"), name)
+
+ def __setattr__(self, name, value):
+ setattr(object.__getattribute__(self, "txn"), name, value)
+
+ def execute(self, sql, *args, **kwargs):
+ # TODO(paul): Maybe use 'info' and 'debug' for values?
+ sql_logger.debug("[SQL] %s", sql)
+ try:
+ if args and args[0]:
+ values = args[0]
+ sql_logger.debug("[SQL values] " +
+ ", ".join(("<%s>",) * len(values)), *values)
+ except:
+ # Don't let logging failures stop SQL from working
+ pass
+
+ # TODO(paul): Here would be an excellent place to put some timing
+ # measurements, and log (warning?) slow queries.
+ return object.__getattribute__(self, "txn").execute(
+ sql, *args, **kwargs
+ )
+
class SQLBaseStore(object):
@@ -35,6 +73,13 @@ class SQLBaseStore(object):
self.event_factory = hs.get_event_factory()
self._clock = hs.get_clock()
+ def runInteraction(self, func, *args, **kwargs):
+ """Wraps the .runInteraction() method on the underlying db_pool."""
+ def inner_func(txn, *args, **kwargs):
+ return func(LoggingTransaction(txn), *args, **kwargs)
+
+ return self._db_pool.runInteraction(inner_func, *args, **kwargs)
+
def cursor_to_dict(self, cursor):
"""Converts a SQL cursor into an list of dicts.
@@ -60,11 +105,6 @@ class SQLBaseStore(object):
Returns:
The result of decoder(results)
"""
- logger.debug(
- "[SQL] %s Args=%s Func=%s",
- query, args, decoder.__name__ if decoder else None
- )
-
def interaction(txn):
cursor = txn.execute(query, args)
if decoder:
@@ -72,7 +112,7 @@ class SQLBaseStore(object):
else:
return cursor.fetchall()
- return self._db_pool.runInteraction(interaction)
+ return self.runInteraction(interaction)
def _execute_and_decode(self, query, *args):
return self._execute(self.cursor_to_dict, query, *args)
@@ -88,7 +128,7 @@ class SQLBaseStore(object):
values : dict of new column names and values for them
or_replace : bool; if True performs an INSERT OR REPLACE
"""
- return self._db_pool.runInteraction(
+ return self.runInteraction(
self._simple_insert_txn, table, values, or_replace=or_replace
)
@@ -172,7 +212,7 @@ class SQLBaseStore(object):
txn.execute(sql, keyvalues.values())
return txn.fetchall()
- res = yield self._db_pool.runInteraction(func)
+ res = yield self.runInteraction(func)
defer.returnValue([r[0] for r in res])
@@ -195,7 +235,7 @@ class SQLBaseStore(object):
txn.execute(sql, keyvalues.values())
return self.cursor_to_dict(txn)
- return self._db_pool.runInteraction(func)
+ return self.runInteraction(func)
def _simple_update_one(self, table, keyvalues, updatevalues,
retcols=None):
@@ -263,7 +303,7 @@ class SQLBaseStore(object):
raise StoreError(500, "More than one row matched")
return ret
- return self._db_pool.runInteraction(func)
+ return self.runInteraction(func)
def _simple_delete_one(self, table, keyvalues):
"""Executes a DELETE query on the named table, expecting to delete a
@@ -284,7 +324,7 @@ class SQLBaseStore(object):
raise StoreError(404, "No row found")
if txn.rowcount > 1:
raise StoreError(500, "more than one row matched")
- return self._db_pool.runInteraction(func)
+ return self.runInteraction(func)
def _simple_max_id(self, table):
"""Executes a SELECT query on the named table, expecting to return the
@@ -302,7 +342,7 @@ class SQLBaseStore(object):
return 0
return max_id
- return self._db_pool.runInteraction(func)
+ return self.runInteraction(func)
def _parse_event_from_row(self, row_dict):
d = copy.deepcopy({k: v for k, v in row_dict.items() if v})
@@ -325,7 +365,7 @@ class SQLBaseStore(object):
)
def _parse_events(self, rows):
- return self._db_pool.runInteraction(self._parse_events_txn, rows)
+ return self.runInteraction(self._parse_events_txn, rows)
def _parse_events_txn(self, txn, rows):
events = [self._parse_event_from_row(r) for r in rows]
diff --git a/synapse/storage/pdu.py b/synapse/storage/pdu.py
index 3c859fdeac..d70467dcd6 100644
--- a/synapse/storage/pdu.py
+++ b/synapse/storage/pdu.py
@@ -43,7 +43,7 @@ class PduStore(SQLBaseStore):
PduTuple: If the pdu does not exist in the database, returns None
"""
- return self._db_pool.runInteraction(
+ return self.runInteraction(
self._get_pdu_tuple, pdu_id, origin
)
@@ -95,7 +95,7 @@ class PduStore(SQLBaseStore):
list: A list of PduTuples
"""
- return self._db_pool.runInteraction(
+ return self.runInteraction(
self._get_current_state_for_context,
context
)
@@ -143,7 +143,7 @@ class PduStore(SQLBaseStore):
pdu_origin (str)
"""
- return self._db_pool.runInteraction(
+ return self.runInteraction(
self._mark_as_processed, pdu_id, pdu_origin
)
@@ -152,7 +152,7 @@ class PduStore(SQLBaseStore):
def get_all_pdus_from_context(self, context):
"""Get a list of all PDUs for a given context."""
- return self._db_pool.runInteraction(
+ return self.runInteraction(
self._get_all_pdus_from_context, context,
)
@@ -179,7 +179,7 @@ class PduStore(SQLBaseStore):
Return:
list: A list of PduTuples
"""
- return self._db_pool.runInteraction(
+ return self.runInteraction(
self._get_backfill, context, pdu_list, limit
)
@@ -240,7 +240,7 @@ class PduStore(SQLBaseStore):
txn
context (str)
"""
- return self._db_pool.runInteraction(
+ return self.runInteraction(
self._get_min_depth_for_context, context
)
@@ -346,7 +346,7 @@ class PduStore(SQLBaseStore):
bool
"""
- return self._db_pool.runInteraction(
+ return self.runInteraction(
self._is_pdu_new,
pdu_id=pdu_id,
origin=origin,
@@ -499,7 +499,7 @@ class StatePduStore(SQLBaseStore):
)
def get_unresolved_state_tree(self, new_state_pdu):
- return self._db_pool.runInteraction(
+ return self.runInteraction(
self._get_unresolved_state_tree, new_state_pdu
)
@@ -538,7 +538,7 @@ class StatePduStore(SQLBaseStore):
def update_current_state(self, pdu_id, origin, context, pdu_type,
state_key):
- return self._db_pool.runInteraction(
+ return self.runInteraction(
self._update_current_state,
pdu_id, origin, context, pdu_type, state_key
)
@@ -577,7 +577,7 @@ class StatePduStore(SQLBaseStore):
PduEntry
"""
- return self._db_pool.runInteraction(
+ return self.runInteraction(
self._get_current_state_pdu, context, pdu_type, state_key
)
@@ -636,7 +636,7 @@ class StatePduStore(SQLBaseStore):
Returns:
bool: True if the new_pdu clobbered the current state, False if not
"""
- return self._db_pool.runInteraction(
+ return self.runInteraction(
self._handle_new_state, new_pdu
)
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index fd762bc643..db20b1daa0 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -62,7 +62,7 @@ class RegistrationStore(SQLBaseStore):
Raises:
StoreError if the user_id could not be registered.
"""
- yield self._db_pool.runInteraction(self._register, user_id, token,
+ yield self.runInteraction(self._register, user_id, token,
password_hash)
def _register(self, txn, user_id, token, password_hash):
@@ -99,7 +99,7 @@ class RegistrationStore(SQLBaseStore):
Raises:
StoreError if no user was found.
"""
- user_id = yield self._db_pool.runInteraction(self._query_for_auth,
+ user_id = yield self.runInteraction(self._query_for_auth,
token)
defer.returnValue(user_id)
diff --git a/synapse/storage/room.py b/synapse/storage/room.py
index 017169ce00..5adf8cdf1b 100644
--- a/synapse/storage/room.py
+++ b/synapse/storage/room.py
@@ -149,7 +149,7 @@ class RoomStore(SQLBaseStore):
defer.returnValue(None)
def get_power_level(self, room_id, user_id):
- return self._db_pool.runInteraction(
+ return self.runInteraction(
self._get_power_level,
room_id, user_id,
)
@@ -182,7 +182,7 @@ class RoomStore(SQLBaseStore):
return None
def get_ops_levels(self, room_id):
- return self._db_pool.runInteraction(
+ return self.runInteraction(
self._get_ops_levels,
room_id,
)
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 676b2f2653..04b4067d03 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -149,7 +149,7 @@ class RoomMemberStore(SQLBaseStore):
membership_list (list): A list of synapse.api.constants.Membership
values which the user must be in.
Returns:
- A list of dicts with "room_id" and "membership" keys.
+ A list of RoomMemberEvent objects
"""
if not membership_list:
return defer.succeed(None)
@@ -198,10 +198,11 @@ class RoomMemberStore(SQLBaseStore):
return results
@defer.inlineCallbacks
- def user_rooms_intersect(self, user_list):
- """ Checks whether a list of users share a room.
+ def user_rooms_intersect(self, user_id_list):
+ """ Checks whether all the users whose IDs are given in a list share a
+ room.
"""
- user_list_clause = " OR ".join(["m.user_id = ?"] * len(user_list))
+ user_list_clause = " OR ".join(["m.user_id = ?"] * len(user_id_list))
sql = (
"SELECT m.room_id FROM room_memberships as m "
"INNER JOIN current_state_events as c "
@@ -211,8 +212,8 @@ class RoomMemberStore(SQLBaseStore):
"GROUP BY m.room_id HAVING COUNT(m.room_id) = ?"
) % {"clause": user_list_clause}
- args = user_list
- args.append(len(user_list))
+ args = list(user_id_list)
+ args.append(len(user_id_list))
rows = yield self._execute(None, sql, *args)
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index aff6dc9855..8c766b8a00 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -286,7 +286,7 @@ class StreamStore(SQLBaseStore):
defer.returnValue(ret)
def get_room_events_max_id(self):
- return self._db_pool.runInteraction(self._get_room_events_max_id_txn)
+ return self.runInteraction(self._get_room_events_max_id_txn)
def _get_room_events_max_id_txn(self, txn):
txn.execute(
diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py
index 7467e1035b..ab4599b468 100644
--- a/synapse/storage/transactions.py
+++ b/synapse/storage/transactions.py
@@ -41,7 +41,7 @@ class TransactionStore(SQLBaseStore):
this transaction or a 2-tuple of (int, dict)
"""
- return self._db_pool.runInteraction(
+ return self.runInteraction(
self._get_received_txn_response, transaction_id, origin
)
@@ -72,7 +72,7 @@ class TransactionStore(SQLBaseStore):
response_json (str)
"""
- return self._db_pool.runInteraction(
+ return self.runInteraction(
self._set_received_txn_response,
transaction_id, origin, code, response_dict
)
@@ -104,7 +104,7 @@ class TransactionStore(SQLBaseStore):
list: A list of previous transaction ids.
"""
- return self._db_pool.runInteraction(
+ return self.runInteraction(
self._prep_send_transaction,
transaction_id, destination, ts, pdu_list
)
@@ -159,7 +159,7 @@ class TransactionStore(SQLBaseStore):
code (int)
response_json (str)
"""
- return self._db_pool.runInteraction(
+ return self.runInteraction(
self._delivered_txn,
transaction_id, destination, code, response_dict
)
@@ -184,7 +184,7 @@ class TransactionStore(SQLBaseStore):
Returns:
list: A list of `ReceivedTransactionsTable.EntryType`
"""
- return self._db_pool.runInteraction(
+ return self.runInteraction(
self._get_transactions_after, transaction_id, destination
)
@@ -214,7 +214,7 @@ class TransactionStore(SQLBaseStore):
Returns
list: A list of PduTuple
"""
- return self._db_pool.runInteraction(
+ return self.runInteraction(
self._get_pdus_after_transaction,
transaction_id, destination
)
diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index 54d6e51f97..dd5d85dde6 100644
--- a/tests/handlers/test_directory.py
+++ b/tests/handlers/test_directory.py
@@ -24,6 +24,8 @@ from synapse.http.client import HttpClient
from synapse.handlers.directory import DirectoryHandler
from synapse.storage.directory import RoomAliasMapping
+from tests.utils import SQLiteMemoryDbPool
+
class DirectoryHandlers(object):
def __init__(self, hs):
@@ -33,6 +35,7 @@ class DirectoryHandlers(object):
class DirectoryTestCase(unittest.TestCase):
""" Tests the directory service. """
+ @defer.inlineCallbacks
def setUp(self):
self.mock_federation = Mock(spec=[
"make_query",
@@ -43,11 +46,11 @@ class DirectoryTestCase(unittest.TestCase):
self.query_handlers[query_type] = handler
self.mock_federation.register_query_handler = register_query_handler
+ db_pool = SQLiteMemoryDbPool()
+ yield db_pool.prepare()
+
hs = HomeServer("test",
- datastore=Mock(spec=[
- "get_association_from_room_alias",
- "get_joined_hosts_for_room",
- ]),
+ db_pool=db_pool,
http_client=None,
resource_for_federation=Mock(),
replication_layer=self.mock_federation,
@@ -56,20 +59,16 @@ class DirectoryTestCase(unittest.TestCase):
self.handler = hs.get_handlers().directory_handler
- self.datastore = hs.get_datastore()
-
- def hosts(room_id):
- return defer.succeed([])
- self.datastore.get_joined_hosts_for_room.side_effect = hosts
+ self.store = hs.get_datastore()
self.my_room = hs.parse_roomalias("#my-room:test")
+ self.your_room = hs.parse_roomalias("#your-room:test")
self.remote_room = hs.parse_roomalias("#another:remote")
@defer.inlineCallbacks
def test_get_local_association(self):
- mocked_get = self.datastore.get_association_from_room_alias
- mocked_get.return_value = defer.succeed(
- RoomAliasMapping("!8765qwer:test", "#my-room:test", ["test"])
+ yield self.store.create_room_alias_association(
+ self.my_room, "!8765qwer:test", ["test"]
)
result = yield self.handler.get_association(self.my_room)
@@ -102,9 +101,8 @@ class DirectoryTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_incoming_fed_query(self):
- mocked_get = self.datastore.get_association_from_room_alias
- mocked_get.return_value = defer.succeed(
- RoomAliasMapping("!8765asdf:test", "#your-room:test", ["test"])
+ yield self.store.create_room_alias_association(
+ self.your_room, "!8765asdf:test", ["test"]
)
response = yield self.query_handlers["directory"](
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index 0cb4dfba39..765929d204 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -20,7 +20,9 @@ from twisted.internet import defer, reactor
from mock import Mock, call, ANY
import json
-from ..utils import MockHttpResource, MockClock, DeferredMockCallable
+from tests.utils import (
+ MockHttpResource, MockClock, DeferredMockCallable, SQLiteMemoryDbPool
+)
from synapse.server import HomeServer
from synapse.api.constants import PresenceState
@@ -60,30 +62,21 @@ class JustPresenceHandlers(object):
class PresenceStateTestCase(unittest.TestCase):
""" Tests presence management. """
+ @defer.inlineCallbacks
def setUp(self):
+ db_pool = SQLiteMemoryDbPool()
+ yield db_pool.prepare()
+
hs = HomeServer("test",
- clock=MockClock(),
- db_pool=None,
- datastore=Mock(spec=[
- "get_presence_state",
- "set_presence_state",
- "add_presence_list_pending",
- "set_presence_list_accepted",
- ]),
- handlers=None,
- resource_for_federation=Mock(),
- http_client=None,
- )
+ clock=MockClock(),
+ db_pool=db_pool,
+ handlers=None,
+ resource_for_federation=Mock(),
+ http_client=None,
+ )
hs.handlers = JustPresenceHandlers(hs)
- self.datastore = hs.get_datastore()
-
- def is_presence_visible(observed_localpart, observer_userid):
- allow = (observed_localpart == "apple" and
- observer_userid == "@banana:test"
- )
- return defer.succeed(allow)
- self.datastore.is_presence_visible = is_presence_visible
+ self.store = hs.get_datastore()
# Mock the RoomMemberHandler
room_member_handler = Mock(spec=[])
@@ -94,6 +87,11 @@ class PresenceStateTestCase(unittest.TestCase):
self.u_banana = hs.parse_userid("@banana:test")
self.u_clementine = hs.parse_userid("@clementine:test")
+ yield self.store.create_presence(self.u_apple.localpart)
+ yield self.store.set_presence_state(
+ self.u_apple.localpart, {"state": ONLINE, "status_msg": "Online"}
+ )
+
self.handler = hs.get_handlers().presence_handler
self.room_members = []
@@ -117,7 +115,7 @@ class PresenceStateTestCase(unittest.TestCase):
shared = all(map(lambda i: i in room_member_ids, userlist))
return defer.succeed(shared)
- self.datastore.user_rooms_intersect = user_rooms_intersect
+ self.store.user_rooms_intersect = user_rooms_intersect
self.mock_start = Mock()
self.mock_stop = Mock()
@@ -127,11 +125,6 @@ class PresenceStateTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_get_my_state(self):
- mocked_get = self.datastore.get_presence_state
- mocked_get.return_value = defer.succeed(
- {"state": ONLINE, "status_msg": "Online"}
- )
-
state = yield self.handler.get_state(
target_user=self.u_apple, auth_user=self.u_apple
)
@@ -140,13 +133,12 @@ class PresenceStateTestCase(unittest.TestCase):
{"presence": ONLINE, "status_msg": "Online"},
state
)
- mocked_get.assert_called_with("apple")
@defer.inlineCallbacks
def test_get_allowed_state(self):
- mocked_get = self.datastore.get_presence_state
- mocked_get.return_value = defer.succeed(
- {"state": ONLINE, "status_msg": "Online"}
+ yield self.store.allow_presence_visible(
+ observed_localpart=self.u_apple.localpart,
+ observer_userid=self.u_banana.to_string(),
)
state = yield self.handler.get_state(
@@ -157,15 +149,9 @@ class PresenceStateTestCase(unittest.TestCase):
{"presence": ONLINE, "status_msg": "Online"},
state
)
- mocked_get.assert_called_with("apple")
@defer.inlineCallbacks
def test_get_same_room_state(self):
- mocked_get = self.datastore.get_presence_state
- mocked_get.return_value = defer.succeed(
- {"state": ONLINE, "status_msg": "Online"}
- )
-
self.room_members = [self.u_apple, self.u_clementine]
state = yield self.handler.get_state(
@@ -179,11 +165,6 @@ class PresenceStateTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_get_disallowed_state(self):
- mocked_get = self.datastore.get_presence_state
- mocked_get.return_value = defer.succeed(
- {"state": ONLINE, "status_msg": "Online"}
- )
-
self.room_members = []
yield self.assertFailure(
@@ -195,16 +176,17 @@ class PresenceStateTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_set_my_state(self):
- mocked_set = self.datastore.set_presence_state
- mocked_set.return_value = defer.succeed({"state": OFFLINE})
-
yield self.handler.set_state(
target_user=self.u_apple, auth_user=self.u_apple,
state={"presence": UNAVAILABLE, "status_msg": "Away"})
- mocked_set.assert_called_with("apple",
- {"state": UNAVAILABLE, "status_msg": "Away"}
+ self.assertEquals(
+ {"state": UNAVAILABLE,
+ "status_msg": "Away",
+ "mtime": 1000000},
+ (yield self.store.get_presence_state(self.u_apple.localpart))
)
+
self.mock_start.assert_called_with(self.u_apple,
state={
"presence": UNAVAILABLE,
@@ -222,50 +204,34 @@ class PresenceStateTestCase(unittest.TestCase):
class PresenceInvitesTestCase(unittest.TestCase):
""" Tests presence management. """
+ @defer.inlineCallbacks
def setUp(self):
self.mock_http_client = Mock(spec=[])
self.mock_http_client.put_json = DeferredMockCallable()
self.mock_federation_resource = MockHttpResource()
- hs = HomeServer("test",
- clock=MockClock(),
- db_pool=None,
- datastore=Mock(spec=[
- "has_presence_state",
- "allow_presence_visible",
- "add_presence_list_pending",
- "set_presence_list_accepted",
- "get_presence_list",
- "del_presence_list",
+ db_pool = SQLiteMemoryDbPool()
+ yield db_pool.prepare()
- # Bits that Federation needs
- "prep_send_transaction",
- "delivered_txn",
- "get_received_txn_response",
- "set_received_txn_response",
- ]),
- handlers=None,
- resource_for_client=Mock(),
- resource_for_federation=self.mock_federation_resource,
- http_client=self.mock_http_client,
- )
+ hs = HomeServer("test",
+ clock=MockClock(),
+ db_pool=db_pool,
+ handlers=None,
+ resource_for_client=Mock(),
+ resource_for_federation=self.mock_federation_resource,
+ http_client=self.mock_http_client,
+ )
hs.handlers = JustPresenceHandlers(hs)
- self.datastore = hs.get_datastore()
-
- def has_presence_state(user_localpart):
- return defer.succeed(
- user_localpart in ("apple", "banana"))
- self.datastore.has_presence_state = has_presence_state
-
- def get_received_txn_response(*args):
- return defer.succeed(None)
- self.datastore.get_received_txn_response = get_received_txn_response
+ self.store = hs.get_datastore()
# Some local users to test with
self.u_apple = hs.parse_userid("@apple:test")
self.u_banana = hs.parse_userid("@banana:test")
+ yield self.store.create_presence(self.u_apple.localpart)
+ yield self.store.create_presence(self.u_banana.localpart)
+
# ID of a local user that does not exist
self.u_durian = hs.parse_userid("@durian:test")
@@ -288,12 +254,16 @@ class PresenceInvitesTestCase(unittest.TestCase):
yield self.handler.send_invite(
observer_user=self.u_apple, observed_user=self.u_banana)
- self.datastore.add_presence_list_pending.assert_called_with(
- "apple", "@banana:test")
- self.datastore.allow_presence_visible.assert_called_with(
- "banana", "@apple:test")
- self.datastore.set_presence_list_accepted.assert_called_with(
- "apple", "@banana:test")
+ self.assertEquals(
+ [{"observed_user_id": "@banana:test", "accepted": 1}],
+ (yield self.store.get_presence_list(self.u_apple.localpart))
+ )
+ self.assertTrue(
+ (yield self.store.is_presence_visible(
+ observed_localpart=self.u_banana.localpart,
+ observer_userid=self.u_apple.to_string(),
+ ))
+ )
self.mock_start.assert_called_with(
self.u_apple, target_user=self.u_banana)
@@ -303,10 +273,10 @@ class PresenceInvitesTestCase(unittest.TestCase):
yield self.handler.send_invite(
observer_user=self.u_apple, observed_user=self.u_durian)
- self.datastore.add_presence_list_pending.assert_called_with(
- "apple", "@durian:test")
- self.datastore.del_presence_list.assert_called_with(
- "apple", "@durian:test")
+ self.assertEquals(
+ [],
+ (yield self.store.get_presence_list(self.u_apple.localpart))
+ )
@defer.inlineCallbacks
def test_invite_remote(self):
@@ -328,8 +298,10 @@ class PresenceInvitesTestCase(unittest.TestCase):
yield self.handler.send_invite(
observer_user=self.u_apple, observed_user=self.u_cabbage)
- self.datastore.add_presence_list_pending.assert_called_with(
- "apple", "@cabbage:elsewhere")
+ self.assertEquals(
+ [{"observed_user_id": "@cabbage:elsewhere", "accepted": 0}],
+ (yield self.store.get_presence_list(self.u_apple.localpart))
+ )
yield put_json.await_calls()
@@ -362,8 +334,12 @@ class PresenceInvitesTestCase(unittest.TestCase):
)
)
- self.datastore.allow_presence_visible.assert_called_with(
- "apple", "@cabbage:elsewhere")
+ self.assertTrue(
+ (yield self.store.is_presence_visible(
+ observed_localpart=self.u_apple.localpart,
+ observer_userid=self.u_cabbage.to_string(),
+ ))
+ )
yield put_json.await_calls()
@@ -398,6 +374,11 @@ class PresenceInvitesTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_accepted_remote(self):
+ yield self.store.add_presence_list_pending(
+ observer_localpart=self.u_apple.localpart,
+ observed_userid=self.u_cabbage.to_string(),
+ )
+
yield self.mock_federation_resource.trigger("PUT",
"/_matrix/federation/v1/send/1000000/",
_make_edu_json("elsewhere", "m.presence_accept",
@@ -408,14 +389,21 @@ class PresenceInvitesTestCase(unittest.TestCase):
)
)
- self.datastore.set_presence_list_accepted.assert_called_with(
- "apple", "@cabbage:elsewhere")
+ self.assertEquals(
+ [{"observed_user_id": "@cabbage:elsewhere", "accepted": 1}],
+ (yield self.store.get_presence_list(self.u_apple.localpart))
+ )
self.mock_start.assert_called_with(
self.u_apple, target_user=self.u_cabbage)
@defer.inlineCallbacks
def test_denied_remote(self):
+ yield self.store.add_presence_list_pending(
+ observer_localpart=self.u_apple.localpart,
+ observed_userid="@eggplant:elsewhere",
+ )
+
yield self.mock_federation_resource.trigger("PUT",
"/_matrix/federation/v1/send/1000000/",
_make_edu_json("elsewhere", "m.presence_deny",
@@ -426,62 +414,76 @@ class PresenceInvitesTestCase(unittest.TestCase):
)
)
- self.datastore.del_presence_list.assert_called_with(
- "apple", "@eggplant:elsewhere")
+ self.assertEquals(
+ [],
+ (yield self.store.get_presence_list(self.u_apple.localpart))
+ )
@defer.inlineCallbacks
def test_drop_local(self):
+ yield self.store.add_presence_list_pending(
+ observer_localpart=self.u_apple.localpart,
+ observed_userid=self.u_banana.to_string(),
+ )
+ yield self.store.set_presence_list_accepted(
+ observer_localpart=self.u_apple.localpart,
+ observed_userid=self.u_banana.to_string(),
+ )
+
yield self.handler.drop(
- observer_user=self.u_apple, observed_user=self.u_banana)
+ observer_user=self.u_apple,
+ observed_user=self.u_banana,
+ )
- self.datastore.del_presence_list.assert_called_with(
- "apple", "@banana:test")
+ self.assertEquals(
+ [],
+ (yield self.store.get_presence_list(self.u_apple.localpart))
+ )
self.mock_stop.assert_called_with(
self.u_apple, target_user=self.u_banana)
@defer.inlineCallbacks
def test_drop_remote(self):
+ yield self.store.add_presence_list_pending(
+ observer_localpart=self.u_apple.localpart,
+ observed_userid=self.u_cabbage.to_string(),
+ )
+ yield self.store.set_presence_list_accepted(
+ observer_localpart=self.u_apple.localpart,
+ observed_userid=self.u_cabbage.to_string(),
+ )
+
yield self.handler.drop(
- observer_user=self.u_apple, observed_user=self.u_cabbage)
+ observer_user=self.u_apple,
+ observed_user=self.u_cabbage,
+ )
- self.datastore.del_presence_list.assert_called_with(
- "apple", "@cabbage:elsewhere")
+ self.assertEquals(
+ [],
+ (yield self.store.get_presence_list(self.u_apple.localpart))
+ )
@defer.inlineCallbacks
def test_get_presence_list(self):
- self.datastore.get_presence_list.return_value = defer.succeed(
- [{"observed_user_id": "@banana:test"}]
- )
-
- presence = yield self.handler.get_presence_list(
- observer_user=self.u_apple)
-
- self.assertEquals([
- {"observed_user": self.u_banana,
- "presence": OFFLINE},
- ], presence)
-
- self.datastore.get_presence_list.assert_called_with("apple",
- accepted=None
+ yield self.store.add_presence_list_pending(
+ observer_localpart=self.u_apple.localpart,
+ observed_userid=self.u_banana.to_string(),
)
-
- self.datastore.get_presence_list.return_value = defer.succeed(
- [{"observed_user_id": "@banana:test"}]
+ yield self.store.set_presence_list_accepted(
+ observer_localpart=self.u_apple.localpart,
+ observed_userid=self.u_banana.to_string(),
)
presence = yield self.handler.get_presence_list(
- observer_user=self.u_apple, accepted=True
- )
+ observer_user=self.u_apple)
self.assertEquals([
{"observed_user": self.u_banana,
- "presence": OFFLINE},
+ "presence": OFFLINE,
+ "accepted": 1},
], presence)
- self.datastore.get_presence_list.assert_called_with("apple",
- accepted=True)
-
class PresencePushTestCase(unittest.TestCase):
""" Tests steady-state presence status updates.
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index ee2be9b6d5..5dc9b456e1 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -24,6 +24,8 @@ from synapse.server import HomeServer
from synapse.handlers.profile import ProfileHandler
from synapse.api.constants import Membership
+from tests.utils import SQLiteMemoryDbPool
+
class ProfileHandlers(object):
def __init__(self, hs):
@@ -33,6 +35,7 @@ class ProfileHandlers(object):
class ProfileTestCase(unittest.TestCase):
""" Tests profile management. """
+ @defer.inlineCallbacks
def setUp(self):
self.mock_federation = Mock(spec=[
"make_query",
@@ -43,63 +46,50 @@ class ProfileTestCase(unittest.TestCase):
self.query_handlers[query_type] = handler
self.mock_federation.register_query_handler = register_query_handler
+ db_pool = SQLiteMemoryDbPool()
+ yield db_pool.prepare()
+
hs = HomeServer("test",
- db_pool=None,
+ db_pool=db_pool,
http_client=None,
- datastore=Mock(spec=[
- "get_profile_displayname",
- "set_profile_displayname",
- "get_profile_avatar_url",
- "set_profile_avatar_url",
- "get_rooms_for_user_where_membership_is",
- ]),
handlers=None,
resource_for_federation=Mock(),
replication_layer=self.mock_federation,
)
hs.handlers = ProfileHandlers(hs)
- self.datastore = hs.get_datastore()
+ self.store = hs.get_datastore()
self.frank = hs.parse_userid("@1234ABCD:test")
self.bob = hs.parse_userid("@4567:test")
self.alice = hs.parse_userid("@alice:remote")
- self.handler = hs.get_handlers().profile_handler
+ yield self.store.create_profile(self.frank.localpart)
- self.mock_get_joined = (
- self.datastore.get_rooms_for_user_where_membership_is
- )
+ self.handler = hs.get_handlers().profile_handler
# TODO(paul): Icky signal declarings.. booo
hs.get_distributor().declare("changed_presencelike_data")
@defer.inlineCallbacks
def test_get_my_name(self):
- mocked_get = self.datastore.get_profile_displayname
- mocked_get.return_value = defer.succeed("Frank")
+ yield self.store.set_profile_displayname(
+ self.frank.localpart, "Frank"
+ )
displayname = yield self.handler.get_displayname(self.frank)
self.assertEquals("Frank", displayname)
- mocked_get.assert_called_with("1234ABCD")
@defer.inlineCallbacks
def test_set_my_name(self):
- mocked_set = self.datastore.set_profile_displayname
- mocked_set.return_value = defer.succeed(())
-
- self.mock_get_joined.return_value = defer.succeed([])
-
yield self.handler.set_displayname(self.frank, self.frank, "Frank Jr.")
- self.mock_get_joined.assert_called_once_with(
- self.frank.to_string(),
- [Membership.JOIN]
+ self.assertEquals(
+ (yield self.store.get_profile_displayname(self.frank.localpart)),
+ "Frank Jr."
)
- mocked_set.assert_called_with("1234ABCD", "Frank Jr.")
-
@defer.inlineCallbacks
def test_set_my_name_noauth(self):
d = self.handler.set_displayname(self.frank, self.bob, "Frank Jr.")
@@ -123,40 +113,31 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_incoming_fed_query(self):
- mocked_get = self.datastore.get_profile_displayname
- mocked_get.return_value = defer.succeed("Caroline")
+ yield self.store.create_profile("caroline")
+ yield self.store.set_profile_displayname("caroline", "Caroline")
response = yield self.query_handlers["profile"](
{"user_id": "@caroline:test", "field": "displayname"}
)
self.assertEquals({"displayname": "Caroline"}, response)
- mocked_get.assert_called_with("caroline")
@defer.inlineCallbacks
def test_get_my_avatar(self):
- mocked_get = self.datastore.get_profile_avatar_url
- mocked_get.return_value = defer.succeed("http://my.server/me.png")
+ yield self.store.set_profile_avatar_url(
+ self.frank.localpart, "http://my.server/me.png"
+ )
avatar_url = yield self.handler.get_avatar_url(self.frank)
self.assertEquals("http://my.server/me.png", avatar_url)
- mocked_get.assert_called_with("1234ABCD")
@defer.inlineCallbacks
def test_set_my_avatar(self):
- mocked_set = self.datastore.set_profile_avatar_url
- mocked_set.return_value = defer.succeed(())
-
- self.mock_get_joined.return_value = defer.succeed([])
-
yield self.handler.set_avatar_url(self.frank, self.frank,
"http://my.server/pic.gif")
- self.mock_get_joined.assert_called_once_with(
- self.frank.to_string(),
- [Membership.JOIN]
+ self.assertEquals(
+ (yield self.store.get_profile_avatar_url(self.frank.localpart)),
+ "http://my.server/pic.gif"
)
-
-
- mocked_set.assert_called_with("1234ABCD", "http://my.server/pic.gif")
diff --git a/tests/storage/TESTS_NEEDED_FOR b/tests/storage/TESTS_NEEDED_FOR
new file mode 100644
index 0000000000..8e5d0cbdc4
--- /dev/null
+++ b/tests/storage/TESTS_NEEDED_FOR
@@ -0,0 +1,5 @@
+synapse/storage/feedback.py
+synapse/storage/keys.py
+synapse/storage/pdu.py
+synapse/storage/stream.py
+synapse/storage/transactions.py
diff --git a/tests/storage/test_directory.py b/tests/storage/test_directory.py
new file mode 100644
index 0000000000..7e8e7e1e83
--- /dev/null
+++ b/tests/storage/test_directory.py
@@ -0,0 +1,68 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from tests import unittest
+from twisted.internet import defer
+
+from synapse.server import HomeServer
+from synapse.storage.directory import DirectoryStore
+
+from tests.utils import SQLiteMemoryDbPool
+
+
+class DirectoryStoreTestCase(unittest.TestCase):
+
+ @defer.inlineCallbacks
+ def setUp(self):
+ db_pool = SQLiteMemoryDbPool()
+ yield db_pool.prepare()
+
+ hs = HomeServer("test",
+ db_pool=db_pool,
+ )
+
+ self.store = DirectoryStore(hs)
+
+ self.room = hs.parse_roomid("!abcde:test")
+ self.alias = hs.parse_roomalias("#my-room:test")
+
+ @defer.inlineCallbacks
+ def test_room_to_alias(self):
+ yield self.store.create_room_alias_association(
+ room_alias=self.alias,
+ room_id=self.room.to_string(),
+ servers=["test"],
+ )
+
+ self.assertEquals(
+ ["#my-room:test"],
+ (yield self.store.get_aliases_for_room(self.room.to_string()))
+ )
+
+ @defer.inlineCallbacks
+ def test_alias_to_room(self):
+ yield self.store.create_room_alias_association(
+ room_alias=self.alias,
+ room_id=self.room.to_string(),
+ servers=["test"],
+ )
+
+
+ self.assertObjectHasAttributes(
+ {"room_id": self.room.to_string(),
+ "servers": ["test"]},
+ (yield self.store.get_association_from_room_alias(self.alias))
+ )
diff --git a/tests/storage/test_presence.py b/tests/storage/test_presence.py
new file mode 100644
index 0000000000..9655d3cf42
--- /dev/null
+++ b/tests/storage/test_presence.py
@@ -0,0 +1,167 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from tests import unittest
+from twisted.internet import defer
+
+from synapse.server import HomeServer
+from synapse.storage.presence import PresenceStore
+
+from tests.utils import SQLiteMemoryDbPool, MockClock
+
+
+class PresenceStoreTestCase(unittest.TestCase):
+
+ @defer.inlineCallbacks
+ def setUp(self):
+ db_pool = SQLiteMemoryDbPool()
+ yield db_pool.prepare()
+
+ hs = HomeServer("test",
+ clock=MockClock(),
+ db_pool=db_pool,
+ )
+
+ self.store = PresenceStore(hs)
+
+ self.u_apple = hs.parse_userid("@apple:test")
+ self.u_banana = hs.parse_userid("@banana:test")
+
+ @defer.inlineCallbacks
+ def test_state(self):
+ yield self.store.create_presence(
+ self.u_apple.localpart
+ )
+
+ state = yield self.store.get_presence_state(
+ self.u_apple.localpart
+ )
+
+ self.assertEquals(
+ {"state": None, "status_msg": None, "mtime": None}, state
+ )
+
+ yield self.store.set_presence_state(
+ self.u_apple.localpart, {"state": "online", "status_msg": "Here"}
+ )
+
+ state = yield self.store.get_presence_state(
+ self.u_apple.localpart
+ )
+
+ self.assertEquals(
+ {"state": "online", "status_msg": "Here", "mtime": 1000000}, state
+ )
+
+ @defer.inlineCallbacks
+ def test_visibility(self):
+ self.assertFalse((yield self.store.is_presence_visible(
+ observed_localpart=self.u_apple.localpart,
+ observer_userid=self.u_banana.to_string(),
+ )))
+
+ yield self.store.allow_presence_visible(
+ observed_localpart=self.u_apple.localpart,
+ observer_userid=self.u_banana.to_string(),
+ )
+
+ self.assertTrue((yield self.store.is_presence_visible(
+ observed_localpart=self.u_apple.localpart,
+ observer_userid=self.u_banana.to_string(),
+ )))
+
+ yield self.store.disallow_presence_visible(
+ observed_localpart=self.u_apple.localpart,
+ observer_userid=self.u_banana.to_string(),
+ )
+
+ self.assertFalse((yield self.store.is_presence_visible(
+ observed_localpart=self.u_apple.localpart,
+ observer_userid=self.u_banana.to_string(),
+ )))
+
+ @defer.inlineCallbacks
+ def test_presence_list(self):
+ self.assertEquals(
+ [],
+ (yield self.store.get_presence_list(
+ observer_localpart=self.u_apple.localpart,
+ ))
+ )
+ self.assertEquals(
+ [],
+ (yield self.store.get_presence_list(
+ observer_localpart=self.u_apple.localpart,
+ accepted=True,
+ ))
+ )
+
+ yield self.store.add_presence_list_pending(
+ observer_localpart=self.u_apple.localpart,
+ observed_userid=self.u_banana.to_string(),
+ )
+
+ self.assertEquals(
+ [{"observed_user_id": "@banana:test", "accepted": 0}],
+ (yield self.store.get_presence_list(
+ observer_localpart=self.u_apple.localpart,
+ ))
+ )
+ self.assertEquals(
+ [],
+ (yield self.store.get_presence_list(
+ observer_localpart=self.u_apple.localpart,
+ accepted=True,
+ ))
+ )
+
+ yield self.store.set_presence_list_accepted(
+ observer_localpart=self.u_apple.localpart,
+ observed_userid=self.u_banana.to_string(),
+ )
+
+ self.assertEquals(
+ [{"observed_user_id": "@banana:test", "accepted": 1}],
+ (yield self.store.get_presence_list(
+ observer_localpart=self.u_apple.localpart,
+ ))
+ )
+ self.assertEquals(
+ [{"observed_user_id": "@banana:test", "accepted": 1}],
+ (yield self.store.get_presence_list(
+ observer_localpart=self.u_apple.localpart,
+ accepted=True,
+ ))
+ )
+
+ yield self.store.del_presence_list(
+ observer_localpart=self.u_apple.localpart,
+ observed_userid=self.u_banana.to_string(),
+ )
+
+ self.assertEquals(
+ [],
+ (yield self.store.get_presence_list(
+ observer_localpart=self.u_apple.localpart,
+ ))
+ )
+ self.assertEquals(
+ [],
+ (yield self.store.get_presence_list(
+ observer_localpart=self.u_apple.localpart,
+ accepted=True,
+ ))
+ )
diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py
new file mode 100644
index 0000000000..5d36723c28
--- /dev/null
+++ b/tests/storage/test_profile.py
@@ -0,0 +1,69 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from tests import unittest
+from twisted.internet import defer
+
+from synapse.server import HomeServer
+from synapse.storage.profile import ProfileStore
+
+from tests.utils import SQLiteMemoryDbPool
+
+
+class ProfileStoreTestCase(unittest.TestCase):
+
+ @defer.inlineCallbacks
+ def setUp(self):
+ db_pool = SQLiteMemoryDbPool()
+ yield db_pool.prepare()
+
+ hs = HomeServer("test",
+ db_pool=db_pool,
+ )
+
+ self.store = ProfileStore(hs)
+
+ self.u_frank = hs.parse_userid("@frank:test")
+
+ @defer.inlineCallbacks
+ def test_displayname(self):
+ yield self.store.create_profile(
+ self.u_frank.localpart
+ )
+
+ yield self.store.set_profile_displayname(
+ self.u_frank.localpart, "Frank"
+ )
+
+ self.assertEquals(
+ "Frank",
+ (yield self.store.get_profile_displayname(self.u_frank.localpart))
+ )
+
+ @defer.inlineCallbacks
+ def test_avatar_url(self):
+ yield self.store.create_profile(
+ self.u_frank.localpart
+ )
+
+ yield self.store.set_profile_avatar_url(
+ self.u_frank.localpart, "http://my.site/here"
+ )
+
+ self.assertEquals(
+ "http://my.site/here",
+ (yield self.store.get_profile_avatar_url(self.u_frank.localpart))
+ )
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
new file mode 100644
index 0000000000..91e221d53e
--- /dev/null
+++ b/tests/storage/test_registration.py
@@ -0,0 +1,69 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from tests import unittest
+from twisted.internet import defer
+
+from synapse.server import HomeServer
+from synapse.storage.registration import RegistrationStore
+
+from tests.utils import SQLiteMemoryDbPool
+
+
+class RegistrationStoreTestCase(unittest.TestCase):
+
+ @defer.inlineCallbacks
+ def setUp(self):
+ db_pool = SQLiteMemoryDbPool()
+ yield db_pool.prepare()
+
+ hs = HomeServer("test",
+ db_pool=db_pool,
+ )
+
+ self.store = RegistrationStore(hs)
+
+ self.user_id = "@my-user:test"
+ self.tokens = ["AbCdEfGhIjKlMnOpQrStUvWxYz",
+ "BcDeFgHiJkLmNoPqRsTuVwXyZa"]
+ self.pwhash = "{xx1}123456789"
+
+ @defer.inlineCallbacks
+ def test_register(self):
+ yield self.store.register(self.user_id, self.tokens[0], self.pwhash)
+
+ self.assertEquals(
+ # TODO(paul): Surely this field should be 'user_id', not 'name'
+ # Additionally surely it shouldn't come in a 1-element list
+ [{"name": self.user_id, "password_hash": self.pwhash}],
+ (yield self.store.get_user_by_id(self.user_id))
+ )
+
+ self.assertEquals(
+ self.user_id,
+ (yield self.store.get_user_by_token(self.tokens[0]))
+ )
+
+ @defer.inlineCallbacks
+ def test_add_tokens(self):
+ yield self.store.register(self.user_id, self.tokens[0], self.pwhash)
+ yield self.store.add_access_token_to_user(self.user_id, self.tokens[1])
+
+ self.assertEquals(
+ self.user_id,
+ (yield self.store.get_user_by_token(self.tokens[1]))
+ )
+
diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py
new file mode 100644
index 0000000000..369a73d917
--- /dev/null
+++ b/tests/storage/test_room.py
@@ -0,0 +1,176 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from tests import unittest
+from twisted.internet import defer
+
+from synapse.server import HomeServer
+from synapse.api.events.room import (
+ RoomNameEvent, RoomTopicEvent
+)
+
+from tests.utils import SQLiteMemoryDbPool
+
+
+class RoomStoreTestCase(unittest.TestCase):
+
+ @defer.inlineCallbacks
+ def setUp(self):
+ db_pool = SQLiteMemoryDbPool()
+ yield db_pool.prepare()
+
+ hs = HomeServer("test",
+ db_pool=db_pool,
+ )
+
+ # We can't test RoomStore on its own without the DirectoryStore, for
+ # management of the 'room_aliases' table
+ self.store = hs.get_datastore()
+
+ self.room = hs.parse_roomid("!abcde:test")
+ self.alias = hs.parse_roomalias("#a-room-name:test")
+ self.u_creator = hs.parse_userid("@creator:test")
+
+ yield self.store.store_room(self.room.to_string(),
+ room_creator_user_id=self.u_creator.to_string(),
+ is_public=True
+ )
+
+ @defer.inlineCallbacks
+ def test_get_room(self):
+ self.assertObjectHasAttributes(
+ {"room_id": self.room.to_string(),
+ "creator": self.u_creator.to_string(),
+ "is_public": True},
+ (yield self.store.get_room(self.room.to_string()))
+ )
+
+ @defer.inlineCallbacks
+ def test_store_room_config(self):
+ yield self.store.store_room_config(self.room.to_string(),
+ visibility=False
+ )
+
+ self.assertObjectHasAttributes(
+ {"is_public": False},
+ (yield self.store.get_room(self.room.to_string()))
+ )
+
+ @defer.inlineCallbacks
+ def test_get_rooms(self):
+ # get_rooms does an INNER JOIN on the room_aliases table :(
+
+ rooms = yield self.store.get_rooms(is_public=True)
+ # Should be empty before we add the alias
+ self.assertEquals([], rooms)
+
+ yield self.store.create_room_alias_association(
+ room_alias=self.alias,
+ room_id=self.room.to_string(),
+ servers=["test"]
+ )
+
+ rooms = yield self.store.get_rooms(is_public=True)
+
+ self.assertEquals(1, len(rooms))
+ self.assertEquals({
+ "name": None,
+ "room_id": self.room.to_string(),
+ "topic": None,
+ "aliases": [self.alias.to_string()],
+ }, rooms[0])
+
+
+class RoomEventsStoreTestCase(unittest.TestCase):
+
+ @defer.inlineCallbacks
+ def setUp(self):
+ db_pool = SQLiteMemoryDbPool()
+ yield db_pool.prepare()
+
+ hs = HomeServer("test",
+ db_pool=db_pool,
+ )
+
+ # Room events need the full datastore, for persist_event() and
+ # get_room_state()
+ self.store = hs.get_datastore()
+ self.event_factory = hs.get_event_factory();
+
+ self.room = hs.parse_roomid("!abcde:test")
+
+ yield self.store.store_room(self.room.to_string(),
+ room_creator_user_id="@creator:text",
+ is_public=True
+ )
+
+ @defer.inlineCallbacks
+ def inject_room_event(self, **kwargs):
+ yield self.store.persist_event(
+ self.event_factory.create_event(
+ room_id=self.room.to_string(),
+ **kwargs
+ )
+ )
+
+ @defer.inlineCallbacks
+ def test_room_name(self):
+ name = u"A-Room-Name"
+
+ yield self.inject_room_event(
+ etype=RoomNameEvent.TYPE,
+ name=name,
+ content={"name": name},
+ depth=1,
+ )
+
+ state = yield self.store.get_current_state(
+ room_id=self.room.to_string()
+ )
+
+ self.assertEquals(1, len(state))
+ self.assertObjectHasAttributes(
+ {"type": "m.room.name",
+ "room_id": self.room.to_string(),
+ "name": name},
+ state[0]
+ )
+
+ @defer.inlineCallbacks
+ def test_room_name(self):
+ topic = u"A place for things"
+
+ yield self.inject_room_event(
+ etype=RoomTopicEvent.TYPE,
+ topic=topic,
+ content={"topic": topic},
+ depth=1,
+ )
+
+ state = yield self.store.get_current_state(
+ room_id=self.room.to_string()
+ )
+
+ self.assertEquals(1, len(state))
+ self.assertObjectHasAttributes(
+ {"type": "m.room.topic",
+ "room_id": self.room.to_string(),
+ "topic": topic},
+ state[0]
+ )
+
+ # Not testing the various 'level' methods for now because there's lots
+ # of them and need coalescing; see JIRA SPEC-11
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
new file mode 100644
index 0000000000..eae278ee8d
--- /dev/null
+++ b/tests/storage/test_roommember.py
@@ -0,0 +1,157 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from tests import unittest
+from twisted.internet import defer
+
+from synapse.server import HomeServer
+from synapse.api.constants import Membership
+from synapse.api.events.room import RoomMemberEvent
+
+from tests.utils import SQLiteMemoryDbPool
+
+
+class RoomMemberStoreTestCase(unittest.TestCase):
+
+ @defer.inlineCallbacks
+ def setUp(self):
+ db_pool = SQLiteMemoryDbPool()
+ yield db_pool.prepare()
+
+ hs = HomeServer("test",
+ db_pool=db_pool,
+ )
+
+ # We can't test the RoomMemberStore on its own without the other event
+ # storage logic
+ self.store = hs.get_datastore()
+ self.event_factory = hs.get_event_factory()
+
+ self.u_alice = hs.parse_userid("@alice:test")
+ self.u_bob = hs.parse_userid("@bob:test")
+
+ # User elsewhere on another host
+ self.u_charlie = hs.parse_userid("@charlie:elsewhere")
+
+ self.room = hs.parse_roomid("!abc123:test")
+
+ @defer.inlineCallbacks
+ def inject_room_member(self, room, user, membership):
+ # Have to create a join event using the eventfactory
+ yield self.store.persist_event(
+ self.event_factory.create_event(
+ etype=RoomMemberEvent.TYPE,
+ user_id=user.to_string(),
+ state_key=user.to_string(),
+ room_id=room.to_string(),
+ membership=membership,
+ content={"membership": membership},
+ depth=1,
+ )
+ )
+
+ @defer.inlineCallbacks
+ def test_one_member(self):
+ yield self.inject_room_member(self.room, self.u_alice, Membership.JOIN)
+
+ self.assertEquals(
+ Membership.JOIN,
+ (yield self.store.get_room_member(
+ user_id=self.u_alice.to_string(),
+ room_id=self.room.to_string(),
+ )).membership
+ )
+ self.assertEquals(
+ [self.u_alice.to_string()],
+ [m.user_id for m in (
+ yield self.store.get_room_members(self.room.to_string())
+ )]
+ )
+ self.assertEquals(
+ [self.room.to_string()],
+ [m.room_id for m in (
+ yield self.store.get_rooms_for_user_where_membership_is(
+ self.u_alice.to_string(), [Membership.JOIN]
+ ))
+ ]
+ )
+ self.assertFalse(
+ (yield self.store.user_rooms_intersect(
+ [self.u_alice.to_string(), self.u_bob.to_string()]
+ ))
+ )
+
+ @defer.inlineCallbacks
+ def test_two_members(self):
+ yield self.inject_room_member(self.room, self.u_alice, Membership.JOIN)
+ yield self.inject_room_member(self.room, self.u_bob, Membership.JOIN)
+
+ self.assertEquals(
+ {self.u_alice.to_string(), self.u_bob.to_string()},
+ {m.user_id for m in (
+ yield self.store.get_room_members(self.room.to_string())
+ )}
+ )
+ self.assertTrue(
+ (yield self.store.user_rooms_intersect(
+ [self.u_alice.to_string(), self.u_bob.to_string()]
+ ))
+ )
+
+ @defer.inlineCallbacks
+ def test_room_hosts(self):
+ yield self.inject_room_member(self.room, self.u_alice, Membership.JOIN)
+
+ self.assertEquals(
+ ["test"],
+ (yield self.store.get_joined_hosts_for_room(self.room.to_string()))
+ )
+
+ # Should still have just one host after second join from it
+ yield self.inject_room_member(self.room, self.u_bob, Membership.JOIN)
+
+ self.assertEquals(
+ ["test"],
+ (yield self.store.get_joined_hosts_for_room(self.room.to_string()))
+ )
+
+ # Should now have two hosts after join from other host
+ yield self.inject_room_member(self.room, self.u_charlie, Membership.JOIN)
+
+ self.assertEquals(
+ {"test", "elsewhere"},
+ set((yield
+ self.store.get_joined_hosts_for_room(self.room.to_string())
+ ))
+ )
+
+ # Should still have both hosts
+ yield self.inject_room_member(self.room, self.u_alice, Membership.LEAVE)
+
+ self.assertEquals(
+ {"test", "elsewhere"},
+ set((yield
+ self.store.get_joined_hosts_for_room(self.room.to_string())
+ ))
+ )
+
+ # Should have only one host after other leaves
+ yield self.inject_room_member(self.room, self.u_charlie, Membership.LEAVE)
+
+ self.assertEquals(
+ ["test"],
+ (yield self.store.get_joined_hosts_for_room(self.room.to_string()))
+ )
diff --git a/tests/unittest.py b/tests/unittest.py
index fb97fb1148..a9c0e05541 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -71,6 +71,17 @@ class TestCase(unittest.TestCase):
logging.getLogger().setLevel(level)
return orig()
+ def assertObjectHasAttributes(self, attrs, obj):
+ """Asserts that the given object has each of the attributes given, and
+ that the value of each matches according to assertEquals."""
+ for (key, value) in attrs.items():
+ if not hasattr(obj, key):
+ raise AssertionError("Expected obj to have a '.%s'" % key)
+ try:
+ self.assertEquals(attrs[key], getattr(obj, key))
+ except AssertionError as e:
+ raise (type(e))(e.message + " for '.%s'" % key)
+
def DEBUG(target):
"""A decorator to set the .loglevel attribute to logging.DEBUG.
diff --git a/tests/utils.py b/tests/utils.py
index d90214e418..bc5d35e56b 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -16,12 +16,14 @@
from synapse.http.server import HttpServer
from synapse.api.errors import cs_error, CodeMessageException, StoreError
from synapse.api.constants import Membership
+from synapse.storage import prepare_database
from synapse.api.events.room import (
RoomMemberEvent, MessageEvent
)
from twisted.internet import defer, reactor
+from twisted.enterprise.adbapi import ConnectionPool
from collections import namedtuple
from mock import patch, Mock
@@ -120,6 +122,18 @@ class MockClock(object):
self.now += secs
+class SQLiteMemoryDbPool(ConnectionPool, object):
+ def __init__(self):
+ super(SQLiteMemoryDbPool, self).__init__(
+ "sqlite3", ":memory:",
+ cp_min=1,
+ cp_max=1,
+ )
+
+ def prepare(self):
+ return self.runWithConnection(prepare_database)
+
+
class MemoryDataStore(object):
Room = namedtuple(
|