diff --git a/jenkins-dendron-postgres.sh b/jenkins-dendron-postgres.sh
index 8e3a4c51a9..d15836e6bf 100755
--- a/jenkins-dendron-postgres.sh
+++ b/jenkins-dendron-postgres.sh
@@ -73,11 +73,12 @@ git checkout "${GIT_BRANCH}" || (echo >&2 "No ref ${GIT_BRANCH} found, falling b
./jenkins/prep_sytest_for_postgres.sh
+mkdir -p var
+
echo >&2 "Running sytest with PostgreSQL";
./jenkins/install_and_run.sh --python $TOX_BIN/python \
--synapse-directory $WORKSPACE \
--dendron $WORKSPACE/dendron/bin/dendron \
- --synchrotron \
--pusher \
--port-base $PORT_BASE
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 22e1721fc4..40ffd9bf0d 100755
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -266,10 +266,9 @@ def setup(config_options):
HomeServer
"""
try:
- config = HomeServerConfig.load_config(
+ config = HomeServerConfig.load_or_generate_config(
"Synapse Homeserver",
config_options,
- generate_section="Homeserver"
)
except ConfigError as e:
sys.stderr.write("\n" + e.message + "\n")
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index 7449f36491..af9f17bf7b 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -157,9 +157,40 @@ class Config(object):
return default_config, config
@classmethod
- def load_config(cls, description, argv, generate_section=None):
+ def load_config(cls, description, argv):
+ config_parser = argparse.ArgumentParser(
+ description=description,
+ )
+ config_parser.add_argument(
+ "-c", "--config-path",
+ action="append",
+ metavar="CONFIG_FILE",
+ help="Specify config file. Can be given multiple times and"
+ " may specify directories containing *.yaml files."
+ )
+
+ config_parser.add_argument(
+ "--keys-directory",
+ metavar="DIRECTORY",
+ help="Where files such as certs and signing keys are stored when"
+ " their location is given explicitly in the config."
+ " Defaults to the directory containing the last config file",
+ )
+
+ config_args = config_parser.parse_args(argv)
+
+ config_files = find_config_files(search_paths=config_args.config_path)
+
obj = cls()
+ obj.read_config_files(
+ config_files,
+ keys_directory=config_args.keys_directory,
+ generate_keys=False,
+ )
+ return obj
+ @classmethod
+ def load_or_generate_config(cls, description, argv):
config_parser = argparse.ArgumentParser(add_help=False)
config_parser.add_argument(
"-c", "--config-path",
@@ -176,7 +207,7 @@ class Config(object):
config_parser.add_argument(
"--report-stats",
action="store",
- help="Stuff",
+ help="Whether the generated config reports anonymized usage statistics",
choices=["yes", "no"]
)
config_parser.add_argument(
@@ -197,36 +228,11 @@ class Config(object):
)
config_args, remaining_args = config_parser.parse_known_args(argv)
+ config_files = find_config_files(search_paths=config_args.config_path)
+
generate_keys = config_args.generate_keys
- config_files = []
- if config_args.config_path:
- for config_path in config_args.config_path:
- if os.path.isdir(config_path):
- # We accept specifying directories as config paths, we search
- # inside that directory for all files matching *.yaml, and then
- # we apply them in *sorted* order.
- files = []
- for entry in os.listdir(config_path):
- entry_path = os.path.join(config_path, entry)
- if not os.path.isfile(entry_path):
- print (
- "Found subdirectory in config directory: %r. IGNORING."
- ) % (entry_path, )
- continue
-
- if not entry.endswith(".yaml"):
- print (
- "Found file in config directory that does not"
- " end in '.yaml': %r. IGNORING."
- ) % (entry_path, )
- continue
-
- files.append(entry_path)
-
- config_files.extend(sorted(files))
- else:
- config_files.append(config_path)
+ obj = cls()
if config_args.generate_config:
if config_args.report_stats is None:
@@ -299,28 +305,43 @@ class Config(object):
" -c CONFIG-FILE\""
)
- if config_args.keys_directory:
- config_dir_path = config_args.keys_directory
- else:
- config_dir_path = os.path.dirname(config_args.config_path[-1])
- config_dir_path = os.path.abspath(config_dir_path)
+ obj.read_config_files(
+ config_files,
+ keys_directory=config_args.keys_directory,
+ generate_keys=generate_keys,
+ )
+
+ if generate_keys:
+ return None
+
+ obj.invoke_all("read_arguments", args)
+
+ return obj
+
+ def read_config_files(self, config_files, keys_directory=None,
+ generate_keys=False):
+ if not keys_directory:
+ keys_directory = os.path.dirname(config_files[-1])
+
+ config_dir_path = os.path.abspath(keys_directory)
specified_config = {}
for config_file in config_files:
- yaml_config = cls.read_config_file(config_file)
+ yaml_config = self.read_config_file(config_file)
specified_config.update(yaml_config)
if "server_name" not in specified_config:
raise ConfigError(MISSING_SERVER_NAME)
server_name = specified_config["server_name"]
- _, config = obj.generate_config(
+ _, config = self.generate_config(
config_dir_path=config_dir_path,
server_name=server_name,
is_generating_file=False,
)
config.pop("log_config")
config.update(specified_config)
+
if "report_stats" not in config:
raise ConfigError(
MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS + "\n" +
@@ -328,11 +349,51 @@ class Config(object):
)
if generate_keys:
- obj.invoke_all("generate_files", config)
+ self.invoke_all("generate_files", config)
return
- obj.invoke_all("read_config", config)
-
- obj.invoke_all("read_arguments", args)
-
- return obj
+ self.invoke_all("read_config", config)
+
+
+def find_config_files(search_paths):
+ """Finds config files using a list of search paths. If a path is a file
+ then that file path is added to the list. If a search path is a directory
+ then all the "*.yaml" files in that directory are added to the list in
+ sorted order.
+
+ Args:
+ search_paths(list(str)): A list of paths to search.
+
+ Returns:
+ list(str): A list of file paths.
+ """
+
+ config_files = []
+ if search_paths:
+ for config_path in search_paths:
+ if os.path.isdir(config_path):
+ # We accept specifying directories as config paths, we search
+ # inside that directory for all files matching *.yaml, and then
+ # we apply them in *sorted* order.
+ files = []
+ for entry in os.listdir(config_path):
+ entry_path = os.path.join(config_path, entry)
+ if not os.path.isfile(entry_path):
+ print (
+ "Found subdirectory in config directory: %r. IGNORING."
+ ) % (entry_path, )
+ continue
+
+ if not entry.endswith(".yaml"):
+ print (
+ "Found file in config directory that does not"
+ " end in '.yaml': %r. IGNORING."
+ ) % (entry_path, )
+ continue
+
+ files.append(entry_path)
+
+ config_files.extend(sorted(files))
+ else:
+ config_files.append(config_path)
+ return config_files
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index f1d231b9d8..9f2a64dede 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -377,10 +377,20 @@ class FederationServer(FederationBase):
@log_function
def on_get_missing_events(self, origin, room_id, earliest_events,
latest_events, limit, min_depth):
+ logger.info(
+ "on_get_missing_events: earliest_events: %r, latest_events: %r,"
+ " limit: %d, min_depth: %d",
+ earliest_events, latest_events, limit, min_depth
+ )
missing_events = yield self.handler.on_get_missing_events(
origin, room_id, earliest_events, latest_events, limit, min_depth
)
+ if len(missing_events) < 5:
+ logger.info("Returning %d events: %r", len(missing_events), missing_events)
+ else:
+ logger.info("Returning %d events", len(missing_events))
+
time_now = self._clock.time_msec()
defer.returnValue({
@@ -490,6 +500,11 @@ class FederationServer(FederationBase):
latest = set(latest)
latest |= seen
+ logger.info(
+ "Missing %d events for room %r: %r...",
+ len(prevs - seen), pdu.room_id, list(prevs - seen)[:5]
+ )
+
missing_events = yield self.get_missing_events(
origin,
pdu.room_id,
@@ -517,6 +532,10 @@ class FederationServer(FederationBase):
prevs = {e_id for e_id, _ in pdu.prev_events}
seen = set(have_seen.keys())
if prevs - seen:
+ logger.info(
+ "Still missing %d events for room %r: %r...",
+ len(prevs - seen), pdu.room_id, list(prevs - seen)[:5]
+ )
fetch_state = True
if fetch_state:
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index a1a334955f..6fc3e2207c 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -528,15 +528,10 @@ class PublicRoomList(BaseFederationServlet):
PATH = "/publicRooms"
@defer.inlineCallbacks
- def on_GET(self, request):
+ def on_GET(self, origin, content, query):
data = yield self.room_list_handler.get_local_public_room_list()
defer.returnValue((200, data))
- # Avoid doing remote HS authorization checks which are done by default by
- # BaseFederationServlet.
- def _wrap(self, code):
- return code
-
SERVLET_CLASSES = (
FederationSendServlet,
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index ff83c608e7..c2df43e2f6 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -345,19 +345,21 @@ class FederationHandler(BaseHandler):
)
missing_auth = required_auth - set(auth_events)
- results = yield defer.gatherResults(
- [
- self.replication_layer.get_pdu(
- [dest],
- event_id,
- outlier=True,
- timeout=10000,
- )
- for event_id in missing_auth
- ],
- consumeErrors=True
- ).addErrback(unwrapFirstError)
- auth_events.update({a.event_id: a for a in results})
+ if missing_auth:
+ logger.info("Missing auth for backfill: %r", missing_auth)
+ results = yield defer.gatherResults(
+ [
+ self.replication_layer.get_pdu(
+ [dest],
+ event_id,
+ outlier=True,
+ timeout=10000,
+ )
+ for event_id in missing_auth
+ ],
+ consumeErrors=True
+ ).addErrback(unwrapFirstError)
+ auth_events.update({a.event_id: a for a in results})
ev_infos = []
for a in auth_events.values():
@@ -399,7 +401,7 @@ class FederationHandler(BaseHandler):
# previous to work out the state.
# TODO: We can probably do something more clever here.
yield self._handle_new_event(
- dest, event
+ dest, event, backfilled=True,
)
defer.returnValue(events)
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index bbc07b045e..e0aaefe7be 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -388,8 +388,8 @@ class RegistrationHandler(BaseHandler):
user = UserID(localpart, self.hs.hostname)
user_id = user.to_string()
- auth_handler = self.hs.get_handlers().auth_handler
- token = auth_handler.generate_short_term_login_token(user_id, duration_seconds)
+ token = self.auth_handler().generate_short_term_login_token(
+ user_id, duration_seconds)
if need_register:
yield self.store.register(
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 9fd34588dd..ae44c7a556 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -20,7 +20,7 @@ from ._base import BaseHandler
from synapse.types import UserID, RoomAlias, RoomID, RoomStreamToken
from synapse.api.constants import (
- EventTypes, JoinRules, RoomCreationPreset,
+ EventTypes, JoinRules, RoomCreationPreset, Membership,
)
from synapse.api.errors import AuthError, StoreError, SynapseError
from synapse.util import stringutils
@@ -367,14 +367,10 @@ class RoomListHandler(BaseHandler):
@defer.inlineCallbacks
def handle_room(room_id):
- # We pull each bit of state out indvidually to avoid pulling the
- # full state into memory. Due to how the caching works this should
- # be fairly quick, even if not originally in the cache.
- def get_state(etype, state_key):
- return self.state_handler.get_current_state(room_id, etype, state_key)
+ current_state = yield self.state_handler.get_current_state(room_id)
# Double check that this is actually a public room.
- join_rules_event = yield get_state(EventTypes.JoinRules, "")
+ join_rules_event = current_state.get((EventTypes.JoinRules, ""))
if join_rules_event:
join_rule = join_rules_event.content.get("join_rule", None)
if join_rule and join_rule != JoinRules.PUBLIC:
@@ -382,47 +378,51 @@ class RoomListHandler(BaseHandler):
result = {"room_id": room_id}
- joined_users = yield self.store.get_users_in_room(room_id)
- if len(joined_users) == 0:
+ num_joined_users = len([
+ 1 for _, event in current_state.items()
+ if event.type == EventTypes.Member
+ and event.membership == Membership.JOIN
+ ])
+ if num_joined_users == 0:
return
- result["num_joined_members"] = len(joined_users)
+ result["num_joined_members"] = num_joined_users
aliases = yield self.store.get_aliases_for_room(room_id)
if aliases:
result["aliases"] = aliases
- name_event = yield get_state(EventTypes.Name, "")
+ name_event = yield current_state.get((EventTypes.Name, ""))
if name_event:
name = name_event.content.get("name", None)
if name:
result["name"] = name
- topic_event = yield get_state(EventTypes.Topic, "")
+ topic_event = current_state.get((EventTypes.Topic, ""))
if topic_event:
topic = topic_event.content.get("topic", None)
if topic:
result["topic"] = topic
- canonical_event = yield get_state(EventTypes.CanonicalAlias, "")
+ canonical_event = current_state.get((EventTypes.CanonicalAlias, ""))
if canonical_event:
canonical_alias = canonical_event.content.get("alias", None)
if canonical_alias:
result["canonical_alias"] = canonical_alias
- visibility_event = yield get_state(EventTypes.RoomHistoryVisibility, "")
+ visibility_event = current_state.get((EventTypes.RoomHistoryVisibility, ""))
visibility = None
if visibility_event:
visibility = visibility_event.content.get("history_visibility", None)
result["world_readable"] = visibility == "world_readable"
- guest_event = yield get_state(EventTypes.GuestAccess, "")
+ guest_event = current_state.get((EventTypes.GuestAccess, ""))
guest = None
if guest_event:
guest = guest_event.content.get("guest_access", None)
result["guest_can_join"] = guest == "can_join"
- avatar_event = yield get_state("m.room.avatar", "")
+ avatar_event = current_state.get(("m.room.avatar", ""))
if avatar_event:
avatar_url = avatar_event.content.get("url", None)
if avatar_url:
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 861b8f7989..5589296c09 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -221,6 +221,9 @@ class TypingHandler(object):
def get_all_typing_updates(self, last_id, current_id):
# TODO: Work out a way to do this without scanning the entire state.
+ if last_id == current_id:
+ return []
+
rows = []
for room_id, serial in self._room_serials.items():
if last_id < serial and serial <= current_id:
diff --git a/synapse/rest/client/v1/events.py b/synapse/rest/client/v1/events.py
index d1afa0f0d5..498bb9e18a 100644
--- a/synapse/rest/client/v1/events.py
+++ b/synapse/rest/client/v1/events.py
@@ -45,30 +45,27 @@ class EventStreamRestServlet(ClientV1RestServlet):
raise SynapseError(400, "Guest users must specify room_id param")
if "room_id" in request.args:
room_id = request.args["room_id"][0]
- try:
- handler = self.handlers.event_stream_handler
- pagin_config = PaginationConfig.from_request(request)
- timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS
- if "timeout" in request.args:
- try:
- timeout = int(request.args["timeout"][0])
- except ValueError:
- raise SynapseError(400, "timeout must be in milliseconds.")
-
- as_client_event = "raw" not in request.args
-
- chunk = yield handler.get_stream(
- requester.user.to_string(),
- pagin_config,
- timeout=timeout,
- as_client_event=as_client_event,
- affect_presence=(not is_guest),
- room_id=room_id,
- is_guest=is_guest,
- )
- except:
- logger.exception("Event stream failed")
- raise
+
+ handler = self.handlers.event_stream_handler
+ pagin_config = PaginationConfig.from_request(request)
+ timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS
+ if "timeout" in request.args:
+ try:
+ timeout = int(request.args["timeout"][0])
+ except ValueError:
+ raise SynapseError(400, "timeout must be in milliseconds.")
+
+ as_client_event = "raw" not in request.args
+
+ chunk = yield handler.get_stream(
+ requester.user.to_string(),
+ pagin_config,
+ timeout=timeout,
+ as_client_event=as_client_event,
+ affect_presence=(not is_guest),
+ room_id=room_id,
+ is_guest=is_guest,
+ )
defer.returnValue((200, chunk))
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index db52a1fc39..86fbe2747d 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -72,8 +72,6 @@ class RoomCreateRestServlet(ClientV1RestServlet):
def get_room_config(self, request):
user_supplied_config = parse_json_object_from_request(request)
- # default visibility
- user_supplied_config.setdefault("visibility", "public")
return user_supplied_config
def on_OPTIONS(self, request):
@@ -279,6 +277,13 @@ class PublicRoomListRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request):
+ try:
+ yield self.auth.get_user_by_req(request)
+ except AuthError:
+ # This endpoint isn't authed, but its useful to know who's hitting
+ # it if they *do* supply an access token
+ pass
+
handler = self.hs.get_room_list_handler()
data = yield handler.get_aggregated_public_room_list()
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index d96bf9afe2..2468c3ac42 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -26,6 +26,7 @@ from .thumbnailer import Thumbnailer
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.util.stringutils import random_string
+from synapse.api.errors import SynapseError
from twisted.internet import defer, threads
@@ -134,10 +135,15 @@ class MediaRepository(object):
request_path = "/".join((
"/_matrix/media/v1/download", server_name, media_id,
))
- length, headers = yield self.client.get_file(
- server_name, request_path, output_stream=f,
- max_size=self.max_upload_size,
- )
+ try:
+ length, headers = yield self.client.get_file(
+ server_name, request_path, output_stream=f,
+ max_size=self.max_upload_size,
+ )
+ except Exception as e:
+ logger.warn("Failed to fetch remoted media %r", e)
+ raise SynapseError(502, "Failed to fetch remoted media")
+
media_type = headers["Content-Type"][0]
time_now_ms = self.clock.time_msec()
diff --git a/synapse/storage/account_data.py b/synapse/storage/account_data.py
index ec7e8d40d2..3fa226e92d 100644
--- a/synapse/storage/account_data.py
+++ b/synapse/storage/account_data.py
@@ -138,6 +138,9 @@ class AccountDataStore(SQLBaseStore):
A deferred pair of lists of tuples of stream_id int, user_id string,
room_id string, type string, and content string.
"""
+ if last_room_id == current_id and last_global_id == current_id:
+ return defer.succeed(([], []))
+
def get_updated_account_data_txn(txn):
sql = (
"SELECT stream_id, user_id, account_data_type, content"
diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py
index 3fab57a7e8..d03f7c541e 100644
--- a/synapse/storage/presence.py
+++ b/synapse/storage/presence.py
@@ -118,6 +118,9 @@ class PresenceStore(SQLBaseStore):
)
def get_all_presence_updates(self, last_id, current_id):
+ if last_id == current_id:
+ return defer.succeed([])
+
def get_all_presence_updates_txn(txn):
sql = (
"SELECT stream_id, user_id, state, last_active_ts,"
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index 786d6f6d67..8183b7f1b0 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -421,6 +421,9 @@ class PushRuleStore(SQLBaseStore):
def get_all_push_rule_updates(self, last_id, current_id, limit):
"""Get all the push rules changes that have happend on the server"""
+ if last_id == current_id:
+ return defer.succeed([])
+
def get_all_push_rule_updates_txn(txn):
sql = (
"SELECT stream_id, event_stream_ordering, user_id, rule_id,"
diff --git a/synapse/storage/tags.py b/synapse/storage/tags.py
index 9da23f34cb..5a2c1aa59b 100644
--- a/synapse/storage/tags.py
+++ b/synapse/storage/tags.py
@@ -68,6 +68,9 @@ class TagsStore(SQLBaseStore):
A deferred list of tuples of stream_id int, user_id string,
room_id string, tag string and content string.
"""
+ if last_id == current_id:
+ defer.returnValue([])
+
def get_all_updated_tags_txn(txn):
sql = (
"SELECT stream_id, user_id, room_id"
diff --git a/synapse/types.py b/synapse/types.py
index 7b6ae44bdd..f639651a73 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -22,7 +22,10 @@ Requester = namedtuple("Requester", ["user", "access_token_id", "is_guest"])
def get_domain_from_id(string):
- return string.split(":", 1)[1]
+ try:
+ return string.split(":", 1)[1]
+ except IndexError:
+ raise SynapseError(400, "Invalid ID: %r", string)
class DomainSpecificString(
diff --git a/tests/config/test_generate.py b/tests/config/test_generate.py
index 4329d73974..8f57fbeb23 100644
--- a/tests/config/test_generate.py
+++ b/tests/config/test_generate.py
@@ -30,7 +30,7 @@ class ConfigGenerationTestCase(unittest.TestCase):
shutil.rmtree(self.dir)
def test_generate_config_generates_files(self):
- HomeServerConfig.load_config("", [
+ HomeServerConfig.load_or_generate_config("", [
"--generate-config",
"-c", self.file,
"--report-stats=yes",
diff --git a/tests/config/test_load.py b/tests/config/test_load.py
index bf46233c5c..161a87d7e3 100644
--- a/tests/config/test_load.py
+++ b/tests/config/test_load.py
@@ -34,6 +34,8 @@ class ConfigLoadingTestCase(unittest.TestCase):
self.generate_config_and_remove_lines_containing("server_name")
with self.assertRaises(Exception):
HomeServerConfig.load_config("", ["-c", self.file])
+ with self.assertRaises(Exception):
+ HomeServerConfig.load_or_generate_config("", ["-c", self.file])
def test_generates_and_loads_macaroon_secret_key(self):
self.generate_config()
@@ -54,11 +56,24 @@ class ConfigLoadingTestCase(unittest.TestCase):
"was: %r" % (config.macaroon_secret_key,)
)
+ config = HomeServerConfig.load_or_generate_config("", ["-c", self.file])
+ self.assertTrue(
+ hasattr(config, "macaroon_secret_key"),
+ "Want config to have attr macaroon_secret_key"
+ )
+ if len(config.macaroon_secret_key) < 5:
+ self.fail(
+ "Want macaroon secret key to be string of at least length 5,"
+ "was: %r" % (config.macaroon_secret_key,)
+ )
+
def test_load_succeeds_if_macaroon_secret_key_missing(self):
self.generate_config_and_remove_lines_containing("macaroon")
config1 = HomeServerConfig.load_config("", ["-c", self.file])
config2 = HomeServerConfig.load_config("", ["-c", self.file])
+ config3 = HomeServerConfig.load_or_generate_config("", ["-c", self.file])
self.assertEqual(config1.macaroon_secret_key, config2.macaroon_secret_key)
+ self.assertEqual(config1.macaroon_secret_key, config3.macaroon_secret_key)
def test_disable_registration(self):
self.generate_config()
@@ -70,14 +85,17 @@ class ConfigLoadingTestCase(unittest.TestCase):
config = HomeServerConfig.load_config("", ["-c", self.file])
self.assertFalse(config.enable_registration)
+ config = HomeServerConfig.load_or_generate_config("", ["-c", self.file])
+ self.assertFalse(config.enable_registration)
+
# Check that either config value is clobbered by the command line.
- config = HomeServerConfig.load_config("", [
+ config = HomeServerConfig.load_or_generate_config("", [
"-c", self.file, "--enable-registration"
])
self.assertTrue(config.enable_registration)
def generate_config(self):
- HomeServerConfig.load_config("", [
+ HomeServerConfig.load_or_generate_config("", [
"--generate-config",
"-c", self.file,
"--report-stats=yes",
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index 9d5c653b45..69a5e5b1d4 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -41,14 +41,15 @@ class RegistrationTestCase(unittest.TestCase):
handlers=None,
http_client=None,
expire_access_token=True)
+ self.auth_handler = Mock(
+ generate_short_term_login_token=Mock(return_value='secret'))
self.hs.handlers = RegistrationHandlers(self.hs)
self.handler = self.hs.get_handlers().registration_handler
self.hs.get_handlers().profile_handler = Mock()
self.mock_handler = Mock(spec=[
"generate_short_term_login_token",
])
-
- self.hs.get_handlers().auth_handler = self.mock_handler
+ self.hs.get_auth_handler = Mock(return_value=self.auth_handler)
@defer.inlineCallbacks
def test_user_is_created_and_logged_in_if_doesnt_exist(self):
@@ -56,8 +57,6 @@ class RegistrationTestCase(unittest.TestCase):
local_part = "someone"
display_name = "someone"
user_id = "@someone:test"
- mock_token = self.mock_handler.generate_short_term_login_token
- mock_token.return_value = 'secret'
result_user_id, result_token = yield self.handler.get_or_create_user(
local_part, display_name, duration_ms)
self.assertEquals(result_user_id, user_id)
@@ -75,8 +74,6 @@ class RegistrationTestCase(unittest.TestCase):
local_part = "frank"
display_name = "Frank"
user_id = "@frank:test"
- mock_token = self.mock_handler.generate_short_term_login_token
- mock_token.return_value = 'secret'
result_user_id, result_token = yield self.handler.get_or_create_user(
local_part, display_name, duration_ms)
self.assertEquals(result_user_id, user_id)
diff --git a/tox.ini b/tox.ini
index 757b7189c3..52d93c65e5 100644
--- a/tox.ini
+++ b/tox.ini
@@ -11,7 +11,7 @@ deps =
setenv =
PYTHONDONTWRITEBYTECODE = no_byte_code
commands =
- /bin/bash -c "find {toxinidir} -name '*.pyc' -delete ; coverage run {env:COVERAGE_OPTS:} --source={toxinidir}/synapse \
+ /bin/sh -c "find {toxinidir} -name '*.pyc' -delete ; coverage run {env:COVERAGE_OPTS:} --source={toxinidir}/synapse \
{envbindir}/trial {env:TRIAL_FLAGS:} {posargs:tests} {env:TOXSUFFIX:}"
{env:DUMP_COVERAGE_COMMAND:coverage report -m}
@@ -26,4 +26,4 @@ skip_install = True
basepython = python2.7
deps =
flake8
-commands = /bin/bash -c "flake8 synapse tests {env:PEP8SUFFIX:}"
+commands = /bin/sh -c "flake8 synapse tests {env:PEP8SUFFIX:}"
|