From 49af4020190eae6b0c65897d96cd2be286364d2b Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Mon, 9 Jul 2018 16:09:20 +1000 Subject: run isort --- tests/handlers/test_appservice.py | 8 +++++--- tests/handlers/test_auth.py | 2 ++ tests/handlers/test_device.py | 2 +- tests/handlers/test_directory.py | 6 +++--- tests/handlers/test_e2e_keys.py | 5 +++-- tests/handlers/test_presence.py | 12 ++++++++---- tests/handlers/test_profile.py | 6 +++--- tests/handlers/test_register.py | 5 +++-- tests/handlers/test_typing.py | 19 ++++++++++++------- 9 files changed, 40 insertions(+), 25 deletions(-) (limited to 'tests/handlers') diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index b753455943..57c0771cf3 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -13,13 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +from mock import Mock + from twisted.internet import defer -from .. import unittest -from tests.utils import MockClock from synapse.handlers.appservice import ApplicationServicesHandler -from mock import Mock +from tests.utils import MockClock + +from .. import unittest class AppServiceHandlerTestCase(unittest.TestCase): diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index 1822dcf1e0..2e5e8e4dec 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -14,11 +14,13 @@ # limitations under the License. import pymacaroons + from twisted.internet import defer import synapse import synapse.api.errors from synapse.handlers.auth import AuthHandler + from tests import unittest from tests.utils import setup_test_homeserver diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index 778ff2f6e9..633a0b7f36 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -17,8 +17,8 @@ from twisted.internet import defer import synapse.api.errors import synapse.handlers.device - import synapse.storage + from tests import unittest, utils user1 = "@boris:aaa" diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index 7e5332e272..a353070316 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -14,14 +14,14 @@ # limitations under the License. -from tests import unittest -from twisted.internet import defer - from mock import Mock +from twisted.internet import defer + from synapse.handlers.directory import DirectoryHandler from synapse.types import RoomAlias +from tests import unittest from tests.utils import setup_test_homeserver diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index d1bd87b898..ca1542236d 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -14,13 +14,14 @@ # limitations under the License. import mock -from synapse.api import errors + from twisted.internet import defer import synapse.api.errors import synapse.handlers.e2e_keys - import synapse.storage +from synapse.api import errors + from tests import unittest, utils diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index de06a6ad30..121ce78634 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -14,18 +14,22 @@ # limitations under the License. -from tests import unittest - from mock import Mock, call from synapse.api.constants import PresenceState from synapse.handlers.presence import ( - handle_update, handle_timeout, - IDLE_TIMER, SYNC_ONLINE_TIMEOUT, LAST_ACTIVE_GRANULARITY, FEDERATION_TIMEOUT, FEDERATION_PING_INTERVAL, + FEDERATION_TIMEOUT, + IDLE_TIMER, + LAST_ACTIVE_GRANULARITY, + SYNC_ONLINE_TIMEOUT, + handle_timeout, + handle_update, ) from synapse.storage.presence import UserPresenceState +from tests import unittest + class PresenceUpdateTestCase(unittest.TestCase): def test_offline_to_online(self): diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index 458296ee4c..dc17918a3d 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -14,16 +14,16 @@ # limitations under the License. -from tests import unittest -from twisted.internet import defer - from mock import Mock, NonCallableMock +from twisted.internet import defer + import synapse.types from synapse.api.errors import AuthError from synapse.handlers.profile import ProfileHandler from synapse.types import UserID +from tests import unittest from tests.utils import setup_test_homeserver diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index e990e45220..025fa1be81 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -13,15 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +from mock import Mock + from twisted.internet import defer -from .. import unittest from synapse.handlers.register import RegistrationHandler from synapse.types import UserID, create_requester from tests.utils import setup_test_homeserver -from mock import Mock +from .. import unittest class RegistrationHandlers(object): diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index a433bbfa8a..b08856f763 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -14,19 +14,24 @@ # limitations under the License. -from tests import unittest -from twisted.internet import defer - -from mock import Mock, call, ANY import json -from ..utils import ( - MockHttpResource, MockClock, DeferredMockCallable, setup_test_homeserver -) +from mock import ANY, Mock, call + +from twisted.internet import defer from synapse.api.errors import AuthError from synapse.types import UserID +from tests import unittest + +from ..utils import ( + DeferredMockCallable, + MockClock, + MockHttpResource, + setup_test_homeserver, +) + def _expect_edu(destination, edu_type, content, origin="test"): return { -- cgit 1.4.1 From 251e6c1210087069a6133140519de80a4ddf218a Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Mon, 30 Jul 2018 15:55:57 +0100 Subject: limit register and sign in on number of monthly users --- synapse/api/errors.py | 1 + synapse/config/server.py | 5 +++++ synapse/handlers/auth.py | 13 +++++++++++ synapse/handlers/register.py | 18 +++++++++++++-- synapse/storage/__init__.py | 34 ++++++++++++++++++++++++++++ tests/handlers/test_auth.py | 49 ++++++++++++++++++++++++++++++++++++++++- tests/handlers/test_register.py | 49 +++++++++++++++++++++++++++++++++++++++++ 7 files changed, 166 insertions(+), 3 deletions(-) (limited to 'tests/handlers') diff --git a/synapse/api/errors.py b/synapse/api/errors.py index 6074df292f..14f5540280 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -55,6 +55,7 @@ class Codes(object): SERVER_NOT_TRUSTED = "M_SERVER_NOT_TRUSTED" CONSENT_NOT_GIVEN = "M_CONSENT_NOT_GIVEN" CANNOT_LEAVE_SERVER_NOTICE_ROOM = "M_CANNOT_LEAVE_SERVER_NOTICE_ROOM" + MAU_LIMIT_EXCEEDED = "M_MAU_LIMIT_EXCEEDED" class CodeMessageException(RuntimeError): diff --git a/synapse/config/server.py b/synapse/config/server.py index 18102656b0..8b335bff3f 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -67,6 +67,11 @@ class ServerConfig(Config): "block_non_admin_invites", False, ) + # Options to control access by tracking MAU + self.limit_usage_by_mau = config.get("limit_usage_by_mau", False) + self.max_mau_value = config.get( + "max_mau_value", 0, + ) # FIXME: federation_domain_whitelist needs sytests self.federation_domain_whitelist = None federation_domain_whitelist = config.get( diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 402e44cdef..f3734f11bd 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -519,6 +519,7 @@ class AuthHandler(BaseHandler): """ logger.info("Logging in user %s on device %s", user_id, device_id) access_token = yield self.issue_access_token(user_id, device_id) + self._check_mau_limits() # the device *should* have been registered before we got here; however, # it's possible we raced against a DELETE operation. The thing we @@ -729,6 +730,7 @@ class AuthHandler(BaseHandler): defer.returnValue(access_token) def validate_short_term_login_token_and_get_user_id(self, login_token): + self._check_mau_limits() auth_api = self.hs.get_auth() try: macaroon = pymacaroons.Macaroon.deserialize(login_token) @@ -892,6 +894,17 @@ class AuthHandler(BaseHandler): else: return defer.succeed(False) + def _check_mau_limits(self): + """ + Ensure that if mau blocking is enabled that invalid users cannot + log in. + """ + if self.hs.config.limit_usage_by_mau is True: + current_mau = self.store.count_monthly_users() + if current_mau >= self.hs.config.max_mau_value: + raise AuthError( + 403, "MAU Limit Exceeded", errcode=Codes.MAU_LIMIT_EXCEEDED + ) @attr.s class MacaroonGenerator(object): diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 7caff0cbc8..f46b8355c0 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -45,7 +45,7 @@ class RegistrationHandler(BaseHandler): hs (synapse.server.HomeServer): """ super(RegistrationHandler, self).__init__(hs) - + self.hs = hs self.auth = hs.get_auth() self._auth_handler = hs.get_auth_handler() self.profile_handler = hs.get_profile_handler() @@ -144,6 +144,7 @@ class RegistrationHandler(BaseHandler): Raises: RegistrationError if there was a problem registering. """ + self._check_mau_limits() password_hash = None if password: password_hash = yield self.auth_handler().hash(password) @@ -288,6 +289,7 @@ class RegistrationHandler(BaseHandler): 400, "User ID can only contain characters a-z, 0-9, or '=_-./'", ) + self._check_mau_limits() user = UserID(localpart, self.hs.hostname) user_id = user.to_string() @@ -437,7 +439,7 @@ class RegistrationHandler(BaseHandler): """ if localpart is None: raise SynapseError(400, "Request must include user id") - + self._check_mau_limits() need_register = True try: @@ -531,3 +533,15 @@ class RegistrationHandler(BaseHandler): remote_room_hosts=remote_room_hosts, action="join", ) + + def _check_mau_limits(self): + """ + Do not accept registrations if monthly active user limits exceeded + and limiting is enabled + """ + if self.hs.config.limit_usage_by_mau is True: + current_mau = self.store.count_monthly_users() + if current_mau >= self.hs.config.max_mau_value: + raise RegistrationError( + 403, "MAU Limit Exceeded", Codes.MAU_LIMIT_EXCEEDED + ) diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index ba88a54979..6a75bf0e52 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -19,6 +19,7 @@ import logging import time from dateutil import tz +from prometheus_client import Gauge from synapse.api.constants import PresenceState from synapse.storage.devices import DeviceStore @@ -60,6 +61,13 @@ from .util.id_generators import ChainedIdGenerator, IdGenerator, StreamIdGenerat logger = logging.getLogger(__name__) +# Gauges to expose monthly active user control metrics +current_mau_gauge = Gauge("synapse_admin_current_mau", "Current MAU") +max_mau_value_gauge = Gauge("synapse_admin_max_mau_value", "MAU Limit") +limit_usage_by_mau_gauge = Gauge( + "synapse_admin_limit_usage_by_mau", "MAU Limiting enabled" +) + class DataStore(RoomMemberStore, RoomStore, RegistrationStore, StreamStore, ProfileStore, @@ -266,6 +274,32 @@ class DataStore(RoomMemberStore, RoomStore, return self.runInteraction("count_users", _count_users) + def count_monthly_users(self): + """ + Counts the number of users who used this homeserver in the last 30 days + This method should be refactored with count_daily_users - the only + reason not to is waiting on definition of mau + returns: + int: count of current monthly active users + """ + def _count_monthly_users(txn): + thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30) + sql = """ + SELECT COUNT(*) FROM user_ips + WHERE last_seen > ? + """ + txn.execute(sql, (thirty_days_ago,)) + count, = txn.fetchone() + + self._current_mau = count + current_mau_gauge.set(self._current_mau) + max_mau_value_gauge.set(self.hs.config.max_mau_value) + limit_usage_by_mau_gauge.set(self.hs.config.limit_usage_by_mau) + logger.info("calling mau stats") + return count + return self.runInteraction("count_monthly_users", _count_monthly_users) + + def count_r30_users(self): """ Counts the number of 30 day retained users, defined as:- diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index 2e5e8e4dec..57f78a6bec 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -12,15 +12,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +from mock import Mock import pymacaroons from twisted.internet import defer import synapse +from synapse.api.errors import AuthError import synapse.api.errors from synapse.handlers.auth import AuthHandler + from tests import unittest from tests.utils import setup_test_homeserver @@ -37,6 +39,10 @@ class AuthTestCase(unittest.TestCase): self.hs.handlers = AuthHandlers(self.hs) self.auth_handler = self.hs.handlers.auth_handler self.macaroon_generator = self.hs.get_macaroon_generator() + # MAU tests + self.hs.config.max_mau_value = 50 + self.small_number_of_users = 1 + self.large_number_of_users = 100 def test_token_is_a_macaroon(self): token = self.macaroon_generator.generate_access_token("some_user") @@ -113,3 +119,44 @@ class AuthTestCase(unittest.TestCase): self.auth_handler.validate_short_term_login_token_and_get_user_id( macaroon.serialize() ) + + @defer.inlineCallbacks + def test_mau_limits_disabled(self): + self.hs.config.limit_usage_by_mau = False + # Ensure does not throw exception + yield self.auth_handler.get_access_token_for_user_id('user_a') + + self.auth_handler.validate_short_term_login_token_and_get_user_id( + self._get_macaroon().serialize() + ) + + @defer.inlineCallbacks + def test_mau_limits_exceeded(self): + self.hs.config.limit_usage_by_mau = True + self.hs.get_datastore().count_monthly_users = Mock( + return_value=self.large_number_of_users + ) + with self.assertRaises(AuthError): + yield self.auth_handler.get_access_token_for_user_id('user_a') + with self.assertRaises(AuthError): + self.auth_handler.validate_short_term_login_token_and_get_user_id( + self._get_macaroon().serialize() + ) + + @defer.inlineCallbacks + def test_mau_limits_not_exceeded(self): + self.hs.config.limit_usage_by_mau = True + self.hs.get_datastore().count_monthly_users = Mock( + return_value=self.small_number_of_users + ) + # Ensure does not raise exception + yield self.auth_handler.get_access_token_for_user_id('user_a') + self.auth_handler.validate_short_term_login_token_and_get_user_id( + self._get_macaroon().serialize() + ) + + def _get_macaroon(self): + token = self.macaroon_generator.generate_short_term_login_token( + "user_a", 5000 + ) + return pymacaroons.Macaroon.deserialize(token) diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index 025fa1be81..a5a8e7c954 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -17,6 +17,7 @@ from mock import Mock from twisted.internet import defer +from synapse.api.errors import RegistrationError from synapse.handlers.register import RegistrationHandler from synapse.types import UserID, create_requester @@ -77,3 +78,51 @@ class RegistrationTestCase(unittest.TestCase): requester, local_part, display_name) self.assertEquals(result_user_id, user_id) self.assertEquals(result_token, 'secret') + + @defer.inlineCallbacks + def test_cannot_register_when_mau_limits_exceeded(self): + local_part = "someone" + display_name = "someone" + requester = create_requester("@as:test") + store = self.hs.get_datastore() + self.hs.config.limit_usage_by_mau = False + self.hs.config.max_mau_value = 50 + lots_of_users = 100 + small_number_users = 1 + + store.count_monthly_users = Mock(return_value=lots_of_users) + + # Ensure does not throw exception + yield self.handler.get_or_create_user(requester, 'a', display_name) + + self.hs.config.limit_usage_by_mau = True + + with self.assertRaises(RegistrationError): + yield self.handler.get_or_create_user(requester, 'b', display_name) + + store.count_monthly_users = Mock(return_value=small_number_users) + + self._macaroon_mock_generator("another_secret") + + # Ensure does not throw exception + yield self.handler.get_or_create_user("@neil:matrix.org", 'c', "Neil") + + self._macaroon_mock_generator("another another secret") + store.count_monthly_users = Mock(return_value=lots_of_users) + with self.assertRaises(RegistrationError): + yield self.handler.register(localpart=local_part) + + self._macaroon_mock_generator("another another secret") + store.count_monthly_users = Mock(return_value=lots_of_users) + with self.assertRaises(RegistrationError): + yield self.handler.register_saml2(local_part) + + def _macaroon_mock_generator(self, secret): + """ + Reset macaroon generator in the case where the test creates multiple users + """ + macaroon_generator = Mock( + generate_access_token=Mock(return_value=secret)) + self.hs.get_macaroon_generator = Mock(return_value=macaroon_generator) + self.hs.handlers = RegistrationHandlers(self.hs) + self.handler = self.hs.get_handlers().registration_handler -- cgit 1.4.1 From e908b8683248328dd92479fc81be350336b9c8f4 Mon Sep 17 00:00:00 2001 From: Travis Ralston Date: Mon, 30 Jul 2018 16:24:02 -0600 Subject: Remove pdu_failures from transactions The field is never read from, and all the opportunities given to populate it are not utilized. It should be very safe to remove this. --- changelog.d/3628.misc | 1 + synapse/federation/federation_server.py | 4 --- synapse/federation/send_queue.py | 63 +-------------------------------- synapse/federation/transaction_queue.py | 32 +++-------------- synapse/federation/transport/server.py | 3 +- synapse/federation/units.py | 1 - tests/handlers/test_typing.py | 1 - 7 files changed, 8 insertions(+), 97 deletions(-) create mode 100644 changelog.d/3628.misc (limited to 'tests/handlers') diff --git a/changelog.d/3628.misc b/changelog.d/3628.misc new file mode 100644 index 0000000000..1aebefbe18 --- /dev/null +++ b/changelog.d/3628.misc @@ -0,0 +1 @@ +Remove unused field "pdu_failures" from transactions. diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index e501251b6e..657935d1ac 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -207,10 +207,6 @@ class FederationServer(FederationBase): edu.content ) - pdu_failures = getattr(transaction, "pdu_failures", []) - for fail in pdu_failures: - logger.info("Got failure %r", fail) - response = { "pdus": pdu_results, } diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py index 5157c3860d..0bb468385d 100644 --- a/synapse/federation/send_queue.py +++ b/synapse/federation/send_queue.py @@ -62,8 +62,6 @@ class FederationRemoteSendQueue(object): self.edus = SortedDict() # stream position -> Edu - self.failures = SortedDict() # stream position -> (destination, Failure) - self.device_messages = SortedDict() # stream position -> destination self.pos = 1 @@ -79,7 +77,7 @@ class FederationRemoteSendQueue(object): for queue_name in [ "presence_map", "presence_changed", "keyed_edu", "keyed_edu_changed", - "edus", "failures", "device_messages", "pos_time", + "edus", "device_messages", "pos_time", ]: register(queue_name, getattr(self, queue_name)) @@ -149,12 +147,6 @@ class FederationRemoteSendQueue(object): for key in keys[:i]: del self.edus[key] - # Delete things out of failure map - keys = self.failures.keys() - i = self.failures.bisect_left(position_to_delete) - for key in keys[:i]: - del self.failures[key] - # Delete things out of device map keys = self.device_messages.keys() i = self.device_messages.bisect_left(position_to_delete) @@ -204,13 +196,6 @@ class FederationRemoteSendQueue(object): self.notifier.on_new_replication_data() - def send_failure(self, failure, destination): - """As per TransactionQueue""" - pos = self._next_pos() - - self.failures[pos] = (destination, str(failure)) - self.notifier.on_new_replication_data() - def send_device_messages(self, destination): """As per TransactionQueue""" pos = self._next_pos() @@ -285,17 +270,6 @@ class FederationRemoteSendQueue(object): for (pos, edu) in edus: rows.append((pos, EduRow(edu))) - # Fetch changed failures - i = self.failures.bisect_right(from_token) - j = self.failures.bisect_right(to_token) + 1 - failures = self.failures.items()[i:j] - - for (pos, (destination, failure)) in failures: - rows.append((pos, FailureRow( - destination=destination, - failure=failure, - ))) - # Fetch changed device messages i = self.device_messages.bisect_right(from_token) j = self.device_messages.bisect_right(to_token) + 1 @@ -417,34 +391,6 @@ class EduRow(BaseFederationRow, namedtuple("EduRow", ( buff.edus.setdefault(self.edu.destination, []).append(self.edu) -class FailureRow(BaseFederationRow, namedtuple("FailureRow", ( - "destination", # str - "failure", -))): - """Streams failures to a remote server. Failures are issued when there was - something wrong with a transaction the remote sent us, e.g. it included - an event that was invalid. - """ - - TypeId = "f" - - @staticmethod - def from_data(data): - return FailureRow( - destination=data["destination"], - failure=data["failure"], - ) - - def to_data(self): - return { - "destination": self.destination, - "failure": self.failure, - } - - def add_to_buffer(self, buff): - buff.failures.setdefault(self.destination, []).append(self.failure) - - class DeviceRow(BaseFederationRow, namedtuple("DeviceRow", ( "destination", # str ))): @@ -471,7 +417,6 @@ TypeToRow = { PresenceRow, KeyedEduRow, EduRow, - FailureRow, DeviceRow, ) } @@ -481,7 +426,6 @@ ParsedFederationStreamData = namedtuple("ParsedFederationStreamData", ( "presence", # list(UserPresenceState) "keyed_edus", # dict of destination -> { key -> Edu } "edus", # dict of destination -> [Edu] - "failures", # dict of destination -> [failures] "device_destinations", # set of destinations )) @@ -503,7 +447,6 @@ def process_rows_for_federation(transaction_queue, rows): presence=[], keyed_edus={}, edus={}, - failures={}, device_destinations=set(), ) @@ -532,9 +475,5 @@ def process_rows_for_federation(transaction_queue, rows): edu.destination, edu.edu_type, edu.content, key=None, ) - for destination, failure_list in iteritems(buff.failures): - for failure in failure_list: - transaction_queue.send_failure(destination, failure) - for destination in buff.device_destinations: transaction_queue.send_device_messages(destination) diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py index 6996d6b695..78f9d40a3a 100644 --- a/synapse/federation/transaction_queue.py +++ b/synapse/federation/transaction_queue.py @@ -116,9 +116,6 @@ class TransactionQueue(object): ), ) - # destination -> list of tuple(failure, deferred) - self.pending_failures_by_dest = {} - # destination -> stream_id of last successfully sent to-device message. # NB: may be a long or an int. self.last_device_stream_id_by_dest = {} @@ -382,19 +379,6 @@ class TransactionQueue(object): self._attempt_new_transaction(destination) - def send_failure(self, failure, destination): - if destination == self.server_name or destination == "localhost": - return - - if not self.can_send_to(destination): - return - - self.pending_failures_by_dest.setdefault( - destination, [] - ).append(failure) - - self._attempt_new_transaction(destination) - def send_device_messages(self, destination): if destination == self.server_name or destination == "localhost": return @@ -469,7 +453,6 @@ class TransactionQueue(object): pending_pdus = self.pending_pdus_by_dest.pop(destination, []) pending_edus = self.pending_edus_by_dest.pop(destination, []) pending_presence = self.pending_presence_by_dest.pop(destination, {}) - pending_failures = self.pending_failures_by_dest.pop(destination, []) pending_edus.extend( self.pending_edus_keyed_by_dest.pop(destination, {}).values() @@ -497,7 +480,7 @@ class TransactionQueue(object): logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d", destination, len(pending_pdus)) - if not pending_pdus and not pending_edus and not pending_failures: + if not pending_pdus and not pending_edus: logger.debug("TX [%s] Nothing to send", destination) self.last_device_stream_id_by_dest[destination] = ( device_stream_id @@ -507,7 +490,7 @@ class TransactionQueue(object): # END CRITICAL SECTION success = yield self._send_new_transaction( - destination, pending_pdus, pending_edus, pending_failures, + destination, pending_pdus, pending_edus, ) if success: sent_transactions_counter.inc() @@ -584,14 +567,12 @@ class TransactionQueue(object): @measure_func("_send_new_transaction") @defer.inlineCallbacks - def _send_new_transaction(self, destination, pending_pdus, pending_edus, - pending_failures): + def _send_new_transaction(self, destination, pending_pdus, pending_edus): # Sort based on the order field pending_pdus.sort(key=lambda t: t[1]) pdus = [x[0] for x in pending_pdus] edus = pending_edus - failures = [x.get_dict() for x in pending_failures] success = True @@ -601,11 +582,10 @@ class TransactionQueue(object): logger.debug( "TX [%s] {%s} Attempting new transaction" - " (pdus: %d, edus: %d, failures: %d)", + " (pdus: %d, edus: %d)", destination, txn_id, len(pdus), len(edus), - len(failures) ) logger.debug("TX [%s] Persisting transaction...", destination) @@ -617,7 +597,6 @@ class TransactionQueue(object): destination=destination, pdus=pdus, edus=edus, - pdu_failures=failures, ) self._next_txn_id += 1 @@ -627,12 +606,11 @@ class TransactionQueue(object): logger.debug("TX [%s] Persisted transaction", destination) logger.info( "TX [%s] {%s} Sending transaction [%s]," - " (PDUs: %d, EDUs: %d, failures: %d)", + " (PDUs: %d, EDUs: %d)", destination, txn_id, transaction.transaction_id, len(pdus), len(edus), - len(failures), ) # Actually send the transaction diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 8574898f0c..3b5ea9515a 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -283,11 +283,10 @@ class FederationSendServlet(BaseFederationServlet): ) logger.info( - "Received txn %s from %s. (PDUs: %d, EDUs: %d, failures: %d)", + "Received txn %s from %s. (PDUs: %d, EDUs: %d)", transaction_id, origin, len(transaction_data.get("pdus", [])), len(transaction_data.get("edus", [])), - len(transaction_data.get("failures", [])), ) # We should ideally be getting this from the security layer. diff --git a/synapse/federation/units.py b/synapse/federation/units.py index bb1b3b13f7..c5ab14314e 100644 --- a/synapse/federation/units.py +++ b/synapse/federation/units.py @@ -73,7 +73,6 @@ class Transaction(JsonEncodedObject): "previous_ids", "pdus", "edus", - "pdu_failures", ] internal_keys = [ diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index b08856f763..2c263af1a3 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -44,7 +44,6 @@ def _expect_edu(destination, edu_type, content, origin="test"): "content": content, } ], - "pdu_failures": [], } -- cgit 1.4.1 From df2235e7fab44a5155134a336a4c27424398c1be Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Tue, 31 Jul 2018 13:16:20 +0100 Subject: coding style --- synapse/app/homeserver.py | 6 +++++- synapse/config/server.py | 2 +- synapse/handlers/auth.py | 3 ++- synapse/storage/__init__.py | 3 +-- synapse/storage/schema/delta/50/make_event_content_nullable.py | 2 +- tests/handlers/test_auth.py | 4 ++-- tests/storage/test__init__.py | 1 - 7 files changed, 12 insertions(+), 9 deletions(-) (limited to 'tests/handlers') diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 96c45b7209..82979e7d1b 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -18,9 +18,10 @@ import logging import os import sys -from prometheus_client import Gauge from six import iteritems +from prometheus_client import Gauge + from twisted.application import service from twisted.internet import defer, reactor from twisted.web.resource import EncodingResourceWrapper, NoResource @@ -300,12 +301,15 @@ class SynapseHomeServer(HomeServer): except IncorrectDatabaseSetup as e: quit_with_error(e.message) + # Gauges to expose monthly active user control metrics current_mau_gauge = Gauge("synapse_admin_current_mau", "Current MAU") max_mau_value_gauge = Gauge("synapse_admin_max_mau_value", "MAU Limit") limit_usage_by_mau_gauge = Gauge( "synapse_admin_limit_usage_by_mau", "MAU Limiting enabled" ) + + def setup(config_options): """ Args: diff --git a/synapse/config/server.py b/synapse/config/server.py index 8b335bff3f..9af42a93ad 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -70,7 +70,7 @@ class ServerConfig(Config): # Options to control access by tracking MAU self.limit_usage_by_mau = config.get("limit_usage_by_mau", False) self.max_mau_value = config.get( - "max_mau_value", 0, + "max_mau_value", 0, ) # FIXME: federation_domain_whitelist needs sytests self.federation_domain_whitelist = None diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index f3734f11bd..28f1c1afbb 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -903,9 +903,10 @@ class AuthHandler(BaseHandler): current_mau = self.store.count_monthly_users() if current_mau >= self.hs.config.max_mau_value: raise AuthError( - 403, "MAU Limit Exceeded", errcode=Codes.MAU_LIMIT_EXCEEDED + 403, "MAU Limit Exceeded", errcode=Codes.MAU_LIMIT_EXCEEDED ) + @attr.s class MacaroonGenerator(object): diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 044e988e92..4747118ed7 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -60,6 +60,7 @@ from .util.id_generators import ChainedIdGenerator, IdGenerator, StreamIdGenerat logger = logging.getLogger(__name__) + class DataStore(RoomMemberStore, RoomStore, RegistrationStore, StreamStore, ProfileStore, PresenceStore, TransactionStore, @@ -291,8 +292,6 @@ class DataStore(RoomMemberStore, RoomStore, finally: txn.close() - - def count_r30_users(self): """ Counts the number of 30 day retained users, defined as:- diff --git a/synapse/storage/schema/delta/50/make_event_content_nullable.py b/synapse/storage/schema/delta/50/make_event_content_nullable.py index 7d27342e39..6dd467b6c5 100644 --- a/synapse/storage/schema/delta/50/make_event_content_nullable.py +++ b/synapse/storage/schema/delta/50/make_event_content_nullable.py @@ -88,5 +88,5 @@ def run_upgrade(cur, database_engine, *args, **kwargs): "UPDATE sqlite_master SET sql=? WHERE tbl_name='events' AND type='table'", (sql, ), ) - cur.execute("PRAGMA schema_version=%i" % (oldver+1,)) + cur.execute("PRAGMA schema_version=%i" % (oldver + 1,)) cur.execute("PRAGMA writable_schema=OFF") diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index 57f78a6bec..e01f14a10a 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -13,16 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. from mock import Mock + import pymacaroons from twisted.internet import defer import synapse -from synapse.api.errors import AuthError import synapse.api.errors +from synapse.api.errors import AuthError from synapse.handlers.auth import AuthHandler - from tests import unittest from tests.utils import setup_test_homeserver diff --git a/tests/storage/test__init__.py b/tests/storage/test__init__.py index c9ae349871..fe6eeeaf10 100644 --- a/tests/storage/test__init__.py +++ b/tests/storage/test__init__.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import sys from twisted.internet import defer -- cgit 1.4.1 From 7931393495c76eef0af9b91c7904c88943197054 Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Wed, 1 Aug 2018 10:21:56 +0100 Subject: make count_monthly_users async synapse/handlers/auth.py --- synapse/handlers/register.py | 9 +++++---- synapse/storage/__init__.py | 26 +++++++++++++------------- tests/handlers/test_auth.py | 39 ++++++++++++++++++++++----------------- tests/handlers/test_register.py | 10 ++++++---- 4 files changed, 46 insertions(+), 38 deletions(-) (limited to 'tests/handlers') diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index f46b8355c0..cc935a5e84 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -144,7 +144,7 @@ class RegistrationHandler(BaseHandler): Raises: RegistrationError if there was a problem registering. """ - self._check_mau_limits() + yield self._check_mau_limits() password_hash = None if password: password_hash = yield self.auth_handler().hash(password) @@ -289,7 +289,7 @@ class RegistrationHandler(BaseHandler): 400, "User ID can only contain characters a-z, 0-9, or '=_-./'", ) - self._check_mau_limits() + yield self._check_mau_limits() user = UserID(localpart, self.hs.hostname) user_id = user.to_string() @@ -439,7 +439,7 @@ class RegistrationHandler(BaseHandler): """ if localpart is None: raise SynapseError(400, "Request must include user id") - self._check_mau_limits() + yield self._check_mau_limits() need_register = True try: @@ -534,13 +534,14 @@ class RegistrationHandler(BaseHandler): action="join", ) + @defer.inlineCallbacks def _check_mau_limits(self): """ Do not accept registrations if monthly active user limits exceeded and limiting is enabled """ if self.hs.config.limit_usage_by_mau is True: - current_mau = self.store.count_monthly_users() + current_mau = yield self.store.count_monthly_users() if current_mau >= self.hs.config.max_mau_value: raise RegistrationError( 403, "MAU Limit Exceeded", Codes.MAU_LIMIT_EXCEEDED diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 4747118ed7..f9682832ca 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -273,24 +273,24 @@ class DataStore(RoomMemberStore, RoomStore, This method should be refactored with count_daily_users - the only reason not to is waiting on definition of mau returns: - int: count of current monthly active users + defered: resolves to int """ + def _count_monthly_users(txn): + thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30) + sql = """ + SELECT COALESCE(count(*), 0) FROM ( + SELECT user_id FROM user_ips + WHERE last_seen > ? + GROUP BY user_id + ) u + """ - thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30) - sql = """ - SELECT COALESCE(count(*), 0) FROM ( - SELECT user_id FROM user_ips - WHERE last_seen > ? - GROUP BY user_id - ) u - """ - try: - txn = self.db_conn.cursor() txn.execute(sql, (thirty_days_ago,)) count, = txn.fetchone() + print "Count is %d" % (count,) return count - finally: - txn.close() + + return self.runInteraction("count_monthly_users", _count_monthly_users) def count_r30_users(self): """ diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index e01f14a10a..440a453082 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -77,38 +77,37 @@ class AuthTestCase(unittest.TestCase): v.satisfy_general(verify_nonce) v.verify(macaroon, self.hs.config.macaroon_secret_key) + @defer.inlineCallbacks def test_short_term_login_token_gives_user_id(self): self.hs.clock.now = 1000 token = self.macaroon_generator.generate_short_term_login_token( "a_user", 5000 ) - - self.assertEqual( - "a_user", - self.auth_handler.validate_short_term_login_token_and_get_user_id( - token - ) + user_id = yield self.auth_handler.validate_short_term_login_token_and_get_user_id( + token ) + self.assertEqual("a_user", user_id) # when we advance the clock, the token should be rejected self.hs.clock.now = 6000 with self.assertRaises(synapse.api.errors.AuthError): - self.auth_handler.validate_short_term_login_token_and_get_user_id( + yield self.auth_handler.validate_short_term_login_token_and_get_user_id( token ) + @defer.inlineCallbacks def test_short_term_login_token_cannot_replace_user_id(self): token = self.macaroon_generator.generate_short_term_login_token( "a_user", 5000 ) macaroon = pymacaroons.Macaroon.deserialize(token) + user_id = yield self.auth_handler.validate_short_term_login_token_and_get_user_id( + macaroon.serialize() + ) self.assertEqual( - "a_user", - self.auth_handler.validate_short_term_login_token_and_get_user_id( - macaroon.serialize() - ) + "a_user", user_id ) # add another "user_id" caveat, which might allow us to override the @@ -116,7 +115,7 @@ class AuthTestCase(unittest.TestCase): macaroon.add_first_party_caveat("user_id = b_user") with self.assertRaises(synapse.api.errors.AuthError): - self.auth_handler.validate_short_term_login_token_and_get_user_id( + yield self.auth_handler.validate_short_term_login_token_and_get_user_id( macaroon.serialize() ) @@ -126,7 +125,7 @@ class AuthTestCase(unittest.TestCase): # Ensure does not throw exception yield self.auth_handler.get_access_token_for_user_id('user_a') - self.auth_handler.validate_short_term_login_token_and_get_user_id( + yield self.auth_handler.validate_short_term_login_token_and_get_user_id( self._get_macaroon().serialize() ) @@ -134,24 +133,30 @@ class AuthTestCase(unittest.TestCase): def test_mau_limits_exceeded(self): self.hs.config.limit_usage_by_mau = True self.hs.get_datastore().count_monthly_users = Mock( - return_value=self.large_number_of_users + return_value=defer.succeed(self.large_number_of_users) ) + with self.assertRaises(AuthError): yield self.auth_handler.get_access_token_for_user_id('user_a') + + self.hs.get_datastore().count_monthly_users = Mock( + return_value=defer.succeed(self.large_number_of_users) + ) with self.assertRaises(AuthError): - self.auth_handler.validate_short_term_login_token_and_get_user_id( + yield self.auth_handler.validate_short_term_login_token_and_get_user_id( self._get_macaroon().serialize() ) @defer.inlineCallbacks def test_mau_limits_not_exceeded(self): self.hs.config.limit_usage_by_mau = True + self.hs.get_datastore().count_monthly_users = Mock( - return_value=self.small_number_of_users + return_value=defer.succeed(self.small_number_of_users) ) # Ensure does not raise exception yield self.auth_handler.get_access_token_for_user_id('user_a') - self.auth_handler.validate_short_term_login_token_and_get_user_id( + yield self.auth_handler.validate_short_term_login_token_and_get_user_id( self._get_macaroon().serialize() ) diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index a5a8e7c954..0937d71cf6 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -90,7 +90,7 @@ class RegistrationTestCase(unittest.TestCase): lots_of_users = 100 small_number_users = 1 - store.count_monthly_users = Mock(return_value=lots_of_users) + store.count_monthly_users = Mock(return_value=defer.succeed(lots_of_users)) # Ensure does not throw exception yield self.handler.get_or_create_user(requester, 'a', display_name) @@ -100,7 +100,7 @@ class RegistrationTestCase(unittest.TestCase): with self.assertRaises(RegistrationError): yield self.handler.get_or_create_user(requester, 'b', display_name) - store.count_monthly_users = Mock(return_value=small_number_users) + store.count_monthly_users = Mock(return_value=defer.succeed(small_number_users)) self._macaroon_mock_generator("another_secret") @@ -108,12 +108,14 @@ class RegistrationTestCase(unittest.TestCase): yield self.handler.get_or_create_user("@neil:matrix.org", 'c', "Neil") self._macaroon_mock_generator("another another secret") - store.count_monthly_users = Mock(return_value=lots_of_users) + store.count_monthly_users = Mock(return_value=defer.succeed(lots_of_users)) + with self.assertRaises(RegistrationError): yield self.handler.register(localpart=local_part) self._macaroon_mock_generator("another another secret") - store.count_monthly_users = Mock(return_value=lots_of_users) + store.count_monthly_users = Mock(return_value=defer.succeed(lots_of_users)) + with self.assertRaises(RegistrationError): yield self.handler.register_saml2(local_part) -- cgit 1.4.1 From 6eed16d8a2335c97675b6b8661869848f397ea29 Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Wed, 1 Aug 2018 14:02:10 +0100 Subject: fix test for py3 --- tests/handlers/test_auth.py | 4 ++++ 1 file changed, 4 insertions(+) (limited to 'tests/handlers') diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index 440a453082..55eab9e9cf 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -156,6 +156,10 @@ class AuthTestCase(unittest.TestCase): ) # Ensure does not raise exception yield self.auth_handler.get_access_token_for_user_id('user_a') + + self.hs.get_datastore().count_monthly_users = Mock( + return_value=defer.succeed(self.small_number_of_users) + ) yield self.auth_handler.validate_short_term_login_token_and_get_user_id( self._get_macaroon().serialize() ) -- cgit 1.4.1 From 74b1d46ad9ae692774f2e9d71cbbe1cea91b4070 Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Thu, 2 Aug 2018 16:57:35 +0100 Subject: do mau checks based on monthly_active_users table --- synapse/api/auth.py | 13 ++++++++ synapse/handlers/auth.py | 10 +++--- synapse/handlers/register.py | 10 +++--- synapse/storage/client_ips.py | 15 +++++---- tests/api/test_auth.py | 31 +++++++++++++++++- tests/handlers/test_auth.py | 8 ++--- tests/handlers/test_register.py | 71 ++++++++++++++++++++--------------------- 7 files changed, 97 insertions(+), 61 deletions(-) (limited to 'tests/handlers') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index d8022bcf8e..943a488339 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -773,3 +773,16 @@ class Auth(object): raise AuthError( 403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN ) + + @defer.inlineCallbacks + def check_auth_blocking(self, error): + """Checks if the user should be rejected for some external reason, + such as monthly active user limiting or global disable flag + Args: + error (Error): The error that should be raised if user is to be + blocked + """ + if self.hs.config.limit_usage_by_mau is True: + current_mau = yield self.store.get_monthly_active_count() + if current_mau >= self.hs.config.max_mau_value: + raise error diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 184eef09d0..8f9cff92e8 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -913,12 +913,10 @@ class AuthHandler(BaseHandler): Ensure that if mau blocking is enabled that invalid users cannot log in. """ - if self.hs.config.limit_usage_by_mau is True: - current_mau = yield self.store.count_monthly_users() - if current_mau >= self.hs.config.max_mau_value: - raise AuthError( - 403, "MAU Limit Exceeded", errcode=Codes.MAU_LIMIT_EXCEEDED - ) + error = AuthError( + 403, "Monthly Active User limits exceeded", errcode=Codes.MAU_LIMIT_EXCEEDED + ) + yield self.auth.check_auth_blocking(error) @attr.s diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 289704b241..706ed8c292 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -540,9 +540,7 @@ class RegistrationHandler(BaseHandler): Do not accept registrations if monthly active user limits exceeded and limiting is enabled """ - if self.hs.config.limit_usage_by_mau is True: - current_mau = yield self.store.count_monthly_users() - if current_mau >= self.hs.config.max_mau_value: - raise RegistrationError( - 403, "MAU Limit Exceeded", Codes.MAU_LIMIT_EXCEEDED - ) + error = RegistrationError( + 403, "Monthly Active User limits exceeded", errcode=Codes.MAU_LIMIT_EXCEEDED + ) + yield self.auth.check_auth_blocking(error) diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py index 506915a1ef..83d64d1563 100644 --- a/synapse/storage/client_ips.py +++ b/synapse/storage/client_ips.py @@ -97,21 +97,22 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): @defer.inlineCallbacks def _populate_monthly_active_users(self, user_id): + """Checks on the state of monthly active user limits and optionally + add the user to the monthly active tables + + Args: + user_id(str): the user_id to query + """ + store = self.hs.get_datastore() - print "entering _populate_monthly_active_users" if self.hs.config.limit_usage_by_mau: - print "self.hs.config.limit_usage_by_mau is TRUE" is_user_monthly_active = yield store.is_user_monthly_active(user_id) - print "is_user_monthly_active is %r" % is_user_monthly_active if is_user_monthly_active: yield store.upsert_monthly_active_user(user_id) else: count = yield store.get_monthly_active_count() - print "count is %d" % count if count < self.hs.config.max_mau_value: - print "count is less than self.hs.config.max_mau_value " - res = yield store.upsert_monthly_active_user(user_id) - print "upsert response is %r" % res + yield store.upsert_monthly_active_user(user_id) def _update_client_ips_batch(self): def update(): diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index a82d737e71..54bdf28663 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -21,7 +21,7 @@ from twisted.internet import defer import synapse.handlers.auth from synapse.api.auth import Auth -from synapse.api.errors import AuthError +from synapse.api.errors import AuthError, Codes from synapse.types import UserID from tests import unittest @@ -444,3 +444,32 @@ class AuthTestCase(unittest.TestCase): self.assertEqual("Guest access token used for regular user", cm.exception.msg) self.store.get_user_by_id.assert_called_with(USER_ID) + + @defer.inlineCallbacks + def test_blocking_mau(self): + self.hs.config.limit_usage_by_mau = False + self.hs.config.max_mau_value = 50 + lots_of_users = 100 + small_number_of_users = 1 + + error = AuthError( + 403, "MAU Limit Exceeded", errcode=Codes.MAU_LIMIT_EXCEEDED + ) + + # Ensure no error thrown + yield self.auth.check_auth_blocking(error) + + self.hs.config.limit_usage_by_mau = True + + self.store.get_monthly_active_count = Mock( + return_value=defer.succeed(lots_of_users) + ) + + with self.assertRaises(AuthError): + yield self.auth.check_auth_blocking(error) + + # Ensure does not throw an error + self.store.get_monthly_active_count = Mock( + return_value=defer.succeed(small_number_of_users) + ) + yield self.auth.check_auth_blocking(error) diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index 55eab9e9cf..8a9bf2d5fd 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -132,14 +132,14 @@ class AuthTestCase(unittest.TestCase): @defer.inlineCallbacks def test_mau_limits_exceeded(self): self.hs.config.limit_usage_by_mau = True - self.hs.get_datastore().count_monthly_users = Mock( + self.hs.get_datastore().get_monthly_active_count = Mock( return_value=defer.succeed(self.large_number_of_users) ) with self.assertRaises(AuthError): yield self.auth_handler.get_access_token_for_user_id('user_a') - self.hs.get_datastore().count_monthly_users = Mock( + self.hs.get_datastore().get_monthly_active_count = Mock( return_value=defer.succeed(self.large_number_of_users) ) with self.assertRaises(AuthError): @@ -151,13 +151,13 @@ class AuthTestCase(unittest.TestCase): def test_mau_limits_not_exceeded(self): self.hs.config.limit_usage_by_mau = True - self.hs.get_datastore().count_monthly_users = Mock( + self.hs.get_datastore().get_monthly_active_count = Mock( return_value=defer.succeed(self.small_number_of_users) ) # Ensure does not raise exception yield self.auth_handler.get_access_token_for_user_id('user_a') - self.hs.get_datastore().count_monthly_users = Mock( + self.hs.get_datastore().get_monthly_active_count = Mock( return_value=defer.succeed(self.small_number_of_users) ) yield self.auth_handler.validate_short_term_login_token_and_get_user_id( diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index 0937d71cf6..6b5b8b3772 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -50,6 +50,10 @@ class RegistrationTestCase(unittest.TestCase): self.hs.get_macaroon_generator = Mock(return_value=self.macaroon_generator) self.hs.handlers = RegistrationHandlers(self.hs) self.handler = self.hs.get_handlers().registration_handler + self.store = self.hs.get_datastore() + self.hs.config.max_mau_value = 50 + self.lots_of_users = 100 + self.small_number_of_users = 1 @defer.inlineCallbacks def test_user_is_created_and_logged_in_if_doesnt_exist(self): @@ -80,51 +84,44 @@ class RegistrationTestCase(unittest.TestCase): self.assertEquals(result_token, 'secret') @defer.inlineCallbacks - def test_cannot_register_when_mau_limits_exceeded(self): - local_part = "someone" - display_name = "someone" - requester = create_requester("@as:test") - store = self.hs.get_datastore() + def test_mau_limits_when_disabled(self): self.hs.config.limit_usage_by_mau = False - self.hs.config.max_mau_value = 50 - lots_of_users = 100 - small_number_users = 1 - - store.count_monthly_users = Mock(return_value=defer.succeed(lots_of_users)) - # Ensure does not throw exception - yield self.handler.get_or_create_user(requester, 'a', display_name) + yield self.handler.get_or_create_user("requester", 'a', "display_name") + @defer.inlineCallbacks + def test_get_or_create_user_mau_not_blocked(self): self.hs.config.limit_usage_by_mau = True - - with self.assertRaises(RegistrationError): - yield self.handler.get_or_create_user(requester, 'b', display_name) - - store.count_monthly_users = Mock(return_value=defer.succeed(small_number_users)) - - self._macaroon_mock_generator("another_secret") - + self.store.count_monthly_users = Mock( + return_value=defer.succeed(self.small_number_of_users) + ) # Ensure does not throw exception - yield self.handler.get_or_create_user("@neil:matrix.org", 'c', "Neil") + yield self.handler.get_or_create_user("@user:server", 'c', "User") - self._macaroon_mock_generator("another another secret") - store.count_monthly_users = Mock(return_value=defer.succeed(lots_of_users)) + @defer.inlineCallbacks + def test_get_or_create_user_mau_blocked(self): + self.hs.config.limit_usage_by_mau = True + self.store.get_monthly_active_count = Mock( + return_value=defer.succeed(self.lots_of_users) + ) with self.assertRaises(RegistrationError): - yield self.handler.register(localpart=local_part) + yield self.handler.get_or_create_user("requester", 'b', "display_name") - self._macaroon_mock_generator("another another secret") - store.count_monthly_users = Mock(return_value=defer.succeed(lots_of_users)) + @defer.inlineCallbacks + def test_register_mau_blocked(self): + self.hs.config.limit_usage_by_mau = True + self.store.get_monthly_active_count = Mock( + return_value=defer.succeed(self.lots_of_users) + ) + with self.assertRaises(RegistrationError): + yield self.handler.register(localpart="local_part") + @defer.inlineCallbacks + def test_register_saml2_mau_blocked(self): + self.hs.config.limit_usage_by_mau = True + self.store.get_monthly_active_count = Mock( + return_value=defer.succeed(self.lots_of_users) + ) with self.assertRaises(RegistrationError): - yield self.handler.register_saml2(local_part) - - def _macaroon_mock_generator(self, secret): - """ - Reset macaroon generator in the case where the test creates multiple users - """ - macaroon_generator = Mock( - generate_access_token=Mock(return_value=secret)) - self.hs.get_macaroon_generator = Mock(return_value=macaroon_generator) - self.hs.handlers = RegistrationHandlers(self.hs) - self.handler = self.hs.get_handlers().registration_handler + yield self.handler.register_saml2(localpart="local_part") -- cgit 1.4.1 From 886be75ad1bc60e016611b453b9644e8db17a9f1 Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Fri, 3 Aug 2018 22:29:03 +0100 Subject: bug fixes --- synapse/handlers/auth.py | 15 ++------------- synapse/handlers/register.py | 8 ++++---- synapse/storage/monthly_active_users.py | 3 +-- tests/api/test_auth.py | 10 +++------- tests/handlers/test_register.py | 1 - 5 files changed, 10 insertions(+), 27 deletions(-) (limited to 'tests/handlers') diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 8f9cff92e8..7ea8ce9f94 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -520,7 +520,7 @@ class AuthHandler(BaseHandler): """ logger.info("Logging in user %s on device %s", user_id, device_id) access_token = yield self.issue_access_token(user_id, device_id) - yield self._check_mau_limits() + yield self.auth.check_auth_blocking() # the device *should* have been registered before we got here; however, # it's possible we raced against a DELETE operation. The thing we @@ -734,7 +734,7 @@ class AuthHandler(BaseHandler): @defer.inlineCallbacks def validate_short_term_login_token_and_get_user_id(self, login_token): - yield self._check_mau_limits() + yield self.auth.check_auth_blocking() auth_api = self.hs.get_auth() user_id = None try: @@ -907,17 +907,6 @@ class AuthHandler(BaseHandler): else: return defer.succeed(False) - @defer.inlineCallbacks - def _check_mau_limits(self): - """ - Ensure that if mau blocking is enabled that invalid users cannot - log in. - """ - error = AuthError( - 403, "Monthly Active User limits exceeded", errcode=Codes.MAU_LIMIT_EXCEEDED - ) - yield self.auth.check_auth_blocking(error) - @attr.s class MacaroonGenerator(object): diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 706ed8c292..8cf0a36a8f 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -540,7 +540,7 @@ class RegistrationHandler(BaseHandler): Do not accept registrations if monthly active user limits exceeded and limiting is enabled """ - error = RegistrationError( - 403, "Monthly Active User limits exceeded", errcode=Codes.MAU_LIMIT_EXCEEDED - ) - yield self.auth.check_auth_blocking(error) + try: + yield self.auth.check_auth_blocking() + except AuthError as e: + raise RegistrationError(e.code, e.message, e.errcode) diff --git a/synapse/storage/monthly_active_users.py b/synapse/storage/monthly_active_users.py index 6def6830d0..135837507a 100644 --- a/synapse/storage/monthly_active_users.py +++ b/synapse/storage/monthly_active_users.py @@ -54,7 +54,7 @@ class MonthlyActiveUsersStore(SQLBaseStore): """ txn.execute(sql, (self.hs.config.max_mau_value,)) - res = yield self.runInteraction("reap_monthly_active_users", _reap_users) + yield self.runInteraction("reap_monthly_active_users", _reap_users) # It seems poor to invalidate the whole cache, Postgres supports # 'Returning' which would allow me to invalidate only the # specific users, but sqlite has no way to do this and instead @@ -64,7 +64,6 @@ class MonthlyActiveUsersStore(SQLBaseStore): # something about it if and when the perf becomes significant self._user_last_seen_monthly_active.invalidate_all() self.get_monthly_active_count.invalidate_all() - return res @cached(num_args=0) def get_monthly_active_count(self): diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index 54bdf28663..e963963c73 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -452,12 +452,8 @@ class AuthTestCase(unittest.TestCase): lots_of_users = 100 small_number_of_users = 1 - error = AuthError( - 403, "MAU Limit Exceeded", errcode=Codes.MAU_LIMIT_EXCEEDED - ) - # Ensure no error thrown - yield self.auth.check_auth_blocking(error) + yield self.auth.check_auth_blocking() self.hs.config.limit_usage_by_mau = True @@ -466,10 +462,10 @@ class AuthTestCase(unittest.TestCase): ) with self.assertRaises(AuthError): - yield self.auth.check_auth_blocking(error) + yield self.auth.check_auth_blocking() # Ensure does not throw an error self.store.get_monthly_active_count = Mock( return_value=defer.succeed(small_number_of_users) ) - yield self.auth.check_auth_blocking(error) + yield self.auth.check_auth_blocking() diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index 6b5b8b3772..4ea59a58de 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -104,7 +104,6 @@ class RegistrationTestCase(unittest.TestCase): self.store.get_monthly_active_count = Mock( return_value=defer.succeed(self.lots_of_users) ) - with self.assertRaises(RegistrationError): yield self.handler.get_or_create_user("requester", 'b', "display_name") -- cgit 1.4.1 From e92fb00f32c63de6ea50ba1cbbadf74060ea143d Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Wed, 8 Aug 2018 17:54:49 +0100 Subject: sync auth blocking --- synapse/handlers/sync.py | 16 +++++++++++----- tests/handlers/test_sync.py | 42 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 5 deletions(-) create mode 100644 tests/handlers/test_sync.py (limited to 'tests/handlers') diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index dff1f67dcb..f748d9afb0 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -191,6 +191,7 @@ class SyncHandler(object): self.clock = hs.get_clock() self.response_cache = ResponseCache(hs, "sync") self.state = hs.get_state_handler() + self.auth = hs.get_auth() # ExpiringCache((User, Device)) -> LruCache(state_key => event_id) self.lazy_loaded_members_cache = ExpiringCache( @@ -198,18 +199,23 @@ class SyncHandler(object): max_len=0, expiry_ms=LAZY_LOADED_MEMBERS_CACHE_MAX_AGE, ) + @defer.inlineCallbacks def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0, full_state=False): """Get the sync for a client if we have new data for it now. Otherwise wait for new data to arrive on the server. If the timeout expires, then return an empty sync result. Returns: - A Deferred SyncResult. + Deferred[SyncResult] """ - return self.response_cache.wrap( - sync_config.request_key, - self._wait_for_sync_for_user, - sync_config, since_token, timeout, full_state, + yield self.auth.check_auth_blocking() + + defer.returnValue( + self.response_cache.wrap( + sync_config.request_key, + self._wait_for_sync_for_user, + sync_config, since_token, timeout, full_state, + ) ) @defer.inlineCallbacks diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py new file mode 100644 index 0000000000..3b1b4d4923 --- /dev/null +++ b/tests/handlers/test_sync.py @@ -0,0 +1,42 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from twisted.internet import defer + +import tests.unittest +import tests.utils +from tests.utils import setup_test_homeserver +from synapse.handlers.sync import SyncHandler, SyncConfig +from synapse.types import UserID + + +class SyncTestCase(tests.unittest.TestCase): + """ Tests Sync Handler. """ + + @defer.inlineCallbacks + def setUp(self): + self.hs = yield setup_test_homeserver() + self.sync_handler = SyncHandler(self.hs) + + @defer.inlineCallbacks + def test_wait_for_sync_for_user_auth_blocking(self): + sync_config = SyncConfig( + user=UserID("@user","server"), + filter_collection=None, + is_guest=False, + request_key="request_key", + device_id="device_id", + ) + res = yield self.sync_handler.wait_for_sync_for_user(sync_config) + print res -- cgit 1.4.1 From 2511f3f8a082585e680dad200069dec77b066a6a Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Thu, 9 Aug 2018 12:22:01 +1000 Subject: Test fixes for Python 3 (#3647) --- changelog.d/3647.misc | 1 + tests/handlers/test_typing.py | 4 +++- tests/rest/client/test_transactions.py | 4 ++-- tests/rest/client/v1/test_admin.py | 10 +++++----- tests/rest/client/v1/test_profile.py | 8 ++++---- tests/rest/client/v1/utils.py | 10 +++++----- tests/rest/client/v2_alpha/test_filter.py | 22 +++++++++++----------- tests/rest/client/v2_alpha/test_register.py | 14 +++++++------- tests/rest/client/v2_alpha/test_sync.py | 4 ++-- tests/server.py | 22 ++++++++++++++++++++-- tests/storage/test_event_federation.py | 2 +- tests/storage/test_state.py | 2 +- tests/test_server.py | 11 ++++------- tests/utils.py | 9 +++++---- 14 files changed, 71 insertions(+), 52 deletions(-) create mode 100644 changelog.d/3647.misc (limited to 'tests/handlers') diff --git a/changelog.d/3647.misc b/changelog.d/3647.misc new file mode 100644 index 0000000000..dbc66dae60 --- /dev/null +++ b/changelog.d/3647.misc @@ -0,0 +1 @@ +Tests now correctly execute on Python 3. diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 2c263af1a3..f422cf3c5a 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -48,7 +48,9 @@ def _expect_edu(destination, edu_type, content, origin="test"): def _make_edu_json(origin, edu_type, content): - return json.dumps(_expect_edu("test", edu_type, content, origin=origin)) + return json.dumps( + _expect_edu("test", edu_type, content, origin=origin) + ).encode('utf8') class TypingNotificationsTestCase(unittest.TestCase): diff --git a/tests/rest/client/test_transactions.py b/tests/rest/client/test_transactions.py index 34e68ae82f..d46c27e7e9 100644 --- a/tests/rest/client/test_transactions.py +++ b/tests/rest/client/test_transactions.py @@ -85,7 +85,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase): try: yield self.cache.fetch_or_execute(self.mock_key, cb) except Exception as e: - self.assertEqual(e.message, "boo") + self.assertEqual(e.args[0], "boo") self.assertIs(LoggingContext.current_context(), test_context) res = yield self.cache.fetch_or_execute(self.mock_key, cb) @@ -111,7 +111,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase): try: yield self.cache.fetch_or_execute(self.mock_key, cb) except Exception as e: - self.assertEqual(e.message, "boo") + self.assertEqual(e.args[0], "boo") self.assertIs(LoggingContext.current_context(), test_context) res = yield self.cache.fetch_or_execute(self.mock_key, cb) diff --git a/tests/rest/client/v1/test_admin.py b/tests/rest/client/v1/test_admin.py index 8c90145601..fb28883d30 100644 --- a/tests/rest/client/v1/test_admin.py +++ b/tests/rest/client/v1/test_admin.py @@ -140,7 +140,7 @@ class UserRegisterTestCase(unittest.TestCase): "admin": True, "mac": want_mac, } - ).encode('utf8') + ) request, channel = make_request("POST", self.url, body.encode('utf8')) render(request, self.resource, self.clock) @@ -168,7 +168,7 @@ class UserRegisterTestCase(unittest.TestCase): "admin": True, "mac": want_mac, } - ).encode('utf8') + ) request, channel = make_request("POST", self.url, body.encode('utf8')) render(request, self.resource, self.clock) @@ -195,7 +195,7 @@ class UserRegisterTestCase(unittest.TestCase): "admin": True, "mac": want_mac, } - ).encode('utf8') + ) request, channel = make_request("POST", self.url, body.encode('utf8')) render(request, self.resource, self.clock) @@ -253,7 +253,7 @@ class UserRegisterTestCase(unittest.TestCase): self.assertEqual('Invalid username', channel.json_body["error"]) # Must not have null bytes - body = json.dumps({"nonce": nonce(), "username": b"abcd\x00"}) + body = json.dumps({"nonce": nonce(), "username": u"abcd\u0000"}) request, channel = make_request("POST", self.url, body.encode('utf8')) render(request, self.resource, self.clock) @@ -289,7 +289,7 @@ class UserRegisterTestCase(unittest.TestCase): self.assertEqual('Invalid password', channel.json_body["error"]) # Must not have null bytes - body = json.dumps({"nonce": nonce(), "username": "a", "password": b"abcd\x00"}) + body = json.dumps({"nonce": nonce(), "username": "a", "password": u"abcd\u0000"}) request, channel = make_request("POST", self.url, body.encode('utf8')) render(request, self.resource, self.clock) diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py index d71cc8e0db..0516ce3cfb 100644 --- a/tests/rest/client/v1/test_profile.py +++ b/tests/rest/client/v1/test_profile.py @@ -80,7 +80,7 @@ class ProfileTestCase(unittest.TestCase): (code, response) = yield self.mock_resource.trigger( "PUT", "/profile/%s/displayname" % (myid), - '{"displayname": "Frank Jr."}' + b'{"displayname": "Frank Jr."}' ) self.assertEquals(200, code) @@ -95,7 +95,7 @@ class ProfileTestCase(unittest.TestCase): (code, response) = yield self.mock_resource.trigger( "PUT", "/profile/%s/displayname" % ("@4567:test"), - '{"displayname": "Frank Jr."}' + b'{"displayname": "Frank Jr."}' ) self.assertTrue( @@ -122,7 +122,7 @@ class ProfileTestCase(unittest.TestCase): (code, response) = yield self.mock_resource.trigger( "PUT", "/profile/%s/displayname" % ("@opaque:elsewhere"), - '{"displayname":"bob"}' + b'{"displayname":"bob"}' ) self.assertTrue( @@ -151,7 +151,7 @@ class ProfileTestCase(unittest.TestCase): (code, response) = yield self.mock_resource.trigger( "PUT", "/profile/%s/avatar_url" % (myid), - '{"avatar_url": "http://my.server/pic.gif"}' + b'{"avatar_url": "http://my.server/pic.gif"}' ) self.assertEquals(200, code) diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index 41de8e0762..e3bc5f378d 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -105,7 +105,7 @@ class RestTestCase(unittest.TestCase): "password": "test", "type": "m.login.password" })) - self.assertEquals(200, code) + self.assertEquals(200, code, msg=response) defer.returnValue(response) @defer.inlineCallbacks @@ -149,14 +149,14 @@ class RestHelper(object): def create_room_as(self, room_creator, is_public=True, tok=None): temp_id = self.auth_user_id self.auth_user_id = room_creator - path = b"/_matrix/client/r0/createRoom" + path = "/_matrix/client/r0/createRoom" content = {} if not is_public: content["visibility"] = "private" if tok: - path = path + b"?access_token=%s" % tok.encode('ascii') + path = path + "?access_token=%s" % tok - request, channel = make_request(b"POST", path, json.dumps(content).encode('utf8')) + request, channel = make_request("POST", path, json.dumps(content).encode('utf8')) request.render(self.resource) wait_until_result(self.hs.get_reactor(), channel) @@ -205,7 +205,7 @@ class RestHelper(object): data = {"membership": membership} request, channel = make_request( - b"PUT", path.encode('ascii'), json.dumps(data).encode('utf8') + "PUT", path, json.dumps(data).encode('utf8') ) request.render(self.resource) diff --git a/tests/rest/client/v2_alpha/test_filter.py b/tests/rest/client/v2_alpha/test_filter.py index e890f0feac..de33b10a5f 100644 --- a/tests/rest/client/v2_alpha/test_filter.py +++ b/tests/rest/client/v2_alpha/test_filter.py @@ -33,7 +33,7 @@ PATH_PREFIX = "/_matrix/client/v2_alpha" class FilterTestCase(unittest.TestCase): - USER_ID = b"@apple:test" + USER_ID = "@apple:test" EXAMPLE_FILTER = {"room": {"timeline": {"types": ["m.room.message"]}}} EXAMPLE_FILTER_JSON = b'{"room": {"timeline": {"types": ["m.room.message"]}}}' TO_REGISTER = [filter] @@ -72,8 +72,8 @@ class FilterTestCase(unittest.TestCase): def test_add_filter(self): request, channel = make_request( - b"POST", - b"/_matrix/client/r0/user/%s/filter" % (self.USER_ID), + "POST", + "/_matrix/client/r0/user/%s/filter" % (self.USER_ID), self.EXAMPLE_FILTER_JSON, ) request.render(self.resource) @@ -87,8 +87,8 @@ class FilterTestCase(unittest.TestCase): def test_add_filter_for_other_user(self): request, channel = make_request( - b"POST", - b"/_matrix/client/r0/user/%s/filter" % (b"@watermelon:test"), + "POST", + "/_matrix/client/r0/user/%s/filter" % ("@watermelon:test"), self.EXAMPLE_FILTER_JSON, ) request.render(self.resource) @@ -101,8 +101,8 @@ class FilterTestCase(unittest.TestCase): _is_mine = self.hs.is_mine self.hs.is_mine = lambda target_user: False request, channel = make_request( - b"POST", - b"/_matrix/client/r0/user/%s/filter" % (self.USER_ID), + "POST", + "/_matrix/client/r0/user/%s/filter" % (self.USER_ID), self.EXAMPLE_FILTER_JSON, ) request.render(self.resource) @@ -119,7 +119,7 @@ class FilterTestCase(unittest.TestCase): self.clock.advance(1) filter_id = filter_id.result request, channel = make_request( - b"GET", b"/_matrix/client/r0/user/%s/filter/%s" % (self.USER_ID, filter_id) + "GET", "/_matrix/client/r0/user/%s/filter/%s" % (self.USER_ID, filter_id) ) request.render(self.resource) wait_until_result(self.clock, channel) @@ -129,7 +129,7 @@ class FilterTestCase(unittest.TestCase): def test_get_filter_non_existant(self): request, channel = make_request( - b"GET", "/_matrix/client/r0/user/%s/filter/12382148321" % (self.USER_ID) + "GET", "/_matrix/client/r0/user/%s/filter/12382148321" % (self.USER_ID) ) request.render(self.resource) wait_until_result(self.clock, channel) @@ -141,7 +141,7 @@ class FilterTestCase(unittest.TestCase): # in errors.py def test_get_filter_invalid_id(self): request, channel = make_request( - b"GET", "/_matrix/client/r0/user/%s/filter/foobar" % (self.USER_ID) + "GET", "/_matrix/client/r0/user/%s/filter/foobar" % (self.USER_ID) ) request.render(self.resource) wait_until_result(self.clock, channel) @@ -151,7 +151,7 @@ class FilterTestCase(unittest.TestCase): # No ID also returns an invalid_id error def test_get_filter_no_id(self): request, channel = make_request( - b"GET", "/_matrix/client/r0/user/%s/filter/" % (self.USER_ID) + "GET", "/_matrix/client/r0/user/%s/filter/" % (self.USER_ID) ) request.render(self.resource) wait_until_result(self.clock, channel) diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index e004d8fc73..f6293f11a8 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -81,7 +81,7 @@ class RegisterRestServletTestCase(unittest.TestCase): "access_token": token, "home_server": self.hs.hostname, } - self.assertDictContainsSubset(det_data, json.loads(channel.result["body"])) + self.assertDictContainsSubset(det_data, channel.json_body) def test_POST_appservice_registration_invalid(self): self.appservice = None # no application service exists @@ -102,7 +102,7 @@ class RegisterRestServletTestCase(unittest.TestCase): self.assertEquals(channel.result["code"], b"400", channel.result) self.assertEquals( - json.loads(channel.result["body"])["error"], "Invalid password" + channel.json_body["error"], "Invalid password" ) def test_POST_bad_username(self): @@ -113,7 +113,7 @@ class RegisterRestServletTestCase(unittest.TestCase): self.assertEquals(channel.result["code"], b"400", channel.result) self.assertEquals( - json.loads(channel.result["body"])["error"], "Invalid username" + channel.json_body["error"], "Invalid username" ) def test_POST_user_valid(self): @@ -140,7 +140,7 @@ class RegisterRestServletTestCase(unittest.TestCase): "device_id": device_id, } self.assertEquals(channel.result["code"], b"200", channel.result) - self.assertDictContainsSubset(det_data, json.loads(channel.result["body"])) + self.assertDictContainsSubset(det_data, channel.json_body) self.auth_handler.get_login_tuple_for_user_id( user_id, device_id=device_id, initial_device_display_name=None ) @@ -158,7 +158,7 @@ class RegisterRestServletTestCase(unittest.TestCase): self.assertEquals(channel.result["code"], b"403", channel.result) self.assertEquals( - json.loads(channel.result["body"])["error"], + channel.json_body["error"], "Registration has been disabled", ) @@ -178,7 +178,7 @@ class RegisterRestServletTestCase(unittest.TestCase): "device_id": "guest_device", } self.assertEquals(channel.result["code"], b"200", channel.result) - self.assertDictContainsSubset(det_data, json.loads(channel.result["body"])) + self.assertDictContainsSubset(det_data, channel.json_body) def test_POST_disabled_guest_registration(self): self.hs.config.allow_guest_access = False @@ -189,5 +189,5 @@ class RegisterRestServletTestCase(unittest.TestCase): self.assertEquals(channel.result["code"], b"403", channel.result) self.assertEquals( - json.loads(channel.result["body"])["error"], "Guest access is disabled" + channel.json_body["error"], "Guest access is disabled" ) diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py index 03ec3993b2..bafc0d1df0 100644 --- a/tests/rest/client/v2_alpha/test_sync.py +++ b/tests/rest/client/v2_alpha/test_sync.py @@ -32,7 +32,7 @@ PATH_PREFIX = "/_matrix/client/v2_alpha" class FilterTestCase(unittest.TestCase): - USER_ID = b"@apple:test" + USER_ID = "@apple:test" TO_REGISTER = [sync] def setUp(self): @@ -68,7 +68,7 @@ class FilterTestCase(unittest.TestCase): r.register_servlets(self.hs, self.resource) def test_sync_argless(self): - request, channel = make_request(b"GET", b"/_matrix/client/r0/sync") + request, channel = make_request("GET", "/_matrix/client/r0/sync") request.render(self.resource) wait_until_result(self.clock, channel) diff --git a/tests/server.py b/tests/server.py index c611dd6059..e249668d21 100644 --- a/tests/server.py +++ b/tests/server.py @@ -11,6 +11,7 @@ from twisted.python.failure import Failure from twisted.test.proto_helpers import MemoryReactorClock from synapse.http.site import SynapseRequest +from synapse.util import Clock from tests.utils import setup_test_homeserver as _sth @@ -28,7 +29,13 @@ class FakeChannel(object): def json_body(self): if not self.result: raise Exception("No result yet.") - return json.loads(self.result["body"]) + return json.loads(self.result["body"].decode('utf8')) + + @property + def code(self): + if not self.result: + raise Exception("No result yet.") + return int(self.result["code"]) def writeHeaders(self, version, code, reason, headers): self.result["version"] = version @@ -79,11 +86,16 @@ def make_request(method, path, content=b""): Make a web request using the given method and path, feed it the content, and return the Request and the Channel underneath. """ + if not isinstance(method, bytes): + method = method.encode('ascii') + + if not isinstance(path, bytes): + path = path.encode('ascii') # Decorate it to be the full path if not path.startswith(b"/_matrix"): path = b"/_matrix/client/r0/" + path - path = path.replace("//", "/") + path = path.replace(b"//", b"/") if isinstance(content, text_type): content = content.encode('utf8') @@ -191,3 +203,9 @@ def setup_test_homeserver(*args, **kwargs): clock.threadpool = ThreadPool() pool.threadpool = ThreadPool() return d + + +def get_clock(): + clock = ThreadedMemoryReactorClock() + hs_clock = Clock(clock) + return (clock, hs_clock) diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index 30683e7888..69412c5aad 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -49,7 +49,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.TestCase): 'INSERT INTO event_reference_hashes ' '(event_id, algorithm, hash) ' "VALUES (?, 'sha256', ?)" - ), (event_id, 'ffff')) + ), (event_id, b'ffff')) for i in range(0, 11): yield self.store.runInteraction("insert", insert_event, i) diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index 7a76d67b8c..f7871cd426 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -176,7 +176,7 @@ class StateStoreTestCase(tests.unittest.TestCase): room_id = self.room.to_string() group_ids = yield self.store.get_state_groups_ids(room_id, [e5.event_id]) - group = group_ids.keys()[0] + group = list(group_ids.keys())[0] # test _get_some_state_from_cache correctly filters out members with types=[] (state_dict, is_all) = yield self.store._get_some_state_from_cache( diff --git a/tests/test_server.py b/tests/test_server.py index 7e063c0290..fc396226ea 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -1,4 +1,3 @@ -import json import re from twisted.internet.defer import Deferred @@ -104,9 +103,8 @@ class JsonResourceTests(unittest.TestCase): request.render(res) self.assertEqual(channel.result["code"], b'403') - reply_body = json.loads(channel.result["body"]) - self.assertEqual(reply_body["error"], "Forbidden!!one!") - self.assertEqual(reply_body["errcode"], "M_FORBIDDEN") + self.assertEqual(channel.json_body["error"], "Forbidden!!one!") + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") def test_no_handler(self): """ @@ -126,6 +124,5 @@ class JsonResourceTests(unittest.TestCase): request.render(res) self.assertEqual(channel.result["code"], b'400') - reply_body = json.loads(channel.result["body"]) - self.assertEqual(reply_body["error"], "Unrecognized request") - self.assertEqual(reply_body["errcode"], "M_UNRECOGNIZED") + self.assertEqual(channel.json_body["error"], "Unrecognized request") + self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED") diff --git a/tests/utils.py b/tests/utils.py index 151db4b890..5d49692c58 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -153,8 +153,9 @@ def setup_test_homeserver(name="test", datastore=None, config=None, reactor=None # Need to let the HS build an auth handler and then mess with it # because AuthHandler's constructor requires the HS, so we can't make one # beforehand and pass it in to the HS's constructor (chicken / egg) - hs.get_auth_handler().hash = lambda p: hashlib.md5(p).hexdigest() - hs.get_auth_handler().validate_hash = lambda p, h: hashlib.md5(p).hexdigest() == h + hs.get_auth_handler().hash = lambda p: hashlib.md5(p.encode('utf8')).hexdigest() + hs.get_auth_handler().validate_hash = lambda p, h: hashlib.md5( + p.encode('utf8')).hexdigest() == h fed = kargs.get("resource_for_federation", None) if fed: @@ -227,8 +228,8 @@ class MockHttpResource(HttpServer): mock_content.configure_mock(**config) mock_request.content = mock_content - mock_request.method = http_method - mock_request.uri = path + mock_request.method = http_method.encode('ascii') + mock_request.uri = path.encode('ascii') mock_request.getClientIP.return_value = "-" -- cgit 1.4.1 From 69ce057ea613f425d5ef6ace03d0019a8e4fdf49 Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Thu, 9 Aug 2018 12:26:27 +0100 Subject: block sync if auth checks fail --- synapse/handlers/sync.py | 12 +++++------- tests/handlers/test_sync.py | 19 +++++++++++++------ 2 files changed, 18 insertions(+), 13 deletions(-) (limited to 'tests/handlers') diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index f748d9afb0..776ddca638 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -209,14 +209,12 @@ class SyncHandler(object): Deferred[SyncResult] """ yield self.auth.check_auth_blocking() - - defer.returnValue( - self.response_cache.wrap( - sync_config.request_key, - self._wait_for_sync_for_user, - sync_config, since_token, timeout, full_state, - ) + res = yield self.response_cache.wrap( + sync_config.request_key, + self._wait_for_sync_for_user, + sync_config, since_token, timeout, full_state, ) + defer.returnValue(res) @defer.inlineCallbacks def _wait_for_sync_for_user(self, sync_config, since_token, timeout, diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py index 3b1b4d4923..497e4bd933 100644 --- a/tests/handlers/test_sync.py +++ b/tests/handlers/test_sync.py @@ -14,11 +14,14 @@ # limitations under the License. from twisted.internet import defer +from synapse.api.errors import AuthError +from synapse.api.filtering import DEFAULT_FILTER_COLLECTION +from synapse.handlers.sync import SyncConfig, SyncHandler +from synapse.types import UserID + import tests.unittest import tests.utils from tests.utils import setup_test_homeserver -from synapse.handlers.sync import SyncHandler, SyncConfig -from synapse.types import UserID class SyncTestCase(tests.unittest.TestCase): @@ -32,11 +35,15 @@ class SyncTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def test_wait_for_sync_for_user_auth_blocking(self): sync_config = SyncConfig( - user=UserID("@user","server"), - filter_collection=None, + user=UserID("@user", "server"), + filter_collection=DEFAULT_FILTER_COLLECTION, is_guest=False, request_key="request_key", device_id="device_id", ) - res = yield self.sync_handler.wait_for_sync_for_user(sync_config) - print res + # Ensure that an exception is not thrown + yield self.sync_handler.wait_for_sync_for_user(sync_config) + self.hs.config.hs_disabled = True + + with self.assertRaises(AuthError): + yield self.sync_handler.wait_for_sync_for_user(sync_config) -- cgit 1.4.1 From 09cf13089858902f3cdcb49b9f9bc3d214ba6337 Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Thu, 9 Aug 2018 17:39:12 +0100 Subject: only block on sync where user is not part of the mau cohort --- synapse/api/auth.py | 13 +++++++++++-- synapse/handlers/sync.py | 7 ++++++- tests/handlers/test_sync.py | 40 +++++++++++++++++++++++++++++++--------- 3 files changed, 48 insertions(+), 12 deletions(-) (limited to 'tests/handlers') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 9c62ec4374..170039fc82 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -775,17 +775,26 @@ class Auth(object): ) @defer.inlineCallbacks - def check_auth_blocking(self): + def check_auth_blocking(self, user_id=None): """Checks if the user should be rejected for some external reason, such as monthly active user limiting or global disable flag + + Args: + user_id(str): If present, checks for presence against existing MAU cohort """ if self.hs.config.hs_disabled: raise AuthError( 403, self.hs.config.hs_disabled_message, errcode=Codes.HS_DISABLED ) if self.hs.config.limit_usage_by_mau is True: + # If the user is already part of the MAU cohort + if user_id: + timestamp = yield self.store._user_last_seen_monthly_active(user_id) + if timestamp: + return + # Else if there is no room in the MAU bucket, bail current_mau = yield self.store.get_monthly_active_count() if current_mau >= self.hs.config.max_mau_value: raise AuthError( 403, "MAU Limit Exceeded", errcode=Codes.MAU_LIMIT_EXCEEDED - ) + ) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 776ddca638..d3b26a4106 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -208,7 +208,12 @@ class SyncHandler(object): Returns: Deferred[SyncResult] """ - yield self.auth.check_auth_blocking() + # If the user is not part of the mau group, then check that limits have + # not been exceeded (if not part of the group by this point, almost certain + # auth_blocking will occur) + user_id = sync_config.user.to_string() + yield self.auth.check_auth_blocking(user_id) + res = yield self.response_cache.wrap( sync_config.request_key, self._wait_for_sync_for_user, diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py index 497e4bd933..b95a8743a7 100644 --- a/tests/handlers/test_sync.py +++ b/tests/handlers/test_sync.py @@ -13,8 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. from twisted.internet import defer +from synapse.api.errors import AuthError, Codes -from synapse.api.errors import AuthError from synapse.api.filtering import DEFAULT_FILTER_COLLECTION from synapse.handlers.sync import SyncConfig, SyncHandler from synapse.types import UserID @@ -31,19 +31,41 @@ class SyncTestCase(tests.unittest.TestCase): def setUp(self): self.hs = yield setup_test_homeserver() self.sync_handler = SyncHandler(self.hs) + self.store = self.hs.get_datastore() @defer.inlineCallbacks def test_wait_for_sync_for_user_auth_blocking(self): - sync_config = SyncConfig( - user=UserID("@user", "server"), + + user_id1 = "@user1:server" + user_id2 = "@user2:server" + sync_config = self._generate_sync_config(user_id1) + + self.hs.config.limit_usage_by_mau = True + self.hs.config.max_mau_value = 1 + + # Check that the happy case does not throw errors + yield self.store.upsert_monthly_active_user(user_id1) + yield self.sync_handler.wait_for_sync_for_user(sync_config) + + # Test that global lock works + self.hs.config.hs_disabled = True + with self.assertRaises(AuthError) as e: + yield self.sync_handler.wait_for_sync_for_user(sync_config) + self.assertEquals(e.exception.errcode, Codes.HS_DISABLED) + + self.hs.config.hs_disabled = False + + sync_config = self._generate_sync_config(user_id2) + + with self.assertRaises(AuthError) as e: + yield self.sync_handler.wait_for_sync_for_user(sync_config) + self.assertEquals(e.exception.errcode, Codes.MAU_LIMIT_EXCEEDED) + + def _generate_sync_config(self, user_id): + return SyncConfig( + user=UserID(user_id.split(":")[0][1:], user_id.split(":")[1]), filter_collection=DEFAULT_FILTER_COLLECTION, is_guest=False, request_key="request_key", device_id="device_id", ) - # Ensure that an exception is not thrown - yield self.sync_handler.wait_for_sync_for_user(sync_config) - self.hs.config.hs_disabled = True - - with self.assertRaises(AuthError): - yield self.sync_handler.wait_for_sync_for_user(sync_config) -- cgit 1.4.1 From 04df7142598b531e8e400611e3f92b21afeabab6 Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Thu, 9 Aug 2018 17:41:52 +0100 Subject: fix imports --- tests/handlers/test_sync.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'tests/handlers') diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py index b95a8743a7..cfd37f3138 100644 --- a/tests/handlers/test_sync.py +++ b/tests/handlers/test_sync.py @@ -13,8 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. from twisted.internet import defer -from synapse.api.errors import AuthError, Codes +from synapse.api.errors import AuthError, Codes from synapse.api.filtering import DEFAULT_FILTER_COLLECTION from synapse.handlers.sync import SyncConfig, SyncHandler from synapse.types import UserID -- cgit 1.4.1 From 8b3d9b6b199abb87246f982d5db356f1966db925 Mon Sep 17 00:00:00 2001 From: black Date: Fri, 10 Aug 2018 23:54:09 +1000 Subject: Run black. --- tests/api/test_auth.py | 76 ++--- tests/api/test_filtering.py | 361 ++++++--------------- tests/api/test_ratelimiting.py | 7 +- tests/appservice/test_appservice.py | 86 ++--- tests/appservice/test_scheduler.py | 28 +- tests/config/test_generate.py | 38 ++- tests/config/test_load.py | 38 ++- tests/crypto/test_event_signing.py | 19 +- tests/crypto/test_keyring.py | 46 +-- tests/events/test_utils.py | 135 +++----- tests/federation/test_federation_server.py | 26 +- tests/handlers/test_appservice.py | 39 +-- tests/handlers/test_auth.py | 16 +- tests/handlers/test_device.py | 98 +++--- tests/handlers/test_directory.py | 22 +- tests/handlers/test_e2e_keys.py | 99 ++---- tests/handlers/test_presence.py | 216 ++++++------ tests/handlers/test_profile.py | 25 +- tests/handlers/test_register.py | 12 +- tests/handlers/test_typing.py | 231 +++++++------ tests/http/test_endpoint.py | 6 +- tests/replication/slave/storage/_base.py | 5 +- .../replication/slave/storage/test_account_data.py | 14 +- tests/replication/slave/storage/test_events.py | 105 +++--- tests/replication/slave/storage/test_receipts.py | 6 +- tests/rest/client/test_transactions.py | 30 +- tests/rest/client/v1/test_admin.py | 5 +- tests/rest/client/v1/test_events.py | 33 +- tests/rest/client/v1/test_profile.py | 42 ++- tests/rest/client/v1/test_register.py | 1 + tests/rest/client/v1/test_typing.py | 47 +-- tests/rest/client/v1/utils.py | 70 ++-- tests/rest/client/v2_alpha/test_register.py | 17 +- tests/rest/media/v1/test_media_storage.py | 6 +- tests/server.py | 2 + tests/storage/test__base.py | 5 +- tests/storage/test_appservice.py | 178 ++++------ tests/storage/test_background_update.py | 17 +- tests/storage/test_base.py | 46 +-- tests/storage/test_client_ips.py | 15 +- tests/storage/test_devices.py | 69 ++-- tests/storage/test_directory.py | 22 +- tests/storage/test_end_to_end_keys.py | 61 ++-- tests/storage/test_event_federation.py | 39 ++- tests/storage/test_event_push_actions.py | 78 ++--- tests/storage/test_keys.py | 11 +- tests/storage/test_monthly_active_users.py | 12 +- tests/storage/test_presence.py | 69 ++-- tests/storage/test_profile.py | 18 +- tests/storage/test_redaction.py | 70 ++-- tests/storage/test_registration.py | 38 +-- tests/storage/test_room.py | 47 +-- tests/storage/test_roommember.py | 31 +- tests/storage/test_state.py | 249 ++++++++------ tests/storage/test_user_directory.py | 40 +-- tests/test_distributor.py | 10 +- tests/test_dns.py | 22 +- tests/test_event_auth.py | 98 +++--- tests/test_preview.py | 55 +--- tests/test_state.py | 245 ++++++-------- tests/test_test_utils.py | 5 +- tests/test_types.py | 5 +- tests/test_visibility.py | 145 +++++---- tests/unittest.py | 7 +- tests/util/caches/test_descriptors.py | 46 ++- tests/util/test_dict_cache.py | 19 +- tests/util/test_expiring_cache.py | 1 - tests/util/test_file_consumer.py | 5 +- tests/util/test_linearizer.py | 7 +- tests/util/test_logcontext.py | 14 +- tests/util/test_lrucache.py | 2 - tests/util/test_rwlock.py | 9 +- tests/util/test_snapshot_cache.py | 1 - tests/util/test_stream_change_cache.py | 13 +- tests/utils.py | 78 ++--- 75 files changed, 1629 insertions(+), 2280 deletions(-) (limited to 'tests/handlers') diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index fbb96361a8..f8e28876bb 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -34,7 +34,6 @@ class TestHandlers(object): class AuthTestCase(unittest.TestCase): - @defer.inlineCallbacks def setUp(self): self.state_handler = Mock() @@ -53,11 +52,7 @@ class AuthTestCase(unittest.TestCase): @defer.inlineCallbacks def test_get_user_by_req_user_valid_token(self): - user_info = { - "name": self.test_user, - "token_id": "ditto", - "device_id": "device", - } + user_info = {"name": self.test_user, "token_id": "ditto", "device_id": "device"} self.store.get_user_by_access_token = Mock(return_value=user_info) request = Mock(args={}) @@ -76,10 +71,7 @@ class AuthTestCase(unittest.TestCase): self.failureResultOf(d, AuthError) def test_get_user_by_req_user_missing_token(self): - user_info = { - "name": self.test_user, - "token_id": "ditto", - } + user_info = {"name": self.test_user, "token_id": "ditto"} self.store.get_user_by_access_token = Mock(return_value=user_info) request = Mock(args={}) @@ -90,8 +82,7 @@ class AuthTestCase(unittest.TestCase): @defer.inlineCallbacks def test_get_user_by_req_appservice_valid_token(self): app_service = Mock( - token="foobar", url="a_url", sender=self.test_user, - ip_range_whitelist=None, + token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None ) self.store.get_app_service_by_token = Mock(return_value=app_service) self.store.get_user_by_access_token = Mock(return_value=None) @@ -106,8 +97,11 @@ class AuthTestCase(unittest.TestCase): @defer.inlineCallbacks def test_get_user_by_req_appservice_valid_token_good_ip(self): from netaddr import IPSet + app_service = Mock( - token="foobar", url="a_url", sender=self.test_user, + token="foobar", + url="a_url", + sender=self.test_user, ip_range_whitelist=IPSet(["192.168/16"]), ) self.store.get_app_service_by_token = Mock(return_value=app_service) @@ -122,8 +116,11 @@ class AuthTestCase(unittest.TestCase): def test_get_user_by_req_appservice_valid_token_bad_ip(self): from netaddr import IPSet + app_service = Mock( - token="foobar", url="a_url", sender=self.test_user, + token="foobar", + url="a_url", + sender=self.test_user, ip_range_whitelist=IPSet(["192.168/16"]), ) self.store.get_app_service_by_token = Mock(return_value=app_service) @@ -160,8 +157,7 @@ class AuthTestCase(unittest.TestCase): def test_get_user_by_req_appservice_valid_token_valid_user_id(self): masquerading_user_id = b"@doppelganger:matrix.org" app_service = Mock( - token="foobar", url="a_url", sender=self.test_user, - ip_range_whitelist=None, + token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None ) app_service.is_interested_in_user = Mock(return_value=True) self.store.get_app_service_by_token = Mock(return_value=app_service) @@ -174,15 +170,13 @@ class AuthTestCase(unittest.TestCase): request.requestHeaders.getRawHeaders = mock_getRawHeaders() requester = yield self.auth.get_user_by_req(request) self.assertEquals( - requester.user.to_string(), - masquerading_user_id.decode('utf8') + requester.user.to_string(), masquerading_user_id.decode('utf8') ) def test_get_user_by_req_appservice_valid_token_bad_user_id(self): masquerading_user_id = b"@doppelganger:matrix.org" app_service = Mock( - token="foobar", url="a_url", sender=self.test_user, - ip_range_whitelist=None, + token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None ) app_service.is_interested_in_user = Mock(return_value=False) self.store.get_app_service_by_token = Mock(return_value=app_service) @@ -201,17 +195,15 @@ class AuthTestCase(unittest.TestCase): # TODO(danielwh): Remove this mock when we remove the # get_user_by_access_token fallback. self.store.get_user_by_access_token = Mock( - return_value={ - "name": "@baldrick:matrix.org", - "device_id": "device", - } + return_value={"name": "@baldrick:matrix.org", "device_id": "device"} ) user_id = "@baldrick:matrix.org" macaroon = pymacaroons.Macaroon( location=self.hs.config.server_name, identifier="key", - key=self.hs.config.macaroon_secret_key) + key=self.hs.config.macaroon_secret_key, + ) macaroon.add_first_party_caveat("gen = 1") macaroon.add_first_party_caveat("type = access") macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) @@ -225,15 +217,14 @@ class AuthTestCase(unittest.TestCase): @defer.inlineCallbacks def test_get_guest_user_from_macaroon(self): - self.store.get_user_by_id = Mock(return_value={ - "is_guest": True, - }) + self.store.get_user_by_id = Mock(return_value={"is_guest": True}) user_id = "@baldrick:matrix.org" macaroon = pymacaroons.Macaroon( location=self.hs.config.server_name, identifier="key", - key=self.hs.config.macaroon_secret_key) + key=self.hs.config.macaroon_secret_key, + ) macaroon.add_first_party_caveat("gen = 1") macaroon.add_first_party_caveat("type = access") macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) @@ -257,7 +248,8 @@ class AuthTestCase(unittest.TestCase): macaroon = pymacaroons.Macaroon( location=self.hs.config.server_name, identifier="key", - key=self.hs.config.macaroon_secret_key) + key=self.hs.config.macaroon_secret_key, + ) macaroon.add_first_party_caveat("gen = 1") macaroon.add_first_party_caveat("type = access") macaroon.add_first_party_caveat("user_id = %s" % (user,)) @@ -277,7 +269,8 @@ class AuthTestCase(unittest.TestCase): macaroon = pymacaroons.Macaroon( location=self.hs.config.server_name, identifier="key", - key=self.hs.config.macaroon_secret_key) + key=self.hs.config.macaroon_secret_key, + ) macaroon.add_first_party_caveat("gen = 1") macaroon.add_first_party_caveat("type = access") @@ -298,7 +291,8 @@ class AuthTestCase(unittest.TestCase): macaroon = pymacaroons.Macaroon( location=self.hs.config.server_name, identifier="key", - key=self.hs.config.macaroon_secret_key + "wrong") + key=self.hs.config.macaroon_secret_key + "wrong", + ) macaroon.add_first_party_caveat("gen = 1") macaroon.add_first_party_caveat("type = access") macaroon.add_first_party_caveat("user_id = %s" % (user,)) @@ -320,7 +314,8 @@ class AuthTestCase(unittest.TestCase): macaroon = pymacaroons.Macaroon( location=self.hs.config.server_name, identifier="key", - key=self.hs.config.macaroon_secret_key) + key=self.hs.config.macaroon_secret_key, + ) macaroon.add_first_party_caveat("gen = 1") macaroon.add_first_party_caveat("type = access") macaroon.add_first_party_caveat("user_id = %s" % (user,)) @@ -347,7 +342,8 @@ class AuthTestCase(unittest.TestCase): macaroon = pymacaroons.Macaroon( location=self.hs.config.server_name, identifier="key", - key=self.hs.config.macaroon_secret_key) + key=self.hs.config.macaroon_secret_key, + ) macaroon.add_first_party_caveat("gen = 1") macaroon.add_first_party_caveat("type = access") macaroon.add_first_party_caveat("user_id = %s" % (user,)) @@ -380,7 +376,8 @@ class AuthTestCase(unittest.TestCase): macaroon = pymacaroons.Macaroon( location=self.hs.config.server_name, identifier="key", - key=self.hs.config.macaroon_secret_key) + key=self.hs.config.macaroon_secret_key, + ) macaroon.add_first_party_caveat("gen = 1") macaroon.add_first_party_caveat("type = access") macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) @@ -401,9 +398,7 @@ class AuthTestCase(unittest.TestCase): token = yield self.hs.handlers.auth_handler.issue_access_token( USER_ID, "DEVICE" ) - self.store.add_access_token_to_user.assert_called_with( - USER_ID, token, "DEVICE" - ) + self.store.add_access_token_to_user.assert_called_with(USER_ID, token, "DEVICE") def get_user(tok): if token != tok: @@ -414,10 +409,9 @@ class AuthTestCase(unittest.TestCase): "token_id": 1234, "device_id": "DEVICE", } + self.store.get_user_by_access_token = get_user - self.store.get_user_by_id = Mock(return_value={ - "is_guest": False, - }) + self.store.get_user_by_id = Mock(return_value={"is_guest": False}) # check the token works request = Mock(args={}) diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index 836a23fb54..1c2d71052c 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -38,7 +38,6 @@ def MockEvent(**kwargs): class FilteringTestCase(unittest.TestCase): - @defer.inlineCallbacks def setUp(self): self.mock_federation_resource = MockHttpResource() @@ -47,9 +46,7 @@ class FilteringTestCase(unittest.TestCase): self.mock_http_client.put_json = DeferredMockCallable() hs = yield setup_test_homeserver( - handlers=None, - http_client=self.mock_http_client, - keyring=Mock(), + handlers=None, http_client=self.mock_http_client, keyring=Mock() ) self.filtering = hs.get_filtering() @@ -64,7 +61,7 @@ class FilteringTestCase(unittest.TestCase): {"room": {"timeline": {"limit": 0}, "state": {"not_bars": ["*"]}}}, {"event_format": "other"}, {"room": {"not_rooms": ["#foo:pik-test"]}}, - {"presence": {"senders": ["@bar;pik.test.com"]}} + {"presence": {"senders": ["@bar;pik.test.com"]}}, ] for filter in invalid_filters: with self.assertRaises(SynapseError) as check_filter_error: @@ -81,34 +78,34 @@ class FilteringTestCase(unittest.TestCase): "include_leave": False, "rooms": ["!dee:pik-test"], "not_rooms": ["!gee:pik-test"], - "account_data": {"limit": 0, "types": ["*"]} + "account_data": {"limit": 0, "types": ["*"]}, } }, { "room": { "state": { "types": ["m.room.*"], - "not_rooms": ["!726s6s6q:example.com"] + "not_rooms": ["!726s6s6q:example.com"], }, "timeline": { "limit": 10, "types": ["m.room.message"], "not_rooms": ["!726s6s6q:example.com"], - "not_senders": ["@spam:example.com"] + "not_senders": ["@spam:example.com"], }, "ephemeral": { "types": ["m.receipt", "m.typing"], "not_rooms": ["!726s6s6q:example.com"], - "not_senders": ["@spam:example.com"] - } + "not_senders": ["@spam:example.com"], + }, }, "presence": { "types": ["m.presence"], - "not_senders": ["@alice:example.com"] + "not_senders": ["@alice:example.com"], }, "event_format": "client", - "event_fields": ["type", "content", "sender"] - } + "event_fields": ["type", "content", "sender"], + }, ] for filter in valid_filters: try: @@ -121,229 +118,131 @@ class FilteringTestCase(unittest.TestCase): pass def test_definition_types_works_with_literals(self): - definition = { - "types": ["m.room.message", "org.matrix.foo.bar"] - } - event = MockEvent( - sender="@foo:bar", - type="m.room.message", - room_id="!foo:bar" - ) + definition = {"types": ["m.room.message", "org.matrix.foo.bar"]} + event = MockEvent(sender="@foo:bar", type="m.room.message", room_id="!foo:bar") - self.assertTrue( - Filter(definition).check(event) - ) + self.assertTrue(Filter(definition).check(event)) def test_definition_types_works_with_wildcards(self): - definition = { - "types": ["m.*", "org.matrix.foo.bar"] - } - event = MockEvent( - sender="@foo:bar", - type="m.room.message", - room_id="!foo:bar" - ) - self.assertTrue( - Filter(definition).check(event) - ) + definition = {"types": ["m.*", "org.matrix.foo.bar"]} + event = MockEvent(sender="@foo:bar", type="m.room.message", room_id="!foo:bar") + self.assertTrue(Filter(definition).check(event)) def test_definition_types_works_with_unknowns(self): - definition = { - "types": ["m.room.message", "org.matrix.foo.bar"] - } + definition = {"types": ["m.room.message", "org.matrix.foo.bar"]} event = MockEvent( sender="@foo:bar", type="now.for.something.completely.different", - room_id="!foo:bar" - ) - self.assertFalse( - Filter(definition).check(event) + room_id="!foo:bar", ) + self.assertFalse(Filter(definition).check(event)) def test_definition_not_types_works_with_literals(self): - definition = { - "not_types": ["m.room.message", "org.matrix.foo.bar"] - } - event = MockEvent( - sender="@foo:bar", - type="m.room.message", - room_id="!foo:bar" - ) - self.assertFalse( - Filter(definition).check(event) - ) + definition = {"not_types": ["m.room.message", "org.matrix.foo.bar"]} + event = MockEvent(sender="@foo:bar", type="m.room.message", room_id="!foo:bar") + self.assertFalse(Filter(definition).check(event)) def test_definition_not_types_works_with_wildcards(self): - definition = { - "not_types": ["m.room.message", "org.matrix.*"] - } + definition = {"not_types": ["m.room.message", "org.matrix.*"]} event = MockEvent( - sender="@foo:bar", - type="org.matrix.custom.event", - room_id="!foo:bar" - ) - self.assertFalse( - Filter(definition).check(event) + sender="@foo:bar", type="org.matrix.custom.event", room_id="!foo:bar" ) + self.assertFalse(Filter(definition).check(event)) def test_definition_not_types_works_with_unknowns(self): - definition = { - "not_types": ["m.*", "org.*"] - } - event = MockEvent( - sender="@foo:bar", - type="com.nom.nom.nom", - room_id="!foo:bar" - ) - self.assertTrue( - Filter(definition).check(event) - ) + definition = {"not_types": ["m.*", "org.*"]} + event = MockEvent(sender="@foo:bar", type="com.nom.nom.nom", room_id="!foo:bar") + self.assertTrue(Filter(definition).check(event)) def test_definition_not_types_takes_priority_over_types(self): definition = { "not_types": ["m.*", "org.*"], - "types": ["m.room.message", "m.room.topic"] + "types": ["m.room.message", "m.room.topic"], } - event = MockEvent( - sender="@foo:bar", - type="m.room.topic", - room_id="!foo:bar" - ) - self.assertFalse( - Filter(definition).check(event) - ) + event = MockEvent(sender="@foo:bar", type="m.room.topic", room_id="!foo:bar") + self.assertFalse(Filter(definition).check(event)) def test_definition_senders_works_with_literals(self): - definition = { - "senders": ["@flibble:wibble"] - } + definition = {"senders": ["@flibble:wibble"]} event = MockEvent( - sender="@flibble:wibble", - type="com.nom.nom.nom", - room_id="!foo:bar" - ) - self.assertTrue( - Filter(definition).check(event) + sender="@flibble:wibble", type="com.nom.nom.nom", room_id="!foo:bar" ) + self.assertTrue(Filter(definition).check(event)) def test_definition_senders_works_with_unknowns(self): - definition = { - "senders": ["@flibble:wibble"] - } + definition = {"senders": ["@flibble:wibble"]} event = MockEvent( - sender="@challenger:appears", - type="com.nom.nom.nom", - room_id="!foo:bar" - ) - self.assertFalse( - Filter(definition).check(event) + sender="@challenger:appears", type="com.nom.nom.nom", room_id="!foo:bar" ) + self.assertFalse(Filter(definition).check(event)) def test_definition_not_senders_works_with_literals(self): - definition = { - "not_senders": ["@flibble:wibble"] - } + definition = {"not_senders": ["@flibble:wibble"]} event = MockEvent( - sender="@flibble:wibble", - type="com.nom.nom.nom", - room_id="!foo:bar" - ) - self.assertFalse( - Filter(definition).check(event) + sender="@flibble:wibble", type="com.nom.nom.nom", room_id="!foo:bar" ) + self.assertFalse(Filter(definition).check(event)) def test_definition_not_senders_works_with_unknowns(self): - definition = { - "not_senders": ["@flibble:wibble"] - } + definition = {"not_senders": ["@flibble:wibble"]} event = MockEvent( - sender="@challenger:appears", - type="com.nom.nom.nom", - room_id="!foo:bar" - ) - self.assertTrue( - Filter(definition).check(event) + sender="@challenger:appears", type="com.nom.nom.nom", room_id="!foo:bar" ) + self.assertTrue(Filter(definition).check(event)) def test_definition_not_senders_takes_priority_over_senders(self): definition = { "not_senders": ["@misspiggy:muppets"], - "senders": ["@kermit:muppets", "@misspiggy:muppets"] + "senders": ["@kermit:muppets", "@misspiggy:muppets"], } event = MockEvent( - sender="@misspiggy:muppets", - type="m.room.topic", - room_id="!foo:bar" - ) - self.assertFalse( - Filter(definition).check(event) + sender="@misspiggy:muppets", type="m.room.topic", room_id="!foo:bar" ) + self.assertFalse(Filter(definition).check(event)) def test_definition_rooms_works_with_literals(self): - definition = { - "rooms": ["!secretbase:unknown"] - } + definition = {"rooms": ["!secretbase:unknown"]} event = MockEvent( - sender="@foo:bar", - type="m.room.message", - room_id="!secretbase:unknown" - ) - self.assertTrue( - Filter(definition).check(event) + sender="@foo:bar", type="m.room.message", room_id="!secretbase:unknown" ) + self.assertTrue(Filter(definition).check(event)) def test_definition_rooms_works_with_unknowns(self): - definition = { - "rooms": ["!secretbase:unknown"] - } + definition = {"rooms": ["!secretbase:unknown"]} event = MockEvent( sender="@foo:bar", type="m.room.message", - room_id="!anothersecretbase:unknown" - ) - self.assertFalse( - Filter(definition).check(event) + room_id="!anothersecretbase:unknown", ) + self.assertFalse(Filter(definition).check(event)) def test_definition_not_rooms_works_with_literals(self): - definition = { - "not_rooms": ["!anothersecretbase:unknown"] - } + definition = {"not_rooms": ["!anothersecretbase:unknown"]} event = MockEvent( sender="@foo:bar", type="m.room.message", - room_id="!anothersecretbase:unknown" - ) - self.assertFalse( - Filter(definition).check(event) + room_id="!anothersecretbase:unknown", ) + self.assertFalse(Filter(definition).check(event)) def test_definition_not_rooms_works_with_unknowns(self): - definition = { - "not_rooms": ["!secretbase:unknown"] - } + definition = {"not_rooms": ["!secretbase:unknown"]} event = MockEvent( sender="@foo:bar", type="m.room.message", - room_id="!anothersecretbase:unknown" - ) - self.assertTrue( - Filter(definition).check(event) + room_id="!anothersecretbase:unknown", ) + self.assertTrue(Filter(definition).check(event)) def test_definition_not_rooms_takes_priority_over_rooms(self): definition = { "not_rooms": ["!secretbase:unknown"], - "rooms": ["!secretbase:unknown"] + "rooms": ["!secretbase:unknown"], } event = MockEvent( - sender="@foo:bar", - type="m.room.message", - room_id="!secretbase:unknown" - ) - self.assertFalse( - Filter(definition).check(event) + sender="@foo:bar", type="m.room.message", room_id="!secretbase:unknown" ) + self.assertFalse(Filter(definition).check(event)) def test_definition_combined_event(self): definition = { @@ -352,16 +251,14 @@ class FilteringTestCase(unittest.TestCase): "rooms": ["!stage:unknown"], "not_rooms": ["!piggyshouse:muppets"], "types": ["m.room.message", "muppets.kermit.*"], - "not_types": ["muppets.misspiggy.*"] + "not_types": ["muppets.misspiggy.*"], } event = MockEvent( sender="@kermit:muppets", # yup type="m.room.message", # yup - room_id="!stage:unknown" # yup - ) - self.assertTrue( - Filter(definition).check(event) + room_id="!stage:unknown", # yup ) + self.assertTrue(Filter(definition).check(event)) def test_definition_combined_event_bad_sender(self): definition = { @@ -370,16 +267,14 @@ class FilteringTestCase(unittest.TestCase): "rooms": ["!stage:unknown"], "not_rooms": ["!piggyshouse:muppets"], "types": ["m.room.message", "muppets.kermit.*"], - "not_types": ["muppets.misspiggy.*"] + "not_types": ["muppets.misspiggy.*"], } event = MockEvent( sender="@misspiggy:muppets", # nope type="m.room.message", # yup - room_id="!stage:unknown" # yup - ) - self.assertFalse( - Filter(definition).check(event) + room_id="!stage:unknown", # yup ) + self.assertFalse(Filter(definition).check(event)) def test_definition_combined_event_bad_room(self): definition = { @@ -388,16 +283,14 @@ class FilteringTestCase(unittest.TestCase): "rooms": ["!stage:unknown"], "not_rooms": ["!piggyshouse:muppets"], "types": ["m.room.message", "muppets.kermit.*"], - "not_types": ["muppets.misspiggy.*"] + "not_types": ["muppets.misspiggy.*"], } event = MockEvent( sender="@kermit:muppets", # yup type="m.room.message", # yup - room_id="!piggyshouse:muppets" # nope - ) - self.assertFalse( - Filter(definition).check(event) + room_id="!piggyshouse:muppets", # nope ) + self.assertFalse(Filter(definition).check(event)) def test_definition_combined_event_bad_type(self): definition = { @@ -406,37 +299,26 @@ class FilteringTestCase(unittest.TestCase): "rooms": ["!stage:unknown"], "not_rooms": ["!piggyshouse:muppets"], "types": ["m.room.message", "muppets.kermit.*"], - "not_types": ["muppets.misspiggy.*"] + "not_types": ["muppets.misspiggy.*"], } event = MockEvent( sender="@kermit:muppets", # yup type="muppets.misspiggy.kisses", # nope - room_id="!stage:unknown" # yup - ) - self.assertFalse( - Filter(definition).check(event) + room_id="!stage:unknown", # yup ) + self.assertFalse(Filter(definition).check(event)) @defer.inlineCallbacks def test_filter_presence_match(self): - user_filter_json = { - "presence": { - "types": ["m.*"] - } - } + user_filter_json = {"presence": {"types": ["m.*"]}} filter_id = yield self.datastore.add_user_filter( - user_localpart=user_localpart, - user_filter=user_filter_json, - ) - event = MockEvent( - sender="@foo:bar", - type="m.profile", + user_localpart=user_localpart, user_filter=user_filter_json ) + event = MockEvent(sender="@foo:bar", type="m.profile") events = [event] user_filter = yield self.filtering.get_user_filter( - user_localpart=user_localpart, - filter_id=filter_id, + user_localpart=user_localpart, filter_id=filter_id ) results = user_filter.filter_presence(events=events) @@ -444,15 +326,10 @@ class FilteringTestCase(unittest.TestCase): @defer.inlineCallbacks def test_filter_presence_no_match(self): - user_filter_json = { - "presence": { - "types": ["m.*"] - } - } + user_filter_json = {"presence": {"types": ["m.*"]}} filter_id = yield self.datastore.add_user_filter( - user_localpart=user_localpart + "2", - user_filter=user_filter_json, + user_localpart=user_localpart + "2", user_filter=user_filter_json ) event = MockEvent( event_id="$asdasd:localhost", @@ -462,8 +339,7 @@ class FilteringTestCase(unittest.TestCase): events = [event] user_filter = yield self.filtering.get_user_filter( - user_localpart=user_localpart + "2", - filter_id=filter_id, + user_localpart=user_localpart + "2", filter_id=filter_id ) results = user_filter.filter_presence(events=events) @@ -471,27 +347,15 @@ class FilteringTestCase(unittest.TestCase): @defer.inlineCallbacks def test_filter_room_state_match(self): - user_filter_json = { - "room": { - "state": { - "types": ["m.*"] - } - } - } + user_filter_json = {"room": {"state": {"types": ["m.*"]}}} filter_id = yield self.datastore.add_user_filter( - user_localpart=user_localpart, - user_filter=user_filter_json, - ) - event = MockEvent( - sender="@foo:bar", - type="m.room.topic", - room_id="!foo:bar" + user_localpart=user_localpart, user_filter=user_filter_json ) + event = MockEvent(sender="@foo:bar", type="m.room.topic", room_id="!foo:bar") events = [event] user_filter = yield self.filtering.get_user_filter( - user_localpart=user_localpart, - filter_id=filter_id, + user_localpart=user_localpart, filter_id=filter_id ) results = user_filter.filter_room_state(events=events) @@ -499,27 +363,17 @@ class FilteringTestCase(unittest.TestCase): @defer.inlineCallbacks def test_filter_room_state_no_match(self): - user_filter_json = { - "room": { - "state": { - "types": ["m.*"] - } - } - } + user_filter_json = {"room": {"state": {"types": ["m.*"]}}} filter_id = yield self.datastore.add_user_filter( - user_localpart=user_localpart, - user_filter=user_filter_json, + user_localpart=user_localpart, user_filter=user_filter_json ) event = MockEvent( - sender="@foo:bar", - type="org.matrix.custom.event", - room_id="!foo:bar" + sender="@foo:bar", type="org.matrix.custom.event", room_id="!foo:bar" ) events = [event] user_filter = yield self.filtering.get_user_filter( - user_localpart=user_localpart, - filter_id=filter_id, + user_localpart=user_localpart, filter_id=filter_id ) results = user_filter.filter_room_state(events) @@ -543,45 +397,32 @@ class FilteringTestCase(unittest.TestCase): @defer.inlineCallbacks def test_add_filter(self): - user_filter_json = { - "room": { - "state": { - "types": ["m.*"] - } - } - } + user_filter_json = {"room": {"state": {"types": ["m.*"]}}} filter_id = yield self.filtering.add_user_filter( - user_localpart=user_localpart, - user_filter=user_filter_json, + user_localpart=user_localpart, user_filter=user_filter_json ) self.assertEquals(filter_id, 0) - self.assertEquals(user_filter_json, ( - yield self.datastore.get_user_filter( - user_localpart=user_localpart, - filter_id=0, - ) - )) + self.assertEquals( + user_filter_json, + ( + yield self.datastore.get_user_filter( + user_localpart=user_localpart, filter_id=0 + ) + ), + ) @defer.inlineCallbacks def test_get_filter(self): - user_filter_json = { - "room": { - "state": { - "types": ["m.*"] - } - } - } + user_filter_json = {"room": {"state": {"types": ["m.*"]}}} filter_id = yield self.datastore.add_user_filter( - user_localpart=user_localpart, - user_filter=user_filter_json, + user_localpart=user_localpart, user_filter=user_filter_json ) filter = yield self.filtering.get_user_filter( - user_localpart=user_localpart, - filter_id=filter_id, + user_localpart=user_localpart, filter_id=filter_id ) self.assertEquals(filter.get_filter_json(), user_filter_json) diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py index c45b59b36c..8933fe3b72 100644 --- a/tests/api/test_ratelimiting.py +++ b/tests/api/test_ratelimiting.py @@ -4,17 +4,16 @@ from tests import unittest class TestRatelimiter(unittest.TestCase): - def test_allowed(self): limiter = Ratelimiter() allowed, time_allowed = limiter.send_message( - user_id="test_id", time_now_s=0, msg_rate_hz=0.1, burst_count=1, + user_id="test_id", time_now_s=0, msg_rate_hz=0.1, burst_count=1 ) self.assertTrue(allowed) self.assertEquals(10., time_allowed) allowed, time_allowed = limiter.send_message( - user_id="test_id", time_now_s=5, msg_rate_hz=0.1, burst_count=1, + user_id="test_id", time_now_s=5, msg_rate_hz=0.1, burst_count=1 ) self.assertFalse(allowed) self.assertEquals(10., time_allowed) @@ -28,7 +27,7 @@ class TestRatelimiter(unittest.TestCase): def test_pruning(self): limiter = Ratelimiter() allowed, time_allowed = limiter.send_message( - user_id="test_id_1", time_now_s=0, msg_rate_hz=0.1, burst_count=1, + user_id="test_id_1", time_now_s=0, msg_rate_hz=0.1, burst_count=1 ) self.assertIn("test_id_1", limiter.message_counts) diff --git a/tests/appservice/test_appservice.py b/tests/appservice/test_appservice.py index 891e0cc973..4003869ed6 100644 --- a/tests/appservice/test_appservice.py +++ b/tests/appservice/test_appservice.py @@ -24,14 +24,10 @@ from tests import unittest def _regex(regex, exclusive=True): - return { - "regex": re.compile(regex), - "exclusive": exclusive - } + return {"regex": re.compile(regex), "exclusive": exclusive} class ApplicationServiceTestCase(unittest.TestCase): - def setUp(self): self.service = ApplicationService( id="unique_identifier", @@ -41,8 +37,8 @@ class ApplicationServiceTestCase(unittest.TestCase): namespaces={ ApplicationService.NS_USERS: [], ApplicationService.NS_ROOMS: [], - ApplicationService.NS_ALIASES: [] - } + ApplicationService.NS_ALIASES: [], + }, ) self.event = Mock( type="m.something", room_id="!foo:bar", sender="@someone:somewhere" @@ -52,25 +48,19 @@ class ApplicationServiceTestCase(unittest.TestCase): @defer.inlineCallbacks def test_regex_user_id_prefix_match(self): - self.service.namespaces[ApplicationService.NS_USERS].append( - _regex("@irc_.*") - ) + self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) self.event.sender = "@irc_foobar:matrix.org" self.assertTrue((yield self.service.is_interested(self.event))) @defer.inlineCallbacks def test_regex_user_id_prefix_no_match(self): - self.service.namespaces[ApplicationService.NS_USERS].append( - _regex("@irc_.*") - ) + self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) self.event.sender = "@someone_else:matrix.org" self.assertFalse((yield self.service.is_interested(self.event))) @defer.inlineCallbacks def test_regex_room_member_is_checked(self): - self.service.namespaces[ApplicationService.NS_USERS].append( - _regex("@irc_.*") - ) + self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) self.event.sender = "@someone_else:matrix.org" self.event.type = "m.room.member" self.event.state_key = "@irc_foobar:matrix.org" @@ -98,60 +88,47 @@ class ApplicationServiceTestCase(unittest.TestCase): _regex("#irc_.*:matrix.org") ) self.store.get_aliases_for_room.return_value = [ - "#irc_foobar:matrix.org", "#athing:matrix.org" + "#irc_foobar:matrix.org", + "#athing:matrix.org", ] self.store.get_users_in_room.return_value = [] - self.assertTrue((yield self.service.is_interested( - self.event, self.store - ))) + self.assertTrue((yield self.service.is_interested(self.event, self.store))) def test_non_exclusive_alias(self): self.service.namespaces[ApplicationService.NS_ALIASES].append( _regex("#irc_.*:matrix.org", exclusive=False) ) - self.assertFalse(self.service.is_exclusive_alias( - "#irc_foobar:matrix.org" - )) + self.assertFalse(self.service.is_exclusive_alias("#irc_foobar:matrix.org")) def test_non_exclusive_room(self): self.service.namespaces[ApplicationService.NS_ROOMS].append( _regex("!irc_.*:matrix.org", exclusive=False) ) - self.assertFalse(self.service.is_exclusive_room( - "!irc_foobar:matrix.org" - )) + self.assertFalse(self.service.is_exclusive_room("!irc_foobar:matrix.org")) def test_non_exclusive_user(self): self.service.namespaces[ApplicationService.NS_USERS].append( _regex("@irc_.*:matrix.org", exclusive=False) ) - self.assertFalse(self.service.is_exclusive_user( - "@irc_foobar:matrix.org" - )) + self.assertFalse(self.service.is_exclusive_user("@irc_foobar:matrix.org")) def test_exclusive_alias(self): self.service.namespaces[ApplicationService.NS_ALIASES].append( _regex("#irc_.*:matrix.org", exclusive=True) ) - self.assertTrue(self.service.is_exclusive_alias( - "#irc_foobar:matrix.org" - )) + self.assertTrue(self.service.is_exclusive_alias("#irc_foobar:matrix.org")) def test_exclusive_user(self): self.service.namespaces[ApplicationService.NS_USERS].append( _regex("@irc_.*:matrix.org", exclusive=True) ) - self.assertTrue(self.service.is_exclusive_user( - "@irc_foobar:matrix.org" - )) + self.assertTrue(self.service.is_exclusive_user("@irc_foobar:matrix.org")) def test_exclusive_room(self): self.service.namespaces[ApplicationService.NS_ROOMS].append( _regex("!irc_.*:matrix.org", exclusive=True) ) - self.assertTrue(self.service.is_exclusive_room( - "!irc_foobar:matrix.org" - )) + self.assertTrue(self.service.is_exclusive_room("!irc_foobar:matrix.org")) @defer.inlineCallbacks def test_regex_alias_no_match(self): @@ -159,47 +136,36 @@ class ApplicationServiceTestCase(unittest.TestCase): _regex("#irc_.*:matrix.org") ) self.store.get_aliases_for_room.return_value = [ - "#xmpp_foobar:matrix.org", "#athing:matrix.org" + "#xmpp_foobar:matrix.org", + "#athing:matrix.org", ] self.store.get_users_in_room.return_value = [] - self.assertFalse((yield self.service.is_interested( - self.event, self.store - ))) + self.assertFalse((yield self.service.is_interested(self.event, self.store))) @defer.inlineCallbacks def test_regex_multiple_matches(self): self.service.namespaces[ApplicationService.NS_ALIASES].append( _regex("#irc_.*:matrix.org") ) - self.service.namespaces[ApplicationService.NS_USERS].append( - _regex("@irc_.*") - ) + self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) self.event.sender = "@irc_foobar:matrix.org" self.store.get_aliases_for_room.return_value = ["#irc_barfoo:matrix.org"] self.store.get_users_in_room.return_value = [] - self.assertTrue((yield self.service.is_interested( - self.event, self.store - ))) + self.assertTrue((yield self.service.is_interested(self.event, self.store))) @defer.inlineCallbacks def test_interested_in_self(self): # make sure invites get through self.service.sender = "@appservice:name" - self.service.namespaces[ApplicationService.NS_USERS].append( - _regex("@irc_.*") - ) + self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) self.event.type = "m.room.member" - self.event.content = { - "membership": "invite" - } + self.event.content = {"membership": "invite"} self.event.state_key = self.service.sender self.assertTrue((yield self.service.is_interested(self.event))) @defer.inlineCallbacks def test_member_list_match(self): - self.service.namespaces[ApplicationService.NS_USERS].append( - _regex("@irc_.*") - ) + self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) self.store.get_users_in_room.return_value = [ "@alice:here", "@irc_fo:here", # AS user @@ -208,6 +174,6 @@ class ApplicationServiceTestCase(unittest.TestCase): self.store.get_aliases_for_room.return_value = [] self.event.sender = "@xmpp_foobar:matrix.org" - self.assertTrue((yield self.service.is_interested( - event=self.event, store=self.store - ))) + self.assertTrue( + (yield self.service.is_interested(event=self.event, store=self.store)) + ) diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py index b9f4863e9a..db9f86bdac 100644 --- a/tests/appservice/test_scheduler.py +++ b/tests/appservice/test_scheduler.py @@ -30,7 +30,6 @@ from ..utils import MockClock class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): - def setUp(self): self.clock = MockClock() self.store = Mock() @@ -38,8 +37,10 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): self.recoverer = Mock() self.recoverer_fn = Mock(return_value=self.recoverer) self.txnctrl = _TransactionController( - clock=self.clock, store=self.store, as_api=self.as_api, - recoverer_fn=self.recoverer_fn + clock=self.clock, + store=self.store, + as_api=self.as_api, + recoverer_fn=self.recoverer_fn, ) def test_single_service_up_txn_sent(self): @@ -54,9 +55,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): return_value=defer.succeed(ApplicationServiceState.UP) ) txn.send = Mock(return_value=defer.succeed(True)) - self.store.create_appservice_txn = Mock( - return_value=defer.succeed(txn) - ) + self.store.create_appservice_txn = Mock(return_value=defer.succeed(txn)) # actual call self.txnctrl.send(service, events) @@ -77,9 +76,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): self.store.get_appservice_state = Mock( return_value=defer.succeed(ApplicationServiceState.DOWN) ) - self.store.create_appservice_txn = Mock( - return_value=defer.succeed(txn) - ) + self.store.create_appservice_txn = Mock(return_value=defer.succeed(txn)) # actual call self.txnctrl.send(service, events) @@ -104,9 +101,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): ) self.store.set_appservice_state = Mock(return_value=defer.succeed(True)) txn.send = Mock(return_value=defer.succeed(False)) # fails to send - self.store.create_appservice_txn = Mock( - return_value=defer.succeed(txn) - ) + self.store.create_appservice_txn = Mock(return_value=defer.succeed(txn)) # actual call self.txnctrl.send(service, events) @@ -124,7 +119,6 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase): - def setUp(self): self.clock = MockClock() self.as_api = Mock() @@ -146,6 +140,7 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase): def take_txn(*args, **kwargs): return defer.succeed(txns.pop(0)) + self.store.get_oldest_unsent_txn = Mock(side_effect=take_txn) self.recoverer.recover() @@ -171,6 +166,7 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase): return defer.succeed(txns.pop(0)) else: return defer.succeed(txn) + self.store.get_oldest_unsent_txn = Mock(side_effect=take_txn) self.recoverer.recover() @@ -197,7 +193,6 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase): class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase): - def setUp(self): self.txn_ctrl = Mock() self.queuer = _ServiceQueuer(self.txn_ctrl, MockClock()) @@ -211,9 +206,7 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase): def test_send_single_event_with_queue(self): d = defer.Deferred() - self.txn_ctrl.send = Mock( - side_effect=lambda x, y: make_deferred_yieldable(d), - ) + self.txn_ctrl.send = Mock(side_effect=lambda x, y: make_deferred_yieldable(d)) service = Mock(id=4) event = Mock(event_id="first") event2 = Mock(event_id="second") @@ -247,6 +240,7 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase): def do_send(x, y): return make_deferred_yieldable(send_return_list.pop(0)) + self.txn_ctrl.send = Mock(side_effect=do_send) # send events for different ASes and make sure they are sent diff --git a/tests/config/test_generate.py b/tests/config/test_generate.py index eb7f0ab12a..f88d28a19d 100644 --- a/tests/config/test_generate.py +++ b/tests/config/test_generate.py @@ -24,7 +24,6 @@ from tests import unittest class ConfigGenerationTestCase(unittest.TestCase): - def setUp(self): self.dir = tempfile.mkdtemp() self.file = os.path.join(self.dir, "homeserver.yaml") @@ -33,23 +32,30 @@ class ConfigGenerationTestCase(unittest.TestCase): shutil.rmtree(self.dir) def test_generate_config_generates_files(self): - HomeServerConfig.load_or_generate_config("", [ - "--generate-config", - "-c", self.file, - "--report-stats=yes", - "-H", "lemurs.win" - ]) + HomeServerConfig.load_or_generate_config( + "", + [ + "--generate-config", + "-c", + self.file, + "--report-stats=yes", + "-H", + "lemurs.win", + ], + ) self.assertSetEqual( - set([ - "homeserver.yaml", - "lemurs.win.log.config", - "lemurs.win.signing.key", - "lemurs.win.tls.crt", - "lemurs.win.tls.dh", - "lemurs.win.tls.key", - ]), - set(os.listdir(self.dir)) + set( + [ + "homeserver.yaml", + "lemurs.win.log.config", + "lemurs.win.signing.key", + "lemurs.win.tls.crt", + "lemurs.win.tls.dh", + "lemurs.win.tls.key", + ] + ), + set(os.listdir(self.dir)), ) self.assert_log_filename_is( diff --git a/tests/config/test_load.py b/tests/config/test_load.py index 5c422eff38..d5f1777093 100644 --- a/tests/config/test_load.py +++ b/tests/config/test_load.py @@ -24,7 +24,6 @@ from tests import unittest class ConfigLoadingTestCase(unittest.TestCase): - def setUp(self): self.dir = tempfile.mkdtemp() print(self.dir) @@ -43,15 +42,14 @@ class ConfigLoadingTestCase(unittest.TestCase): def test_generates_and_loads_macaroon_secret_key(self): self.generate_config() - with open(self.file, - "r") as f: + with open(self.file, "r") as f: raw = yaml.load(f) self.assertIn("macaroon_secret_key", raw) config = HomeServerConfig.load_config("", ["-c", self.file]) self.assertTrue( hasattr(config, "macaroon_secret_key"), - "Want config to have attr macaroon_secret_key" + "Want config to have attr macaroon_secret_key", ) if len(config.macaroon_secret_key) < 5: self.fail( @@ -62,7 +60,7 @@ class ConfigLoadingTestCase(unittest.TestCase): config = HomeServerConfig.load_or_generate_config("", ["-c", self.file]) self.assertTrue( hasattr(config, "macaroon_secret_key"), - "Want config to have attr macaroon_secret_key" + "Want config to have attr macaroon_secret_key", ) if len(config.macaroon_secret_key) < 5: self.fail( @@ -80,10 +78,9 @@ class ConfigLoadingTestCase(unittest.TestCase): def test_disable_registration(self): self.generate_config() - self.add_lines_to_config([ - "enable_registration: true", - "disable_registration: true", - ]) + self.add_lines_to_config( + ["enable_registration: true", "disable_registration: true"] + ) # Check that disable_registration clobbers enable_registration. config = HomeServerConfig.load_config("", ["-c", self.file]) self.assertFalse(config.enable_registration) @@ -92,18 +89,23 @@ class ConfigLoadingTestCase(unittest.TestCase): self.assertFalse(config.enable_registration) # Check that either config value is clobbered by the command line. - config = HomeServerConfig.load_or_generate_config("", [ - "-c", self.file, "--enable-registration" - ]) + config = HomeServerConfig.load_or_generate_config( + "", ["-c", self.file, "--enable-registration"] + ) self.assertTrue(config.enable_registration) def generate_config(self): - HomeServerConfig.load_or_generate_config("", [ - "--generate-config", - "-c", self.file, - "--report-stats=yes", - "-H", "lemurs.win" - ]) + HomeServerConfig.load_or_generate_config( + "", + [ + "--generate-config", + "-c", + self.file, + "--report-stats=yes", + "-H", + "lemurs.win", + ], + ) def generate_config_and_remove_lines_containing(self, needle): self.generate_config() diff --git a/tests/crypto/test_event_signing.py b/tests/crypto/test_event_signing.py index cd11871b80..b2536c1e69 100644 --- a/tests/crypto/test_event_signing.py +++ b/tests/crypto/test_event_signing.py @@ -24,9 +24,7 @@ from tests import unittest # Perform these tests using given secret key so we get entirely deterministic # signatures output that we can test against. -SIGNING_KEY_SEED = decode_base64( - "YJDBA9Xnr2sVqXD9Vj7XVUnmFZcZrlw8Md7kMW+3XA1" -) +SIGNING_KEY_SEED = decode_base64("YJDBA9Xnr2sVqXD9Vj7XVUnmFZcZrlw8Md7kMW+3XA1") KEY_ALG = "ed25519" KEY_VER = 1 @@ -36,7 +34,6 @@ HOSTNAME = "domain" class EventSigningTestCase(unittest.TestCase): - def setUp(self): self.signing_key = nacl.signing.SigningKey(SIGNING_KEY_SEED) self.signing_key.alg = KEY_ALG @@ -51,7 +48,7 @@ class EventSigningTestCase(unittest.TestCase): 'signatures': {}, 'type': "X", 'unsigned': {'age_ts': 1000000}, - }, + } ) add_hashes_and_signatures(builder, HOSTNAME, self.signing_key) @@ -61,8 +58,7 @@ class EventSigningTestCase(unittest.TestCase): self.assertTrue(hasattr(event, 'hashes')) self.assertIn('sha256', event.hashes) self.assertEquals( - event.hashes['sha256'], - "6tJjLpXtggfke8UxFhAKg82QVkJzvKOVOOSjUDK4ZSI", + event.hashes['sha256'], "6tJjLpXtggfke8UxFhAKg82QVkJzvKOVOOSjUDK4ZSI" ) self.assertTrue(hasattr(event, 'signatures')) @@ -77,9 +73,7 @@ class EventSigningTestCase(unittest.TestCase): def test_sign_message(self): builder = EventBuilder( { - 'content': { - 'body': "Here is the message content", - }, + 'content': {'body': "Here is the message content"}, 'event_id': "$0:domain", 'origin': "domain", 'origin_server_ts': 1000000, @@ -98,8 +92,7 @@ class EventSigningTestCase(unittest.TestCase): self.assertTrue(hasattr(event, 'hashes')) self.assertIn('sha256', event.hashes) self.assertEquals( - event.hashes['sha256'], - "onLKD1bGljeBWQhWZ1kaP9SorVmRQNdN5aM2JYU2n/g", + event.hashes['sha256'], "onLKD1bGljeBWQhWZ1kaP9SorVmRQNdN5aM2JYU2n/g" ) self.assertTrue(hasattr(event, 'signatures')) @@ -108,5 +101,5 @@ class EventSigningTestCase(unittest.TestCase): self.assertEquals( event.signatures[HOSTNAME][KEY_NAME], "Wm+VzmOUOz08Ds+0NTWb1d4CZrVsJSikkeRxh6aCcUw" - "u6pNC78FunoD7KNWzqFn241eYHYMGCA5McEiVPdhzBA" + "u6pNC78FunoD7KNWzqFn241eYHYMGCA5McEiVPdhzBA", ) diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index a9d37fe084..e40681ed1e 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -36,9 +36,7 @@ class MockPerspectiveServer(object): def get_verify_keys(self): vk = signedjson.key.get_verify_key(self.key) - return { - "%s:%s" % (vk.alg, vk.version): vk, - } + return {"%s:%s" % (vk.alg, vk.version): vk} def get_signed_key(self, server_name, verify_key): key_id = "%s:%s" % (verify_key.alg, verify_key.version) @@ -47,10 +45,8 @@ class MockPerspectiveServer(object): "old_verify_keys": {}, "valid_until_ts": time.time() * 1000 + 3600, "verify_keys": { - key_id: { - "key": signedjson.key.encode_verify_key_base64(verify_key) - } - } + key_id: {"key": signedjson.key.encode_verify_key_base64(verify_key)} + }, } signedjson.sign.sign_json(res, self.server_name, self.key) return res @@ -62,18 +58,15 @@ class KeyringTestCase(unittest.TestCase): self.mock_perspective_server = MockPerspectiveServer() self.http_client = Mock() self.hs = yield utils.setup_test_homeserver( - handlers=None, - http_client=self.http_client, + handlers=None, http_client=self.http_client ) self.hs.config.perspectives = { - self.mock_perspective_server.server_name: - self.mock_perspective_server.get_verify_keys() + self.mock_perspective_server.server_name: self.mock_perspective_server.get_verify_keys() } def check_context(self, _, expected): self.assertEquals( - getattr(LoggingContext.current_context(), "request", None), - expected + getattr(LoggingContext.current_context(), "request", None), expected ) @defer.inlineCallbacks @@ -89,8 +82,7 @@ class KeyringTestCase(unittest.TestCase): context_one.request = "one" wait_1_deferred = kr.wait_for_previous_lookups( - ["server1"], - {"server1": lookup_1_deferred}, + ["server1"], {"server1": lookup_1_deferred} ) # there were no previous lookups, so the deferred should be ready @@ -105,8 +97,7 @@ class KeyringTestCase(unittest.TestCase): # set off another wait. It should block because the first lookup # hasn't yet completed. wait_2_deferred = kr.wait_for_previous_lookups( - ["server1"], - {"server1": lookup_2_deferred}, + ["server1"], {"server1": lookup_2_deferred} ) self.assertFalse(wait_2_deferred.called) # ... so we should have reset the LoggingContext. @@ -132,21 +123,19 @@ class KeyringTestCase(unittest.TestCase): persp_resp = { "server_keys": [ self.mock_perspective_server.get_signed_key( - "server10", - signedjson.key.get_verify_key(key1) - ), + "server10", signedjson.key.get_verify_key(key1) + ) ] } persp_deferred = defer.Deferred() @defer.inlineCallbacks def get_perspectives(**kwargs): - self.assertEquals( - LoggingContext.current_context().request, "11", - ) + self.assertEquals(LoggingContext.current_context().request, "11") with logcontext.PreserveLoggingContext(): yield persp_deferred defer.returnValue(persp_resp) + self.http_client.post_json.side_effect = get_perspectives with LoggingContext("11") as context_11: @@ -154,9 +143,7 @@ class KeyringTestCase(unittest.TestCase): # start off a first set of lookups res_deferreds = kr.verify_json_objects_for_server( - [("server10", json1), - ("server11", {}) - ] + [("server10", json1), ("server11", {})] ) # the unsigned json should be rejected pretty quickly @@ -172,7 +159,7 @@ class KeyringTestCase(unittest.TestCase): # wait a tick for it to send the request to the perspectives server # (it first tries the datastore) - yield clock.sleep(1) # XXX find out why this takes so long! + yield clock.sleep(1) # XXX find out why this takes so long! self.http_client.post_json.assert_called_once() self.assertIs(LoggingContext.current_context(), context_11) @@ -186,7 +173,7 @@ class KeyringTestCase(unittest.TestCase): self.http_client.post_json.return_value = defer.Deferred() res_deferreds_2 = kr.verify_json_objects_for_server( - [("server10", json1)], + [("server10", json1)] ) yield clock.sleep(1) self.http_client.post_json.assert_not_called() @@ -207,8 +194,7 @@ class KeyringTestCase(unittest.TestCase): key1 = signedjson.key.generate_signing_key(1) yield self.hs.datastore.store_server_verify_key( - "server9", "", time.time() * 1000, - signedjson.key.get_verify_key(key1), + "server9", "", time.time() * 1000, signedjson.key.get_verify_key(key1) ) json1 = {} signedjson.sign.sign_json(json1, "server9", key1) diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py index f51d99419e..ff217ca8b9 100644 --- a/tests/events/test_utils.py +++ b/tests/events/test_utils.py @@ -31,25 +31,20 @@ def MockEvent(**kwargs): class PruneEventTestCase(unittest.TestCase): """ Asserts that a new event constructed with `evdict` will look like `matchdict` when it is redacted. """ + def run_test(self, evdict, matchdict): - self.assertEquals( - prune_event(FrozenEvent(evdict)).get_dict(), - matchdict - ) + self.assertEquals(prune_event(FrozenEvent(evdict)).get_dict(), matchdict) def test_minimal(self): self.run_test( - { - 'type': 'A', - 'event_id': '$test:domain', - }, + {'type': 'A', 'event_id': '$test:domain'}, { 'type': 'A', 'event_id': '$test:domain', 'content': {}, 'signatures': {}, 'unsigned': {}, - } + }, ) def test_basic_keys(self): @@ -70,23 +65,19 @@ class PruneEventTestCase(unittest.TestCase): 'content': {}, 'signatures': {}, 'unsigned': {}, - } + }, ) def test_unsigned_age_ts(self): self.run_test( - { - 'type': 'B', - 'event_id': '$test:domain', - 'unsigned': {'age_ts': 20}, - }, + {'type': 'B', 'event_id': '$test:domain', 'unsigned': {'age_ts': 20}}, { 'type': 'B', 'event_id': '$test:domain', 'content': {}, 'signatures': {}, 'unsigned': {'age_ts': 20}, - } + }, ) self.run_test( @@ -101,23 +92,19 @@ class PruneEventTestCase(unittest.TestCase): 'content': {}, 'signatures': {}, 'unsigned': {}, - } + }, ) def test_content(self): self.run_test( - { - 'type': 'C', - 'event_id': '$test:domain', - 'content': {'things': 'here'}, - }, + {'type': 'C', 'event_id': '$test:domain', 'content': {'things': 'here'}}, { 'type': 'C', 'event_id': '$test:domain', 'content': {}, 'signatures': {}, 'unsigned': {}, - } + }, ) self.run_test( @@ -132,27 +119,20 @@ class PruneEventTestCase(unittest.TestCase): 'content': {'creator': '@2:domain'}, 'signatures': {}, 'unsigned': {}, - } + }, ) class SerializeEventTestCase(unittest.TestCase): - def serialize(self, ev, fields): return serialize_event(ev, 1479807801915, only_event_fields=fields) def test_event_fields_works_with_keys(self): self.assertEquals( self.serialize( - MockEvent( - sender="@alice:localhost", - room_id="!foo:bar" - ), - ["room_id"] + MockEvent(sender="@alice:localhost", room_id="!foo:bar"), ["room_id"] ), - { - "room_id": "!foo:bar", - } + {"room_id": "!foo:bar"}, ) def test_event_fields_works_with_nested_keys(self): @@ -161,17 +141,11 @@ class SerializeEventTestCase(unittest.TestCase): MockEvent( sender="@alice:localhost", room_id="!foo:bar", - content={ - "body": "A message", - }, + content={"body": "A message"}, ), - ["content.body"] + ["content.body"], ), - { - "content": { - "body": "A message", - } - } + {"content": {"body": "A message"}}, ) def test_event_fields_works_with_dot_keys(self): @@ -180,17 +154,11 @@ class SerializeEventTestCase(unittest.TestCase): MockEvent( sender="@alice:localhost", room_id="!foo:bar", - content={ - "key.with.dots": {}, - }, + content={"key.with.dots": {}}, ), - ["content.key\.with\.dots"] + ["content.key\.with\.dots"], ), - { - "content": { - "key.with.dots": {}, - } - } + {"content": {"key.with.dots": {}}}, ) def test_event_fields_works_with_nested_dot_keys(self): @@ -201,21 +169,12 @@ class SerializeEventTestCase(unittest.TestCase): room_id="!foo:bar", content={ "not_me": 1, - "nested.dot.key": { - "leaf.key": 42, - "not_me_either": 1, - }, + "nested.dot.key": {"leaf.key": 42, "not_me_either": 1}, }, ), - ["content.nested\.dot\.key.leaf\.key"] + ["content.nested\.dot\.key.leaf\.key"], ), - { - "content": { - "nested.dot.key": { - "leaf.key": 42, - }, - } - } + {"content": {"nested.dot.key": {"leaf.key": 42}}}, ) def test_event_fields_nops_with_unknown_keys(self): @@ -224,17 +183,11 @@ class SerializeEventTestCase(unittest.TestCase): MockEvent( sender="@alice:localhost", room_id="!foo:bar", - content={ - "foo": "bar", - }, + content={"foo": "bar"}, ), - ["content.foo", "content.notexists"] + ["content.foo", "content.notexists"], ), - { - "content": { - "foo": "bar", - } - } + {"content": {"foo": "bar"}}, ) def test_event_fields_nops_with_non_dict_keys(self): @@ -243,13 +196,11 @@ class SerializeEventTestCase(unittest.TestCase): MockEvent( sender="@alice:localhost", room_id="!foo:bar", - content={ - "foo": ["I", "am", "an", "array"], - }, + content={"foo": ["I", "am", "an", "array"]}, ), - ["content.foo.am"] + ["content.foo.am"], ), - {} + {}, ) def test_event_fields_nops_with_array_keys(self): @@ -258,13 +209,11 @@ class SerializeEventTestCase(unittest.TestCase): MockEvent( sender="@alice:localhost", room_id="!foo:bar", - content={ - "foo": ["I", "am", "an", "array"], - }, + content={"foo": ["I", "am", "an", "array"]}, ), - ["content.foo.1"] + ["content.foo.1"], ), - {} + {}, ) def test_event_fields_all_fields_if_empty(self): @@ -274,31 +223,21 @@ class SerializeEventTestCase(unittest.TestCase): type="foo", event_id="test", room_id="!foo:bar", - content={ - "foo": "bar", - }, + content={"foo": "bar"}, ), - [] + [], ), { "type": "foo", "event_id": "test", "room_id": "!foo:bar", - "content": { - "foo": "bar", - }, - "unsigned": {} - } + "content": {"foo": "bar"}, + "unsigned": {}, + }, ) def test_event_fields_fail_if_fields_not_str(self): with self.assertRaises(TypeError): self.serialize( - MockEvent( - room_id="!foo:bar", - content={ - "foo": "bar", - }, - ), - ["room_id", 4] + MockEvent(room_id="!foo:bar", content={"foo": "bar"}), ["room_id", 4] ) diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py index c91e25f54f..af15f4cc5a 100644 --- a/tests/federation/test_federation_server.py +++ b/tests/federation/test_federation_server.py @@ -23,10 +23,7 @@ from tests import unittest @unittest.DEBUG class ServerACLsTestCase(unittest.TestCase): def test_blacklisted_server(self): - e = _create_acl_event({ - "allow": ["*"], - "deny": ["evil.com"], - }) + e = _create_acl_event({"allow": ["*"], "deny": ["evil.com"]}) logging.info("ACL event: %s", e.content) self.assertFalse(server_matches_acl_event("evil.com", e)) @@ -36,10 +33,7 @@ class ServerACLsTestCase(unittest.TestCase): self.assertTrue(server_matches_acl_event("honestly.not.evil.com", e)) def test_block_ip_literals(self): - e = _create_acl_event({ - "allow_ip_literals": False, - "allow": ["*"], - }) + e = _create_acl_event({"allow_ip_literals": False, "allow": ["*"]}) logging.info("ACL event: %s", e.content) self.assertFalse(server_matches_acl_event("1.2.3.4", e)) @@ -49,10 +43,12 @@ class ServerACLsTestCase(unittest.TestCase): def _create_acl_event(content): - return FrozenEvent({ - "room_id": "!a:b", - "event_id": "$a:b", - "type": "m.room.server_acls", - "sender": "@a:b", - "content": content - }) + return FrozenEvent( + { + "room_id": "!a:b", + "event_id": "$a:b", + "type": "m.room.server_acls", + "sender": "@a:b", + "content": content, + } + ) diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index 57c0771cf3..ba7148ec01 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -45,20 +45,18 @@ class AppServiceHandlerTestCase(unittest.TestCase): services = [ self._mkservice(is_interested=False), interested_service, - self._mkservice(is_interested=False) + self._mkservice(is_interested=False), ] self.mock_store.get_app_services = Mock(return_value=services) self.mock_store.get_user_by_id = Mock(return_value=[]) event = Mock( - sender="@someone:anywhere", - type="m.room.message", - room_id="!foo:bar" + sender="@someone:anywhere", type="m.room.message", room_id="!foo:bar" ) self.mock_store.get_new_events_for_appservice.side_effect = [ (0, [event]), - (0, []) + (0, []), ] self.mock_as_api.push = Mock() yield self.handler.notify_interested_services(0) @@ -74,21 +72,15 @@ class AppServiceHandlerTestCase(unittest.TestCase): self.mock_store.get_app_services = Mock(return_value=services) self.mock_store.get_user_by_id = Mock(return_value=None) - event = Mock( - sender=user_id, - type="m.room.message", - room_id="!foo:bar" - ) + event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar") self.mock_as_api.push = Mock() self.mock_as_api.query_user = Mock() self.mock_store.get_new_events_for_appservice.side_effect = [ (0, [event]), - (0, []) + (0, []), ] yield self.handler.notify_interested_services(0) - self.mock_as_api.query_user.assert_called_once_with( - services[0], user_id - ) + self.mock_as_api.query_user.assert_called_once_with(services[0], user_id) @defer.inlineCallbacks def test_query_user_exists_known_user(self): @@ -96,25 +88,19 @@ class AppServiceHandlerTestCase(unittest.TestCase): services = [self._mkservice(is_interested=True)] services[0].is_interested_in_user = Mock(return_value=True) self.mock_store.get_app_services = Mock(return_value=services) - self.mock_store.get_user_by_id = Mock(return_value={ - "name": user_id - }) + self.mock_store.get_user_by_id = Mock(return_value={"name": user_id}) - event = Mock( - sender=user_id, - type="m.room.message", - room_id="!foo:bar" - ) + event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar") self.mock_as_api.push = Mock() self.mock_as_api.query_user = Mock() self.mock_store.get_new_events_for_appservice.side_effect = [ (0, [event]), - (0, []) + (0, []), ] yield self.handler.notify_interested_services(0) self.assertFalse( self.mock_as_api.query_user.called, - "query_user called when it shouldn't have been." + "query_user called when it shouldn't have been.", ) @defer.inlineCallbacks @@ -129,7 +115,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): services = [ self._mkservice_alias(is_interested_in_alias=False), interested_service, - self._mkservice_alias(is_interested_in_alias=False) + self._mkservice_alias(is_interested_in_alias=False), ] self.mock_store.get_app_services = Mock(return_value=services) @@ -140,8 +126,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): result = yield self.handler.query_room_alias_exists(room_alias) self.mock_as_api.query_alias.assert_called_once_with( - interested_service, - room_alias_str + interested_service, room_alias_str ) self.assertEquals(result.room_id, room_id) self.assertEquals(result.servers, servers) diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index 8a9bf2d5fd..ede01f8099 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -81,9 +81,7 @@ class AuthTestCase(unittest.TestCase): def test_short_term_login_token_gives_user_id(self): self.hs.clock.now = 1000 - token = self.macaroon_generator.generate_short_term_login_token( - "a_user", 5000 - ) + token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000) user_id = yield self.auth_handler.validate_short_term_login_token_and_get_user_id( token ) @@ -98,17 +96,13 @@ class AuthTestCase(unittest.TestCase): @defer.inlineCallbacks def test_short_term_login_token_cannot_replace_user_id(self): - token = self.macaroon_generator.generate_short_term_login_token( - "a_user", 5000 - ) + token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000) macaroon = pymacaroons.Macaroon.deserialize(token) user_id = yield self.auth_handler.validate_short_term_login_token_and_get_user_id( macaroon.serialize() ) - self.assertEqual( - "a_user", user_id - ) + self.assertEqual("a_user", user_id) # add another "user_id" caveat, which might allow us to override the # user_id. @@ -165,7 +159,5 @@ class AuthTestCase(unittest.TestCase): ) def _get_macaroon(self): - token = self.macaroon_generator.generate_short_term_login_token( - "user_a", 5000 - ) + token = self.macaroon_generator.generate_short_term_login_token("user_a", 5000) return pymacaroons.Macaroon.deserialize(token) diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index 633a0b7f36..d70d645504 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -28,9 +28,9 @@ user2 = "@theresa:bbb" class DeviceTestCase(unittest.TestCase): def __init__(self, *args, **kwargs): super(DeviceTestCase, self).__init__(*args, **kwargs) - self.store = None # type: synapse.storage.DataStore + self.store = None # type: synapse.storage.DataStore self.handler = None # type: synapse.handlers.device.DeviceHandler - self.clock = None # type: utils.MockClock + self.clock = None # type: utils.MockClock @defer.inlineCallbacks def setUp(self): @@ -44,7 +44,7 @@ class DeviceTestCase(unittest.TestCase): res = yield self.handler.check_device_registered( user_id="@boris:foo", device_id="fco", - initial_device_display_name="display name" + initial_device_display_name="display name", ) self.assertEqual(res, "fco") @@ -56,14 +56,14 @@ class DeviceTestCase(unittest.TestCase): res1 = yield self.handler.check_device_registered( user_id="@boris:foo", device_id="fco", - initial_device_display_name="display name" + initial_device_display_name="display name", ) self.assertEqual(res1, "fco") res2 = yield self.handler.check_device_registered( user_id="@boris:foo", device_id="fco", - initial_device_display_name="new display name" + initial_device_display_name="new display name", ) self.assertEqual(res2, "fco") @@ -75,7 +75,7 @@ class DeviceTestCase(unittest.TestCase): device_id = yield self.handler.check_device_registered( user_id="@theresa:foo", device_id=None, - initial_device_display_name="display" + initial_device_display_name="display", ) dev = yield self.handler.store.get_device("@theresa:foo", device_id) @@ -87,43 +87,53 @@ class DeviceTestCase(unittest.TestCase): res = yield self.handler.get_devices_by_user(user1) self.assertEqual(3, len(res)) - device_map = { - d["device_id"]: d for d in res - } - self.assertDictContainsSubset({ - "user_id": user1, - "device_id": "xyz", - "display_name": "display 0", - "last_seen_ip": None, - "last_seen_ts": None, - }, device_map["xyz"]) - self.assertDictContainsSubset({ - "user_id": user1, - "device_id": "fco", - "display_name": "display 1", - "last_seen_ip": "ip1", - "last_seen_ts": 1000000, - }, device_map["fco"]) - self.assertDictContainsSubset({ - "user_id": user1, - "device_id": "abc", - "display_name": "display 2", - "last_seen_ip": "ip3", - "last_seen_ts": 3000000, - }, device_map["abc"]) + device_map = {d["device_id"]: d for d in res} + self.assertDictContainsSubset( + { + "user_id": user1, + "device_id": "xyz", + "display_name": "display 0", + "last_seen_ip": None, + "last_seen_ts": None, + }, + device_map["xyz"], + ) + self.assertDictContainsSubset( + { + "user_id": user1, + "device_id": "fco", + "display_name": "display 1", + "last_seen_ip": "ip1", + "last_seen_ts": 1000000, + }, + device_map["fco"], + ) + self.assertDictContainsSubset( + { + "user_id": user1, + "device_id": "abc", + "display_name": "display 2", + "last_seen_ip": "ip3", + "last_seen_ts": 3000000, + }, + device_map["abc"], + ) @defer.inlineCallbacks def test_get_device(self): yield self._record_users() res = yield self.handler.get_device(user1, "abc") - self.assertDictContainsSubset({ - "user_id": user1, - "device_id": "abc", - "display_name": "display 2", - "last_seen_ip": "ip3", - "last_seen_ts": 3000000, - }, res) + self.assertDictContainsSubset( + { + "user_id": user1, + "device_id": "abc", + "display_name": "display 2", + "last_seen_ip": "ip3", + "last_seen_ts": 3000000, + }, + res, + ) @defer.inlineCallbacks def test_delete_device(self): @@ -153,8 +163,7 @@ class DeviceTestCase(unittest.TestCase): def test_update_unknown_device(self): update = {"display_name": "new_display"} with self.assertRaises(synapse.api.errors.NotFoundError): - yield self.handler.update_device("user_id", "unknown_device_id", - update) + yield self.handler.update_device("user_id", "unknown_device_id", update) @defer.inlineCallbacks def _record_users(self): @@ -168,16 +177,17 @@ class DeviceTestCase(unittest.TestCase): yield self._record_user(user2, "def", "dispkay", "token4", "ip4") @defer.inlineCallbacks - def _record_user(self, user_id, device_id, display_name, - access_token=None, ip=None): + def _record_user( + self, user_id, device_id, display_name, access_token=None, ip=None + ): device_id = yield self.handler.check_device_registered( user_id=user_id, device_id=device_id, - initial_device_display_name=display_name + initial_device_display_name=display_name, ) if ip is not None: yield self.store.insert_client_ip( - user_id, - access_token, ip, "user_agent", device_id) + user_id, access_token, ip, "user_agent", device_id + ) self.clock.advance_time(1000) diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index a353070316..06de9f5eca 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -42,6 +42,7 @@ class DirectoryTestCase(unittest.TestCase): def register_query_handler(query_type, handler): self.query_handlers[query_type] = handler + self.mock_registry.register_query_handler = register_query_handler hs = yield setup_test_homeserver( @@ -68,10 +69,7 @@ class DirectoryTestCase(unittest.TestCase): result = yield self.handler.get_association(self.my_room) - self.assertEquals({ - "room_id": "!8765qwer:test", - "servers": ["test"], - }, result) + self.assertEquals({"room_id": "!8765qwer:test", "servers": ["test"]}, result) @defer.inlineCallbacks def test_get_remote_association(self): @@ -81,16 +79,13 @@ class DirectoryTestCase(unittest.TestCase): result = yield self.handler.get_association(self.remote_room) - self.assertEquals({ - "room_id": "!8765qwer:test", - "servers": ["test", "remote"], - }, result) + self.assertEquals( + {"room_id": "!8765qwer:test", "servers": ["test", "remote"]}, result + ) self.mock_federation.make_query.assert_called_with( destination="remote", query_type="directory", - args={ - "room_alias": "#another:remote", - }, + args={"room_alias": "#another:remote"}, retry_on_dns_fail=False, ignore_backoff=True, ) @@ -105,7 +100,4 @@ class DirectoryTestCase(unittest.TestCase): {"room_alias": "#your-room:test"} ) - self.assertEquals({ - "room_id": "!8765asdf:test", - "servers": ["test"], - }, response) + self.assertEquals({"room_id": "!8765asdf:test", "servers": ["test"]}, response) diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index ca1542236d..57ab228455 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -28,14 +28,13 @@ from tests import unittest, utils class E2eKeysHandlerTestCase(unittest.TestCase): def __init__(self, *args, **kwargs): super(E2eKeysHandlerTestCase, self).__init__(*args, **kwargs) - self.hs = None # type: synapse.server.HomeServer + self.hs = None # type: synapse.server.HomeServer self.handler = None # type: synapse.handlers.e2e_keys.E2eKeysHandler @defer.inlineCallbacks def setUp(self): self.hs = yield utils.setup_test_homeserver( - handlers=None, - federation_client=mock.Mock(), + handlers=None, federation_client=mock.Mock() ) self.handler = synapse.handlers.e2e_keys.E2eKeysHandler(self.hs) @@ -54,30 +53,21 @@ class E2eKeysHandlerTestCase(unittest.TestCase): device_id = "xyz" keys = { "alg1:k1": "key1", - "alg2:k2": { - "key": "key2", - "signatures": {"k1": "sig1"} - }, - "alg2:k3": { - "key": "key3", - }, + "alg2:k2": {"key": "key2", "signatures": {"k1": "sig1"}}, + "alg2:k3": {"key": "key3"}, } res = yield self.handler.upload_keys_for_user( - local_user, device_id, {"one_time_keys": keys}, + local_user, device_id, {"one_time_keys": keys} ) - self.assertDictEqual(res, { - "one_time_key_counts": {"alg1": 1, "alg2": 2} - }) + self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}}) # we should be able to change the signature without a problem keys["alg2:k2"]["signatures"]["k1"] = "sig2" res = yield self.handler.upload_keys_for_user( - local_user, device_id, {"one_time_keys": keys}, + local_user, device_id, {"one_time_keys": keys} ) - self.assertDictEqual(res, { - "one_time_key_counts": {"alg1": 1, "alg2": 2} - }) + self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}}) @defer.inlineCallbacks def test_change_one_time_keys(self): @@ -87,25 +77,18 @@ class E2eKeysHandlerTestCase(unittest.TestCase): device_id = "xyz" keys = { "alg1:k1": "key1", - "alg2:k2": { - "key": "key2", - "signatures": {"k1": "sig1"} - }, - "alg2:k3": { - "key": "key3", - }, + "alg2:k2": {"key": "key2", "signatures": {"k1": "sig1"}}, + "alg2:k3": {"key": "key3"}, } res = yield self.handler.upload_keys_for_user( - local_user, device_id, {"one_time_keys": keys}, + local_user, device_id, {"one_time_keys": keys} ) - self.assertDictEqual(res, { - "one_time_key_counts": {"alg1": 1, "alg2": 2} - }) + self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}}) try: yield self.handler.upload_keys_for_user( - local_user, device_id, {"one_time_keys": {"alg1:k1": "key2"}}, + local_user, device_id, {"one_time_keys": {"alg1:k1": "key2"}} ) self.fail("No error when changing string key") except errors.SynapseError: @@ -113,7 +96,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase): try: yield self.handler.upload_keys_for_user( - local_user, device_id, {"one_time_keys": {"alg2:k3": "key2"}}, + local_user, device_id, {"one_time_keys": {"alg2:k3": "key2"}} ) self.fail("No error when replacing dict key with string") except errors.SynapseError: @@ -121,9 +104,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase): try: yield self.handler.upload_keys_for_user( - local_user, device_id, { - "one_time_keys": {"alg1:k1": {"key": "key"}} - }, + local_user, device_id, {"one_time_keys": {"alg1:k1": {"key": "key"}}} ) self.fail("No error when replacing string key with dict") except errors.SynapseError: @@ -131,13 +112,12 @@ class E2eKeysHandlerTestCase(unittest.TestCase): try: yield self.handler.upload_keys_for_user( - local_user, device_id, { + local_user, + device_id, + { "one_time_keys": { - "alg2:k2": { - "key": "key3", - "signatures": {"k1": "sig1"}, - } - }, + "alg2:k2": {"key": "key3", "signatures": {"k1": "sig1"}} + } }, ) self.fail("No error when replacing dict key") @@ -148,31 +128,20 @@ class E2eKeysHandlerTestCase(unittest.TestCase): def test_claim_one_time_key(self): local_user = "@boris:" + self.hs.hostname device_id = "xyz" - keys = { - "alg1:k1": "key1", - } + keys = {"alg1:k1": "key1"} res = yield self.handler.upload_keys_for_user( - local_user, device_id, {"one_time_keys": keys}, + local_user, device_id, {"one_time_keys": keys} + ) + self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1}}) + + res2 = yield self.handler.claim_one_time_keys( + {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None + ) + self.assertEqual( + res2, + { + "failures": {}, + "one_time_keys": {local_user: {device_id: {"alg1:k1": "key1"}}}, + }, ) - self.assertDictEqual(res, { - "one_time_key_counts": {"alg1": 1} - }) - - res2 = yield self.handler.claim_one_time_keys({ - "one_time_keys": { - local_user: { - device_id: "alg1" - } - } - }, timeout=None) - self.assertEqual(res2, { - "failures": {}, - "one_time_keys": { - local_user: { - device_id: { - "alg1:k1": "key1" - } - } - } - }) diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index 121ce78634..fc2b646ba2 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -39,8 +39,7 @@ class PresenceUpdateTestCase(unittest.TestCase): prev_state = UserPresenceState.default(user_id) new_state = prev_state.copy_and_replace( - state=PresenceState.ONLINE, - last_active_ts=now, + state=PresenceState.ONLINE, last_active_ts=now ) state, persist_and_notify, federation_ping = handle_update( @@ -54,23 +53,22 @@ class PresenceUpdateTestCase(unittest.TestCase): self.assertEquals(state.last_federation_update_ts, now) self.assertEquals(wheel_timer.insert.call_count, 3) - wheel_timer.insert.assert_has_calls([ - call( - now=now, - obj=user_id, - then=new_state.last_active_ts + IDLE_TIMER - ), - call( - now=now, - obj=user_id, - then=new_state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT - ), - call( - now=now, - obj=user_id, - then=new_state.last_active_ts + LAST_ACTIVE_GRANULARITY - ), - ], any_order=True) + wheel_timer.insert.assert_has_calls( + [ + call(now=now, obj=user_id, then=new_state.last_active_ts + IDLE_TIMER), + call( + now=now, + obj=user_id, + then=new_state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT, + ), + call( + now=now, + obj=user_id, + then=new_state.last_active_ts + LAST_ACTIVE_GRANULARITY, + ), + ], + any_order=True, + ) def test_online_to_online(self): wheel_timer = Mock() @@ -79,14 +77,11 @@ class PresenceUpdateTestCase(unittest.TestCase): prev_state = UserPresenceState.default(user_id) prev_state = prev_state.copy_and_replace( - state=PresenceState.ONLINE, - last_active_ts=now, - currently_active=True, + state=PresenceState.ONLINE, last_active_ts=now, currently_active=True ) new_state = prev_state.copy_and_replace( - state=PresenceState.ONLINE, - last_active_ts=now, + state=PresenceState.ONLINE, last_active_ts=now ) state, persist_and_notify, federation_ping = handle_update( @@ -101,23 +96,22 @@ class PresenceUpdateTestCase(unittest.TestCase): self.assertEquals(state.last_federation_update_ts, now) self.assertEquals(wheel_timer.insert.call_count, 3) - wheel_timer.insert.assert_has_calls([ - call( - now=now, - obj=user_id, - then=new_state.last_active_ts + IDLE_TIMER - ), - call( - now=now, - obj=user_id, - then=new_state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT - ), - call( - now=now, - obj=user_id, - then=new_state.last_active_ts + LAST_ACTIVE_GRANULARITY - ), - ], any_order=True) + wheel_timer.insert.assert_has_calls( + [ + call(now=now, obj=user_id, then=new_state.last_active_ts + IDLE_TIMER), + call( + now=now, + obj=user_id, + then=new_state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT, + ), + call( + now=now, + obj=user_id, + then=new_state.last_active_ts + LAST_ACTIVE_GRANULARITY, + ), + ], + any_order=True, + ) def test_online_to_online_last_active_noop(self): wheel_timer = Mock() @@ -132,8 +126,7 @@ class PresenceUpdateTestCase(unittest.TestCase): ) new_state = prev_state.copy_and_replace( - state=PresenceState.ONLINE, - last_active_ts=now, + state=PresenceState.ONLINE, last_active_ts=now ) state, persist_and_notify, federation_ping = handle_update( @@ -148,23 +141,22 @@ class PresenceUpdateTestCase(unittest.TestCase): self.assertEquals(state.last_federation_update_ts, now) self.assertEquals(wheel_timer.insert.call_count, 3) - wheel_timer.insert.assert_has_calls([ - call( - now=now, - obj=user_id, - then=new_state.last_active_ts + IDLE_TIMER - ), - call( - now=now, - obj=user_id, - then=new_state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT - ), - call( - now=now, - obj=user_id, - then=new_state.last_active_ts + LAST_ACTIVE_GRANULARITY - ), - ], any_order=True) + wheel_timer.insert.assert_has_calls( + [ + call(now=now, obj=user_id, then=new_state.last_active_ts + IDLE_TIMER), + call( + now=now, + obj=user_id, + then=new_state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT, + ), + call( + now=now, + obj=user_id, + then=new_state.last_active_ts + LAST_ACTIVE_GRANULARITY, + ), + ], + any_order=True, + ) def test_online_to_online_last_active(self): wheel_timer = Mock() @@ -178,9 +170,7 @@ class PresenceUpdateTestCase(unittest.TestCase): currently_active=True, ) - new_state = prev_state.copy_and_replace( - state=PresenceState.ONLINE, - ) + new_state = prev_state.copy_and_replace(state=PresenceState.ONLINE) state, persist_and_notify, federation_ping = handle_update( prev_state, new_state, is_mine=True, wheel_timer=wheel_timer, now=now @@ -193,18 +183,17 @@ class PresenceUpdateTestCase(unittest.TestCase): self.assertEquals(state.last_federation_update_ts, now) self.assertEquals(wheel_timer.insert.call_count, 2) - wheel_timer.insert.assert_has_calls([ - call( - now=now, - obj=user_id, - then=new_state.last_active_ts + IDLE_TIMER - ), - call( - now=now, - obj=user_id, - then=new_state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT - ) - ], any_order=True) + wheel_timer.insert.assert_has_calls( + [ + call(now=now, obj=user_id, then=new_state.last_active_ts + IDLE_TIMER), + call( + now=now, + obj=user_id, + then=new_state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT, + ), + ], + any_order=True, + ) def test_remote_ping_timer(self): wheel_timer = Mock() @@ -213,13 +202,10 @@ class PresenceUpdateTestCase(unittest.TestCase): prev_state = UserPresenceState.default(user_id) prev_state = prev_state.copy_and_replace( - state=PresenceState.ONLINE, - last_active_ts=now, + state=PresenceState.ONLINE, last_active_ts=now ) - new_state = prev_state.copy_and_replace( - state=PresenceState.ONLINE, - ) + new_state = prev_state.copy_and_replace(state=PresenceState.ONLINE) state, persist_and_notify, federation_ping = handle_update( prev_state, new_state, is_mine=False, wheel_timer=wheel_timer, now=now @@ -232,13 +218,16 @@ class PresenceUpdateTestCase(unittest.TestCase): self.assertEquals(new_state.status_msg, state.status_msg) self.assertEquals(wheel_timer.insert.call_count, 1) - wheel_timer.insert.assert_has_calls([ - call( - now=now, - obj=user_id, - then=new_state.last_federation_update_ts + FEDERATION_TIMEOUT - ), - ], any_order=True) + wheel_timer.insert.assert_has_calls( + [ + call( + now=now, + obj=user_id, + then=new_state.last_federation_update_ts + FEDERATION_TIMEOUT, + ) + ], + any_order=True, + ) def test_online_to_offline(self): wheel_timer = Mock() @@ -247,14 +236,10 @@ class PresenceUpdateTestCase(unittest.TestCase): prev_state = UserPresenceState.default(user_id) prev_state = prev_state.copy_and_replace( - state=PresenceState.ONLINE, - last_active_ts=now, - currently_active=True, + state=PresenceState.ONLINE, last_active_ts=now, currently_active=True ) - new_state = prev_state.copy_and_replace( - state=PresenceState.OFFLINE, - ) + new_state = prev_state.copy_and_replace(state=PresenceState.OFFLINE) state, persist_and_notify, federation_ping = handle_update( prev_state, new_state, is_mine=True, wheel_timer=wheel_timer, now=now @@ -273,14 +258,10 @@ class PresenceUpdateTestCase(unittest.TestCase): prev_state = UserPresenceState.default(user_id) prev_state = prev_state.copy_and_replace( - state=PresenceState.ONLINE, - last_active_ts=now, - currently_active=True, + state=PresenceState.ONLINE, last_active_ts=now, currently_active=True ) - new_state = prev_state.copy_and_replace( - state=PresenceState.UNAVAILABLE, - ) + new_state = prev_state.copy_and_replace(state=PresenceState.UNAVAILABLE) state, persist_and_notify, federation_ping = handle_update( prev_state, new_state, is_mine=True, wheel_timer=wheel_timer, now=now @@ -293,13 +274,16 @@ class PresenceUpdateTestCase(unittest.TestCase): self.assertEquals(new_state.status_msg, state.status_msg) self.assertEquals(wheel_timer.insert.call_count, 1) - wheel_timer.insert.assert_has_calls([ - call( - now=now, - obj=user_id, - then=new_state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT - ) - ], any_order=True) + wheel_timer.insert.assert_has_calls( + [ + call( + now=now, + obj=user_id, + then=new_state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT, + ) + ], + any_order=True, + ) class PresenceTimeoutTestCase(unittest.TestCase): @@ -314,9 +298,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): last_user_sync_ts=now, ) - new_state = handle_timeout( - state, is_mine=True, syncing_user_ids=set(), now=now - ) + new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now) self.assertIsNotNone(new_state) self.assertEquals(new_state.state, PresenceState.UNAVAILABLE) @@ -332,9 +314,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): last_user_sync_ts=now - SYNC_ONLINE_TIMEOUT - 1, ) - new_state = handle_timeout( - state, is_mine=True, syncing_user_ids=set(), now=now - ) + new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now) self.assertIsNotNone(new_state) self.assertEquals(new_state.state, PresenceState.OFFLINE) @@ -369,9 +349,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): last_federation_update_ts=now - FEDERATION_PING_INTERVAL - 1, ) - new_state = handle_timeout( - state, is_mine=True, syncing_user_ids=set(), now=now - ) + new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now) self.assertIsNotNone(new_state) self.assertEquals(new_state, new_state) @@ -388,9 +366,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): last_federation_update_ts=now, ) - new_state = handle_timeout( - state, is_mine=True, syncing_user_ids=set(), now=now - ) + new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now) self.assertIsNone(new_state) @@ -425,9 +401,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): last_federation_update_ts=now, ) - new_state = handle_timeout( - state, is_mine=True, syncing_user_ids=set(), now=now - ) + new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now) self.assertIsNotNone(new_state) self.assertEquals(state, new_state) diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index dc17918a3d..9268a6fe2b 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -54,9 +54,7 @@ class ProfileTestCase(unittest.TestCase): federation_client=self.mock_federation, federation_server=Mock(), federation_registry=self.mock_registry, - ratelimiter=NonCallableMock(spec_set=[ - "send_message", - ]) + ratelimiter=NonCallableMock(spec_set=["send_message"]), ) self.ratelimiter = hs.get_ratelimiter() @@ -74,9 +72,7 @@ class ProfileTestCase(unittest.TestCase): @defer.inlineCallbacks def test_get_my_name(self): - yield self.store.set_profile_displayname( - self.frank.localpart, "Frank" - ) + yield self.store.set_profile_displayname(self.frank.localpart, "Frank") displayname = yield self.handler.get_displayname(self.frank) @@ -85,22 +81,18 @@ class ProfileTestCase(unittest.TestCase): @defer.inlineCallbacks def test_set_my_name(self): yield self.handler.set_displayname( - self.frank, - synapse.types.create_requester(self.frank), - "Frank Jr." + self.frank, synapse.types.create_requester(self.frank), "Frank Jr." ) self.assertEquals( (yield self.store.get_profile_displayname(self.frank.localpart)), - "Frank Jr." + "Frank Jr.", ) @defer.inlineCallbacks def test_set_my_name_noauth(self): d = self.handler.set_displayname( - self.frank, - synapse.types.create_requester(self.bob), - "Frank Jr." + self.frank, synapse.types.create_requester(self.bob), "Frank Jr." ) yield self.assertFailure(d, AuthError) @@ -145,11 +137,12 @@ class ProfileTestCase(unittest.TestCase): @defer.inlineCallbacks def test_set_my_avatar(self): yield self.handler.set_avatar_url( - self.frank, synapse.types.create_requester(self.frank), - "http://my.server/pic.gif" + self.frank, + synapse.types.create_requester(self.frank), + "http://my.server/pic.gif", ) self.assertEquals( (yield self.store.get_profile_avatar_url(self.frank.localpart)), - "http://my.server/pic.gif" + "http://my.server/pic.gif", ) diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index 4ea59a58de..dbec81076f 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -46,7 +46,8 @@ class RegistrationTestCase(unittest.TestCase): profile_handler=Mock(), ) self.macaroon_generator = Mock( - generate_access_token=Mock(return_value='secret')) + generate_access_token=Mock(return_value='secret') + ) self.hs.get_macaroon_generator = Mock(return_value=self.macaroon_generator) self.hs.handlers = RegistrationHandlers(self.hs) self.handler = self.hs.get_handlers().registration_handler @@ -62,7 +63,8 @@ class RegistrationTestCase(unittest.TestCase): user_id = "@someone:test" requester = create_requester("@as:test") result_user_id, result_token = yield self.handler.get_or_create_user( - requester, local_part, display_name) + requester, local_part, display_name + ) self.assertEquals(result_user_id, user_id) self.assertEquals(result_token, 'secret') @@ -73,13 +75,15 @@ class RegistrationTestCase(unittest.TestCase): yield store.register( user_id=frank.to_string(), token="jkv;g498752-43gj['eamb!-5", - password_hash=None) + password_hash=None, + ) local_part = "frank" display_name = "Frank" user_id = "@frank:test" requester = create_requester("@as:test") result_user_id, result_token = yield self.handler.get_or_create_user( - requester, local_part, display_name) + requester, local_part, display_name + ) self.assertEquals(result_user_id, user_id) self.assertEquals(result_token, 'secret') diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index f422cf3c5a..becfa77bfa 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -38,23 +38,19 @@ def _expect_edu(destination, edu_type, content, origin="test"): "origin": origin, "origin_server_ts": 1000000, "pdus": [], - "edus": [ - { - "edu_type": edu_type, - "content": content, - } - ], + "edus": [{"edu_type": edu_type, "content": content}], } def _make_edu_json(origin, edu_type, content): - return json.dumps( - _expect_edu("test", edu_type, content, origin=origin) - ).encode('utf8') + return json.dumps(_expect_edu("test", edu_type, content, origin=origin)).encode( + 'utf8' + ) class TypingNotificationsTestCase(unittest.TestCase): """Tests typing notifications to rooms.""" + @defer.inlineCallbacks def setUp(self): self.clock = MockClock() @@ -74,18 +70,20 @@ class TypingNotificationsTestCase(unittest.TestCase): "test", auth=self.auth, clock=self.clock, - datastore=Mock(spec=[ - # Bits that Federation needs - "prep_send_transaction", - "delivered_txn", - "get_received_txn_response", - "set_received_txn_response", - "get_destination_retry_timings", - "get_devices_by_remote", - # Bits that user_directory needs - "get_user_directory_stream_pos", - "get_current_state_deltas", - ]), + datastore=Mock( + spec=[ + # Bits that Federation needs + "prep_send_transaction", + "delivered_txn", + "get_received_txn_response", + "set_received_txn_response", + "get_destination_retry_timings", + "get_devices_by_remote", + # Bits that user_directory needs + "get_user_directory_stream_pos", + "get_current_state_deltas", + ] + ), state_handler=self.state_handler, handlers=Mock(), notifier=mock_notifier, @@ -100,19 +98,16 @@ class TypingNotificationsTestCase(unittest.TestCase): self.event_source = hs.get_event_sources().sources["typing"] self.datastore = hs.get_datastore() - retry_timings_res = { - "destination": "", - "retry_last_ts": 0, - "retry_interval": 0, - } - self.datastore.get_destination_retry_timings.return_value = ( - defer.succeed(retry_timings_res) + retry_timings_res = {"destination": "", "retry_last_ts": 0, "retry_interval": 0} + self.datastore.get_destination_retry_timings.return_value = defer.succeed( + retry_timings_res ) self.datastore.get_devices_by_remote.return_value = (0, []) def get_received_txn_response(*args): return defer.succeed(None) + self.datastore.get_received_txn_response = get_received_txn_response self.room_id = "a-room" @@ -125,10 +120,12 @@ class TypingNotificationsTestCase(unittest.TestCase): def get_joined_hosts_for_room(room_id): return set(member.domain for member in self.room_members) + self.datastore.get_joined_hosts_for_room = get_joined_hosts_for_room def get_current_user_in_room(room_id): return set(str(u) for u in self.room_members) + self.state_handler.get_current_user_in_room = get_current_user_in_room self.datastore.get_user_directory_stream_pos.return_value = ( @@ -136,19 +133,13 @@ class TypingNotificationsTestCase(unittest.TestCase): defer.succeed(1) ) - self.datastore.get_current_state_deltas.return_value = ( - None - ) + self.datastore.get_current_state_deltas.return_value = None self.auth.check_joined_room = check_joined_room self.datastore.get_to_device_stream_token = lambda: 0 - self.datastore.get_new_device_msgs_for_remote = ( - lambda *args, **kargs: ([], 0) - ) - self.datastore.delete_device_msgs_for_remote = ( - lambda *args, **kargs: None - ) + self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: ([], 0) + self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: None # Some local users to test with self.u_apple = UserID.from_string("@apple:test") @@ -170,24 +161,23 @@ class TypingNotificationsTestCase(unittest.TestCase): timeout=20000, ) - self.on_new_event.assert_has_calls([ - call('typing_key', 1, rooms=[self.room_id]), - ]) + self.on_new_event.assert_has_calls( + [call('typing_key', 1, rooms=[self.room_id])] + ) self.assertEquals(self.event_source.get_current_key(), 1) events = yield self.event_source.get_new_events( - room_ids=[self.room_id], - from_key=0, + room_ids=[self.room_id], from_key=0 ) self.assertEquals( events[0], [ - {"type": "m.typing", - "room_id": self.room_id, - "content": { - "user_ids": [self.u_apple.to_string()], - }}, - ] + { + "type": "m.typing", + "room_id": self.room_id, + "content": {"user_ids": [self.u_apple.to_string()]}, + } + ], ) @defer.inlineCallbacks @@ -206,13 +196,13 @@ class TypingNotificationsTestCase(unittest.TestCase): "room_id": self.room_id, "user_id": self.u_apple.to_string(), "typing": True, - } + }, ), json_data_callback=ANY, long_retries=True, backoff_on_404=True, ), - defer.succeed((200, "OK")) + defer.succeed((200, "OK")), ) yield self.handler.started_typing( @@ -240,27 +230,29 @@ class TypingNotificationsTestCase(unittest.TestCase): "room_id": self.room_id, "user_id": self.u_onion.to_string(), "typing": True, - } + }, ), federation_auth=True, ) - self.on_new_event.assert_has_calls([ - call('typing_key', 1, rooms=[self.room_id]), - ]) + self.on_new_event.assert_has_calls( + [call('typing_key', 1, rooms=[self.room_id])] + ) self.assertEquals(self.event_source.get_current_key(), 1) events = yield self.event_source.get_new_events( - room_ids=[self.room_id], - from_key=0 + room_ids=[self.room_id], from_key=0 + ) + self.assertEquals( + events[0], + [ + { + "type": "m.typing", + "room_id": self.room_id, + "content": {"user_ids": [self.u_onion.to_string()]}, + } + ], ) - self.assertEquals(events[0], [{ - "type": "m.typing", - "room_id": self.room_id, - "content": { - "user_ids": [self.u_onion.to_string()], - }, - }]) @defer.inlineCallbacks def test_stopped_typing(self): @@ -278,17 +270,18 @@ class TypingNotificationsTestCase(unittest.TestCase): "room_id": self.room_id, "user_id": self.u_apple.to_string(), "typing": False, - } + }, ), json_data_callback=ANY, long_retries=True, backoff_on_404=True, ), - defer.succeed((200, "OK")) + defer.succeed((200, "OK")), ) # Gut-wrenching from synapse.handlers.typing import RoomMember + member = RoomMember(self.room_id, self.u_apple.to_string()) self.handler._member_typing_until[member] = 1002000 self.handler._room_typing[self.room_id] = set([self.u_apple.to_string()]) @@ -296,29 +289,29 @@ class TypingNotificationsTestCase(unittest.TestCase): self.assertEquals(self.event_source.get_current_key(), 0) yield self.handler.stopped_typing( - target_user=self.u_apple, - auth_user=self.u_apple, - room_id=self.room_id, + target_user=self.u_apple, auth_user=self.u_apple, room_id=self.room_id ) - self.on_new_event.assert_has_calls([ - call('typing_key', 1, rooms=[self.room_id]), - ]) + self.on_new_event.assert_has_calls( + [call('typing_key', 1, rooms=[self.room_id])] + ) yield put_json.await_calls() self.assertEquals(self.event_source.get_current_key(), 1) events = yield self.event_source.get_new_events( - room_ids=[self.room_id], - from_key=0, + room_ids=[self.room_id], from_key=0 + ) + self.assertEquals( + events[0], + [ + { + "type": "m.typing", + "room_id": self.room_id, + "content": {"user_ids": []}, + } + ], ) - self.assertEquals(events[0], [{ - "type": "m.typing", - "room_id": self.room_id, - "content": { - "user_ids": [], - }, - }]) @defer.inlineCallbacks def test_typing_timeout(self): @@ -333,42 +326,46 @@ class TypingNotificationsTestCase(unittest.TestCase): timeout=10000, ) - self.on_new_event.assert_has_calls([ - call('typing_key', 1, rooms=[self.room_id]), - ]) + self.on_new_event.assert_has_calls( + [call('typing_key', 1, rooms=[self.room_id])] + ) self.on_new_event.reset_mock() self.assertEquals(self.event_source.get_current_key(), 1) events = yield self.event_source.get_new_events( - room_ids=[self.room_id], - from_key=0, + room_ids=[self.room_id], from_key=0 + ) + self.assertEquals( + events[0], + [ + { + "type": "m.typing", + "room_id": self.room_id, + "content": {"user_ids": [self.u_apple.to_string()]}, + } + ], ) - self.assertEquals(events[0], [{ - "type": "m.typing", - "room_id": self.room_id, - "content": { - "user_ids": [self.u_apple.to_string()], - }, - }]) self.clock.advance_time(16) - self.on_new_event.assert_has_calls([ - call('typing_key', 2, rooms=[self.room_id]), - ]) + self.on_new_event.assert_has_calls( + [call('typing_key', 2, rooms=[self.room_id])] + ) self.assertEquals(self.event_source.get_current_key(), 2) events = yield self.event_source.get_new_events( - room_ids=[self.room_id], - from_key=1, + room_ids=[self.room_id], from_key=1 + ) + self.assertEquals( + events[0], + [ + { + "type": "m.typing", + "room_id": self.room_id, + "content": {"user_ids": []}, + } + ], ) - self.assertEquals(events[0], [{ - "type": "m.typing", - "room_id": self.room_id, - "content": { - "user_ids": [], - }, - }]) # SYN-230 - see if we can still set after timeout @@ -379,20 +376,22 @@ class TypingNotificationsTestCase(unittest.TestCase): timeout=10000, ) - self.on_new_event.assert_has_calls([ - call('typing_key', 3, rooms=[self.room_id]), - ]) + self.on_new_event.assert_has_calls( + [call('typing_key', 3, rooms=[self.room_id])] + ) self.on_new_event.reset_mock() self.assertEquals(self.event_source.get_current_key(), 3) events = yield self.event_source.get_new_events( - room_ids=[self.room_id], - from_key=0, + room_ids=[self.room_id], from_key=0 + ) + self.assertEquals( + events[0], + [ + { + "type": "m.typing", + "room_id": self.room_id, + "content": {"user_ids": [self.u_apple.to_string()]}, + } + ], ) - self.assertEquals(events[0], [{ - "type": "m.typing", - "room_id": self.room_id, - "content": { - "user_ids": [self.u_apple.to_string()], - }, - }]) diff --git a/tests/http/test_endpoint.py b/tests/http/test_endpoint.py index 60e6a75953..3b0155ed03 100644 --- a/tests/http/test_endpoint.py +++ b/tests/http/test_endpoint.py @@ -39,15 +39,13 @@ class ServerNameTestCase(unittest.TestCase): "[1234", "underscore_.com", "percent%65.com", - "1234:5678:80", # too many colons + "1234:5678:80", # too many colons ] for i in test_data: try: parse_and_validate_server_name(i) self.fail( - "Expected parse_and_validate_server_name('%s') to throw" % ( - i, - ), + "Expected parse_and_validate_server_name('%s') to throw" % (i,) ) except ValueError: pass diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index a103e7be80..c23b6e2cfd 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -31,6 +31,7 @@ from tests.utils import setup_test_homeserver class TestReplicationClientHandler(ReplicationClientHandler): """Overrides on_rdata so that we can wait for it to happen""" + def __init__(self, store): super(TestReplicationClientHandler, self).__init__(store) self._rdata_awaiters = [] @@ -56,9 +57,7 @@ class BaseSlavedStoreTestCase(unittest.TestCase): "blue", http_client=None, federation_client=Mock(), - ratelimiter=NonCallableMock(spec_set=[ - "send_message", - ]), + ratelimiter=NonCallableMock(spec_set=["send_message"]), ) self.hs.get_ratelimiter().send_message.return_value = (True, 0) diff --git a/tests/replication/slave/storage/test_account_data.py b/tests/replication/slave/storage/test_account_data.py index adf226404e..87cc2b2fba 100644 --- a/tests/replication/slave/storage/test_account_data.py +++ b/tests/replication/slave/storage/test_account_data.py @@ -29,20 +29,14 @@ class SlavedAccountDataStoreTestCase(BaseSlavedStoreTestCase): @defer.inlineCallbacks def test_user_account_data(self): - yield self.master_store.add_account_data_for_user( - USER_ID, TYPE, {"a": 1} - ) + yield self.master_store.add_account_data_for_user(USER_ID, TYPE, {"a": 1}) yield self.replicate() yield self.check( - "get_global_account_data_by_type_for_user", - [TYPE, USER_ID], {"a": 1} + "get_global_account_data_by_type_for_user", [TYPE, USER_ID], {"a": 1} ) - yield self.master_store.add_account_data_for_user( - USER_ID, TYPE, {"a": 2} - ) + yield self.master_store.add_account_data_for_user(USER_ID, TYPE, {"a": 2}) yield self.replicate() yield self.check( - "get_global_account_data_by_type_for_user", - [TYPE, USER_ID], {"a": 2} + "get_global_account_data_by_type_for_user", [TYPE, USER_ID], {"a": 2} ) diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index f5b47f5ec0..622be2eef8 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -38,6 +38,7 @@ def patch__eq__(cls): def unpatch(): if eq is not None: cls.__eq__ = eq + return unpatch @@ -48,10 +49,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): def setUp(self): # Patch up the equality operator for events so that we can check # whether lists of events match using assertEquals - self.unpatches = [ - patch__eq__(_EventInternalMetadata), - patch__eq__(FrozenEvent), - ] + self.unpatches = [patch__eq__(_EventInternalMetadata), patch__eq__(FrozenEvent)] return super(SlavedEventStoreTestCase, self).setUp() def tearDown(self): @@ -61,33 +59,27 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): def test_get_latest_event_ids_in_room(self): create = yield self.persist(type="m.room.create", key="", creator=USER_ID) yield self.replicate() - yield self.check( - "get_latest_event_ids_in_room", (ROOM_ID,), [create.event_id] - ) + yield self.check("get_latest_event_ids_in_room", (ROOM_ID,), [create.event_id]) join = yield self.persist( - type="m.room.member", key=USER_ID, membership="join", + type="m.room.member", + key=USER_ID, + membership="join", prev_events=[(create.event_id, {})], ) yield self.replicate() - yield self.check( - "get_latest_event_ids_in_room", (ROOM_ID,), [join.event_id] - ) + yield self.check("get_latest_event_ids_in_room", (ROOM_ID,), [join.event_id]) @defer.inlineCallbacks def test_redactions(self): yield self.persist(type="m.room.create", key="", creator=USER_ID) yield self.persist(type="m.room.member", key=USER_ID, membership="join") - msg = yield self.persist( - type="m.room.message", msgtype="m.text", body="Hello" - ) + msg = yield self.persist(type="m.room.message", msgtype="m.text", body="Hello") yield self.replicate() yield self.check("get_event", [msg.event_id], msg) - redaction = yield self.persist( - type="m.room.redaction", redacts=msg.event_id - ) + redaction = yield self.persist(type="m.room.redaction", redacts=msg.event_id) yield self.replicate() msg_dict = msg.get_dict() @@ -102,9 +94,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): yield self.persist(type="m.room.create", key="", creator=USER_ID) yield self.persist(type="m.room.member", key=USER_ID, membership="join") - msg = yield self.persist( - type="m.room.message", msgtype="m.text", body="Hello" - ) + msg = yield self.persist(type="m.room.message", msgtype="m.text", body="Hello") yield self.replicate() yield self.check("get_event", [msg.event_id], msg) @@ -127,10 +117,19 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): type="m.room.member", key=USER_ID_2, membership="invite" ) yield self.replicate() - yield self.check("get_invited_rooms_for_user", [USER_ID_2], [RoomsForUser( - ROOM_ID, USER_ID, "invite", event.event_id, - event.internal_metadata.stream_ordering - )]) + yield self.check( + "get_invited_rooms_for_user", + [USER_ID_2], + [ + RoomsForUser( + ROOM_ID, + USER_ID, + "invite", + event.event_id, + event.internal_metadata.stream_ordering, + ) + ], + ) @defer.inlineCallbacks def test_push_actions_for_user(self): @@ -146,40 +145,55 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): yield self.check( "get_unread_event_push_actions_by_room_for_user", [ROOM_ID, USER_ID_2, event1.event_id], - {"highlight_count": 0, "notify_count": 0} + {"highlight_count": 0, "notify_count": 0}, ) yield self.persist( - type="m.room.message", msgtype="m.text", body="world", + type="m.room.message", + msgtype="m.text", + body="world", push_actions=[(USER_ID_2, ["notify"])], ) yield self.replicate() yield self.check( "get_unread_event_push_actions_by_room_for_user", [ROOM_ID, USER_ID_2, event1.event_id], - {"highlight_count": 0, "notify_count": 1} + {"highlight_count": 0, "notify_count": 1}, ) yield self.persist( - type="m.room.message", msgtype="m.text", body="world", - push_actions=[(USER_ID_2, [ - "notify", {"set_tweak": "highlight", "value": True} - ])], + type="m.room.message", + msgtype="m.text", + body="world", + push_actions=[ + (USER_ID_2, ["notify", {"set_tweak": "highlight", "value": True}]) + ], ) yield self.replicate() yield self.check( "get_unread_event_push_actions_by_room_for_user", [ROOM_ID, USER_ID_2, event1.event_id], - {"highlight_count": 1, "notify_count": 2} + {"highlight_count": 1, "notify_count": 2}, ) event_id = 0 @defer.inlineCallbacks def persist( - self, sender=USER_ID, room_id=ROOM_ID, type={}, key=None, internal={}, - state=None, reset_state=False, backfill=False, - depth=None, prev_events=[], auth_events=[], prev_state=[], redacts=None, + self, + sender=USER_ID, + room_id=ROOM_ID, + type={}, + key=None, + internal={}, + state=None, + reset_state=False, + backfill=False, + depth=None, + prev_events=[], + auth_events=[], + prev_state=[], + redacts=None, push_actions=[], **content ): @@ -219,34 +233,23 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): self.event_id += 1 if state is not None: - state_ids = { - key: e.event_id for key, e in state.items() - } + state_ids = {key: e.event_id for key, e in state.items()} context = EventContext.with_state( - state_group=None, - current_state_ids=state_ids, - prev_state_ids=state_ids + state_group=None, current_state_ids=state_ids, prev_state_ids=state_ids ) else: state_handler = self.hs.get_state_handler() context = yield state_handler.compute_event_context(event) yield self.master_store.add_push_actions_to_staging( - event.event_id, { - user_id: actions - for user_id, actions in push_actions - }, + event.event_id, {user_id: actions for user_id, actions in push_actions} ) ordering = None if backfill: - yield self.master_store.persist_events( - [(event, context)], backfilled=True - ) + yield self.master_store.persist_events([(event, context)], backfilled=True) else: - ordering, _ = yield self.master_store.persist_event( - event, context, - ) + ordering, _ = yield self.master_store.persist_event(event, context) if ordering: event.internal_metadata.stream_ordering = ordering diff --git a/tests/replication/slave/storage/test_receipts.py b/tests/replication/slave/storage/test_receipts.py index e6d670cc1f..ae1adeded1 100644 --- a/tests/replication/slave/storage/test_receipts.py +++ b/tests/replication/slave/storage/test_receipts.py @@ -34,6 +34,6 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase): ROOM_ID, "m.read", USER_ID, [EVENT_ID], {} ) yield self.replicate() - yield self.check("get_receipts_for_user", [USER_ID, "m.read"], { - ROOM_ID: EVENT_ID - }) + yield self.check( + "get_receipts_for_user", [USER_ID, "m.read"], {ROOM_ID: EVENT_ID} + ) diff --git a/tests/rest/client/test_transactions.py b/tests/rest/client/test_transactions.py index d46c27e7e9..708dc26e61 100644 --- a/tests/rest/client/test_transactions.py +++ b/tests/rest/client/test_transactions.py @@ -11,7 +11,6 @@ from tests.utils import MockClock class HttpTransactionCacheTestCase(unittest.TestCase): - def setUp(self): self.clock = MockClock() self.hs = Mock() @@ -24,9 +23,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase): @defer.inlineCallbacks def test_executes_given_function(self): - cb = Mock( - return_value=defer.succeed(self.mock_http_response) - ) + cb = Mock(return_value=defer.succeed(self.mock_http_response)) res = yield self.cache.fetch_or_execute( self.mock_key, cb, "some_arg", keyword="arg" ) @@ -35,9 +32,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase): @defer.inlineCallbacks def test_deduplicates_based_on_key(self): - cb = Mock( - return_value=defer.succeed(self.mock_http_response) - ) + cb = Mock(return_value=defer.succeed(self.mock_http_response)) for i in range(3): # invoke multiple times res = yield self.cache.fetch_or_execute( self.mock_key, cb, "some_arg", keyword="arg", changing_args=i @@ -120,29 +115,18 @@ class HttpTransactionCacheTestCase(unittest.TestCase): @defer.inlineCallbacks def test_cleans_up(self): - cb = Mock( - return_value=defer.succeed(self.mock_http_response) - ) - yield self.cache.fetch_or_execute( - self.mock_key, cb, "an arg" - ) + cb = Mock(return_value=defer.succeed(self.mock_http_response)) + yield self.cache.fetch_or_execute(self.mock_key, cb, "an arg") # should NOT have cleaned up yet self.clock.advance_time_msec(CLEANUP_PERIOD_MS / 2) - yield self.cache.fetch_or_execute( - self.mock_key, cb, "an arg" - ) + yield self.cache.fetch_or_execute(self.mock_key, cb, "an arg") # still using cache cb.assert_called_once_with("an arg") self.clock.advance_time_msec(CLEANUP_PERIOD_MS) - yield self.cache.fetch_or_execute( - self.mock_key, cb, "an arg" - ) + yield self.cache.fetch_or_execute(self.mock_key, cb, "an arg") # no longer using cache self.assertEqual(cb.call_count, 2) - self.assertEqual( - cb.call_args_list, - [call("an arg",), call("an arg",)] - ) + self.assertEqual(cb.call_args_list, [call("an arg"), call("an arg")]) diff --git a/tests/rest/client/v1/test_admin.py b/tests/rest/client/v1/test_admin.py index fb28883d30..67d9ab94e2 100644 --- a/tests/rest/client/v1/test_admin.py +++ b/tests/rest/client/v1/test_admin.py @@ -215,6 +215,7 @@ class UserRegisterTestCase(unittest.TestCase): mac. Admin is optional. Additional checks are done for length and type. """ + def nonce(): request, channel = make_request("GET", self.url) render(request, self.resource, self.clock) @@ -289,7 +290,9 @@ class UserRegisterTestCase(unittest.TestCase): self.assertEqual('Invalid password', channel.json_body["error"]) # Must not have null bytes - body = json.dumps({"nonce": nonce(), "username": "a", "password": u"abcd\u0000"}) + body = json.dumps( + {"nonce": nonce(), "username": "a", "password": u"abcd\u0000"} + ) request, channel = make_request("POST", self.url, body.encode('utf8')) render(request, self.resource, self.clock) diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py index 50418153fa..0316b74fa1 100644 --- a/tests/rest/client/v1/test_events.py +++ b/tests/rest/client/v1/test_events.py @@ -43,9 +43,7 @@ class EventStreamPermissionsTestCase(RestTestCase): hs = yield setup_test_homeserver( http_client=None, federation_client=Mock(), - ratelimiter=NonCallableMock(spec_set=[ - "send_message", - ]), + ratelimiter=NonCallableMock(spec_set=["send_message"]), ) self.ratelimiter = hs.get_ratelimiter() self.ratelimiter.send_message.return_value = (True, 0) @@ -83,7 +81,7 @@ class EventStreamPermissionsTestCase(RestTestCase): # behaviour is used instead to be consistent with the r0 spec. # see issue #2602 (code, response) = yield self.mock_resource.trigger_get( - "/events?access_token=%s" % ("invalid" + self.token, ) + "/events?access_token=%s" % ("invalid" + self.token,) ) self.assertEquals(401, code, msg=str(response)) @@ -98,18 +96,12 @@ class EventStreamPermissionsTestCase(RestTestCase): @defer.inlineCallbacks def test_stream_room_permissions(self): - room_id = yield self.create_room_as( - self.other_user, - tok=self.other_token - ) + room_id = yield self.create_room_as(self.other_user, tok=self.other_token) yield self.send(room_id, tok=self.other_token) # invited to room (expect no content for room) yield self.invite( - room_id, - src=self.other_user, - targ=self.user_id, - tok=self.other_token + room_id, src=self.other_user, targ=self.user_id, tok=self.other_token ) (code, response) = yield self.mock_resource.trigger_get( @@ -120,13 +112,16 @@ class EventStreamPermissionsTestCase(RestTestCase): # We may get a presence event for ourselves down self.assertEquals( 0, - len([ - c for c in response["chunk"] - if not ( - c.get("type") == "m.presence" - and c["content"].get("user_id") == self.user_id - ) - ]) + len( + [ + c + for c in response["chunk"] + if not ( + c.get("type") == "m.presence" + and c["content"].get("user_id") == self.user_id + ) + ] + ), ) # joined room (expect all content for room) diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py index 0516ce3cfb..9ba0ffc19f 100644 --- a/tests/rest/client/v1/test_profile.py +++ b/tests/rest/client/v1/test_profile.py @@ -36,12 +36,14 @@ class ProfileTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): self.mock_resource = MockHttpResource(prefix=PATH_PREFIX) - self.mock_handler = Mock(spec=[ - "get_displayname", - "set_displayname", - "get_avatar_url", - "set_avatar_url", - ]) + self.mock_handler = Mock( + spec=[ + "get_displayname", + "set_displayname", + "get_avatar_url", + "set_avatar_url", + ] + ) hs = yield setup_test_homeserver( "test", @@ -49,7 +51,7 @@ class ProfileTestCase(unittest.TestCase): resource_for_client=self.mock_resource, federation=Mock(), federation_client=Mock(), - profile_handler=self.mock_handler + profile_handler=self.mock_handler, ) def _get_user_by_req(request=None, allow_guest=False): @@ -78,9 +80,7 @@ class ProfileTestCase(unittest.TestCase): mocked_set.return_value = defer.succeed(()) (code, response) = yield self.mock_resource.trigger( - "PUT", - "/profile/%s/displayname" % (myid), - b'{"displayname": "Frank Jr."}' + "PUT", "/profile/%s/displayname" % (myid), b'{"displayname": "Frank Jr."}' ) self.assertEquals(200, code) @@ -94,14 +94,12 @@ class ProfileTestCase(unittest.TestCase): mocked_set.side_effect = AuthError(400, "message") (code, response) = yield self.mock_resource.trigger( - "PUT", "/profile/%s/displayname" % ("@4567:test"), - b'{"displayname": "Frank Jr."}' + "PUT", + "/profile/%s/displayname" % ("@4567:test"), + b'{"displayname": "Frank Jr."}', ) - self.assertTrue( - 400 <= code < 499, - msg="code %d is in the 4xx range" % (code) - ) + self.assertTrue(400 <= code < 499, msg="code %d is in the 4xx range" % (code)) @defer.inlineCallbacks def test_get_other_name(self): @@ -121,14 +119,12 @@ class ProfileTestCase(unittest.TestCase): mocked_set.side_effect = SynapseError(400, "message") (code, response) = yield self.mock_resource.trigger( - "PUT", "/profile/%s/displayname" % ("@opaque:elsewhere"), - b'{"displayname":"bob"}' + "PUT", + "/profile/%s/displayname" % ("@opaque:elsewhere"), + b'{"displayname":"bob"}', ) - self.assertTrue( - 400 <= code <= 499, - msg="code %d is in the 4xx range" % (code) - ) + self.assertTrue(400 <= code <= 499, msg="code %d is in the 4xx range" % (code)) @defer.inlineCallbacks def test_get_my_avatar(self): @@ -151,7 +147,7 @@ class ProfileTestCase(unittest.TestCase): (code, response) = yield self.mock_resource.trigger( "PUT", "/profile/%s/avatar_url" % (myid), - b'{"avatar_url": "http://my.server/pic.gif"}' + b'{"avatar_url": "http://my.server/pic.gif"}', ) self.assertEquals(200, code) diff --git a/tests/rest/client/v1/test_register.py b/tests/rest/client/v1/test_register.py index 83a23cd8fe..6f15d69ecd 100644 --- a/tests/rest/client/v1/test_register.py +++ b/tests/rest/client/v1/test_register.py @@ -32,6 +32,7 @@ class CreateUserServletTestCase(unittest.TestCase): """ Tests for CreateUserRestServlet. """ + if PY3: skip = "Not ported to Python 3." diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py index bddb3302e4..7f1a435e7b 100644 --- a/tests/rest/client/v1/test_typing.py +++ b/tests/rest/client/v1/test_typing.py @@ -31,6 +31,7 @@ PATH_PREFIX = "/_matrix/client/api/v1" class RoomTypingTestCase(RestTestCase): """ Tests /rooms/$room_id/typing/$user_id REST API. """ + user_id = "@sid:red" user = UserID.from_string(user_id) @@ -47,9 +48,7 @@ class RoomTypingTestCase(RestTestCase): clock=self.clock, http_client=None, federation_client=Mock(), - ratelimiter=NonCallableMock(spec_set=[ - "send_message", - ]), + ratelimiter=NonCallableMock(spec_set=["send_message"]), ) self.hs = hs @@ -71,6 +70,7 @@ class RoomTypingTestCase(RestTestCase): def _insert_client_ip(*args, **kwargs): return defer.succeed(None) + hs.get_datastore().insert_client_ip = _insert_client_ip def get_room_members(room_id): @@ -94,6 +94,7 @@ class RoomTypingTestCase(RestTestCase): else: if remotedomains is not None: remotedomains.add(member.domain) + hs.get_room_member_handler().fetch_room_distributions_into = ( fetch_room_distributions_into ) @@ -107,37 +108,42 @@ class RoomTypingTestCase(RestTestCase): @defer.inlineCallbacks def test_set_typing(self): (code, _) = yield self.mock_resource.trigger( - "PUT", "/rooms/%s/typing/%s" % (self.room_id, self.user_id), - '{"typing": true, "timeout": 30000}' + "PUT", + "/rooms/%s/typing/%s" % (self.room_id, self.user_id), + '{"typing": true, "timeout": 30000}', ) self.assertEquals(200, code) self.assertEquals(self.event_source.get_current_key(), 1) events = yield self.event_source.get_new_events( - from_key=0, - room_ids=[self.room_id], + from_key=0, room_ids=[self.room_id] + ) + self.assertEquals( + events[0], + [ + { + "type": "m.typing", + "room_id": self.room_id, + "content": {"user_ids": [self.user_id]}, + } + ], ) - self.assertEquals(events[0], [{ - "type": "m.typing", - "room_id": self.room_id, - "content": { - "user_ids": [self.user_id], - } - }]) @defer.inlineCallbacks def test_set_not_typing(self): (code, _) = yield self.mock_resource.trigger( - "PUT", "/rooms/%s/typing/%s" % (self.room_id, self.user_id), - '{"typing": false}' + "PUT", + "/rooms/%s/typing/%s" % (self.room_id, self.user_id), + '{"typing": false}', ) self.assertEquals(200, code) @defer.inlineCallbacks def test_typing_timeout(self): (code, _) = yield self.mock_resource.trigger( - "PUT", "/rooms/%s/typing/%s" % (self.room_id, self.user_id), - '{"typing": true, "timeout": 30000}' + "PUT", + "/rooms/%s/typing/%s" % (self.room_id, self.user_id), + '{"typing": true, "timeout": 30000}', ) self.assertEquals(200, code) @@ -148,8 +154,9 @@ class RoomTypingTestCase(RestTestCase): self.assertEquals(self.event_source.get_current_key(), 2) (code, _) = yield self.mock_resource.trigger( - "PUT", "/rooms/%s/typing/%s" % (self.room_id, self.user_id), - '{"typing": true, "timeout": 30000}' + "PUT", + "/rooms/%s/typing/%s" % (self.room_id, self.user_id), + '{"typing": true, "timeout": 30000}', ) self.assertEquals(200, code) diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index e3bc5f378d..9f862f9dfa 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -55,25 +55,39 @@ class RestTestCase(unittest.TestCase): @defer.inlineCallbacks def invite(self, room=None, src=None, targ=None, expect_code=200, tok=None): - yield self.change_membership(room=room, src=src, targ=targ, tok=tok, - membership=Membership.INVITE, - expect_code=expect_code) + yield self.change_membership( + room=room, + src=src, + targ=targ, + tok=tok, + membership=Membership.INVITE, + expect_code=expect_code, + ) @defer.inlineCallbacks def join(self, room=None, user=None, expect_code=200, tok=None): - yield self.change_membership(room=room, src=user, targ=user, tok=tok, - membership=Membership.JOIN, - expect_code=expect_code) + yield self.change_membership( + room=room, + src=user, + targ=user, + tok=tok, + membership=Membership.JOIN, + expect_code=expect_code, + ) @defer.inlineCallbacks def leave(self, room=None, user=None, expect_code=200, tok=None): - yield self.change_membership(room=room, src=user, targ=user, tok=tok, - membership=Membership.LEAVE, - expect_code=expect_code) + yield self.change_membership( + room=room, + src=user, + targ=user, + tok=tok, + membership=Membership.LEAVE, + expect_code=expect_code, + ) @defer.inlineCallbacks - def change_membership(self, room, src, targ, membership, tok=None, - expect_code=200): + def change_membership(self, room, src, targ, membership, tok=None, expect_code=200): temp_id = self.auth_user_id self.auth_user_id = src @@ -81,16 +95,15 @@ class RestTestCase(unittest.TestCase): if tok: path = path + "?access_token=%s" % tok - data = { - "membership": membership - } + data = {"membership": membership} (code, response) = yield self.mock_resource.trigger( "PUT", path, json.dumps(data) ) self.assertEquals( - expect_code, code, - msg="Expected: %d, got: %d, resp: %r" % (expect_code, code, response) + expect_code, + code, + msg="Expected: %d, got: %d, resp: %r" % (expect_code, code, response), ) self.auth_user_id = temp_id @@ -100,17 +113,15 @@ class RestTestCase(unittest.TestCase): (code, response) = yield self.mock_resource.trigger( "POST", "/register", - json.dumps({ - "user": user_id, - "password": "test", - "type": "m.login.password" - })) + json.dumps( + {"user": user_id, "password": "test", "type": "m.login.password"} + ), + ) self.assertEquals(200, code, msg=response) defer.returnValue(response) @defer.inlineCallbacks - def send(self, room_id, body=None, txn_id=None, tok=None, - expect_code=200): + def send(self, room_id, body=None, txn_id=None, tok=None, expect_code=200): if txn_id is None: txn_id = "m%s" % (str(time.time())) if body is None: @@ -132,8 +143,9 @@ class RestTestCase(unittest.TestCase): actual (dict): The test result. Extra keys will not be checked. """ for key in required: - self.assertEquals(required[key], actual[key], - msg="%s mismatch. %s" % (key, actual)) + self.assertEquals( + required[key], actual[key], msg="%s mismatch. %s" % (key, actual) + ) @attr.s @@ -156,7 +168,9 @@ class RestHelper(object): if tok: path = path + "?access_token=%s" % tok - request, channel = make_request("POST", path, json.dumps(content).encode('utf8')) + request, channel = make_request( + "POST", path, json.dumps(content).encode('utf8') + ) request.render(self.resource) wait_until_result(self.hs.get_reactor(), channel) @@ -204,9 +218,7 @@ class RestHelper(object): data = {"membership": membership} - request, channel = make_request( - "PUT", path, json.dumps(data).encode('utf8') - ) + request, channel = make_request("PUT", path, json.dumps(data).encode('utf8')) request.render(self.resource) wait_until_result(self.hs.get_reactor(), channel) diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index f6293f11a8..9487babac3 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -101,9 +101,7 @@ class RegisterRestServletTestCase(unittest.TestCase): wait_until_result(self.clock, channel) self.assertEquals(channel.result["code"], b"400", channel.result) - self.assertEquals( - channel.json_body["error"], "Invalid password" - ) + self.assertEquals(channel.json_body["error"], "Invalid password") def test_POST_bad_username(self): request_data = json.dumps({"username": 777, "password": "monkey"}) @@ -112,9 +110,7 @@ class RegisterRestServletTestCase(unittest.TestCase): wait_until_result(self.clock, channel) self.assertEquals(channel.result["code"], b"400", channel.result) - self.assertEquals( - channel.json_body["error"], "Invalid username" - ) + self.assertEquals(channel.json_body["error"], "Invalid username") def test_POST_user_valid(self): user_id = "@kermit:muppet" @@ -157,10 +153,7 @@ class RegisterRestServletTestCase(unittest.TestCase): wait_until_result(self.clock, channel) self.assertEquals(channel.result["code"], b"403", channel.result) - self.assertEquals( - channel.json_body["error"], - "Registration has been disabled", - ) + self.assertEquals(channel.json_body["error"], "Registration has been disabled") def test_POST_guest_registration(self): user_id = "a@b" @@ -188,6 +181,4 @@ class RegisterRestServletTestCase(unittest.TestCase): wait_until_result(self.clock, channel) self.assertEquals(channel.result["code"], b"403", channel.result) - self.assertEquals( - channel.json_body["error"], "Guest access is disabled" - ) + self.assertEquals(channel.json_body["error"], "Guest access is disabled") diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py index bf254a260d..a86901c2d8 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py @@ -41,13 +41,11 @@ class MediaStorageTests(unittest.TestCase): hs.get_reactor = Mock(return_value=reactor) hs.config.media_store_path = self.primary_base_path - storage_providers = [FileStorageProviderBackend( - hs, self.secondary_base_path - )] + storage_providers = [FileStorageProviderBackend(hs, self.secondary_base_path)] self.filepaths = MediaFilePaths(self.primary_base_path) self.media_storage = MediaStorage( - hs, self.primary_base_path, self.filepaths, storage_providers, + hs, self.primary_base_path, self.filepaths, storage_providers ) def tearDown(self): diff --git a/tests/server.py b/tests/server.py index e249668d21..05708be8b9 100644 --- a/tests/server.py +++ b/tests/server.py @@ -136,6 +136,7 @@ class ThreadedMemoryReactorClock(MemoryReactorClock): """ A MemoryReactorClock that supports callFromThread. """ + def callFromThread(self, callback, *args, **kwargs): """ Make the callback fire in the next reactor iteration. @@ -184,6 +185,7 @@ def setup_test_homeserver(*args, **kwargs): """ Threadless thread pool. """ + def start(self): pass diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index 52eff2a104..52eb05bfbf 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -25,7 +25,6 @@ from tests import unittest class CacheTestCase(unittest.TestCase): - def setUp(self): self.cache = Cache("test") @@ -97,7 +96,6 @@ class CacheTestCase(unittest.TestCase): class CacheDecoratorTestCase(unittest.TestCase): - @defer.inlineCallbacks def test_passthrough(self): class A(object): @@ -180,8 +178,7 @@ class CacheDecoratorTestCase(unittest.TestCase): yield a.func(k) self.assertTrue( - callcount[0] >= 14, - msg="Expected callcount >= 14, got %d" % (callcount[0]) + callcount[0] >= 14, msg="Expected callcount >= 14, got %d" % (callcount[0]) ) def test_prefill(self): diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index 099861b27c..fbb25a8844 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -34,7 +34,6 @@ from tests.utils import setup_test_homeserver class ApplicationServiceStoreTestCase(unittest.TestCase): - @defer.inlineCallbacks def setUp(self): self.as_yaml_files = [] @@ -44,20 +43,14 @@ class ApplicationServiceStoreTestCase(unittest.TestCase): password_providers=[], ) hs = yield setup_test_homeserver( - config=config, - federation_sender=Mock(), - federation_client=Mock(), + config=config, federation_sender=Mock(), federation_client=Mock() ) self.as_token = "token1" self.as_url = "some_url" self.as_id = "as1" self._add_appservice( - self.as_token, - self.as_id, - self.as_url, - "some_hs_token", - "bob" + self.as_token, self.as_id, self.as_url, "some_hs_token", "bob" ) self._add_appservice("token2", "as2", "some_url", "some_hs_token", "bob") self._add_appservice("token3", "as3", "some_url", "some_hs_token", "bob") @@ -73,8 +66,14 @@ class ApplicationServiceStoreTestCase(unittest.TestCase): pass def _add_appservice(self, as_token, id, url, hs_token, sender): - as_yaml = dict(url=url, as_token=as_token, hs_token=hs_token, - id=id, sender_localpart=sender, namespaces={}) + as_yaml = dict( + url=url, + as_token=as_token, + hs_token=hs_token, + id=id, + sender_localpart=sender, + namespaces={}, + ) # use the token as the filename with open(as_token, 'w') as outfile: outfile.write(yaml.dump(as_yaml)) @@ -85,24 +84,13 @@ class ApplicationServiceStoreTestCase(unittest.TestCase): self.assertEquals(service, None) def test_retrieval_of_service(self): - stored_service = self.store.get_app_service_by_token( - self.as_token - ) + stored_service = self.store.get_app_service_by_token(self.as_token) self.assertEquals(stored_service.token, self.as_token) self.assertEquals(stored_service.id, self.as_id) self.assertEquals(stored_service.url, self.as_url) - self.assertEquals( - stored_service.namespaces[ApplicationService.NS_ALIASES], - [] - ) - self.assertEquals( - stored_service.namespaces[ApplicationService.NS_ROOMS], - [] - ) - self.assertEquals( - stored_service.namespaces[ApplicationService.NS_USERS], - [] - ) + self.assertEquals(stored_service.namespaces[ApplicationService.NS_ALIASES], []) + self.assertEquals(stored_service.namespaces[ApplicationService.NS_ROOMS], []) + self.assertEquals(stored_service.namespaces[ApplicationService.NS_USERS], []) def test_retrieval_of_all_services(self): services = self.store.get_app_services() @@ -110,7 +98,6 @@ class ApplicationServiceStoreTestCase(unittest.TestCase): class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): - @defer.inlineCallbacks def setUp(self): self.as_yaml_files = [] @@ -121,33 +108,15 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): password_providers=[], ) hs = yield setup_test_homeserver( - config=config, - federation_sender=Mock(), - federation_client=Mock(), + config=config, federation_sender=Mock(), federation_client=Mock() ) self.db_pool = hs.get_db_pool() self.as_list = [ - { - "token": "token1", - "url": "https://matrix-as.org", - "id": "id_1" - }, - { - "token": "alpha_tok", - "url": "https://alpha.com", - "id": "id_alpha" - }, - { - "token": "beta_tok", - "url": "https://beta.com", - "id": "id_beta" - }, - { - "token": "gamma_tok", - "url": "https://gamma.com", - "id": "id_gamma" - }, + {"token": "token1", "url": "https://matrix-as.org", "id": "id_1"}, + {"token": "alpha_tok", "url": "https://alpha.com", "id": "id_alpha"}, + {"token": "beta_tok", "url": "https://beta.com", "id": "id_beta"}, + {"token": "gamma_tok", "url": "https://gamma.com", "id": "id_gamma"}, ] for s in self.as_list: yield self._add_service(s["url"], s["token"], s["id"]) @@ -157,8 +126,14 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): self.store = TestTransactionStore(None, hs) def _add_service(self, url, as_token, id): - as_yaml = dict(url=url, as_token=as_token, hs_token="something", - id=id, sender_localpart="a_sender", namespaces={}) + as_yaml = dict( + url=url, + as_token=as_token, + hs_token="something", + id=id, + sender_localpart="a_sender", + namespaces={}, + ) # use the token as the filename with open(as_token, 'w') as outfile: outfile.write(yaml.dump(as_yaml)) @@ -168,21 +143,21 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): return self.db_pool.runQuery( "INSERT INTO application_services_state(as_id, state, last_txn) " "VALUES(?,?,?)", - (id, state, txn) + (id, state, txn), ) def _insert_txn(self, as_id, txn_id, events): return self.db_pool.runQuery( "INSERT INTO application_services_txns(as_id, txn_id, event_ids) " "VALUES(?,?,?)", - (as_id, txn_id, json.dumps([e.event_id for e in events])) + (as_id, txn_id, json.dumps([e.event_id for e in events])), ) def _set_last_txn(self, as_id, txn_id): return self.db_pool.runQuery( "INSERT INTO application_services_state(as_id, last_txn, state) " "VALUES(?,?,?)", - (as_id, txn_id, ApplicationServiceState.UP) + (as_id, txn_id, ApplicationServiceState.UP), ) @defer.inlineCallbacks @@ -193,24 +168,16 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_get_appservice_state_up(self): - yield self._set_state( - self.as_list[0]["id"], ApplicationServiceState.UP - ) + yield self._set_state(self.as_list[0]["id"], ApplicationServiceState.UP) service = Mock(id=self.as_list[0]["id"]) state = yield self.store.get_appservice_state(service) self.assertEquals(ApplicationServiceState.UP, state) @defer.inlineCallbacks def test_get_appservice_state_down(self): - yield self._set_state( - self.as_list[0]["id"], ApplicationServiceState.UP - ) - yield self._set_state( - self.as_list[1]["id"], ApplicationServiceState.DOWN - ) - yield self._set_state( - self.as_list[2]["id"], ApplicationServiceState.DOWN - ) + yield self._set_state(self.as_list[0]["id"], ApplicationServiceState.UP) + yield self._set_state(self.as_list[1]["id"], ApplicationServiceState.DOWN) + yield self._set_state(self.as_list[2]["id"], ApplicationServiceState.DOWN) service = Mock(id=self.as_list[1]["id"]) state = yield self.store.get_appservice_state(service) self.assertEquals(ApplicationServiceState.DOWN, state) @@ -225,34 +192,22 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_set_appservices_state_down(self): service = Mock(id=self.as_list[1]["id"]) - yield self.store.set_appservice_state( - service, - ApplicationServiceState.DOWN - ) + yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN) rows = yield self.db_pool.runQuery( "SELECT as_id FROM application_services_state WHERE state=?", - (ApplicationServiceState.DOWN,) + (ApplicationServiceState.DOWN,), ) self.assertEquals(service.id, rows[0][0]) @defer.inlineCallbacks def test_set_appservices_state_multiple_up(self): service = Mock(id=self.as_list[1]["id"]) - yield self.store.set_appservice_state( - service, - ApplicationServiceState.UP - ) - yield self.store.set_appservice_state( - service, - ApplicationServiceState.DOWN - ) - yield self.store.set_appservice_state( - service, - ApplicationServiceState.UP - ) + yield self.store.set_appservice_state(service, ApplicationServiceState.UP) + yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN) + yield self.store.set_appservice_state(service, ApplicationServiceState.UP) rows = yield self.db_pool.runQuery( "SELECT as_id FROM application_services_state WHERE state=?", - (ApplicationServiceState.UP,) + (ApplicationServiceState.UP,), ) self.assertEquals(service.id, rows[0][0]) @@ -319,14 +274,13 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): res = yield self.db_pool.runQuery( "SELECT last_txn FROM application_services_state WHERE as_id=?", - (service.id,) + (service.id,), ) self.assertEquals(1, len(res)) self.assertEquals(txn_id, res[0][0]) res = yield self.db_pool.runQuery( - "SELECT * FROM application_services_txns WHERE txn_id=?", - (txn_id,) + "SELECT * FROM application_services_txns WHERE txn_id=?", (txn_id,) ) self.assertEquals(0, len(res)) @@ -340,17 +294,15 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): yield self.store.complete_appservice_txn(txn_id=txn_id, service=service) res = yield self.db_pool.runQuery( - "SELECT last_txn, state FROM application_services_state WHERE " - "as_id=?", - (service.id,) + "SELECT last_txn, state FROM application_services_state WHERE " "as_id=?", + (service.id,), ) self.assertEquals(1, len(res)) self.assertEquals(txn_id, res[0][0]) self.assertEquals(ApplicationServiceState.UP, res[0][1]) res = yield self.db_pool.runQuery( - "SELECT * FROM application_services_txns WHERE txn_id=?", - (txn_id,) + "SELECT * FROM application_services_txns WHERE txn_id=?", (txn_id,) ) self.assertEquals(0, len(res)) @@ -382,12 +334,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_get_appservices_by_state_single(self): - yield self._set_state( - self.as_list[0]["id"], ApplicationServiceState.DOWN - ) - yield self._set_state( - self.as_list[1]["id"], ApplicationServiceState.UP - ) + yield self._set_state(self.as_list[0]["id"], ApplicationServiceState.DOWN) + yield self._set_state(self.as_list[1]["id"], ApplicationServiceState.UP) services = yield self.store.get_appservices_by_state( ApplicationServiceState.DOWN @@ -397,18 +345,10 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_get_appservices_by_state_multiple(self): - yield self._set_state( - self.as_list[0]["id"], ApplicationServiceState.DOWN - ) - yield self._set_state( - self.as_list[1]["id"], ApplicationServiceState.UP - ) - yield self._set_state( - self.as_list[2]["id"], ApplicationServiceState.DOWN - ) - yield self._set_state( - self.as_list[3]["id"], ApplicationServiceState.UP - ) + yield self._set_state(self.as_list[0]["id"], ApplicationServiceState.DOWN) + yield self._set_state(self.as_list[1]["id"], ApplicationServiceState.UP) + yield self._set_state(self.as_list[2]["id"], ApplicationServiceState.DOWN) + yield self._set_state(self.as_list[3]["id"], ApplicationServiceState.UP) services = yield self.store.get_appservices_by_state( ApplicationServiceState.DOWN @@ -416,20 +356,17 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): self.assertEquals(2, len(services)) self.assertEquals( set([self.as_list[2]["id"], self.as_list[0]["id"]]), - set([services[0].id, services[1].id]) + set([services[0].id, services[1].id]), ) # required for ApplicationServiceTransactionStoreTestCase tests -class TestTransactionStore(ApplicationServiceTransactionStore, - ApplicationServiceStore): - +class TestTransactionStore(ApplicationServiceTransactionStore, ApplicationServiceStore): def __init__(self, db_conn, hs): super(TestTransactionStore, self).__init__(db_conn, hs) class ApplicationServiceStoreConfigTestCase(unittest.TestCase): - def _write_config(self, suffix, **kwargs): vals = { "id": "id" + suffix, @@ -452,8 +389,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): f2 = self._write_config(suffix="2") config = Mock( - app_service_config_files=[f1, f2], event_cache_size=1, - password_providers=[] + app_service_config_files=[f1, f2], event_cache_size=1, password_providers=[] ) hs = yield setup_test_homeserver( config=config, @@ -470,8 +406,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): f2 = self._write_config(id="id", suffix="2") config = Mock( - app_service_config_files=[f1, f2], event_cache_size=1, - password_providers=[] + app_service_config_files=[f1, f2], event_cache_size=1, password_providers=[] ) hs = yield setup_test_homeserver( config=config, @@ -494,8 +429,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): f2 = self._write_config(as_token="as_token", suffix="2") config = Mock( - app_service_config_files=[f1, f2], event_cache_size=1, - password_providers=[] + app_service_config_files=[f1, f2], event_cache_size=1, password_providers=[] ) hs = yield setup_test_homeserver( config=config, diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py index ab1f310572..b4f6baf441 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py @@ -7,7 +7,6 @@ from tests.utils import setup_test_homeserver class BackgroundUpdateTestCase(unittest.TestCase): - @defer.inlineCallbacks def setUp(self): hs = yield setup_test_homeserver() # type: synapse.server.HomeServer @@ -51,9 +50,7 @@ class BackgroundUpdateTestCase(unittest.TestCase): yield self.store.start_background_update("test_update", {"my_key": 1}) self.update_handler.reset_mock() - result = yield self.store.do_next_background_update( - duration_ms * desired_count - ) + result = yield self.store.do_next_background_update(duration_ms * desired_count) self.assertIsNotNone(result) self.update_handler.assert_called_once_with( {"my_key": 1}, self.store.DEFAULT_BACKGROUND_BATCH_SIZE @@ -67,18 +64,12 @@ class BackgroundUpdateTestCase(unittest.TestCase): self.update_handler.side_effect = update self.update_handler.reset_mock() - result = yield self.store.do_next_background_update( - duration_ms * desired_count - ) + result = yield self.store.do_next_background_update(duration_ms * desired_count) self.assertIsNotNone(result) - self.update_handler.assert_called_once_with( - {"my_key": 2}, desired_count - ) + self.update_handler.assert_called_once_with({"my_key": 2}, desired_count) # third step: we don't expect to be called any more self.update_handler.reset_mock() - result = yield self.store.do_next_background_update( - duration_ms * desired_count - ) + result = yield self.store.do_next_background_update(duration_ms * desired_count) self.assertIsNone(result) self.assertFalse(self.update_handler.called) diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index 1d1234ee39..7cb5f0e4cf 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -40,10 +40,12 @@ class SQLBaseStoreTestCase(unittest.TestCase): def runInteraction(func, *args, **kwargs): return defer.succeed(func(self.mock_txn, *args, **kwargs)) + self.db_pool.runInteraction = runInteraction def runWithConnection(func, *args, **kwargs): return defer.succeed(func(self.mock_conn, *args, **kwargs)) + self.db_pool.runWithConnection = runWithConnection config = Mock() @@ -63,8 +65,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.mock_txn.rowcount = 1 yield self.datastore._simple_insert( - table="tablename", - values={"columname": "Value"} + table="tablename", values={"columname": "Value"} ) self.mock_txn.execute.assert_called_with( @@ -78,12 +79,11 @@ class SQLBaseStoreTestCase(unittest.TestCase): yield self.datastore._simple_insert( table="tablename", # Use OrderedDict() so we can assert on the SQL generated - values=OrderedDict([("colA", 1), ("colB", 2), ("colC", 3)]) + values=OrderedDict([("colA", 1), ("colB", 2), ("colC", 3)]), ) self.mock_txn.execute.assert_called_with( - "INSERT INTO tablename (colA, colB, colC) VALUES(?, ?, ?)", - (1, 2, 3,) + "INSERT INTO tablename (colA, colB, colC) VALUES(?, ?, ?)", (1, 2, 3) ) @defer.inlineCallbacks @@ -92,9 +92,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.mock_txn.__iter__ = Mock(return_value=iter([("Value",)])) value = yield self.datastore._simple_select_one_onecol( - table="tablename", - keyvalues={"keycol": "TheKey"}, - retcol="retcol" + table="tablename", keyvalues={"keycol": "TheKey"}, retcol="retcol" ) self.assertEquals("Value", value) @@ -110,13 +108,12 @@ class SQLBaseStoreTestCase(unittest.TestCase): ret = yield self.datastore._simple_select_one( table="tablename", keyvalues={"keycol": "TheKey"}, - retcols=["colA", "colB", "colC"] + retcols=["colA", "colB", "colC"], ) self.assertEquals({"colA": 1, "colB": 2, "colC": 3}, ret) self.mock_txn.execute.assert_called_with( - "SELECT colA, colB, colC FROM tablename WHERE keycol = ?", - ["TheKey"] + "SELECT colA, colB, colC FROM tablename WHERE keycol = ?", ["TheKey"] ) @defer.inlineCallbacks @@ -128,7 +125,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): table="tablename", keyvalues={"keycol": "Not here"}, retcols=["colA"], - allow_none=True + allow_none=True, ) self.assertFalse(ret) @@ -137,20 +134,15 @@ class SQLBaseStoreTestCase(unittest.TestCase): def test_select_list(self): self.mock_txn.rowcount = 3 self.mock_txn.__iter__ = Mock(return_value=iter([(1,), (2,), (3,)])) - self.mock_txn.description = ( - ("colA", None, None, None, None, None, None), - ) + self.mock_txn.description = (("colA", None, None, None, None, None, None),) ret = yield self.datastore._simple_select_list( - table="tablename", - keyvalues={"keycol": "A set"}, - retcols=["colA"], + table="tablename", keyvalues={"keycol": "A set"}, retcols=["colA"] ) self.assertEquals([{"colA": 1}, {"colA": 2}, {"colA": 3}], ret) self.mock_txn.execute.assert_called_with( - "SELECT colA FROM tablename WHERE keycol = ?", - ["A set"] + "SELECT colA FROM tablename WHERE keycol = ?", ["A set"] ) @defer.inlineCallbacks @@ -160,12 +152,12 @@ class SQLBaseStoreTestCase(unittest.TestCase): yield self.datastore._simple_update_one( table="tablename", keyvalues={"keycol": "TheKey"}, - updatevalues={"columnname": "New Value"} + updatevalues={"columnname": "New Value"}, ) self.mock_txn.execute.assert_called_with( "UPDATE tablename SET columnname = ? WHERE keycol = ?", - ["New Value", "TheKey"] + ["New Value", "TheKey"], ) @defer.inlineCallbacks @@ -175,13 +167,12 @@ class SQLBaseStoreTestCase(unittest.TestCase): yield self.datastore._simple_update_one( table="tablename", keyvalues=OrderedDict([("colA", 1), ("colB", 2)]), - updatevalues=OrderedDict([("colC", 3), ("colD", 4)]) + updatevalues=OrderedDict([("colC", 3), ("colD", 4)]), ) self.mock_txn.execute.assert_called_with( - "UPDATE tablename SET colC = ?, colD = ? WHERE" - " colA = ? AND colB = ?", - [3, 4, 1, 2] + "UPDATE tablename SET colC = ?, colD = ? WHERE" " colA = ? AND colB = ?", + [3, 4, 1, 2], ) @defer.inlineCallbacks @@ -189,8 +180,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.mock_txn.rowcount = 1 yield self.datastore._simple_delete_one( - table="tablename", - keyvalues={"keycol": "Go away"}, + table="tablename", keyvalues={"keycol": "Go away"} ) self.mock_txn.execute.assert_called_with( diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index 7a58c6eb24..ea00bbe84c 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -37,8 +37,7 @@ class ClientIpStoreTestCase(tests.unittest.TestCase): self.clock.now = 12345678 user_id = "@user:id" yield self.store.insert_client_ip( - user_id, - "access_token", "ip", "user_agent", "device_id", + user_id, "access_token", "ip", "user_agent", "device_id" ) result = yield self.store.get_last_client_ip_by_device(user_id, "device_id") @@ -53,7 +52,7 @@ class ClientIpStoreTestCase(tests.unittest.TestCase): "user_agent": "user_agent", "last_seen": 12345678000, }, - r + r, ) @defer.inlineCallbacks @@ -62,7 +61,7 @@ class ClientIpStoreTestCase(tests.unittest.TestCase): self.hs.config.max_mau_value = 50 user_id = "@user:server" yield self.store.insert_client_ip( - user_id, "access_token", "ip", "user_agent", "device_id", + user_id, "access_token", "ip", "user_agent", "device_id" ) active = yield self.store._user_last_seen_monthly_active(user_id) self.assertFalse(active) @@ -78,7 +77,7 @@ class ClientIpStoreTestCase(tests.unittest.TestCase): return_value=defer.succeed(lots_of_users) ) yield self.store.insert_client_ip( - user_id, "access_token", "ip", "user_agent", "device_id", + user_id, "access_token", "ip", "user_agent", "device_id" ) active = yield self.store._user_last_seen_monthly_active(user_id) self.assertFalse(active) @@ -92,7 +91,7 @@ class ClientIpStoreTestCase(tests.unittest.TestCase): self.assertFalse(active) yield self.store.insert_client_ip( - user_id, "access_token", "ip", "user_agent", "device_id", + user_id, "access_token", "ip", "user_agent", "device_id" ) active = yield self.store._user_last_seen_monthly_active(user_id) self.assertTrue(active) @@ -107,10 +106,10 @@ class ClientIpStoreTestCase(tests.unittest.TestCase): self.assertFalse(active) yield self.store.insert_client_ip( - user_id, "access_token", "ip", "user_agent", "device_id", + user_id, "access_token", "ip", "user_agent", "device_id" ) yield self.store.insert_client_ip( - user_id, "access_token", "ip", "user_agent", "device_id", + user_id, "access_token", "ip", "user_agent", "device_id" ) active = yield self.store._user_last_seen_monthly_active(user_id) self.assertTrue(active) diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py index a54cc6bc32..63bc42d9e0 100644 --- a/tests/storage/test_devices.py +++ b/tests/storage/test_devices.py @@ -34,62 +34,58 @@ class DeviceStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def test_store_new_device(self): - yield self.store.store_device( - "user_id", "device_id", "display_name" - ) + yield self.store.store_device("user_id", "device_id", "display_name") res = yield self.store.get_device("user_id", "device_id") - self.assertDictContainsSubset({ - "user_id": "user_id", - "device_id": "device_id", - "display_name": "display_name", - }, res) + self.assertDictContainsSubset( + { + "user_id": "user_id", + "device_id": "device_id", + "display_name": "display_name", + }, + res, + ) @defer.inlineCallbacks def test_get_devices_by_user(self): - yield self.store.store_device( - "user_id", "device1", "display_name 1" - ) - yield self.store.store_device( - "user_id", "device2", "display_name 2" - ) - yield self.store.store_device( - "user_id2", "device3", "display_name 3" - ) + yield self.store.store_device("user_id", "device1", "display_name 1") + yield self.store.store_device("user_id", "device2", "display_name 2") + yield self.store.store_device("user_id2", "device3", "display_name 3") res = yield self.store.get_devices_by_user("user_id") self.assertEqual(2, len(res.keys())) - self.assertDictContainsSubset({ - "user_id": "user_id", - "device_id": "device1", - "display_name": "display_name 1", - }, res["device1"]) - self.assertDictContainsSubset({ - "user_id": "user_id", - "device_id": "device2", - "display_name": "display_name 2", - }, res["device2"]) + self.assertDictContainsSubset( + { + "user_id": "user_id", + "device_id": "device1", + "display_name": "display_name 1", + }, + res["device1"], + ) + self.assertDictContainsSubset( + { + "user_id": "user_id", + "device_id": "device2", + "display_name": "display_name 2", + }, + res["device2"], + ) @defer.inlineCallbacks def test_update_device(self): - yield self.store.store_device( - "user_id", "device_id", "display_name 1" - ) + yield self.store.store_device("user_id", "device_id", "display_name 1") res = yield self.store.get_device("user_id", "device_id") self.assertEqual("display_name 1", res["display_name"]) # do a no-op first - yield self.store.update_device( - "user_id", "device_id", - ) + yield self.store.update_device("user_id", "device_id") res = yield self.store.get_device("user_id", "device_id") self.assertEqual("display_name 1", res["display_name"]) # do the update yield self.store.update_device( - "user_id", "device_id", - new_display_name="display_name 2", + "user_id", "device_id", new_display_name="display_name 2" ) # check it worked @@ -100,7 +96,6 @@ class DeviceStoreTestCase(tests.unittest.TestCase): def test_update_unknown_device(self): with self.assertRaises(synapse.api.errors.StoreError) as cm: yield self.store.update_device( - "user_id", "unknown_device_id", - new_display_name="display_name 2", + "user_id", "unknown_device_id", new_display_name="display_name 2" ) self.assertEqual(404, cm.exception.code) diff --git a/tests/storage/test_directory.py b/tests/storage/test_directory.py index 129ebaf343..9a8ba2fcfe 100644 --- a/tests/storage/test_directory.py +++ b/tests/storage/test_directory.py @@ -24,7 +24,6 @@ from tests.utils import setup_test_homeserver class DirectoryStoreTestCase(unittest.TestCase): - @defer.inlineCallbacks def setUp(self): hs = yield setup_test_homeserver() @@ -37,38 +36,29 @@ class DirectoryStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_room_to_alias(self): yield self.store.create_room_alias_association( - room_alias=self.alias, - room_id=self.room.to_string(), - servers=["test"], + room_alias=self.alias, room_id=self.room.to_string(), servers=["test"] ) self.assertEquals( ["#my-room:test"], - (yield self.store.get_aliases_for_room(self.room.to_string())) + (yield self.store.get_aliases_for_room(self.room.to_string())), ) @defer.inlineCallbacks def test_alias_to_room(self): yield self.store.create_room_alias_association( - room_alias=self.alias, - room_id=self.room.to_string(), - servers=["test"], + room_alias=self.alias, room_id=self.room.to_string(), servers=["test"] ) self.assertObjectHasAttributes( - { - "room_id": self.room.to_string(), - "servers": ["test"], - }, - (yield self.store.get_association_from_room_alias(self.alias)) + {"room_id": self.room.to_string(), "servers": ["test"]}, + (yield self.store.get_association_from_room_alias(self.alias)), ) @defer.inlineCallbacks def test_delete_alias(self): yield self.store.create_room_alias_association( - room_alias=self.alias, - room_id=self.room.to_string(), - servers=["test"], + room_alias=self.alias, room_id=self.room.to_string(), servers=["test"] ) room_id = yield self.store.delete_room_alias(self.alias) diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py index 84ce492a2c..d45c775c2d 100644 --- a/tests/storage/test_end_to_end_keys.py +++ b/tests/storage/test_end_to_end_keys.py @@ -35,70 +35,49 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase): now = 1470174257070 json = {"key": "value"} - yield self.store.store_device( - "user", "device", None - ) + yield self.store.store_device("user", "device", None) - yield self.store.set_e2e_device_keys( - "user", "device", now, json) + yield self.store.set_e2e_device_keys("user", "device", now, json) res = yield self.store.get_e2e_device_keys((("user", "device"),)) self.assertIn("user", res) self.assertIn("device", res["user"]) dev = res["user"]["device"] - self.assertDictContainsSubset({ - "keys": json, - "device_display_name": None, - }, dev) + self.assertDictContainsSubset({"keys": json, "device_display_name": None}, dev) @defer.inlineCallbacks def test_get_key_with_device_name(self): now = 1470174257070 json = {"key": "value"} - yield self.store.set_e2e_device_keys( - "user", "device", now, json) - yield self.store.store_device( - "user", "device", "display_name" - ) + yield self.store.set_e2e_device_keys("user", "device", now, json) + yield self.store.store_device("user", "device", "display_name") res = yield self.store.get_e2e_device_keys((("user", "device"),)) self.assertIn("user", res) self.assertIn("device", res["user"]) dev = res["user"]["device"] - self.assertDictContainsSubset({ - "keys": json, - "device_display_name": "display_name", - }, dev) + self.assertDictContainsSubset( + {"keys": json, "device_display_name": "display_name"}, dev + ) @defer.inlineCallbacks def test_multiple_devices(self): now = 1470174257070 - yield self.store.store_device( - "user1", "device1", None - ) - yield self.store.store_device( - "user1", "device2", None - ) - yield self.store.store_device( - "user2", "device1", None - ) - yield self.store.store_device( - "user2", "device2", None - ) + yield self.store.store_device("user1", "device1", None) + yield self.store.store_device("user1", "device2", None) + yield self.store.store_device("user2", "device1", None) + yield self.store.store_device("user2", "device2", None) - yield self.store.set_e2e_device_keys( - "user1", "device1", now, 'json11') - yield self.store.set_e2e_device_keys( - "user1", "device2", now, 'json12') - yield self.store.set_e2e_device_keys( - "user2", "device1", now, 'json21') - yield self.store.set_e2e_device_keys( - "user2", "device2", now, 'json22') - - res = yield self.store.get_e2e_device_keys((("user1", "device1"), - ("user2", "device2"))) + yield self.store.set_e2e_device_keys("user1", "device1", now, 'json11') + yield self.store.set_e2e_device_keys("user1", "device2", now, 'json12') + yield self.store.set_e2e_device_keys("user2", "device1", now, 'json21') + yield self.store.set_e2e_device_keys("user2", "device2", now, 'json22') + + res = yield self.store.get_e2e_device_keys( + (("user1", "device1"), ("user2", "device2")) + ) self.assertIn("user1", res) self.assertIn("device1", res["user1"]) self.assertNotIn("device2", res["user1"]) diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index 69412c5aad..66eb119581 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -33,23 +33,32 @@ class EventFederationWorkerStoreTestCase(tests.unittest.TestCase): def insert_event(txn, i): event_id = '$event_%i:local' % i - txn.execute(( - "INSERT INTO events (" - " room_id, event_id, type, depth, topological_ordering," - " content, processed, outlier) " - "VALUES (?, ?, 'm.test', ?, ?, 'test', ?, ?)" - ), (room_id, event_id, i, i, True, False)) + txn.execute( + ( + "INSERT INTO events (" + " room_id, event_id, type, depth, topological_ordering," + " content, processed, outlier) " + "VALUES (?, ?, 'm.test', ?, ?, 'test', ?, ?)" + ), + (room_id, event_id, i, i, True, False), + ) - txn.execute(( - 'INSERT INTO event_forward_extremities (room_id, event_id) ' - 'VALUES (?, ?)' - ), (room_id, event_id)) + txn.execute( + ( + 'INSERT INTO event_forward_extremities (room_id, event_id) ' + 'VALUES (?, ?)' + ), + (room_id, event_id), + ) - txn.execute(( - 'INSERT INTO event_reference_hashes ' - '(event_id, algorithm, hash) ' - "VALUES (?, 'sha256', ?)" - ), (event_id, b'ffff')) + txn.execute( + ( + 'INSERT INTO event_reference_hashes ' + '(event_id, algorithm, hash) ' + "VALUES (?, 'sha256', ?)" + ), + (event_id, b'ffff'), + ) for i in range(0, 11): yield self.store.runInteraction("insert", insert_event, i) diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index 8430fc7ba6..5e87b4530d 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -24,12 +24,13 @@ USER_ID = "@user:example.com" PlAIN_NOTIF = ["notify", {"set_tweak": "highlight", "value": False}] HIGHLIGHT = [ - "notify", {"set_tweak": "sound", "value": "default"}, {"set_tweak": "highlight"} + "notify", + {"set_tweak": "sound", "value": "default"}, + {"set_tweak": "highlight"}, ] class EventPushActionsStoreTestCase(tests.unittest.TestCase): - @defer.inlineCallbacks def setUp(self): hs = yield tests.utils.setup_test_homeserver() @@ -55,12 +56,11 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def _assert_counts(noitf_count, highlight_count): counts = yield self.store.runInteraction( - "", self.store._get_unread_counts_by_pos_txn, - room_id, user_id, 0 + "", self.store._get_unread_counts_by_pos_txn, room_id, user_id, 0 ) self.assertEquals( counts, - {"notify_count": noitf_count, "highlight_count": highlight_count} + {"notify_count": noitf_count, "highlight_count": highlight_count}, ) @defer.inlineCallbacks @@ -72,11 +72,13 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): event.depth = stream yield self.store.add_push_actions_to_staging( - event.event_id, {user_id: action}, + event.event_id, {user_id: action} ) yield self.store.runInteraction( - "", self.store._set_push_actions_for_event_and_users_txn, - [(event, None)], [(event, None)], + "", + self.store._set_push_actions_for_event_and_users_txn, + [(event, None)], + [(event, None)], ) def _rotate(stream): @@ -86,8 +88,11 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): def _mark_read(stream, depth): return self.store.runInteraction( - "", self.store._remove_old_push_actions_before_txn, - room_id, user_id, stream + "", + self.store._remove_old_push_actions_before_txn, + room_id, + user_id, + stream, ) yield _assert_counts(0, 0) @@ -112,9 +117,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): yield _rotate(7) yield self.store._simple_delete( - table="event_push_actions", - keyvalues={"1": 1}, - desc="", + table="event_push_actions", keyvalues={"1": 1}, desc="" ) yield _assert_counts(1, 0) @@ -132,18 +135,21 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def test_find_first_stream_ordering_after_ts(self): def add_event(so, ts): - return self.store._simple_insert("events", { - "stream_ordering": so, - "received_ts": ts, - "event_id": "event%i" % so, - "type": "", - "room_id": "", - "content": "", - "processed": True, - "outlier": False, - "topological_ordering": 0, - "depth": 0, - }) + return self.store._simple_insert( + "events", + { + "stream_ordering": so, + "received_ts": ts, + "event_id": "event%i" % so, + "type": "", + "room_id": "", + "content": "", + "processed": True, + "outlier": False, + "topological_ordering": 0, + "depth": 0, + }, + ) # start with the base case where there are no events in the table r = yield self.store.find_first_stream_ordering_after_ts(11) @@ -160,31 +166,27 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): # add a bunch of dummy events to the events table for (stream_ordering, ts) in ( - (3, 110), - (4, 120), - (5, 120), - (10, 130), - (20, 140), + (3, 110), + (4, 120), + (5, 120), + (10, 130), + (20, 140), ): yield add_event(stream_ordering, ts) r = yield self.store.find_first_stream_ordering_after_ts(110) - self.assertEqual(r, 3, - "First event after 110ms should be 3, was %i" % r) + self.assertEqual(r, 3, "First event after 110ms should be 3, was %i" % r) # 4 and 5 are both after 120: we want 4 rather than 5 r = yield self.store.find_first_stream_ordering_after_ts(120) - self.assertEqual(r, 4, - "First event after 120ms should be 4, was %i" % r) + self.assertEqual(r, 4, "First event after 120ms should be 4, was %i" % r) r = yield self.store.find_first_stream_ordering_after_ts(129) - self.assertEqual(r, 10, - "First event after 129ms should be 10, was %i" % r) + self.assertEqual(r, 10, "First event after 129ms should be 10, was %i" % r) # check we can get the last event r = yield self.store.find_first_stream_ordering_after_ts(140) - self.assertEqual(r, 20, - "First event after 14ms should be 20, was %i" % r) + self.assertEqual(r, 20, "First event after 14ms should be 20, was %i" % r) # off the end r = yield self.store.find_first_stream_ordering_after_ts(160) diff --git a/tests/storage/test_keys.py b/tests/storage/test_keys.py index 3a3d002782..ad0a55b324 100644 --- a/tests/storage/test_keys.py +++ b/tests/storage/test_keys.py @@ -39,15 +39,12 @@ class KeyStoreTestCase(tests.unittest.TestCase): key2 = signedjson.key.decode_verify_key_base64( "ed25519", "key2", "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw" ) - yield self.store.store_server_verify_key( - "server1", "from_server", 0, key1 - ) - yield self.store.store_server_verify_key( - "server1", "from_server", 0, key2 - ) + yield self.store.store_server_verify_key("server1", "from_server", 0, key1) + yield self.store.store_server_verify_key("server1", "from_server", 0, key2) res = yield self.store.get_server_verify_keys( - "server1", ["ed25519:key1", "ed25519:key2", "ed25519:key3"]) + "server1", ["ed25519:key1", "ed25519:key2", "ed25519:key3"] + ) self.assertEqual(len(res.keys()), 2) self.assertEqual(res["ed25519:key1"].version, "key1") diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index cbd480cd42..22b1072d9f 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -40,19 +40,13 @@ class MonthlyActiveUsersTestCase(tests.unittest.TestCase): user2_email = "user2@matrix.org" threepids = [ {'medium': 'email', 'address': user1_email}, - {'medium': 'email', 'address': user2_email} + {'medium': 'email', 'address': user2_email}, ] user_num = len(threepids) - yield self.store.register( - user_id=user1, - token="123", - password_hash=None) + yield self.store.register(user_id=user1, token="123", password_hash=None) - yield self.store.register( - user_id=user2, - token="456", - password_hash=None) + yield self.store.register(user_id=user2, token="456", password_hash=None) now = int(self.hs.get_clock().time_msec()) yield self.store.user_add_threepid(user1, "email", user1_email, now, now) diff --git a/tests/storage/test_presence.py b/tests/storage/test_presence.py index 3276b39504..12c540dfab 100644 --- a/tests/storage/test_presence.py +++ b/tests/storage/test_presence.py @@ -24,7 +24,6 @@ from tests.utils import MockClock, setup_test_homeserver class PresenceStoreTestCase(unittest.TestCase): - @defer.inlineCallbacks def setUp(self): hs = yield setup_test_homeserver(clock=MockClock()) @@ -38,16 +37,19 @@ class PresenceStoreTestCase(unittest.TestCase): def test_presence_list(self): self.assertEquals( [], - (yield self.store.get_presence_list( - observer_localpart=self.u_apple.localpart, - )) + ( + yield self.store.get_presence_list( + observer_localpart=self.u_apple.localpart + ) + ), ) self.assertEquals( [], - (yield self.store.get_presence_list( - observer_localpart=self.u_apple.localpart, - accepted=True, - )) + ( + yield self.store.get_presence_list( + observer_localpart=self.u_apple.localpart, accepted=True + ) + ), ) yield self.store.add_presence_list_pending( @@ -57,16 +59,19 @@ class PresenceStoreTestCase(unittest.TestCase): self.assertEquals( [{"observed_user_id": "@banana:test", "accepted": 0}], - (yield self.store.get_presence_list( - observer_localpart=self.u_apple.localpart, - )) + ( + yield self.store.get_presence_list( + observer_localpart=self.u_apple.localpart + ) + ), ) self.assertEquals( [], - (yield self.store.get_presence_list( - observer_localpart=self.u_apple.localpart, - accepted=True, - )) + ( + yield self.store.get_presence_list( + observer_localpart=self.u_apple.localpart, accepted=True + ) + ), ) yield self.store.set_presence_list_accepted( @@ -76,16 +81,19 @@ class PresenceStoreTestCase(unittest.TestCase): self.assertEquals( [{"observed_user_id": "@banana:test", "accepted": 1}], - (yield self.store.get_presence_list( - observer_localpart=self.u_apple.localpart, - )) + ( + yield self.store.get_presence_list( + observer_localpart=self.u_apple.localpart + ) + ), ) self.assertEquals( [{"observed_user_id": "@banana:test", "accepted": 1}], - (yield self.store.get_presence_list( - observer_localpart=self.u_apple.localpart, - accepted=True, - )) + ( + yield self.store.get_presence_list( + observer_localpart=self.u_apple.localpart, accepted=True + ) + ), ) yield self.store.del_presence_list( @@ -95,14 +103,17 @@ class PresenceStoreTestCase(unittest.TestCase): self.assertEquals( [], - (yield self.store.get_presence_list( - observer_localpart=self.u_apple.localpart, - )) + ( + yield self.store.get_presence_list( + observer_localpart=self.u_apple.localpart + ) + ), ) self.assertEquals( [], - (yield self.store.get_presence_list( - observer_localpart=self.u_apple.localpart, - accepted=True, - )) + ( + yield self.store.get_presence_list( + observer_localpart=self.u_apple.localpart, accepted=True + ) + ), ) diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py index 2c95e5e95a..5acbc8be0c 100644 --- a/tests/storage/test_profile.py +++ b/tests/storage/test_profile.py @@ -24,7 +24,6 @@ from tests.utils import setup_test_homeserver class ProfileStoreTestCase(unittest.TestCase): - @defer.inlineCallbacks def setUp(self): hs = yield setup_test_homeserver() @@ -35,24 +34,17 @@ class ProfileStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_displayname(self): - yield self.store.create_profile( - self.u_frank.localpart - ) + yield self.store.create_profile(self.u_frank.localpart) - yield self.store.set_profile_displayname( - self.u_frank.localpart, "Frank" - ) + yield self.store.set_profile_displayname(self.u_frank.localpart, "Frank") self.assertEquals( - "Frank", - (yield self.store.get_profile_displayname(self.u_frank.localpart)) + "Frank", (yield self.store.get_profile_displayname(self.u_frank.localpart)) ) @defer.inlineCallbacks def test_avatar_url(self): - yield self.store.create_profile( - self.u_frank.localpart - ) + yield self.store.create_profile(self.u_frank.localpart) yield self.store.set_profile_avatar_url( self.u_frank.localpart, "http://my.site/here" @@ -60,5 +52,5 @@ class ProfileStoreTestCase(unittest.TestCase): self.assertEquals( "http://my.site/here", - (yield self.store.get_profile_avatar_url(self.u_frank.localpart)) + (yield self.store.get_profile_avatar_url(self.u_frank.localpart)), ) diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index 475ec900c4..85ce61e841 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -26,12 +26,10 @@ from tests.utils import setup_test_homeserver class RedactionTestCase(unittest.TestCase): - @defer.inlineCallbacks def setUp(self): hs = yield setup_test_homeserver( - resource_for_federation=Mock(), - http_client=None, + resource_for_federation=Mock(), http_client=None ) self.store = hs.get_datastore() @@ -46,17 +44,20 @@ class RedactionTestCase(unittest.TestCase): self.depth = 1 @defer.inlineCallbacks - def inject_room_member(self, room, user, membership, replaces_state=None, - extra_content={}): + def inject_room_member( + self, room, user, membership, replaces_state=None, extra_content={} + ): content = {"membership": membership} content.update(extra_content) - builder = self.event_builder_factory.new({ - "type": EventTypes.Member, - "sender": user.to_string(), - "state_key": user.to_string(), - "room_id": room.to_string(), - "content": content, - }) + builder = self.event_builder_factory.new( + { + "type": EventTypes.Member, + "sender": user.to_string(), + "state_key": user.to_string(), + "room_id": room.to_string(), + "content": content, + } + ) event, context = yield self.event_creation_handler.create_new_client_event( builder @@ -70,13 +71,15 @@ class RedactionTestCase(unittest.TestCase): def inject_message(self, room, user, body): self.depth += 1 - builder = self.event_builder_factory.new({ - "type": EventTypes.Message, - "sender": user.to_string(), - "state_key": user.to_string(), - "room_id": room.to_string(), - "content": {"body": body, "msgtype": u"message"}, - }) + builder = self.event_builder_factory.new( + { + "type": EventTypes.Message, + "sender": user.to_string(), + "state_key": user.to_string(), + "room_id": room.to_string(), + "content": {"body": body, "msgtype": u"message"}, + } + ) event, context = yield self.event_creation_handler.create_new_client_event( builder @@ -88,14 +91,16 @@ class RedactionTestCase(unittest.TestCase): @defer.inlineCallbacks def inject_redaction(self, room, event_id, user, reason): - builder = self.event_builder_factory.new({ - "type": EventTypes.Redaction, - "sender": user.to_string(), - "state_key": user.to_string(), - "room_id": room.to_string(), - "content": {"reason": reason}, - "redacts": event_id, - }) + builder = self.event_builder_factory.new( + { + "type": EventTypes.Redaction, + "sender": user.to_string(), + "state_key": user.to_string(), + "room_id": room.to_string(), + "content": {"reason": reason}, + "redacts": event_id, + } + ) event, context = yield self.event_creation_handler.create_new_client_event( builder @@ -105,9 +110,7 @@ class RedactionTestCase(unittest.TestCase): @defer.inlineCallbacks def test_redact(self): - yield self.inject_room_member( - self.room1, self.u_alice, Membership.JOIN - ) + yield self.inject_room_member(self.room1, self.u_alice, Membership.JOIN) msg_event = yield self.inject_message(self.room1, self.u_alice, u"t") @@ -157,13 +160,10 @@ class RedactionTestCase(unittest.TestCase): @defer.inlineCallbacks def test_redact_join(self): - yield self.inject_room_member( - self.room1, self.u_alice, Membership.JOIN - ) + yield self.inject_room_member(self.room1, self.u_alice, Membership.JOIN) msg_event = yield self.inject_room_member( - self.room1, self.u_bob, Membership.JOIN, - extra_content={"blue": "red"}, + self.room1, self.u_bob, Membership.JOIN, extra_content={"blue": "red"} ) event = yield self.store.get_event(msg_event.event_id) diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index 7821ea3fa3..bd96896bb3 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -21,7 +21,6 @@ from tests.utils import setup_test_homeserver class RegistrationStoreTestCase(unittest.TestCase): - @defer.inlineCallbacks def setUp(self): hs = yield setup_test_homeserver() @@ -30,10 +29,7 @@ class RegistrationStoreTestCase(unittest.TestCase): self.store = hs.get_datastore() self.user_id = "@my-user:test" - self.tokens = [ - "AbCdEfGhIjKlMnOpQrStUvWxYz", - "BcDeFgHiJkLmNoPqRsTuVwXyZa" - ] + self.tokens = ["AbCdEfGhIjKlMnOpQrStUvWxYz", "BcDeFgHiJkLmNoPqRsTuVwXyZa"] self.pwhash = "{xx1}123456789" self.device_id = "akgjhdjklgshg" @@ -51,34 +47,26 @@ class RegistrationStoreTestCase(unittest.TestCase): "consent_server_notice_sent": None, "appservice_id": None, }, - (yield self.store.get_user_by_id(self.user_id)) + (yield self.store.get_user_by_id(self.user_id)), ) result = yield self.store.get_user_by_access_token(self.tokens[0]) - self.assertDictContainsSubset( - { - "name": self.user_id, - }, - result - ) + self.assertDictContainsSubset({"name": self.user_id}, result) self.assertTrue("token_id" in result) @defer.inlineCallbacks def test_add_tokens(self): yield self.store.register(self.user_id, self.tokens[0], self.pwhash) - yield self.store.add_access_token_to_user(self.user_id, self.tokens[1], - self.device_id) + yield self.store.add_access_token_to_user( + self.user_id, self.tokens[1], self.device_id + ) result = yield self.store.get_user_by_access_token(self.tokens[1]) self.assertDictContainsSubset( - { - "name": self.user_id, - "device_id": self.device_id, - }, - result + {"name": self.user_id, "device_id": self.device_id}, result ) self.assertTrue("token_id" in result) @@ -87,12 +75,13 @@ class RegistrationStoreTestCase(unittest.TestCase): def test_user_delete_access_tokens(self): # add some tokens yield self.store.register(self.user_id, self.tokens[0], self.pwhash) - yield self.store.add_access_token_to_user(self.user_id, self.tokens[1], - self.device_id) + yield self.store.add_access_token_to_user( + self.user_id, self.tokens[1], self.device_id + ) # now delete some yield self.store.user_delete_access_tokens( - self.user_id, device_id=self.device_id, + self.user_id, device_id=self.device_id ) # check they were deleted @@ -107,8 +96,7 @@ class RegistrationStoreTestCase(unittest.TestCase): yield self.store.user_delete_access_tokens(self.user_id) user = yield self.store.get_user_by_access_token(self.tokens[0]) - self.assertIsNone(user, - "access token was not deleted without device_id") + self.assertIsNone(user, "access token was not deleted without device_id") class TokenGenerator: @@ -117,4 +105,4 @@ class TokenGenerator: def generate(self, user_id): self._last_issued_token += 1 - return u"%s-%d" % (user_id, self._last_issued_token,) + return u"%s-%d" % (user_id, self._last_issued_token) diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py index ae8ae94b6d..84d49b55c1 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py @@ -24,7 +24,6 @@ from tests.utils import setup_test_homeserver class RoomStoreTestCase(unittest.TestCase): - @defer.inlineCallbacks def setUp(self): hs = yield setup_test_homeserver() @@ -40,7 +39,7 @@ class RoomStoreTestCase(unittest.TestCase): yield self.store.store_room( self.room.to_string(), room_creator_user_id=self.u_creator.to_string(), - is_public=True + is_public=True, ) @defer.inlineCallbacks @@ -49,14 +48,13 @@ class RoomStoreTestCase(unittest.TestCase): { "room_id": self.room.to_string(), "creator": self.u_creator.to_string(), - "is_public": True + "is_public": True, }, - (yield self.store.get_room(self.room.to_string())) + (yield self.store.get_room(self.room.to_string())), ) class RoomEventsStoreTestCase(unittest.TestCase): - @defer.inlineCallbacks def setUp(self): hs = setup_test_homeserver() @@ -69,18 +67,13 @@ class RoomEventsStoreTestCase(unittest.TestCase): self.room = RoomID.from_string("!abcde:test") yield self.store.store_room( - self.room.to_string(), - room_creator_user_id="@creator:text", - is_public=True + self.room.to_string(), room_creator_user_id="@creator:text", is_public=True ) @defer.inlineCallbacks def inject_room_event(self, **kwargs): yield self.store.persist_event( - self.event_factory.create_event( - room_id=self.room.to_string(), - **kwargs - ) + self.event_factory.create_event(room_id=self.room.to_string(), **kwargs) ) @defer.inlineCallbacks @@ -88,22 +81,15 @@ class RoomEventsStoreTestCase(unittest.TestCase): name = u"A-Room-Name" yield self.inject_room_event( - etype=EventTypes.Name, - name=name, - content={"name": name}, - depth=1, + etype=EventTypes.Name, name=name, content={"name": name}, depth=1 ) - state = yield self.store.get_current_state( - room_id=self.room.to_string() - ) + state = yield self.store.get_current_state(room_id=self.room.to_string()) self.assertEquals(1, len(state)) self.assertObjectHasAttributes( - {"type": "m.room.name", - "room_id": self.room.to_string(), - "name": name}, - state[0] + {"type": "m.room.name", "room_id": self.room.to_string(), "name": name}, + state[0], ) @defer.inlineCallbacks @@ -111,22 +97,15 @@ class RoomEventsStoreTestCase(unittest.TestCase): topic = u"A place for things" yield self.inject_room_event( - etype=EventTypes.Topic, - topic=topic, - content={"topic": topic}, - depth=1, + etype=EventTypes.Topic, topic=topic, content={"topic": topic}, depth=1 ) - state = yield self.store.get_current_state( - room_id=self.room.to_string() - ) + state = yield self.store.get_current_state(room_id=self.room.to_string()) self.assertEquals(1, len(state)) self.assertObjectHasAttributes( - {"type": "m.room.topic", - "room_id": self.room.to_string(), - "topic": topic}, - state[0] + {"type": "m.room.topic", "room_id": self.room.to_string(), "topic": topic}, + state[0], ) # Not testing the various 'level' methods for now because there's lots diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index c5fd54f67e..0d9908926a 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -26,12 +26,10 @@ from tests.utils import setup_test_homeserver class RoomMemberStoreTestCase(unittest.TestCase): - @defer.inlineCallbacks def setUp(self): hs = yield setup_test_homeserver( - resource_for_federation=Mock(), - http_client=None, + resource_for_federation=Mock(), http_client=None ) # We can't test the RoomMemberStore on its own without the other event # storage logic @@ -49,13 +47,15 @@ class RoomMemberStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def inject_room_member(self, room, user, membership, replaces_state=None): - builder = self.event_builder_factory.new({ - "type": EventTypes.Member, - "sender": user.to_string(), - "state_key": user.to_string(), - "room_id": room.to_string(), - "content": {"membership": membership}, - }) + builder = self.event_builder_factory.new( + { + "type": EventTypes.Member, + "sender": user.to_string(), + "state_key": user.to_string(), + "room_id": room.to_string(), + "content": {"membership": membership}, + } + ) event, context = yield self.event_creation_handler.create_new_client_event( builder @@ -71,9 +71,12 @@ class RoomMemberStoreTestCase(unittest.TestCase): self.assertEquals( [self.room.to_string()], - [m.room_id for m in ( - yield self.store.get_rooms_for_user_where_membership_is( - self.u_alice.to_string(), [Membership.JOIN] + [ + m.room_id + for m in ( + yield self.store.get_rooms_for_user_where_membership_is( + self.u_alice.to_string(), [Membership.JOIN] + ) ) - )] + ], ) diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index f7871cd426..ed5b41644a 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -45,20 +45,20 @@ class StateStoreTestCase(tests.unittest.TestCase): self.room = RoomID.from_string("!abc123:test") yield self.store.store_room( - self.room.to_string(), - room_creator_user_id="@creator:text", - is_public=True + self.room.to_string(), room_creator_user_id="@creator:text", is_public=True ) @defer.inlineCallbacks def inject_state_event(self, room, sender, typ, state_key, content): - builder = self.event_builder_factory.new({ - "type": typ, - "sender": sender.to_string(), - "state_key": state_key, - "room_id": room.to_string(), - "content": content, - }) + builder = self.event_builder_factory.new( + { + "type": typ, + "sender": sender.to_string(), + "state_key": state_key, + "room_id": room.to_string(), + "content": content, + } + ) event, context = yield self.event_creation_handler.create_new_client_event( builder @@ -80,27 +80,31 @@ class StateStoreTestCase(tests.unittest.TestCase): # this defaults to a linear DAG as each new injection defaults to whatever # forward extremities are currently in the DB for this room. e1 = yield self.inject_state_event( - self.room, self.u_alice, EventTypes.Create, '', {}, + self.room, self.u_alice, EventTypes.Create, '', {} ) e2 = yield self.inject_state_event( - self.room, self.u_alice, EventTypes.Name, '', { - "name": "test room" - }, + self.room, self.u_alice, EventTypes.Name, '', {"name": "test room"} ) e3 = yield self.inject_state_event( - self.room, self.u_alice, EventTypes.Member, self.u_alice.to_string(), { - "membership": Membership.JOIN - }, + self.room, + self.u_alice, + EventTypes.Member, + self.u_alice.to_string(), + {"membership": Membership.JOIN}, ) e4 = yield self.inject_state_event( - self.room, self.u_bob, EventTypes.Member, self.u_bob.to_string(), { - "membership": Membership.JOIN - }, + self.room, + self.u_bob, + EventTypes.Member, + self.u_bob.to_string(), + {"membership": Membership.JOIN}, ) e5 = yield self.inject_state_event( - self.room, self.u_bob, EventTypes.Member, self.u_bob.to_string(), { - "membership": Membership.LEAVE - }, + self.room, + self.u_bob, + EventTypes.Member, + self.u_bob.to_string(), + {"membership": Membership.LEAVE}, ) # check we get the full state as of the final event @@ -110,65 +114,66 @@ class StateStoreTestCase(tests.unittest.TestCase): self.assertIsNotNone(e4) - self.assertStateMapEqual({ - (e1.type, e1.state_key): e1, - (e2.type, e2.state_key): e2, - (e3.type, e3.state_key): e3, - # e4 is overwritten by e5 - (e5.type, e5.state_key): e5, - }, state) + self.assertStateMapEqual( + { + (e1.type, e1.state_key): e1, + (e2.type, e2.state_key): e2, + (e3.type, e3.state_key): e3, + # e4 is overwritten by e5 + (e5.type, e5.state_key): e5, + }, + state, + ) # check we can filter to the m.room.name event (with a '' state key) state = yield self.store.get_state_for_event( e5.event_id, [(EventTypes.Name, '')], filtered_types=None ) - self.assertStateMapEqual({ - (e2.type, e2.state_key): e2, - }, state) + self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state) # check we can filter to the m.room.name event (with a wildcard None state key) state = yield self.store.get_state_for_event( e5.event_id, [(EventTypes.Name, None)], filtered_types=None ) - self.assertStateMapEqual({ - (e2.type, e2.state_key): e2, - }, state) + self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state) # check we can grab the m.room.member events (with a wildcard None state key) state = yield self.store.get_state_for_event( e5.event_id, [(EventTypes.Member, None)], filtered_types=None ) - self.assertStateMapEqual({ - (e3.type, e3.state_key): e3, - (e5.type, e5.state_key): e5, - }, state) + self.assertStateMapEqual( + {(e3.type, e3.state_key): e3, (e5.type, e5.state_key): e5}, state + ) # check we can use filter_types to grab a specific room member # without filtering out the other event types state = yield self.store.get_state_for_event( - e5.event_id, [(EventTypes.Member, self.u_alice.to_string())], + e5.event_id, + [(EventTypes.Member, self.u_alice.to_string())], filtered_types=[EventTypes.Member], ) - self.assertStateMapEqual({ - (e1.type, e1.state_key): e1, - (e2.type, e2.state_key): e2, - (e3.type, e3.state_key): e3, - }, state) + self.assertStateMapEqual( + { + (e1.type, e1.state_key): e1, + (e2.type, e2.state_key): e2, + (e3.type, e3.state_key): e3, + }, + state, + ) # check that types=[], filtered_types=[EventTypes.Member] # doesn't return all members state = yield self.store.get_state_for_event( - e5.event_id, [], filtered_types=[EventTypes.Member], + e5.event_id, [], filtered_types=[EventTypes.Member] ) - self.assertStateMapEqual({ - (e1.type, e1.state_key): e1, - (e2.type, e2.state_key): e2, - }, state) + self.assertStateMapEqual( + {(e1.type, e1.state_key): e1, (e2.type, e2.state_key): e2}, state + ) ####################################################### # _get_some_state_from_cache tests against a full cache @@ -184,10 +189,13 @@ class StateStoreTestCase(tests.unittest.TestCase): ) self.assertEqual(is_all, True) - self.assertDictEqual({ - (e1.type, e1.state_key): e1.event_id, - (e2.type, e2.state_key): e2.event_id, - }, state_dict) + self.assertDictEqual( + { + (e1.type, e1.state_key): e1.event_id, + (e2.type, e2.state_key): e2.event_id, + }, + state_dict, + ) # test _get_some_state_from_cache correctly filters in members with wildcard types (state_dict, is_all) = yield self.store._get_some_state_from_cache( @@ -195,25 +203,33 @@ class StateStoreTestCase(tests.unittest.TestCase): ) self.assertEqual(is_all, True) - self.assertDictEqual({ - (e1.type, e1.state_key): e1.event_id, - (e2.type, e2.state_key): e2.event_id, - (e3.type, e3.state_key): e3.event_id, - # e4 is overwritten by e5 - (e5.type, e5.state_key): e5.event_id, - }, state_dict) + self.assertDictEqual( + { + (e1.type, e1.state_key): e1.event_id, + (e2.type, e2.state_key): e2.event_id, + (e3.type, e3.state_key): e3.event_id, + # e4 is overwritten by e5 + (e5.type, e5.state_key): e5.event_id, + }, + state_dict, + ) # test _get_some_state_from_cache correctly filters in members with specific types (state_dict, is_all) = yield self.store._get_some_state_from_cache( - group, [(EventTypes.Member, e5.state_key)], filtered_types=[EventTypes.Member] + group, + [(EventTypes.Member, e5.state_key)], + filtered_types=[EventTypes.Member], ) self.assertEqual(is_all, True) - self.assertDictEqual({ - (e1.type, e1.state_key): e1.event_id, - (e2.type, e2.state_key): e2.event_id, - (e5.type, e5.state_key): e5.event_id, - }, state_dict) + self.assertDictEqual( + { + (e1.type, e1.state_key): e1.event_id, + (e2.type, e2.state_key): e2.event_id, + (e5.type, e5.state_key): e5.event_id, + }, + state_dict, + ) # test _get_some_state_from_cache correctly filters in members with specific types # and no filtered_types @@ -222,24 +238,27 @@ class StateStoreTestCase(tests.unittest.TestCase): ) self.assertEqual(is_all, True) - self.assertDictEqual({ - (e5.type, e5.state_key): e5.event_id, - }, state_dict) + self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict) ####################################################### # deliberately remove e2 (room name) from the _state_group_cache - (is_all, known_absent, state_dict_ids) = self.store._state_group_cache.get(group) + (is_all, known_absent, state_dict_ids) = self.store._state_group_cache.get( + group + ) self.assertEqual(is_all, True) self.assertEqual(known_absent, set()) - self.assertDictEqual(state_dict_ids, { - (e1.type, e1.state_key): e1.event_id, - (e2.type, e2.state_key): e2.event_id, - (e3.type, e3.state_key): e3.event_id, - # e4 is overwritten by e5 - (e5.type, e5.state_key): e5.event_id, - }) + self.assertDictEqual( + state_dict_ids, + { + (e1.type, e1.state_key): e1.event_id, + (e2.type, e2.state_key): e2.event_id, + (e3.type, e3.state_key): e3.event_id, + # e4 is overwritten by e5 + (e5.type, e5.state_key): e5.event_id, + }, + ) state_dict_ids.pop((e2.type, e2.state_key)) self.store._state_group_cache.invalidate(group) @@ -252,22 +271,32 @@ class StateStoreTestCase(tests.unittest.TestCase): (e1.type, e1.state_key), (e3.type, e3.state_key), (e5.type, e5.state_key), - ) + ), ) - (is_all, known_absent, state_dict_ids) = self.store._state_group_cache.get(group) + (is_all, known_absent, state_dict_ids) = self.store._state_group_cache.get( + group + ) self.assertEqual(is_all, False) - self.assertEqual(known_absent, set([ - (e1.type, e1.state_key), - (e3.type, e3.state_key), - (e5.type, e5.state_key), - ])) - self.assertDictEqual(state_dict_ids, { - (e1.type, e1.state_key): e1.event_id, - (e3.type, e3.state_key): e3.event_id, - (e5.type, e5.state_key): e5.event_id, - }) + self.assertEqual( + known_absent, + set( + [ + (e1.type, e1.state_key), + (e3.type, e3.state_key), + (e5.type, e5.state_key), + ] + ), + ) + self.assertDictEqual( + state_dict_ids, + { + (e1.type, e1.state_key): e1.event_id, + (e3.type, e3.state_key): e3.event_id, + (e5.type, e5.state_key): e5.event_id, + }, + ) ############################################ # test that things work with a partial cache @@ -279,9 +308,7 @@ class StateStoreTestCase(tests.unittest.TestCase): ) self.assertEqual(is_all, False) - self.assertDictEqual({ - (e1.type, e1.state_key): e1.event_id, - }, state_dict) + self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict) # test _get_some_state_from_cache correctly filters in members wildcard types (state_dict, is_all) = yield self.store._get_some_state_from_cache( @@ -289,23 +316,31 @@ class StateStoreTestCase(tests.unittest.TestCase): ) self.assertEqual(is_all, False) - self.assertDictEqual({ - (e1.type, e1.state_key): e1.event_id, - (e3.type, e3.state_key): e3.event_id, - # e4 is overwritten by e5 - (e5.type, e5.state_key): e5.event_id, - }, state_dict) + self.assertDictEqual( + { + (e1.type, e1.state_key): e1.event_id, + (e3.type, e3.state_key): e3.event_id, + # e4 is overwritten by e5 + (e5.type, e5.state_key): e5.event_id, + }, + state_dict, + ) # test _get_some_state_from_cache correctly filters in members with specific types (state_dict, is_all) = yield self.store._get_some_state_from_cache( - group, [(EventTypes.Member, e5.state_key)], filtered_types=[EventTypes.Member] + group, + [(EventTypes.Member, e5.state_key)], + filtered_types=[EventTypes.Member], ) self.assertEqual(is_all, False) - self.assertDictEqual({ - (e1.type, e1.state_key): e1.event_id, - (e5.type, e5.state_key): e5.event_id, - }, state_dict) + self.assertDictEqual( + { + (e1.type, e1.state_key): e1.event_id, + (e5.type, e5.state_key): e5.event_id, + }, + state_dict, + ) # test _get_some_state_from_cache correctly filters in members with specific types # and no filtered_types @@ -314,6 +349,4 @@ class StateStoreTestCase(tests.unittest.TestCase): ) self.assertEqual(is_all, True) - self.assertDictEqual({ - (e5.type, e5.state_key): e5.event_id, - }, state_dict) + self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict) diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py index 23fad12bca..7a273eab48 100644 --- a/tests/storage/test_user_directory.py +++ b/tests/storage/test_user_directory.py @@ -39,20 +39,12 @@ class UserDirectoryStoreTestCase(unittest.TestCase): { ALICE: ProfileInfo(None, "alice"), BOB: ProfileInfo(None, "bob"), - BOBBY: ProfileInfo(None, "bobby") + BOBBY: ProfileInfo(None, "bobby"), }, ) - yield self.store.add_users_to_public_room( - "!room:id", - [ALICE, BOB], - ) + yield self.store.add_users_to_public_room("!room:id", [ALICE, BOB]) yield self.store.add_users_who_share_room( - "!room:id", - False, - ( - (ALICE, BOB), - (BOB, ALICE), - ), + "!room:id", False, ((ALICE, BOB), (BOB, ALICE)) ) @defer.inlineCallbacks @@ -62,11 +54,9 @@ class UserDirectoryStoreTestCase(unittest.TestCase): r = yield self.store.search_user_dir(ALICE, "bob", 10) self.assertFalse(r["limited"]) self.assertEqual(1, len(r["results"])) - self.assertDictEqual(r["results"][0], { - "user_id": BOB, - "display_name": "bob", - "avatar_url": None, - }) + self.assertDictEqual( + r["results"][0], {"user_id": BOB, "display_name": "bob", "avatar_url": None} + ) @defer.inlineCallbacks def test_search_user_dir_all_users(self): @@ -75,15 +65,13 @@ class UserDirectoryStoreTestCase(unittest.TestCase): r = yield self.store.search_user_dir(ALICE, "bob", 10) self.assertFalse(r["limited"]) self.assertEqual(2, len(r["results"])) - self.assertDictEqual(r["results"][0], { - "user_id": BOB, - "display_name": "bob", - "avatar_url": None, - }) - self.assertDictEqual(r["results"][1], { - "user_id": BOBBY, - "display_name": "bobby", - "avatar_url": None, - }) + self.assertDictEqual( + r["results"][0], + {"user_id": BOB, "display_name": "bob", "avatar_url": None}, + ) + self.assertDictEqual( + r["results"][1], + {"user_id": BOBBY, "display_name": "bobby", "avatar_url": None}, + ) finally: self.hs.config.user_directory_search_all_users = False diff --git a/tests/test_distributor.py b/tests/test_distributor.py index 71d11cda77..b57f36e6ac 100644 --- a/tests/test_distributor.py +++ b/tests/test_distributor.py @@ -22,7 +22,6 @@ from . import unittest class DistributorTestCase(unittest.TestCase): - def setUp(self): self.dist = Distributor() @@ -44,18 +43,14 @@ class DistributorTestCase(unittest.TestCase): observers[0].side_effect = Exception("Awoogah!") - with patch( - "synapse.util.distributor.logger", spec=["warning"] - ) as mock_logger: + with patch("synapse.util.distributor.logger", spec=["warning"]) as mock_logger: self.dist.fire("alarm", "Go") observers[0].assert_called_once_with("Go") observers[1].assert_called_once_with("Go") self.assertEquals(mock_logger.warning.call_count, 1) - self.assertIsInstance( - mock_logger.warning.call_args[0][0], str - ) + self.assertIsInstance(mock_logger.warning.call_args[0][0], str) def test_signal_prereg(self): observer = Mock() @@ -69,4 +64,5 @@ class DistributorTestCase(unittest.TestCase): def test_signal_undeclared(self): def code(): self.dist.fire("notification") + self.assertRaises(KeyError, code) diff --git a/tests/test_dns.py b/tests/test_dns.py index b647d92697..90bd34be34 100644 --- a/tests/test_dns.py +++ b/tests/test_dns.py @@ -27,7 +27,6 @@ from . import unittest @unittest.DEBUG class DnsTestCase(unittest.TestCase): - @defer.inlineCallbacks def test_resolve(self): dns_client_mock = Mock() @@ -36,14 +35,11 @@ class DnsTestCase(unittest.TestCase): host_name = "example.com" answer_srv = dns.RRHeader( - type=dns.SRV, - payload=dns.Record_SRV( - target=host_name, - ) + type=dns.SRV, payload=dns.Record_SRV(target=host_name) ) dns_client_mock.lookupService.return_value = defer.succeed( - ([answer_srv], None, None), + ([answer_srv], None, None) ) cache = {} @@ -68,9 +64,7 @@ class DnsTestCase(unittest.TestCase): entry = Mock(spec_set=["expires"]) entry.expires = 0 - cache = { - service_name: [entry] - } + cache = {service_name: [entry]} servers = yield resolve_service( service_name, dns_client=dns_client_mock, cache=cache @@ -93,12 +87,10 @@ class DnsTestCase(unittest.TestCase): entry = Mock(spec_set=["expires"]) entry.expires = 999999999 - cache = { - service_name: [entry] - } + cache = {service_name: [entry]} servers = yield resolve_service( - service_name, dns_client=dns_client_mock, cache=cache, clock=clock, + service_name, dns_client=dns_client_mock, cache=cache, clock=clock ) self.assertFalse(dns_client_mock.lookupService.called) @@ -117,9 +109,7 @@ class DnsTestCase(unittest.TestCase): cache = {} with self.assertRaises(error.DNSServerError): - yield resolve_service( - service_name, dns_client=dns_client_mock, cache=cache - ) + yield resolve_service(service_name, dns_client=dns_client_mock, cache=cache) @defer.inlineCallbacks def test_name_error(self): diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py index 06112430e5..411b4a9f86 100644 --- a/tests/test_event_auth.py +++ b/tests/test_event_auth.py @@ -35,10 +35,7 @@ class EventAuthTestCase(unittest.TestCase): } # creator should be able to send state - event_auth.check( - _random_state_event(creator), auth_events, - do_sig_check=False, - ) + event_auth.check(_random_state_event(creator), auth_events, do_sig_check=False) # joiner should not be able to send state self.assertRaises( @@ -61,13 +58,9 @@ class EventAuthTestCase(unittest.TestCase): auth_events = { ("m.room.create", ""): _create_event(creator), ("m.room.member", creator): _join_event(creator), - ("m.room.power_levels", ""): _power_levels_event(creator, { - "state_default": "30", - "users": { - pleb: "29", - king: "30", - }, - }), + ("m.room.power_levels", ""): _power_levels_event( + creator, {"state_default": "30", "users": {pleb: "29", king: "30"}} + ), ("m.room.member", pleb): _join_event(pleb), ("m.room.member", king): _join_event(king), } @@ -82,10 +75,7 @@ class EventAuthTestCase(unittest.TestCase): ), # king should be able to send state - event_auth.check( - _random_state_event(king), auth_events, - do_sig_check=False, - ) + event_auth.check(_random_state_event(king), auth_events, do_sig_check=False) # helpers for making events @@ -94,52 +84,54 @@ TEST_ROOM_ID = "!test:room" def _create_event(user_id): - return FrozenEvent({ - "room_id": TEST_ROOM_ID, - "event_id": _get_event_id(), - "type": "m.room.create", - "sender": user_id, - "content": { - "creator": user_id, - }, - }) + return FrozenEvent( + { + "room_id": TEST_ROOM_ID, + "event_id": _get_event_id(), + "type": "m.room.create", + "sender": user_id, + "content": {"creator": user_id}, + } + ) def _join_event(user_id): - return FrozenEvent({ - "room_id": TEST_ROOM_ID, - "event_id": _get_event_id(), - "type": "m.room.member", - "sender": user_id, - "state_key": user_id, - "content": { - "membership": "join", - }, - }) + return FrozenEvent( + { + "room_id": TEST_ROOM_ID, + "event_id": _get_event_id(), + "type": "m.room.member", + "sender": user_id, + "state_key": user_id, + "content": {"membership": "join"}, + } + ) def _power_levels_event(sender, content): - return FrozenEvent({ - "room_id": TEST_ROOM_ID, - "event_id": _get_event_id(), - "type": "m.room.power_levels", - "sender": sender, - "state_key": "", - "content": content, - }) + return FrozenEvent( + { + "room_id": TEST_ROOM_ID, + "event_id": _get_event_id(), + "type": "m.room.power_levels", + "sender": sender, + "state_key": "", + "content": content, + } + ) def _random_state_event(sender): - return FrozenEvent({ - "room_id": TEST_ROOM_ID, - "event_id": _get_event_id(), - "type": "test.state", - "sender": sender, - "state_key": "", - "content": { - "membership": "join", - }, - }) + return FrozenEvent( + { + "room_id": TEST_ROOM_ID, + "event_id": _get_event_id(), + "type": "test.state", + "sender": sender, + "state_key": "", + "content": {"membership": "join"}, + } + ) event_count = 0 @@ -149,4 +141,4 @@ def _get_event_id(): global event_count c = event_count event_count += 1 - return "!%i:example.com" % (c, ) + return "!%i:example.com" % (c,) diff --git a/tests/test_preview.py b/tests/test_preview.py index 446843367e..84ef5e5ba4 100644 --- a/tests/test_preview.py +++ b/tests/test_preview.py @@ -22,7 +22,6 @@ from . import unittest class PreviewTestCase(unittest.TestCase): - def test_long_summarize(self): example_paras = [ u"""Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami: @@ -32,7 +31,6 @@ class PreviewTestCase(unittest.TestCase): alternative spellings of the city.Tromsø is considered the northernmost city in the world with a population above 50,000. The most populous town north of it is Alta, Norway, with a population of 14,272 (2013).""", - u"""Tromsø lies in Northern Norway. The municipality has a population of (2015) 72,066, but with an annual influx of students it has over 75,000 most of the year. It is the largest urban area in Northern Norway and the @@ -46,7 +44,6 @@ class PreviewTestCase(unittest.TestCase): Sandnessund Bridge. Tromsø Airport connects the city to many destinations in Europe. The city is warmer than most other places located on the same latitude, due to the warming effect of the Gulf Stream.""", - u"""The city centre of Tromsø contains the highest number of old wooden houses in Northern Norway, the oldest house dating from 1789. The Arctic Cathedral, a modern church from 1965, is probably the most famous landmark @@ -67,7 +64,7 @@ class PreviewTestCase(unittest.TestCase): u" the city of Tromsø. Outside of Norway, Tromso and Tromsö are" u" alternative spellings of the city.Tromsø is considered the northernmost" u" city in the world with a population above 50,000. The most populous town" - u" north of it is Alta, Norway, with a population of 14,272 (2013)." + u" north of it is Alta, Norway, with a population of 14,272 (2013).", ) desc = summarize_paragraphs(example_paras[1:], min_size=200, max_size=500) @@ -80,7 +77,7 @@ class PreviewTestCase(unittest.TestCase): u" third largest north of the Arctic Circle (following Murmansk and Norilsk)." u" Most of Tromsø, including the city centre, is located on the island of" u" Tromsøya, 350 kilometres (217 mi) north of the Arctic Circle. In 2012," - u" Tromsøya had a population of 36,088. Substantial parts of the urban…" + u" Tromsøya had a population of 36,088. Substantial parts of the urban…", ) def test_short_summarize(self): @@ -88,11 +85,9 @@ class PreviewTestCase(unittest.TestCase): u"Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" u" Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" u" Troms county, Norway.", - u"Tromsø lies in Northern Norway. The municipality has a population of" u" (2015) 72,066, but with an annual influx of students it has over 75,000" u" most of the year.", - u"The city centre of Tromsø contains the highest number of old wooden" u" houses in Northern Norway, the oldest house dating from 1789. The Arctic" u" Cathedral, a modern church from 1965, is probably the most famous landmark" @@ -109,7 +104,7 @@ class PreviewTestCase(unittest.TestCase): u"\n" u"Tromsø lies in Northern Norway. The municipality has a population of" u" (2015) 72,066, but with an annual influx of students it has over 75,000" - u" most of the year." + u" most of the year.", ) def test_small_then_large_summarize(self): @@ -117,7 +112,6 @@ class PreviewTestCase(unittest.TestCase): u"Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" u" Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" u" Troms county, Norway.", - u"Tromsø lies in Northern Norway. The municipality has a population of" u" (2015) 72,066, but with an annual influx of students it has over 75,000" u" most of the year." @@ -138,7 +132,7 @@ class PreviewTestCase(unittest.TestCase): u" (2015) 72,066, but with an annual influx of students it has over 75,000" u" most of the year. The city centre of Tromsø contains the highest number" u" of old wooden houses in Northern Norway, the oldest house dating from" - u" 1789. The Arctic Cathedral, a modern church from…" + u" 1789. The Arctic Cathedral, a modern church from…", ) @@ -155,10 +149,7 @@ class PreviewUrlTestCase(unittest.TestCase): og = decode_and_calc_og(html, "http://example.com/test.html") - self.assertEquals(og, { - u"og:title": u"Foo", - u"og:description": u"Some text." - }) + self.assertEquals(og, {u"og:title": u"Foo", u"og:description": u"Some text."}) def test_comment(self): html = u""" @@ -173,10 +164,7 @@ class PreviewUrlTestCase(unittest.TestCase): og = decode_and_calc_og(html, "http://example.com/test.html") - self.assertEquals(og, { - u"og:title": u"Foo", - u"og:description": u"Some text." - }) + self.assertEquals(og, {u"og:title": u"Foo", u"og:description": u"Some text."}) def test_comment2(self): html = u""" @@ -194,10 +182,13 @@ class PreviewUrlTestCase(unittest.TestCase): og = decode_and_calc_og(html, "http://example.com/test.html") - self.assertEquals(og, { - u"og:title": u"Foo", - u"og:description": u"Some text.\n\nSome more text.\n\nText\n\nMore text" - }) + self.assertEquals( + og, + { + u"og:title": u"Foo", + u"og:description": u"Some text.\n\nSome more text.\n\nText\n\nMore text", + }, + ) def test_script(self): html = u""" @@ -212,10 +203,7 @@ class PreviewUrlTestCase(unittest.TestCase): og = decode_and_calc_og(html, "http://example.com/test.html") - self.assertEquals(og, { - u"og:title": u"Foo", - u"og:description": u"Some text." - }) + self.assertEquals(og, {u"og:title": u"Foo", u"og:description": u"Some text."}) def test_missing_title(self): html = u""" @@ -228,10 +216,7 @@ class PreviewUrlTestCase(unittest.TestCase): og = decode_and_calc_og(html, "http://example.com/test.html") - self.assertEquals(og, { - u"og:title": None, - u"og:description": u"Some text." - }) + self.assertEquals(og, {u"og:title": None, u"og:description": u"Some text."}) def test_h1_as_title(self): html = u""" @@ -245,10 +230,7 @@ class PreviewUrlTestCase(unittest.TestCase): og = decode_and_calc_og(html, "http://example.com/test.html") - self.assertEquals(og, { - u"og:title": u"Title", - u"og:description": u"Some text." - }) + self.assertEquals(og, {u"og:title": u"Title", u"og:description": u"Some text."}) def test_missing_title_and_broken_h1(self): html = u""" @@ -262,7 +244,4 @@ class PreviewUrlTestCase(unittest.TestCase): og = decode_and_calc_og(html, "http://example.com/test.html") - self.assertEquals(og, { - u"og:title": None, - u"og:description": u"Some text." - }) + self.assertEquals(og, {u"og:title": None, u"og:description": u"Some text."}) diff --git a/tests/test_state.py b/tests/test_state.py index 429a18cbf7..96fdb8636c 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -29,8 +29,15 @@ from .utils import MockClock _next_event_id = 1000 -def create_event(name=None, type=None, state_key=None, depth=2, event_id=None, - prev_events=[], **kwargs): +def create_event( + name=None, + type=None, + state_key=None, + depth=2, + event_id=None, + prev_events=[], + **kwargs +): global _next_event_id if not event_id: @@ -39,9 +46,9 @@ def create_event(name=None, type=None, state_key=None, depth=2, event_id=None, if not name: if state_key is not None: - name = "<%s-%s, %s>" % (type, state_key, event_id,) + name = "<%s-%s, %s>" % (type, state_key, event_id) else: - name = "<%s, %s>" % (type, event_id,) + name = "<%s, %s>" % (type, event_id) d = { "event_id": event_id, @@ -80,8 +87,9 @@ class StateGroupStore(object): return defer.succeed(groups) - def store_state_group(self, event_id, room_id, prev_group, delta_ids, - current_state_ids): + def store_state_group( + self, event_id, room_id, prev_group, delta_ids, current_state_ids + ): state_group = self._next_group self._next_group += 1 @@ -91,7 +99,8 @@ class StateGroupStore(object): def get_events(self, event_ids, **kwargs): return { - e_id: self._event_id_to_event[e_id] for e_id in event_ids + e_id: self._event_id_to_event[e_id] + for e_id in event_ids if e_id in self._event_id_to_event } @@ -129,9 +138,7 @@ class Graph(object): prev_events = [] events[event_id] = create_event( - event_id=event_id, - prev_events=prev_events, - **fields + event_id=event_id, prev_events=prev_events, **fields ) self._leaves = clobbered @@ -147,10 +154,15 @@ class Graph(object): class StateTestCase(unittest.TestCase): def setUp(self): self.store = StateGroupStore() - hs = Mock(spec_set=[ - "get_datastore", "get_auth", "get_state_handler", "get_clock", - "get_state_resolution_handler", - ]) + hs = Mock( + spec_set=[ + "get_datastore", + "get_auth", + "get_state_handler", + "get_clock", + "get_state_resolution_handler", + ] + ) hs.get_datastore.return_value = self.store hs.get_state_handler.return_value = None hs.get_clock.return_value = MockClock() @@ -164,35 +176,13 @@ class StateTestCase(unittest.TestCase): def test_branch_no_conflict(self): graph = Graph( nodes={ - "START": DictObj( - type=EventTypes.Create, - state_key="", - depth=1, - ), - "A": DictObj( - type=EventTypes.Message, - depth=2, - ), - "B": DictObj( - type=EventTypes.Message, - depth=3, - ), - "C": DictObj( - type=EventTypes.Name, - state_key="", - depth=3, - ), - "D": DictObj( - type=EventTypes.Message, - depth=4, - ), + "START": DictObj(type=EventTypes.Create, state_key="", depth=1), + "A": DictObj(type=EventTypes.Message, depth=2), + "B": DictObj(type=EventTypes.Message, depth=3), + "C": DictObj(type=EventTypes.Name, state_key="", depth=3), + "D": DictObj(type=EventTypes.Message, depth=4), }, - edges={ - "A": ["START"], - "B": ["A"], - "C": ["A"], - "D": ["B", "C"] - } + edges={"A": ["START"], "B": ["A"], "C": ["A"], "D": ["B", "C"]}, ) self.store.register_events(graph.walk()) @@ -224,27 +214,11 @@ class StateTestCase(unittest.TestCase): membership=Membership.JOIN, depth=2, ), - "B": DictObj( - type=EventTypes.Name, - state_key="", - depth=3, - ), - "C": DictObj( - type=EventTypes.Name, - state_key="", - depth=4, - ), - "D": DictObj( - type=EventTypes.Message, - depth=5, - ), + "B": DictObj(type=EventTypes.Name, state_key="", depth=3), + "C": DictObj(type=EventTypes.Name, state_key="", depth=4), + "D": DictObj(type=EventTypes.Message, depth=5), }, - edges={ - "A": ["START"], - "B": ["A"], - "C": ["A"], - "D": ["B", "C"] - } + edges={"A": ["START"], "B": ["A"], "C": ["A"], "D": ["B", "C"]}, ) self.store.register_events(graph.walk()) @@ -259,8 +233,7 @@ class StateTestCase(unittest.TestCase): prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store) self.assertSetEqual( - {"START", "A", "C"}, - {e_id for e_id in prev_state_ids.values()} + {"START", "A", "C"}, {e_id for e_id in prev_state_ids.values()} ) @defer.inlineCallbacks @@ -280,11 +253,7 @@ class StateTestCase(unittest.TestCase): membership=Membership.JOIN, depth=2, ), - "B": DictObj( - type=EventTypes.Name, - state_key="", - depth=3, - ), + "B": DictObj(type=EventTypes.Name, state_key="", depth=3), "C": DictObj( type=EventTypes.Member, state_key="@user_id_2:example.com", @@ -298,18 +267,9 @@ class StateTestCase(unittest.TestCase): depth=4, sender="@user_id_2:example.com", ), - "E": DictObj( - type=EventTypes.Message, - depth=5, - ), + "E": DictObj(type=EventTypes.Message, depth=5), }, - edges={ - "A": ["START"], - "B": ["A"], - "C": ["B"], - "D": ["B"], - "E": ["C", "D"] - } + edges={"A": ["START"], "B": ["A"], "C": ["B"], "D": ["B"], "E": ["C", "D"]}, ) self.store.register_events(graph.walk()) @@ -324,8 +284,7 @@ class StateTestCase(unittest.TestCase): prev_state_ids = yield context_store["E"].get_prev_state_ids(self.store) self.assertSetEqual( - {"START", "A", "B", "C"}, - {e for e in prev_state_ids.values()} + {"START", "A", "B", "C"}, {e for e in prev_state_ids.values()} ) @defer.inlineCallbacks @@ -357,30 +316,17 @@ class StateTestCase(unittest.TestCase): state_key="", content={ "events": {"m.room.name": 50}, - "users": {userid1: 100, - userid2: 60}, + "users": {userid1: 100, userid2: 60}, }, ), - "A5": DictObj( - type=EventTypes.Name, - state_key="", - ), + "A5": DictObj(type=EventTypes.Name, state_key=""), "B": DictObj( type=EventTypes.PowerLevels, state_key="", - content={ - "events": {"m.room.name": 50}, - "users": {userid2: 30}, - }, - ), - "C": DictObj( - type=EventTypes.Name, - state_key="", - sender=userid2, - ), - "D": DictObj( - type=EventTypes.Message, + content={"events": {"m.room.name": 50}, "users": {userid2: 30}}, ), + "C": DictObj(type=EventTypes.Name, state_key="", sender=userid2), + "D": DictObj(type=EventTypes.Message), } edges = { "A2": ["A1"], @@ -389,7 +335,7 @@ class StateTestCase(unittest.TestCase): "A5": ["A4"], "B": ["A5"], "C": ["A5"], - "D": ["B", "C"] + "D": ["B", "C"], } self._add_depths(nodes, edges) graph = Graph(nodes, edges) @@ -406,8 +352,7 @@ class StateTestCase(unittest.TestCase): prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store) self.assertSetEqual( - {"A1", "A2", "A3", "A5", "B"}, - {e for e in prev_state_ids.values()} + {"A1", "A2", "A3", "A5", "B"}, {e for e in prev_state_ids.values()} ) def _add_depths(self, nodes, edges): @@ -432,9 +377,7 @@ class StateTestCase(unittest.TestCase): create_event(type="test2", state_key=""), ] - context = yield self.state.compute_event_context( - event, old_state=old_state - ) + context = yield self.state.compute_event_context(event, old_state=old_state) current_state_ids = yield context.get_current_state_ids(self.store) @@ -454,9 +397,7 @@ class StateTestCase(unittest.TestCase): create_event(type="test2", state_key=""), ] - context = yield self.state.compute_event_context( - event, old_state=old_state - ) + context = yield self.state.compute_event_context(event, old_state=old_state) prev_state_ids = yield context.get_prev_state_ids(self.store) @@ -468,8 +409,7 @@ class StateTestCase(unittest.TestCase): def test_trivial_annotate_message(self): prev_event_id = "prev_event_id" event = create_event( - type="test_message", name="event2", - prev_events=[(prev_event_id, {})], + type="test_message", name="event2", prev_events=[(prev_event_id, {})] ) old_state = [ @@ -479,7 +419,10 @@ class StateTestCase(unittest.TestCase): ] group_name = self.store.store_state_group( - prev_event_id, event.room_id, None, None, + prev_event_id, + event.room_id, + None, + None, {(e.type, e.state_key): e.event_id for e in old_state}, ) self.store.register_event_id_state_group(prev_event_id, group_name) @@ -489,8 +432,7 @@ class StateTestCase(unittest.TestCase): current_state_ids = yield context.get_current_state_ids(self.store) self.assertEqual( - set([e.event_id for e in old_state]), - set(current_state_ids.values()) + set([e.event_id for e in old_state]), set(current_state_ids.values()) ) self.assertEqual(group_name, context.state_group) @@ -499,8 +441,7 @@ class StateTestCase(unittest.TestCase): def test_trivial_annotate_state(self): prev_event_id = "prev_event_id" event = create_event( - type="state", state_key="", name="event2", - prev_events=[(prev_event_id, {})], + type="state", state_key="", name="event2", prev_events=[(prev_event_id, {})] ) old_state = [ @@ -510,7 +451,10 @@ class StateTestCase(unittest.TestCase): ] group_name = self.store.store_state_group( - prev_event_id, event.room_id, None, None, + prev_event_id, + event.room_id, + None, + None, {(e.type, e.state_key): e.event_id for e in old_state}, ) self.store.register_event_id_state_group(prev_event_id, group_name) @@ -520,8 +464,7 @@ class StateTestCase(unittest.TestCase): prev_state_ids = yield context.get_prev_state_ids(self.store) self.assertEqual( - set([e.event_id for e in old_state]), - set(prev_state_ids.values()) + set([e.event_id for e in old_state]), set(prev_state_ids.values()) ) self.assertIsNotNone(context.state_group) @@ -531,13 +474,12 @@ class StateTestCase(unittest.TestCase): prev_event_id1 = "event_id1" prev_event_id2 = "event_id2" event = create_event( - type="test_message", name="event3", + type="test_message", + name="event3", prev_events=[(prev_event_id1, {}), (prev_event_id2, {})], ) - creation = create_event( - type=EventTypes.Create, state_key="" - ) + creation = create_event(type=EventTypes.Create, state_key="") old_state_1 = [ creation, @@ -557,7 +499,7 @@ class StateTestCase(unittest.TestCase): self.store.register_events(old_state_2) context = yield self._get_context( - event, prev_event_id1, old_state_1, prev_event_id2, old_state_2, + event, prev_event_id1, old_state_1, prev_event_id2, old_state_2 ) current_state_ids = yield context.get_current_state_ids(self.store) @@ -571,13 +513,13 @@ class StateTestCase(unittest.TestCase): prev_event_id1 = "event_id1" prev_event_id2 = "event_id2" event = create_event( - type="test4", state_key="", name="event", + type="test4", + state_key="", + name="event", prev_events=[(prev_event_id1, {}), (prev_event_id2, {})], ) - creation = create_event( - type=EventTypes.Create, state_key="" - ) + creation = create_event(type=EventTypes.Create, state_key="") old_state_1 = [ creation, @@ -599,7 +541,7 @@ class StateTestCase(unittest.TestCase): self.store.get_events = store.get_events context = yield self._get_context( - event, prev_event_id1, old_state_1, prev_event_id2, old_state_2, + event, prev_event_id1, old_state_1, prev_event_id2, old_state_2 ) current_state_ids = yield context.get_current_state_ids(self.store) @@ -613,29 +555,25 @@ class StateTestCase(unittest.TestCase): prev_event_id1 = "event_id1" prev_event_id2 = "event_id2" event = create_event( - type="test4", name="event", + type="test4", + name="event", prev_events=[(prev_event_id1, {}), (prev_event_id2, {})], ) member_event = create_event( type=EventTypes.Member, state_key="@user_id:example.com", - content={ - "membership": Membership.JOIN, - } + content={"membership": Membership.JOIN}, ) power_levels = create_event( - type=EventTypes.PowerLevels, state_key="", - content={"users": { - "@foo:bar": "100", - "@user_id:example.com": "100", - }} + type=EventTypes.PowerLevels, + state_key="", + content={"users": {"@foo:bar": "100", "@user_id:example.com": "100"}}, ) creation = create_event( - type=EventTypes.Create, state_key="", - content={"creator": "@foo:bar"} + type=EventTypes.Create, state_key="", content={"creator": "@foo:bar"} ) old_state_1 = [ @@ -658,14 +596,12 @@ class StateTestCase(unittest.TestCase): self.store.get_events = store.get_events context = yield self._get_context( - event, prev_event_id1, old_state_1, prev_event_id2, old_state_2, + event, prev_event_id1, old_state_1, prev_event_id2, old_state_2 ) current_state_ids = yield context.get_current_state_ids(self.store) - self.assertEqual( - old_state_2[3].event_id, current_state_ids[("test1", "1")] - ) + self.assertEqual(old_state_2[3].event_id, current_state_ids[("test1", "1")]) # Reverse the depth to make sure we are actually using the depths # during state resolution. @@ -688,25 +624,30 @@ class StateTestCase(unittest.TestCase): store.register_events(old_state_2) context = yield self._get_context( - event, prev_event_id1, old_state_1, prev_event_id2, old_state_2, + event, prev_event_id1, old_state_1, prev_event_id2, old_state_2 ) current_state_ids = yield context.get_current_state_ids(self.store) - self.assertEqual( - old_state_1[3].event_id, current_state_ids[("test1", "1")] - ) + self.assertEqual(old_state_1[3].event_id, current_state_ids[("test1", "1")]) - def _get_context(self, event, prev_event_id_1, old_state_1, prev_event_id_2, - old_state_2): + def _get_context( + self, event, prev_event_id_1, old_state_1, prev_event_id_2, old_state_2 + ): sg1 = self.store.store_state_group( - prev_event_id_1, event.room_id, None, None, + prev_event_id_1, + event.room_id, + None, + None, {(e.type, e.state_key): e.event_id for e in old_state_1}, ) self.store.register_event_id_state_group(prev_event_id_1, sg1) sg2 = self.store.store_state_group( - prev_event_id_2, event.room_id, None, None, + prev_event_id_2, + event.room_id, + None, + None, {(e.type, e.state_key): e.event_id for e in old_state_2}, ) self.store.register_event_id_state_group(prev_event_id_2, sg2) diff --git a/tests/test_test_utils.py b/tests/test_test_utils.py index bc97c12245..b921ac52c0 100644 --- a/tests/test_test_utils.py +++ b/tests/test_test_utils.py @@ -18,7 +18,6 @@ from tests.utils import MockClock class MockClockTestCase(unittest.TestCase): - def setUp(self): self.clock = MockClock() @@ -34,10 +33,12 @@ class MockClockTestCase(unittest.TestCase): def _cb0(): invoked[0] = 1 + self.clock.call_later(10, _cb0) def _cb1(): invoked[1] = 1 + self.clock.call_later(20, _cb1) self.assertFalse(invoked[0]) @@ -56,10 +57,12 @@ class MockClockTestCase(unittest.TestCase): def _cb0(): invoked[0] = 1 + t0 = self.clock.call_later(10, _cb0) def _cb1(): invoked[1] = 1 + self.clock.call_later(20, _cb1) self.clock.cancel_call_later(t0) diff --git a/tests/test_types.py b/tests/test_types.py index 729bd676c1..be072d402b 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -69,10 +69,7 @@ class GroupIDTestCase(unittest.TestCase): self.assertEqual("my.domain", group_id.domain) def test_validate(self): - bad_ids = [ - "$badsigil:domain", - "+:empty", - ] + [ + bad_ids = ["$badsigil:domain", "+:empty"] + [ "+group" + c + ":domain" for c in "A%?æ£" ] for id_string in bad_ids: diff --git a/tests/test_visibility.py b/tests/test_visibility.py index 0dc1a924d3..8643d63125 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -54,14 +54,12 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): events_to_filter = [] for i in range(0, 10): - user = "@user%i:%s" % ( - i, "test_server" if i == 5 else "other_server" - ) + user = "@user%i:%s" % (i, "test_server" if i == 5 else "other_server") evt = yield self.inject_room_member(user, extra_content={"a": "b"}) events_to_filter.append(evt) filtered = yield filter_events_for_server( - self.store, "test_server", events_to_filter, + self.store, "test_server", events_to_filter ) # the result should be 5 redacted events, and 5 unredacted events. @@ -100,19 +98,21 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): # ... and the filtering happens. filtered = yield filter_events_for_server( - self.store, "test_server", events_to_filter, + self.store, "test_server", events_to_filter ) for i in range(0, len(events_to_filter)): self.assertEqual( - events_to_filter[i].event_id, filtered[i].event_id, - "Unexpected event at result position %i" % (i, ) + events_to_filter[i].event_id, + filtered[i].event_id, + "Unexpected event at result position %i" % (i,), ) for i in (0, 3): self.assertEqual( - events_to_filter[i].content["body"], filtered[i].content["body"], - "Unexpected event content at result position %i" % (i,) + events_to_filter[i].content["body"], + filtered[i].content["body"], + "Unexpected event content at result position %i" % (i,), ) for i in (1, 4): @@ -121,13 +121,15 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def inject_visibility(self, user_id, visibility): content = {"history_visibility": visibility} - builder = self.event_builder_factory.new({ - "type": "m.room.history_visibility", - "sender": user_id, - "state_key": "", - "room_id": TEST_ROOM_ID, - "content": content, - }) + builder = self.event_builder_factory.new( + { + "type": "m.room.history_visibility", + "sender": user_id, + "state_key": "", + "room_id": TEST_ROOM_ID, + "content": content, + } + ) event, context = yield self.event_creation_handler.create_new_client_event( builder @@ -139,13 +141,15 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): def inject_room_member(self, user_id, membership="join", extra_content={}): content = {"membership": membership} content.update(extra_content) - builder = self.event_builder_factory.new({ - "type": "m.room.member", - "sender": user_id, - "state_key": user_id, - "room_id": TEST_ROOM_ID, - "content": content, - }) + builder = self.event_builder_factory.new( + { + "type": "m.room.member", + "sender": user_id, + "state_key": user_id, + "room_id": TEST_ROOM_ID, + "content": content, + } + ) event, context = yield self.event_creation_handler.create_new_client_event( builder @@ -158,12 +162,14 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): def inject_message(self, user_id, content=None): if content is None: content = {"body": "testytest"} - builder = self.event_builder_factory.new({ - "type": "m.room.message", - "sender": user_id, - "room_id": TEST_ROOM_ID, - "content": content, - }) + builder = self.event_builder_factory.new( + { + "type": "m.room.message", + "sender": user_id, + "room_id": TEST_ROOM_ID, + "content": content, + } + ) event, context = yield self.event_creation_handler.create_new_client_event( builder @@ -192,56 +198,54 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): # history_visibility event. room_state = [] - history_visibility_evt = FrozenEvent({ - "event_id": "$history_vis", - "type": "m.room.history_visibility", - "sender": "@resident_user_0:test.com", - "state_key": "", - "room_id": TEST_ROOM_ID, - "content": {"history_visibility": "joined"}, - }) + history_visibility_evt = FrozenEvent( + { + "event_id": "$history_vis", + "type": "m.room.history_visibility", + "sender": "@resident_user_0:test.com", + "state_key": "", + "room_id": TEST_ROOM_ID, + "content": {"history_visibility": "joined"}, + } + ) room_state.append(history_visibility_evt) test_store.add_event(history_visibility_evt) for i in range(0, 100000): - user = "@resident_user_%i:test.com" % (i, ) - evt = FrozenEvent({ - "event_id": "$res_event_%i" % (i, ), - "type": "m.room.member", - "state_key": user, - "sender": user, - "room_id": TEST_ROOM_ID, - "content": { - "membership": "join", - "extra": "zzz," - }, - }) + user = "@resident_user_%i:test.com" % (i,) + evt = FrozenEvent( + { + "event_id": "$res_event_%i" % (i,), + "type": "m.room.member", + "state_key": user, + "sender": user, + "room_id": TEST_ROOM_ID, + "content": {"membership": "join", "extra": "zzz,"}, + } + ) room_state.append(evt) test_store.add_event(evt) events_to_filter = [] for i in range(0, 10): - user = "@user%i:%s" % ( - i, "test_server" if i == 5 else "other_server" + user = "@user%i:%s" % (i, "test_server" if i == 5 else "other_server") + evt = FrozenEvent( + { + "event_id": "$evt%i" % (i,), + "type": "m.room.member", + "state_key": user, + "sender": user, + "room_id": TEST_ROOM_ID, + "content": {"membership": "join", "extra": "zzz"}, + } ) - evt = FrozenEvent({ - "event_id": "$evt%i" % (i, ), - "type": "m.room.member", - "state_key": user, - "sender": user, - "room_id": TEST_ROOM_ID, - "content": { - "membership": "join", - "extra": "zzz", - }, - }) events_to_filter.append(evt) room_state.append(evt) test_store.add_event(evt) - test_store.set_state_ids_for_event(evt, { - (e.type, e.state_key): e.event_id for e in room_state - }) + test_store.set_state_ids_for_event( + evt, {(e.type, e.state_key): e.event_id for e in room_state} + ) pr = cProfile.Profile() pr.enable() @@ -249,7 +253,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): logger.info("Starting filtering") start = time.time() filtered = yield filter_events_for_server( - test_store, "test_server", events_to_filter, + test_store, "test_server", events_to_filter ) logger.info("Filtering took %f seconds", time.time() - start) @@ -275,6 +279,7 @@ class _TestStore(object): filter_events_for_server """ + def __init__(self): # data for get_events: a map from event_id to event self.events = {} @@ -298,8 +303,8 @@ class _TestStore(object): continue if type != "m.room.member" or state_key is not None: raise RuntimeError( - "Unimplemented: get_state_ids with type (%s, %s)" % - (type, state_key), + "Unimplemented: get_state_ids with type (%s, %s)" + % (type, state_key) ) include_memberships = True @@ -316,9 +321,7 @@ class _TestStore(object): return succeed(res) def get_events(self, events): - return succeed({ - event_id: self.events[event_id] for event_id in events - }) + return succeed({event_id: self.events[event_id] for event_id in events}) def are_users_erased(self, users): return succeed({u: False for u in users}) diff --git a/tests/unittest.py b/tests/unittest.py index b15b06726b..f448a6dfbd 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -56,6 +56,7 @@ def around(target): def method_name(orig, *args, **kwargs): return orig(*args, **kwargs) """ + def _around(code): name = code.__name__ orig = getattr(target, name) @@ -89,6 +90,7 @@ class TestCase(unittest.TestCase): old_level = logging.getLogger().level if old_level != level: + @around(self) def tearDown(orig): ret = orig() @@ -117,8 +119,9 @@ class TestCase(unittest.TestCase): actual (dict): The test result. Extra keys will not be checked. """ for key in required: - self.assertEquals(required[key], actual[key], - msg="%s mismatch. %s" % (key, actual)) + self.assertEquals( + required[key], actual[key], msg="%s mismatch. %s" % (key, actual) + ) def DEBUG(target): diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index ca8a7c907f..463a737efa 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -67,12 +67,8 @@ class CacheTestCase(unittest.TestCase): self.assertIsNone(cache.get("key2", None)) # both callbacks should have been callbacked - self.assertTrue( - callback_record[0], "Invalidation callback for key1 not called", - ) - self.assertTrue( - callback_record[1], "Invalidation callback for key2 not called", - ) + self.assertTrue(callback_record[0], "Invalidation callback for key1 not called") + self.assertTrue(callback_record[1], "Invalidation callback for key2 not called") # letting the other lookup complete should do nothing d1.callback("result1") @@ -168,8 +164,7 @@ class DescriptorTestCase(unittest.TestCase): with logcontext.LoggingContext() as c1: c1.name = "c1" r = yield obj.fn(1) - self.assertEqual(logcontext.LoggingContext.current_context(), - c1) + self.assertEqual(logcontext.LoggingContext.current_context(), c1) defer.returnValue(r) def check_result(r): @@ -179,14 +174,18 @@ class DescriptorTestCase(unittest.TestCase): # set off a deferred which will do a cache lookup d1 = do_lookup() - self.assertEqual(logcontext.LoggingContext.current_context(), - logcontext.LoggingContext.sentinel) + self.assertEqual( + logcontext.LoggingContext.current_context(), + logcontext.LoggingContext.sentinel, + ) d1.addCallback(check_result) # and another d2 = do_lookup() - self.assertEqual(logcontext.LoggingContext.current_context(), - logcontext.LoggingContext.sentinel) + self.assertEqual( + logcontext.LoggingContext.current_context(), + logcontext.LoggingContext.sentinel, + ) d2.addCallback(check_result) # let the lookup complete @@ -224,15 +223,16 @@ class DescriptorTestCase(unittest.TestCase): except SynapseError: pass - self.assertEqual(logcontext.LoggingContext.current_context(), - c1) + self.assertEqual(logcontext.LoggingContext.current_context(), c1) obj = Cls() # set off a deferred which will do a cache lookup d1 = do_lookup() - self.assertEqual(logcontext.LoggingContext.current_context(), - logcontext.LoggingContext.sentinel) + self.assertEqual( + logcontext.LoggingContext.current_context(), + logcontext.LoggingContext.sentinel, + ) return d1 @@ -288,14 +288,10 @@ class CachedListDescriptorTestCase(unittest.TestCase): @descriptors.cachedList("fn", "args1", inlineCallbacks=True) def list_fn(self, args1, arg2): - assert ( - logcontext.LoggingContext.current_context().request == "c1" - ) + assert logcontext.LoggingContext.current_context().request == "c1" # we want this to behave like an asynchronous function yield run_on_reactor() - assert ( - logcontext.LoggingContext.current_context().request == "c1" - ) + assert logcontext.LoggingContext.current_context().request == "c1" defer.returnValue(self.mock(args1, arg2)) with logcontext.LoggingContext() as c1: @@ -308,10 +304,7 @@ class CachedListDescriptorTestCase(unittest.TestCase): logcontext.LoggingContext.sentinel, ) r = yield d1 - self.assertEqual( - logcontext.LoggingContext.current_context(), - c1 - ) + self.assertEqual(logcontext.LoggingContext.current_context(), c1) obj.mock.assert_called_once_with([10, 20], 2) self.assertEqual(r, {10: 'fish', 20: 'chips'}) obj.mock.reset_mock() @@ -337,6 +330,7 @@ class CachedListDescriptorTestCase(unittest.TestCase): @defer.inlineCallbacks def test_invalidate(self): """Make sure that invalidation callbacks are called.""" + class Cls(object): def __init__(self): self.mock = mock.Mock() diff --git a/tests/util/test_dict_cache.py b/tests/util/test_dict_cache.py index 26f2fa5800..34fdc9a43a 100644 --- a/tests/util/test_dict_cache.py +++ b/tests/util/test_dict_cache.py @@ -20,7 +20,6 @@ from tests import unittest class DictCacheTestCase(unittest.TestCase): - def setUp(self): self.cache = DictionaryCache("foobar") @@ -41,9 +40,7 @@ class DictCacheTestCase(unittest.TestCase): key = "test_simple_cache_hit_partial" seq = self.cache.sequence - test_value = { - "test": "test_simple_cache_hit_partial" - } + test_value = {"test": "test_simple_cache_hit_partial"} self.cache.update(seq, key, test_value) c = self.cache.get(key, ["test"]) @@ -53,9 +50,7 @@ class DictCacheTestCase(unittest.TestCase): key = "test_simple_cache_miss_partial" seq = self.cache.sequence - test_value = { - "test": "test_simple_cache_miss_partial" - } + test_value = {"test": "test_simple_cache_miss_partial"} self.cache.update(seq, key, test_value) c = self.cache.get(key, ["test2"]) @@ -79,15 +74,11 @@ class DictCacheTestCase(unittest.TestCase): key = "test_simple_cache_hit_miss_partial" seq = self.cache.sequence - test_value_1 = { - "test": "test_simple_cache_hit_miss_partial", - } + test_value_1 = {"test": "test_simple_cache_hit_miss_partial"} self.cache.update(seq, key, test_value_1, fetched_keys=set("test")) seq = self.cache.sequence - test_value_2 = { - "test2": "test_simple_cache_hit_miss_partial2", - } + test_value_2 = {"test2": "test_simple_cache_hit_miss_partial2"} self.cache.update(seq, key, test_value_2, fetched_keys=set("test2")) c = self.cache.get(key) @@ -96,5 +87,5 @@ class DictCacheTestCase(unittest.TestCase): "test": "test_simple_cache_hit_miss_partial", "test2": "test_simple_cache_hit_miss_partial2", }, - c.value + c.value, ) diff --git a/tests/util/test_expiring_cache.py b/tests/util/test_expiring_cache.py index d12b5e838b..5cbada4eda 100644 --- a/tests/util/test_expiring_cache.py +++ b/tests/util/test_expiring_cache.py @@ -22,7 +22,6 @@ from .. import unittest class ExpiringCacheTestCase(unittest.TestCase): - def test_get_set(self): clock = MockClock() cache = ExpiringCache("test", clock, max_len=1) diff --git a/tests/util/test_file_consumer.py b/tests/util/test_file_consumer.py index 7ce5f8c258..e90e08d1c0 100644 --- a/tests/util/test_file_consumer.py +++ b/tests/util/test_file_consumer.py @@ -27,7 +27,6 @@ from tests import unittest class FileConsumerTests(unittest.TestCase): - @defer.inlineCallbacks def test_pull_consumer(self): string_file = StringIO() @@ -87,7 +86,9 @@ class FileConsumerTests(unittest.TestCase): producer = NonCallableMock(spec_set=["pauseProducing", "resumeProducing"]) resume_deferred = defer.Deferred() - producer.resumeProducing.side_effect = lambda: resume_deferred.callback(None) + producer.resumeProducing.side_effect = lambda: resume_deferred.callback( + None + ) consumer.registerProducer(producer, True) diff --git a/tests/util/test_linearizer.py b/tests/util/test_linearizer.py index 7f48a72de8..61a55b461b 100644 --- a/tests/util/test_linearizer.py +++ b/tests/util/test_linearizer.py @@ -26,7 +26,6 @@ from tests import unittest class LinearizerTestCase(unittest.TestCase): - @defer.inlineCallbacks def test_linearizer(self): linearizer = Linearizer() @@ -54,13 +53,11 @@ class LinearizerTestCase(unittest.TestCase): def func(i, sleep=False): with logcontext.LoggingContext("func(%s)" % i) as lc: with (yield linearizer.queue("")): - self.assertEqual( - logcontext.LoggingContext.current_context(), lc) + self.assertEqual(logcontext.LoggingContext.current_context(), lc) if sleep: yield Clock(reactor).sleep(0) - self.assertEqual( - logcontext.LoggingContext.current_context(), lc) + self.assertEqual(logcontext.LoggingContext.current_context(), lc) func(0, sleep=True) for i in range(1, 100): diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py index c54001f7a4..4633db77b3 100644 --- a/tests/util/test_logcontext.py +++ b/tests/util/test_logcontext.py @@ -8,11 +8,8 @@ from .. import unittest class LoggingContextTestCase(unittest.TestCase): - def _check_test_key(self, value): - self.assertEquals( - LoggingContext.current_context().request, value - ) + self.assertEquals(LoggingContext.current_context().request, value) def test_with_context(self): with LoggingContext() as context_one: @@ -50,6 +47,7 @@ class LoggingContextTestCase(unittest.TestCase): self._check_test_key("one") callback_completed[0] = True return res + d.addCallback(cb) return d @@ -74,8 +72,7 @@ class LoggingContextTestCase(unittest.TestCase): # make sure that the context was reset before it got thrown back # into the reactor try: - self.assertIs(LoggingContext.current_context(), - sentinel_context) + self.assertIs(LoggingContext.current_context(), sentinel_context) d2.callback(None) except BaseException: d2.errback(twisted.python.failure.Failure()) @@ -104,9 +101,7 @@ class LoggingContextTestCase(unittest.TestCase): # a function which returns a deferred which looks like it has been # called, but is actually paused def testfunc(): - return logcontext.make_deferred_yieldable( - _chained_deferred_function() - ) + return logcontext.make_deferred_yieldable(_chained_deferred_function()) return self._test_run_in_background(testfunc) @@ -175,5 +170,6 @@ def _chained_deferred_function(): d2 = defer.Deferred() reactor.callLater(0, d2.callback, res) return d2 + d.addCallback(cb) return d diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py index 9b36ef4482..786947375d 100644 --- a/tests/util/test_lrucache.py +++ b/tests/util/test_lrucache.py @@ -23,7 +23,6 @@ from .. import unittest class LruCacheTestCase(unittest.TestCase): - def test_get_set(self): cache = LruCache(1) cache["key"] = "value" @@ -235,7 +234,6 @@ class LruCacheCallbacksTestCase(unittest.TestCase): class LruCacheSizedTestCase(unittest.TestCase): - def test_evict(self): cache = LruCache(5, size_callback=len) cache["key1"] = [0] diff --git a/tests/util/test_rwlock.py b/tests/util/test_rwlock.py index 7cd470be67..bd32e2cee7 100644 --- a/tests/util/test_rwlock.py +++ b/tests/util/test_rwlock.py @@ -20,7 +20,6 @@ from tests import unittest class ReadWriteLockTestCase(unittest.TestCase): - def _assert_called_before_not_after(self, lst, first_false): for i, d in enumerate(lst[:first_false]): self.assertTrue(d.called, msg="%d was unexpectedly false" % i) @@ -36,12 +35,12 @@ class ReadWriteLockTestCase(unittest.TestCase): key = object() ds = [ - rwlock.read(key), # 0 - rwlock.read(key), # 1 + rwlock.read(key), # 0 + rwlock.read(key), # 1 rwlock.write(key), # 2 rwlock.write(key), # 3 - rwlock.read(key), # 4 - rwlock.read(key), # 5 + rwlock.read(key), # 4 + rwlock.read(key), # 5 rwlock.write(key), # 6 ] diff --git a/tests/util/test_snapshot_cache.py b/tests/util/test_snapshot_cache.py index 0f5b32fcc0..1a44f72425 100644 --- a/tests/util/test_snapshot_cache.py +++ b/tests/util/test_snapshot_cache.py @@ -22,7 +22,6 @@ from .. import unittest class SnapshotCacheTestCase(unittest.TestCase): - def setUp(self): self.cache = SnapshotCache() self.cache.DURATION_MS = 1 diff --git a/tests/util/test_stream_change_cache.py b/tests/util/test_stream_change_cache.py index 65b0f2e6fb..f2be63706b 100644 --- a/tests/util/test_stream_change_cache.py +++ b/tests/util/test_stream_change_cache.py @@ -181,17 +181,8 @@ class StreamChangeCacheTests(unittest.TestCase): # Query a subset of the entries mid-way through the stream. We should # only get back the subset. self.assertEqual( - cache.get_entities_changed( - [ - "bar@baz.net", - ], - stream_pos=2, - ), - set( - [ - "bar@baz.net", - ] - ), + cache.get_entities_changed(["bar@baz.net"], stream_pos=2), + set(["bar@baz.net"]), ) def test_max_pos(self): diff --git a/tests/utils.py b/tests/utils.py index 8d73797971..8668b5478f 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -38,8 +38,9 @@ USE_POSTGRES_FOR_TESTS = False @defer.inlineCallbacks -def setup_test_homeserver(name="test", datastore=None, config=None, reactor=None, - **kargs): +def setup_test_homeserver( + name="test", datastore=None, config=None, reactor=None, **kargs +): """Setup a homeserver suitable for running tests against. Keyword arguments are passed to the Homeserver constructor. If no datastore is supplied a datastore backed by an in-memory sqlite db will be given to the HS. @@ -96,20 +97,12 @@ def setup_test_homeserver(name="test", datastore=None, config=None, reactor=None if USE_POSTGRES_FOR_TESTS: config.database_config = { "name": "psycopg2", - "args": { - "database": "synapse_test", - "cp_min": 1, - "cp_max": 5, - }, + "args": {"database": "synapse_test", "cp_min": 1, "cp_max": 5}, } else: config.database_config = { "name": "sqlite3", - "args": { - "database": ":memory:", - "cp_min": 1, - "cp_max": 1, - }, + "args": {"database": ":memory:", "cp_min": 1, "cp_max": 1}, } db_engine = create_engine(config.database_config) @@ -121,7 +114,8 @@ def setup_test_homeserver(name="test", datastore=None, config=None, reactor=None if datastore is None: hs = HomeServer( - name, config=config, + name, + config=config, db_config=config.database_config, version_string="Synapse/tests", database_engine=db_engine, @@ -143,7 +137,10 @@ def setup_test_homeserver(name="test", datastore=None, config=None, reactor=None hs.setup() else: hs = HomeServer( - name, db_pool=None, datastore=datastore, config=config, + name, + db_pool=None, + datastore=datastore, + config=config, version_string="Synapse/tests", database_engine=db_engine, room_list_handler=object(), @@ -158,8 +155,9 @@ def setup_test_homeserver(name="test", datastore=None, config=None, reactor=None # because AuthHandler's constructor requires the HS, so we can't make one # beforehand and pass it in to the HS's constructor (chicken / egg) hs.get_auth_handler().hash = lambda p: hashlib.md5(p.encode('utf8')).hexdigest() - hs.get_auth_handler().validate_hash = lambda p, h: hashlib.md5( - p.encode('utf8')).hexdigest() == h + hs.get_auth_handler().validate_hash = ( + lambda p, h: hashlib.md5(p.encode('utf8')).hexdigest() == h + ) fed = kargs.get("resource_for_federation", None) if fed: @@ -173,7 +171,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, reactor=None sleep_limit=hs.config.federation_rc_sleep_limit, sleep_msec=hs.config.federation_rc_sleep_delay, reject_limit=hs.config.federation_rc_reject_limit, - concurrent_requests=hs.config.federation_rc_concurrent + concurrent_requests=hs.config.federation_rc_concurrent, ), ) @@ -199,7 +197,6 @@ def mock_getRawHeaders(headers=None): # This is a mock /resource/ not an entire server class MockHttpResource(HttpServer): - def __init__(self, prefix=""): self.callbacks = [] # 3-tuple of method/pattern/function self.prefix = prefix @@ -263,15 +260,9 @@ class MockHttpResource(HttpServer): matcher = pattern.match(path) if matcher: try: - args = [ - urlparse.unquote(u) - for u in matcher.groups() - ] - - (code, response) = yield func( - mock_request, - *args - ) + args = [urlparse.unquote(u) for u in matcher.groups()] + + (code, response) = yield func(mock_request, *args) defer.returnValue((code, response)) except CodeMessageException as e: defer.returnValue((e.code, cs_error(e.msg, code=e.errcode))) @@ -372,8 +363,7 @@ class MockClock(object): def _format_call(args, kwargs): return ", ".join( - ["%r" % (a) for a in args] + - ["%s=%r" % (k, v) for k, v in kwargs.items()] + ["%r" % (a) for a in args] + ["%s=%r" % (k, v) for k, v in kwargs.items()] ) @@ -391,8 +381,9 @@ class DeferredMockCallable(object): self.calls.append((args, kwargs)) if not self.expectations: - raise ValueError("%r has no pending calls to handle call(%s)" % ( - self, _format_call(args, kwargs)) + raise ValueError( + "%r has no pending calls to handle call(%s)" + % (self, _format_call(args, kwargs)) ) for (call, result, d) in self.expectations: @@ -400,9 +391,9 @@ class DeferredMockCallable(object): d.callback(None) return result - failure = AssertionError("Was not expecting call(%s)" % ( - _format_call(args, kwargs) - )) + failure = AssertionError( + "Was not expecting call(%s)" % (_format_call(args, kwargs)) + ) for _, _, d in self.expectations: try: @@ -418,17 +409,19 @@ class DeferredMockCallable(object): @defer.inlineCallbacks def await_calls(self, timeout=1000): deferred = defer.DeferredList( - [d for _, _, d in self.expectations], - fireOnOneErrback=True + [d for _, _, d in self.expectations], fireOnOneErrback=True ) timer = reactor.callLater( timeout / 1000, deferred.errback, - AssertionError("%d pending calls left: %s" % ( - len([e for e in self.expectations if not e[2].called]), - [e for e in self.expectations if not e[2].called] - )) + AssertionError( + "%d pending calls left: %s" + % ( + len([e for e in self.expectations if not e[2].called]), + [e for e in self.expectations if not e[2].called], + ) + ), ) yield deferred @@ -443,7 +436,6 @@ class DeferredMockCallable(object): self.calls = [] raise AssertionError( - "Expected not to received any calls, got:\n" + "\n".join([ - "call(%s)" % _format_call(c[0], c[1]) for c in calls - ]) + "Expected not to received any calls, got:\n" + + "\n".join(["call(%s)" % _format_call(c[0], c[1]) for c in calls]) ) -- cgit 1.4.1 From 99dd975dae7baaaef2a3b0a92fa51965b121ae34 Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Mon, 13 Aug 2018 16:47:46 +1000 Subject: Run tests under PostgreSQL (#3423) --- .travis.yml | 10 +++ changelog.d/3423.misc | 1 + synapse/handlers/presence.py | 5 ++ synapse/storage/client_ips.py | 5 ++ tests/__init__.py | 3 + tests/api/test_auth.py | 2 +- tests/api/test_filtering.py | 5 +- tests/crypto/test_keyring.py | 6 +- tests/handlers/test_auth.py | 2 +- tests/handlers/test_device.py | 2 +- tests/handlers/test_directory.py | 1 + tests/handlers/test_e2e_keys.py | 2 +- tests/handlers/test_profile.py | 1 + tests/handlers/test_register.py | 1 + tests/handlers/test_typing.py | 1 + tests/replication/slave/storage/_base.py | 1 + tests/rest/client/v1/test_admin.py | 2 +- tests/rest/client/v1/test_events.py | 1 + tests/rest/client/v1/test_profile.py | 1 + tests/rest/client/v1/test_register.py | 2 +- tests/rest/client/v1/test_rooms.py | 1 + tests/rest/client/v1/test_typing.py | 1 + tests/rest/client/v2_alpha/test_filter.py | 2 +- tests/rest/client/v2_alpha/test_register.py | 2 +- tests/rest/client/v2_alpha/test_sync.py | 2 +- tests/server.py | 10 ++- tests/storage/test_appservice.py | 13 ++- tests/storage/test_background_update.py | 4 +- tests/storage/test_client_ips.py | 2 +- tests/storage/test_devices.py | 2 +- tests/storage/test_directory.py | 2 +- tests/storage/test_end_to_end_keys.py | 3 +- tests/storage/test_event_federation.py | 2 +- tests/storage/test_event_push_actions.py | 2 +- tests/storage/test_keys.py | 2 +- tests/storage/test_monthly_active_users.py | 2 +- tests/storage/test_presence.py | 2 +- tests/storage/test_profile.py | 2 +- tests/storage/test_redaction.py | 2 +- tests/storage/test_registration.py | 2 +- tests/storage/test_room.py | 4 +- tests/storage/test_roommember.py | 2 +- tests/storage/test_state.py | 2 +- tests/storage/test_user_directory.py | 2 +- tests/test_federation.py | 5 +- tests/test_server.py | 2 +- tests/test_visibility.py | 2 +- tests/utils.py | 133 ++++++++++++++++++++++++---- tox.ini | 20 ++++- 49 files changed, 227 insertions(+), 59 deletions(-) create mode 100644 changelog.d/3423.misc (limited to 'tests/handlers') diff --git a/.travis.yml b/.travis.yml index b34b17af75..318701c9f8 100644 --- a/.travis.yml +++ b/.travis.yml @@ -8,6 +8,9 @@ before_script: - git remote set-branches --add origin develop - git fetch origin develop +services: + - postgresql + matrix: fast_finish: true include: @@ -20,6 +23,9 @@ matrix: - python: 2.7 env: TOX_ENV=py27 + - python: 2.7 + env: TOX_ENV=py27-postgres TRIAL_FLAGS="-j 4" + - python: 3.6 env: TOX_ENV=py36 @@ -29,6 +35,10 @@ matrix: - python: 3.6 env: TOX_ENV=check-newsfragment + allow_failures: + - python: 2.7 + env: TOX_ENV=py27-postgres TRIAL_FLAGS="-j 4" + install: - pip install tox diff --git a/changelog.d/3423.misc b/changelog.d/3423.misc new file mode 100644 index 0000000000..51768c6d14 --- /dev/null +++ b/changelog.d/3423.misc @@ -0,0 +1 @@ +The test suite now can run under PostgreSQL. diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 20fc3b0323..3671d24f60 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -95,6 +95,7 @@ class PresenceHandler(object): Args: hs (synapse.server.HomeServer): """ + self.hs = hs self.is_mine = hs.is_mine self.is_mine_id = hs.is_mine_id self.clock = hs.get_clock() @@ -230,6 +231,10 @@ class PresenceHandler(object): earlier than they should when synapse is restarted. This affect of this is some spurious presence changes that will self-correct. """ + # If the DB pool has already terminated, don't try updating + if not self.hs.get_db_pool().running: + return + logger.info( "Performing _on_shutdown. Persisting %d unpersisted changes", len(self.user_to_current_state) diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py index 2489527f2c..8fc678fa67 100644 --- a/synapse/storage/client_ips.py +++ b/synapse/storage/client_ips.py @@ -96,6 +96,11 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): self._batch_row_update[key] = (user_agent, device_id, now) def _update_client_ips_batch(self): + + # If the DB pool has already terminated, don't try updating + if not self.hs.get_db_pool().running: + return + def update(): to_update = self._batch_row_update self._batch_row_update = {} diff --git a/tests/__init__.py b/tests/__init__.py index 24006c949e..9d9ca22829 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -15,4 +15,7 @@ from twisted.trial import util +from tests import utils + util.DEFAULT_TIMEOUT_DURATION = 10 +utils.setupdb() diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index f8e28876bb..a65689ba89 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -39,7 +39,7 @@ class AuthTestCase(unittest.TestCase): self.state_handler = Mock() self.store = Mock() - self.hs = yield setup_test_homeserver(handlers=None) + self.hs = yield setup_test_homeserver(self.addCleanup, handlers=None) self.hs.get_datastore = Mock(return_value=self.store) self.hs.handlers = TestHandlers(self.hs) self.auth = Auth(self.hs) diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index 1c2d71052c..48b2d3d663 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -46,7 +46,10 @@ class FilteringTestCase(unittest.TestCase): self.mock_http_client.put_json = DeferredMockCallable() hs = yield setup_test_homeserver( - handlers=None, http_client=self.mock_http_client, keyring=Mock() + self.addCleanup, + handlers=None, + http_client=self.mock_http_client, + keyring=Mock(), ) self.filtering = hs.get_filtering() diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index 0c6f510d11..8299dc72c8 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -58,12 +58,10 @@ class KeyringTestCase(unittest.TestCase): self.mock_perspective_server = MockPerspectiveServer() self.http_client = Mock() self.hs = yield utils.setup_test_homeserver( - handlers=None, http_client=self.http_client + self.addCleanup, handlers=None, http_client=self.http_client ) keys = self.mock_perspective_server.get_verify_keys() - self.hs.config.perspectives = { - self.mock_perspective_server.server_name: keys - } + self.hs.config.perspectives = {self.mock_perspective_server.server_name: keys} def check_context(self, _, expected): self.assertEquals( diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index ede01f8099..56c0f87fb7 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -35,7 +35,7 @@ class AuthHandlers(object): class AuthTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): - self.hs = yield setup_test_homeserver(handlers=None) + self.hs = yield setup_test_homeserver(self.addCleanup, handlers=None) self.hs.handlers = AuthHandlers(self.hs) self.auth_handler = self.hs.handlers.auth_handler self.macaroon_generator = self.hs.get_macaroon_generator() diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index d70d645504..56e7acd37c 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -34,7 +34,7 @@ class DeviceTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): - hs = yield utils.setup_test_homeserver() + hs = yield utils.setup_test_homeserver(self.addCleanup) self.handler = hs.get_device_handler() self.store = hs.get_datastore() self.clock = hs.get_clock() diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index 06de9f5eca..ec7355688b 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -46,6 +46,7 @@ class DirectoryTestCase(unittest.TestCase): self.mock_registry.register_query_handler = register_query_handler hs = yield setup_test_homeserver( + self.addCleanup, http_client=None, resource_for_federation=Mock(), federation_client=self.mock_federation, diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index 57ab228455..8dccc6826e 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -34,7 +34,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): self.hs = yield utils.setup_test_homeserver( - handlers=None, federation_client=mock.Mock() + self.addCleanup, handlers=None, federation_client=mock.Mock() ) self.handler = synapse.handlers.e2e_keys.E2eKeysHandler(self.hs) diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index 9268a6fe2b..62dc69003c 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -48,6 +48,7 @@ class ProfileTestCase(unittest.TestCase): self.mock_registry.register_query_handler = register_query_handler hs = yield setup_test_homeserver( + self.addCleanup, http_client=None, handlers=None, resource_for_federation=Mock(), diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index dbec81076f..d48d40c8dd 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -40,6 +40,7 @@ class RegistrationTestCase(unittest.TestCase): self.mock_distributor.declare("registered_user") self.mock_captcha_client = Mock() self.hs = yield setup_test_homeserver( + self.addCleanup, handlers=None, http_client=None, expire_access_token=True, diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index becfa77bfa..ad58073a14 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -67,6 +67,7 @@ class TypingNotificationsTestCase(unittest.TestCase): self.state_handler = Mock() hs = yield setup_test_homeserver( + self.addCleanup, "test", auth=self.auth, clock=self.clock, diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index c23b6e2cfd..65df116efc 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -54,6 +54,7 @@ class BaseSlavedStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): self.hs = yield setup_test_homeserver( + self.addCleanup, "blue", http_client=None, federation_client=Mock(), diff --git a/tests/rest/client/v1/test_admin.py b/tests/rest/client/v1/test_admin.py index 67d9ab94e2..1a553fa3f9 100644 --- a/tests/rest/client/v1/test_admin.py +++ b/tests/rest/client/v1/test_admin.py @@ -51,7 +51,7 @@ class UserRegisterTestCase(unittest.TestCase): self.secrets = Mock() self.hs = setup_test_homeserver( - http_client=None, clock=self.hs_clock, reactor=self.clock + self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.clock ) self.hs.config.registration_shared_secret = u"shared" diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py index 0316b74fa1..956f7fc4c4 100644 --- a/tests/rest/client/v1/test_events.py +++ b/tests/rest/client/v1/test_events.py @@ -41,6 +41,7 @@ class EventStreamPermissionsTestCase(RestTestCase): self.mock_resource = MockHttpResource(prefix=PATH_PREFIX) hs = yield setup_test_homeserver( + self.addCleanup, http_client=None, federation_client=Mock(), ratelimiter=NonCallableMock(spec_set=["send_message"]), diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py index 9ba0ffc19f..1eab9c3bdb 100644 --- a/tests/rest/client/v1/test_profile.py +++ b/tests/rest/client/v1/test_profile.py @@ -46,6 +46,7 @@ class ProfileTestCase(unittest.TestCase): ) hs = yield setup_test_homeserver( + self.addCleanup, "test", http_client=None, resource_for_client=self.mock_resource, diff --git a/tests/rest/client/v1/test_register.py b/tests/rest/client/v1/test_register.py index 6f15d69ecd..4be88b8a39 100644 --- a/tests/rest/client/v1/test_register.py +++ b/tests/rest/client/v1/test_register.py @@ -49,7 +49,7 @@ class CreateUserServletTestCase(unittest.TestCase): self.hs_clock = Clock(self.clock) self.hs = self.hs = setup_test_homeserver( - http_client=None, clock=self.hs_clock, reactor=self.clock + self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.clock ) self.hs.get_datastore = Mock(return_value=self.datastore) self.hs.get_handlers = Mock(return_value=handlers) diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 00fc796787..9fe0760496 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -50,6 +50,7 @@ class RoomBase(unittest.TestCase): self.hs_clock = Clock(self.clock) self.hs = setup_test_homeserver( + self.addCleanup, "red", http_client=None, clock=self.hs_clock, diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py index 7f1a435e7b..677265edf6 100644 --- a/tests/rest/client/v1/test_typing.py +++ b/tests/rest/client/v1/test_typing.py @@ -44,6 +44,7 @@ class RoomTypingTestCase(RestTestCase): self.auth_user_id = self.user_id hs = yield setup_test_homeserver( + self.addCleanup, "red", clock=self.clock, http_client=None, diff --git a/tests/rest/client/v2_alpha/test_filter.py b/tests/rest/client/v2_alpha/test_filter.py index de33b10a5f..8260c130f8 100644 --- a/tests/rest/client/v2_alpha/test_filter.py +++ b/tests/rest/client/v2_alpha/test_filter.py @@ -43,7 +43,7 @@ class FilterTestCase(unittest.TestCase): self.hs_clock = Clock(self.clock) self.hs = setup_test_homeserver( - http_client=None, clock=self.hs_clock, reactor=self.clock + self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.clock ) self.auth = self.hs.get_auth() diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index 9487babac3..b72bd0fb7f 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -47,7 +47,7 @@ class RegisterRestServletTestCase(unittest.TestCase): login_handler=self.login_handler, ) self.hs = setup_test_homeserver( - http_client=None, clock=self.hs_clock, reactor=self.clock + self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.clock ) self.hs.get_auth = Mock(return_value=self.auth) self.hs.get_handlers = Mock(return_value=self.handlers) diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py index bafc0d1df0..2e1d06c509 100644 --- a/tests/rest/client/v2_alpha/test_sync.py +++ b/tests/rest/client/v2_alpha/test_sync.py @@ -40,7 +40,7 @@ class FilterTestCase(unittest.TestCase): self.hs_clock = Clock(self.clock) self.hs = setup_test_homeserver( - http_client=None, clock=self.hs_clock, reactor=self.clock + self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.clock ) self.auth = self.hs.get_auth() diff --git a/tests/server.py b/tests/server.py index 05708be8b9..beb24cf032 100644 --- a/tests/server.py +++ b/tests/server.py @@ -147,12 +147,15 @@ class ThreadedMemoryReactorClock(MemoryReactorClock): return d -def setup_test_homeserver(*args, **kwargs): +def setup_test_homeserver(cleanup_func, *args, **kwargs): """ Set up a synchronous test server, driven by the reactor used by the homeserver. """ - d = _sth(*args, **kwargs).result + d = _sth(cleanup_func, *args, **kwargs).result + + if isinstance(d, Failure): + d.raiseException() # Make the thread pool synchronous. clock = d.get_clock() @@ -189,6 +192,9 @@ def setup_test_homeserver(*args, **kwargs): def start(self): pass + def stop(self): + pass + def callInThreadWithCallback(self, onResult, function, *args, **kwargs): def _(res): if isinstance(res, Failure): diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index fbb25a8844..c893990454 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -43,7 +43,10 @@ class ApplicationServiceStoreTestCase(unittest.TestCase): password_providers=[], ) hs = yield setup_test_homeserver( - config=config, federation_sender=Mock(), federation_client=Mock() + self.addCleanup, + config=config, + federation_sender=Mock(), + federation_client=Mock(), ) self.as_token = "token1" @@ -108,7 +111,10 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): password_providers=[], ) hs = yield setup_test_homeserver( - config=config, federation_sender=Mock(), federation_client=Mock() + self.addCleanup, + config=config, + federation_sender=Mock(), + federation_client=Mock(), ) self.db_pool = hs.get_db_pool() @@ -392,6 +398,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): app_service_config_files=[f1, f2], event_cache_size=1, password_providers=[] ) hs = yield setup_test_homeserver( + self.addCleanup, config=config, datastore=Mock(), federation_sender=Mock(), @@ -409,6 +416,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): app_service_config_files=[f1, f2], event_cache_size=1, password_providers=[] ) hs = yield setup_test_homeserver( + self.addCleanup, config=config, datastore=Mock(), federation_sender=Mock(), @@ -432,6 +440,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): app_service_config_files=[f1, f2], event_cache_size=1, password_providers=[] ) hs = yield setup_test_homeserver( + self.addCleanup, config=config, datastore=Mock(), federation_sender=Mock(), diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py index b4f6baf441..81403727c5 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py @@ -9,7 +9,9 @@ from tests.utils import setup_test_homeserver class BackgroundUpdateTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): - hs = yield setup_test_homeserver() # type: synapse.server.HomeServer + hs = yield setup_test_homeserver( + self.addCleanup + ) # type: synapse.server.HomeServer self.store = hs.get_datastore() self.clock = hs.get_clock() diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index ea00bbe84c..fa60d949ba 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -28,7 +28,7 @@ class ClientIpStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def setUp(self): - self.hs = yield tests.utils.setup_test_homeserver() + self.hs = yield tests.utils.setup_test_homeserver(self.addCleanup) self.store = self.hs.get_datastore() self.clock = self.hs.get_clock() diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py index 63bc42d9e0..aef4dfaf57 100644 --- a/tests/storage/test_devices.py +++ b/tests/storage/test_devices.py @@ -28,7 +28,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def setUp(self): - hs = yield tests.utils.setup_test_homeserver() + hs = yield tests.utils.setup_test_homeserver(self.addCleanup) self.store = hs.get_datastore() diff --git a/tests/storage/test_directory.py b/tests/storage/test_directory.py index 9a8ba2fcfe..b4510c1c8d 100644 --- a/tests/storage/test_directory.py +++ b/tests/storage/test_directory.py @@ -26,7 +26,7 @@ from tests.utils import setup_test_homeserver class DirectoryStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): - hs = yield setup_test_homeserver() + hs = yield setup_test_homeserver(self.addCleanup) self.store = DirectoryStore(None, hs) diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py index d45c775c2d..8f0aaece40 100644 --- a/tests/storage/test_end_to_end_keys.py +++ b/tests/storage/test_end_to_end_keys.py @@ -26,8 +26,7 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def setUp(self): - hs = yield tests.utils.setup_test_homeserver() - + hs = yield tests.utils.setup_test_homeserver(self.addCleanup) self.store = hs.get_datastore() @defer.inlineCallbacks diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index 66eb119581..2fdf34fdf6 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -22,7 +22,7 @@ import tests.utils class EventFederationWorkerStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def setUp(self): - hs = yield tests.utils.setup_test_homeserver() + hs = yield tests.utils.setup_test_homeserver(self.addCleanup) self.store = hs.get_datastore() @defer.inlineCallbacks diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index 5e87b4530d..b114c6fb1d 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -33,7 +33,7 @@ HIGHLIGHT = [ class EventPushActionsStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def setUp(self): - hs = yield tests.utils.setup_test_homeserver() + hs = yield tests.utils.setup_test_homeserver(self.addCleanup) self.store = hs.get_datastore() @defer.inlineCallbacks diff --git a/tests/storage/test_keys.py b/tests/storage/test_keys.py index ad0a55b324..47f4a8ceac 100644 --- a/tests/storage/test_keys.py +++ b/tests/storage/test_keys.py @@ -28,7 +28,7 @@ class KeyStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def setUp(self): - hs = yield tests.utils.setup_test_homeserver() + hs = yield tests.utils.setup_test_homeserver(self.addCleanup) self.store = hs.get_datastore() @defer.inlineCallbacks diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index 22b1072d9f..0a2c859f26 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -28,7 +28,7 @@ class MonthlyActiveUsersTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def setUp(self): - self.hs = yield setup_test_homeserver() + self.hs = yield setup_test_homeserver(self.addCleanup) self.store = self.hs.get_datastore() @defer.inlineCallbacks diff --git a/tests/storage/test_presence.py b/tests/storage/test_presence.py index 12c540dfab..b5b58ff660 100644 --- a/tests/storage/test_presence.py +++ b/tests/storage/test_presence.py @@ -26,7 +26,7 @@ from tests.utils import MockClock, setup_test_homeserver class PresenceStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): - hs = yield setup_test_homeserver(clock=MockClock()) + hs = yield setup_test_homeserver(self.addCleanup, clock=MockClock()) self.store = PresenceStore(None, hs) diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py index 5acbc8be0c..a1f6618bf9 100644 --- a/tests/storage/test_profile.py +++ b/tests/storage/test_profile.py @@ -26,7 +26,7 @@ from tests.utils import setup_test_homeserver class ProfileStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): - hs = yield setup_test_homeserver() + hs = yield setup_test_homeserver(self.addCleanup) self.store = ProfileStore(None, hs) diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index 85ce61e841..c4e9fb72bf 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -29,7 +29,7 @@ class RedactionTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): hs = yield setup_test_homeserver( - resource_for_federation=Mock(), http_client=None + self.addCleanup, resource_for_federation=Mock(), http_client=None ) self.store = hs.get_datastore() diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index bd96896bb3..4eda122edc 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -23,7 +23,7 @@ from tests.utils import setup_test_homeserver class RegistrationStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): - hs = yield setup_test_homeserver() + hs = yield setup_test_homeserver(self.addCleanup) self.db_pool = hs.get_db_pool() self.store = hs.get_datastore() diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py index 84d49b55c1..a1ea23b068 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py @@ -26,7 +26,7 @@ from tests.utils import setup_test_homeserver class RoomStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): - hs = yield setup_test_homeserver() + hs = yield setup_test_homeserver(self.addCleanup) # We can't test RoomStore on its own without the DirectoryStore, for # management of the 'room_aliases' table @@ -57,7 +57,7 @@ class RoomStoreTestCase(unittest.TestCase): class RoomEventsStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): - hs = setup_test_homeserver() + hs = setup_test_homeserver(self.addCleanup) # Room events need the full datastore, for persist_event() and # get_room_state() diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index 0d9908926a..c83ef60062 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -29,7 +29,7 @@ class RoomMemberStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): hs = yield setup_test_homeserver( - resource_for_federation=Mock(), http_client=None + self.addCleanup, resource_for_federation=Mock(), http_client=None ) # We can't test the RoomMemberStore on its own without the other event # storage logic diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index ed5b41644a..6168c46248 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -33,7 +33,7 @@ class StateStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def setUp(self): - hs = yield tests.utils.setup_test_homeserver() + hs = yield tests.utils.setup_test_homeserver(self.addCleanup) self.store = hs.get_datastore() self.event_builder_factory = hs.get_event_builder_factory() diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py index 7a273eab48..b46e0ea7e2 100644 --- a/tests/storage/test_user_directory.py +++ b/tests/storage/test_user_directory.py @@ -29,7 +29,7 @@ BOBBY = "@bobby:a" class UserDirectoryStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): - self.hs = yield setup_test_homeserver() + self.hs = yield setup_test_homeserver(self.addCleanup) self.store = UserDirectoryStore(None, self.hs) # alice and bob are both in !room_id. bobby is not but shares diff --git a/tests/test_federation.py b/tests/test_federation.py index f40ff29b52..2540604fcc 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -18,7 +18,10 @@ class MessageAcceptTests(unittest.TestCase): self.reactor = ThreadedMemoryReactorClock() self.hs_clock = Clock(self.reactor) self.homeserver = setup_test_homeserver( - http_client=self.http_client, clock=self.hs_clock, reactor=self.reactor + self.addCleanup, + http_client=self.http_client, + clock=self.hs_clock, + reactor=self.reactor, ) user_id = UserID("us", "test") diff --git a/tests/test_server.py b/tests/test_server.py index fc396226ea..895e490406 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -16,7 +16,7 @@ class JsonResourceTests(unittest.TestCase): self.reactor = MemoryReactorClock() self.hs_clock = Clock(self.reactor) self.homeserver = setup_test_homeserver( - http_client=None, clock=self.hs_clock, reactor=self.reactor + self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.reactor ) def test_handler_for_request(self): diff --git a/tests/test_visibility.py b/tests/test_visibility.py index 8643d63125..45a78338d6 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -31,7 +31,7 @@ TEST_ROOM_ID = "!TEST:ROOM" class FilterEventsForServerTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def setUp(self): - self.hs = yield setup_test_homeserver() + self.hs = yield setup_test_homeserver(self.addCleanup) self.event_creation_handler = self.hs.get_event_creation_handler() self.event_builder_factory = self.hs.get_event_builder_factory() self.store = self.hs.get_datastore() diff --git a/tests/utils.py b/tests/utils.py index 8668b5478f..90378326f8 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -13,7 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import atexit import hashlib +import os +import uuid from inspect import getcallargs from mock import Mock, patch @@ -27,23 +30,80 @@ from synapse.http.server import HttpServer from synapse.server import HomeServer from synapse.storage import PostgresEngine from synapse.storage.engines import create_engine -from synapse.storage.prepare_database import prepare_database +from synapse.storage.prepare_database import ( + _get_or_create_schema_state, + _setup_new_database, + prepare_database, +) from synapse.util.logcontext import LoggingContext from synapse.util.ratelimitutils import FederationRateLimiter # set this to True to run the tests against postgres instead of sqlite. -# It requires you to have a local postgres database called synapse_test, within -# which ALL TABLES WILL BE DROPPED -USE_POSTGRES_FOR_TESTS = False +USE_POSTGRES_FOR_TESTS = os.environ.get("SYNAPSE_POSTGRES", False) +POSTGRES_USER = os.environ.get("SYNAPSE_POSTGRES_USER", "postgres") +POSTGRES_BASE_DB = "_synapse_unit_tests_base_%s" % (os.getpid(),) + + +def setupdb(): + + # If we're using PostgreSQL, set up the db once + if USE_POSTGRES_FOR_TESTS: + pgconfig = { + "name": "psycopg2", + "args": { + "database": POSTGRES_BASE_DB, + "user": POSTGRES_USER, + "cp_min": 1, + "cp_max": 5, + }, + } + config = Mock() + config.password_providers = [] + config.database_config = pgconfig + db_engine = create_engine(pgconfig) + db_conn = db_engine.module.connect(user=POSTGRES_USER) + db_conn.autocommit = True + cur = db_conn.cursor() + cur.execute("DROP DATABASE IF EXISTS %s;" % (POSTGRES_BASE_DB,)) + cur.execute("CREATE DATABASE %s;" % (POSTGRES_BASE_DB,)) + cur.close() + db_conn.close() + + # Set up in the db + db_conn = db_engine.module.connect( + database=POSTGRES_BASE_DB, user=POSTGRES_USER + ) + cur = db_conn.cursor() + _get_or_create_schema_state(cur, db_engine) + _setup_new_database(cur, db_engine) + db_conn.commit() + cur.close() + db_conn.close() + + def _cleanup(): + db_conn = db_engine.module.connect(user=POSTGRES_USER) + db_conn.autocommit = True + cur = db_conn.cursor() + cur.execute("DROP DATABASE IF EXISTS %s;" % (POSTGRES_BASE_DB,)) + cur.close() + db_conn.close() + + atexit.register(_cleanup) @defer.inlineCallbacks def setup_test_homeserver( - name="test", datastore=None, config=None, reactor=None, **kargs + cleanup_func, name="test", datastore=None, config=None, reactor=None, **kargs ): - """Setup a homeserver suitable for running tests against. Keyword arguments - are passed to the Homeserver constructor. If no datastore is supplied a - datastore backed by an in-memory sqlite db will be given to the HS. + """ + Setup a homeserver suitable for running tests against. Keyword arguments + are passed to the Homeserver constructor. + + If no datastore is supplied, one is created and given to the homeserver. + + Args: + cleanup_func : The function used to register a cleanup routine for + after the test. """ if reactor is None: from twisted.internet import reactor @@ -95,9 +155,11 @@ def setup_test_homeserver( kargs["clock"] = MockClock() if USE_POSTGRES_FOR_TESTS: + test_db = "synapse_test_%s" % uuid.uuid4().hex + config.database_config = { "name": "psycopg2", - "args": {"database": "synapse_test", "cp_min": 1, "cp_max": 5}, + "args": {"database": test_db, "cp_min": 1, "cp_max": 5}, } else: config.database_config = { @@ -107,6 +169,21 @@ def setup_test_homeserver( db_engine = create_engine(config.database_config) + # Create the database before we actually try and connect to it, based off + # the template database we generate in setupdb() + if datastore is None and isinstance(db_engine, PostgresEngine): + db_conn = db_engine.module.connect( + database=POSTGRES_BASE_DB, user=POSTGRES_USER + ) + db_conn.autocommit = True + cur = db_conn.cursor() + cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,)) + cur.execute( + "CREATE DATABASE %s WITH TEMPLATE %s;" % (test_db, POSTGRES_BASE_DB) + ) + cur.close() + db_conn.close() + # we need to configure the connection pool to run the on_new_connection # function, so that we can test code that uses custom sqlite functions # (like rank). @@ -125,15 +202,35 @@ def setup_test_homeserver( reactor=reactor, **kargs ) - db_conn = hs.get_db_conn() - # make sure that the database is empty - if isinstance(db_engine, PostgresEngine): - cur = db_conn.cursor() - cur.execute("SELECT tablename FROM pg_tables where schemaname='public'") - rows = cur.fetchall() - for r in rows: - cur.execute("DROP TABLE %s CASCADE" % r[0]) - yield prepare_database(db_conn, db_engine, config) + + # Prepare the DB on SQLite -- PostgreSQL is a copy of an already up to + # date db + if not isinstance(db_engine, PostgresEngine): + db_conn = hs.get_db_conn() + yield prepare_database(db_conn, db_engine, config) + db_conn.commit() + db_conn.close() + + else: + # We need to do cleanup on PostgreSQL + def cleanup(): + # Close all the db pools + hs.get_db_pool().close() + + # Drop the test database + db_conn = db_engine.module.connect( + database=POSTGRES_BASE_DB, user=POSTGRES_USER + ) + db_conn.autocommit = True + cur = db_conn.cursor() + cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,)) + db_conn.commit() + cur.close() + db_conn.close() + + # Register the cleanup hook + cleanup_func(cleanup) + hs.setup() else: hs = HomeServer( diff --git a/tox.ini b/tox.ini index ed26644bd9..085f438989 100644 --- a/tox.ini +++ b/tox.ini @@ -1,7 +1,7 @@ [tox] envlist = packaging, py27, py36, pep8, check_isort -[testenv] +[base] deps = coverage Twisted>=15.1 @@ -15,6 +15,15 @@ deps = setenv = PYTHONDONTWRITEBYTECODE = no_byte_code +[testenv] +deps = + {[base]deps} + +setenv = + {[base]setenv} + +passenv = * + commands = /usr/bin/find "{toxinidir}" -name '*.pyc' -delete coverage run {env:COVERAGE_OPTS:} --source="{toxinidir}/synapse" \ @@ -46,6 +55,15 @@ commands = # ) usedevelop=true +[testenv:py27-postgres] +usedevelop=true +deps = + {[base]deps} + psycopg2 +setenv = + {[base]setenv} + SYNAPSE_POSTGRES = 1 + [testenv:py36] usedevelop=true commands = -- cgit 1.4.1 From 0d43f991a19840a224d3dac78d79f13d78212ee6 Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Mon, 13 Aug 2018 18:00:23 +0100 Subject: support admin_email config and pass through into blocking errors, return AuthError in all cases --- synapse/api/auth.py | 8 ++++++-- synapse/api/errors.py | 13 +++++++++++-- synapse/config/server.py | 4 ++++ synapse/handlers/register.py | 27 ++++++++++++++------------- tests/api/test_auth.py | 6 +++++- tests/handlers/test_register.py | 8 ++++---- tests/utils.py | 1 + 7 files changed, 45 insertions(+), 22 deletions(-) (limited to 'tests/handlers') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 9c62ec4374..4f028078fa 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -781,11 +781,15 @@ class Auth(object): """ if self.hs.config.hs_disabled: raise AuthError( - 403, self.hs.config.hs_disabled_message, errcode=Codes.HS_DISABLED + 403, self.hs.config.hs_disabled_message, + errcode=Codes.HS_DISABLED, + admin_email=self.hs.config.admin_email, ) if self.hs.config.limit_usage_by_mau is True: current_mau = yield self.store.get_monthly_active_count() if current_mau >= self.hs.config.max_mau_value: raise AuthError( - 403, "MAU Limit Exceeded", errcode=Codes.MAU_LIMIT_EXCEEDED + 403, "MAU Limit Exceeded", + admin_email=self.hs.config.admin_email, + errcode=Codes.MAU_LIMIT_EXCEEDED ) diff --git a/synapse/api/errors.py b/synapse/api/errors.py index dc3bed5fcb..d74848159e 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -225,11 +225,20 @@ class NotFoundError(SynapseError): class AuthError(SynapseError): """An error raised when there was a problem authorising an event.""" - def __init__(self, *args, **kwargs): if "errcode" not in kwargs: kwargs["errcode"] = Codes.FORBIDDEN - super(AuthError, self).__init__(*args, **kwargs) + self.admin_email = kwargs.get('admin_email') + self.msg = kwargs.get('msg') + self.errcode = kwargs.get('errcode') + super(AuthError, self).__init__(*args, errcode=kwargs["errcode"]) + + def error_dict(self): + return cs_error( + self.msg, + self.errcode, + admin_email=self.admin_email, + ) class EventSizeError(SynapseError): diff --git a/synapse/config/server.py b/synapse/config/server.py index 3b078d72ca..64a5121a45 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -82,6 +82,10 @@ class ServerConfig(Config): self.hs_disabled = config.get("hs_disabled", False) self.hs_disabled_message = config.get("hs_disabled_message", "") + # Admin email to direct users at should their instance become blocked + # due to resource constraints + self.admin_email = config.get("admin_email", None) + # FIXME: federation_domain_whitelist needs sytests self.federation_domain_whitelist = None federation_domain_whitelist = config.get( diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 3526b20d5a..ef7222d7b8 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -144,7 +144,8 @@ class RegistrationHandler(BaseHandler): Raises: RegistrationError if there was a problem registering. """ - yield self._check_mau_limits() + + yield self.auth.check_auth_blocking() password_hash = None if password: password_hash = yield self.auth_handler().hash(password) @@ -289,7 +290,7 @@ class RegistrationHandler(BaseHandler): 400, "User ID can only contain characters a-z, 0-9, or '=_-./'", ) - yield self._check_mau_limits() + yield self.auth.check_auth_blocking() user = UserID(localpart, self.hs.hostname) user_id = user.to_string() @@ -439,7 +440,7 @@ class RegistrationHandler(BaseHandler): """ if localpart is None: raise SynapseError(400, "Request must include user id") - yield self._check_mau_limits() + yield self.auth.check_auth_blocking() need_register = True try: @@ -534,13 +535,13 @@ class RegistrationHandler(BaseHandler): action="join", ) - @defer.inlineCallbacks - def _check_mau_limits(self): - """ - Do not accept registrations if monthly active user limits exceeded - and limiting is enabled - """ - try: - yield self.auth.check_auth_blocking() - except AuthError as e: - raise RegistrationError(e.code, str(e), e.errcode) + # @defer.inlineCallbacks + # def _s(self): + # """ + # Do not accept registrations if monthly active user limits exceeded + # and limiting is enabled + # """ + # try: + # yield self.auth.check_auth_blocking() + # except AuthError as e: + # raise RegistrationError(e.code, str(e), e.errcode) diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index a65689ba89..e8a1894e65 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -455,8 +455,11 @@ class AuthTestCase(unittest.TestCase): return_value=defer.succeed(lots_of_users) ) - with self.assertRaises(AuthError): + with self.assertRaises(AuthError) as e: yield self.auth.check_auth_blocking() + self.assertEquals(e.exception.admin_email, self.hs.config.admin_email) + self.assertEquals(e.exception.errcode, Codes.MAU_LIMIT_EXCEEDED) + self.assertEquals(e.exception.code, 403) # Ensure does not throw an error self.store.get_monthly_active_count = Mock( @@ -470,5 +473,6 @@ class AuthTestCase(unittest.TestCase): self.hs.config.hs_disabled_message = "Reason for being disabled" with self.assertRaises(AuthError) as e: yield self.auth.check_auth_blocking() + self.assertEquals(e.exception.admin_email, self.hs.config.admin_email) self.assertEquals(e.exception.errcode, Codes.HS_DISABLED) self.assertEquals(e.exception.code, 403) diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index d48d40c8dd..35d1bcab3e 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -17,7 +17,7 @@ from mock import Mock from twisted.internet import defer -from synapse.api.errors import RegistrationError +from synapse.api.errors import AuthError from synapse.handlers.register import RegistrationHandler from synapse.types import UserID, create_requester @@ -109,7 +109,7 @@ class RegistrationTestCase(unittest.TestCase): self.store.get_monthly_active_count = Mock( return_value=defer.succeed(self.lots_of_users) ) - with self.assertRaises(RegistrationError): + with self.assertRaises(AuthError): yield self.handler.get_or_create_user("requester", 'b', "display_name") @defer.inlineCallbacks @@ -118,7 +118,7 @@ class RegistrationTestCase(unittest.TestCase): self.store.get_monthly_active_count = Mock( return_value=defer.succeed(self.lots_of_users) ) - with self.assertRaises(RegistrationError): + with self.assertRaises(AuthError): yield self.handler.register(localpart="local_part") @defer.inlineCallbacks @@ -127,5 +127,5 @@ class RegistrationTestCase(unittest.TestCase): self.store.get_monthly_active_count = Mock( return_value=defer.succeed(self.lots_of_users) ) - with self.assertRaises(RegistrationError): + with self.assertRaises(AuthError): yield self.handler.register_saml2(localpart="local_part") diff --git a/tests/utils.py b/tests/utils.py index 90378326f8..4af81624eb 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -139,6 +139,7 @@ def setup_test_homeserver( config.hs_disabled_message = "" config.max_mau_value = 50 config.mau_limits_reserved_threepids = [] + config.admin_email = None # we need a sane default_room_version, otherwise attempts to create rooms will # fail. -- cgit 1.4.1 From ce7de9ae6b74e8e5e89ff442bc29f8cd73328042 Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Mon, 13 Aug 2018 18:06:18 +0100 Subject: Revert "support admin_email config and pass through into blocking errors, return AuthError in all cases" This reverts commit 0d43f991a19840a224d3dac78d79f13d78212ee6. --- synapse/api/auth.py | 8 ++------ synapse/api/errors.py | 13 ++----------- synapse/config/server.py | 4 ---- synapse/handlers/register.py | 27 +++++++++++++-------------- tests/api/test_auth.py | 6 +----- tests/handlers/test_register.py | 8 ++++---- tests/utils.py | 1 - 7 files changed, 22 insertions(+), 45 deletions(-) (limited to 'tests/handlers') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 4f028078fa..9c62ec4374 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -781,15 +781,11 @@ class Auth(object): """ if self.hs.config.hs_disabled: raise AuthError( - 403, self.hs.config.hs_disabled_message, - errcode=Codes.HS_DISABLED, - admin_email=self.hs.config.admin_email, + 403, self.hs.config.hs_disabled_message, errcode=Codes.HS_DISABLED ) if self.hs.config.limit_usage_by_mau is True: current_mau = yield self.store.get_monthly_active_count() if current_mau >= self.hs.config.max_mau_value: raise AuthError( - 403, "MAU Limit Exceeded", - admin_email=self.hs.config.admin_email, - errcode=Codes.MAU_LIMIT_EXCEEDED + 403, "MAU Limit Exceeded", errcode=Codes.MAU_LIMIT_EXCEEDED ) diff --git a/synapse/api/errors.py b/synapse/api/errors.py index d74848159e..dc3bed5fcb 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -225,20 +225,11 @@ class NotFoundError(SynapseError): class AuthError(SynapseError): """An error raised when there was a problem authorising an event.""" + def __init__(self, *args, **kwargs): if "errcode" not in kwargs: kwargs["errcode"] = Codes.FORBIDDEN - self.admin_email = kwargs.get('admin_email') - self.msg = kwargs.get('msg') - self.errcode = kwargs.get('errcode') - super(AuthError, self).__init__(*args, errcode=kwargs["errcode"]) - - def error_dict(self): - return cs_error( - self.msg, - self.errcode, - admin_email=self.admin_email, - ) + super(AuthError, self).__init__(*args, **kwargs) class EventSizeError(SynapseError): diff --git a/synapse/config/server.py b/synapse/config/server.py index 64a5121a45..3b078d72ca 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -82,10 +82,6 @@ class ServerConfig(Config): self.hs_disabled = config.get("hs_disabled", False) self.hs_disabled_message = config.get("hs_disabled_message", "") - # Admin email to direct users at should their instance become blocked - # due to resource constraints - self.admin_email = config.get("admin_email", None) - # FIXME: federation_domain_whitelist needs sytests self.federation_domain_whitelist = None federation_domain_whitelist = config.get( diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index ef7222d7b8..3526b20d5a 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -144,8 +144,7 @@ class RegistrationHandler(BaseHandler): Raises: RegistrationError if there was a problem registering. """ - - yield self.auth.check_auth_blocking() + yield self._check_mau_limits() password_hash = None if password: password_hash = yield self.auth_handler().hash(password) @@ -290,7 +289,7 @@ class RegistrationHandler(BaseHandler): 400, "User ID can only contain characters a-z, 0-9, or '=_-./'", ) - yield self.auth.check_auth_blocking() + yield self._check_mau_limits() user = UserID(localpart, self.hs.hostname) user_id = user.to_string() @@ -440,7 +439,7 @@ class RegistrationHandler(BaseHandler): """ if localpart is None: raise SynapseError(400, "Request must include user id") - yield self.auth.check_auth_blocking() + yield self._check_mau_limits() need_register = True try: @@ -535,13 +534,13 @@ class RegistrationHandler(BaseHandler): action="join", ) - # @defer.inlineCallbacks - # def _s(self): - # """ - # Do not accept registrations if monthly active user limits exceeded - # and limiting is enabled - # """ - # try: - # yield self.auth.check_auth_blocking() - # except AuthError as e: - # raise RegistrationError(e.code, str(e), e.errcode) + @defer.inlineCallbacks + def _check_mau_limits(self): + """ + Do not accept registrations if monthly active user limits exceeded + and limiting is enabled + """ + try: + yield self.auth.check_auth_blocking() + except AuthError as e: + raise RegistrationError(e.code, str(e), e.errcode) diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index e8a1894e65..a65689ba89 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -455,11 +455,8 @@ class AuthTestCase(unittest.TestCase): return_value=defer.succeed(lots_of_users) ) - with self.assertRaises(AuthError) as e: + with self.assertRaises(AuthError): yield self.auth.check_auth_blocking() - self.assertEquals(e.exception.admin_email, self.hs.config.admin_email) - self.assertEquals(e.exception.errcode, Codes.MAU_LIMIT_EXCEEDED) - self.assertEquals(e.exception.code, 403) # Ensure does not throw an error self.store.get_monthly_active_count = Mock( @@ -473,6 +470,5 @@ class AuthTestCase(unittest.TestCase): self.hs.config.hs_disabled_message = "Reason for being disabled" with self.assertRaises(AuthError) as e: yield self.auth.check_auth_blocking() - self.assertEquals(e.exception.admin_email, self.hs.config.admin_email) self.assertEquals(e.exception.errcode, Codes.HS_DISABLED) self.assertEquals(e.exception.code, 403) diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index 35d1bcab3e..d48d40c8dd 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -17,7 +17,7 @@ from mock import Mock from twisted.internet import defer -from synapse.api.errors import AuthError +from synapse.api.errors import RegistrationError from synapse.handlers.register import RegistrationHandler from synapse.types import UserID, create_requester @@ -109,7 +109,7 @@ class RegistrationTestCase(unittest.TestCase): self.store.get_monthly_active_count = Mock( return_value=defer.succeed(self.lots_of_users) ) - with self.assertRaises(AuthError): + with self.assertRaises(RegistrationError): yield self.handler.get_or_create_user("requester", 'b', "display_name") @defer.inlineCallbacks @@ -118,7 +118,7 @@ class RegistrationTestCase(unittest.TestCase): self.store.get_monthly_active_count = Mock( return_value=defer.succeed(self.lots_of_users) ) - with self.assertRaises(AuthError): + with self.assertRaises(RegistrationError): yield self.handler.register(localpart="local_part") @defer.inlineCallbacks @@ -127,5 +127,5 @@ class RegistrationTestCase(unittest.TestCase): self.store.get_monthly_active_count = Mock( return_value=defer.succeed(self.lots_of_users) ) - with self.assertRaises(AuthError): + with self.assertRaises(RegistrationError): yield self.handler.register_saml2(localpart="local_part") diff --git a/tests/utils.py b/tests/utils.py index 4af81624eb..90378326f8 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -139,7 +139,6 @@ def setup_test_homeserver( config.hs_disabled_message = "" config.max_mau_value = 50 config.mau_limits_reserved_threepids = [] - config.admin_email = None # we need a sane default_room_version, otherwise attempts to create rooms will # fail. -- cgit 1.4.1 From f4b49152e27593dd6c863e71479a2ab712c4ada2 Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Mon, 13 Aug 2018 18:00:23 +0100 Subject: support admin_email config and pass through into blocking errors, return AuthError in all cases --- synapse/api/auth.py | 8 ++++++-- synapse/api/errors.py | 13 +++++++++++-- synapse/config/server.py | 4 ++++ synapse/handlers/register.py | 27 ++++++++++++++------------- tests/api/test_auth.py | 6 +++++- tests/handlers/test_register.py | 8 ++++---- tests/utils.py | 1 + 7 files changed, 45 insertions(+), 22 deletions(-) (limited to 'tests/handlers') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 9c62ec4374..4f028078fa 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -781,11 +781,15 @@ class Auth(object): """ if self.hs.config.hs_disabled: raise AuthError( - 403, self.hs.config.hs_disabled_message, errcode=Codes.HS_DISABLED + 403, self.hs.config.hs_disabled_message, + errcode=Codes.HS_DISABLED, + admin_email=self.hs.config.admin_email, ) if self.hs.config.limit_usage_by_mau is True: current_mau = yield self.store.get_monthly_active_count() if current_mau >= self.hs.config.max_mau_value: raise AuthError( - 403, "MAU Limit Exceeded", errcode=Codes.MAU_LIMIT_EXCEEDED + 403, "MAU Limit Exceeded", + admin_email=self.hs.config.admin_email, + errcode=Codes.MAU_LIMIT_EXCEEDED ) diff --git a/synapse/api/errors.py b/synapse/api/errors.py index dc3bed5fcb..d74848159e 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -225,11 +225,20 @@ class NotFoundError(SynapseError): class AuthError(SynapseError): """An error raised when there was a problem authorising an event.""" - def __init__(self, *args, **kwargs): if "errcode" not in kwargs: kwargs["errcode"] = Codes.FORBIDDEN - super(AuthError, self).__init__(*args, **kwargs) + self.admin_email = kwargs.get('admin_email') + self.msg = kwargs.get('msg') + self.errcode = kwargs.get('errcode') + super(AuthError, self).__init__(*args, errcode=kwargs["errcode"]) + + def error_dict(self): + return cs_error( + self.msg, + self.errcode, + admin_email=self.admin_email, + ) class EventSizeError(SynapseError): diff --git a/synapse/config/server.py b/synapse/config/server.py index 3b078d72ca..64a5121a45 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -82,6 +82,10 @@ class ServerConfig(Config): self.hs_disabled = config.get("hs_disabled", False) self.hs_disabled_message = config.get("hs_disabled_message", "") + # Admin email to direct users at should their instance become blocked + # due to resource constraints + self.admin_email = config.get("admin_email", None) + # FIXME: federation_domain_whitelist needs sytests self.federation_domain_whitelist = None federation_domain_whitelist = config.get( diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 3526b20d5a..ef7222d7b8 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -144,7 +144,8 @@ class RegistrationHandler(BaseHandler): Raises: RegistrationError if there was a problem registering. """ - yield self._check_mau_limits() + + yield self.auth.check_auth_blocking() password_hash = None if password: password_hash = yield self.auth_handler().hash(password) @@ -289,7 +290,7 @@ class RegistrationHandler(BaseHandler): 400, "User ID can only contain characters a-z, 0-9, or '=_-./'", ) - yield self._check_mau_limits() + yield self.auth.check_auth_blocking() user = UserID(localpart, self.hs.hostname) user_id = user.to_string() @@ -439,7 +440,7 @@ class RegistrationHandler(BaseHandler): """ if localpart is None: raise SynapseError(400, "Request must include user id") - yield self._check_mau_limits() + yield self.auth.check_auth_blocking() need_register = True try: @@ -534,13 +535,13 @@ class RegistrationHandler(BaseHandler): action="join", ) - @defer.inlineCallbacks - def _check_mau_limits(self): - """ - Do not accept registrations if monthly active user limits exceeded - and limiting is enabled - """ - try: - yield self.auth.check_auth_blocking() - except AuthError as e: - raise RegistrationError(e.code, str(e), e.errcode) + # @defer.inlineCallbacks + # def _s(self): + # """ + # Do not accept registrations if monthly active user limits exceeded + # and limiting is enabled + # """ + # try: + # yield self.auth.check_auth_blocking() + # except AuthError as e: + # raise RegistrationError(e.code, str(e), e.errcode) diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index a65689ba89..e8a1894e65 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -455,8 +455,11 @@ class AuthTestCase(unittest.TestCase): return_value=defer.succeed(lots_of_users) ) - with self.assertRaises(AuthError): + with self.assertRaises(AuthError) as e: yield self.auth.check_auth_blocking() + self.assertEquals(e.exception.admin_email, self.hs.config.admin_email) + self.assertEquals(e.exception.errcode, Codes.MAU_LIMIT_EXCEEDED) + self.assertEquals(e.exception.code, 403) # Ensure does not throw an error self.store.get_monthly_active_count = Mock( @@ -470,5 +473,6 @@ class AuthTestCase(unittest.TestCase): self.hs.config.hs_disabled_message = "Reason for being disabled" with self.assertRaises(AuthError) as e: yield self.auth.check_auth_blocking() + self.assertEquals(e.exception.admin_email, self.hs.config.admin_email) self.assertEquals(e.exception.errcode, Codes.HS_DISABLED) self.assertEquals(e.exception.code, 403) diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index d48d40c8dd..35d1bcab3e 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -17,7 +17,7 @@ from mock import Mock from twisted.internet import defer -from synapse.api.errors import RegistrationError +from synapse.api.errors import AuthError from synapse.handlers.register import RegistrationHandler from synapse.types import UserID, create_requester @@ -109,7 +109,7 @@ class RegistrationTestCase(unittest.TestCase): self.store.get_monthly_active_count = Mock( return_value=defer.succeed(self.lots_of_users) ) - with self.assertRaises(RegistrationError): + with self.assertRaises(AuthError): yield self.handler.get_or_create_user("requester", 'b', "display_name") @defer.inlineCallbacks @@ -118,7 +118,7 @@ class RegistrationTestCase(unittest.TestCase): self.store.get_monthly_active_count = Mock( return_value=defer.succeed(self.lots_of_users) ) - with self.assertRaises(RegistrationError): + with self.assertRaises(AuthError): yield self.handler.register(localpart="local_part") @defer.inlineCallbacks @@ -127,5 +127,5 @@ class RegistrationTestCase(unittest.TestCase): self.store.get_monthly_active_count = Mock( return_value=defer.succeed(self.lots_of_users) ) - with self.assertRaises(RegistrationError): + with self.assertRaises(AuthError): yield self.handler.register_saml2(localpart="local_part") diff --git a/tests/utils.py b/tests/utils.py index 90378326f8..4af81624eb 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -139,6 +139,7 @@ def setup_test_homeserver( config.hs_disabled_message = "" config.max_mau_value = 50 config.mau_limits_reserved_threepids = [] + config.admin_email = None # we need a sane default_room_version, otherwise attempts to create rooms will # fail. -- cgit 1.4.1 From ed4bc3d2fc7242e27b3cdd36bc6c27c98fac09c8 Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Tue, 14 Aug 2018 15:04:48 +0100 Subject: fix off by 1s on mau --- synapse/handlers/auth.py | 4 ++-- tests/handlers/test_auth.py | 39 ++++++++++++++++++++++++++++++++++++++- tests/handlers/test_register.py | 14 ++++++++++---- 3 files changed, 50 insertions(+), 7 deletions(-) (limited to 'tests/handlers') diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 7ea8ce9f94..7baaa39447 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -520,7 +520,7 @@ class AuthHandler(BaseHandler): """ logger.info("Logging in user %s on device %s", user_id, device_id) access_token = yield self.issue_access_token(user_id, device_id) - yield self.auth.check_auth_blocking() + yield self.auth.check_auth_blocking(user_id) # the device *should* have been registered before we got here; however, # it's possible we raced against a DELETE operation. The thing we @@ -734,7 +734,6 @@ class AuthHandler(BaseHandler): @defer.inlineCallbacks def validate_short_term_login_token_and_get_user_id(self, login_token): - yield self.auth.check_auth_blocking() auth_api = self.hs.get_auth() user_id = None try: @@ -743,6 +742,7 @@ class AuthHandler(BaseHandler): auth_api.validate_macaroon(macaroon, "login", True, user_id) except Exception: raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN) + yield self.auth.check_auth_blocking(user_id) defer.returnValue(user_id) @defer.inlineCallbacks diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index 56c0f87fb7..9ca7b2ee4e 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -124,7 +124,7 @@ class AuthTestCase(unittest.TestCase): ) @defer.inlineCallbacks - def test_mau_limits_exceeded(self): + def test_mau_limits_exceeded_large(self): self.hs.config.limit_usage_by_mau = True self.hs.get_datastore().get_monthly_active_count = Mock( return_value=defer.succeed(self.large_number_of_users) @@ -141,6 +141,43 @@ class AuthTestCase(unittest.TestCase): self._get_macaroon().serialize() ) + @defer.inlineCallbacks + def test_mau_limits_parity(self): + self.hs.config.limit_usage_by_mau = True + + # If not in monthly active cohort + self.hs.get_datastore().get_monthly_active_count = Mock( + return_value=defer.succeed(self.hs.config.max_mau_value) + ) + with self.assertRaises(AuthError): + yield self.auth_handler.get_access_token_for_user_id('user_a') + + self.hs.get_datastore().get_monthly_active_count = Mock( + return_value=defer.succeed(self.hs.config.max_mau_value) + ) + with self.assertRaises(AuthError): + yield self.auth_handler.validate_short_term_login_token_and_get_user_id( + self._get_macaroon().serialize() + ) + # If in monthly active cohort + self.hs.get_datastore().user_last_seen_monthly_active = Mock( + return_value=defer.succeed(self.hs.get_clock().time_msec()) + ) + self.hs.get_datastore().get_monthly_active_count = Mock( + return_value=defer.succeed(self.hs.config.max_mau_value) + ) + yield self.auth_handler.get_access_token_for_user_id('user_a') + self.hs.get_datastore().user_last_seen_monthly_active = Mock( + return_value=defer.succeed(self.hs.get_clock().time_msec()) + ) + self.hs.get_datastore().get_monthly_active_count = Mock( + return_value=defer.succeed(self.hs.config.max_mau_value) + ) + yield self.auth_handler.validate_short_term_login_token_and_get_user_id( + self._get_macaroon().serialize() + ) + + @defer.inlineCallbacks def test_mau_limits_not_exceeded(self): self.hs.config.limit_usage_by_mau = True diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index 35d1bcab3e..a821da0750 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -17,7 +17,7 @@ from mock import Mock from twisted.internet import defer -from synapse.api.errors import AuthError +from synapse.api.errors import RegistrationError from synapse.handlers.register import RegistrationHandler from synapse.types import UserID, create_requester @@ -109,7 +109,7 @@ class RegistrationTestCase(unittest.TestCase): self.store.get_monthly_active_count = Mock( return_value=defer.succeed(self.lots_of_users) ) - with self.assertRaises(AuthError): + with self.assertRaises(RegistrationError): yield self.handler.get_or_create_user("requester", 'b', "display_name") @defer.inlineCallbacks @@ -118,7 +118,13 @@ class RegistrationTestCase(unittest.TestCase): self.store.get_monthly_active_count = Mock( return_value=defer.succeed(self.lots_of_users) ) - with self.assertRaises(AuthError): + with self.assertRaises(RegistrationError): + yield self.handler.register(localpart="local_part") + + self.store.get_monthly_active_count = Mock( + return_value=defer.succeed(self.hs.config.max_mau_value) + ) + with self.assertRaises(RegistrationError): yield self.handler.register(localpart="local_part") @defer.inlineCallbacks @@ -127,5 +133,5 @@ class RegistrationTestCase(unittest.TestCase): self.store.get_monthly_active_count = Mock( return_value=defer.succeed(self.lots_of_users) ) - with self.assertRaises(AuthError): + with self.assertRaises(RegistrationError): yield self.handler.register_saml2(localpart="local_part") -- cgit 1.4.1 From 8f9a7eb58de214cd489ba233c381521e9bf79dec Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Mon, 13 Aug 2018 18:00:23 +0100 Subject: support admin_email config and pass through into blocking errors, return AuthError in all cases --- synapse/handlers/register.py | 1 - tests/handlers/test_register.py | 8 ++++---- 2 files changed, 4 insertions(+), 5 deletions(-) (limited to 'tests/handlers') diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 54e3434928..f03ee1476b 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -534,4 +534,3 @@ class RegistrationHandler(BaseHandler): remote_room_hosts=remote_room_hosts, action="join", ) - diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index a821da0750..6699d25121 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -17,7 +17,7 @@ from mock import Mock from twisted.internet import defer -from synapse.api.errors import RegistrationError +from synapse.api.errors import AuthError from synapse.handlers.register import RegistrationHandler from synapse.types import UserID, create_requester @@ -109,7 +109,7 @@ class RegistrationTestCase(unittest.TestCase): self.store.get_monthly_active_count = Mock( return_value=defer.succeed(self.lots_of_users) ) - with self.assertRaises(RegistrationError): + with self.assertRaises(AuthError): yield self.handler.get_or_create_user("requester", 'b', "display_name") @defer.inlineCallbacks @@ -118,7 +118,7 @@ class RegistrationTestCase(unittest.TestCase): self.store.get_monthly_active_count = Mock( return_value=defer.succeed(self.lots_of_users) ) - with self.assertRaises(RegistrationError): + with self.assertRaises(AuthError): yield self.handler.register(localpart="local_part") self.store.get_monthly_active_count = Mock( @@ -133,5 +133,5 @@ class RegistrationTestCase(unittest.TestCase): self.store.get_monthly_active_count = Mock( return_value=defer.succeed(self.lots_of_users) ) - with self.assertRaises(RegistrationError): + with self.assertRaises(AuthError): yield self.handler.register_saml2(localpart="local_part") -- cgit 1.4.1 From 06b331ff4035558500e196a5dce79ffe9d2da807 Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Tue, 14 Aug 2018 15:28:15 +0100 Subject: fix off by 1 errors --- tests/handlers/test_auth.py | 1 - tests/handlers/test_register.py | 16 ++++++++++++++-- 2 files changed, 14 insertions(+), 3 deletions(-) (limited to 'tests/handlers') diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index 9ca7b2ee4e..3046bd6093 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -177,7 +177,6 @@ class AuthTestCase(unittest.TestCase): self._get_macaroon().serialize() ) - @defer.inlineCallbacks def test_mau_limits_not_exceeded(self): self.hs.config.limit_usage_by_mau = True diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index 6699d25121..7154816a34 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -98,7 +98,7 @@ class RegistrationTestCase(unittest.TestCase): def test_get_or_create_user_mau_not_blocked(self): self.hs.config.limit_usage_by_mau = True self.store.count_monthly_users = Mock( - return_value=defer.succeed(self.small_number_of_users) + return_value=defer.succeed(self.hs.config.max_mau_value - 1) ) # Ensure does not throw exception yield self.handler.get_or_create_user("@user:server", 'c', "User") @@ -112,6 +112,12 @@ class RegistrationTestCase(unittest.TestCase): with self.assertRaises(AuthError): yield self.handler.get_or_create_user("requester", 'b', "display_name") + self.store.get_monthly_active_count = Mock( + return_value=defer.succeed(self.hs.config.max_mau_value) + ) + with self.assertRaises(AuthError): + yield self.handler.get_or_create_user("requester", 'b', "display_name") + @defer.inlineCallbacks def test_register_mau_blocked(self): self.hs.config.limit_usage_by_mau = True @@ -124,7 +130,7 @@ class RegistrationTestCase(unittest.TestCase): self.store.get_monthly_active_count = Mock( return_value=defer.succeed(self.hs.config.max_mau_value) ) - with self.assertRaises(RegistrationError): + with self.assertRaises(AuthError): yield self.handler.register(localpart="local_part") @defer.inlineCallbacks @@ -135,3 +141,9 @@ class RegistrationTestCase(unittest.TestCase): ) with self.assertRaises(AuthError): yield self.handler.register_saml2(localpart="local_part") + + self.store.get_monthly_active_count = Mock( + return_value=defer.succeed(self.hs.config.max_mau_value) + ) + with self.assertRaises(AuthError): + yield self.handler.register_saml2(localpart="local_part") -- cgit 1.4.1 From 7277216d01261f055886d0ac7b1ae5e5c5fc33cf Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Tue, 14 Aug 2018 17:14:39 +0100 Subject: fix setup_test_homeserver to be postgres compatible --- tests/handlers/test_sync.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'tests/handlers') diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py index cfd37f3138..8c8b65e04e 100644 --- a/tests/handlers/test_sync.py +++ b/tests/handlers/test_sync.py @@ -29,7 +29,7 @@ class SyncTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def setUp(self): - self.hs = yield setup_test_homeserver() + self.hs = yield setup_test_homeserver(self.addCleanup) self.sync_handler = SyncHandler(self.hs) self.store = self.hs.get_datastore() -- cgit 1.4.1 From 75c663c7b97733cf1e217d0a973a6c4c64228444 Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Wed, 15 Aug 2018 11:27:48 +0100 Subject: update error codes --- tests/handlers/test_sync.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'tests/handlers') diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py index 8c8b65e04e..33d861bd64 100644 --- a/tests/handlers/test_sync.py +++ b/tests/handlers/test_sync.py @@ -51,7 +51,7 @@ class SyncTestCase(tests.unittest.TestCase): self.hs.config.hs_disabled = True with self.assertRaises(AuthError) as e: yield self.sync_handler.wait_for_sync_for_user(sync_config) - self.assertEquals(e.exception.errcode, Codes.HS_DISABLED) + self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEED) self.hs.config.hs_disabled = False @@ -59,7 +59,7 @@ class SyncTestCase(tests.unittest.TestCase): with self.assertRaises(AuthError) as e: yield self.sync_handler.wait_for_sync_for_user(sync_config) - self.assertEquals(e.exception.errcode, Codes.MAU_LIMIT_EXCEEDED) + self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEED) def _generate_sync_config(self, user_id): return SyncConfig( -- cgit 1.4.1 From 13ad9930c8799ea54671a6ce00533528d89e061b Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Thu, 16 Aug 2018 18:02:02 +0100 Subject: add new error type ResourceLimit --- synapse/api/auth.py | 10 ++++++---- synapse/api/errors.py | 23 +++++++++++++++++++++-- synapse/config/server.py | 1 + tests/api/test_auth.py | 6 +++--- tests/handlers/test_auth.py | 10 +++++----- tests/handlers/test_register.py | 14 +++++++------- tests/handlers/test_sync.py | 6 +++--- tests/utils.py | 1 + 8 files changed, 47 insertions(+), 24 deletions(-) (limited to 'tests/handlers') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 3b2a2ab77a..6945c118d3 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -25,7 +25,7 @@ from twisted.internet import defer import synapse.types from synapse import event_auth from synapse.api.constants import EventTypes, JoinRules, Membership -from synapse.api.errors import AuthError, Codes +from synapse.api.errors import AuthError, Codes, ResourceLimitError from synapse.types import UserID from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache from synapse.util.caches.lrucache import LruCache @@ -784,10 +784,11 @@ class Auth(object): MAU cohort """ if self.hs.config.hs_disabled: - raise AuthError( + raise ResourceLimitError( 403, self.hs.config.hs_disabled_message, errcode=Codes.RESOURCE_LIMIT_EXCEED, admin_uri=self.hs.config.admin_uri, + limit_type=self.hs.config.hs_disabled_limit_type ) if self.hs.config.limit_usage_by_mau is True: # If the user is already part of the MAU cohort @@ -798,8 +799,9 @@ class Auth(object): # Else if there is no room in the MAU bucket, bail current_mau = yield self.store.get_monthly_active_count() if current_mau >= self.hs.config.max_mau_value: - raise AuthError( + raise ResourceLimitError( 403, "Monthly Active User Limits AU Limit Exceeded", admin_uri=self.hs.config.admin_uri, - errcode=Codes.RESOURCE_LIMIT_EXCEED + errcode=Codes.RESOURCE_LIMIT_EXCEED, + limit_type="monthly_active_user" ) diff --git a/synapse/api/errors.py b/synapse/api/errors.py index 08f0cb5554..e26001ab12 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -224,15 +224,34 @@ class NotFoundError(SynapseError): class AuthError(SynapseError): """An error raised when there was a problem authorising an event.""" - def __init__(self, code, msg, errcode=Codes.FORBIDDEN, admin_uri=None): + + def __init__(self, *args, **kwargs): + if "errcode" not in kwargs: + kwargs["errcode"] = Codes.FORBIDDEN + super(AuthError, self).__init__(*args, **kwargs) + + +class ResourceLimitError(SynapseError): + """ + Any error raised when there is a problem with resource usage. + For instance, the monthly active user limit for the server has been exceeded + """ + def __init__( + self, code, msg, + errcode=Codes.RESOURCE_LIMIT_EXCEED, + admin_uri=None, + limit_type=None, + ): self.admin_uri = admin_uri - super(AuthError, self).__init__(code, msg, errcode=errcode) + self.limit_type = limit_type + super(ResourceLimitError, self).__init__(code, msg, errcode=errcode) def error_dict(self): return cs_error( self.msg, self.errcode, admin_uri=self.admin_uri, + limit_type=self.limit_type ) diff --git a/synapse/config/server.py b/synapse/config/server.py index 2190f3210a..ae72c872d9 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -81,6 +81,7 @@ class ServerConfig(Config): # Options to disable HS self.hs_disabled = config.get("hs_disabled", False) self.hs_disabled_message = config.get("hs_disabled_message", "") + self.hs_disabled_limit_type = config.get("hs_disabled_limit_type", "") # Admin uri to direct users at should their instance become blocked # due to resource constraints diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index 32a2b5fc3d..022d81ce3e 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -21,7 +21,7 @@ from twisted.internet import defer import synapse.handlers.auth from synapse.api.auth import Auth -from synapse.api.errors import AuthError, Codes +from synapse.api.errors import AuthError, Codes, ResourceLimitError from synapse.types import UserID from tests import unittest @@ -455,7 +455,7 @@ class AuthTestCase(unittest.TestCase): return_value=defer.succeed(lots_of_users) ) - with self.assertRaises(AuthError) as e: + with self.assertRaises(ResourceLimitError) as e: yield self.auth.check_auth_blocking() self.assertEquals(e.exception.admin_uri, self.hs.config.admin_uri) self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEED) @@ -471,7 +471,7 @@ class AuthTestCase(unittest.TestCase): def test_hs_disabled(self): self.hs.config.hs_disabled = True self.hs.config.hs_disabled_message = "Reason for being disabled" - with self.assertRaises(AuthError) as e: + with self.assertRaises(ResourceLimitError) as e: yield self.auth.check_auth_blocking() self.assertEquals(e.exception.admin_uri, self.hs.config.admin_uri) self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEED) diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index 3046bd6093..1e39fe0ec2 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -20,7 +20,7 @@ from twisted.internet import defer import synapse import synapse.api.errors -from synapse.api.errors import AuthError +from synapse.api.errors import ResourceLimitError from synapse.handlers.auth import AuthHandler from tests import unittest @@ -130,13 +130,13 @@ class AuthTestCase(unittest.TestCase): return_value=defer.succeed(self.large_number_of_users) ) - with self.assertRaises(AuthError): + with self.assertRaises(ResourceLimitError): yield self.auth_handler.get_access_token_for_user_id('user_a') self.hs.get_datastore().get_monthly_active_count = Mock( return_value=defer.succeed(self.large_number_of_users) ) - with self.assertRaises(AuthError): + with self.assertRaises(ResourceLimitError): yield self.auth_handler.validate_short_term_login_token_and_get_user_id( self._get_macaroon().serialize() ) @@ -149,13 +149,13 @@ class AuthTestCase(unittest.TestCase): self.hs.get_datastore().get_monthly_active_count = Mock( return_value=defer.succeed(self.hs.config.max_mau_value) ) - with self.assertRaises(AuthError): + with self.assertRaises(ResourceLimitError): yield self.auth_handler.get_access_token_for_user_id('user_a') self.hs.get_datastore().get_monthly_active_count = Mock( return_value=defer.succeed(self.hs.config.max_mau_value) ) - with self.assertRaises(AuthError): + with self.assertRaises(ResourceLimitError): yield self.auth_handler.validate_short_term_login_token_and_get_user_id( self._get_macaroon().serialize() ) diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index 7154816a34..7b4ade3dfb 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -17,7 +17,7 @@ from mock import Mock from twisted.internet import defer -from synapse.api.errors import AuthError +from synapse.api.errors import ResourceLimitError from synapse.handlers.register import RegistrationHandler from synapse.types import UserID, create_requester @@ -109,13 +109,13 @@ class RegistrationTestCase(unittest.TestCase): self.store.get_monthly_active_count = Mock( return_value=defer.succeed(self.lots_of_users) ) - with self.assertRaises(AuthError): + with self.assertRaises(ResourceLimitError): yield self.handler.get_or_create_user("requester", 'b', "display_name") self.store.get_monthly_active_count = Mock( return_value=defer.succeed(self.hs.config.max_mau_value) ) - with self.assertRaises(AuthError): + with self.assertRaises(ResourceLimitError): yield self.handler.get_or_create_user("requester", 'b', "display_name") @defer.inlineCallbacks @@ -124,13 +124,13 @@ class RegistrationTestCase(unittest.TestCase): self.store.get_monthly_active_count = Mock( return_value=defer.succeed(self.lots_of_users) ) - with self.assertRaises(AuthError): + with self.assertRaises(ResourceLimitError): yield self.handler.register(localpart="local_part") self.store.get_monthly_active_count = Mock( return_value=defer.succeed(self.hs.config.max_mau_value) ) - with self.assertRaises(AuthError): + with self.assertRaises(ResourceLimitError): yield self.handler.register(localpart="local_part") @defer.inlineCallbacks @@ -139,11 +139,11 @@ class RegistrationTestCase(unittest.TestCase): self.store.get_monthly_active_count = Mock( return_value=defer.succeed(self.lots_of_users) ) - with self.assertRaises(AuthError): + with self.assertRaises(ResourceLimitError): yield self.handler.register_saml2(localpart="local_part") self.store.get_monthly_active_count = Mock( return_value=defer.succeed(self.hs.config.max_mau_value) ) - with self.assertRaises(AuthError): + with self.assertRaises(ResourceLimitError): yield self.handler.register_saml2(localpart="local_part") diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py index 33d861bd64..a01ab471f5 100644 --- a/tests/handlers/test_sync.py +++ b/tests/handlers/test_sync.py @@ -14,7 +14,7 @@ # limitations under the License. from twisted.internet import defer -from synapse.api.errors import AuthError, Codes +from synapse.api.errors import Codes, ResourceLimitError from synapse.api.filtering import DEFAULT_FILTER_COLLECTION from synapse.handlers.sync import SyncConfig, SyncHandler from synapse.types import UserID @@ -49,7 +49,7 @@ class SyncTestCase(tests.unittest.TestCase): # Test that global lock works self.hs.config.hs_disabled = True - with self.assertRaises(AuthError) as e: + with self.assertRaises(ResourceLimitError) as e: yield self.sync_handler.wait_for_sync_for_user(sync_config) self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEED) @@ -57,7 +57,7 @@ class SyncTestCase(tests.unittest.TestCase): sync_config = self._generate_sync_config(user_id2) - with self.assertRaises(AuthError) as e: + with self.assertRaises(ResourceLimitError) as e: yield self.sync_handler.wait_for_sync_for_user(sync_config) self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEED) diff --git a/tests/utils.py b/tests/utils.py index 52326d4f67..6f8b1de3e7 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -137,6 +137,7 @@ def setup_test_homeserver( config.limit_usage_by_mau = False config.hs_disabled = False config.hs_disabled_message = "" + config.hs_disabled_limit_type = "" config.max_mau_value = 50 config.mau_limits_reserved_threepids = [] config.admin_uri = None -- cgit 1.4.1 From ca87ad1defac1082462367854cb4a656b7a96e90 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 17 Aug 2018 11:43:16 +0100 Subject: Split ProfileHandler into master and worker --- synapse/handlers/profile.py | 21 ++++++++++++++------- synapse/server.py | 7 +++++-- tests/handlers/test_profile.py | 4 ++-- 3 files changed, 21 insertions(+), 11 deletions(-) (limited to 'tests/handlers') diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 6d1fbb1a5c..8b349f6ad6 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -33,12 +33,12 @@ from ._base import BaseHandler logger = logging.getLogger(__name__) -class ProfileHandler(BaseHandler): +class WorkerProfileHandler(BaseHandler): PROFILE_UPDATE_MS = 60 * 1000 PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000 def __init__(self, hs): - super(ProfileHandler, self).__init__(hs) + super(WorkerProfileHandler, self).__init__(hs) self.federation = hs.get_federation_client() hs.get_federation_registry().register_query_handler( @@ -47,11 +47,6 @@ class ProfileHandler(BaseHandler): self.user_directory_handler = hs.get_user_directory_handler() - if hs.config.worker_app is None: - self.clock.looping_call( - self._start_update_remote_profile_cache, self.PROFILE_UPDATE_MS, - ) - self._notify_master_profile_change = ( ReplicationHandleProfileChangeRestServlet.make_client(hs) ) @@ -298,6 +293,18 @@ class ProfileHandler(BaseHandler): room_id, str(e.message) ) + +class MasterProfileHandler(WorkerProfileHandler): + PROFILE_UPDATE_MS = 60 * 1000 + PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000 + + def __init__(self, hs): + super(MasterProfileHandler, self).__init__(hs) + + self.clock.looping_call( + self._start_update_remote_profile_cache, self.PROFILE_UPDATE_MS, + ) + def _start_update_remote_profile_cache(self): return run_as_background_process( "Update remote profile", self._update_remote_profile_cache, diff --git a/synapse/server.py b/synapse/server.py index 140be9ebe8..be85aad8cf 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -55,7 +55,7 @@ from synapse.handlers.initial_sync import InitialSyncHandler from synapse.handlers.message import EventCreationHandler, MessageHandler from synapse.handlers.pagination import PaginationHandler from synapse.handlers.presence import PresenceHandler -from synapse.handlers.profile import ProfileHandler +from synapse.handlers.profile import MasterProfileHandler, WorkerProfileHandler from synapse.handlers.read_marker import ReadMarkerHandler from synapse.handlers.receipts import ReceiptsHandler from synapse.handlers.room import RoomContextHandler, RoomCreationHandler @@ -307,7 +307,10 @@ class HomeServer(object): return InitialSyncHandler(self) def build_profile_handler(self): - return ProfileHandler(self) + if self.config.worker_app: + return WorkerProfileHandler(self) + else: + return MasterProfileHandler(self) def build_event_creation_handler(self): return EventCreationHandler(self) diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index dc17918a3d..07cf5f4c8e 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -20,7 +20,7 @@ from twisted.internet import defer import synapse.types from synapse.api.errors import AuthError -from synapse.handlers.profile import ProfileHandler +from synapse.handlers.profile import MasterProfileHandler from synapse.types import UserID from tests import unittest @@ -29,7 +29,7 @@ from tests.utils import setup_test_homeserver class ProfileHandlers(object): def __init__(self, hs): - self.profile_handler = ProfileHandler(hs) + self.profile_handler = MasterProfileHandler(hs) class ProfileTestCase(unittest.TestCase): -- cgit 1.4.1 From e07970165f852ccbc4542f1aaf0fd1b2bc54b973 Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Sat, 18 Aug 2018 14:39:45 +0100 Subject: rename error code --- synapse/api/auth.py | 4 ++-- synapse/api/errors.py | 4 ++-- tests/api/test_auth.py | 4 ++-- tests/handlers/test_sync.py | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) (limited to 'tests/handlers') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 55384d6ffe..4207a48afd 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -792,7 +792,7 @@ class Auth(object): if self.hs.config.hs_disabled: raise ResourceLimitError( 403, self.hs.config.hs_disabled_message, - errcode=Codes.RESOURCE_LIMIT_EXCEED, + errcode=Codes.RESOURCE_LIMIT_EXCEEDED, admin_uri=self.hs.config.admin_uri, limit_type=self.hs.config.hs_disabled_limit_type ) @@ -809,6 +809,6 @@ class Auth(object): 403, "Monthly Active User Limit Exceeded", admin_uri=self.hs.config.admin_uri, - errcode=Codes.RESOURCE_LIMIT_EXCEED, + errcode=Codes.RESOURCE_LIMIT_EXCEEDED, limit_type="monthly_active_user" ) diff --git a/synapse/api/errors.py b/synapse/api/errors.py index e26001ab12..c4ddba9889 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -56,7 +56,7 @@ class Codes(object): SERVER_NOT_TRUSTED = "M_SERVER_NOT_TRUSTED" CONSENT_NOT_GIVEN = "M_CONSENT_NOT_GIVEN" CANNOT_LEAVE_SERVER_NOTICE_ROOM = "M_CANNOT_LEAVE_SERVER_NOTICE_ROOM" - RESOURCE_LIMIT_EXCEED = "M_RESOURCE_LIMIT_EXCEED" + RESOURCE_LIMIT_EXCEEDED = "M_RESOURCE_LIMIT_EXCEEDED" UNSUPPORTED_ROOM_VERSION = "M_UNSUPPORTED_ROOM_VERSION" INCOMPATIBLE_ROOM_VERSION = "M_INCOMPATIBLE_ROOM_VERSION" @@ -238,7 +238,7 @@ class ResourceLimitError(SynapseError): """ def __init__( self, code, msg, - errcode=Codes.RESOURCE_LIMIT_EXCEED, + errcode=Codes.RESOURCE_LIMIT_EXCEEDED, admin_uri=None, limit_type=None, ): diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index c4cbff4e8d..ed960090c4 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -458,7 +458,7 @@ class AuthTestCase(unittest.TestCase): with self.assertRaises(ResourceLimitError) as e: yield self.auth.check_auth_blocking() self.assertEquals(e.exception.admin_uri, self.hs.config.admin_uri) - self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEED) + self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.assertEquals(e.exception.code, 403) # Ensure does not throw an error @@ -474,7 +474,7 @@ class AuthTestCase(unittest.TestCase): with self.assertRaises(ResourceLimitError) as e: yield self.auth.check_auth_blocking() self.assertEquals(e.exception.admin_uri, self.hs.config.admin_uri) - self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEED) + self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.assertEquals(e.exception.code, 403) @defer.inlineCallbacks diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py index a01ab471f5..31f54bbd7d 100644 --- a/tests/handlers/test_sync.py +++ b/tests/handlers/test_sync.py @@ -51,7 +51,7 @@ class SyncTestCase(tests.unittest.TestCase): self.hs.config.hs_disabled = True with self.assertRaises(ResourceLimitError) as e: yield self.sync_handler.wait_for_sync_for_user(sync_config) - self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEED) + self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.hs.config.hs_disabled = False @@ -59,7 +59,7 @@ class SyncTestCase(tests.unittest.TestCase): with self.assertRaises(ResourceLimitError) as e: yield self.sync_handler.wait_for_sync_for_user(sync_config) - self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEED) + self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) def _generate_sync_config(self, user_id): return SyncConfig( -- cgit 1.4.1