diff --git a/changelog.d/8565.misc b/changelog.d/8565.misc
new file mode 100644
index 0000000000..7bef422618
--- /dev/null
+++ b/changelog.d/8565.misc
@@ -0,0 +1 @@
+Simplify the way the `HomeServer` object caches its internal attributes.
diff --git a/changelog.d/8820.feature b/changelog.d/8820.feature
new file mode 100644
index 0000000000..9e35861b11
--- /dev/null
+++ b/changelog.d/8820.feature
@@ -0,0 +1 @@
+Add a config option, `push.group_by_unread_count`, which controls whether unread message counts in push notifications are defined as "the number of rooms with unread messages" or "total unread messages".
diff --git a/changelog.d/8845.misc b/changelog.d/8845.misc
new file mode 100644
index 0000000000..7db1c31520
--- /dev/null
+++ b/changelog.d/8845.misc
@@ -0,0 +1 @@
+Drop redundant database index on `event_json`.
diff --git a/changelog.d/8847.misc b/changelog.d/8847.misc
new file mode 100644
index 0000000000..5028997b04
--- /dev/null
+++ b/changelog.d/8847.misc
@@ -0,0 +1 @@
+Simplify `uk.half-shot.msc2778.login.application_service` login handler.
diff --git a/changelog.d/8851.misc b/changelog.d/8851.misc
new file mode 100644
index 0000000000..7bef422618
--- /dev/null
+++ b/changelog.d/8851.misc
@@ -0,0 +1 @@
+Simplify the way the `HomeServer` object caches its internal attributes.
diff --git a/docs/password_auth_providers.md b/docs/password_auth_providers.md
index 7d98d9f255..d2cdb9b2f4 100644
--- a/docs/password_auth_providers.md
+++ b/docs/password_auth_providers.md
@@ -26,6 +26,7 @@ Password auth provider classes must provide the following methods:
It should perform any appropriate sanity checks on the provided
configuration, and return an object which is then passed into
+ `__init__`.
This method should have the `@staticmethod` decoration.
diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml
index 11267a77ba..c84a61e539 100644
--- a/docs/sample_config.yaml
+++ b/docs/sample_config.yaml
@@ -2449,6 +2449,16 @@ push:
#
#include_content: false
+ # When a push notification is received, an unread count is also sent.
+ # This number can either be calculated as the number of unread messages
+ # for the user, or the number of *rooms* the user has unread messages in.
+ #
+ # The default value is "true", meaning push clients will see the number of
+ # rooms with unread messages in them. Uncomment to instead send the number
+ # of unread messages.
+ #
+ #group_unread_count_by_room: false
+
# Spam checkers are third-party modules that can block specific actions
# of local users, such as creating rooms and registering undesirable
diff --git a/synapse/config/push.py b/synapse/config/push.py
index a71baac89c..3adbfb73e6 100644
--- a/synapse/config/push.py
+++ b/synapse/config/push.py
@@ -23,6 +23,9 @@ class PushConfig(Config):
def read_config(self, config, **kwargs):
push_config = config.get("push") or {}
self.push_include_content = push_config.get("include_content", True)
+ self.push_group_unread_count_by_room = push_config.get(
+ "group_unread_count_by_room", True
+ )
pusher_instances = config.get("pusher_instances") or []
self.pusher_shard_config = ShardedWorkerHandlingConfig(pusher_instances)
@@ -68,4 +71,14 @@ class PushConfig(Config):
# include the event ID and room ID in push notification payloads.
#
#include_content: false
+
+ # When a push notification is received, an unread count is also sent.
+ # This number can either be calculated as the number of unread messages
+ # for the user, or the number of *rooms* the user has unread messages in.
+ #
+ # The default value is "true", meaning push clients will see the number of
+ # rooms with unread messages in them. Uncomment to instead send the number
+ # of unread messages.
+ #
+ #group_unread_count_by_room: false
"""
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index c5ff072d5a..b8a6b1d491 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -380,7 +380,8 @@ class IdentityHandler(BaseHandler):
raise SynapseError(500, "An error was encountered when sending the email")
token_expires = (
- self.hs.clock.time_msec() + self.hs.config.email_validation_token_lifetime
+ self.hs.get_clock().time_msec()
+ + self.hs.config.email_validation_token_lifetime
)
await self.store.start_or_continue_validation_session(
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index 793d0db2d9..eff0975b6a 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -75,6 +75,7 @@ class HttpPusher:
self.failing_since = pusherdict["failing_since"]
self.timed_call = None
self._is_processing = False
+ self._group_unread_count_by_room = hs.config.push_group_unread_count_by_room
# This is the highest stream ordering we know it's safe to process.
# When new events arrive, we'll be given a window of new events: we
@@ -136,7 +137,11 @@ class HttpPusher:
async def _update_badge(self):
# XXX as per https://github.com/matrix-org/matrix-doc/issues/2627, this seems
# to be largely redundant. perhaps we can remove it.
- badge = await push_tools.get_badge_count(self.hs.get_datastore(), self.user_id)
+ badge = await push_tools.get_badge_count(
+ self.hs.get_datastore(),
+ self.user_id,
+ group_by_room=self._group_unread_count_by_room,
+ )
await self._send_badge(badge)
def on_timer(self):
@@ -283,7 +288,11 @@ class HttpPusher:
return True
tweaks = push_rule_evaluator.tweaks_for_actions(push_action["actions"])
- badge = await push_tools.get_badge_count(self.hs.get_datastore(), self.user_id)
+ badge = await push_tools.get_badge_count(
+ self.hs.get_datastore(),
+ self.user_id,
+ group_by_room=self._group_unread_count_by_room,
+ )
event = await self.store.get_event(push_action["event_id"], allow_none=True)
if event is None:
diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py
index d0145666bf..6e7c880dc0 100644
--- a/synapse/push/push_tools.py
+++ b/synapse/push/push_tools.py
@@ -12,12 +12,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
from synapse.push.presentable_names import calculate_room_name, name_from_member_event
from synapse.storage import Storage
+from synapse.storage.databases.main import DataStore
-async def get_badge_count(store, user_id):
+async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -> int:
invites = await store.get_invited_rooms_for_local_user(user_id)
joins = await store.get_rooms_for_user(user_id)
@@ -34,9 +34,15 @@ async def get_badge_count(store, user_id):
room_id, user_id, last_unread_event_id
)
)
- # return one badge count per conversation, as count per
- # message is so noisy as to be almost useless
- badge += 1 if notifs["notify_count"] else 0
+ if notifs["notify_count"] == 0:
+ continue
+
+ if group_by_room:
+ # return one badge count per conversation
+ badge += 1
+ else:
+ # increment the badge count by the number of unread messages in the room
+ badge += notifs["notify_count"]
return badge
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 94452fcbf5..074bdd66c9 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -154,13 +154,28 @@ class LoginRestServlet(RestServlet):
async def _do_appservice_login(
self, login_submission: JsonDict, appservice: ApplicationService
):
- logger.info(
- "Got appservice login request with identifier: %r",
- login_submission.get("identifier"),
- )
+ identifier = login_submission.get("identifier")
+ logger.info("Got appservice login request with identifier: %r", identifier)
- identifier = convert_client_dict_legacy_fields_to_identifier(login_submission)
- qualified_user_id = self._get_qualified_user_id(identifier)
+ if not isinstance(identifier, dict):
+ raise SynapseError(
+ 400, "Invalid identifier in login submission", Codes.INVALID_PARAM
+ )
+
+ # this login flow only supports identifiers of type "m.id.user".
+ if identifier.get("type") != "m.id.user":
+ raise SynapseError(
+ 400, "Unknown login identifier type", Codes.INVALID_PARAM
+ )
+
+ user = identifier.get("user")
+ if not isinstance(user, str):
+ raise SynapseError(400, "Invalid user in identifier", Codes.INVALID_PARAM)
+
+ if user.startswith("@"):
+ qualified_user_id = user
+ else:
+ qualified_user_id = UserID(user, self.hs.hostname).to_string()
if not appservice.is_interested_in_user(qualified_user_id):
raise LoginError(403, "Invalid access_token", errcode=Codes.FORBIDDEN)
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index f5cd397493..8e4a01a7e8 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -117,7 +117,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
# comments for request_token_inhibit_3pid_errors.
# Also wait for some random amount of time between 100ms and 1s to make it
# look like we did something.
- await self.hs.clock.sleep(random.randint(1, 10) / 10)
+ await self.hs.get_clock().sleep(random.randint(1, 10) / 10)
return 200, {"sid": random_string(16)}
raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND)
@@ -417,7 +417,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
# comments for request_token_inhibit_3pid_errors.
# Also wait for some random amount of time between 100ms and 1s to make it
# look like we did something.
- await self.hs.clock.sleep(random.randint(1, 10) / 10)
+ await self.hs.get_clock().sleep(random.randint(1, 10) / 10)
return 200, {"sid": random_string(16)}
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
@@ -496,7 +496,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
# comments for request_token_inhibit_3pid_errors.
# Also wait for some random amount of time between 100ms and 1s to make it
# look like we did something.
- await self.hs.clock.sleep(random.randint(1, 10) / 10)
+ await self.hs.get_clock().sleep(random.randint(1, 10) / 10)
return 200, {"sid": random_string(16)}
raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE)
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 27770393b2..36807458de 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -137,7 +137,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
# comments for request_token_inhibit_3pid_errors.
# Also wait for some random amount of time between 100ms and 1s to make it
# look like we did something.
- await self.hs.clock.sleep(random.randint(1, 10) / 10)
+ await self.hs.get_clock().sleep(random.randint(1, 10) / 10)
return 200, {"sid": random_string(16)}
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
@@ -218,7 +218,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
# comments for request_token_inhibit_3pid_errors.
# Also wait for some random amount of time between 100ms and 1s to make it
# look like we did something.
- await self.hs.clock.sleep(random.randint(1, 10) / 10)
+ await self.hs.get_clock().sleep(random.randint(1, 10) / 10)
return 200, {"sid": random_string(16)}
raise SynapseError(
diff --git a/synapse/rest/key/v2/local_key_resource.py b/synapse/rest/key/v2/local_key_resource.py
index c16280f668..d8e8e48c1c 100644
--- a/synapse/rest/key/v2/local_key_resource.py
+++ b/synapse/rest/key/v2/local_key_resource.py
@@ -66,7 +66,7 @@ class LocalKey(Resource):
def __init__(self, hs):
self.config = hs.config
- self.clock = hs.clock
+ self.clock = hs.get_clock()
self.update_response_body(self.clock.time_msec())
Resource.__init__(self)
diff --git a/synapse/server.py b/synapse/server.py
index c82d8f9fad..b017e3489f 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -147,7 +147,8 @@ def cache_in_self(builder: T) -> T:
"@cache_in_self can only be used on functions starting with `get_`"
)
- depname = builder.__name__[len("get_") :]
+ # get_attr -> _attr
+ depname = builder.__name__[len("get") :]
building = [False]
@@ -235,15 +236,6 @@ class HomeServer(metaclass=abc.ABCMeta):
self._instance_id = random_string(5)
self._instance_name = config.worker_name or "master"
- self.clock = Clock(reactor)
- self.distributor = Distributor()
-
- self.registration_ratelimiter = Ratelimiter(
- clock=self.clock,
- rate_hz=config.rc_registration.per_second,
- burst_count=config.rc_registration.burst_count,
- )
-
self.version_string = version_string
self.datastores = None # type: Optional[Databases]
@@ -301,8 +293,9 @@ class HomeServer(metaclass=abc.ABCMeta):
def is_mine_id(self, string: str) -> bool:
return string.split(":", 1)[1] == self.hostname
+ @cache_in_self
def get_clock(self) -> Clock:
- return self.clock
+ return Clock(self._reactor)
def get_datastore(self) -> DataStore:
if not self.datastores:
@@ -319,11 +312,17 @@ class HomeServer(metaclass=abc.ABCMeta):
def get_config(self) -> HomeServerConfig:
return self.config
+ @cache_in_self
def get_distributor(self) -> Distributor:
- return self.distributor
+ return Distributor()
+ @cache_in_self
def get_registration_ratelimiter(self) -> Ratelimiter:
- return self.registration_ratelimiter
+ return Ratelimiter(
+ clock=self.get_clock(),
+ rate_hz=self.config.rc_registration.per_second,
+ burst_count=self.config.rc_registration.burst_count,
+ )
@cache_in_self
def get_federation_client(self) -> FederationClient:
@@ -687,7 +686,7 @@ class HomeServer(metaclass=abc.ABCMeta):
@cache_in_self
def get_federation_ratelimiter(self) -> FederationRateLimiter:
- return FederationRateLimiter(self.clock, config=self.config.rc_federation)
+ return FederationRateLimiter(self.get_clock(), config=self.config.rc_federation)
@cache_in_self
def get_module_api(self) -> ModuleApi:
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index ecfc6717b3..5d668aadb2 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -314,6 +314,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
for table in (
"event_auth",
"event_edges",
+ "event_json",
"event_push_actions_staging",
"event_reference_hashes",
"event_relations",
@@ -340,7 +341,6 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
"destination_rooms",
"event_backward_extremities",
"event_forward_extremities",
- "event_json",
"event_push_actions",
"event_search",
"events",
diff --git a/synapse/storage/databases/main/schema/delta/58/24drop_event_json_index.sql b/synapse/storage/databases/main/schema/delta/58/24drop_event_json_index.sql
new file mode 100644
index 0000000000..8a39d54aed
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/24drop_event_json_index.sql
@@ -0,0 +1,19 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- this index is essentially redundant. The only time it was ever used was when purging
+-- rooms - and Synapse 1.24 will change that.
+
+DROP INDEX IF EXISTS event_json_room_id;
diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py
index b5055e018c..e24ce81284 100644
--- a/tests/handlers/test_auth.py
+++ b/tests/handlers/test_auth.py
@@ -52,7 +52,7 @@ class AuthTestCase(unittest.TestCase):
self.fail("some_user was not in %s" % macaroon.inspect())
def test_macaroon_caveats(self):
- self.hs.clock.now = 5000
+ self.hs.get_clock().now = 5000
token = self.macaroon_generator.generate_access_token("a_user")
macaroon = pymacaroons.Macaroon.deserialize(token)
@@ -78,7 +78,7 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_short_term_login_token_gives_user_id(self):
- self.hs.clock.now = 1000
+ self.hs.get_clock().now = 1000
token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000)
user_id = yield defer.ensureDeferred(
@@ -87,7 +87,7 @@ class AuthTestCase(unittest.TestCase):
self.assertEqual("a_user", user_id)
# when we advance the clock, the token should be rejected
- self.hs.clock.now = 6000
+ self.hs.get_clock().now = 6000
with self.assertRaises(synapse.api.errors.AuthError):
yield defer.ensureDeferred(
self.auth_handler.validate_short_term_login_token_and_get_user_id(token)
diff --git a/tests/push/test_http.py b/tests/push/test_http.py
index 826cebbf0c..cd46c485dd 100644
--- a/tests/push/test_http.py
+++ b/tests/push/test_http.py
@@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
from mock import Mock
from twisted.internet.defer import Deferred
@@ -20,8 +19,9 @@ from twisted.internet.defer import Deferred
import synapse.rest.admin
from synapse.logging.context import make_deferred_yieldable
from synapse.rest.client.v1 import login, room
+from synapse.rest.client.v2_alpha import receipts
-from tests.unittest import HomeserverTestCase
+from tests.unittest import HomeserverTestCase, override_config
class HTTPPusherTests(HomeserverTestCase):
@@ -29,6 +29,7 @@ class HTTPPusherTests(HomeserverTestCase):
synapse.rest.admin.register_servlets_for_client_rest_resource,
room.register_servlets,
login.register_servlets,
+ receipts.register_servlets,
]
user_id = True
hijack_auth = False
@@ -499,3 +500,161 @@ class HTTPPusherTests(HomeserverTestCase):
# check that this is high-priority
self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "high")
+
+ def test_push_unread_count_group_by_room(self):
+ """
+ The HTTP pusher will group unread count by number of unread rooms.
+ """
+ # Carry out common push count tests and setup
+ self._test_push_unread_count()
+
+ # Carry out our option-value specific test
+ #
+ # This push should still only contain an unread count of 1 (for 1 unread room)
+ self.assertEqual(
+ self.push_attempts[5][2]["notification"]["counts"]["unread"], 1
+ )
+
+ @override_config({"push": {"group_unread_count_by_room": False}})
+ def test_push_unread_count_message_count(self):
+ """
+ The HTTP pusher will send the total unread message count.
+ """
+ # Carry out common push count tests and setup
+ self._test_push_unread_count()
+
+ # Carry out our option-value specific test
+ #
+ # We're counting every unread message, so there should now be 4 since the
+ # last read receipt
+ self.assertEqual(
+ self.push_attempts[5][2]["notification"]["counts"]["unread"], 4
+ )
+
+ def _test_push_unread_count(self):
+ """
+ Tests that the correct unread count appears in sent push notifications
+
+ Note that:
+ * Sending messages will cause push notifications to go out to relevant users
+ * Sending a read receipt will cause a "badge update" notification to go out to
+ the user that sent the receipt
+ """
+ # Register the user who gets notified
+ user_id = self.register_user("user", "pass")
+ access_token = self.login("user", "pass")
+
+ # Register the user who sends the message
+ other_user_id = self.register_user("other_user", "pass")
+ other_access_token = self.login("other_user", "pass")
+
+ # Create a room (as other_user)
+ room_id = self.helper.create_room_as(other_user_id, tok=other_access_token)
+
+ # The user to get notified joins
+ self.helper.join(room=room_id, user=user_id, tok=access_token)
+
+ # Register the pusher
+ user_tuple = self.get_success(
+ self.hs.get_datastore().get_user_by_access_token(access_token)
+ )
+ token_id = user_tuple.token_id
+
+ self.get_success(
+ self.hs.get_pusherpool().add_pusher(
+ user_id=user_id,
+ access_token=token_id,
+ kind="http",
+ app_id="m.http",
+ app_display_name="HTTP Push Notifications",
+ device_display_name="pushy push",
+ pushkey="a@example.com",
+ lang=None,
+ data={"url": "example.com"},
+ )
+ )
+
+ # Send a message
+ response = self.helper.send(
+ room_id, body="Hello there!", tok=other_access_token
+ )
+ # To get an unread count, the user who is getting notified has to have a read
+ # position in the room. We'll set the read position to this event in a moment
+ first_message_event_id = response["event_id"]
+
+ # Advance time a bit (so the pusher will register something has happened) and
+ # make the push succeed
+ self.push_attempts[0][0].callback({})
+ self.pump()
+
+ # Check our push made it
+ self.assertEqual(len(self.push_attempts), 1)
+ self.assertEqual(self.push_attempts[0][1], "example.com")
+
+ # Check that the unread count for the room is 0
+ #
+ # The unread count is zero as the user has no read receipt in the room yet
+ self.assertEqual(
+ self.push_attempts[0][2]["notification"]["counts"]["unread"], 0
+ )
+
+ # Now set the user's read receipt position to the first event
+ #
+ # This will actually trigger a new notification to be sent out so that
+ # even if the user does not receive another message, their unread
+ # count goes down
+ request, channel = self.make_request(
+ "POST",
+ "/rooms/%s/receipt/m.read/%s" % (room_id, first_message_event_id),
+ {},
+ access_token=access_token,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ # Advance time and make the push succeed
+ self.push_attempts[1][0].callback({})
+ self.pump()
+
+ # Unread count is still zero as we've read the only message in the room
+ self.assertEqual(len(self.push_attempts), 2)
+ self.assertEqual(
+ self.push_attempts[1][2]["notification"]["counts"]["unread"], 0
+ )
+
+ # Send another message
+ self.helper.send(
+ room_id, body="How's the weather today?", tok=other_access_token
+ )
+
+ # Advance time and make the push succeed
+ self.push_attempts[2][0].callback({})
+ self.pump()
+
+ # This push should contain an unread count of 1 as there's now been one
+ # message since our last read receipt
+ self.assertEqual(len(self.push_attempts), 3)
+ self.assertEqual(
+ self.push_attempts[2][2]["notification"]["counts"]["unread"], 1
+ )
+
+ # Since we're grouping by room, sending more messages shouldn't increase the
+ # unread count, as they're all being sent in the same room
+ self.helper.send(room_id, body="Hello?", tok=other_access_token)
+
+ # Advance time and make the push succeed
+ self.pump()
+ self.push_attempts[3][0].callback({})
+
+ self.helper.send(room_id, body="Hello??", tok=other_access_token)
+
+ # Advance time and make the push succeed
+ self.pump()
+ self.push_attempts[4][0].callback({})
+
+ self.helper.send(room_id, body="HELLO???", tok=other_access_token)
+
+ # Advance time and make the push succeed
+ self.pump()
+ self.push_attempts[5][0].callback({})
+
+ self.assertEqual(len(self.push_attempts), 6)
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index 516db4c30a..295c5d58a6 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -78,7 +78,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self.worker_hs.get_datastore().db_pool = hs.get_datastore().db_pool
self.test_handler = self._build_replication_data_handler()
- self.worker_hs.replication_data_handler = self.test_handler
+ self.worker_hs._replication_data_handler = self.test_handler
repl_handler = ReplicationCommandHandler(self.worker_hs)
self.client = ClientReplicationStreamProtocol(
diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py
index 2a65ab33bd..dadf9db660 100644
--- a/tests/rest/admin/test_media.py
+++ b/tests/rest/admin/test_media.py
@@ -192,7 +192,6 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.handler = hs.get_device_handler()
self.media_repo = hs.get_media_repository_resource()
self.server_name = hs.hostname
- self.clock = hs.clock
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py
index b84f86d28c..5d5c24d01c 100644
--- a/tests/rest/client/v1/test_presence.py
+++ b/tests/rest/client/v1/test_presence.py
@@ -33,13 +33,16 @@ class PresenceTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
+ presence_handler = Mock()
+ presence_handler.set_state.return_value = defer.succeed(None)
+
hs = self.setup_test_homeserver(
- "red", http_client=None, federation_client=Mock()
+ "red",
+ http_client=None,
+ federation_client=Mock(),
+ presence_handler=presence_handler,
)
- hs.presence_handler = Mock()
- hs.presence_handler.set_state.return_value = defer.succeed(None)
-
return hs
def test_put_presence(self):
@@ -55,7 +58,7 @@ class PresenceTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(channel.code, 200)
- self.assertEqual(self.hs.presence_handler.set_state.call_count, 1)
+ self.assertEqual(self.hs.get_presence_handler().set_state.call_count, 1)
def test_put_presence_disabled(self):
"""
@@ -70,4 +73,4 @@ class PresenceTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(channel.code, 200)
- self.assertEqual(self.hs.presence_handler.set_state.call_count, 0)
+ self.assertEqual(self.hs.get_presence_handler().set_state.call_count, 0)
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index 2272caa048..b53b47a0cc 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -769,7 +769,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
tok = self.login("kermit", "monkey")
# We need to manually add an email address otherwise the handler will do
# nothing.
- now = self.hs.clock.time_msec()
+ now = self.hs.get_clock().time_msec()
self.get_success(
self.store.user_add_threepid(
user_id=user_id,
@@ -787,7 +787,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
# We need to manually add an email address otherwise the handler will do
# nothing.
- now = self.hs.clock.time_msec()
+ now = self.hs.get_clock().time_msec()
self.get_success(
self.store.user_add_threepid(
user_id=user_id,
@@ -849,7 +849,7 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase):
"""
user_id = self.register_user("kermit_delta", "user")
- now_ms = self.hs.clock.time_msec()
+ now_ms = self.hs.get_clock().time_msec()
self.get_success(self.store._set_expiration_date_when_missing())
res = self.get_success(self.store.get_expiration_ts_for_user(user_id))
diff --git a/tests/utils.py b/tests/utils.py
index 6e7a6fd3cf..1584eacb12 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -273,7 +273,7 @@ def setup_test_homeserver(
# Install @cache_in_self attributes
for key, val in kwargs.items():
- setattr(hs, key, val)
+ setattr(hs, "_" + key, val)
# Mock TLS
hs.tls_server_context_factory = Mock()
|