diff options
author | Richard van der Hoff <richard@matrix.org> | 2018-02-03 22:40:28 +0000 |
---|---|---|
committer | Richard van der Hoff <richard@matrix.org> | 2018-02-03 22:40:28 +0000 |
commit | bb9f0f3cdb57882248b6426dd4ec5bd5781daaef (patch) | |
tree | 41a8ee0b215406119ec5ef44cf0a3326a8ae1636 /tests/utils.py | |
parent | oops (diff) | |
parent | Merge pull request #2837 from matrix-org/rav/fix_quarantine_media (diff) | |
download | synapse-bb9f0f3cdb57882248b6426dd4ec5bd5781daaef.tar.xz |
Merge branch 'develop' into matthew/gin_work_mem
Diffstat (limited to 'tests/utils.py')
-rw-r--r-- | tests/utils.py | 249 |
1 files changed, 63 insertions, 186 deletions
diff --git a/tests/utils.py b/tests/utils.py index ed8a7360f5..8efd3a3475 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -13,27 +13,28 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.http.server import HttpServer -from synapse.api.errors import cs_error, CodeMessageException, StoreError -from synapse.api.constants import EventTypes -from synapse.storage.prepare_database import prepare_database -from synapse.storage.engines import create_engine -from synapse.server import HomeServer -from synapse.federation.transport import server -from synapse.util.ratelimitutils import FederationRateLimiter - -from synapse.util.logcontext import LoggingContext - -from twisted.internet import defer, reactor -from twisted.enterprise.adbapi import ConnectionPool - -from collections import namedtuple -from mock import patch, Mock import hashlib +from inspect import getcallargs import urllib import urlparse -from inspect import getcallargs +from mock import Mock, patch +from twisted.internet import defer, reactor + +from synapse.api.errors import CodeMessageException, cs_error +from synapse.federation.transport import server +from synapse.http.server import HttpServer +from synapse.server import HomeServer +from synapse.storage import PostgresEngine +from synapse.storage.engines import create_engine +from synapse.storage.prepare_database import prepare_database +from synapse.util.logcontext import LoggingContext +from synapse.util.ratelimitutils import FederationRateLimiter + +# set this to True to run the tests against postgres instead of sqlite. +# It requires you to have a local postgres database called synapse_test, within +# which ALL TABLES WILL BE DROPPED +USE_POSTGRES_FOR_TESTS = False @defer.inlineCallbacks @@ -57,32 +58,70 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs): config.worker_app = None config.email_enable_notifs = False config.block_non_admin_invites = False + config.federation_domain_whitelist = None + config.user_directory_search_all_users = False + + # disable user directory updates, because they get done in the + # background, which upsets the test runner. + config.update_user_directory = False config.use_frozen_dicts = True - config.database_config = {"name": "sqlite3"} config.ldap_enabled = False if "clock" not in kargs: kargs["clock"] = MockClock() + if USE_POSTGRES_FOR_TESTS: + config.database_config = { + "name": "psycopg2", + "args": { + "database": "synapse_test", + "cp_min": 1, + "cp_max": 5, + }, + } + else: + config.database_config = { + "name": "sqlite3", + "args": { + "database": ":memory:", + "cp_min": 1, + "cp_max": 1, + }, + } + + db_engine = create_engine(config.database_config) + + # we need to configure the connection pool to run the on_new_connection + # function, so that we can test code that uses custom sqlite functions + # (like rank). + config.database_config["args"]["cp_openfun"] = db_engine.on_new_connection + if datastore is None: - db_pool = SQLiteMemoryDbPool() - yield db_pool.prepare() hs = HomeServer( - name, db_pool=db_pool, config=config, + name, config=config, + db_config=config.database_config, version_string="Synapse/tests", - database_engine=create_engine(config.database_config), - get_db_conn=db_pool.get_db_conn, + database_engine=db_engine, room_list_handler=object(), tls_server_context_factory=Mock(), **kargs ) + db_conn = hs.get_db_conn() + # make sure that the database is empty + if isinstance(db_engine, PostgresEngine): + cur = db_conn.cursor() + cur.execute("SELECT tablename FROM pg_tables where schemaname='public'") + rows = cur.fetchall() + for r in rows: + cur.execute("DROP TABLE %s CASCADE" % r[0]) + yield prepare_database(db_conn, db_engine, config) hs.setup() else: hs = HomeServer( name, db_pool=None, datastore=datastore, config=config, version_string="Synapse/tests", - database_engine=create_engine(config.database_config), + database_engine=db_engine, room_list_handler=object(), tls_server_context_factory=Mock(), **kargs @@ -301,168 +340,6 @@ class MockClock(object): return d -class SQLiteMemoryDbPool(ConnectionPool, object): - def __init__(self): - super(SQLiteMemoryDbPool, self).__init__( - "sqlite3", ":memory:", - cp_min=1, - cp_max=1, - ) - - self.config = Mock() - self.config.password_providers = [] - self.config.database_config = {"name": "sqlite3"} - - def prepare(self): - engine = self.create_engine() - return self.runWithConnection( - lambda conn: prepare_database(conn, engine, self.config) - ) - - def get_db_conn(self): - conn = self.connect() - engine = self.create_engine() - prepare_database(conn, engine, self.config) - return conn - - def create_engine(self): - return create_engine(self.config.database_config) - - -class MemoryDataStore(object): - - Room = namedtuple( - "Room", - ["room_id", "is_public", "creator"] - ) - - def __init__(self): - self.tokens_to_users = {} - self.paths_to_content = {} - - self.members = {} - self.rooms = {} - - self.current_state = {} - self.events = [] - - class Snapshot(namedtuple("Snapshot", "room_id user_id membership_state")): - def fill_out_prev_events(self, event): - pass - - def snapshot_room(self, room_id, user_id, state_type=None, state_key=None): - return self.Snapshot( - room_id, user_id, self.get_room_member(user_id, room_id) - ) - - def register(self, user_id, token, password_hash): - if user_id in self.tokens_to_users.values(): - raise StoreError(400, "User in use.") - self.tokens_to_users[token] = user_id - - def get_user_by_access_token(self, token): - try: - return { - "name": self.tokens_to_users[token], - } - except Exception: - raise StoreError(400, "User does not exist.") - - def get_room(self, room_id): - try: - return self.rooms[room_id] - except Exception: - return None - - def store_room(self, room_id, room_creator_user_id, is_public): - if room_id in self.rooms: - raise StoreError(409, "Conflicting room!") - - room = MemoryDataStore.Room( - room_id=room_id, - is_public=is_public, - creator=room_creator_user_id - ) - self.rooms[room_id] = room - - def get_room_member(self, user_id, room_id): - return self.members.get(room_id, {}).get(user_id) - - def get_room_members(self, room_id, membership=None): - if membership: - return [ - v for k, v in self.members.get(room_id, {}).items() - if v.membership == membership - ] - else: - return self.members.get(room_id, {}).values() - - def get_rooms_for_user_where_membership_is(self, user_id, membership_list): - return [ - m[user_id] for m in self.members.values() - if user_id in m and m[user_id].membership in membership_list - ] - - def get_room_events_stream(self, user_id=None, from_key=None, to_key=None, - limit=0, with_feedback=False): - return ([], from_key) # TODO - - def get_joined_hosts_for_room(self, room_id): - return defer.succeed([]) - - def persist_event(self, event): - if event.type == EventTypes.Member: - room_id = event.room_id - user = event.state_key - self.members.setdefault(room_id, {})[user] = event - - if hasattr(event, "state_key"): - key = (event.room_id, event.type, event.state_key) - self.current_state[key] = event - - self.events.append(event) - - def get_current_state(self, room_id, event_type=None, state_key=""): - if event_type: - key = (room_id, event_type, state_key) - if self.current_state.get(key): - return [self.current_state.get(key)] - return None - else: - return [ - e for e in self.current_state - if e[0] == room_id - ] - - def set_presence_state(self, user_localpart, state): - return defer.succeed({"state": 0}) - - def get_presence_list(self, user_localpart, accepted): - return [] - - def get_room_events_max_id(self): - return "s0" # TODO (erikj) - - def get_send_event_level(self, room_id): - return defer.succeed(0) - - def get_power_level(self, room_id, user_id): - return defer.succeed(0) - - def get_add_state_level(self, room_id): - return defer.succeed(0) - - def get_room_join_rule(self, room_id): - # TODO (erikj): This should be configurable - return defer.succeed("invite") - - def get_ops_levels(self, room_id): - return defer.succeed((5, 5, 5)) - - def insert_client_ip(self, user, access_token, ip, user_agent): - return defer.succeed(None) - - def _format_call(args, kwargs): return ", ".join( ["%r" % (a) for a in args] + |