From 62ace05c4521caeed023e5b9dadd993afe5c154f Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 26 Jul 2018 13:31:59 +0100 Subject: Move necessary storage functions to worker classes --- synapse/storage/events.py | 82 --------------------------------------- synapse/storage/events_worker.py | 84 ++++++++++++++++++++++++++++++++++++++++ synapse/storage/room.py | 32 +++++++-------- 3 files changed, 100 insertions(+), 98 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/events.py b/synapse/storage/events.py index e8e5a0fe44..5374796fd7 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -1430,88 +1430,6 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore (event.event_id, event.redacts) ) - @defer.inlineCallbacks - def have_events_in_timeline(self, event_ids): - """Given a list of event ids, check if we have already processed and - stored them as non outliers. - """ - rows = yield self._simple_select_many_batch( - table="events", - retcols=("event_id",), - column="event_id", - iterable=list(event_ids), - keyvalues={"outlier": False}, - desc="have_events_in_timeline", - ) - - defer.returnValue(set(r["event_id"] for r in rows)) - - @defer.inlineCallbacks - def have_seen_events(self, event_ids): - """Given a list of event ids, check if we have already processed them. - - Args: - event_ids (iterable[str]): - - Returns: - Deferred[set[str]]: The events we have already seen. - """ - results = set() - - def have_seen_events_txn(txn, chunk): - sql = ( - "SELECT event_id FROM events as e WHERE e.event_id IN (%s)" - % (",".join("?" * len(chunk)), ) - ) - txn.execute(sql, chunk) - for (event_id, ) in txn: - results.add(event_id) - - # break the input up into chunks of 100 - input_iterator = iter(event_ids) - for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), - []): - yield self.runInteraction( - "have_seen_events", - have_seen_events_txn, - chunk, - ) - defer.returnValue(results) - - def get_seen_events_with_rejections(self, event_ids): - """Given a list of event ids, check if we rejected them. - - Args: - event_ids (list[str]) - - Returns: - Deferred[dict[str, str|None): - Has an entry for each event id we already have seen. Maps to - the rejected reason string if we rejected the event, else maps - to None. - """ - if not event_ids: - return defer.succeed({}) - - def f(txn): - sql = ( - "SELECT e.event_id, reason FROM events as e " - "LEFT JOIN rejections as r ON e.event_id = r.event_id " - "WHERE e.event_id = ?" - ) - - res = {} - for event_id in event_ids: - txn.execute(sql, (event_id,)) - row = txn.fetchone() - if row: - _, rejected = row - res[event_id] = rejected - - return res - - return self.runInteraction("get_rejection_reasons", f) - @defer.inlineCallbacks def count_daily_messages(self): """ diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py index 9b4cfeb899..e1dc2c62fd 100644 --- a/synapse/storage/events_worker.py +++ b/synapse/storage/events_worker.py @@ -12,7 +12,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import itertools import logging + from collections import namedtuple from canonicaljson import json @@ -442,3 +444,85 @@ class EventsWorkerStore(SQLBaseStore): self._get_event_cache.prefill((original_ev.event_id,), cache_entry) defer.returnValue(cache_entry) + + @defer.inlineCallbacks + def have_events_in_timeline(self, event_ids): + """Given a list of event ids, check if we have already processed and + stored them as non outliers. + """ + rows = yield self._simple_select_many_batch( + table="events", + retcols=("event_id",), + column="event_id", + iterable=list(event_ids), + keyvalues={"outlier": False}, + desc="have_events_in_timeline", + ) + + defer.returnValue(set(r["event_id"] for r in rows)) + + @defer.inlineCallbacks + def have_seen_events(self, event_ids): + """Given a list of event ids, check if we have already processed them. + + Args: + event_ids (iterable[str]): + + Returns: + Deferred[set[str]]: The events we have already seen. + """ + results = set() + + def have_seen_events_txn(txn, chunk): + sql = ( + "SELECT event_id FROM events as e WHERE e.event_id IN (%s)" + % (",".join("?" * len(chunk)), ) + ) + txn.execute(sql, chunk) + for (event_id, ) in txn: + results.add(event_id) + + # break the input up into chunks of 100 + input_iterator = iter(event_ids) + for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), + []): + yield self.runInteraction( + "have_seen_events", + have_seen_events_txn, + chunk, + ) + defer.returnValue(results) + + def get_seen_events_with_rejections(self, event_ids): + """Given a list of event ids, check if we rejected them. + + Args: + event_ids (list[str]) + + Returns: + Deferred[dict[str, str|None): + Has an entry for each event id we already have seen. Maps to + the rejected reason string if we rejected the event, else maps + to None. + """ + if not event_ids: + return defer.succeed({}) + + def f(txn): + sql = ( + "SELECT e.event_id, reason FROM events as e " + "LEFT JOIN rejections as r ON e.event_id = r.event_id " + "WHERE e.event_id = ?" + ) + + res = {} + for event_id in event_ids: + txn.execute(sql, (event_id,)) + row = txn.fetchone() + if row: + _, rejected = row + res[event_id] = rejected + + return res + + return self.runInteraction("get_rejection_reasons", f) diff --git a/synapse/storage/room.py b/synapse/storage/room.py index 3147fb6827..3378fc77d1 100644 --- a/synapse/storage/room.py +++ b/synapse/storage/room.py @@ -41,6 +41,22 @@ RatelimitOverride = collections.namedtuple( class RoomWorkerStore(SQLBaseStore): + def get_room(self, room_id): + """Retrieve a room. + + Args: + room_id (str): The ID of the room to retrieve. + Returns: + A namedtuple containing the room information, or an empty list. + """ + return self._simple_select_one( + table="rooms", + keyvalues={"room_id": room_id}, + retcols=("room_id", "is_public", "creator"), + desc="get_room", + allow_none=True, + ) + def get_public_room_ids(self): return self._simple_select_onecol( table="rooms", @@ -215,22 +231,6 @@ class RoomStore(RoomWorkerStore, SearchStore): logger.error("store_room with room_id=%s failed: %s", room_id, e) raise StoreError(500, "Problem creating room.") - def get_room(self, room_id): - """Retrieve a room. - - Args: - room_id (str): The ID of the room to retrieve. - Returns: - A namedtuple containing the room information, or an empty list. - """ - return self._simple_select_one( - table="rooms", - keyvalues={"room_id": room_id}, - retcols=("room_id", "is_public", "creator"), - desc="get_room", - allow_none=True, - ) - @defer.inlineCallbacks def set_room_is_public(self, room_id, is_public): def set_room_is_public_txn(txn, next_id): -- cgit 1.5.1 From 96a9a296455e2c08be4e0375cf784e4113fa51e1 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 26 Jul 2018 13:35:36 +0100 Subject: Pull in necessary stores in federation_reader --- synapse/app/federation_reader.py | 2 ++ synapse/storage/events_worker.py | 1 - 2 files changed, 2 insertions(+), 1 deletion(-) (limited to 'synapse/storage') diff --git a/synapse/app/federation_reader.py b/synapse/app/federation_reader.py index c512b4be87..0c557a8242 100644 --- a/synapse/app/federation_reader.py +++ b/synapse/app/federation_reader.py @@ -36,6 +36,7 @@ from synapse.replication.slave.storage.appservice import SlavedApplicationServic from synapse.replication.slave.storage.directory import DirectoryStore from synapse.replication.slave.storage.events import SlavedEventStore from synapse.replication.slave.storage.keys import SlavedKeyStore +from synapse.replication.slave.storage.profile import SlavedProfileStore from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore from synapse.replication.slave.storage.pushers import SlavedPusherStore from synapse.replication.slave.storage.receipts import SlavedReceiptsStore @@ -53,6 +54,7 @@ logger = logging.getLogger("synapse.app.federation_reader") class FederationReaderSlavedStore( + SlavedProfileStore, SlavedApplicationServiceStore, SlavedPusherStore, SlavedPushRuleStore, diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py index e1dc2c62fd..59822178ff 100644 --- a/synapse/storage/events_worker.py +++ b/synapse/storage/events_worker.py @@ -14,7 +14,6 @@ # limitations under the License. import itertools import logging - from collections import namedtuple from canonicaljson import json -- cgit 1.5.1 From 885ea9c602ca4002385a6d73c94c35e0958de809 Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Thu, 9 Aug 2018 18:02:12 +0100 Subject: rename _user_last_seen_monthly_active --- synapse/api/auth.py | 2 +- synapse/storage/monthly_active_users.py | 10 +++++----- tests/storage/test_client_ips.py | 12 ++++++------ tests/storage/test_monthly_active_users.py | 13 +++++++------ 4 files changed, 19 insertions(+), 18 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index df6022ff69..c31c6a6a08 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -789,7 +789,7 @@ class Auth(object): if self.hs.config.limit_usage_by_mau is True: # If the user is already part of the MAU cohort if user_id: - timestamp = yield self.store._user_last_seen_monthly_active(user_id) + timestamp = yield self.store.user_last_seen_monthly_active(user_id) if timestamp: return # Else if there is no room in the MAU bucket, bail diff --git a/synapse/storage/monthly_active_users.py b/synapse/storage/monthly_active_users.py index d47dcef3a0..07211432af 100644 --- a/synapse/storage/monthly_active_users.py +++ b/synapse/storage/monthly_active_users.py @@ -113,7 +113,7 @@ class MonthlyActiveUsersStore(SQLBaseStore): # is racy. # Have resolved to invalidate the whole cache for now and do # something about it if and when the perf becomes significant - self._user_last_seen_monthly_active.invalidate_all() + self.user_last_seen_monthly_active.invalidate_all() self.get_monthly_active_count.invalidate_all() @cached(num_args=0) @@ -152,11 +152,11 @@ class MonthlyActiveUsersStore(SQLBaseStore): lock=False, ) if is_insert: - self._user_last_seen_monthly_active.invalidate((user_id,)) + self.user_last_seen_monthly_active.invalidate((user_id,)) self.get_monthly_active_count.invalidate(()) @cached(num_args=1) - def _user_last_seen_monthly_active(self, user_id): + def user_last_seen_monthly_active(self, user_id): """ Checks if a given user is part of the monthly active user group Arguments: @@ -173,7 +173,7 @@ class MonthlyActiveUsersStore(SQLBaseStore): }, retcol="timestamp", allow_none=True, - desc="_user_last_seen_monthly_active", + desc="user_last_seen_monthly_active", )) @defer.inlineCallbacks @@ -185,7 +185,7 @@ class MonthlyActiveUsersStore(SQLBaseStore): user_id(str): the user_id to query """ if self.hs.config.limit_usage_by_mau: - last_seen_timestamp = yield self._user_last_seen_monthly_active(user_id) + last_seen_timestamp = yield self.user_last_seen_monthly_active(user_id) now = self.hs.get_clock().time_msec() # We want to reduce to the total number of db writes, and are happy diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index 7a58c6eb24..6d2bb3058e 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -64,7 +64,7 @@ class ClientIpStoreTestCase(tests.unittest.TestCase): yield self.store.insert_client_ip( user_id, "access_token", "ip", "user_agent", "device_id", ) - active = yield self.store._user_last_seen_monthly_active(user_id) + active = yield self.store.user_last_seen_monthly_active(user_id) self.assertFalse(active) @defer.inlineCallbacks @@ -80,7 +80,7 @@ class ClientIpStoreTestCase(tests.unittest.TestCase): yield self.store.insert_client_ip( user_id, "access_token", "ip", "user_agent", "device_id", ) - active = yield self.store._user_last_seen_monthly_active(user_id) + active = yield self.store.user_last_seen_monthly_active(user_id) self.assertFalse(active) @defer.inlineCallbacks @@ -88,13 +88,13 @@ class ClientIpStoreTestCase(tests.unittest.TestCase): self.hs.config.limit_usage_by_mau = True self.hs.config.max_mau_value = 50 user_id = "@user:server" - active = yield self.store._user_last_seen_monthly_active(user_id) + active = yield self.store.user_last_seen_monthly_active(user_id) self.assertFalse(active) yield self.store.insert_client_ip( user_id, "access_token", "ip", "user_agent", "device_id", ) - active = yield self.store._user_last_seen_monthly_active(user_id) + active = yield self.store.user_last_seen_monthly_active(user_id) self.assertTrue(active) @defer.inlineCallbacks @@ -103,7 +103,7 @@ class ClientIpStoreTestCase(tests.unittest.TestCase): self.hs.config.max_mau_value = 50 user_id = "@user:server" - active = yield self.store._user_last_seen_monthly_active(user_id) + active = yield self.store.user_last_seen_monthly_active(user_id) self.assertFalse(active) yield self.store.insert_client_ip( @@ -112,5 +112,5 @@ class ClientIpStoreTestCase(tests.unittest.TestCase): yield self.store.insert_client_ip( user_id, "access_token", "ip", "user_agent", "device_id", ) - active = yield self.store._user_last_seen_monthly_active(user_id) + active = yield self.store.user_last_seen_monthly_active(user_id) self.assertTrue(active) diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index cbd480cd42..be74c30218 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -66,9 +66,9 @@ class MonthlyActiveUsersTestCase(tests.unittest.TestCase): # Test user is marked as active - timestamp = yield self.store._user_last_seen_monthly_active(user1) + timestamp = yield self.store.user_last_seen_monthly_active(user1) self.assertTrue(timestamp) - timestamp = yield self.store._user_last_seen_monthly_active(user2) + timestamp = yield self.store.user_last_seen_monthly_active(user2) self.assertTrue(timestamp) # Test that users are never removed from the db. @@ -92,17 +92,18 @@ class MonthlyActiveUsersTestCase(tests.unittest.TestCase): self.assertEqual(1, count) @defer.inlineCallbacks - def test__user_last_seen_monthly_active(self): + def test_user_last_seen_monthly_active(self): user_id1 = "@user1:server" user_id2 = "@user2:server" user_id3 = "@user3:server" - result = yield self.store._user_last_seen_monthly_active(user_id1) + + result = yield self.store.user_last_seen_monthly_active(user_id1) self.assertFalse(result == 0) yield self.store.upsert_monthly_active_user(user_id1) yield self.store.upsert_monthly_active_user(user_id2) - result = yield self.store._user_last_seen_monthly_active(user_id1) + result = yield self.store.user_last_seen_monthly_active(user_id1) self.assertTrue(result > 0) - result = yield self.store._user_last_seen_monthly_active(user_id3) + result = yield self.store.user_last_seen_monthly_active(user_id3) self.assertFalse(result == 0) @defer.inlineCallbacks -- cgit 1.5.1 From b37c4724191995eec4f5bc9d64f176b2a315f777 Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Fri, 10 Aug 2018 23:50:21 +1000 Subject: Rename async to async_helpers because `async` is a keyword on Python 3.7 (#3678) --- changelog.d/3678.misc | 1 + synapse/app/federation_sender.py | 2 +- synapse/federation/federation_server.py | 8 +- synapse/handlers/device.py | 2 +- synapse/handlers/federation.py | 2 +- synapse/handlers/initial_sync.py | 2 +- synapse/handlers/message.py | 2 +- synapse/handlers/pagination.py | 2 +- synapse/handlers/presence.py | 2 +- synapse/handlers/read_marker.py | 2 +- synapse/handlers/register.py | 2 +- synapse/handlers/room_list.py | 2 +- synapse/handlers/room_member.py | 2 +- synapse/handlers/sync.py | 2 +- synapse/http/client.py | 2 +- synapse/http/matrixfederationclient.py | 2 +- synapse/notifier.py | 2 +- synapse/push/bulk_push_rule_evaluator.py | 2 +- synapse/push/mailer.py | 2 +- synapse/rest/client/transactions.py | 2 +- synapse/rest/media/v1/media_repository.py | 2 +- synapse/rest/media/v1/preview_url_resource.py | 2 +- synapse/state.py | 2 +- synapse/storage/events.py | 2 +- synapse/storage/roommember.py | 2 +- synapse/util/async.py | 415 -------------------------- synapse/util/async_helpers.py | 415 ++++++++++++++++++++++++++ synapse/util/caches/descriptors.py | 2 +- synapse/util/caches/response_cache.py | 2 +- synapse/util/caches/snapshot_cache.py | 2 +- synapse/util/logcontext.py | 2 +- tests/storage/test__base.py | 2 +- tests/util/test_linearizer.py | 2 +- tests/util/test_rwlock.py | 2 +- 34 files changed, 450 insertions(+), 449 deletions(-) create mode 100644 changelog.d/3678.misc delete mode 100644 synapse/util/async.py create mode 100644 synapse/util/async_helpers.py (limited to 'synapse/storage') diff --git a/changelog.d/3678.misc b/changelog.d/3678.misc new file mode 100644 index 0000000000..0d7c8da64a --- /dev/null +++ b/changelog.d/3678.misc @@ -0,0 +1 @@ +Rename synapse.util.async to synapse.util.async_helpers to mitigate async becoming a keyword on Python 3.7. diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py index e082d45c94..7bbf0ad082 100644 --- a/synapse/app/federation_sender.py +++ b/synapse/app/federation_sender.py @@ -40,7 +40,7 @@ from synapse.replication.slave.storage.transactions import TransactionStore from synapse.replication.tcp.client import ReplicationClientHandler from synapse.server import HomeServer from synapse.storage.engines import create_engine -from synapse.util.async import Linearizer +from synapse.util.async_helpers import Linearizer from synapse.util.httpresourcetree import create_resource_tree from synapse.util.logcontext import LoggingContext, run_in_background from synapse.util.manhole import manhole diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 2b62f687b6..a23136784a 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -40,7 +40,7 @@ from synapse.federation.persistence import TransactionActions from synapse.federation.units import Edu, Transaction from synapse.http.endpoint import parse_server_name from synapse.types import get_domain_from_id -from synapse.util import async +from synapse.util.async_helpers import Linearizer, concurrently_execute from synapse.util.caches.response_cache import ResponseCache from synapse.util.logutils import log_function @@ -67,8 +67,8 @@ class FederationServer(FederationBase): self.auth = hs.get_auth() self.handler = hs.get_handlers().federation_handler - self._server_linearizer = async.Linearizer("fed_server") - self._transaction_linearizer = async.Linearizer("fed_txn_handler") + self._server_linearizer = Linearizer("fed_server") + self._transaction_linearizer = Linearizer("fed_txn_handler") self.transaction_actions = TransactionActions(self.store) @@ -200,7 +200,7 @@ class FederationServer(FederationBase): event_id, f.getTraceback().rstrip(), ) - yield async.concurrently_execute( + yield concurrently_execute( process_pdus_for_room, pdus_by_room.keys(), TRANSACTION_CONCURRENCY_LIMIT, ) diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 2d44f15da3..9e017116a9 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -23,7 +23,7 @@ from synapse.api.constants import EventTypes from synapse.api.errors import FederationDeniedError from synapse.types import RoomStreamToken, get_domain_from_id from synapse.util import stringutils -from synapse.util.async import Linearizer +from synapse.util.async_helpers import Linearizer from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.metrics import measure_func from synapse.util.retryutils import NotRetryingDestination diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 0dffd44e22..2380d17f4e 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -52,7 +52,7 @@ from synapse.events.validator import EventValidator from synapse.state import resolve_events_with_factory from synapse.types import UserID, get_domain_from_id from synapse.util import logcontext, unwrapFirstError -from synapse.util.async import Linearizer +from synapse.util.async_helpers import Linearizer from synapse.util.distributor import user_joined_room from synapse.util.frozenutils import unfreeze from synapse.util.logutils import log_function diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 40e7580a61..1fb17fd9a5 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -25,7 +25,7 @@ from synapse.handlers.presence import format_user_presence_state from synapse.streams.config import PaginationConfig from synapse.types import StreamToken, UserID from synapse.util import unwrapFirstError -from synapse.util.async import concurrently_execute +from synapse.util.async_helpers import concurrently_execute from synapse.util.caches.snapshot_cache import SnapshotCache from synapse.util.logcontext import make_deferred_yieldable, run_in_background from synapse.visibility import filter_events_for_client diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index bcb093ba3e..01a362360e 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -32,7 +32,7 @@ from synapse.events.utils import serialize_event from synapse.events.validator import EventValidator from synapse.replication.http.send_event import ReplicationSendEventRestServlet from synapse.types import RoomAlias, UserID -from synapse.util.async import Linearizer +from synapse.util.async_helpers import Linearizer from synapse.util.frozenutils import frozendict_json_encoder from synapse.util.logcontext import run_in_background from synapse.util.metrics import measure_func diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index b2849783ed..a97d43550f 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -22,7 +22,7 @@ from synapse.api.constants import Membership from synapse.api.errors import SynapseError from synapse.events.utils import serialize_event from synapse.types import RoomStreamToken -from synapse.util.async import ReadWriteLock +from synapse.util.async_helpers import ReadWriteLock from synapse.util.logcontext import run_in_background from synapse.util.stringutils import random_string from synapse.visibility import filter_events_for_client diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 3732830194..20fc3b0323 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -36,7 +36,7 @@ from synapse.api.errors import SynapseError from synapse.metrics import LaterGauge from synapse.storage.presence import UserPresenceState from synapse.types import UserID, get_domain_from_id -from synapse.util.async import Linearizer +from synapse.util.async_helpers import Linearizer from synapse.util.caches.descriptors import cachedInlineCallbacks from synapse.util.logcontext import run_in_background from synapse.util.logutils import log_function diff --git a/synapse/handlers/read_marker.py b/synapse/handlers/read_marker.py index 995460f82a..32108568c6 100644 --- a/synapse/handlers/read_marker.py +++ b/synapse/handlers/read_marker.py @@ -17,7 +17,7 @@ import logging from twisted.internet import defer -from synapse.util.async import Linearizer +from synapse.util.async_helpers import Linearizer from ._base import BaseHandler diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 0e16bbe0ee..3526b20d5a 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -28,7 +28,7 @@ from synapse.api.errors import ( ) from synapse.http.client import CaptchaServerHttpClient from synapse.types import RoomAlias, RoomID, UserID, create_requester -from synapse.util.async import Linearizer +from synapse.util.async_helpers import Linearizer from synapse.util.threepids import check_3pid_allowed from ._base import BaseHandler diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index 828229f5c3..37e41afd61 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -26,7 +26,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, JoinRules from synapse.types import ThirdPartyInstanceID -from synapse.util.async import concurrently_execute +from synapse.util.async_helpers import concurrently_execute from synapse.util.caches.descriptors import cachedInlineCallbacks from synapse.util.caches.response_cache import ResponseCache diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 0d4a3f4677..fb94b5d7d4 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -30,7 +30,7 @@ import synapse.types from synapse.api.constants import EventTypes, Membership from synapse.api.errors import AuthError, Codes, SynapseError from synapse.types import RoomID, UserID -from synapse.util.async import Linearizer +from synapse.util.async_helpers import Linearizer from synapse.util.distributor import user_joined_room, user_left_room logger = logging.getLogger(__name__) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index dff1f67dcb..6393a9674b 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -25,7 +25,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership from synapse.push.clientformat import format_push_rules_for_user from synapse.types import RoomStreamToken -from synapse.util.async import concurrently_execute +from synapse.util.async_helpers import concurrently_execute from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.lrucache import LruCache from synapse.util.caches.response_cache import ResponseCache diff --git a/synapse/http/client.py b/synapse/http/client.py index 3771e0b3f6..ab4fbf59b2 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -42,7 +42,7 @@ from twisted.web.http_headers import Headers from synapse.api.errors import Codes, HttpResponseException, SynapseError from synapse.http import cancelled_to_request_timed_out_error, redact_uri from synapse.http.endpoint import SpiderEndpoint -from synapse.util.async import add_timeout_to_deferred +from synapse.util.async_helpers import add_timeout_to_deferred from synapse.util.caches import CACHE_SIZE_FACTOR from synapse.util.logcontext import make_deferred_yieldable diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 762273f59b..44b61e70a4 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -43,7 +43,7 @@ from synapse.api.errors import ( from synapse.http import cancelled_to_request_timed_out_error from synapse.http.endpoint import matrix_federation_endpoint from synapse.util import logcontext -from synapse.util.async import add_timeout_to_deferred +from synapse.util.async_helpers import add_timeout_to_deferred from synapse.util.logcontext import make_deferred_yieldable logger = logging.getLogger(__name__) diff --git a/synapse/notifier.py b/synapse/notifier.py index e650c3e494..82f391481c 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -25,7 +25,7 @@ from synapse.api.errors import AuthError from synapse.handlers.presence import format_user_presence_state from synapse.metrics import LaterGauge from synapse.types import StreamToken -from synapse.util.async import ( +from synapse.util.async_helpers import ( DeferredTimeoutError, ObservableDeferred, add_timeout_to_deferred, diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 1d14d3639c..8f9a76147f 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -26,7 +26,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership from synapse.event_auth import get_user_power_level from synapse.state import POWER_KEY -from synapse.util.async import Linearizer +from synapse.util.async_helpers import Linearizer from synapse.util.caches import register_cache from synapse.util.caches.descriptors import cached diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index 9d601208fd..bfa6df7b68 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -35,7 +35,7 @@ from synapse.push.presentable_names import ( name_from_member_event, ) from synapse.types import UserID -from synapse.util.async import concurrently_execute +from synapse.util.async_helpers import concurrently_execute from synapse.visibility import filter_events_for_client logger = logging.getLogger(__name__) diff --git a/synapse/rest/client/transactions.py b/synapse/rest/client/transactions.py index 00b1b3066e..511e96ab00 100644 --- a/synapse/rest/client/transactions.py +++ b/synapse/rest/client/transactions.py @@ -17,7 +17,7 @@ to ensure idempotency when performing PUTs using the REST API.""" import logging -from synapse.util.async import ObservableDeferred +from synapse.util.async_helpers import ObservableDeferred from synapse.util.logcontext import make_deferred_yieldable, run_in_background logger = logging.getLogger(__name__) diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 8fb413d825..4c589e05e0 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -36,7 +36,7 @@ from synapse.api.errors import ( ) from synapse.http.matrixfederationclient import MatrixFederationHttpClient from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.util.async import Linearizer +from synapse.util.async_helpers import Linearizer from synapse.util.logcontext import make_deferred_yieldable from synapse.util.retryutils import NotRetryingDestination from synapse.util.stringutils import is_ascii, random_string diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 27aa0def2f..778ef97337 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -42,7 +42,7 @@ from synapse.http.server import ( ) from synapse.http.servlet import parse_integer, parse_string from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.util.async import ObservableDeferred +from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.logcontext import make_deferred_yieldable, run_in_background from synapse.util.stringutils import is_ascii, random_string diff --git a/synapse/state.py b/synapse/state.py index e1092b97a9..8b92d4057a 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -28,7 +28,7 @@ from synapse import event_auth from synapse.api.constants import EventTypes from synapse.api.errors import AuthError from synapse.events.snapshot import EventContext -from synapse.util.async import Linearizer +from synapse.util.async_helpers import Linearizer from synapse.util.caches import CACHE_SIZE_FACTOR from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.logutils import log_function diff --git a/synapse/storage/events.py b/synapse/storage/events.py index ce32e8fefd..d4aa192a0a 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -38,7 +38,7 @@ from synapse.storage.background_updates import BackgroundUpdateStore from synapse.storage.event_federation import EventFederationStore from synapse.storage.events_worker import EventsWorkerStore from synapse.types import RoomStreamToken, get_domain_from_id -from synapse.util.async import ObservableDeferred +from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches.descriptors import cached, cachedInlineCallbacks from synapse.util.frozenutils import frozendict_json_encoder from synapse.util.logcontext import PreserveLoggingContext, make_deferred_yieldable diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 10dce21cea..9b4e6d6aa8 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -26,7 +26,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership from synapse.storage.events_worker import EventsWorkerStore from synapse.types import get_domain_from_id -from synapse.util.async import Linearizer +from synapse.util.async_helpers import Linearizer from synapse.util.caches import intern_string from synapse.util.caches.descriptors import cached, cachedInlineCallbacks from synapse.util.stringutils import to_ascii diff --git a/synapse/util/async.py b/synapse/util/async.py deleted file mode 100644 index a7094e2fb4..0000000000 --- a/synapse/util/async.py +++ /dev/null @@ -1,415 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# Copyright 2018 New Vector Ltd. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import collections -import logging -from contextlib import contextmanager - -from six.moves import range - -from twisted.internet import defer -from twisted.internet.defer import CancelledError -from twisted.python import failure - -from synapse.util import Clock, logcontext, unwrapFirstError - -from .logcontext import ( - PreserveLoggingContext, - make_deferred_yieldable, - run_in_background, -) - -logger = logging.getLogger(__name__) - - -class ObservableDeferred(object): - """Wraps a deferred object so that we can add observer deferreds. These - observer deferreds do not affect the callback chain of the original - deferred. - - If consumeErrors is true errors will be captured from the origin deferred. - - Cancelling or otherwise resolving an observer will not affect the original - ObservableDeferred. - - NB that it does not attempt to do anything with logcontexts; in general - you should probably make_deferred_yieldable the deferreds - returned by `observe`, and ensure that the original deferred runs its - callbacks in the sentinel logcontext. - """ - - __slots__ = ["_deferred", "_observers", "_result"] - - def __init__(self, deferred, consumeErrors=False): - object.__setattr__(self, "_deferred", deferred) - object.__setattr__(self, "_result", None) - object.__setattr__(self, "_observers", set()) - - def callback(r): - object.__setattr__(self, "_result", (True, r)) - while self._observers: - try: - # TODO: Handle errors here. - self._observers.pop().callback(r) - except Exception: - pass - return r - - def errback(f): - object.__setattr__(self, "_result", (False, f)) - while self._observers: - try: - # TODO: Handle errors here. - self._observers.pop().errback(f) - except Exception: - pass - - if consumeErrors: - return None - else: - return f - - deferred.addCallbacks(callback, errback) - - def observe(self): - """Observe the underlying deferred. - - Can return either a deferred if the underlying deferred is still pending - (or has failed), or the actual value. Callers may need to use maybeDeferred. - """ - if not self._result: - d = defer.Deferred() - - def remove(r): - self._observers.discard(d) - return r - d.addBoth(remove) - - self._observers.add(d) - return d - else: - success, res = self._result - return res if success else defer.fail(res) - - def observers(self): - return self._observers - - def has_called(self): - return self._result is not None - - def has_succeeded(self): - return self._result is not None and self._result[0] is True - - def get_result(self): - return self._result[1] - - def __getattr__(self, name): - return getattr(self._deferred, name) - - def __setattr__(self, name, value): - setattr(self._deferred, name, value) - - def __repr__(self): - return "" % ( - id(self), self._result, self._deferred, - ) - - -def concurrently_execute(func, args, limit): - """Executes the function with each argument conncurrently while limiting - the number of concurrent executions. - - Args: - func (func): Function to execute, should return a deferred. - args (list): List of arguments to pass to func, each invocation of func - gets a signle argument. - limit (int): Maximum number of conccurent executions. - - Returns: - deferred: Resolved when all function invocations have finished. - """ - it = iter(args) - - @defer.inlineCallbacks - def _concurrently_execute_inner(): - try: - while True: - yield func(next(it)) - except StopIteration: - pass - - return logcontext.make_deferred_yieldable(defer.gatherResults([ - run_in_background(_concurrently_execute_inner) - for _ in range(limit) - ], consumeErrors=True)).addErrback(unwrapFirstError) - - -class Linearizer(object): - """Limits concurrent access to resources based on a key. Useful to ensure - only a few things happen at a time on a given resource. - - Example: - - with (yield limiter.queue("test_key")): - # do some work. - - """ - def __init__(self, name=None, max_count=1, clock=None): - """ - Args: - max_count(int): The maximum number of concurrent accesses - """ - if name is None: - self.name = id(self) - else: - self.name = name - - if not clock: - from twisted.internet import reactor - clock = Clock(reactor) - self._clock = clock - self.max_count = max_count - - # key_to_defer is a map from the key to a 2 element list where - # the first element is the number of things executing, and - # the second element is an OrderedDict, where the keys are deferreds for the - # things blocked from executing. - self.key_to_defer = {} - - @defer.inlineCallbacks - def queue(self, key): - entry = self.key_to_defer.setdefault(key, [0, collections.OrderedDict()]) - - # If the number of things executing is greater than the maximum - # then add a deferred to the list of blocked items - # When on of the things currently executing finishes it will callback - # this item so that it can continue executing. - if entry[0] >= self.max_count: - new_defer = defer.Deferred() - entry[1][new_defer] = 1 - - logger.info( - "Waiting to acquire linearizer lock %r for key %r", self.name, key, - ) - try: - yield make_deferred_yieldable(new_defer) - except Exception as e: - if isinstance(e, CancelledError): - logger.info( - "Cancelling wait for linearizer lock %r for key %r", - self.name, key, - ) - else: - logger.warn( - "Unexpected exception waiting for linearizer lock %r for key %r", - self.name, key, - ) - - # we just have to take ourselves back out of the queue. - del entry[1][new_defer] - raise - - logger.info("Acquired linearizer lock %r for key %r", self.name, key) - entry[0] += 1 - - # if the code holding the lock completes synchronously, then it - # will recursively run the next claimant on the list. That can - # relatively rapidly lead to stack exhaustion. This is essentially - # the same problem as http://twistedmatrix.com/trac/ticket/9304. - # - # In order to break the cycle, we add a cheeky sleep(0) here to - # ensure that we fall back to the reactor between each iteration. - # - # (This needs to happen while we hold the lock, and the context manager's exit - # code must be synchronous, so this is the only sensible place.) - yield self._clock.sleep(0) - - else: - logger.info( - "Acquired uncontended linearizer lock %r for key %r", self.name, key, - ) - entry[0] += 1 - - @contextmanager - def _ctx_manager(): - try: - yield - finally: - logger.info("Releasing linearizer lock %r for key %r", self.name, key) - - # We've finished executing so check if there are any things - # blocked waiting to execute and start one of them - entry[0] -= 1 - - if entry[1]: - (next_def, _) = entry[1].popitem(last=False) - - # we need to run the next thing in the sentinel context. - with PreserveLoggingContext(): - next_def.callback(None) - elif entry[0] == 0: - # We were the last thing for this key: remove it from the - # map. - del self.key_to_defer[key] - - defer.returnValue(_ctx_manager()) - - -class ReadWriteLock(object): - """A deferred style read write lock. - - Example: - - with (yield read_write_lock.read("test_key")): - # do some work - """ - - # IMPLEMENTATION NOTES - # - # We track the most recent queued reader and writer deferreds (which get - # resolved when they release the lock). - # - # Read: We know its safe to acquire a read lock when the latest writer has - # been resolved. The new reader is appeneded to the list of latest readers. - # - # Write: We know its safe to acquire the write lock when both the latest - # writers and readers have been resolved. The new writer replaces the latest - # writer. - - def __init__(self): - # Latest readers queued - self.key_to_current_readers = {} - - # Latest writer queued - self.key_to_current_writer = {} - - @defer.inlineCallbacks - def read(self, key): - new_defer = defer.Deferred() - - curr_readers = self.key_to_current_readers.setdefault(key, set()) - curr_writer = self.key_to_current_writer.get(key, None) - - curr_readers.add(new_defer) - - # We wait for the latest writer to finish writing. We can safely ignore - # any existing readers... as they're readers. - yield make_deferred_yieldable(curr_writer) - - @contextmanager - def _ctx_manager(): - try: - yield - finally: - new_defer.callback(None) - self.key_to_current_readers.get(key, set()).discard(new_defer) - - defer.returnValue(_ctx_manager()) - - @defer.inlineCallbacks - def write(self, key): - new_defer = defer.Deferred() - - curr_readers = self.key_to_current_readers.get(key, set()) - curr_writer = self.key_to_current_writer.get(key, None) - - # We wait on all latest readers and writer. - to_wait_on = list(curr_readers) - if curr_writer: - to_wait_on.append(curr_writer) - - # We can clear the list of current readers since the new writer waits - # for them to finish. - curr_readers.clear() - self.key_to_current_writer[key] = new_defer - - yield make_deferred_yieldable(defer.gatherResults(to_wait_on)) - - @contextmanager - def _ctx_manager(): - try: - yield - finally: - new_defer.callback(None) - if self.key_to_current_writer[key] == new_defer: - self.key_to_current_writer.pop(key) - - defer.returnValue(_ctx_manager()) - - -class DeferredTimeoutError(Exception): - """ - This error is raised by default when a L{Deferred} times out. - """ - - -def add_timeout_to_deferred(deferred, timeout, reactor, on_timeout_cancel=None): - """ - Add a timeout to a deferred by scheduling it to be cancelled after - timeout seconds. - - This is essentially a backport of deferred.addTimeout, which was introduced - in twisted 16.5. - - If the deferred gets timed out, it errbacks with a DeferredTimeoutError, - unless a cancelable function was passed to its initialization or unless - a different on_timeout_cancel callable is provided. - - Args: - deferred (defer.Deferred): deferred to be timed out - timeout (Number): seconds to time out after - reactor (twisted.internet.reactor): the Twisted reactor to use - - on_timeout_cancel (callable): A callable which is called immediately - after the deferred times out, and not if this deferred is - otherwise cancelled before the timeout. - - It takes an arbitrary value, which is the value of the deferred at - that exact point in time (probably a CancelledError Failure), and - the timeout. - - The default callable (if none is provided) will translate a - CancelledError Failure into a DeferredTimeoutError. - """ - timed_out = [False] - - def time_it_out(): - timed_out[0] = True - deferred.cancel() - - delayed_call = reactor.callLater(timeout, time_it_out) - - def convert_cancelled(value): - if timed_out[0]: - to_call = on_timeout_cancel or _cancelled_to_timed_out_error - return to_call(value, timeout) - return value - - deferred.addBoth(convert_cancelled) - - def cancel_timeout(result): - # stop the pending call to cancel the deferred if it's been fired - if delayed_call.active(): - delayed_call.cancel() - return result - - deferred.addBoth(cancel_timeout) - - -def _cancelled_to_timed_out_error(value, timeout): - if isinstance(value, failure.Failure): - value.trap(CancelledError) - raise DeferredTimeoutError(timeout, "Deferred") - return value diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py new file mode 100644 index 0000000000..a7094e2fb4 --- /dev/null +++ b/synapse/util/async_helpers.py @@ -0,0 +1,415 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import collections +import logging +from contextlib import contextmanager + +from six.moves import range + +from twisted.internet import defer +from twisted.internet.defer import CancelledError +from twisted.python import failure + +from synapse.util import Clock, logcontext, unwrapFirstError + +from .logcontext import ( + PreserveLoggingContext, + make_deferred_yieldable, + run_in_background, +) + +logger = logging.getLogger(__name__) + + +class ObservableDeferred(object): + """Wraps a deferred object so that we can add observer deferreds. These + observer deferreds do not affect the callback chain of the original + deferred. + + If consumeErrors is true errors will be captured from the origin deferred. + + Cancelling or otherwise resolving an observer will not affect the original + ObservableDeferred. + + NB that it does not attempt to do anything with logcontexts; in general + you should probably make_deferred_yieldable the deferreds + returned by `observe`, and ensure that the original deferred runs its + callbacks in the sentinel logcontext. + """ + + __slots__ = ["_deferred", "_observers", "_result"] + + def __init__(self, deferred, consumeErrors=False): + object.__setattr__(self, "_deferred", deferred) + object.__setattr__(self, "_result", None) + object.__setattr__(self, "_observers", set()) + + def callback(r): + object.__setattr__(self, "_result", (True, r)) + while self._observers: + try: + # TODO: Handle errors here. + self._observers.pop().callback(r) + except Exception: + pass + return r + + def errback(f): + object.__setattr__(self, "_result", (False, f)) + while self._observers: + try: + # TODO: Handle errors here. + self._observers.pop().errback(f) + except Exception: + pass + + if consumeErrors: + return None + else: + return f + + deferred.addCallbacks(callback, errback) + + def observe(self): + """Observe the underlying deferred. + + Can return either a deferred if the underlying deferred is still pending + (or has failed), or the actual value. Callers may need to use maybeDeferred. + """ + if not self._result: + d = defer.Deferred() + + def remove(r): + self._observers.discard(d) + return r + d.addBoth(remove) + + self._observers.add(d) + return d + else: + success, res = self._result + return res if success else defer.fail(res) + + def observers(self): + return self._observers + + def has_called(self): + return self._result is not None + + def has_succeeded(self): + return self._result is not None and self._result[0] is True + + def get_result(self): + return self._result[1] + + def __getattr__(self, name): + return getattr(self._deferred, name) + + def __setattr__(self, name, value): + setattr(self._deferred, name, value) + + def __repr__(self): + return "" % ( + id(self), self._result, self._deferred, + ) + + +def concurrently_execute(func, args, limit): + """Executes the function with each argument conncurrently while limiting + the number of concurrent executions. + + Args: + func (func): Function to execute, should return a deferred. + args (list): List of arguments to pass to func, each invocation of func + gets a signle argument. + limit (int): Maximum number of conccurent executions. + + Returns: + deferred: Resolved when all function invocations have finished. + """ + it = iter(args) + + @defer.inlineCallbacks + def _concurrently_execute_inner(): + try: + while True: + yield func(next(it)) + except StopIteration: + pass + + return logcontext.make_deferred_yieldable(defer.gatherResults([ + run_in_background(_concurrently_execute_inner) + for _ in range(limit) + ], consumeErrors=True)).addErrback(unwrapFirstError) + + +class Linearizer(object): + """Limits concurrent access to resources based on a key. Useful to ensure + only a few things happen at a time on a given resource. + + Example: + + with (yield limiter.queue("test_key")): + # do some work. + + """ + def __init__(self, name=None, max_count=1, clock=None): + """ + Args: + max_count(int): The maximum number of concurrent accesses + """ + if name is None: + self.name = id(self) + else: + self.name = name + + if not clock: + from twisted.internet import reactor + clock = Clock(reactor) + self._clock = clock + self.max_count = max_count + + # key_to_defer is a map from the key to a 2 element list where + # the first element is the number of things executing, and + # the second element is an OrderedDict, where the keys are deferreds for the + # things blocked from executing. + self.key_to_defer = {} + + @defer.inlineCallbacks + def queue(self, key): + entry = self.key_to_defer.setdefault(key, [0, collections.OrderedDict()]) + + # If the number of things executing is greater than the maximum + # then add a deferred to the list of blocked items + # When on of the things currently executing finishes it will callback + # this item so that it can continue executing. + if entry[0] >= self.max_count: + new_defer = defer.Deferred() + entry[1][new_defer] = 1 + + logger.info( + "Waiting to acquire linearizer lock %r for key %r", self.name, key, + ) + try: + yield make_deferred_yieldable(new_defer) + except Exception as e: + if isinstance(e, CancelledError): + logger.info( + "Cancelling wait for linearizer lock %r for key %r", + self.name, key, + ) + else: + logger.warn( + "Unexpected exception waiting for linearizer lock %r for key %r", + self.name, key, + ) + + # we just have to take ourselves back out of the queue. + del entry[1][new_defer] + raise + + logger.info("Acquired linearizer lock %r for key %r", self.name, key) + entry[0] += 1 + + # if the code holding the lock completes synchronously, then it + # will recursively run the next claimant on the list. That can + # relatively rapidly lead to stack exhaustion. This is essentially + # the same problem as http://twistedmatrix.com/trac/ticket/9304. + # + # In order to break the cycle, we add a cheeky sleep(0) here to + # ensure that we fall back to the reactor between each iteration. + # + # (This needs to happen while we hold the lock, and the context manager's exit + # code must be synchronous, so this is the only sensible place.) + yield self._clock.sleep(0) + + else: + logger.info( + "Acquired uncontended linearizer lock %r for key %r", self.name, key, + ) + entry[0] += 1 + + @contextmanager + def _ctx_manager(): + try: + yield + finally: + logger.info("Releasing linearizer lock %r for key %r", self.name, key) + + # We've finished executing so check if there are any things + # blocked waiting to execute and start one of them + entry[0] -= 1 + + if entry[1]: + (next_def, _) = entry[1].popitem(last=False) + + # we need to run the next thing in the sentinel context. + with PreserveLoggingContext(): + next_def.callback(None) + elif entry[0] == 0: + # We were the last thing for this key: remove it from the + # map. + del self.key_to_defer[key] + + defer.returnValue(_ctx_manager()) + + +class ReadWriteLock(object): + """A deferred style read write lock. + + Example: + + with (yield read_write_lock.read("test_key")): + # do some work + """ + + # IMPLEMENTATION NOTES + # + # We track the most recent queued reader and writer deferreds (which get + # resolved when they release the lock). + # + # Read: We know its safe to acquire a read lock when the latest writer has + # been resolved. The new reader is appeneded to the list of latest readers. + # + # Write: We know its safe to acquire the write lock when both the latest + # writers and readers have been resolved. The new writer replaces the latest + # writer. + + def __init__(self): + # Latest readers queued + self.key_to_current_readers = {} + + # Latest writer queued + self.key_to_current_writer = {} + + @defer.inlineCallbacks + def read(self, key): + new_defer = defer.Deferred() + + curr_readers = self.key_to_current_readers.setdefault(key, set()) + curr_writer = self.key_to_current_writer.get(key, None) + + curr_readers.add(new_defer) + + # We wait for the latest writer to finish writing. We can safely ignore + # any existing readers... as they're readers. + yield make_deferred_yieldable(curr_writer) + + @contextmanager + def _ctx_manager(): + try: + yield + finally: + new_defer.callback(None) + self.key_to_current_readers.get(key, set()).discard(new_defer) + + defer.returnValue(_ctx_manager()) + + @defer.inlineCallbacks + def write(self, key): + new_defer = defer.Deferred() + + curr_readers = self.key_to_current_readers.get(key, set()) + curr_writer = self.key_to_current_writer.get(key, None) + + # We wait on all latest readers and writer. + to_wait_on = list(curr_readers) + if curr_writer: + to_wait_on.append(curr_writer) + + # We can clear the list of current readers since the new writer waits + # for them to finish. + curr_readers.clear() + self.key_to_current_writer[key] = new_defer + + yield make_deferred_yieldable(defer.gatherResults(to_wait_on)) + + @contextmanager + def _ctx_manager(): + try: + yield + finally: + new_defer.callback(None) + if self.key_to_current_writer[key] == new_defer: + self.key_to_current_writer.pop(key) + + defer.returnValue(_ctx_manager()) + + +class DeferredTimeoutError(Exception): + """ + This error is raised by default when a L{Deferred} times out. + """ + + +def add_timeout_to_deferred(deferred, timeout, reactor, on_timeout_cancel=None): + """ + Add a timeout to a deferred by scheduling it to be cancelled after + timeout seconds. + + This is essentially a backport of deferred.addTimeout, which was introduced + in twisted 16.5. + + If the deferred gets timed out, it errbacks with a DeferredTimeoutError, + unless a cancelable function was passed to its initialization or unless + a different on_timeout_cancel callable is provided. + + Args: + deferred (defer.Deferred): deferred to be timed out + timeout (Number): seconds to time out after + reactor (twisted.internet.reactor): the Twisted reactor to use + + on_timeout_cancel (callable): A callable which is called immediately + after the deferred times out, and not if this deferred is + otherwise cancelled before the timeout. + + It takes an arbitrary value, which is the value of the deferred at + that exact point in time (probably a CancelledError Failure), and + the timeout. + + The default callable (if none is provided) will translate a + CancelledError Failure into a DeferredTimeoutError. + """ + timed_out = [False] + + def time_it_out(): + timed_out[0] = True + deferred.cancel() + + delayed_call = reactor.callLater(timeout, time_it_out) + + def convert_cancelled(value): + if timed_out[0]: + to_call = on_timeout_cancel or _cancelled_to_timed_out_error + return to_call(value, timeout) + return value + + deferred.addBoth(convert_cancelled) + + def cancel_timeout(result): + # stop the pending call to cancel the deferred if it's been fired + if delayed_call.active(): + delayed_call.cancel() + return result + + deferred.addBoth(cancel_timeout) + + +def _cancelled_to_timed_out_error(value, timeout): + if isinstance(value, failure.Failure): + value.trap(CancelledError) + raise DeferredTimeoutError(timeout, "Deferred") + return value diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 861c24809c..187510576a 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -25,7 +25,7 @@ from six import itervalues, string_types from twisted.internet import defer from synapse.util import logcontext, unwrapFirstError -from synapse.util.async import ObservableDeferred +from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches import get_cache_factor_for from synapse.util.caches.lrucache import LruCache from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py index a8491b42d5..afb03b2e1b 100644 --- a/synapse/util/caches/response_cache.py +++ b/synapse/util/caches/response_cache.py @@ -16,7 +16,7 @@ import logging from twisted.internet import defer -from synapse.util.async import ObservableDeferred +from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches import register_cache from synapse.util.logcontext import make_deferred_yieldable, run_in_background diff --git a/synapse/util/caches/snapshot_cache.py b/synapse/util/caches/snapshot_cache.py index d03678b8c8..8318db8d2c 100644 --- a/synapse/util/caches/snapshot_cache.py +++ b/synapse/util/caches/snapshot_cache.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.util.async import ObservableDeferred +from synapse.util.async_helpers import ObservableDeferred class SnapshotCache(object): diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py index 8dcae50b39..07e83fadda 100644 --- a/synapse/util/logcontext.py +++ b/synapse/util/logcontext.py @@ -526,7 +526,7 @@ _to_ignore = [ "synapse.util.logcontext", "synapse.http.server", "synapse.storage._base", - "synapse.util.async", + "synapse.util.async_helpers", ] diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index 6d6f00c5c5..52eff2a104 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -18,7 +18,7 @@ from mock import Mock from twisted.internet import defer -from synapse.util.async import ObservableDeferred +from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches.descriptors import Cache, cached from tests import unittest diff --git a/tests/util/test_linearizer.py b/tests/util/test_linearizer.py index 4729bd5a0a..7f48a72de8 100644 --- a/tests/util/test_linearizer.py +++ b/tests/util/test_linearizer.py @@ -20,7 +20,7 @@ from twisted.internet import defer, reactor from twisted.internet.defer import CancelledError from synapse.util import Clock, logcontext -from synapse.util.async import Linearizer +from synapse.util.async_helpers import Linearizer from tests import unittest diff --git a/tests/util/test_rwlock.py b/tests/util/test_rwlock.py index 24194e3b25..7cd470be67 100644 --- a/tests/util/test_rwlock.py +++ b/tests/util/test_rwlock.py @@ -14,7 +14,7 @@ # limitations under the License. -from synapse.util.async import ReadWriteLock +from synapse.util.async_helpers import ReadWriteLock from tests import unittest -- cgit 1.5.1 From 31fa743567c3f70c66c7af2f1332560b04b78f7a Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Sat, 11 Aug 2018 22:38:34 +0100 Subject: fix sqlite/postgres incompatibility in reap_monthly_active_users --- synapse/storage/monthly_active_users.py | 44 +++++++++++++++++++++------------ 1 file changed, 28 insertions(+), 16 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/monthly_active_users.py b/synapse/storage/monthly_active_users.py index d47dcef3a0..9a6e699386 100644 --- a/synapse/storage/monthly_active_users.py +++ b/synapse/storage/monthly_active_users.py @@ -64,23 +64,27 @@ class MonthlyActiveUsersStore(SQLBaseStore): Deferred[] """ def _reap_users(txn): + # Purge stale users thirty_days_ago = ( int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30) ) - # Purge stale users - - # questionmarks is a hack to overcome sqlite not supporting - # tuples in 'WHERE IN %s' - questionmarks = '?' * len(self.reserved_users) query_args = [thirty_days_ago] - query_args.extend(self.reserved_users) - - sql = """ - DELETE FROM monthly_active_users - WHERE timestamp < ? - AND user_id NOT IN ({}) - """.format(','.join(questionmarks)) + base_sql = "DELETE FROM monthly_active_users WHERE timestamp < ?" + + # Need if/else since 'AND user_id NOT IN ({})' fails on Postgres + # when len(reserved_users) == 0. Works fine on sqlite. + if len(self.reserved_users) > 0: + # questionmarks is a hack to overcome sqlite not supporting + # tuples in 'WHERE IN %s' + questionmarks = '?' * len(self.reserved_users) + + query_args.extend(self.reserved_users) + sql = base_sql + """ AND user_id NOT IN ({})""".format( + ','.join(questionmarks) + ) + else: + sql = base_sql txn.execute(sql, query_args) @@ -93,16 +97,24 @@ class MonthlyActiveUsersStore(SQLBaseStore): # negative LIMIT values. So there is no way to write it that both can # support query_args = [self.hs.config.max_mau_value] - query_args.extend(self.reserved_users) - sql = """ + + base_sql = """ DELETE FROM monthly_active_users WHERE user_id NOT IN ( SELECT user_id FROM monthly_active_users ORDER BY timestamp DESC LIMIT ? ) - AND user_id NOT IN ({}) - """.format(','.join(questionmarks)) + """ + # Need if/else since 'AND user_id NOT IN ({})' fails on Postgres + # when len(reserved_users) == 0. Works fine on sqlite. + if len(self.reserved_users) > 0: + query_args.extend(self.reserved_users) + sql = base_sql + """ AND user_id NOT IN ({})""".format( + ','.join(questionmarks) + ) + else: + sql = base_sql txn.execute(sql, query_args) yield self.runInteraction("reap_monthly_active_users", _reap_users) -- cgit 1.5.1 From 99dd975dae7baaaef2a3b0a92fa51965b121ae34 Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Mon, 13 Aug 2018 16:47:46 +1000 Subject: Run tests under PostgreSQL (#3423) --- .travis.yml | 10 +++ changelog.d/3423.misc | 1 + synapse/handlers/presence.py | 5 ++ synapse/storage/client_ips.py | 5 ++ tests/__init__.py | 3 + tests/api/test_auth.py | 2 +- tests/api/test_filtering.py | 5 +- tests/crypto/test_keyring.py | 6 +- tests/handlers/test_auth.py | 2 +- tests/handlers/test_device.py | 2 +- tests/handlers/test_directory.py | 1 + tests/handlers/test_e2e_keys.py | 2 +- tests/handlers/test_profile.py | 1 + tests/handlers/test_register.py | 1 + tests/handlers/test_typing.py | 1 + tests/replication/slave/storage/_base.py | 1 + tests/rest/client/v1/test_admin.py | 2 +- tests/rest/client/v1/test_events.py | 1 + tests/rest/client/v1/test_profile.py | 1 + tests/rest/client/v1/test_register.py | 2 +- tests/rest/client/v1/test_rooms.py | 1 + tests/rest/client/v1/test_typing.py | 1 + tests/rest/client/v2_alpha/test_filter.py | 2 +- tests/rest/client/v2_alpha/test_register.py | 2 +- tests/rest/client/v2_alpha/test_sync.py | 2 +- tests/server.py | 10 ++- tests/storage/test_appservice.py | 13 ++- tests/storage/test_background_update.py | 4 +- tests/storage/test_client_ips.py | 2 +- tests/storage/test_devices.py | 2 +- tests/storage/test_directory.py | 2 +- tests/storage/test_end_to_end_keys.py | 3 +- tests/storage/test_event_federation.py | 2 +- tests/storage/test_event_push_actions.py | 2 +- tests/storage/test_keys.py | 2 +- tests/storage/test_monthly_active_users.py | 2 +- tests/storage/test_presence.py | 2 +- tests/storage/test_profile.py | 2 +- tests/storage/test_redaction.py | 2 +- tests/storage/test_registration.py | 2 +- tests/storage/test_room.py | 4 +- tests/storage/test_roommember.py | 2 +- tests/storage/test_state.py | 2 +- tests/storage/test_user_directory.py | 2 +- tests/test_federation.py | 5 +- tests/test_server.py | 2 +- tests/test_visibility.py | 2 +- tests/utils.py | 133 ++++++++++++++++++++++++---- tox.ini | 20 ++++- 49 files changed, 227 insertions(+), 59 deletions(-) create mode 100644 changelog.d/3423.misc (limited to 'synapse/storage') diff --git a/.travis.yml b/.travis.yml index b34b17af75..318701c9f8 100644 --- a/.travis.yml +++ b/.travis.yml @@ -8,6 +8,9 @@ before_script: - git remote set-branches --add origin develop - git fetch origin develop +services: + - postgresql + matrix: fast_finish: true include: @@ -20,6 +23,9 @@ matrix: - python: 2.7 env: TOX_ENV=py27 + - python: 2.7 + env: TOX_ENV=py27-postgres TRIAL_FLAGS="-j 4" + - python: 3.6 env: TOX_ENV=py36 @@ -29,6 +35,10 @@ matrix: - python: 3.6 env: TOX_ENV=check-newsfragment + allow_failures: + - python: 2.7 + env: TOX_ENV=py27-postgres TRIAL_FLAGS="-j 4" + install: - pip install tox diff --git a/changelog.d/3423.misc b/changelog.d/3423.misc new file mode 100644 index 0000000000..51768c6d14 --- /dev/null +++ b/changelog.d/3423.misc @@ -0,0 +1 @@ +The test suite now can run under PostgreSQL. diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 20fc3b0323..3671d24f60 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -95,6 +95,7 @@ class PresenceHandler(object): Args: hs (synapse.server.HomeServer): """ + self.hs = hs self.is_mine = hs.is_mine self.is_mine_id = hs.is_mine_id self.clock = hs.get_clock() @@ -230,6 +231,10 @@ class PresenceHandler(object): earlier than they should when synapse is restarted. This affect of this is some spurious presence changes that will self-correct. """ + # If the DB pool has already terminated, don't try updating + if not self.hs.get_db_pool().running: + return + logger.info( "Performing _on_shutdown. Persisting %d unpersisted changes", len(self.user_to_current_state) diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py index 2489527f2c..8fc678fa67 100644 --- a/synapse/storage/client_ips.py +++ b/synapse/storage/client_ips.py @@ -96,6 +96,11 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): self._batch_row_update[key] = (user_agent, device_id, now) def _update_client_ips_batch(self): + + # If the DB pool has already terminated, don't try updating + if not self.hs.get_db_pool().running: + return + def update(): to_update = self._batch_row_update self._batch_row_update = {} diff --git a/tests/__init__.py b/tests/__init__.py index 24006c949e..9d9ca22829 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -15,4 +15,7 @@ from twisted.trial import util +from tests import utils + util.DEFAULT_TIMEOUT_DURATION = 10 +utils.setupdb() diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index f8e28876bb..a65689ba89 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -39,7 +39,7 @@ class AuthTestCase(unittest.TestCase): self.state_handler = Mock() self.store = Mock() - self.hs = yield setup_test_homeserver(handlers=None) + self.hs = yield setup_test_homeserver(self.addCleanup, handlers=None) self.hs.get_datastore = Mock(return_value=self.store) self.hs.handlers = TestHandlers(self.hs) self.auth = Auth(self.hs) diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index 1c2d71052c..48b2d3d663 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -46,7 +46,10 @@ class FilteringTestCase(unittest.TestCase): self.mock_http_client.put_json = DeferredMockCallable() hs = yield setup_test_homeserver( - handlers=None, http_client=self.mock_http_client, keyring=Mock() + self.addCleanup, + handlers=None, + http_client=self.mock_http_client, + keyring=Mock(), ) self.filtering = hs.get_filtering() diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index 0c6f510d11..8299dc72c8 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -58,12 +58,10 @@ class KeyringTestCase(unittest.TestCase): self.mock_perspective_server = MockPerspectiveServer() self.http_client = Mock() self.hs = yield utils.setup_test_homeserver( - handlers=None, http_client=self.http_client + self.addCleanup, handlers=None, http_client=self.http_client ) keys = self.mock_perspective_server.get_verify_keys() - self.hs.config.perspectives = { - self.mock_perspective_server.server_name: keys - } + self.hs.config.perspectives = {self.mock_perspective_server.server_name: keys} def check_context(self, _, expected): self.assertEquals( diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index ede01f8099..56c0f87fb7 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -35,7 +35,7 @@ class AuthHandlers(object): class AuthTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): - self.hs = yield setup_test_homeserver(handlers=None) + self.hs = yield setup_test_homeserver(self.addCleanup, handlers=None) self.hs.handlers = AuthHandlers(self.hs) self.auth_handler = self.hs.handlers.auth_handler self.macaroon_generator = self.hs.get_macaroon_generator() diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index d70d645504..56e7acd37c 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -34,7 +34,7 @@ class DeviceTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): - hs = yield utils.setup_test_homeserver() + hs = yield utils.setup_test_homeserver(self.addCleanup) self.handler = hs.get_device_handler() self.store = hs.get_datastore() self.clock = hs.get_clock() diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index 06de9f5eca..ec7355688b 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -46,6 +46,7 @@ class DirectoryTestCase(unittest.TestCase): self.mock_registry.register_query_handler = register_query_handler hs = yield setup_test_homeserver( + self.addCleanup, http_client=None, resource_for_federation=Mock(), federation_client=self.mock_federation, diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index 57ab228455..8dccc6826e 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -34,7 +34,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): self.hs = yield utils.setup_test_homeserver( - handlers=None, federation_client=mock.Mock() + self.addCleanup, handlers=None, federation_client=mock.Mock() ) self.handler = synapse.handlers.e2e_keys.E2eKeysHandler(self.hs) diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index 9268a6fe2b..62dc69003c 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -48,6 +48,7 @@ class ProfileTestCase(unittest.TestCase): self.mock_registry.register_query_handler = register_query_handler hs = yield setup_test_homeserver( + self.addCleanup, http_client=None, handlers=None, resource_for_federation=Mock(), diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index dbec81076f..d48d40c8dd 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -40,6 +40,7 @@ class RegistrationTestCase(unittest.TestCase): self.mock_distributor.declare("registered_user") self.mock_captcha_client = Mock() self.hs = yield setup_test_homeserver( + self.addCleanup, handlers=None, http_client=None, expire_access_token=True, diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index becfa77bfa..ad58073a14 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -67,6 +67,7 @@ class TypingNotificationsTestCase(unittest.TestCase): self.state_handler = Mock() hs = yield setup_test_homeserver( + self.addCleanup, "test", auth=self.auth, clock=self.clock, diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index c23b6e2cfd..65df116efc 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -54,6 +54,7 @@ class BaseSlavedStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): self.hs = yield setup_test_homeserver( + self.addCleanup, "blue", http_client=None, federation_client=Mock(), diff --git a/tests/rest/client/v1/test_admin.py b/tests/rest/client/v1/test_admin.py index 67d9ab94e2..1a553fa3f9 100644 --- a/tests/rest/client/v1/test_admin.py +++ b/tests/rest/client/v1/test_admin.py @@ -51,7 +51,7 @@ class UserRegisterTestCase(unittest.TestCase): self.secrets = Mock() self.hs = setup_test_homeserver( - http_client=None, clock=self.hs_clock, reactor=self.clock + self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.clock ) self.hs.config.registration_shared_secret = u"shared" diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py index 0316b74fa1..956f7fc4c4 100644 --- a/tests/rest/client/v1/test_events.py +++ b/tests/rest/client/v1/test_events.py @@ -41,6 +41,7 @@ class EventStreamPermissionsTestCase(RestTestCase): self.mock_resource = MockHttpResource(prefix=PATH_PREFIX) hs = yield setup_test_homeserver( + self.addCleanup, http_client=None, federation_client=Mock(), ratelimiter=NonCallableMock(spec_set=["send_message"]), diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py index 9ba0ffc19f..1eab9c3bdb 100644 --- a/tests/rest/client/v1/test_profile.py +++ b/tests/rest/client/v1/test_profile.py @@ -46,6 +46,7 @@ class ProfileTestCase(unittest.TestCase): ) hs = yield setup_test_homeserver( + self.addCleanup, "test", http_client=None, resource_for_client=self.mock_resource, diff --git a/tests/rest/client/v1/test_register.py b/tests/rest/client/v1/test_register.py index 6f15d69ecd..4be88b8a39 100644 --- a/tests/rest/client/v1/test_register.py +++ b/tests/rest/client/v1/test_register.py @@ -49,7 +49,7 @@ class CreateUserServletTestCase(unittest.TestCase): self.hs_clock = Clock(self.clock) self.hs = self.hs = setup_test_homeserver( - http_client=None, clock=self.hs_clock, reactor=self.clock + self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.clock ) self.hs.get_datastore = Mock(return_value=self.datastore) self.hs.get_handlers = Mock(return_value=handlers) diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 00fc796787..9fe0760496 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -50,6 +50,7 @@ class RoomBase(unittest.TestCase): self.hs_clock = Clock(self.clock) self.hs = setup_test_homeserver( + self.addCleanup, "red", http_client=None, clock=self.hs_clock, diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py index 7f1a435e7b..677265edf6 100644 --- a/tests/rest/client/v1/test_typing.py +++ b/tests/rest/client/v1/test_typing.py @@ -44,6 +44,7 @@ class RoomTypingTestCase(RestTestCase): self.auth_user_id = self.user_id hs = yield setup_test_homeserver( + self.addCleanup, "red", clock=self.clock, http_client=None, diff --git a/tests/rest/client/v2_alpha/test_filter.py b/tests/rest/client/v2_alpha/test_filter.py index de33b10a5f..8260c130f8 100644 --- a/tests/rest/client/v2_alpha/test_filter.py +++ b/tests/rest/client/v2_alpha/test_filter.py @@ -43,7 +43,7 @@ class FilterTestCase(unittest.TestCase): self.hs_clock = Clock(self.clock) self.hs = setup_test_homeserver( - http_client=None, clock=self.hs_clock, reactor=self.clock + self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.clock ) self.auth = self.hs.get_auth() diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index 9487babac3..b72bd0fb7f 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -47,7 +47,7 @@ class RegisterRestServletTestCase(unittest.TestCase): login_handler=self.login_handler, ) self.hs = setup_test_homeserver( - http_client=None, clock=self.hs_clock, reactor=self.clock + self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.clock ) self.hs.get_auth = Mock(return_value=self.auth) self.hs.get_handlers = Mock(return_value=self.handlers) diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py index bafc0d1df0..2e1d06c509 100644 --- a/tests/rest/client/v2_alpha/test_sync.py +++ b/tests/rest/client/v2_alpha/test_sync.py @@ -40,7 +40,7 @@ class FilterTestCase(unittest.TestCase): self.hs_clock = Clock(self.clock) self.hs = setup_test_homeserver( - http_client=None, clock=self.hs_clock, reactor=self.clock + self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.clock ) self.auth = self.hs.get_auth() diff --git a/tests/server.py b/tests/server.py index 05708be8b9..beb24cf032 100644 --- a/tests/server.py +++ b/tests/server.py @@ -147,12 +147,15 @@ class ThreadedMemoryReactorClock(MemoryReactorClock): return d -def setup_test_homeserver(*args, **kwargs): +def setup_test_homeserver(cleanup_func, *args, **kwargs): """ Set up a synchronous test server, driven by the reactor used by the homeserver. """ - d = _sth(*args, **kwargs).result + d = _sth(cleanup_func, *args, **kwargs).result + + if isinstance(d, Failure): + d.raiseException() # Make the thread pool synchronous. clock = d.get_clock() @@ -189,6 +192,9 @@ def setup_test_homeserver(*args, **kwargs): def start(self): pass + def stop(self): + pass + def callInThreadWithCallback(self, onResult, function, *args, **kwargs): def _(res): if isinstance(res, Failure): diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index fbb25a8844..c893990454 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -43,7 +43,10 @@ class ApplicationServiceStoreTestCase(unittest.TestCase): password_providers=[], ) hs = yield setup_test_homeserver( - config=config, federation_sender=Mock(), federation_client=Mock() + self.addCleanup, + config=config, + federation_sender=Mock(), + federation_client=Mock(), ) self.as_token = "token1" @@ -108,7 +111,10 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): password_providers=[], ) hs = yield setup_test_homeserver( - config=config, federation_sender=Mock(), federation_client=Mock() + self.addCleanup, + config=config, + federation_sender=Mock(), + federation_client=Mock(), ) self.db_pool = hs.get_db_pool() @@ -392,6 +398,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): app_service_config_files=[f1, f2], event_cache_size=1, password_providers=[] ) hs = yield setup_test_homeserver( + self.addCleanup, config=config, datastore=Mock(), federation_sender=Mock(), @@ -409,6 +416,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): app_service_config_files=[f1, f2], event_cache_size=1, password_providers=[] ) hs = yield setup_test_homeserver( + self.addCleanup, config=config, datastore=Mock(), federation_sender=Mock(), @@ -432,6 +440,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): app_service_config_files=[f1, f2], event_cache_size=1, password_providers=[] ) hs = yield setup_test_homeserver( + self.addCleanup, config=config, datastore=Mock(), federation_sender=Mock(), diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py index b4f6baf441..81403727c5 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py @@ -9,7 +9,9 @@ from tests.utils import setup_test_homeserver class BackgroundUpdateTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): - hs = yield setup_test_homeserver() # type: synapse.server.HomeServer + hs = yield setup_test_homeserver( + self.addCleanup + ) # type: synapse.server.HomeServer self.store = hs.get_datastore() self.clock = hs.get_clock() diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index ea00bbe84c..fa60d949ba 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -28,7 +28,7 @@ class ClientIpStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def setUp(self): - self.hs = yield tests.utils.setup_test_homeserver() + self.hs = yield tests.utils.setup_test_homeserver(self.addCleanup) self.store = self.hs.get_datastore() self.clock = self.hs.get_clock() diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py index 63bc42d9e0..aef4dfaf57 100644 --- a/tests/storage/test_devices.py +++ b/tests/storage/test_devices.py @@ -28,7 +28,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def setUp(self): - hs = yield tests.utils.setup_test_homeserver() + hs = yield tests.utils.setup_test_homeserver(self.addCleanup) self.store = hs.get_datastore() diff --git a/tests/storage/test_directory.py b/tests/storage/test_directory.py index 9a8ba2fcfe..b4510c1c8d 100644 --- a/tests/storage/test_directory.py +++ b/tests/storage/test_directory.py @@ -26,7 +26,7 @@ from tests.utils import setup_test_homeserver class DirectoryStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): - hs = yield setup_test_homeserver() + hs = yield setup_test_homeserver(self.addCleanup) self.store = DirectoryStore(None, hs) diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py index d45c775c2d..8f0aaece40 100644 --- a/tests/storage/test_end_to_end_keys.py +++ b/tests/storage/test_end_to_end_keys.py @@ -26,8 +26,7 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def setUp(self): - hs = yield tests.utils.setup_test_homeserver() - + hs = yield tests.utils.setup_test_homeserver(self.addCleanup) self.store = hs.get_datastore() @defer.inlineCallbacks diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index 66eb119581..2fdf34fdf6 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -22,7 +22,7 @@ import tests.utils class EventFederationWorkerStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def setUp(self): - hs = yield tests.utils.setup_test_homeserver() + hs = yield tests.utils.setup_test_homeserver(self.addCleanup) self.store = hs.get_datastore() @defer.inlineCallbacks diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index 5e87b4530d..b114c6fb1d 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -33,7 +33,7 @@ HIGHLIGHT = [ class EventPushActionsStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def setUp(self): - hs = yield tests.utils.setup_test_homeserver() + hs = yield tests.utils.setup_test_homeserver(self.addCleanup) self.store = hs.get_datastore() @defer.inlineCallbacks diff --git a/tests/storage/test_keys.py b/tests/storage/test_keys.py index ad0a55b324..47f4a8ceac 100644 --- a/tests/storage/test_keys.py +++ b/tests/storage/test_keys.py @@ -28,7 +28,7 @@ class KeyStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def setUp(self): - hs = yield tests.utils.setup_test_homeserver() + hs = yield tests.utils.setup_test_homeserver(self.addCleanup) self.store = hs.get_datastore() @defer.inlineCallbacks diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index 22b1072d9f..0a2c859f26 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -28,7 +28,7 @@ class MonthlyActiveUsersTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def setUp(self): - self.hs = yield setup_test_homeserver() + self.hs = yield setup_test_homeserver(self.addCleanup) self.store = self.hs.get_datastore() @defer.inlineCallbacks diff --git a/tests/storage/test_presence.py b/tests/storage/test_presence.py index 12c540dfab..b5b58ff660 100644 --- a/tests/storage/test_presence.py +++ b/tests/storage/test_presence.py @@ -26,7 +26,7 @@ from tests.utils import MockClock, setup_test_homeserver class PresenceStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): - hs = yield setup_test_homeserver(clock=MockClock()) + hs = yield setup_test_homeserver(self.addCleanup, clock=MockClock()) self.store = PresenceStore(None, hs) diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py index 5acbc8be0c..a1f6618bf9 100644 --- a/tests/storage/test_profile.py +++ b/tests/storage/test_profile.py @@ -26,7 +26,7 @@ from tests.utils import setup_test_homeserver class ProfileStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): - hs = yield setup_test_homeserver() + hs = yield setup_test_homeserver(self.addCleanup) self.store = ProfileStore(None, hs) diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index 85ce61e841..c4e9fb72bf 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -29,7 +29,7 @@ class RedactionTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): hs = yield setup_test_homeserver( - resource_for_federation=Mock(), http_client=None + self.addCleanup, resource_for_federation=Mock(), http_client=None ) self.store = hs.get_datastore() diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index bd96896bb3..4eda122edc 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -23,7 +23,7 @@ from tests.utils import setup_test_homeserver class RegistrationStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): - hs = yield setup_test_homeserver() + hs = yield setup_test_homeserver(self.addCleanup) self.db_pool = hs.get_db_pool() self.store = hs.get_datastore() diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py index 84d49b55c1..a1ea23b068 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py @@ -26,7 +26,7 @@ from tests.utils import setup_test_homeserver class RoomStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): - hs = yield setup_test_homeserver() + hs = yield setup_test_homeserver(self.addCleanup) # We can't test RoomStore on its own without the DirectoryStore, for # management of the 'room_aliases' table @@ -57,7 +57,7 @@ class RoomStoreTestCase(unittest.TestCase): class RoomEventsStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): - hs = setup_test_homeserver() + hs = setup_test_homeserver(self.addCleanup) # Room events need the full datastore, for persist_event() and # get_room_state() diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index 0d9908926a..c83ef60062 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -29,7 +29,7 @@ class RoomMemberStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): hs = yield setup_test_homeserver( - resource_for_federation=Mock(), http_client=None + self.addCleanup, resource_for_federation=Mock(), http_client=None ) # We can't test the RoomMemberStore on its own without the other event # storage logic diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index ed5b41644a..6168c46248 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -33,7 +33,7 @@ class StateStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def setUp(self): - hs = yield tests.utils.setup_test_homeserver() + hs = yield tests.utils.setup_test_homeserver(self.addCleanup) self.store = hs.get_datastore() self.event_builder_factory = hs.get_event_builder_factory() diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py index 7a273eab48..b46e0ea7e2 100644 --- a/tests/storage/test_user_directory.py +++ b/tests/storage/test_user_directory.py @@ -29,7 +29,7 @@ BOBBY = "@bobby:a" class UserDirectoryStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): - self.hs = yield setup_test_homeserver() + self.hs = yield setup_test_homeserver(self.addCleanup) self.store = UserDirectoryStore(None, self.hs) # alice and bob are both in !room_id. bobby is not but shares diff --git a/tests/test_federation.py b/tests/test_federation.py index f40ff29b52..2540604fcc 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -18,7 +18,10 @@ class MessageAcceptTests(unittest.TestCase): self.reactor = ThreadedMemoryReactorClock() self.hs_clock = Clock(self.reactor) self.homeserver = setup_test_homeserver( - http_client=self.http_client, clock=self.hs_clock, reactor=self.reactor + self.addCleanup, + http_client=self.http_client, + clock=self.hs_clock, + reactor=self.reactor, ) user_id = UserID("us", "test") diff --git a/tests/test_server.py b/tests/test_server.py index fc396226ea..895e490406 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -16,7 +16,7 @@ class JsonResourceTests(unittest.TestCase): self.reactor = MemoryReactorClock() self.hs_clock = Clock(self.reactor) self.homeserver = setup_test_homeserver( - http_client=None, clock=self.hs_clock, reactor=self.reactor + self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.reactor ) def test_handler_for_request(self): diff --git a/tests/test_visibility.py b/tests/test_visibility.py index 8643d63125..45a78338d6 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -31,7 +31,7 @@ TEST_ROOM_ID = "!TEST:ROOM" class FilterEventsForServerTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def setUp(self): - self.hs = yield setup_test_homeserver() + self.hs = yield setup_test_homeserver(self.addCleanup) self.event_creation_handler = self.hs.get_event_creation_handler() self.event_builder_factory = self.hs.get_event_builder_factory() self.store = self.hs.get_datastore() diff --git a/tests/utils.py b/tests/utils.py index 8668b5478f..90378326f8 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -13,7 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import atexit import hashlib +import os +import uuid from inspect import getcallargs from mock import Mock, patch @@ -27,23 +30,80 @@ from synapse.http.server import HttpServer from synapse.server import HomeServer from synapse.storage import PostgresEngine from synapse.storage.engines import create_engine -from synapse.storage.prepare_database import prepare_database +from synapse.storage.prepare_database import ( + _get_or_create_schema_state, + _setup_new_database, + prepare_database, +) from synapse.util.logcontext import LoggingContext from synapse.util.ratelimitutils import FederationRateLimiter # set this to True to run the tests against postgres instead of sqlite. -# It requires you to have a local postgres database called synapse_test, within -# which ALL TABLES WILL BE DROPPED -USE_POSTGRES_FOR_TESTS = False +USE_POSTGRES_FOR_TESTS = os.environ.get("SYNAPSE_POSTGRES", False) +POSTGRES_USER = os.environ.get("SYNAPSE_POSTGRES_USER", "postgres") +POSTGRES_BASE_DB = "_synapse_unit_tests_base_%s" % (os.getpid(),) + + +def setupdb(): + + # If we're using PostgreSQL, set up the db once + if USE_POSTGRES_FOR_TESTS: + pgconfig = { + "name": "psycopg2", + "args": { + "database": POSTGRES_BASE_DB, + "user": POSTGRES_USER, + "cp_min": 1, + "cp_max": 5, + }, + } + config = Mock() + config.password_providers = [] + config.database_config = pgconfig + db_engine = create_engine(pgconfig) + db_conn = db_engine.module.connect(user=POSTGRES_USER) + db_conn.autocommit = True + cur = db_conn.cursor() + cur.execute("DROP DATABASE IF EXISTS %s;" % (POSTGRES_BASE_DB,)) + cur.execute("CREATE DATABASE %s;" % (POSTGRES_BASE_DB,)) + cur.close() + db_conn.close() + + # Set up in the db + db_conn = db_engine.module.connect( + database=POSTGRES_BASE_DB, user=POSTGRES_USER + ) + cur = db_conn.cursor() + _get_or_create_schema_state(cur, db_engine) + _setup_new_database(cur, db_engine) + db_conn.commit() + cur.close() + db_conn.close() + + def _cleanup(): + db_conn = db_engine.module.connect(user=POSTGRES_USER) + db_conn.autocommit = True + cur = db_conn.cursor() + cur.execute("DROP DATABASE IF EXISTS %s;" % (POSTGRES_BASE_DB,)) + cur.close() + db_conn.close() + + atexit.register(_cleanup) @defer.inlineCallbacks def setup_test_homeserver( - name="test", datastore=None, config=None, reactor=None, **kargs + cleanup_func, name="test", datastore=None, config=None, reactor=None, **kargs ): - """Setup a homeserver suitable for running tests against. Keyword arguments - are passed to the Homeserver constructor. If no datastore is supplied a - datastore backed by an in-memory sqlite db will be given to the HS. + """ + Setup a homeserver suitable for running tests against. Keyword arguments + are passed to the Homeserver constructor. + + If no datastore is supplied, one is created and given to the homeserver. + + Args: + cleanup_func : The function used to register a cleanup routine for + after the test. """ if reactor is None: from twisted.internet import reactor @@ -95,9 +155,11 @@ def setup_test_homeserver( kargs["clock"] = MockClock() if USE_POSTGRES_FOR_TESTS: + test_db = "synapse_test_%s" % uuid.uuid4().hex + config.database_config = { "name": "psycopg2", - "args": {"database": "synapse_test", "cp_min": 1, "cp_max": 5}, + "args": {"database": test_db, "cp_min": 1, "cp_max": 5}, } else: config.database_config = { @@ -107,6 +169,21 @@ def setup_test_homeserver( db_engine = create_engine(config.database_config) + # Create the database before we actually try and connect to it, based off + # the template database we generate in setupdb() + if datastore is None and isinstance(db_engine, PostgresEngine): + db_conn = db_engine.module.connect( + database=POSTGRES_BASE_DB, user=POSTGRES_USER + ) + db_conn.autocommit = True + cur = db_conn.cursor() + cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,)) + cur.execute( + "CREATE DATABASE %s WITH TEMPLATE %s;" % (test_db, POSTGRES_BASE_DB) + ) + cur.close() + db_conn.close() + # we need to configure the connection pool to run the on_new_connection # function, so that we can test code that uses custom sqlite functions # (like rank). @@ -125,15 +202,35 @@ def setup_test_homeserver( reactor=reactor, **kargs ) - db_conn = hs.get_db_conn() - # make sure that the database is empty - if isinstance(db_engine, PostgresEngine): - cur = db_conn.cursor() - cur.execute("SELECT tablename FROM pg_tables where schemaname='public'") - rows = cur.fetchall() - for r in rows: - cur.execute("DROP TABLE %s CASCADE" % r[0]) - yield prepare_database(db_conn, db_engine, config) + + # Prepare the DB on SQLite -- PostgreSQL is a copy of an already up to + # date db + if not isinstance(db_engine, PostgresEngine): + db_conn = hs.get_db_conn() + yield prepare_database(db_conn, db_engine, config) + db_conn.commit() + db_conn.close() + + else: + # We need to do cleanup on PostgreSQL + def cleanup(): + # Close all the db pools + hs.get_db_pool().close() + + # Drop the test database + db_conn = db_engine.module.connect( + database=POSTGRES_BASE_DB, user=POSTGRES_USER + ) + db_conn.autocommit = True + cur = db_conn.cursor() + cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,)) + db_conn.commit() + cur.close() + db_conn.close() + + # Register the cleanup hook + cleanup_func(cleanup) + hs.setup() else: hs = HomeServer( diff --git a/tox.ini b/tox.ini index ed26644bd9..085f438989 100644 --- a/tox.ini +++ b/tox.ini @@ -1,7 +1,7 @@ [tox] envlist = packaging, py27, py36, pep8, check_isort -[testenv] +[base] deps = coverage Twisted>=15.1 @@ -15,6 +15,15 @@ deps = setenv = PYTHONDONTWRITEBYTECODE = no_byte_code +[testenv] +deps = + {[base]deps} + +setenv = + {[base]setenv} + +passenv = * + commands = /usr/bin/find "{toxinidir}" -name '*.pyc' -delete coverage run {env:COVERAGE_OPTS:} --source="{toxinidir}/synapse" \ @@ -46,6 +55,15 @@ commands = # ) usedevelop=true +[testenv:py27-postgres] +usedevelop=true +deps = + {[base]deps} + psycopg2 +setenv = + {[base]setenv} + SYNAPSE_POSTGRES = 1 + [testenv:py36] usedevelop=true commands = -- cgit 1.5.1 From 9ecbaf8ba8b12537862e3045a650fffd438ba8a5 Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Tue, 14 Aug 2018 16:55:28 +0100 Subject: adding missing yield --- synapse/storage/monthly_active_users.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse/storage') diff --git a/synapse/storage/monthly_active_users.py b/synapse/storage/monthly_active_users.py index 7b36accdb6..7e417f811e 100644 --- a/synapse/storage/monthly_active_users.py +++ b/synapse/storage/monthly_active_users.py @@ -46,7 +46,7 @@ class MonthlyActiveUsersStore(SQLBaseStore): tp["medium"], tp["address"] ) if user_id: - self.upsert_monthly_active_user(user_id) + yield self.upsert_monthly_active_user(user_id) reserved_user_list.append(user_id) else: logger.warning( -- cgit 1.5.1 From 2f78f432c421702e3756d029b3c669db837d8bcd Mon Sep 17 00:00:00 2001 From: Matthew Hodgson Date: Wed, 15 Aug 2018 17:35:22 +0200 Subject: speed up /members and add at= and membership params (#3568) --- changelog.d/3331.feature | 1 + changelog.d/3567.feature | 1 + changelog.d/3568.feature | 1 + synapse/handlers/message.py | 88 ++++++++++++++++++++++++++++++++++++------ synapse/rest/client/v1/room.py | 32 +++++++++++++-- synapse/storage/events.py | 2 +- synapse/storage/state.py | 66 ++++++++++++++++++++++++++++++- tests/storage/test_state.py | 2 +- 8 files changed, 174 insertions(+), 19 deletions(-) create mode 100644 changelog.d/3331.feature create mode 100644 changelog.d/3567.feature create mode 100644 changelog.d/3568.feature (limited to 'synapse/storage') diff --git a/changelog.d/3331.feature b/changelog.d/3331.feature new file mode 100644 index 0000000000..e574b9bcc3 --- /dev/null +++ b/changelog.d/3331.feature @@ -0,0 +1 @@ +add support for the include_redundant_members filter param as per MSC1227 diff --git a/changelog.d/3567.feature b/changelog.d/3567.feature new file mode 100644 index 0000000000..c74c1f57a9 --- /dev/null +++ b/changelog.d/3567.feature @@ -0,0 +1 @@ +make the /context API filter & lazy-load aware as per MSC1227 diff --git a/changelog.d/3568.feature b/changelog.d/3568.feature new file mode 100644 index 0000000000..247f02ba4e --- /dev/null +++ b/changelog.d/3568.feature @@ -0,0 +1 @@ +speed up /members API and add `at` and `membership` params as per MSC1227 diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 01a362360e..893c9bcdc4 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -25,7 +25,13 @@ from twisted.internet import defer from twisted.internet.defer import succeed from synapse.api.constants import MAX_DEPTH, EventTypes, Membership -from synapse.api.errors import AuthError, Codes, ConsentNotGivenError, SynapseError +from synapse.api.errors import ( + AuthError, + Codes, + ConsentNotGivenError, + NotFoundError, + SynapseError, +) from synapse.api.urls import ConsentURIBuilder from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.events.utils import serialize_event @@ -36,6 +42,7 @@ from synapse.util.async_helpers import Linearizer from synapse.util.frozenutils import frozendict_json_encoder from synapse.util.logcontext import run_in_background from synapse.util.metrics import measure_func +from synapse.visibility import filter_events_for_client from ._base import BaseHandler @@ -82,28 +89,85 @@ class MessageHandler(object): defer.returnValue(data) @defer.inlineCallbacks - def get_state_events(self, user_id, room_id, is_guest=False): + def get_state_events( + self, user_id, room_id, types=None, filtered_types=None, + at_token=None, is_guest=False, + ): """Retrieve all state events for a given room. If the user is joined to the room then return the current state. If the user has - left the room return the state events from when they left. + left the room return the state events from when they left. If an explicit + 'at' parameter is passed, return the state events as of that event, if + visible. Args: user_id(str): The user requesting state events. room_id(str): The room ID to get all state events from. + types(list[(str, str|None)]|None): List of (type, state_key) tuples + which are used to filter the state fetched. If `state_key` is None, + all events are returned of the given type. + May be None, which matches any key. + filtered_types(list[str]|None): Only apply filtering via `types` to this + list of event types. Other types of events are returned unfiltered. + If None, `types` filtering is applied to all events. + at_token(StreamToken|None): the stream token of the at which we are requesting + the stats. If the user is not allowed to view the state as of that + stream token, we raise a 403 SynapseError. If None, returns the current + state based on the current_state_events table. + is_guest(bool): whether this user is a guest Returns: A list of dicts representing state events. [{}, {}, {}] + Raises: + NotFoundError (404) if the at token does not yield an event + + AuthError (403) if the user doesn't have permission to view + members of this room. """ - membership, membership_event_id = yield self.auth.check_in_room_or_world_readable( - room_id, user_id - ) + if at_token: + # FIXME this claims to get the state at a stream position, but + # get_recent_events_for_room operates by topo ordering. This therefore + # does not reliably give you the state at the given stream position. + # (https://github.com/matrix-org/synapse/issues/3305) + last_events, _ = yield self.store.get_recent_events_for_room( + room_id, end_token=at_token.room_key, limit=1, + ) - if membership == Membership.JOIN: - room_state = yield self.state.get_current_state(room_id) - elif membership == Membership.LEAVE: - room_state = yield self.store.get_state_for_events( - [membership_event_id], None + if not last_events: + raise NotFoundError("Can't find event for token %s" % (at_token, )) + + visible_events = yield filter_events_for_client( + self.store, user_id, last_events, + ) + + event = last_events[0] + if visible_events: + room_state = yield self.store.get_state_for_events( + [event.event_id], types, filtered_types=filtered_types, + ) + room_state = room_state[event.event_id] + else: + raise AuthError( + 403, + "User %s not allowed to view events in room %s at token %s" % ( + user_id, room_id, at_token, + ) + ) + else: + membership, membership_event_id = ( + yield self.auth.check_in_room_or_world_readable( + room_id, user_id, + ) ) - room_state = room_state[membership_event_id] + + if membership == Membership.JOIN: + state_ids = yield self.store.get_filtered_current_state_ids( + room_id, types, filtered_types=filtered_types, + ) + room_state = yield self.store.get_events(state_ids.values()) + elif membership == Membership.LEAVE: + room_state = yield self.store.get_state_for_events( + [membership_event_id], types, filtered_types=filtered_types, + ) + room_state = room_state[membership_event_id] now = self.clock.time_msec() defer.returnValue( diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index fa5989e74e..fcc1091760 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -34,7 +34,7 @@ from synapse.http.servlet import ( parse_string, ) from synapse.streams.config import PaginationConfig -from synapse.types import RoomAlias, RoomID, ThirdPartyInstanceID, UserID +from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID from .base import ClientV1RestServlet, client_path_patterns @@ -384,15 +384,39 @@ class RoomMemberListRestServlet(ClientV1RestServlet): def on_GET(self, request, room_id): # TODO support Pagination stream API (limit/tokens) requester = yield self.auth.get_user_by_req(request) - events = yield self.message_handler.get_state_events( + handler = self.message_handler + + # request the state as of a given event, as identified by a stream token, + # for consistency with /messages etc. + # useful for getting the membership in retrospect as of a given /sync + # response. + at_token_string = parse_string(request, "at") + if at_token_string is None: + at_token = None + else: + at_token = StreamToken.from_string(at_token_string) + + # let you filter down on particular memberships. + # XXX: this may not be the best shape for this API - we could pass in a filter + # instead, except filters aren't currently aware of memberships. + # See https://github.com/matrix-org/matrix-doc/issues/1337 for more details. + membership = parse_string(request, "membership") + not_membership = parse_string(request, "not_membership") + + events = yield handler.get_state_events( room_id=room_id, user_id=requester.user.to_string(), + at_token=at_token, + types=[(EventTypes.Member, None)], ) chunk = [] for event in events: - if event["type"] != EventTypes.Member: + if ( + (membership and event['content'].get("membership") != membership) or + (not_membership and event['content'].get("membership") == not_membership) + ): continue chunk.append(event) @@ -401,6 +425,8 @@ class RoomMemberListRestServlet(ClientV1RestServlet): })) +# deprecated in favour of /members?membership=join? +# except it does custom AS logic and has a simpler return format class JoinedRoomMemberListRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/rooms/(?P[^/]*)/joined_members$") diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 135af54fa9..025a7fb6d9 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -1911,7 +1911,7 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore max_depth = max(row[0] for row in rows) if max_depth <= token.topological: - # We need to ensure we don't delete all the events from the datanase + # We need to ensure we don't delete all the events from the database # otherwise we wouldn't be able to send any events (due to not # having any backwards extremeties) raise SynapseError( diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 17b14d464b..754dfa6973 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -116,6 +116,69 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): _get_current_state_ids_txn, ) + # FIXME: how should this be cached? + def get_filtered_current_state_ids(self, room_id, types, filtered_types=None): + """Get the current state event of a given type for a room based on the + current_state_events table. This may not be as up-to-date as the result + of doing a fresh state resolution as per state_handler.get_current_state + Args: + room_id (str) + types (list[(Str, (Str|None))]): List of (type, state_key) tuples + which are used to filter the state fetched. `state_key` may be + None, which matches any `state_key` + filtered_types (list[Str]|None): List of types to apply the above filter to. + Returns: + deferred: dict of (type, state_key) -> event + """ + + include_other_types = False if filtered_types is None else True + + def _get_filtered_current_state_ids_txn(txn): + results = {} + sql = """SELECT type, state_key, event_id FROM current_state_events + WHERE room_id = ? %s""" + # Turns out that postgres doesn't like doing a list of OR's and + # is about 1000x slower, so we just issue a query for each specific + # type seperately. + if types: + clause_to_args = [ + ( + "AND type = ? AND state_key = ?", + (etype, state_key) + ) if state_key is not None else ( + "AND type = ?", + (etype,) + ) + for etype, state_key in types + ] + + if include_other_types: + unique_types = set(filtered_types) + clause_to_args.append( + ( + "AND type <> ? " * len(unique_types), + list(unique_types) + ) + ) + else: + # If types is None we fetch all the state, and so just use an + # empty where clause with no extra args. + clause_to_args = [("", [])] + for where_clause, where_args in clause_to_args: + args = [room_id] + args.extend(where_args) + txn.execute(sql % (where_clause,), args) + for row in txn: + typ, state_key, event_id = row + key = (intern_string(typ), intern_string(state_key)) + results[key] = event_id + return results + + return self.runInteraction( + "get_filtered_current_state_ids", + _get_filtered_current_state_ids_txn, + ) + @cached(max_entries=10000, iterable=True) def get_state_group_delta(self, state_group): """Given a state group try to return a previous group and a delta between @@ -389,8 +452,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): If None, `types` filtering is applied to all events. Returns: - deferred: A list of dicts corresponding to the event_ids given. - The dicts are mappings from (type, state_key) -> state_events + deferred: A dict of (event_id) -> (type, state_key) -> [state_events] """ event_to_groups = yield self._get_state_group_for_events( event_ids, diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index 6168c46248..ebfd969b36 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -148,7 +148,7 @@ class StateStoreTestCase(tests.unittest.TestCase): {(e3.type, e3.state_key): e3, (e5.type, e5.state_key): e5}, state ) - # check we can use filter_types to grab a specific room member + # check we can use filtered_types to grab a specific room member # without filtering out the other event types state = yield self.store.get_state_for_event( e5.event_id, -- cgit 1.5.1 From 3f543dc021dd9456c8ed2da7a1e4769a68c07729 Mon Sep 17 00:00:00 2001 From: Matthew Hodgson Date: Thu, 16 Aug 2018 10:46:50 +0200 Subject: initial cut at a room summary API (#3574) --- changelog.d/3574.feature | 1 + synapse/handlers/sync.py | 159 ++++++++++++++++++++++++++++++++--- synapse/rest/client/v2_alpha/sync.py | 1 + synapse/storage/_base.py | 7 +- synapse/storage/state.py | 5 +- synapse/storage/stream.py | 4 +- 6 files changed, 159 insertions(+), 18 deletions(-) create mode 100644 changelog.d/3574.feature (limited to 'synapse/storage') diff --git a/changelog.d/3574.feature b/changelog.d/3574.feature new file mode 100644 index 0000000000..87ac32df72 --- /dev/null +++ b/changelog.d/3574.feature @@ -0,0 +1 @@ +implement `summary` block in /sync response as per MSC688 diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 3b21a04a5d..ac3edf0cc9 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -75,6 +75,7 @@ class JoinedSyncResult(collections.namedtuple("JoinedSyncResult", [ "ephemeral", "account_data", "unread_notifications", + "summary", ])): __slots__ = [] @@ -503,10 +504,142 @@ class SyncHandler(object): state = {} defer.returnValue(state) + @defer.inlineCallbacks + def compute_summary(self, room_id, sync_config, batch, state, now_token): + """ Works out a room summary block for this room, summarising the number + of joined members in the room, and providing the 'hero' members if the + room has no name so clients can consistently name rooms. Also adds + state events to 'state' if needed to describe the heroes. + + Args: + room_id(str): + sync_config(synapse.handlers.sync.SyncConfig): + batch(synapse.handlers.sync.TimelineBatch): The timeline batch for + the room that will be sent to the user. + state(dict): dict of (type, state_key) -> Event as returned by + compute_state_delta + now_token(str): Token of the end of the current batch. + + Returns: + A deferred dict describing the room summary + """ + + # FIXME: this promulgates https://github.com/matrix-org/synapse/issues/3305 + last_events, _ = yield self.store.get_recent_event_ids_for_room( + room_id, end_token=now_token.room_key, limit=1, + ) + + if not last_events: + defer.returnValue(None) + return + + last_event = last_events[-1] + state_ids = yield self.store.get_state_ids_for_event( + last_event.event_id, [ + (EventTypes.Member, None), + (EventTypes.Name, ''), + (EventTypes.CanonicalAlias, ''), + ] + ) + + member_ids = { + state_key: event_id + for (t, state_key), event_id in state_ids.iteritems() + if t == EventTypes.Member + } + name_id = state_ids.get((EventTypes.Name, '')) + canonical_alias_id = state_ids.get((EventTypes.CanonicalAlias, '')) + + summary = {} + + # FIXME: it feels very heavy to load up every single membership event + # just to calculate the counts. + member_events = yield self.store.get_events(member_ids.values()) + + joined_user_ids = [] + invited_user_ids = [] + + for ev in member_events.values(): + if ev.content.get("membership") == Membership.JOIN: + joined_user_ids.append(ev.state_key) + elif ev.content.get("membership") == Membership.INVITE: + invited_user_ids.append(ev.state_key) + + # TODO: only send these when they change. + summary["m.joined_member_count"] = len(joined_user_ids) + summary["m.invited_member_count"] = len(invited_user_ids) + + if name_id or canonical_alias_id: + defer.returnValue(summary) + + # FIXME: order by stream ordering, not alphabetic + + me = sync_config.user.to_string() + if (joined_user_ids or invited_user_ids): + summary['m.heroes'] = sorted( + [ + user_id + for user_id in (joined_user_ids + invited_user_ids) + if user_id != me + ] + )[0:5] + else: + summary['m.heroes'] = sorted( + [user_id for user_id in member_ids.keys() if user_id != me] + )[0:5] + + if not sync_config.filter_collection.lazy_load_members(): + defer.returnValue(summary) + + # ensure we send membership events for heroes if needed + cache_key = (sync_config.user.to_string(), sync_config.device_id) + cache = self.get_lazy_loaded_members_cache(cache_key) + + # track which members the client should already know about via LL: + # Ones which are already in state... + existing_members = set( + user_id for (typ, user_id) in state.keys() + if typ == EventTypes.Member + ) + + # ...or ones which are in the timeline... + for ev in batch.events: + if ev.type == EventTypes.Member: + existing_members.add(ev.state_key) + + # ...and then ensure any missing ones get included in state. + missing_hero_event_ids = [ + member_ids[hero_id] + for hero_id in summary['m.heroes'] + if ( + cache.get(hero_id) != member_ids[hero_id] and + hero_id not in existing_members + ) + ] + + missing_hero_state = yield self.store.get_events(missing_hero_event_ids) + missing_hero_state = missing_hero_state.values() + + for s in missing_hero_state: + cache.set(s.state_key, s.event_id) + state[(EventTypes.Member, s.state_key)] = s + + defer.returnValue(summary) + + def get_lazy_loaded_members_cache(self, cache_key): + cache = self.lazy_loaded_members_cache.get(cache_key) + if cache is None: + logger.debug("creating LruCache for %r", cache_key) + cache = LruCache(LAZY_LOADED_MEMBERS_CACHE_MAX_SIZE) + self.lazy_loaded_members_cache[cache_key] = cache + else: + logger.debug("found LruCache for %r", cache_key) + return cache + @defer.inlineCallbacks def compute_state_delta(self, room_id, batch, sync_config, since_token, now_token, full_state): - """ Works out the differnce in state between the start of the timeline + """ Works out the difference in state between the start of the timeline and the previous sync. Args: @@ -520,7 +653,7 @@ class SyncHandler(object): full_state(bool): Whether to force returning the full state. Returns: - A deferred new event dictionary + A deferred dict of (type, state_key) -> Event """ # TODO(mjark) Check if the state events were received by the server # after the previous sync, since we need to include those state @@ -618,13 +751,7 @@ class SyncHandler(object): if lazy_load_members and not include_redundant_members: cache_key = (sync_config.user.to_string(), sync_config.device_id) - cache = self.lazy_loaded_members_cache.get(cache_key) - if cache is None: - logger.debug("creating LruCache for %r", cache_key) - cache = LruCache(LAZY_LOADED_MEMBERS_CACHE_MAX_SIZE) - self.lazy_loaded_members_cache[cache_key] = cache - else: - logger.debug("found LruCache for %r", cache_key) + cache = self.get_lazy_loaded_members_cache(cache_key) # if it's a new sync sequence, then assume the client has had # amnesia and doesn't want any recent lazy-loaded members @@ -1425,7 +1552,6 @@ class SyncHandler(object): if events == [] and tags is None: return - since_token = sync_result_builder.since_token now_token = sync_result_builder.now_token sync_config = sync_result_builder.sync_config @@ -1468,6 +1594,18 @@ class SyncHandler(object): full_state=full_state ) + summary = {} + if ( + sync_config.filter_collection.lazy_load_members() and + ( + any(ev.type == EventTypes.Member for ev in batch.events) or + since_token is None + ) + ): + summary = yield self.compute_summary( + room_id, sync_config, batch, state, now_token + ) + if room_builder.rtype == "joined": unread_notifications = {} room_sync = JoinedSyncResult( @@ -1477,6 +1615,7 @@ class SyncHandler(object): ephemeral=ephemeral, account_data=account_data_events, unread_notifications=unread_notifications, + summary=summary, ) if room_sync or always_include: diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index 8aa06faf23..1275baa1ba 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -370,6 +370,7 @@ class SyncRestServlet(RestServlet): ephemeral_events = room.ephemeral result["ephemeral"] = {"events": ephemeral_events} result["unread_notifications"] = room.unread_notifications + result["summary"] = room.summary return result diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 44f37b4c1e..08dffd774f 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -1150,17 +1150,16 @@ class SQLBaseStore(object): defer.returnValue(retval) def get_user_count_txn(self, txn): - """Get a total number of registerd users in the users list. + """Get a total number of registered users in the users list. Args: txn : Transaction object Returns: - defer.Deferred: resolves to int + int : number of users """ sql_count = "SELECT COUNT(*) FROM users WHERE is_guest = 0;" txn.execute(sql_count) - count = txn.fetchone()[0] - defer.returnValue(count) + return txn.fetchone()[0] def _simple_search_list(self, table, term, col, retcols, desc="_simple_search_list"): diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 754dfa6973..dd03c4168b 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -480,7 +480,8 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): @defer.inlineCallbacks def get_state_ids_for_events(self, event_ids, types=None, filtered_types=None): """ - Get the state dicts corresponding to a list of events + Get the state dicts corresponding to a list of events, containing the event_ids + of the state events (as opposed to the events themselves) Args: event_ids(list(str)): events whose state should be returned @@ -493,7 +494,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): If None, `types` filtering is applied to all events. Returns: - A deferred dict from event_id -> (type, state_key) -> state_event + A deferred dict from event_id -> (type, state_key) -> event_id """ event_to_groups = yield self._get_state_group_for_events( event_ids, diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index b9f2b74ac6..4c296d72c0 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -348,7 +348,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): end_token (str): The stream token representing now. Returns: - Deferred[tuple[list[FrozenEvent], str]]: Returns a list of + Deferred[tuple[list[FrozenEvent], str]]: Returns a list of events and a token pointing to the start of the returned events. The events returned are in ascending order. @@ -379,7 +379,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): end_token (str): The stream token representing now. Returns: - Deferred[tuple[list[_EventDictReturn], str]]: Returns a list of + Deferred[tuple[list[_EventDictReturn], str]]: Returns a list of _EventDictReturn and a token pointing to the start of the returned events. The events returned are in ascending order. -- cgit 1.5.1 From bcfeb44afe750dadd4199e7c02302b0157bdab11 Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Thu, 16 Aug 2018 22:55:32 +0100 Subject: call reap on start up and fix under reaping bug --- synapse/api/auth.py | 2 +- synapse/app/homeserver.py | 1 + synapse/storage/monthly_active_users.py | 5 ++++- tests/storage/test_monthly_active_users.py | 13 +++++++++++++ 4 files changed, 19 insertions(+), 2 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 3b2a2ab77a..ab1e3a4e35 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -799,7 +799,7 @@ class Auth(object): current_mau = yield self.store.get_monthly_active_count() if current_mau >= self.hs.config.max_mau_value: raise AuthError( - 403, "Monthly Active User Limits AU Limit Exceeded", + 403, "Monthly Active User Limit Exceeded", admin_uri=self.hs.config.admin_uri, errcode=Codes.RESOURCE_LIMIT_EXCEED ) diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index a98bb506e5..800b9c0e3c 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -525,6 +525,7 @@ def run(hs): clock.looping_call( hs.get_datastore().reap_monthly_active_users, 1000 * 60 * 60 ) + yield hs.get_datastore().reap_monthly_active_users() @defer.inlineCallbacks def generate_monthly_active_users(): diff --git a/synapse/storage/monthly_active_users.py b/synapse/storage/monthly_active_users.py index 7e417f811e..06f9a75a97 100644 --- a/synapse/storage/monthly_active_users.py +++ b/synapse/storage/monthly_active_users.py @@ -96,7 +96,10 @@ class MonthlyActiveUsersStore(SQLBaseStore): # While Postgres does not require 'LIMIT', but also does not support # negative LIMIT values. So there is no way to write it that both can # support - query_args = [self.hs.config.max_mau_value] + safe_guard = self.hs.config.max_mau_value - len(self.reserved_users) + # Must be greater than zero for postgres + safe_guard = safe_guard if safe_guard > 0 else 0 + query_args = [safe_guard] base_sql = """ DELETE FROM monthly_active_users diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index 511acbde9b..f2ed866ae7 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -75,6 +75,19 @@ class MonthlyActiveUsersTestCase(tests.unittest.TestCase): active_count = yield self.store.get_monthly_active_count() self.assertEquals(active_count, user_num) + # Test that regalar users are removed from the db + ru_count = 2 + yield self.store.upsert_monthly_active_user("@ru1:server") + yield self.store.upsert_monthly_active_user("@ru2:server") + active_count = yield self.store.get_monthly_active_count() + + self.assertEqual(active_count, user_num + ru_count) + self.hs.config.max_mau_value = user_num + yield self.store.reap_monthly_active_users() + + active_count = yield self.store.get_monthly_active_count() + self.assertEquals(active_count, user_num) + @defer.inlineCallbacks def test_can_insert_and_count_mau(self): count = yield self.store.get_monthly_active_count() -- cgit 1.5.1