diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index c0cb8ef296..6121efcfa9 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -21,6 +21,7 @@ from twisted.internet import defer
import synapse.handlers.auth
from synapse.api.auth import Auth
+from synapse.api.constants import UserTypes
from synapse.api.errors import (
AuthError,
Codes,
@@ -336,6 +337,23 @@ class AuthTestCase(unittest.TestCase):
yield self.auth.check_auth_blocking()
@defer.inlineCallbacks
+ def test_blocking_mau__depending_on_user_type(self):
+ self.hs.config.max_mau_value = 50
+ self.hs.config.limit_usage_by_mau = True
+
+ self.store.get_monthly_active_count = Mock(return_value=defer.succeed(100))
+ # Support users allowed
+ yield self.auth.check_auth_blocking(user_type=UserTypes.SUPPORT)
+ self.store.get_monthly_active_count = Mock(return_value=defer.succeed(100))
+ # Bots not allowed
+ with self.assertRaises(ResourceLimitError):
+ yield self.auth.check_auth_blocking(user_type=UserTypes.BOT)
+ self.store.get_monthly_active_count = Mock(return_value=defer.succeed(100))
+ # Real users not allowed
+ with self.assertRaises(ResourceLimitError):
+ yield self.auth.check_auth_blocking()
+
+ @defer.inlineCallbacks
def test_reserved_threepid(self):
self.hs.config.limit_usage_by_mau = True
self.hs.config.max_mau_value = 1
diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py
index 04b8c2c07c..52f89d3f83 100644
--- a/tests/appservice/test_scheduler.py
+++ b/tests/appservice/test_scheduler.py
@@ -37,11 +37,9 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
self.recoverer = Mock()
self.recoverer_fn = Mock(return_value=self.recoverer)
self.txnctrl = _TransactionController(
- clock=self.clock,
- store=self.store,
- as_api=self.as_api,
- recoverer_fn=self.recoverer_fn,
+ clock=self.clock, store=self.store, as_api=self.as_api
)
+ self.txnctrl.RECOVERER_CLASS = self.recoverer_fn
def test_single_service_up_txn_sent(self):
# Test: The AS is up and the txn is successfully sent.
diff --git a/tests/config/test_database.py b/tests/config/test_database.py
new file mode 100644
index 0000000000..151d3006ac
--- /dev/null
+++ b/tests/config/test_database.py
@@ -0,0 +1,52 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import yaml
+
+from synapse.config.database import DatabaseConfig
+
+from tests import unittest
+
+
+class DatabaseConfigTestCase(unittest.TestCase):
+ def test_database_configured_correctly_no_database_conf_param(self):
+ conf = yaml.safe_load(
+ DatabaseConfig().generate_config_section("/data_dir_path", None)
+ )
+
+ expected_database_conf = {
+ "name": "sqlite3",
+ "args": {"database": "/data_dir_path/homeserver.db"},
+ }
+
+ self.assertEqual(conf["database"], expected_database_conf)
+
+ def test_database_configured_correctly_database_conf_param(self):
+
+ database_conf = {
+ "name": "my super fast datastore",
+ "args": {
+ "user": "matrix",
+ "password": "synapse_database_password",
+ "host": "synapse_database_host",
+ "database": "matrix",
+ },
+ }
+
+ conf = yaml.safe_load(
+ DatabaseConfig().generate_config_section("/data_dir_path", database_conf)
+ )
+
+ self.assertEqual(conf["database"], database_conf)
diff --git a/tests/config/test_generate.py b/tests/config/test_generate.py
index 5017cbce85..2684e662de 100644
--- a/tests/config/test_generate.py
+++ b/tests/config/test_generate.py
@@ -17,6 +17,8 @@ import os.path
import re
import shutil
import tempfile
+from contextlib import redirect_stdout
+from io import StringIO
from synapse.config.homeserver import HomeServerConfig
@@ -32,17 +34,18 @@ class ConfigGenerationTestCase(unittest.TestCase):
shutil.rmtree(self.dir)
def test_generate_config_generates_files(self):
- HomeServerConfig.load_or_generate_config(
- "",
- [
- "--generate-config",
- "-c",
- self.file,
- "--report-stats=yes",
- "-H",
- "lemurs.win",
- ],
- )
+ with redirect_stdout(StringIO()):
+ HomeServerConfig.load_or_generate_config(
+ "",
+ [
+ "--generate-config",
+ "-c",
+ self.file,
+ "--report-stats=yes",
+ "-H",
+ "lemurs.win",
+ ],
+ )
self.assertSetEqual(
set(["homeserver.yaml", "lemurs.win.log.config", "lemurs.win.signing.key"]),
diff --git a/tests/config/test_load.py b/tests/config/test_load.py
index 6bfc1970ad..b3e557bd6a 100644
--- a/tests/config/test_load.py
+++ b/tests/config/test_load.py
@@ -15,6 +15,8 @@
import os.path
import shutil
import tempfile
+from contextlib import redirect_stdout
+from io import StringIO
import yaml
@@ -26,7 +28,6 @@ from tests import unittest
class ConfigLoadingTestCase(unittest.TestCase):
def setUp(self):
self.dir = tempfile.mkdtemp()
- print(self.dir)
self.file = os.path.join(self.dir, "homeserver.yaml")
def tearDown(self):
@@ -94,18 +95,27 @@ class ConfigLoadingTestCase(unittest.TestCase):
)
self.assertTrue(config.enable_registration)
+ def test_stats_enabled(self):
+ self.generate_config_and_remove_lines_containing("enable_metrics")
+ self.add_lines_to_config(["enable_metrics: true"])
+
+ # The default Metrics Flags are off by default.
+ config = HomeServerConfig.load_config("", ["-c", self.file])
+ self.assertFalse(config.metrics_flags.known_servers)
+
def generate_config(self):
- HomeServerConfig.load_or_generate_config(
- "",
- [
- "--generate-config",
- "-c",
- self.file,
- "--report-stats=yes",
- "-H",
- "lemurs.win",
- ],
- )
+ with redirect_stdout(StringIO()):
+ HomeServerConfig.load_or_generate_config(
+ "",
+ [
+ "--generate-config",
+ "-c",
+ self.file,
+ "--report-stats=yes",
+ "-H",
+ "lemurs.win",
+ ],
+ )
def generate_config_and_remove_lines_containing(self, needle):
self.generate_config()
diff --git a/tests/config/test_server.py b/tests/config/test_server.py
index 1ca5ea54ca..a10d017120 100644
--- a/tests/config/test_server.py
+++ b/tests/config/test_server.py
@@ -13,7 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.config.server import is_threepid_reserved
+import yaml
+
+from synapse.config.server import ServerConfig, is_threepid_reserved
from tests import unittest
@@ -29,3 +31,100 @@ class ServerConfigTestCase(unittest.TestCase):
self.assertTrue(is_threepid_reserved(config, user1))
self.assertFalse(is_threepid_reserved(config, user3))
self.assertFalse(is_threepid_reserved(config, user1_msisdn))
+
+ def test_unsecure_listener_no_listeners_open_private_ports_false(self):
+ conf = yaml.safe_load(
+ ServerConfig().generate_config_section(
+ "che.org", "/data_dir_path", False, None
+ )
+ )
+
+ expected_listeners = [
+ {
+ "port": 8008,
+ "tls": False,
+ "type": "http",
+ "x_forwarded": True,
+ "bind_addresses": ["::1", "127.0.0.1"],
+ "resources": [{"names": ["client", "federation"], "compress": False}],
+ }
+ ]
+
+ self.assertEqual(conf["listeners"], expected_listeners)
+
+ def test_unsecure_listener_no_listeners_open_private_ports_true(self):
+ conf = yaml.safe_load(
+ ServerConfig().generate_config_section(
+ "che.org", "/data_dir_path", True, None
+ )
+ )
+
+ expected_listeners = [
+ {
+ "port": 8008,
+ "tls": False,
+ "type": "http",
+ "x_forwarded": True,
+ "resources": [{"names": ["client", "federation"], "compress": False}],
+ }
+ ]
+
+ self.assertEqual(conf["listeners"], expected_listeners)
+
+ def test_listeners_set_correctly_open_private_ports_false(self):
+ listeners = [
+ {
+ "port": 8448,
+ "resources": [{"names": ["federation"]}],
+ "tls": True,
+ "type": "http",
+ },
+ {
+ "port": 443,
+ "resources": [{"names": ["client"]}],
+ "tls": False,
+ "type": "http",
+ },
+ ]
+
+ conf = yaml.safe_load(
+ ServerConfig().generate_config_section(
+ "this.one.listens", "/data_dir_path", True, listeners
+ )
+ )
+
+ self.assertEqual(conf["listeners"], listeners)
+
+ def test_listeners_set_correctly_open_private_ports_true(self):
+ listeners = [
+ {
+ "port": 8448,
+ "resources": [{"names": ["federation"]}],
+ "tls": True,
+ "type": "http",
+ },
+ {
+ "port": 443,
+ "resources": [{"names": ["client"]}],
+ "tls": False,
+ "type": "http",
+ },
+ {
+ "port": 1243,
+ "resources": [{"names": ["client"]}],
+ "tls": False,
+ "type": "http",
+ "bind_addresses": ["this_one_is_bound"],
+ },
+ ]
+
+ expected_listeners = listeners.copy()
+ expected_listeners[1]["bind_addresses"] = ["::1", "127.0.0.1"]
+
+ conf = yaml.safe_load(
+ ServerConfig().generate_config_section(
+ "this.one.listens", "/data_dir_path", True, listeners
+ )
+ )
+
+ self.assertEqual(conf["listeners"], expected_listeners)
diff --git a/tests/config/test_tls.py b/tests/config/test_tls.py
index 4f8a87a3df..b02780772a 100644
--- a/tests/config/test_tls.py
+++ b/tests/config/test_tls.py
@@ -16,6 +16,9 @@
import os
+import idna
+import yaml
+
from OpenSSL import SSL
from synapse.config.tls import ConfigError, TlsConfig
@@ -191,3 +194,84 @@ s4niecZKPBizL6aucT59CsunNmmb5Glq8rlAcU+1ZTZZzGYqVYhF6axB9Qg=
self.assertEqual(cf._verify_ssl._options & SSL.OP_NO_TLSv1, 0)
self.assertEqual(cf._verify_ssl._options & SSL.OP_NO_TLSv1_1, 0)
self.assertEqual(cf._verify_ssl._options & SSL.OP_NO_TLSv1_2, 0)
+
+ def test_acme_disabled_in_generated_config_no_acme_domain_provied(self):
+ """
+ Checks acme is disabled by default.
+ """
+ conf = TestConfig()
+ conf.read_config(
+ yaml.safe_load(
+ TestConfig().generate_config_section(
+ "/config_dir_path",
+ "my_super_secure_server",
+ "/data_dir_path",
+ "/tls_cert_path",
+ "tls_private_key",
+ None, # This is the acme_domain
+ )
+ ),
+ "/config_dir_path",
+ )
+
+ self.assertFalse(conf.acme_enabled)
+
+ def test_acme_enabled_in_generated_config_domain_provided(self):
+ """
+ Checks acme is enabled if the acme_domain arg is set to some string.
+ """
+ conf = TestConfig()
+ conf.read_config(
+ yaml.safe_load(
+ TestConfig().generate_config_section(
+ "/config_dir_path",
+ "my_super_secure_server",
+ "/data_dir_path",
+ "/tls_cert_path",
+ "tls_private_key",
+ "my_supe_secure_server", # This is the acme_domain
+ )
+ ),
+ "/config_dir_path",
+ )
+
+ self.assertTrue(conf.acme_enabled)
+
+ def test_whitelist_idna_failure(self):
+ """
+ The federation certificate whitelist will not allow IDNA domain names.
+ """
+ config = {
+ "federation_certificate_verification_whitelist": [
+ "example.com",
+ "*.ドメイン.テスト",
+ ]
+ }
+ t = TestConfig()
+ e = self.assertRaises(
+ ConfigError, t.read_config, config, config_dir_path="", data_dir_path=""
+ )
+ self.assertIn("IDNA domain names", str(e))
+
+ def test_whitelist_idna_result(self):
+ """
+ The federation certificate whitelist will match on IDNA encoded names.
+ """
+ config = {
+ "federation_certificate_verification_whitelist": [
+ "example.com",
+ "*.xn--eckwd4c7c.xn--zckzah",
+ ]
+ }
+ t = TestConfig()
+ t.read_config(config, config_dir_path="", data_dir_path="")
+
+ cf = ClientTLSOptionsFactory(t)
+
+ # Not in the whitelist
+ opts = cf.get_options(b"notexample.com")
+ self.assertTrue(opts._verifier._verify_certs)
+
+ # Caught by the wildcard
+ opts = cf.get_options(idna.encode("テスト.ドメイン.テスト"))
+ self.assertFalse(opts._verifier._verify_certs)
diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py
index af15f4cc5a..b08be451aa 100644
--- a/tests/federation/test_federation_server.py
+++ b/tests/federation/test_federation_server.py
@@ -20,7 +20,6 @@ from synapse.federation.federation_server import server_matches_acl_event
from tests import unittest
-@unittest.DEBUG
class ServerACLsTestCase(unittest.TestCase):
def test_blacklisted_server(self):
e = _create_acl_event({"allow": ["*"], "deny": ["evil.com"]})
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index 0ad0a88165..1e9ba3a201 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -171,11 +171,11 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
self.assertEqual(len(rooms), 0)
- def test_auto_create_auto_join_rooms_when_support_user_exists(self):
+ def test_auto_create_auto_join_rooms_when_user_is_not_a_real_user(self):
room_alias_str = "#room:test"
self.hs.config.auto_join_rooms = [room_alias_str]
- self.store.is_support_user = Mock(return_value=True)
+ self.store.is_real_user = Mock(return_value=False)
user_id = self.get_success(self.handler.register_user(localpart="support"))
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
self.assertEqual(len(rooms), 0)
@@ -183,6 +183,31 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
room_alias = RoomAlias.from_string(room_alias_str)
self.get_failure(directory_handler.get_association(room_alias), SynapseError)
+ def test_auto_create_auto_join_rooms_when_user_is_the_first_real_user(self):
+ room_alias_str = "#room:test"
+ self.hs.config.auto_join_rooms = [room_alias_str]
+
+ self.store.count_real_users = Mock(return_value=1)
+ self.store.is_real_user = Mock(return_value=True)
+ user_id = self.get_success(self.handler.register_user(localpart="real"))
+ rooms = self.get_success(self.store.get_rooms_for_user(user_id))
+ directory_handler = self.hs.get_handlers().directory_handler
+ room_alias = RoomAlias.from_string(room_alias_str)
+ room_id = self.get_success(directory_handler.get_association(room_alias))
+
+ self.assertTrue(room_id["room_id"] in rooms)
+ self.assertEqual(len(rooms), 1)
+
+ def test_auto_create_auto_join_rooms_when_user_is_not_the_first_real_user(self):
+ room_alias_str = "#room:test"
+ self.hs.config.auto_join_rooms = [room_alias_str]
+
+ self.store.count_real_users = Mock(return_value=2)
+ self.store.is_real_user = Mock(return_value=True)
+ user_id = self.get_success(self.handler.register_user(localpart="real"))
+ rooms = self.get_success(self.store.get_rooms_for_user(user_id))
+ self.assertEqual(len(rooms), 0)
+
def test_auto_create_auto_join_where_no_consent(self):
"""Test to ensure that the first user is not auto-joined to a room if
they have not given general consent.
@@ -283,4 +308,4 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
user, requester, displayname, by_admin=True
)
- return (user_id, token)
+ return user_id, token
diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py
index a8b858eb4f..7569b6fab5 100644
--- a/tests/handlers/test_stats.py
+++ b/tests/handlers/test_stats.py
@@ -13,16 +13,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from mock import Mock
-
-from twisted.internet import defer
-
-from synapse.api.constants import EventTypes, Membership
+from synapse import storage
from synapse.rest import admin
from synapse.rest.client.v1 import login, room
from tests import unittest
+# The expected number of state events in a fresh public room.
+EXPT_NUM_STATE_EVTS_IN_FRESH_PUBLIC_ROOM = 5
+# The expected number of state events in a fresh private room.
+EXPT_NUM_STATE_EVTS_IN_FRESH_PRIVATE_ROOM = 6
+
class StatsRoomTests(unittest.HomeserverTestCase):
@@ -33,7 +34,6 @@ class StatsRoomTests(unittest.HomeserverTestCase):
]
def prepare(self, reactor, clock, hs):
-
self.store = hs.get_datastore()
self.handler = self.hs.get_stats_handler()
@@ -47,7 +47,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
self.get_success(
self.store._simple_insert(
"background_updates",
- {"update_name": "populate_stats_createtables", "progress_json": "{}"},
+ {"update_name": "populate_stats_prepare", "progress_json": "{}"},
)
)
self.get_success(
@@ -56,7 +56,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
{
"update_name": "populate_stats_process_rooms",
"progress_json": "{}",
- "depends_on": "populate_stats_createtables",
+ "depends_on": "populate_stats_prepare",
},
)
)
@@ -64,18 +64,58 @@ class StatsRoomTests(unittest.HomeserverTestCase):
self.store._simple_insert(
"background_updates",
{
- "update_name": "populate_stats_cleanup",
+ "update_name": "populate_stats_process_users",
"progress_json": "{}",
"depends_on": "populate_stats_process_rooms",
},
)
)
+ self.get_success(
+ self.store._simple_insert(
+ "background_updates",
+ {
+ "update_name": "populate_stats_cleanup",
+ "progress_json": "{}",
+ "depends_on": "populate_stats_process_users",
+ },
+ )
+ )
+
+ def get_all_room_state(self):
+ return self.store._simple_select_list(
+ "room_stats_state", None, retcols=("name", "topic", "canonical_alias")
+ )
+
+ def _get_current_stats(self, stats_type, stat_id):
+ table, id_col = storage.stats.TYPE_TO_TABLE[stats_type]
+
+ cols = list(storage.stats.ABSOLUTE_STATS_FIELDS[stats_type]) + list(
+ storage.stats.PER_SLICE_FIELDS[stats_type]
+ )
+
+ end_ts = self.store.quantise_stats_time(self.reactor.seconds() * 1000)
+
+ return self.get_success(
+ self.store._simple_select_one(
+ table + "_historical",
+ {id_col: stat_id, end_ts: end_ts},
+ cols,
+ allow_none=True,
+ )
+ )
+
+ def _perform_background_initial_update(self):
+ # Do the initial population of the stats via the background update
+ self._add_background_updates()
+
+ while not self.get_success(self.store.has_completed_background_updates()):
+ self.get_success(self.store.do_next_background_update(100), by=0.1)
def test_initial_room(self):
"""
The background updates will build the table from scratch.
"""
- r = self.get_success(self.store.get_all_room_state())
+ r = self.get_success(self.get_all_room_state())
self.assertEqual(len(r), 0)
# Disable stats
@@ -91,7 +131,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
)
# Stats disabled, shouldn't have done anything
- r = self.get_success(self.store.get_all_room_state())
+ r = self.get_success(self.get_all_room_state())
self.assertEqual(len(r), 0)
# Enable stats
@@ -104,7 +144,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
while not self.get_success(self.store.has_completed_background_updates()):
self.get_success(self.store.do_next_background_update(100), by=0.1)
- r = self.get_success(self.store.get_all_room_state())
+ r = self.get_success(self.get_all_room_state())
self.assertEqual(len(r), 1)
self.assertEqual(r[0]["topic"], "foo")
@@ -114,6 +154,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
Ingestion via notify_new_event will ignore tokens that the background
update have already processed.
"""
+
self.reactor.advance(86401)
self.hs.config.stats_enabled = False
@@ -138,12 +179,18 @@ class StatsRoomTests(unittest.HomeserverTestCase):
self.hs.config.stats_enabled = True
self.handler.stats_enabled = True
self.store._all_done = False
- self.get_success(self.store.update_stats_stream_pos(None))
+ self.get_success(
+ self.store._simple_update_one(
+ table="stats_incremental_position",
+ keyvalues={},
+ updatevalues={"stream_id": 0},
+ )
+ )
self.get_success(
self.store._simple_insert(
"background_updates",
- {"update_name": "populate_stats_createtables", "progress_json": "{}"},
+ {"update_name": "populate_stats_prepare", "progress_json": "{}"},
)
)
@@ -154,6 +201,8 @@ class StatsRoomTests(unittest.HomeserverTestCase):
self.helper.invite(room=room_1, src=u1, targ=u2, tok=u1_token)
self.helper.join(room=room_1, user=u2, tok=u2_token)
+ # orig_delta_processor = self.store.
+
# Now do the initial ingestion.
self.get_success(
self.store._simple_insert(
@@ -185,8 +234,15 @@ class StatsRoomTests(unittest.HomeserverTestCase):
self.helper.invite(room=room_1, src=u1, targ=u3, tok=u1_token)
self.helper.join(room=room_1, user=u3, tok=u3_token)
- # Get the deltas! There should be two -- day 1, and day 2.
- r = self.get_success(self.store.get_deltas_for_room(room_1, 0))
+ # self.handler.notify_new_event()
+
+ # We need to let the delta processor advance…
+ self.pump(10 * 60)
+
+ # Get the slices! There should be two -- day 1, and day 2.
+ r = self.get_success(self.store.get_statistics_for_subject("room", room_1, 0))
+
+ self.assertEqual(len(r), 2)
# The oldest has 2 joined members
self.assertEqual(r[-1]["joined_members"], 2)
@@ -194,111 +250,476 @@ class StatsRoomTests(unittest.HomeserverTestCase):
# The newest has 3
self.assertEqual(r[0]["joined_members"], 3)
- def test_incorrect_state_transition(self):
- """
- If the state transition is not one of (JOIN, INVITE, LEAVE, BAN) to
- (JOIN, INVITE, LEAVE, BAN), an error is raised.
- """
- events = {
- "a1": {"membership": Membership.LEAVE},
- "a2": {"membership": "not a real thing"},
- }
-
- def get_event(event_id, allow_none=True):
- m = Mock()
- m.content = events[event_id]
- d = defer.Deferred()
- self.reactor.callLater(0.0, d.callback, m)
- return d
-
- def get_received_ts(event_id):
- return defer.succeed(1)
-
- self.store.get_received_ts = get_received_ts
- self.store.get_event = get_event
-
- deltas = [
- {
- "type": EventTypes.Member,
- "state_key": "some_user",
- "room_id": "room",
- "event_id": "a1",
- "prev_event_id": "a2",
- "stream_id": 60,
- }
- ]
-
- f = self.get_failure(self.handler._handle_deltas(deltas), ValueError)
+ def test_create_user(self):
+ """
+ When we create a user, it should have statistics already ready.
+ """
+
+ u1 = self.register_user("u1", "pass")
+
+ u1stats = self._get_current_stats("user", u1)
+
+ self.assertIsNotNone(u1stats)
+
+ # not in any rooms by default
+ self.assertEqual(u1stats["joined_rooms"], 0)
+
+ def test_create_room(self):
+ """
+ When we create a room, it should have statistics already ready.
+ """
+
+ self._perform_background_initial_update()
+
+ u1 = self.register_user("u1", "pass")
+ u1token = self.login("u1", "pass")
+ r1 = self.helper.create_room_as(u1, tok=u1token)
+ r1stats = self._get_current_stats("room", r1)
+ r2 = self.helper.create_room_as(u1, tok=u1token, is_public=False)
+ r2stats = self._get_current_stats("room", r2)
+
+ self.assertIsNotNone(r1stats)
+ self.assertIsNotNone(r2stats)
+
+ # contains the default things you'd expect in a fresh room
self.assertEqual(
- f.value.args[0], "'not a real thing' is not a valid prev_membership"
- )
-
- # And the other way...
- deltas = [
- {
- "type": EventTypes.Member,
- "state_key": "some_user",
- "room_id": "room",
- "event_id": "a2",
- "prev_event_id": "a1",
- "stream_id": 100,
- }
- ]
-
- f = self.get_failure(self.handler._handle_deltas(deltas), ValueError)
+ r1stats["total_events"],
+ EXPT_NUM_STATE_EVTS_IN_FRESH_PUBLIC_ROOM,
+ "Wrong number of total_events in new room's stats!"
+ " You may need to update this if more state events are added to"
+ " the room creation process.",
+ )
self.assertEqual(
- f.value.args[0], "'not a real thing' is not a valid membership"
+ r2stats["total_events"],
+ EXPT_NUM_STATE_EVTS_IN_FRESH_PRIVATE_ROOM,
+ "Wrong number of total_events in new room's stats!"
+ " You may need to update this if more state events are added to"
+ " the room creation process.",
)
- def test_redacted_prev_event(self):
+ self.assertEqual(
+ r1stats["current_state_events"], EXPT_NUM_STATE_EVTS_IN_FRESH_PUBLIC_ROOM
+ )
+ self.assertEqual(
+ r2stats["current_state_events"], EXPT_NUM_STATE_EVTS_IN_FRESH_PRIVATE_ROOM
+ )
+
+ self.assertEqual(r1stats["joined_members"], 1)
+ self.assertEqual(r1stats["invited_members"], 0)
+ self.assertEqual(r1stats["banned_members"], 0)
+
+ self.assertEqual(r2stats["joined_members"], 1)
+ self.assertEqual(r2stats["invited_members"], 0)
+ self.assertEqual(r2stats["banned_members"], 0)
+
+ def test_send_message_increments_total_events(self):
"""
- If the prev_event does not exist, then it is assumed to be a LEAVE.
+ When we send a message, it increments total_events.
"""
+
+ self._perform_background_initial_update()
+
u1 = self.register_user("u1", "pass")
- u1_token = self.login("u1", "pass")
+ u1token = self.login("u1", "pass")
+ r1 = self.helper.create_room_as(u1, tok=u1token)
+ r1stats_ante = self._get_current_stats("room", r1)
- room_1 = self.helper.create_room_as(u1, tok=u1_token)
+ self.helper.send(r1, "hiss", tok=u1token)
- # Do the initial population of the user directory via the background update
- self._add_background_updates()
+ r1stats_post = self._get_current_stats("room", r1)
+
+ self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1)
+
+ def test_send_state_event_nonoverwriting(self):
+ """
+ When we send a non-overwriting state event, it increments total_events AND current_state_events
+ """
+
+ self._perform_background_initial_update()
+
+ u1 = self.register_user("u1", "pass")
+ u1token = self.login("u1", "pass")
+ r1 = self.helper.create_room_as(u1, tok=u1token)
+
+ self.helper.send_state(
+ r1, "cat.hissing", {"value": True}, tok=u1token, state_key="tabby"
+ )
+
+ r1stats_ante = self._get_current_stats("room", r1)
+
+ self.helper.send_state(
+ r1, "cat.hissing", {"value": False}, tok=u1token, state_key="moggy"
+ )
+
+ r1stats_post = self._get_current_stats("room", r1)
+
+ self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1)
+ self.assertEqual(
+ r1stats_post["current_state_events"] - r1stats_ante["current_state_events"],
+ 1,
+ )
+
+ def test_send_state_event_overwriting(self):
+ """
+ When we send an overwriting state event, it increments total_events ONLY
+ """
+
+ self._perform_background_initial_update()
+
+ u1 = self.register_user("u1", "pass")
+ u1token = self.login("u1", "pass")
+ r1 = self.helper.create_room_as(u1, tok=u1token)
+
+ self.helper.send_state(
+ r1, "cat.hissing", {"value": True}, tok=u1token, state_key="tabby"
+ )
+
+ r1stats_ante = self._get_current_stats("room", r1)
+
+ self.helper.send_state(
+ r1, "cat.hissing", {"value": False}, tok=u1token, state_key="tabby"
+ )
+
+ r1stats_post = self._get_current_stats("room", r1)
+
+ self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1)
+ self.assertEqual(
+ r1stats_post["current_state_events"] - r1stats_ante["current_state_events"],
+ 0,
+ )
+
+ def test_join_first_time(self):
+ """
+ When a user joins a room for the first time, total_events, current_state_events and
+ joined_members should increase by exactly 1.
+ """
+
+ self._perform_background_initial_update()
+
+ u1 = self.register_user("u1", "pass")
+ u1token = self.login("u1", "pass")
+ r1 = self.helper.create_room_as(u1, tok=u1token)
+
+ u2 = self.register_user("u2", "pass")
+ u2token = self.login("u2", "pass")
+
+ r1stats_ante = self._get_current_stats("room", r1)
+
+ self.helper.join(r1, u2, tok=u2token)
+
+ r1stats_post = self._get_current_stats("room", r1)
+
+ self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1)
+ self.assertEqual(
+ r1stats_post["current_state_events"] - r1stats_ante["current_state_events"],
+ 1,
+ )
+ self.assertEqual(
+ r1stats_post["joined_members"] - r1stats_ante["joined_members"], 1
+ )
+
+ def test_join_after_leave(self):
+ """
+ When a user joins a room after being previously left, total_events and
+ joined_members should increase by exactly 1.
+ current_state_events should not increase.
+ left_members should decrease by exactly 1.
+ """
+
+ self._perform_background_initial_update()
+
+ u1 = self.register_user("u1", "pass")
+ u1token = self.login("u1", "pass")
+ r1 = self.helper.create_room_as(u1, tok=u1token)
+
+ u2 = self.register_user("u2", "pass")
+ u2token = self.login("u2", "pass")
+
+ self.helper.join(r1, u2, tok=u2token)
+ self.helper.leave(r1, u2, tok=u2token)
+
+ r1stats_ante = self._get_current_stats("room", r1)
+
+ self.helper.join(r1, u2, tok=u2token)
+
+ r1stats_post = self._get_current_stats("room", r1)
+
+ self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1)
+ self.assertEqual(
+ r1stats_post["current_state_events"] - r1stats_ante["current_state_events"],
+ 0,
+ )
+ self.assertEqual(
+ r1stats_post["joined_members"] - r1stats_ante["joined_members"], +1
+ )
+ self.assertEqual(
+ r1stats_post["left_members"] - r1stats_ante["left_members"], -1
+ )
+
+ def test_invited(self):
+ """
+ When a user invites another user, current_state_events, total_events and
+ invited_members should increase by exactly 1.
+ """
+
+ self._perform_background_initial_update()
+
+ u1 = self.register_user("u1", "pass")
+ u1token = self.login("u1", "pass")
+ r1 = self.helper.create_room_as(u1, tok=u1token)
+
+ u2 = self.register_user("u2", "pass")
+
+ r1stats_ante = self._get_current_stats("room", r1)
+
+ self.helper.invite(r1, u1, u2, tok=u1token)
+
+ r1stats_post = self._get_current_stats("room", r1)
+
+ self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1)
+ self.assertEqual(
+ r1stats_post["current_state_events"] - r1stats_ante["current_state_events"],
+ 1,
+ )
+ self.assertEqual(
+ r1stats_post["invited_members"] - r1stats_ante["invited_members"], +1
+ )
+
+ def test_join_after_invite(self):
+ """
+ When a user joins a room after being invited, total_events and
+ joined_members should increase by exactly 1.
+ current_state_events should not increase.
+ invited_members should decrease by exactly 1.
+ """
+
+ self._perform_background_initial_update()
+
+ u1 = self.register_user("u1", "pass")
+ u1token = self.login("u1", "pass")
+ r1 = self.helper.create_room_as(u1, tok=u1token)
+
+ u2 = self.register_user("u2", "pass")
+ u2token = self.login("u2", "pass")
+
+ self.helper.invite(r1, u1, u2, tok=u1token)
+
+ r1stats_ante = self._get_current_stats("room", r1)
+
+ self.helper.join(r1, u2, tok=u2token)
+
+ r1stats_post = self._get_current_stats("room", r1)
+
+ self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1)
+ self.assertEqual(
+ r1stats_post["current_state_events"] - r1stats_ante["current_state_events"],
+ 0,
+ )
+ self.assertEqual(
+ r1stats_post["joined_members"] - r1stats_ante["joined_members"], +1
+ )
+ self.assertEqual(
+ r1stats_post["invited_members"] - r1stats_ante["invited_members"], -1
+ )
+
+ def test_left(self):
+ """
+ When a user leaves a room after joining, total_events and
+ left_members should increase by exactly 1.
+ current_state_events should not increase.
+ joined_members should decrease by exactly 1.
+ """
+
+ self._perform_background_initial_update()
+
+ u1 = self.register_user("u1", "pass")
+ u1token = self.login("u1", "pass")
+ r1 = self.helper.create_room_as(u1, tok=u1token)
+
+ u2 = self.register_user("u2", "pass")
+ u2token = self.login("u2", "pass")
+
+ self.helper.join(r1, u2, tok=u2token)
+
+ r1stats_ante = self._get_current_stats("room", r1)
+
+ self.helper.leave(r1, u2, tok=u2token)
+
+ r1stats_post = self._get_current_stats("room", r1)
+
+ self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1)
+ self.assertEqual(
+ r1stats_post["current_state_events"] - r1stats_ante["current_state_events"],
+ 0,
+ )
+ self.assertEqual(
+ r1stats_post["left_members"] - r1stats_ante["left_members"], +1
+ )
+ self.assertEqual(
+ r1stats_post["joined_members"] - r1stats_ante["joined_members"], -1
+ )
+
+ def test_banned(self):
+ """
+ When a user is banned from a room after joining, total_events and
+ left_members should increase by exactly 1.
+ current_state_events should not increase.
+ banned_members should decrease by exactly 1.
+ """
+
+ self._perform_background_initial_update()
+
+ u1 = self.register_user("u1", "pass")
+ u1token = self.login("u1", "pass")
+ r1 = self.helper.create_room_as(u1, tok=u1token)
+
+ u2 = self.register_user("u2", "pass")
+ u2token = self.login("u2", "pass")
+
+ self.helper.join(r1, u2, tok=u2token)
+
+ r1stats_ante = self._get_current_stats("room", r1)
+
+ self.helper.change_membership(r1, u1, u2, "ban", tok=u1token)
+
+ r1stats_post = self._get_current_stats("room", r1)
+
+ self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1)
+ self.assertEqual(
+ r1stats_post["current_state_events"] - r1stats_ante["current_state_events"],
+ 0,
+ )
+ self.assertEqual(
+ r1stats_post["banned_members"] - r1stats_ante["banned_members"], +1
+ )
+ self.assertEqual(
+ r1stats_post["joined_members"] - r1stats_ante["joined_members"], -1
+ )
+
+ def test_initial_background_update(self):
+ """
+ Test that statistics can be generated by the initial background update
+ handler.
+
+ This test also tests that stats rows are not created for new subjects
+ when stats are disabled. However, it may be desirable to change this
+ behaviour eventually to still keep current rows.
+ """
+
+ self.hs.config.stats_enabled = False
+
+ u1 = self.register_user("u1", "pass")
+ u1token = self.login("u1", "pass")
+ r1 = self.helper.create_room_as(u1, tok=u1token)
+
+ # test that these subjects, which were created during a time of disabled
+ # stats, do not have stats.
+ self.assertIsNone(self._get_current_stats("room", r1))
+ self.assertIsNone(self._get_current_stats("user", u1))
+
+ self.hs.config.stats_enabled = True
+
+ self._perform_background_initial_update()
+
+ r1stats = self._get_current_stats("room", r1)
+ u1stats = self._get_current_stats("user", u1)
+
+ self.assertEqual(r1stats["joined_members"], 1)
+ self.assertEqual(
+ r1stats["current_state_events"], EXPT_NUM_STATE_EVTS_IN_FRESH_PUBLIC_ROOM
+ )
+
+ self.assertEqual(u1stats["joined_rooms"], 1)
+
+ def test_incomplete_stats(self):
+ """
+ This tests that we track incomplete statistics.
+
+ We first test that incomplete stats are incrementally generated,
+ following the preparation of a background regen.
+
+ We then test that these incomplete rows are completed by the background
+ regen.
+ """
+
+ u1 = self.register_user("u1", "pass")
+ u1token = self.login("u1", "pass")
+ u2 = self.register_user("u2", "pass")
+ u2token = self.login("u2", "pass")
+ u3 = self.register_user("u3", "pass")
+ r1 = self.helper.create_room_as(u1, tok=u1token, is_public=False)
+
+ # preparation stage of the initial background update
+ # Ugh, have to reset this flag
+ self.store._all_done = False
+
+ self.get_success(
+ self.store._simple_delete(
+ "room_stats_current", {"1": 1}, "test_delete_stats"
+ )
+ )
+ self.get_success(
+ self.store._simple_delete(
+ "user_stats_current", {"1": 1}, "test_delete_stats"
+ )
+ )
+
+ self.helper.invite(r1, u1, u2, tok=u1token)
+ self.helper.join(r1, u2, tok=u2token)
+ self.helper.invite(r1, u1, u3, tok=u1token)
+ self.helper.send(r1, "thou shalt yield", tok=u1token)
+
+ # now do the background updates
+
+ self.store._all_done = False
+ self.get_success(
+ self.store._simple_insert(
+ "background_updates",
+ {
+ "update_name": "populate_stats_process_rooms",
+ "progress_json": "{}",
+ "depends_on": "populate_stats_prepare",
+ },
+ )
+ )
+ self.get_success(
+ self.store._simple_insert(
+ "background_updates",
+ {
+ "update_name": "populate_stats_process_users",
+ "progress_json": "{}",
+ "depends_on": "populate_stats_process_rooms",
+ },
+ )
+ )
+ self.get_success(
+ self.store._simple_insert(
+ "background_updates",
+ {
+ "update_name": "populate_stats_cleanup",
+ "progress_json": "{}",
+ "depends_on": "populate_stats_process_users",
+ },
+ )
+ )
while not self.get_success(self.store.has_completed_background_updates()):
self.get_success(self.store.do_next_background_update(100), by=0.1)
- events = {"a1": None, "a2": {"membership": Membership.JOIN}}
-
- def get_event(event_id, allow_none=True):
- if events.get(event_id):
- m = Mock()
- m.content = events[event_id]
- else:
- m = None
- d = defer.Deferred()
- self.reactor.callLater(0.0, d.callback, m)
- return d
-
- def get_received_ts(event_id):
- return defer.succeed(1)
-
- self.store.get_received_ts = get_received_ts
- self.store.get_event = get_event
-
- deltas = [
- {
- "type": EventTypes.Member,
- "state_key": "some_user:test",
- "room_id": room_1,
- "event_id": "a2",
- "prev_event_id": "a1",
- "stream_id": 100,
- }
- ]
-
- # Handle our fake deltas, which has a user going from LEAVE -> JOIN.
- self.get_success(self.handler._handle_deltas(deltas))
-
- # One delta, with two joined members -- the room creator, and our fake
- # user.
- r = self.get_success(self.store.get_deltas_for_room(room_1, 0))
- self.assertEqual(len(r), 1)
- self.assertEqual(r[0]["joined_members"], 2)
+ r1stats_complete = self._get_current_stats("room", r1)
+ u1stats_complete = self._get_current_stats("user", u1)
+ u2stats_complete = self._get_current_stats("user", u2)
+
+ # now we make our assertions
+
+ # check that _complete rows are complete and correct
+ self.assertEqual(r1stats_complete["joined_members"], 2)
+ self.assertEqual(r1stats_complete["invited_members"], 1)
+
+ self.assertEqual(
+ r1stats_complete["current_state_events"],
+ 2 + EXPT_NUM_STATE_EVTS_IN_FRESH_PRIVATE_ROOM,
+ )
+
+ self.assertEqual(u1stats_complete["joined_rooms"], 1)
+ self.assertEqual(u2stats_complete["joined_rooms"], 1)
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 5d5e324df2..1f2ef5d01f 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -99,7 +99,12 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.event_source = hs.get_event_sources().sources["typing"]
self.datastore = hs.get_datastore()
- retry_timings_res = {"destination": "", "retry_last_ts": 0, "retry_interval": 0}
+ retry_timings_res = {
+ "destination": "",
+ "retry_last_ts": 0,
+ "retry_interval": 0,
+ "failure_ts": None,
+ }
self.datastore.get_destination_retry_timings.return_value = defer.succeed(
retry_timings_res
)
diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
index 1435baede2..71d7025264 100644
--- a/tests/http/federation/test_matrix_federation_agent.py
+++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -41,9 +41,9 @@ from synapse.http.federation.well_known_resolver import (
from synapse.logging.context import LoggingContext
from synapse.util.caches.ttlcache import TTLCache
+from tests import unittest
from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file
from tests.server import FakeTransport, ThreadedMemoryReactorClock
-from tests.unittest import TestCase
from tests.utils import default_config
logger = logging.getLogger(__name__)
@@ -67,14 +67,12 @@ def get_connection_factory():
return test_server_connection_factory
-class MatrixFederationAgentTests(TestCase):
+class MatrixFederationAgentTests(unittest.TestCase):
def setUp(self):
self.reactor = ThreadedMemoryReactorClock()
self.mock_resolver = Mock()
- self.well_known_cache = TTLCache("test_cache", timer=self.reactor.seconds)
-
config_dict = default_config("test", parse=False)
config_dict["federation_custom_ca_list"] = [get_test_ca_cert_file()]
@@ -82,11 +80,21 @@ class MatrixFederationAgentTests(TestCase):
config.parse_config_dict(config_dict, "", "")
self.tls_factory = ClientTLSOptionsFactory(config)
+
+ self.well_known_cache = TTLCache("test_cache", timer=self.reactor.seconds)
+ self.had_well_known_cache = TTLCache("test_cache", timer=self.reactor.seconds)
+ self.well_known_resolver = WellKnownResolver(
+ self.reactor,
+ Agent(self.reactor, contextFactory=self.tls_factory),
+ well_known_cache=self.well_known_cache,
+ had_well_known_cache=self.had_well_known_cache,
+ )
+
self.agent = MatrixFederationAgent(
reactor=self.reactor,
tls_client_options_factory=self.tls_factory,
_srv_resolver=self.mock_resolver,
- _well_known_cache=self.well_known_cache,
+ _well_known_resolver=self.well_known_resolver,
)
def _make_connection(self, client_factory, expected_sni):
@@ -543,7 +551,7 @@ class MatrixFederationAgentTests(TestCase):
self.assertEqual(self.well_known_cache[b"testserv"], b"target-server")
# check the cache expires
- self.reactor.pump((25 * 3600,))
+ self.reactor.pump((48 * 3600,))
self.well_known_cache.expire()
self.assertNotIn(b"testserv", self.well_known_cache)
@@ -631,7 +639,7 @@ class MatrixFederationAgentTests(TestCase):
self.assertEqual(self.well_known_cache[b"testserv"], b"target-server")
# check the cache expires
- self.reactor.pump((25 * 3600,))
+ self.reactor.pump((48 * 3600,))
self.well_known_cache.expire()
self.assertNotIn(b"testserv", self.well_known_cache)
@@ -701,11 +709,18 @@ class MatrixFederationAgentTests(TestCase):
config = default_config("test", parse=True)
+ # Build a new agent and WellKnownResolver with a different tls factory
+ tls_factory = ClientTLSOptionsFactory(config)
agent = MatrixFederationAgent(
reactor=self.reactor,
- tls_client_options_factory=ClientTLSOptionsFactory(config),
+ tls_client_options_factory=tls_factory,
_srv_resolver=self.mock_resolver,
- _well_known_cache=self.well_known_cache,
+ _well_known_resolver=WellKnownResolver(
+ self.reactor,
+ Agent(self.reactor, contextFactory=tls_factory),
+ well_known_cache=self.well_known_cache,
+ had_well_known_cache=self.had_well_known_cache,
+ ),
)
test_d = agent.request(b"GET", b"matrix://testserv/foo/bar")
@@ -932,15 +947,9 @@ class MatrixFederationAgentTests(TestCase):
self.successResultOf(test_d)
def test_well_known_cache(self):
- well_known_resolver = WellKnownResolver(
- self.reactor,
- Agent(self.reactor, contextFactory=self.tls_factory),
- well_known_cache=self.well_known_cache,
- )
-
self.reactor.lookups["testserv"] = "1.2.3.4"
- fetch_d = well_known_resolver.get_well_known(b"testserv")
+ fetch_d = self.well_known_resolver.get_well_known(b"testserv")
# there should be an attempt to connect on port 443 for the .well-known
clients = self.reactor.tcpClients
@@ -963,7 +972,7 @@ class MatrixFederationAgentTests(TestCase):
well_known_server.loseConnection()
# repeat the request: it should hit the cache
- fetch_d = well_known_resolver.get_well_known(b"testserv")
+ fetch_d = self.well_known_resolver.get_well_known(b"testserv")
r = self.successResultOf(fetch_d)
self.assertEqual(r.delegated_server, b"target-server")
@@ -971,7 +980,7 @@ class MatrixFederationAgentTests(TestCase):
self.reactor.pump((1000.0,))
# now it should connect again
- fetch_d = well_known_resolver.get_well_known(b"testserv")
+ fetch_d = self.well_known_resolver.get_well_known(b"testserv")
self.assertEqual(len(clients), 1)
(host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
@@ -987,8 +996,137 @@ class MatrixFederationAgentTests(TestCase):
r = self.successResultOf(fetch_d)
self.assertEqual(r.delegated_server, b"other-server")
+ def test_well_known_cache_with_temp_failure(self):
+ """Test that we refetch well-known before the cache expires, and that
+ it ignores transient errors.
+ """
+
+ self.reactor.lookups["testserv"] = "1.2.3.4"
+
+ fetch_d = self.well_known_resolver.get_well_known(b"testserv")
+
+ # there should be an attempt to connect on port 443 for the .well-known
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
+ self.assertEqual(host, "1.2.3.4")
+ self.assertEqual(port, 443)
+
+ well_known_server = self._handle_well_known_connection(
+ client_factory,
+ expected_sni=b"testserv",
+ response_headers={b"Cache-Control": b"max-age=1000"},
+ content=b'{ "m.server": "target-server" }',
+ )
+
+ r = self.successResultOf(fetch_d)
+ self.assertEqual(r.delegated_server, b"target-server")
+
+ # close the tcp connection
+ well_known_server.loseConnection()
+
+ # Get close to the cache expiry, this will cause the resolver to do
+ # another lookup.
+ self.reactor.pump((900.0,))
+
+ fetch_d = self.well_known_resolver.get_well_known(b"testserv")
+
+ # The resolver may retry a few times, so fonx all requests that come along
+ attempts = 0
+ while self.reactor.tcpClients:
+ clients = self.reactor.tcpClients
+ (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
+
+ attempts += 1
+
+ # fonx the connection attempt, this will be treated as a temporary
+ # failure.
+ client_factory.clientConnectionFailed(None, Exception("nope"))
+
+ # There's a few sleeps involved, so we have to pump the reactor a
+ # bit.
+ self.reactor.pump((1.0, 1.0))
+
+ # We expect to see more than one attempt as there was previously a valid
+ # well known.
+ self.assertGreater(attempts, 1)
+
+ # Resolver should return cached value, despite the lookup failing.
+ r = self.successResultOf(fetch_d)
+ self.assertEqual(r.delegated_server, b"target-server")
+
+ # Expire both caches and repeat the request
+ self.reactor.pump((10000.0,))
+
+ # Repated the request, this time it should fail if the lookup fails.
+ fetch_d = self.well_known_resolver.get_well_known(b"testserv")
+
+ clients = self.reactor.tcpClients
+ (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
+ client_factory.clientConnectionFailed(None, Exception("nope"))
+ self.reactor.pump((0.4,))
+
+ r = self.successResultOf(fetch_d)
+ self.assertEqual(r.delegated_server, None)
+
+ def test_srv_fallbacks(self):
+ """Test that other SRV results are tried if the first one fails.
+ """
+
+ self.mock_resolver.resolve_service.side_effect = lambda _: [
+ Server(host=b"target.com", port=8443),
+ Server(host=b"target.com", port=8444),
+ ]
+ self.reactor.lookups["target.com"] = "1.2.3.4"
+
+ test_d = self._make_get_request(b"matrix://testserv/foo/bar")
+
+ # Nothing happened yet
+ self.assertNoResult(test_d)
+
+ self.mock_resolver.resolve_service.assert_called_once_with(
+ b"_matrix._tcp.testserv"
+ )
+
+ # We should see an attempt to connect to the first server
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
+ self.assertEqual(host, "1.2.3.4")
+ self.assertEqual(port, 8443)
+
+ # Fonx the connection
+ client_factory.clientConnectionFailed(None, Exception("nope"))
+
+ # There's a 300ms delay in HostnameEndpoint
+ self.reactor.pump((0.4,))
+
+ # Hasn't failed yet
+ self.assertNoResult(test_d)
+
+ # We shouldnow see an attempt to connect to the second server
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
+ self.assertEqual(host, "1.2.3.4")
+ self.assertEqual(port, 8444)
+
+ # make a test server, and wire up the client
+ http_server = self._make_connection(client_factory, expected_sni=b"testserv")
+
+ self.assertEqual(len(http_server.requests), 1)
+ request = http_server.requests[0]
+ self.assertEqual(request.method, b"GET")
+ self.assertEqual(request.path, b"/foo/bar")
+ self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"testserv"])
+
+ # finish the request
+ request.finish()
+ self.reactor.pump((0.1,))
+ self.successResultOf(test_d)
+
-class TestCachePeriodFromHeaders(TestCase):
+class TestCachePeriodFromHeaders(unittest.TestCase):
def test_cache_control(self):
# uppercase
self.assertEqual(
diff --git a/tests/http/federation/test_srv_resolver.py b/tests/http/federation/test_srv_resolver.py
index 3b885ef64b..df034ab237 100644
--- a/tests/http/federation/test_srv_resolver.py
+++ b/tests/http/federation/test_srv_resolver.py
@@ -83,8 +83,10 @@ class SrvResolverTestCase(unittest.TestCase):
service_name = b"test_service.example.com"
- entry = Mock(spec_set=["expires"])
+ entry = Mock(spec_set=["expires", "priority", "weight"])
entry.expires = 0
+ entry.priority = 0
+ entry.weight = 0
cache = {service_name: [entry]}
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
@@ -105,8 +107,10 @@ class SrvResolverTestCase(unittest.TestCase):
service_name = b"test_service.example.com"
- entry = Mock(spec_set=["expires"])
+ entry = Mock(spec_set=["expires", "priority", "weight"])
entry.expires = 999999999
+ entry.priority = 0
+ entry.weight = 0
cache = {service_name: [entry]}
resolver = SrvResolver(
diff --git a/tests/logging/__init__.py b/tests/logging/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/tests/logging/__init__.py
diff --git a/tests/logging/test_structured.py b/tests/logging/test_structured.py
new file mode 100644
index 0000000000..451d05c0f0
--- /dev/null
+++ b/tests/logging/test_structured.py
@@ -0,0 +1,214 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import os
+import os.path
+import shutil
+import sys
+import textwrap
+
+from twisted.logger import Logger, eventAsText, eventsFromJSONLogFile
+
+from synapse.config.logger import setup_logging
+from synapse.logging._structured import setup_structured_logging
+from synapse.logging.context import LoggingContext
+
+from tests.unittest import DEBUG, HomeserverTestCase
+
+
+class FakeBeginner(object):
+ def beginLoggingTo(self, observers, **kwargs):
+ self.observers = observers
+
+
+class StructuredLoggingTestBase(object):
+ """
+ Test base that registers a cleanup handler to reset the stdlib log handler
+ to 'unset'.
+ """
+
+ def prepare(self, reactor, clock, hs):
+ def _cleanup():
+ logging.getLogger("synapse").setLevel(logging.NOTSET)
+
+ self.addCleanup(_cleanup)
+
+
+class StructuredLoggingTestCase(StructuredLoggingTestBase, HomeserverTestCase):
+ """
+ Tests for Synapse's structured logging support.
+ """
+
+ def test_output_to_json_round_trip(self):
+ """
+ Synapse logs can be outputted to JSON and then read back again.
+ """
+ temp_dir = self.mktemp()
+ os.mkdir(temp_dir)
+ self.addCleanup(shutil.rmtree, temp_dir)
+
+ json_log_file = os.path.abspath(os.path.join(temp_dir, "out.json"))
+
+ log_config = {
+ "drains": {"jsonfile": {"type": "file_json", "location": json_log_file}}
+ }
+
+ # Begin the logger with our config
+ beginner = FakeBeginner()
+ setup_structured_logging(
+ self.hs, self.hs.config, log_config, logBeginner=beginner
+ )
+
+ # Make a logger and send an event
+ logger = Logger(
+ namespace="tests.logging.test_structured", observer=beginner.observers[0]
+ )
+ logger.info("Hello there, {name}!", name="wally")
+
+ # Read the log file and check it has the event we sent
+ with open(json_log_file, "r") as f:
+ logged_events = list(eventsFromJSONLogFile(f))
+ self.assertEqual(len(logged_events), 1)
+
+ # The event pulled from the file should render fine
+ self.assertEqual(
+ eventAsText(logged_events[0], includeTimestamp=False),
+ "[tests.logging.test_structured#info] Hello there, wally!",
+ )
+
+ def test_output_to_text(self):
+ """
+ Synapse logs can be outputted to text.
+ """
+ temp_dir = self.mktemp()
+ os.mkdir(temp_dir)
+ self.addCleanup(shutil.rmtree, temp_dir)
+
+ log_file = os.path.abspath(os.path.join(temp_dir, "out.log"))
+
+ log_config = {"drains": {"file": {"type": "file", "location": log_file}}}
+
+ # Begin the logger with our config
+ beginner = FakeBeginner()
+ setup_structured_logging(
+ self.hs, self.hs.config, log_config, logBeginner=beginner
+ )
+
+ # Make a logger and send an event
+ logger = Logger(
+ namespace="tests.logging.test_structured", observer=beginner.observers[0]
+ )
+ logger.info("Hello there, {name}!", name="wally")
+
+ # Read the log file and check it has the event we sent
+ with open(log_file, "r") as f:
+ logged_events = f.read().strip().split("\n")
+ self.assertEqual(len(logged_events), 1)
+
+ # The event pulled from the file should render fine
+ self.assertTrue(
+ logged_events[0].endswith(
+ " - tests.logging.test_structured - INFO - None - Hello there, wally!"
+ )
+ )
+
+ def test_collects_logcontext(self):
+ """
+ Test that log outputs have the attached logging context.
+ """
+ log_config = {"drains": {}}
+
+ # Begin the logger with our config
+ beginner = FakeBeginner()
+ publisher = setup_structured_logging(
+ self.hs, self.hs.config, log_config, logBeginner=beginner
+ )
+
+ logs = []
+
+ publisher.addObserver(logs.append)
+
+ # Make a logger and send an event
+ logger = Logger(
+ namespace="tests.logging.test_structured", observer=beginner.observers[0]
+ )
+
+ with LoggingContext("testcontext", request="somereq"):
+ logger.info("Hello there, {name}!", name="steve")
+
+ self.assertEqual(len(logs), 1)
+ self.assertEqual(logs[0]["request"], "somereq")
+
+
+class StructuredLoggingConfigurationFileTestCase(
+ StructuredLoggingTestBase, HomeserverTestCase
+):
+ def make_homeserver(self, reactor, clock):
+
+ tempdir = self.mktemp()
+ os.mkdir(tempdir)
+ log_config_file = os.path.abspath(os.path.join(tempdir, "log.config.yaml"))
+ self.homeserver_log = os.path.abspath(os.path.join(tempdir, "homeserver.log"))
+
+ config = self.default_config()
+ config["log_config"] = log_config_file
+
+ with open(log_config_file, "w") as f:
+ f.write(
+ textwrap.dedent(
+ """\
+ structured: true
+
+ drains:
+ file:
+ type: file_json
+ location: %s
+ """
+ % (self.homeserver_log,)
+ )
+ )
+
+ self.addCleanup(self._sys_cleanup)
+
+ return self.setup_test_homeserver(config=config)
+
+ def _sys_cleanup(self):
+ sys.stdout = sys.__stdout__
+ sys.stderr = sys.__stderr__
+
+ # Do not remove! We need the logging system to be set other than WARNING.
+ @DEBUG
+ def test_log_output(self):
+ """
+ When a structured logging config is given, Synapse will use it.
+ """
+ beginner = FakeBeginner()
+ publisher = setup_logging(self.hs, self.hs.config, logBeginner=beginner)
+
+ # Make a logger and send an event
+ logger = Logger(namespace="tests.logging.test_structured", observer=publisher)
+
+ with LoggingContext("testcontext", request="somereq"):
+ logger.info("Hello there, {name}!", name="steve")
+
+ with open(self.homeserver_log, "r") as f:
+ logged_events = [
+ eventAsText(x, includeTimestamp=False) for x in eventsFromJSONLogFile(f)
+ ]
+
+ logs = "\n".join(logged_events)
+ self.assertTrue("***** STARTING SERVER *****" in logs)
+ self.assertTrue("Hello there, steve!" in logs)
diff --git a/tests/logging/test_terse_json.py b/tests/logging/test_terse_json.py
new file mode 100644
index 0000000000..4cf81f7128
--- /dev/null
+++ b/tests/logging/test_terse_json.py
@@ -0,0 +1,234 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+from collections import Counter
+
+from twisted.logger import Logger
+
+from synapse.logging._structured import setup_structured_logging
+
+from tests.server import connect_client
+from tests.unittest import HomeserverTestCase
+
+from .test_structured import FakeBeginner, StructuredLoggingTestBase
+
+
+class TerseJSONTCPTestCase(StructuredLoggingTestBase, HomeserverTestCase):
+ def test_log_output(self):
+ """
+ The Terse JSON outputter delivers simplified structured logs over TCP.
+ """
+ log_config = {
+ "drains": {
+ "tersejson": {
+ "type": "network_json_terse",
+ "host": "127.0.0.1",
+ "port": 8000,
+ }
+ }
+ }
+
+ # Begin the logger with our config
+ beginner = FakeBeginner()
+ setup_structured_logging(
+ self.hs, self.hs.config, log_config, logBeginner=beginner
+ )
+
+ logger = Logger(
+ namespace="tests.logging.test_terse_json", observer=beginner.observers[0]
+ )
+ logger.info("Hello there, {name}!", name="wally")
+
+ # Trigger the connection
+ self.pump()
+
+ _, server = connect_client(self.reactor, 0)
+
+ # Trigger data being sent
+ self.pump()
+
+ # One log message, with a single trailing newline
+ logs = server.data.decode("utf8").splitlines()
+ self.assertEqual(len(logs), 1)
+ self.assertEqual(server.data.count(b"\n"), 1)
+
+ log = json.loads(logs[0])
+
+ # The terse logger should give us these keys.
+ expected_log_keys = [
+ "log",
+ "time",
+ "level",
+ "log_namespace",
+ "request",
+ "scope",
+ "server_name",
+ "name",
+ ]
+ self.assertEqual(set(log.keys()), set(expected_log_keys))
+
+ # It contains the data we expect.
+ self.assertEqual(log["name"], "wally")
+
+ def test_log_backpressure_debug(self):
+ """
+ When backpressure is hit, DEBUG logs will be shed.
+ """
+ log_config = {
+ "loggers": {"synapse": {"level": "DEBUG"}},
+ "drains": {
+ "tersejson": {
+ "type": "network_json_terse",
+ "host": "127.0.0.1",
+ "port": 8000,
+ "maximum_buffer": 10,
+ }
+ },
+ }
+
+ # Begin the logger with our config
+ beginner = FakeBeginner()
+ setup_structured_logging(
+ self.hs,
+ self.hs.config,
+ log_config,
+ logBeginner=beginner,
+ redirect_stdlib_logging=False,
+ )
+
+ logger = Logger(
+ namespace="synapse.logging.test_terse_json", observer=beginner.observers[0]
+ )
+
+ # Send some debug messages
+ for i in range(0, 3):
+ logger.debug("debug %s" % (i,))
+
+ # Send a bunch of useful messages
+ for i in range(0, 7):
+ logger.info("test message %s" % (i,))
+
+ # The last debug message pushes it past the maximum buffer
+ logger.debug("too much debug")
+
+ # Allow the reconnection
+ _, server = connect_client(self.reactor, 0)
+ self.pump()
+
+ # Only the 7 infos made it through, the debugs were elided
+ logs = server.data.splitlines()
+ self.assertEqual(len(logs), 7)
+
+ def test_log_backpressure_info(self):
+ """
+ When backpressure is hit, DEBUG and INFO logs will be shed.
+ """
+ log_config = {
+ "loggers": {"synapse": {"level": "DEBUG"}},
+ "drains": {
+ "tersejson": {
+ "type": "network_json_terse",
+ "host": "127.0.0.1",
+ "port": 8000,
+ "maximum_buffer": 10,
+ }
+ },
+ }
+
+ # Begin the logger with our config
+ beginner = FakeBeginner()
+ setup_structured_logging(
+ self.hs,
+ self.hs.config,
+ log_config,
+ logBeginner=beginner,
+ redirect_stdlib_logging=False,
+ )
+
+ logger = Logger(
+ namespace="synapse.logging.test_terse_json", observer=beginner.observers[0]
+ )
+
+ # Send some debug messages
+ for i in range(0, 3):
+ logger.debug("debug %s" % (i,))
+
+ # Send a bunch of useful messages
+ for i in range(0, 10):
+ logger.warn("test warn %s" % (i,))
+
+ # Send a bunch of info messages
+ for i in range(0, 3):
+ logger.info("test message %s" % (i,))
+
+ # The last debug message pushes it past the maximum buffer
+ logger.debug("too much debug")
+
+ # Allow the reconnection
+ client, server = connect_client(self.reactor, 0)
+ self.pump()
+
+ # The 10 warnings made it through, the debugs and infos were elided
+ logs = list(map(json.loads, server.data.decode("utf8").splitlines()))
+ self.assertEqual(len(logs), 10)
+
+ self.assertEqual(Counter([x["level"] for x in logs]), {"WARN": 10})
+
+ def test_log_backpressure_cut_middle(self):
+ """
+ When backpressure is hit, and no more DEBUG and INFOs cannot be culled,
+ it will cut the middle messages out.
+ """
+ log_config = {
+ "loggers": {"synapse": {"level": "DEBUG"}},
+ "drains": {
+ "tersejson": {
+ "type": "network_json_terse",
+ "host": "127.0.0.1",
+ "port": 8000,
+ "maximum_buffer": 10,
+ }
+ },
+ }
+
+ # Begin the logger with our config
+ beginner = FakeBeginner()
+ setup_structured_logging(
+ self.hs,
+ self.hs.config,
+ log_config,
+ logBeginner=beginner,
+ redirect_stdlib_logging=False,
+ )
+
+ logger = Logger(
+ namespace="synapse.logging.test_terse_json", observer=beginner.observers[0]
+ )
+
+ # Send a bunch of useful messages
+ for i in range(0, 20):
+ logger.warn("test warn", num=i)
+
+ # Allow the reconnection
+ client, server = connect_client(self.reactor, 0)
+ self.pump()
+
+ # The first five and last five warnings made it through, the debugs and
+ # infos were elided
+ logs = list(map(json.loads, server.data.decode("utf8").splitlines()))
+ self.assertEqual(len(logs), 10)
+ self.assertEqual(Counter([x["level"] for x in logs]), {"WARN": 10})
+ self.assertEqual([0, 1, 2, 3, 4, 15, 16, 17, 18, 19], [x["num"] for x in logs])
diff --git a/tests/rest/client/test_redactions.py b/tests/rest/client/test_redactions.py
index fe66e397c4..d2bcf256fa 100644
--- a/tests/rest/client/test_redactions.py
+++ b/tests/rest/client/test_redactions.py
@@ -30,6 +30,14 @@ class RedactionsTestCase(HomeserverTestCase):
sync.register_servlets,
]
+ def make_homeserver(self, reactor, clock):
+ config = self.default_config()
+
+ config["rc_message"] = {"per_second": 0.2, "burst_count": 10}
+ config["rc_admin_redaction"] = {"per_second": 1, "burst_count": 100}
+
+ return self.setup_test_homeserver(config=config)
+
def prepare(self, reactor, clock, hs):
# register a couple of users
self.mod_user_id = self.register_user("user1", "pass")
@@ -177,3 +185,20 @@ class RedactionsTestCase(HomeserverTestCase):
self._redact_event(
self.other_access_token, self.room_id, create_event_id, expect_code=403
)
+
+ def test_redact_event_as_moderator_ratelimit(self):
+ """Tests that the correct ratelimiting is applied to redactions
+ """
+
+ message_ids = []
+ # as a regular user, send messages to redact
+ for _ in range(20):
+ b = self.helper.send(room_id=self.room_id, tok=self.other_access_token)
+ message_ids.append(b["event_id"])
+ self.reactor.advance(10) # To get around ratelimits
+
+ # as the moderator, send a bunch of redactions
+ for msg_id in message_ids:
+ # These should all succeed, even though this would be denied by
+ # the standard message ratelimiter
+ self._redact_event(self.mod_access_token, self.room_id, msg_id)
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index 9915367144..cdded88b7f 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -128,8 +128,12 @@ class RestHelper(object):
return channel.json_body
- def send_state(self, room_id, event_type, body, tok, expect_code=200):
- path = "/_matrix/client/r0/rooms/%s/state/%s" % (room_id, event_type)
+ def send_state(self, room_id, event_type, body, tok, expect_code=200, state_key=""):
+ path = "/_matrix/client/r0/rooms/%s/state/%s/%s" % (
+ room_id,
+ event_type,
+ state_key,
+ )
if tok:
path = path + "?access_token=%s" % tok
diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py
index b9ef46e8fb..b6df1396ad 100644
--- a/tests/rest/client/v2_alpha/test_auth.py
+++ b/tests/rest/client/v2_alpha/test_auth.py
@@ -18,11 +18,22 @@ from twisted.internet.defer import succeed
import synapse.rest.admin
from synapse.api.constants import LoginType
+from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
from synapse.rest.client.v2_alpha import auth, register
from tests import unittest
+class DummyRecaptchaChecker(UserInteractiveAuthChecker):
+ def __init__(self, hs):
+ super().__init__(hs)
+ self.recaptcha_attempts = []
+
+ def check_auth(self, authdict, clientip):
+ self.recaptcha_attempts.append((authdict, clientip))
+ return succeed(True)
+
+
class FallbackAuthTests(unittest.HomeserverTestCase):
servlets = [
@@ -44,15 +55,9 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
return hs
def prepare(self, reactor, clock, hs):
+ self.recaptcha_checker = DummyRecaptchaChecker(hs)
auth_handler = hs.get_auth_handler()
-
- self.recaptcha_attempts = []
-
- def _recaptcha(authdict, clientip):
- self.recaptcha_attempts.append((authdict, clientip))
- return succeed(True)
-
- auth_handler.checkers[LoginType.RECAPTCHA] = _recaptcha
+ auth_handler.checkers[LoginType.RECAPTCHA] = self.recaptcha_checker
@unittest.INFO
def test_fallback_captcha(self):
@@ -89,8 +94,9 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
self.assertEqual(request.code, 200)
# The recaptcha handler is called with the response given
- self.assertEqual(len(self.recaptcha_attempts), 1)
- self.assertEqual(self.recaptcha_attempts[0][0]["response"], "a")
+ attempts = self.recaptcha_checker.recaptcha_attempts
+ self.assertEqual(len(attempts), 1)
+ self.assertEqual(attempts[0][0]["response"], "a")
# also complete the dummy auth
request, channel = self.make_request(
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index bb867150f4..dab87e5edf 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -34,19 +34,12 @@ from tests import unittest
class RegisterRestServletTestCase(unittest.HomeserverTestCase):
servlets = [register.register_servlets]
+ url = b"/_matrix/client/r0/register"
- def make_homeserver(self, reactor, clock):
-
- self.url = b"/_matrix/client/r0/register"
-
- self.hs = self.setup_test_homeserver()
- self.hs.config.enable_registration = True
- self.hs.config.registrations_require_3pid = []
- self.hs.config.auto_join_rooms = []
- self.hs.config.enable_registration_captcha = False
- self.hs.config.allow_guest_access = True
-
- return self.hs
+ def default_config(self, name="test"):
+ config = super().default_config(name)
+ config["allow_guest_access"] = True
+ return config
def test_POST_appservice_registration_valid(self):
user_id = "@as_user_kermit:test"
@@ -199,6 +192,73 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"200", channel.result)
+ def test_advertised_flows(self):
+ request, channel = self.make_request(b"POST", self.url, b"{}")
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"401", channel.result)
+ flows = channel.json_body["flows"]
+
+ # with the stock config, we only expect the dummy flow
+ self.assertCountEqual([["m.login.dummy"]], (f["stages"] for f in flows))
+
+ @unittest.override_config(
+ {
+ "enable_registration_captcha": True,
+ "user_consent": {
+ "version": "1",
+ "template_dir": "/",
+ "require_at_registration": True,
+ },
+ "account_threepid_delegates": {
+ "email": "https://id_server",
+ "msisdn": "https://id_server",
+ },
+ }
+ )
+ def test_advertised_flows_captcha_and_terms_and_3pids(self):
+ request, channel = self.make_request(b"POST", self.url, b"{}")
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"401", channel.result)
+ flows = channel.json_body["flows"]
+
+ self.assertCountEqual(
+ [
+ ["m.login.recaptcha", "m.login.terms", "m.login.dummy"],
+ ["m.login.recaptcha", "m.login.terms", "m.login.email.identity"],
+ ["m.login.recaptcha", "m.login.terms", "m.login.msisdn"],
+ [
+ "m.login.recaptcha",
+ "m.login.terms",
+ "m.login.msisdn",
+ "m.login.email.identity",
+ ],
+ ],
+ (f["stages"] for f in flows),
+ )
+
+ @unittest.override_config(
+ {
+ "public_baseurl": "https://test_server",
+ "registrations_require_3pid": ["email"],
+ "disable_msisdn_registration": True,
+ "email": {
+ "smtp_host": "mail_server",
+ "smtp_port": 2525,
+ "notif_from": "sender@host",
+ },
+ }
+ )
+ def test_advertised_flows_no_msisdn_email_required(self):
+ request, channel = self.make_request(b"POST", self.url, b"{}")
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"401", channel.result)
+ flows = channel.json_body["flows"]
+
+ # with the stock config, we expect all four combinations of 3pid
+ self.assertCountEqual(
+ [["m.login.email.identity"]], (f["stages"] for f in flows)
+ )
+
class AccountValidityTestCase(unittest.HomeserverTestCase):
@@ -472,7 +532,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
added_at=now,
)
)
- return (user_id, tok)
+ return user_id, tok
def test_manual_email_send_expired_account(self):
user_id = self.register_user("kermit", "monkey")
diff --git a/tests/server.py b/tests/server.py
index e573c4e4c5..e397ebe8fa 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -11,9 +11,13 @@ from twisted.internet import address, threads, udp
from twisted.internet._resolver import SimpleResolverComplexifier
from twisted.internet.defer import Deferred, fail, succeed
from twisted.internet.error import DNSLookupError
-from twisted.internet.interfaces import IReactorPluggableNameResolver, IResolverSimple
+from twisted.internet.interfaces import (
+ IReactorPluggableNameResolver,
+ IReactorTCP,
+ IResolverSimple,
+)
from twisted.python.failure import Failure
-from twisted.test.proto_helpers import MemoryReactorClock
+from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock
from twisted.web.http import unquote
from twisted.web.http_headers import Headers
@@ -334,7 +338,7 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs):
def get_clock():
clock = ThreadedMemoryReactorClock()
hs_clock = Clock(clock)
- return (clock, hs_clock)
+ return clock, hs_clock
@attr.s(cmp=False)
@@ -465,3 +469,22 @@ class FakeTransport(object):
self.buffer = self.buffer[len(to_write) :]
if self.buffer and self.autoflush:
self._reactor.callLater(0.0, self.flush)
+
+
+def connect_client(reactor: IReactorTCP, client_id: int) -> AccumulatingProtocol:
+ """
+ Connect a client to a fake TCP transport.
+
+ Args:
+ reactor
+ factory: The connecting factory to build.
+ """
+ factory = reactor.tcpClients[client_id][2]
+ client = factory.buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, reactor))
+ client.makeConnection(FakeTransport(server, reactor))
+
+ reactor.tcpClients.pop(client_id)
+
+ return client, server
diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py
index e9e2d5337c..34f9c72709 100644
--- a/tests/storage/test_cleanup_extrems.py
+++ b/tests/storage/test_cleanup_extrems.py
@@ -14,7 +14,13 @@
# limitations under the License.
import os.path
+from unittest.mock import patch
+from mock import Mock
+
+import synapse.rest.admin
+from synapse.api.constants import EventTypes
+from synapse.rest.client.v1 import login, room
from synapse.storage import prepare_database
from synapse.types import Requester, UserID
@@ -225,6 +231,14 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
+ CONSENT_VERSION = "1"
+ EXTREMITIES_COUNT = 50
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
def make_homeserver(self, reactor, clock):
config = self.default_config()
config["cleanup_extremities_with_dummy_events"] = True
@@ -233,28 +247,39 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
def prepare(self, reactor, clock, homeserver):
self.store = homeserver.get_datastore()
self.room_creator = homeserver.get_room_creation_handler()
+ self.event_creator_handler = homeserver.get_event_creation_handler()
# Create a test user and room
- self.user = UserID("alice", "test")
+ self.user = UserID.from_string(self.register_user("user1", "password"))
+ self.token1 = self.login("user1", "password")
self.requester = Requester(self.user, None, False, None, None)
info = self.get_success(self.room_creator.create_room(self.requester, {}))
self.room_id = info["room_id"]
+ self.event_creator = homeserver.get_event_creation_handler()
+ homeserver.config.user_consent_version = self.CONSENT_VERSION
def test_send_dummy_event(self):
- # Create a bushy graph with 50 extremities.
-
- event_id_start = self.create_and_send_event(self.room_id, self.user)
+ self._create_extremity_rich_graph()
- for _ in range(50):
- self.create_and_send_event(
- self.room_id, self.user, prev_event_ids=[event_id_start]
- )
+ # Pump the reactor repeatedly so that the background updates have a
+ # chance to run.
+ self.pump(10 * 60)
latest_event_ids = self.get_success(
self.store.get_latest_event_ids_in_room(self.room_id)
)
- self.assertEqual(len(latest_event_ids), 50)
+ self.assertTrue(len(latest_event_ids) < 10, len(latest_event_ids))
+ @patch("synapse.handlers.message._DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY", new=0)
+ def test_send_dummy_events_when_insufficient_power(self):
+ self._create_extremity_rich_graph()
+ # Criple power levels
+ self.helper.send_state(
+ self.room_id,
+ EventTypes.PowerLevels,
+ body={"users": {str(self.user): -1}},
+ tok=self.token1,
+ )
# Pump the reactor repeatedly so that the background updates have a
# chance to run.
self.pump(10 * 60)
@@ -262,4 +287,108 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
latest_event_ids = self.get_success(
self.store.get_latest_event_ids_in_room(self.room_id)
)
+ # Check that the room has not been pruned
+ self.assertTrue(len(latest_event_ids) > 10)
+
+ # New user with regular levels
+ user2 = self.register_user("user2", "password")
+ token2 = self.login("user2", "password")
+ self.helper.join(self.room_id, user2, tok=token2)
+ self.pump(10 * 60)
+
+ latest_event_ids = self.get_success(
+ self.store.get_latest_event_ids_in_room(self.room_id)
+ )
+ self.assertTrue(len(latest_event_ids) < 10, len(latest_event_ids))
+
+ @patch("synapse.handlers.message._DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY", new=0)
+ def test_send_dummy_event_without_consent(self):
+ self._create_extremity_rich_graph()
+ self._enable_consent_checking()
+
+ # Pump the reactor repeatedly so that the background updates have a
+ # chance to run. Attempt to add dummy event with user that has not consented
+ # Check that dummy event send fails.
+ self.pump(10 * 60)
+ latest_event_ids = self.get_success(
+ self.store.get_latest_event_ids_in_room(self.room_id)
+ )
+ self.assertTrue(len(latest_event_ids) == self.EXTREMITIES_COUNT)
+
+ # Create new user, and add consent
+ user2 = self.register_user("user2", "password")
+ token2 = self.login("user2", "password")
+ self.get_success(
+ self.store.user_set_consent_version(user2, self.CONSENT_VERSION)
+ )
+ self.helper.join(self.room_id, user2, tok=token2)
+
+ # Background updates should now cause a dummy event to be added to the graph
+ self.pump(10 * 60)
+
+ latest_event_ids = self.get_success(
+ self.store.get_latest_event_ids_in_room(self.room_id)
+ )
self.assertTrue(len(latest_event_ids) < 10, len(latest_event_ids))
+
+ @patch("synapse.handlers.message._DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY", new=250)
+ def test_expiry_logic(self):
+ """Simple test to ensure that _expire_rooms_to_exclude_from_dummy_event_insertion()
+ expires old entries correctly.
+ """
+ self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion[
+ "1"
+ ] = 100000
+ self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion[
+ "2"
+ ] = 200000
+ self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion[
+ "3"
+ ] = 300000
+ self.event_creator_handler._expire_rooms_to_exclude_from_dummy_event_insertion()
+ # All entries within time frame
+ self.assertEqual(
+ len(
+ self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion
+ ),
+ 3,
+ )
+ # Oldest room to expire
+ self.pump(1)
+ self.event_creator_handler._expire_rooms_to_exclude_from_dummy_event_insertion()
+ self.assertEqual(
+ len(
+ self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion
+ ),
+ 2,
+ )
+ # All rooms to expire
+ self.pump(2)
+ self.assertEqual(
+ len(
+ self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion
+ ),
+ 0,
+ )
+
+ def _create_extremity_rich_graph(self):
+ """Helper method to create bushy graph on demand"""
+
+ event_id_start = self.create_and_send_event(self.room_id, self.user)
+
+ for _ in range(self.EXTREMITIES_COUNT):
+ self.create_and_send_event(
+ self.room_id, self.user, prev_event_ids=[event_id_start]
+ )
+
+ latest_event_ids = self.get_success(
+ self.store.get_latest_event_ids_in_room(self.room_id)
+ )
+ self.assertEqual(len(latest_event_ids), 50)
+
+ def _enable_consent_checking(self):
+ """Helper method to enable consent checking"""
+ self.event_creator._block_events_without_consent_error = "No consent from user"
+ consent_uri_builder = Mock()
+ consent_uri_builder.build_user_consent_uri.return_value = "http://example.com"
+ self.event_creator._consent_uri_builder = consent_uri_builder
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index 09305c3bf1..afac5dec7f 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -55,7 +55,6 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
{
"user_id": user_id,
"device_id": "device_id",
- "access_token": "access_token",
"ip": "ip",
"user_agent": "user_agent",
"last_seen": 12345678000,
@@ -201,6 +200,156 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
active = self.get_success(self.store.user_last_seen_monthly_active(user_id))
self.assertTrue(active)
+ def test_devices_last_seen_bg_update(self):
+ # First make sure we have completed all updates.
+ while not self.get_success(self.store.has_completed_background_updates()):
+ self.get_success(self.store.do_next_background_update(100), by=0.1)
+
+ # Insert a user IP
+ user_id = "@user:id"
+ self.get_success(
+ self.store.insert_client_ip(
+ user_id, "access_token", "ip", "user_agent", "device_id"
+ )
+ )
+
+ # Force persisting to disk
+ self.reactor.advance(200)
+
+ # But clear the associated entry in devices table
+ self.get_success(
+ self.store._simple_update(
+ table="devices",
+ keyvalues={"user_id": user_id, "device_id": "device_id"},
+ updatevalues={"last_seen": None, "ip": None, "user_agent": None},
+ desc="test_devices_last_seen_bg_update",
+ )
+ )
+
+ # We should now get nulls when querying
+ result = self.get_success(
+ self.store.get_last_client_ip_by_device(user_id, "device_id")
+ )
+
+ r = result[(user_id, "device_id")]
+ self.assertDictContainsSubset(
+ {
+ "user_id": user_id,
+ "device_id": "device_id",
+ "ip": None,
+ "user_agent": None,
+ "last_seen": None,
+ },
+ r,
+ )
+
+ # Register the background update to run again.
+ self.get_success(
+ self.store._simple_insert(
+ table="background_updates",
+ values={
+ "update_name": "devices_last_seen",
+ "progress_json": "{}",
+ "depends_on": None,
+ },
+ )
+ )
+
+ # ... and tell the DataStore that it hasn't finished all updates yet
+ self.store._all_done = False
+
+ # Now let's actually drive the updates to completion
+ while not self.get_success(self.store.has_completed_background_updates()):
+ self.get_success(self.store.do_next_background_update(100), by=0.1)
+
+ # We should now get the correct result again
+ result = self.get_success(
+ self.store.get_last_client_ip_by_device(user_id, "device_id")
+ )
+
+ r = result[(user_id, "device_id")]
+ self.assertDictContainsSubset(
+ {
+ "user_id": user_id,
+ "device_id": "device_id",
+ "ip": "ip",
+ "user_agent": "user_agent",
+ "last_seen": 0,
+ },
+ r,
+ )
+
+ def test_old_user_ips_pruned(self):
+ # First make sure we have completed all updates.
+ while not self.get_success(self.store.has_completed_background_updates()):
+ self.get_success(self.store.do_next_background_update(100), by=0.1)
+
+ # Insert a user IP
+ user_id = "@user:id"
+ self.get_success(
+ self.store.insert_client_ip(
+ user_id, "access_token", "ip", "user_agent", "device_id"
+ )
+ )
+
+ # Force persisting to disk
+ self.reactor.advance(200)
+
+ # We should see that in the DB
+ result = self.get_success(
+ self.store._simple_select_list(
+ table="user_ips",
+ keyvalues={"user_id": user_id},
+ retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
+ desc="get_user_ip_and_agents",
+ )
+ )
+
+ self.assertEqual(
+ result,
+ [
+ {
+ "access_token": "access_token",
+ "ip": "ip",
+ "user_agent": "user_agent",
+ "device_id": "device_id",
+ "last_seen": 0,
+ }
+ ],
+ )
+
+ # Now advance by a couple of months
+ self.reactor.advance(60 * 24 * 60 * 60)
+
+ # We should get no results.
+ result = self.get_success(
+ self.store._simple_select_list(
+ table="user_ips",
+ keyvalues={"user_id": user_id},
+ retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
+ desc="get_user_ip_and_agents",
+ )
+ )
+
+ self.assertEqual(result, [])
+
+ # But we should still get the correct values for the device
+ result = self.get_success(
+ self.store.get_last_client_ip_by_device(user_id, "device_id")
+ )
+
+ r = result[(user_id, "device_id")]
+ self.assertDictContainsSubset(
+ {
+ "user_id": user_id,
+ "device_id": "device_id",
+ "ip": "ip",
+ "user_agent": "user_agent",
+ "last_seen": 0,
+ },
+ r,
+ )
+
class ClientIpAuthTestCase(unittest.HomeserverTestCase):
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index 86c7ac350d..b58386994e 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -75,3 +75,43 @@ class EventFederationWorkerStoreTestCase(tests.unittest.TestCase):
el = r[i]
depth = el[2]
self.assertLessEqual(5, depth)
+
+ @defer.inlineCallbacks
+ def test_get_rooms_with_many_extremities(self):
+ room1 = "#room1"
+ room2 = "#room2"
+ room3 = "#room3"
+
+ def insert_event(txn, i, room_id):
+ event_id = "$event_%i:local" % i
+ txn.execute(
+ (
+ "INSERT INTO event_forward_extremities (room_id, event_id) "
+ "VALUES (?, ?)"
+ ),
+ (room_id, event_id),
+ )
+
+ for i in range(0, 20):
+ yield self.store.runInteraction("insert", insert_event, i, room1)
+ yield self.store.runInteraction("insert", insert_event, i, room2)
+ yield self.store.runInteraction("insert", insert_event, i, room3)
+
+ # Test simple case
+ r = yield self.store.get_rooms_with_many_extremities(5, 5, [])
+ self.assertEqual(len(r), 3)
+
+ # Does filter work?
+
+ r = yield self.store.get_rooms_with_many_extremities(5, 5, [room1])
+ self.assertTrue(room2 in r)
+ self.assertTrue(room3 in r)
+ self.assertEqual(len(r), 2)
+
+ r = yield self.store.get_rooms_with_many_extremities(5, 5, [room1, room2])
+ self.assertEqual(r, [room3])
+
+ # Does filter and limit work?
+
+ r = yield self.store.get_rooms_with_many_extremities(5, 1, [room1])
+ self.assertTrue(r == [room2] or r == [room3])
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index d961b81d48..427d3c49c5 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -17,6 +17,8 @@
from mock import Mock
+from canonicaljson import json
+
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
@@ -29,8 +31,10 @@ from tests.utils import create_room
class RedactionTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
+ config = self.default_config()
+ config["redaction_retention_period"] = "30d"
return self.setup_test_homeserver(
- resource_for_federation=Mock(), http_client=None
+ resource_for_federation=Mock(), http_client=None, config=config
)
def prepare(self, reactor, clock, hs):
@@ -114,6 +118,8 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.get_success(self.store.persist_event(event, context))
+ return event
+
def test_redact(self):
self.get_success(
self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
@@ -286,3 +292,108 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.assertEqual(
fetched.unsigned["redacted_because"].event_id, redaction_event_id2
)
+
+ def test_redact_censor(self):
+ """Test that a redacted event gets censored in the DB after a month
+ """
+
+ self.get_success(
+ self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
+ )
+
+ msg_event = self.get_success(self.inject_message(self.room1, self.u_alice, "t"))
+
+ # Check event has not been redacted:
+ event = self.get_success(self.store.get_event(msg_event.event_id))
+
+ self.assertObjectHasAttributes(
+ {
+ "type": EventTypes.Message,
+ "user_id": self.u_alice.to_string(),
+ "content": {"body": "t", "msgtype": "message"},
+ },
+ event,
+ )
+
+ self.assertFalse("redacted_because" in event.unsigned)
+
+ # Redact event
+ reason = "Because I said so"
+ self.get_success(
+ self.inject_redaction(self.room1, msg_event.event_id, self.u_alice, reason)
+ )
+
+ event = self.get_success(self.store.get_event(msg_event.event_id))
+
+ self.assertTrue("redacted_because" in event.unsigned)
+
+ self.assertObjectHasAttributes(
+ {
+ "type": EventTypes.Message,
+ "user_id": self.u_alice.to_string(),
+ "content": {},
+ },
+ event,
+ )
+
+ event_json = self.get_success(
+ self.store._simple_select_one_onecol(
+ table="event_json",
+ keyvalues={"event_id": msg_event.event_id},
+ retcol="json",
+ )
+ )
+
+ self.assert_dict(
+ {"content": {"body": "t", "msgtype": "message"}}, json.loads(event_json)
+ )
+
+ # Advance by 30 days, then advance again to ensure that the looping call
+ # for updating the stream position gets called and then the looping call
+ # for the censoring gets called.
+ self.reactor.advance(60 * 60 * 24 * 31)
+ self.reactor.advance(60 * 60 * 2)
+
+ event_json = self.get_success(
+ self.store._simple_select_one_onecol(
+ table="event_json",
+ keyvalues={"event_id": msg_event.event_id},
+ retcol="json",
+ )
+ )
+
+ self.assert_dict({"content": {}}, json.loads(event_json))
+
+ def test_redact_redaction(self):
+ """Tests that we can redact a redaction and can fetch it again.
+ """
+
+ self.get_success(
+ self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
+ )
+
+ msg_event = self.get_success(self.inject_message(self.room1, self.u_alice, "t"))
+
+ first_redact_event = self.get_success(
+ self.inject_redaction(
+ self.room1, msg_event.event_id, self.u_alice, "Redacting message"
+ )
+ )
+
+ self.get_success(
+ self.inject_redaction(
+ self.room1,
+ first_redact_event.event_id,
+ self.u_alice,
+ "Redacting redaction",
+ )
+ )
+
+ # Now lets jump to the future where we have censored the redaction event
+ # in the DB.
+ self.reactor.advance(60 * 60 * 24 * 31)
+
+ # We just want to check that fetching the event doesn't raise an exception.
+ self.get_success(
+ self.store.get_event(first_redact_event.event_id, allow_none=True)
+ )
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index 0253c4ac05..4578cc3b60 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -49,6 +49,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
"consent_server_notice_sent": None,
"appservice_id": None,
"creation_ts": 1000,
+ "user_type": None,
},
(yield self.store.get_user_by_id(self.user_id)),
)
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index 64cb294c37..447a3c6ffb 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,78 +14,129 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
-from mock import Mock
-
-from twisted.internet import defer
+from unittest.mock import Mock
from synapse.api.constants import EventTypes, Membership
from synapse.api.room_versions import RoomVersions
-from synapse.types import Requester, RoomID, UserID
+from synapse.rest.admin import register_servlets_for_client_rest_resource
+from synapse.rest.client.v1 import login, room
+from synapse.types import Requester, UserID
from tests import unittest
-from tests.utils import create_room, setup_test_homeserver
-class RoomMemberStoreTestCase(unittest.TestCase):
- @defer.inlineCallbacks
- def setUp(self):
- hs = yield setup_test_homeserver(
- self.addCleanup, resource_for_federation=Mock(), http_client=None
+class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ login.register_servlets,
+ register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ hs = self.setup_test_homeserver(
+ resource_for_federation=Mock(), http_client=None
)
+ return hs
+
+ def prepare(self, reactor, clock, hs):
+
# We can't test the RoomMemberStore on its own without the other event
# storage logic
self.store = hs.get_datastore()
self.event_builder_factory = hs.get_event_builder_factory()
self.event_creation_handler = hs.get_event_creation_handler()
- self.u_alice = UserID.from_string("@alice:test")
- self.u_bob = UserID.from_string("@bob:test")
+ self.u_alice = self.register_user("alice", "pass")
+ self.t_alice = self.login("alice", "pass")
+ self.u_bob = self.register_user("bob", "pass")
# User elsewhere on another host
self.u_charlie = UserID.from_string("@charlie:elsewhere")
- self.room = RoomID.from_string("!abc123:test")
-
- yield create_room(hs, self.room.to_string(), self.u_alice.to_string())
-
- @defer.inlineCallbacks
def inject_room_member(self, room, user, membership, replaces_state=None):
builder = self.event_builder_factory.for_room_version(
RoomVersions.V1,
{
"type": EventTypes.Member,
- "sender": user.to_string(),
- "state_key": user.to_string(),
- "room_id": room.to_string(),
+ "sender": user,
+ "state_key": user,
+ "room_id": room,
"content": {"membership": membership},
},
)
- event, context = yield self.event_creation_handler.create_new_client_event(
- builder
+ event, context = self.get_success(
+ self.event_creation_handler.create_new_client_event(builder)
)
- yield self.store.persist_event(event, context)
+ self.get_success(self.store.persist_event(event, context))
return event
- @defer.inlineCallbacks
def test_one_member(self):
- yield self.inject_room_member(self.room, self.u_alice, Membership.JOIN)
-
- self.assertEquals(
- [self.room.to_string()],
- [
- m.room_id
- for m in (
- yield self.store.get_rooms_for_user_where_membership_is(
- self.u_alice.to_string(), [Membership.JOIN]
- )
- )
- ],
+
+ # Alice creates the room, and is automatically joined
+ self.room = self.helper.create_room_as(self.u_alice, tok=self.t_alice)
+
+ rooms_for_user = self.get_success(
+ self.store.get_rooms_for_user_where_membership_is(
+ self.u_alice, [Membership.JOIN]
+ )
)
+ self.assertEquals([self.room], [m.room_id for m in rooms_for_user])
+
+ def test_count_known_servers(self):
+ """
+ _count_known_servers will calculate how many servers are in a room.
+ """
+ self.room = self.helper.create_room_as(self.u_alice, tok=self.t_alice)
+ self.inject_room_member(self.room, self.u_bob, Membership.JOIN)
+ self.inject_room_member(self.room, self.u_charlie.to_string(), Membership.JOIN)
+
+ servers = self.get_success(self.store._count_known_servers())
+ self.assertEqual(servers, 2)
+
+ def test_count_known_servers_stat_counter_disabled(self):
+ """
+ If enabled, the metrics for how many servers are known will be counted.
+ """
+ self.assertTrue("_known_servers_count" not in self.store.__dict__.keys())
+
+ self.room = self.helper.create_room_as(self.u_alice, tok=self.t_alice)
+ self.inject_room_member(self.room, self.u_bob, Membership.JOIN)
+ self.inject_room_member(self.room, self.u_charlie.to_string(), Membership.JOIN)
+
+ self.pump(20)
+
+ self.assertTrue("_known_servers_count" not in self.store.__dict__.keys())
+
+ @unittest.override_config(
+ {"enable_metrics": True, "metrics_flags": {"known_servers": True}}
+ )
+ def test_count_known_servers_stat_counter_enabled(self):
+ """
+ If enabled, the metrics for how many servers are known will be counted.
+ """
+ # Initialises to 1 -- itself
+ self.assertEqual(self.store._known_servers_count, 1)
+
+ self.pump(20)
+
+ # No rooms have been joined, so technically the SQL returns 0, but it
+ # will still say it knows about itself.
+ self.assertEqual(self.store._known_servers_count, 1)
+
+ self.room = self.helper.create_room_as(self.u_alice, tok=self.t_alice)
+ self.inject_room_member(self.room, self.u_bob, Membership.JOIN)
+ self.inject_room_member(self.room, self.u_charlie.to_string(), Membership.JOIN)
+
+ self.pump(20)
+
+ # It now knows about Charlie's server.
+ self.assertEqual(self.store._known_servers_count, 2)
+
class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, homeserver):
diff --git a/tests/storage/test_transactions.py b/tests/storage/test_transactions.py
index 14169afa96..8e817e2c7f 100644
--- a/tests/storage/test_transactions.py
+++ b/tests/storage/test_transactions.py
@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from synapse.util.retryutils import MAX_RETRY_INTERVAL
+
from tests.unittest import HomeserverTestCase
@@ -29,17 +31,28 @@ class TransactionStoreTestCase(HomeserverTestCase):
r = self.get_success(d)
self.assertIsNone(r)
- d = self.store.set_destination_retry_timings("example.com", 50, 100)
+ d = self.store.set_destination_retry_timings("example.com", 1000, 50, 100)
self.get_success(d)
d = self.store.get_destination_retry_timings("example.com")
r = self.get_success(d)
- self.assert_dict({"retry_last_ts": 50, "retry_interval": 100}, r)
+ self.assert_dict(
+ {"retry_last_ts": 50, "retry_interval": 100, "failure_ts": 1000}, r
+ )
def test_initial_set_transactions(self):
"""Tests that we can successfully set the destination retries (there
was a bug around invalidating the cache that broke this)
"""
- d = self.store.set_destination_retry_timings("example.com", 50, 100)
+ d = self.store.set_destination_retry_timings("example.com", 1000, 50, 100)
+ self.get_success(d)
+
+ def test_large_destination_retry(self):
+ d = self.store.set_destination_retry_timings(
+ "example.com", MAX_RETRY_INTERVAL, MAX_RETRY_INTERVAL, MAX_RETRY_INTERVAL
+ )
+ self.get_success(d)
+
+ d = self.store.get_destination_retry_timings("example.com")
self.get_success(d)
diff --git a/tests/test_metrics.py b/tests/test_metrics.py
index 2edbae5c6d..270f853d60 100644
--- a/tests/test_metrics.py
+++ b/tests/test_metrics.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
+# Copyright 2019 Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,8 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
-from synapse.metrics import InFlightGauge
+from synapse.metrics import REGISTRY, InFlightGauge, generate_latest
from tests import unittest
@@ -111,3 +111,21 @@ class TestMauLimit(unittest.TestCase):
}
return results
+
+
+class BuildInfoTests(unittest.TestCase):
+ def test_get_build(self):
+ """
+ The synapse_build_info metric reports the OS version, Python version,
+ and Synapse version.
+ """
+ items = list(
+ filter(
+ lambda x: b"synapse_build_info{" in x,
+ generate_latest(REGISTRY).split(b"\n"),
+ )
+ )
+ self.assertEqual(len(items), 1)
+ self.assertTrue(b"osversion=" in items[0])
+ self.assertTrue(b"pythonversion=" in items[0])
+ self.assertTrue(b"version=" in items[0])
diff --git a/tests/test_server.py b/tests/test_server.py
index 2a7d407c98..98fef21d55 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -57,7 +57,7 @@ class JsonResourceTests(unittest.TestCase):
def _callback(request, **kwargs):
got_kwargs.update(kwargs)
- return (200, kwargs)
+ return 200, kwargs
res = JsonResource(self.homeserver)
res.register_paths(
diff --git a/tests/test_state.py b/tests/test_state.py
index 6d33566f47..610ec9fb46 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -106,7 +106,7 @@ class StateGroupStore(object):
}
def get_state_group_delta(self, name):
- return (None, None)
+ return None, None
def register_events(self, events):
for e in events:
diff --git a/tests/test_terms_auth.py b/tests/test_terms_auth.py
index 52739fbabc..5ec5d2b358 100644
--- a/tests/test_terms_auth.py
+++ b/tests/test_terms_auth.py
@@ -28,6 +28,21 @@ from tests import unittest
class TermsTestCase(unittest.HomeserverTestCase):
servlets = [register_servlets]
+ def default_config(self, name="test"):
+ config = super().default_config(name)
+ config.update(
+ {
+ "public_baseurl": "https://example.org/",
+ "user_consent": {
+ "version": "1.0",
+ "policy_name": "My Cool Privacy Policy",
+ "template_dir": "/",
+ "require_at_registration": True,
+ },
+ }
+ )
+ return config
+
def prepare(self, reactor, clock, hs):
self.clock = MemoryReactorClock()
self.hs_clock = Clock(self.clock)
@@ -35,17 +50,8 @@ class TermsTestCase(unittest.HomeserverTestCase):
self.registration_handler = Mock()
self.auth_handler = Mock()
self.device_handler = Mock()
- hs.config.enable_registration = True
- hs.config.registrations_require_3pid = []
- hs.config.auto_join_rooms = []
- hs.config.enable_registration_captcha = False
def test_ui_auth(self):
- self.hs.config.user_consent_at_registration = True
- self.hs.config.user_consent_policy_name = "My Cool Privacy Policy"
- self.hs.config.public_baseurl = "https://example.org/"
- self.hs.config.user_consent_version = "1.0"
-
# Do a UI auth request
request, channel = self.make_request(b"POST", self.url, b"{}")
self.render(request)
diff --git a/tests/test_visibility.py b/tests/test_visibility.py
index e0605dac2f..18f1a0035d 100644
--- a/tests/test_visibility.py
+++ b/tests/test_visibility.py
@@ -74,7 +74,6 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
self.assertEqual(events_to_filter[i].event_id, filtered[i].event_id)
self.assertEqual(filtered[i].content["a"], "b")
- @tests.unittest.DEBUG
@defer.inlineCallbacks
def test_erased_user(self):
# 4 message events, from erased and unerased users, with a membership
diff --git a/tests/util/caches/test_ttlcache.py b/tests/util/caches/test_ttlcache.py
index c94cbb662b..816795c136 100644
--- a/tests/util/caches/test_ttlcache.py
+++ b/tests/util/caches/test_ttlcache.py
@@ -36,7 +36,7 @@ class CacheTestCase(unittest.TestCase):
self.assertTrue("one" in self.cache)
self.assertEqual(self.cache.get("one"), "1")
self.assertEqual(self.cache["one"], "1")
- self.assertEqual(self.cache.get_with_expiry("one"), ("1", 110))
+ self.assertEqual(self.cache.get_with_expiry("one"), ("1", 110, 10))
self.assertEqual(self.cache._metrics.hits, 3)
self.assertEqual(self.cache._metrics.misses, 0)
@@ -77,7 +77,7 @@ class CacheTestCase(unittest.TestCase):
self.assertEqual(self.cache["two"], "2")
self.assertEqual(self.cache["three"], "3")
- self.assertEqual(self.cache.get_with_expiry("two"), ("2", 120))
+ self.assertEqual(self.cache.get_with_expiry("two"), ("2", 120, 20))
self.assertEqual(self.cache._metrics.hits, 5)
self.assertEqual(self.cache._metrics.misses, 0)
diff --git a/tests/util/test_retryutils.py b/tests/util/test_retryutils.py
new file mode 100644
index 0000000000..9e348694ad
--- /dev/null
+++ b/tests/util/test_retryutils.py
@@ -0,0 +1,127 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from synapse.util.retryutils import (
+ MIN_RETRY_INTERVAL,
+ RETRY_MULTIPLIER,
+ NotRetryingDestination,
+ get_retry_limiter,
+)
+
+from tests.unittest import HomeserverTestCase
+
+
+class RetryLimiterTestCase(HomeserverTestCase):
+ def test_new_destination(self):
+ """A happy-path case with a new destination and a successful operation"""
+ store = self.hs.get_datastore()
+ d = get_retry_limiter("test_dest", self.clock, store)
+ self.pump()
+ limiter = self.successResultOf(d)
+
+ # advance the clock a bit before making the request
+ self.pump(1)
+
+ with limiter:
+ pass
+
+ d = store.get_destination_retry_timings("test_dest")
+ self.pump()
+ new_timings = self.successResultOf(d)
+ self.assertIsNone(new_timings)
+
+ def test_limiter(self):
+ """General test case which walks through the process of a failing request"""
+ store = self.hs.get_datastore()
+
+ d = get_retry_limiter("test_dest", self.clock, store)
+ self.pump()
+ limiter = self.successResultOf(d)
+
+ self.pump(1)
+ try:
+ with limiter:
+ self.pump(1)
+ failure_ts = self.clock.time_msec()
+ raise AssertionError("argh")
+ except AssertionError:
+ pass
+
+ # wait for the update to land
+ self.pump()
+
+ d = store.get_destination_retry_timings("test_dest")
+ self.pump()
+ new_timings = self.successResultOf(d)
+ self.assertEqual(new_timings["failure_ts"], failure_ts)
+ self.assertEqual(new_timings["retry_last_ts"], failure_ts)
+ self.assertEqual(new_timings["retry_interval"], MIN_RETRY_INTERVAL)
+
+ # now if we try again we should get a failure
+ d = get_retry_limiter("test_dest", self.clock, store)
+ self.pump()
+ self.failureResultOf(d, NotRetryingDestination)
+
+ #
+ # advance the clock and try again
+ #
+
+ self.pump(MIN_RETRY_INTERVAL)
+ d = get_retry_limiter("test_dest", self.clock, store)
+ self.pump()
+ limiter = self.successResultOf(d)
+
+ self.pump(1)
+ try:
+ with limiter:
+ self.pump(1)
+ retry_ts = self.clock.time_msec()
+ raise AssertionError("argh")
+ except AssertionError:
+ pass
+
+ # wait for the update to land
+ self.pump()
+
+ d = store.get_destination_retry_timings("test_dest")
+ self.pump()
+ new_timings = self.successResultOf(d)
+ self.assertEqual(new_timings["failure_ts"], failure_ts)
+ self.assertEqual(new_timings["retry_last_ts"], retry_ts)
+ self.assertGreaterEqual(
+ new_timings["retry_interval"], MIN_RETRY_INTERVAL * RETRY_MULTIPLIER * 0.5
+ )
+ self.assertLessEqual(
+ new_timings["retry_interval"], MIN_RETRY_INTERVAL * RETRY_MULTIPLIER * 2.0
+ )
+
+ #
+ # one more go, with success
+ #
+ self.pump(MIN_RETRY_INTERVAL * RETRY_MULTIPLIER * 2.0)
+ d = get_retry_limiter("test_dest", self.clock, store)
+ self.pump()
+ limiter = self.successResultOf(d)
+
+ self.pump(1)
+ with limiter:
+ self.pump(1)
+
+ # wait for the update to land
+ self.pump()
+
+ d = store.get_destination_retry_timings("test_dest")
+ self.pump()
+ new_timings = self.successResultOf(d)
+ self.assertIsNone(new_timings)
diff --git a/tests/utils.py b/tests/utils.py
index f1eb9a545c..46ef2959f2 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -464,7 +464,7 @@ class MockHttpResource(HttpServer):
args = [urlparse.unquote(u) for u in matcher.groups()]
(code, response) = yield func(mock_request, *args)
- return (code, response)
+ return code, response
except CodeMessageException as e:
return (e.code, cs_error(e.msg, code=e.errcode))
|