From 6ef983ce5cc0a1cd7323ac82c8eed41d72ff3a99 Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Tue, 31 Jul 2018 16:36:24 +0100 Subject: api into monthly_active_users table --- synapse/storage/monthly_active_users.py | 89 ++++++++++++++++++++++ synapse/storage/prepare_database.py | 2 +- .../schema/delta/50/make_event_content_nullable.py | 2 +- .../schema/delta/51/monthly_active_users.sql | 23 ++++++ 4 files changed, 114 insertions(+), 2 deletions(-) create mode 100644 synapse/storage/monthly_active_users.py create mode 100644 synapse/storage/schema/delta/51/monthly_active_users.sql (limited to 'synapse/storage') diff --git a/synapse/storage/monthly_active_users.py b/synapse/storage/monthly_active_users.py new file mode 100644 index 0000000000..373e828c0a --- /dev/null +++ b/synapse/storage/monthly_active_users.py @@ -0,0 +1,89 @@ +from twisted.internet import defer + +from ._base import SQLBaseStore + + +class MonthlyActiveUsersStore(SQLBaseStore): + def __init__(self, hs): + super(MonthlyActiveUsersStore, self).__init__(None, hs) + self._clock = hs.get_clock() + + def reap_monthly_active_users(self): + """ + Cleans out monthly active user table to ensure that no stale + entries exist. + Return: + defered, no return type + """ + def _reap_users(txn): + thirty_days_ago = ( + int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30) + ) + sql = "DELETE FROM monthly_active_users WHERE timestamp < ?" + txn.execute(sql, (thirty_days_ago,)) + + return self.runInteraction("reap_monthly_active_users", _reap_users) + + def get_monthly_active_count(self): + """ + Generates current count of monthly active users.abs + return: + defered resolves to int + """ + def _count_users(txn): + sql = """ + SELECT COALESCE(count(*), 0) FROM ( + SELECT user_id FROM monthly_active_users + ) u + """ + txn.execute(sql) + count, = txn.fetchone() + return count + return self.runInteraction("count_users", _count_users) + + def upsert_monthly_active_user(self, user_id): + """ + Updates or inserts monthly active user member + Arguments: + user_id (str): user to add/update + """ + return self._simple_upsert( + desc="upsert_monthly_active_user", + table="monthly_active_users", + keyvalues={ + "user_id": user_id, + }, + values={ + "timestamp": int(self._clock.time_msec()), + }, + lock=False, + ) + + def clean_out_monthly_active_users(self): + pass + + @defer.inlineCallbacks + def is_user_monthly_active(self, user_id): + """ + Checks if a given user is part of the monthly active user group + Arguments: + user_id (str): user to add/update + Return: + bool : True if user part of group, False otherwise + """ + user_present = yield self._simple_select_onecol( + table="monthly_active_users", + keyvalues={ + "user_id": user_id, + }, + retcol="user_id", + desc="is_user_monthly_active", + ) + # jeff = self.cursor_to_dict(res) + result = False + if user_present: + result = True + + defer.returnValue( + result + ) diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index b290f834b3..b364719312 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -25,7 +25,7 @@ logger = logging.getLogger(__name__) # Remember to update this number every time a change is made to database # schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 50 +SCHEMA_VERSION = 51 dir_path = os.path.abspath(os.path.dirname(__file__)) diff --git a/synapse/storage/schema/delta/50/make_event_content_nullable.py b/synapse/storage/schema/delta/50/make_event_content_nullable.py index 7d27342e39..6dd467b6c5 100644 --- a/synapse/storage/schema/delta/50/make_event_content_nullable.py +++ b/synapse/storage/schema/delta/50/make_event_content_nullable.py @@ -88,5 +88,5 @@ def run_upgrade(cur, database_engine, *args, **kwargs): "UPDATE sqlite_master SET sql=? WHERE tbl_name='events' AND type='table'", (sql, ), ) - cur.execute("PRAGMA schema_version=%i" % (oldver+1,)) + cur.execute("PRAGMA schema_version=%i" % (oldver + 1,)) cur.execute("PRAGMA writable_schema=OFF") diff --git a/synapse/storage/schema/delta/51/monthly_active_users.sql b/synapse/storage/schema/delta/51/monthly_active_users.sql new file mode 100644 index 0000000000..b3c0e1a676 --- /dev/null +++ b/synapse/storage/schema/delta/51/monthly_active_users.sql @@ -0,0 +1,23 @@ +/* Copyright 2018 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- a table of users who have requested that their details be erased +CREATE TABLE monthly_active_users ( + user_id TEXT NOT NULL, + timestamp BIGINT NOT NULL +); + +CREATE UNIQUE INDEX monthly_active_users_users ON monthly_active_users(user_id); +CREATE INDEX monthly_active_users_time_stamp ON monthly_active_users(timestamp); -- cgit 1.5.1 From f9f55599718ba9849b2dee18e253dcaf8f5fcfe7 Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Wed, 1 Aug 2018 12:03:42 +0100 Subject: fix comment --- synapse/storage/schema/delta/51/monthly_active_users.sql | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse/storage') diff --git a/synapse/storage/schema/delta/51/monthly_active_users.sql b/synapse/storage/schema/delta/51/monthly_active_users.sql index b3c0e1a676..f2b6d3e31e 100644 --- a/synapse/storage/schema/delta/51/monthly_active_users.sql +++ b/synapse/storage/schema/delta/51/monthly_active_users.sql @@ -13,7 +13,7 @@ * limitations under the License. */ --- a table of users who have requested that their details be erased +-- a table of monthly active users, for use where blocking based on mau limits CREATE TABLE monthly_active_users ( user_id TEXT NOT NULL, timestamp BIGINT NOT NULL -- cgit 1.5.1 From 4e5ac901ddd15e8938605c5ac4312be804997ed5 Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Wed, 1 Aug 2018 12:03:57 +0100 Subject: clean up --- synapse/storage/monthly_active_users.py | 18 +++--------------- 1 file changed, 3 insertions(+), 15 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/monthly_active_users.py b/synapse/storage/monthly_active_users.py index 373e828c0a..03eeea792e 100644 --- a/synapse/storage/monthly_active_users.py +++ b/synapse/storage/monthly_active_users.py @@ -31,11 +31,8 @@ class MonthlyActiveUsersStore(SQLBaseStore): defered resolves to int """ def _count_users(txn): - sql = """ - SELECT COALESCE(count(*), 0) FROM ( - SELECT user_id FROM monthly_active_users - ) u - """ + sql = "SELECT COALESCE(count(*), 0) FROM monthly_active_users" + txn.execute(sql) count, = txn.fetchone() return count @@ -59,9 +56,6 @@ class MonthlyActiveUsersStore(SQLBaseStore): lock=False, ) - def clean_out_monthly_active_users(self): - pass - @defer.inlineCallbacks def is_user_monthly_active(self, user_id): """ @@ -79,11 +73,5 @@ class MonthlyActiveUsersStore(SQLBaseStore): retcol="user_id", desc="is_user_monthly_active", ) - # jeff = self.cursor_to_dict(res) - result = False - if user_present: - result = True - defer.returnValue( - result - ) + defer.returnValue(bool(user_present)) -- cgit 1.5.1 From ec716a35b219d147dee51733b55573952799a549 Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Wed, 1 Aug 2018 17:54:37 +0100 Subject: change monthly_active_users table to be a single column --- synapse/storage/monthly_active_users.py | 10 +++------- synapse/storage/schema/delta/51/monthly_active_users.sql | 4 +--- tests/storage/test_monthly_active_users.py | 6 +++--- 3 files changed, 7 insertions(+), 13 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/monthly_active_users.py b/synapse/storage/monthly_active_users.py index 03eeea792e..7b3f13aedf 100644 --- a/synapse/storage/monthly_active_users.py +++ b/synapse/storage/monthly_active_users.py @@ -38,22 +38,18 @@ class MonthlyActiveUsersStore(SQLBaseStore): return count return self.runInteraction("count_users", _count_users) - def upsert_monthly_active_user(self, user_id): + def insert_monthly_active_user(self, user_id): """ Updates or inserts monthly active user member Arguments: user_id (str): user to add/update """ - return self._simple_upsert( + return self._simple_insert( desc="upsert_monthly_active_user", table="monthly_active_users", - keyvalues={ - "user_id": user_id, - }, values={ - "timestamp": int(self._clock.time_msec()), + "user_id": user_id, }, - lock=False, ) @defer.inlineCallbacks diff --git a/synapse/storage/schema/delta/51/monthly_active_users.sql b/synapse/storage/schema/delta/51/monthly_active_users.sql index f2b6d3e31e..590b1abab2 100644 --- a/synapse/storage/schema/delta/51/monthly_active_users.sql +++ b/synapse/storage/schema/delta/51/monthly_active_users.sql @@ -15,9 +15,7 @@ -- a table of monthly active users, for use where blocking based on mau limits CREATE TABLE monthly_active_users ( - user_id TEXT NOT NULL, - timestamp BIGINT NOT NULL + user_id TEXT NOT NULL ); CREATE UNIQUE INDEX monthly_active_users_users ON monthly_active_users(user_id); -CREATE INDEX monthly_active_users_time_stamp ON monthly_active_users(timestamp); diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index 9b1ffc6369..7a8432ce69 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -22,7 +22,7 @@ class MonthlyActiveUsersTestCase(tests.unittest.TestCase): count = yield self.mau.get_monthly_active_count() self.assertEqual(0, count) - yield self.mau.upsert_monthly_active_user("@user:server") + yield self.mau.insert_monthly_active_user("@user:server") count = yield self.mau.get_monthly_active_count() self.assertEqual(1, count) @@ -34,8 +34,8 @@ class MonthlyActiveUsersTestCase(tests.unittest.TestCase): user_id3 = "@user3:server" result = yield self.mau.is_user_monthly_active(user_id1) self.assertFalse(result) - yield self.mau.upsert_monthly_active_user(user_id1) - yield self.mau.upsert_monthly_active_user(user_id2) + yield self.mau.insert_monthly_active_user(user_id1) + yield self.mau.insert_monthly_active_user(user_id2) result = yield self.mau.is_user_monthly_active(user_id1) self.assertTrue(result) result = yield self.mau.is_user_monthly_active(user_id3) -- cgit 1.5.1 From c21d82bab3b32e2f59c9ef09a1a10b337ce45db4 Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Wed, 1 Aug 2018 23:24:38 +0100 Subject: normalise reaping query --- synapse/storage/monthly_active_users.py | 41 ++++++++++++++++++++++++++++++--- 1 file changed, 38 insertions(+), 3 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/monthly_active_users.py b/synapse/storage/monthly_active_users.py index 7b3f13aedf..0741c7fa61 100644 --- a/synapse/storage/monthly_active_users.py +++ b/synapse/storage/monthly_active_users.py @@ -7,6 +7,7 @@ class MonthlyActiveUsersStore(SQLBaseStore): def __init__(self, hs): super(MonthlyActiveUsersStore, self).__init__(None, hs) self._clock = hs.get_clock() + self.max_mau_value = hs.config.max_mau_value def reap_monthly_active_users(self): """ @@ -19,8 +20,42 @@ class MonthlyActiveUsersStore(SQLBaseStore): thirty_days_ago = ( int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30) ) - sql = "DELETE FROM monthly_active_users WHERE timestamp < ?" - txn.execute(sql, (thirty_days_ago,)) + + # Query deletes the union of users that have either: + # * not visited in the last 30 days + # * exceeded the total max_mau_value threshold. Where there is + # an excess, more recent users are favoured - this is to cover + # the case where the limit has been step change reduced. + # + sql = """ + DELETE FROM monthly_active_users + WHERE user_id + IN ( + SELECT * FROM ( + SELECT monthly_active_users.user_id + FROM monthly_active_users + LEFT JOIN ( + SELECT user_id, max(last_seen) AS last_seen + FROM user_ips + GROUP BY user_id + ) AS uip ON uip.user_id=monthly_active_users.user_id + ORDER BY uip.last_seen desc LIMIT -1 OFFSET ? + ) + UNION + SELECT * FROM ( + SELECT monthly_active_users.user_id + FROM monthly_active_users + LEFT JOIN ( + SELECT user_id, max(last_seen) AS last_seen + FROM user_ips + GROUP BY user_id + ) AS uip ON uip.user_id=monthly_active_users.user_id + WHERE uip.last_seen < ? + ) + ) + """ + + txn.execute(sql, (self.max_mau_value, thirty_days_ago,)) return self.runInteraction("reap_monthly_active_users", _reap_users) @@ -45,7 +80,7 @@ class MonthlyActiveUsersStore(SQLBaseStore): user_id (str): user to add/update """ return self._simple_insert( - desc="upsert_monthly_active_user", + desc="insert_monthly_active_user", table="monthly_active_users", values={ "user_id": user_id, -- cgit 1.5.1 From 08281fe6b7cdd9620d28d2aa6b725bdb0a351bbc Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Wed, 1 Aug 2018 23:26:24 +0100 Subject: self.db_conn unused --- synapse/storage/__init__.py | 1 - 1 file changed, 1 deletion(-) (limited to 'synapse/storage') diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 134e4a80f1..04ff1f8006 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -94,7 +94,6 @@ class DataStore(RoomMemberStore, RoomStore, self._clock = hs.get_clock() self.database_engine = hs.database_engine - self.db_conn = db_conn self._stream_id_gen = StreamIdGenerator( db_conn, "events", "stream_ordering", extra_tables=[("local_invites", "stream_id")] -- cgit 1.5.1 From 165e06703340667dd88eb73dfe299a4d3298b94b Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Thu, 2 Aug 2018 10:59:58 +0100 Subject: Revert "change monthly_active_users table to be a single column" This reverts commit ec716a35b219d147dee51733b55573952799a549. --- synapse/storage/monthly_active_users.py | 10 +++++++--- synapse/storage/schema/delta/51/monthly_active_users.sql | 4 +++- tests/storage/test_monthly_active_users.py | 6 +++--- 3 files changed, 13 insertions(+), 7 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/monthly_active_users.py b/synapse/storage/monthly_active_users.py index 7b3f13aedf..03eeea792e 100644 --- a/synapse/storage/monthly_active_users.py +++ b/synapse/storage/monthly_active_users.py @@ -38,18 +38,22 @@ class MonthlyActiveUsersStore(SQLBaseStore): return count return self.runInteraction("count_users", _count_users) - def insert_monthly_active_user(self, user_id): + def upsert_monthly_active_user(self, user_id): """ Updates or inserts monthly active user member Arguments: user_id (str): user to add/update """ - return self._simple_insert( + return self._simple_upsert( desc="upsert_monthly_active_user", table="monthly_active_users", - values={ + keyvalues={ "user_id": user_id, }, + values={ + "timestamp": int(self._clock.time_msec()), + }, + lock=False, ) @defer.inlineCallbacks diff --git a/synapse/storage/schema/delta/51/monthly_active_users.sql b/synapse/storage/schema/delta/51/monthly_active_users.sql index 590b1abab2..f2b6d3e31e 100644 --- a/synapse/storage/schema/delta/51/monthly_active_users.sql +++ b/synapse/storage/schema/delta/51/monthly_active_users.sql @@ -15,7 +15,9 @@ -- a table of monthly active users, for use where blocking based on mau limits CREATE TABLE monthly_active_users ( - user_id TEXT NOT NULL + user_id TEXT NOT NULL, + timestamp BIGINT NOT NULL ); CREATE UNIQUE INDEX monthly_active_users_users ON monthly_active_users(user_id); +CREATE INDEX monthly_active_users_time_stamp ON monthly_active_users(timestamp); diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index 7a8432ce69..9b1ffc6369 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -22,7 +22,7 @@ class MonthlyActiveUsersTestCase(tests.unittest.TestCase): count = yield self.mau.get_monthly_active_count() self.assertEqual(0, count) - yield self.mau.insert_monthly_active_user("@user:server") + yield self.mau.upsert_monthly_active_user("@user:server") count = yield self.mau.get_monthly_active_count() self.assertEqual(1, count) @@ -34,8 +34,8 @@ class MonthlyActiveUsersTestCase(tests.unittest.TestCase): user_id3 = "@user3:server" result = yield self.mau.is_user_monthly_active(user_id1) self.assertFalse(result) - yield self.mau.insert_monthly_active_user(user_id1) - yield self.mau.insert_monthly_active_user(user_id2) + yield self.mau.upsert_monthly_active_user(user_id1) + yield self.mau.upsert_monthly_active_user(user_id2) result = yield self.mau.is_user_monthly_active(user_id1) self.assertTrue(result) result = yield self.mau.is_user_monthly_active(user_id3) -- cgit 1.5.1 From 00f99f74b1b875bb7ac6b0623994cabad4a59cc6 Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Thu, 2 Aug 2018 13:47:19 +0100 Subject: insertion into monthly_active_users --- synapse/api/auth.py | 2 +- synapse/storage/__init__.py | 2 + synapse/storage/client_ips.py | 22 ++++++++++- synapse/storage/monthly_active_users.py | 18 ++++++--- tests/storage/test_client_ips.py | 66 +++++++++++++++++++++++++++++++-- 5 files changed, 99 insertions(+), 11 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 5bbbe8e2e7..d8022bcf8e 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -213,7 +213,7 @@ class Auth(object): default=[b""] )[0] if user and access_token and ip_addr: - self.store.insert_client_ip( + yield self.store.insert_client_ip( user_id=user.to_string(), access_token=access_token, ip=ip_addr, diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 04ff1f8006..2120f46ed5 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -39,6 +39,7 @@ from .filtering import FilteringStore from .group_server import GroupServerStore from .keys import KeyStore from .media_repository import MediaRepositoryStore +from .monthly_active_users import MonthlyActiveUsersStore from .openid import OpenIdStore from .presence import PresenceStore, UserPresenceState from .profile import ProfileStore @@ -87,6 +88,7 @@ class DataStore(RoomMemberStore, RoomStore, UserDirectoryStore, GroupServerStore, UserErasureStore, + MonthlyActiveUsersStore, ): def __init__(self, db_conn, hs): diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py index b8cefd43d6..506915a1ef 100644 --- a/synapse/storage/client_ips.py +++ b/synapse/storage/client_ips.py @@ -35,6 +35,7 @@ LAST_SEEN_GRANULARITY = 120 * 1000 class ClientIpStore(background_updates.BackgroundUpdateStore): def __init__(self, db_conn, hs): + self.client_ip_last_seen = Cache( name="client_ip_last_seen", keylen=4, @@ -74,6 +75,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): "before", "shutdown", self._update_client_ips_batch ) + @defer.inlineCallbacks def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id, now=None): if not now: @@ -84,7 +86,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): last_seen = self.client_ip_last_seen.get(key) except KeyError: last_seen = None - + yield self._populate_monthly_active_users(user_id) # Rate-limited inserts if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY: return @@ -93,6 +95,24 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): self._batch_row_update[key] = (user_agent, device_id, now) + @defer.inlineCallbacks + def _populate_monthly_active_users(self, user_id): + store = self.hs.get_datastore() + print "entering _populate_monthly_active_users" + if self.hs.config.limit_usage_by_mau: + print "self.hs.config.limit_usage_by_mau is TRUE" + is_user_monthly_active = yield store.is_user_monthly_active(user_id) + print "is_user_monthly_active is %r" % is_user_monthly_active + if is_user_monthly_active: + yield store.upsert_monthly_active_user(user_id) + else: + count = yield store.get_monthly_active_count() + print "count is %d" % count + if count < self.hs.config.max_mau_value: + print "count is less than self.hs.config.max_mau_value " + res = yield store.upsert_monthly_active_user(user_id) + print "upsert response is %r" % res + def _update_client_ips_batch(self): def update(): to_update = self._batch_row_update diff --git a/synapse/storage/monthly_active_users.py b/synapse/storage/monthly_active_users.py index 2337438c58..aada9bd2b9 100644 --- a/synapse/storage/monthly_active_users.py +++ b/synapse/storage/monthly_active_users.py @@ -4,7 +4,7 @@ from ._base import SQLBaseStore class MonthlyActiveUsersStore(SQLBaseStore): - def __init__(self, hs): + def __init__(self, dbconn, hs): super(MonthlyActiveUsersStore, self).__init__(None, hs) self._clock = hs.get_clock() self.max_mau_value = hs.config.max_mau_value @@ -14,24 +14,28 @@ class MonthlyActiveUsersStore(SQLBaseStore): Cleans out monthly active user table to ensure that no stale entries exist. Return: - defered, no return type + Defered() """ def _reap_users(txn): thirty_days_ago = ( int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30) ) - sql = "DELETE FROM monthly_active_users WHERE timestamp < ?" - txn.execute(sql, (thirty_days_ago,)) + sql = """ + DELETE FROM monthly_active_users + ORDER BY timestamp desc + LIMIT -1 OFFSET ? + """ + txn.execute(sql, (self.max_mau_value,)) return self.runInteraction("reap_monthly_active_users", _reap_users) def get_monthly_active_count(self): """ Generates current count of monthly active users.abs - return: - defered resolves to int + Return: + Defered(int): Number of current monthly active users """ def _count_users(txn): sql = "SELECT COALESCE(count(*), 0) FROM monthly_active_users" @@ -46,6 +50,8 @@ class MonthlyActiveUsersStore(SQLBaseStore): Updates or inserts monthly active user member Arguments: user_id (str): user to add/update + Deferred(bool): True if a new entry was created, False if an + existing one was updated. """ return self._simple_upsert( desc="upsert_monthly_active_user", diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index bd6fda6cb1..e1552510cc 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from mock import Mock from twisted.internet import defer @@ -27,9 +28,9 @@ class ClientIpStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def setUp(self): - hs = yield tests.utils.setup_test_homeserver() - self.store = hs.get_datastore() - self.clock = hs.get_clock() + self.hs = yield tests.utils.setup_test_homeserver() + self.store = self.hs.get_datastore() + self.clock = self.hs.get_clock() @defer.inlineCallbacks def test_insert_new_client_ip(self): @@ -54,3 +55,62 @@ class ClientIpStoreTestCase(tests.unittest.TestCase): }, r ) + + @defer.inlineCallbacks + def test_disabled_monthly_active_user(self): + self.hs.config.limit_usage_by_mau = False + self.hs.config.max_mau_value = 50 + user_id = "@user:server" + yield self.store.insert_client_ip( + user_id, "access_token", "ip", "user_agent", "device_id", + ) + active = yield self.store.is_user_monthly_active(user_id) + self.assertFalse(active) + + @defer.inlineCallbacks + def test_adding_monthly_active_user_when_full(self): + self.hs.config.limit_usage_by_mau = True + self.hs.config.max_mau_value = 50 + lots_of_users = 100 + user_id = "@user:server" + + self.store.get_monthly_active_count = Mock( + return_value=defer.succeed(lots_of_users) + ) + yield self.store.insert_client_ip( + user_id, "access_token", "ip", "user_agent", "device_id", + ) + active = yield self.store.is_user_monthly_active(user_id) + self.assertFalse(active) + + @defer.inlineCallbacks + def test_adding_monthly_active_user_when_space(self): + self.hs.config.limit_usage_by_mau = True + self.hs.config.max_mau_value = 50 + user_id = "@user:server" + active = yield self.store.is_user_monthly_active(user_id) + self.assertFalse(active) + + yield self.store.insert_client_ip( + user_id, "access_token", "ip", "user_agent", "device_id", + ) + active = yield self.store.is_user_monthly_active(user_id) + self.assertTrue(active) + + @defer.inlineCallbacks + def test_updating_monthly_active_user_when_space(self): + self.hs.config.limit_usage_by_mau = True + self.hs.config.max_mau_value = 50 + user_id = "@user:server" + + active = yield self.store.is_user_monthly_active(user_id) + self.assertFalse(active) + + yield self.store.insert_client_ip( + user_id, "access_token", "ip", "user_agent", "device_id", + ) + yield self.store.insert_client_ip( + user_id, "access_token", "ip", "user_agent", "device_id", + ) + active = yield self.store.is_user_monthly_active(user_id) + self.assertTrue(active) -- cgit 1.5.1 From 9180061b49907da363f2e3bfa0061f913d7ccf7a Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Thu, 2 Aug 2018 15:55:29 +0100 Subject: remove unused count_monthly_users --- synapse/storage/__init__.py | 25 ----------------- tests/storage/test__init__.py | 65 ------------------------------------------- 2 files changed, 90 deletions(-) delete mode 100644 tests/storage/test__init__.py (limited to 'synapse/storage') diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 2120f46ed5..23b4a8d76d 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -268,31 +268,6 @@ class DataStore(RoomMemberStore, RoomStore, return self.runInteraction("count_users", _count_users) - def count_monthly_users(self): - """Counts the number of users who used this homeserver in the last 30 days - - This method should be refactored with count_daily_users - the only - reason not to is waiting on definition of mau - - Returns: - Defered[int] - """ - def _count_monthly_users(txn): - thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30) - sql = """ - SELECT COALESCE(count(*), 0) FROM ( - SELECT user_id FROM user_ips - WHERE last_seen > ? - GROUP BY user_id - ) u - """ - - txn.execute(sql, (thirty_days_ago,)) - count, = txn.fetchone() - return count - - return self.runInteraction("count_monthly_users", _count_monthly_users) - def count_r30_users(self): """ Counts the number of 30 day retained users, defined as:- diff --git a/tests/storage/test__init__.py b/tests/storage/test__init__.py deleted file mode 100644 index f19cb1265c..0000000000 --- a/tests/storage/test__init__.py +++ /dev/null @@ -1,65 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2018 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from twisted.internet import defer - -import tests.utils - - -class InitTestCase(tests.unittest.TestCase): - def __init__(self, *args, **kwargs): - super(InitTestCase, self).__init__(*args, **kwargs) - self.store = None # type: synapse.storage.DataStore - - @defer.inlineCallbacks - def setUp(self): - hs = yield tests.utils.setup_test_homeserver() - - hs.config.max_mau_value = 50 - hs.config.limit_usage_by_mau = True - self.store = hs.get_datastore() - self.clock = hs.get_clock() - - @defer.inlineCallbacks - def test_count_monthly_users(self): - count = yield self.store.count_monthly_users() - self.assertEqual(0, count) - - yield self._insert_user_ips("@user:server1") - yield self._insert_user_ips("@user:server2") - - count = yield self.store.count_monthly_users() - self.assertEqual(2, count) - - @defer.inlineCallbacks - def _insert_user_ips(self, user): - """ - Helper function to populate user_ips without using batch insertion infra - args: - user (str): specify username i.e. @user:server.com - """ - yield self.store._simple_upsert( - table="user_ips", - keyvalues={ - "user_id": user, - "access_token": "access_token", - "ip": "ip", - "user_agent": "user_agent", - "device_id": "device_id", - }, - values={ - "last_seen": self.clock.time_msec(), - } - ) -- cgit 1.5.1 From 74b1d46ad9ae692774f2e9d71cbbe1cea91b4070 Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Thu, 2 Aug 2018 16:57:35 +0100 Subject: do mau checks based on monthly_active_users table --- synapse/api/auth.py | 13 ++++++++ synapse/handlers/auth.py | 10 +++--- synapse/handlers/register.py | 10 +++--- synapse/storage/client_ips.py | 15 +++++---- tests/api/test_auth.py | 31 +++++++++++++++++- tests/handlers/test_auth.py | 8 ++--- tests/handlers/test_register.py | 71 ++++++++++++++++++++--------------------- 7 files changed, 97 insertions(+), 61 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index d8022bcf8e..943a488339 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -773,3 +773,16 @@ class Auth(object): raise AuthError( 403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN ) + + @defer.inlineCallbacks + def check_auth_blocking(self, error): + """Checks if the user should be rejected for some external reason, + such as monthly active user limiting or global disable flag + Args: + error (Error): The error that should be raised if user is to be + blocked + """ + if self.hs.config.limit_usage_by_mau is True: + current_mau = yield self.store.get_monthly_active_count() + if current_mau >= self.hs.config.max_mau_value: + raise error diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 184eef09d0..8f9cff92e8 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -913,12 +913,10 @@ class AuthHandler(BaseHandler): Ensure that if mau blocking is enabled that invalid users cannot log in. """ - if self.hs.config.limit_usage_by_mau is True: - current_mau = yield self.store.count_monthly_users() - if current_mau >= self.hs.config.max_mau_value: - raise AuthError( - 403, "MAU Limit Exceeded", errcode=Codes.MAU_LIMIT_EXCEEDED - ) + error = AuthError( + 403, "Monthly Active User limits exceeded", errcode=Codes.MAU_LIMIT_EXCEEDED + ) + yield self.auth.check_auth_blocking(error) @attr.s diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 289704b241..706ed8c292 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -540,9 +540,7 @@ class RegistrationHandler(BaseHandler): Do not accept registrations if monthly active user limits exceeded and limiting is enabled """ - if self.hs.config.limit_usage_by_mau is True: - current_mau = yield self.store.count_monthly_users() - if current_mau >= self.hs.config.max_mau_value: - raise RegistrationError( - 403, "MAU Limit Exceeded", Codes.MAU_LIMIT_EXCEEDED - ) + error = RegistrationError( + 403, "Monthly Active User limits exceeded", errcode=Codes.MAU_LIMIT_EXCEEDED + ) + yield self.auth.check_auth_blocking(error) diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py index 506915a1ef..83d64d1563 100644 --- a/synapse/storage/client_ips.py +++ b/synapse/storage/client_ips.py @@ -97,21 +97,22 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): @defer.inlineCallbacks def _populate_monthly_active_users(self, user_id): + """Checks on the state of monthly active user limits and optionally + add the user to the monthly active tables + + Args: + user_id(str): the user_id to query + """ + store = self.hs.get_datastore() - print "entering _populate_monthly_active_users" if self.hs.config.limit_usage_by_mau: - print "self.hs.config.limit_usage_by_mau is TRUE" is_user_monthly_active = yield store.is_user_monthly_active(user_id) - print "is_user_monthly_active is %r" % is_user_monthly_active if is_user_monthly_active: yield store.upsert_monthly_active_user(user_id) else: count = yield store.get_monthly_active_count() - print "count is %d" % count if count < self.hs.config.max_mau_value: - print "count is less than self.hs.config.max_mau_value " - res = yield store.upsert_monthly_active_user(user_id) - print "upsert response is %r" % res + yield store.upsert_monthly_active_user(user_id) def _update_client_ips_batch(self): def update(): diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index a82d737e71..54bdf28663 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -21,7 +21,7 @@ from twisted.internet import defer import synapse.handlers.auth from synapse.api.auth import Auth -from synapse.api.errors import AuthError +from synapse.api.errors import AuthError, Codes from synapse.types import UserID from tests import unittest @@ -444,3 +444,32 @@ class AuthTestCase(unittest.TestCase): self.assertEqual("Guest access token used for regular user", cm.exception.msg) self.store.get_user_by_id.assert_called_with(USER_ID) + + @defer.inlineCallbacks + def test_blocking_mau(self): + self.hs.config.limit_usage_by_mau = False + self.hs.config.max_mau_value = 50 + lots_of_users = 100 + small_number_of_users = 1 + + error = AuthError( + 403, "MAU Limit Exceeded", errcode=Codes.MAU_LIMIT_EXCEEDED + ) + + # Ensure no error thrown + yield self.auth.check_auth_blocking(error) + + self.hs.config.limit_usage_by_mau = True + + self.store.get_monthly_active_count = Mock( + return_value=defer.succeed(lots_of_users) + ) + + with self.assertRaises(AuthError): + yield self.auth.check_auth_blocking(error) + + # Ensure does not throw an error + self.store.get_monthly_active_count = Mock( + return_value=defer.succeed(small_number_of_users) + ) + yield self.auth.check_auth_blocking(error) diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index 55eab9e9cf..8a9bf2d5fd 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -132,14 +132,14 @@ class AuthTestCase(unittest.TestCase): @defer.inlineCallbacks def test_mau_limits_exceeded(self): self.hs.config.limit_usage_by_mau = True - self.hs.get_datastore().count_monthly_users = Mock( + self.hs.get_datastore().get_monthly_active_count = Mock( return_value=defer.succeed(self.large_number_of_users) ) with self.assertRaises(AuthError): yield self.auth_handler.get_access_token_for_user_id('user_a') - self.hs.get_datastore().count_monthly_users = Mock( + self.hs.get_datastore().get_monthly_active_count = Mock( return_value=defer.succeed(self.large_number_of_users) ) with self.assertRaises(AuthError): @@ -151,13 +151,13 @@ class AuthTestCase(unittest.TestCase): def test_mau_limits_not_exceeded(self): self.hs.config.limit_usage_by_mau = True - self.hs.get_datastore().count_monthly_users = Mock( + self.hs.get_datastore().get_monthly_active_count = Mock( return_value=defer.succeed(self.small_number_of_users) ) # Ensure does not raise exception yield self.auth_handler.get_access_token_for_user_id('user_a') - self.hs.get_datastore().count_monthly_users = Mock( + self.hs.get_datastore().get_monthly_active_count = Mock( return_value=defer.succeed(self.small_number_of_users) ) yield self.auth_handler.validate_short_term_login_token_and_get_user_id( diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index 0937d71cf6..6b5b8b3772 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -50,6 +50,10 @@ class RegistrationTestCase(unittest.TestCase): self.hs.get_macaroon_generator = Mock(return_value=self.macaroon_generator) self.hs.handlers = RegistrationHandlers(self.hs) self.handler = self.hs.get_handlers().registration_handler + self.store = self.hs.get_datastore() + self.hs.config.max_mau_value = 50 + self.lots_of_users = 100 + self.small_number_of_users = 1 @defer.inlineCallbacks def test_user_is_created_and_logged_in_if_doesnt_exist(self): @@ -80,51 +84,44 @@ class RegistrationTestCase(unittest.TestCase): self.assertEquals(result_token, 'secret') @defer.inlineCallbacks - def test_cannot_register_when_mau_limits_exceeded(self): - local_part = "someone" - display_name = "someone" - requester = create_requester("@as:test") - store = self.hs.get_datastore() + def test_mau_limits_when_disabled(self): self.hs.config.limit_usage_by_mau = False - self.hs.config.max_mau_value = 50 - lots_of_users = 100 - small_number_users = 1 - - store.count_monthly_users = Mock(return_value=defer.succeed(lots_of_users)) - # Ensure does not throw exception - yield self.handler.get_or_create_user(requester, 'a', display_name) + yield self.handler.get_or_create_user("requester", 'a', "display_name") + @defer.inlineCallbacks + def test_get_or_create_user_mau_not_blocked(self): self.hs.config.limit_usage_by_mau = True - - with self.assertRaises(RegistrationError): - yield self.handler.get_or_create_user(requester, 'b', display_name) - - store.count_monthly_users = Mock(return_value=defer.succeed(small_number_users)) - - self._macaroon_mock_generator("another_secret") - + self.store.count_monthly_users = Mock( + return_value=defer.succeed(self.small_number_of_users) + ) # Ensure does not throw exception - yield self.handler.get_or_create_user("@neil:matrix.org", 'c', "Neil") + yield self.handler.get_or_create_user("@user:server", 'c', "User") - self._macaroon_mock_generator("another another secret") - store.count_monthly_users = Mock(return_value=defer.succeed(lots_of_users)) + @defer.inlineCallbacks + def test_get_or_create_user_mau_blocked(self): + self.hs.config.limit_usage_by_mau = True + self.store.get_monthly_active_count = Mock( + return_value=defer.succeed(self.lots_of_users) + ) with self.assertRaises(RegistrationError): - yield self.handler.register(localpart=local_part) + yield self.handler.get_or_create_user("requester", 'b', "display_name") - self._macaroon_mock_generator("another another secret") - store.count_monthly_users = Mock(return_value=defer.succeed(lots_of_users)) + @defer.inlineCallbacks + def test_register_mau_blocked(self): + self.hs.config.limit_usage_by_mau = True + self.store.get_monthly_active_count = Mock( + return_value=defer.succeed(self.lots_of_users) + ) + with self.assertRaises(RegistrationError): + yield self.handler.register(localpart="local_part") + @defer.inlineCallbacks + def test_register_saml2_mau_blocked(self): + self.hs.config.limit_usage_by_mau = True + self.store.get_monthly_active_count = Mock( + return_value=defer.succeed(self.lots_of_users) + ) with self.assertRaises(RegistrationError): - yield self.handler.register_saml2(local_part) - - def _macaroon_mock_generator(self, secret): - """ - Reset macaroon generator in the case where the test creates multiple users - """ - macaroon_generator = Mock( - generate_access_token=Mock(return_value=secret)) - self.hs.get_macaroon_generator = Mock(return_value=macaroon_generator) - self.hs.handlers = RegistrationHandlers(self.hs) - self.handler = self.hs.get_handlers().registration_handler + yield self.handler.register_saml2(localpart="local_part") -- cgit 1.5.1 From 4ecb4bdac963ae967da810de6134d85b50c999f9 Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Thu, 2 Aug 2018 22:41:05 +0100 Subject: wip attempt at caching --- synapse/storage/monthly_active_users.py | 56 +++++++++++++++++++++++++++------ 1 file changed, 47 insertions(+), 9 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/monthly_active_users.py b/synapse/storage/monthly_active_users.py index aada9bd2b9..19846cefd9 100644 --- a/synapse/storage/monthly_active_users.py +++ b/synapse/storage/monthly_active_users.py @@ -1,4 +1,6 @@ from twisted.internet import defer +from synapse.util.caches.descriptors import cachedInlineCallbacks +from synapse.storage.engines import PostgresEngine, Sqlite3Engine from ._base import SQLBaseStore @@ -20,17 +22,53 @@ class MonthlyActiveUsersStore(SQLBaseStore): thirty_days_ago = ( int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30) ) - sql = "DELETE FROM monthly_active_users WHERE timestamp < ?" - txn.execute(sql, (thirty_days_ago,)) - sql = """ - DELETE FROM monthly_active_users - ORDER BY timestamp desc - LIMIT -1 OFFSET ? - """ - txn.execute(sql, (self.max_mau_value,)) + + if isinstance(self.database_engine, PostgresEngine): + sql = """ + DELETE FROM monthly_active_users + WHERE timestamp < ? + RETURNING user_id + """ + txn.execute(sql, (thirty_days_ago,)) + res = txn.fetchall() + for r in res: + self.is_user_monthly_active.invalidate(r) + + sql = """ + DELETE FROM monthly_active_users + ORDER BY timestamp desc + LIMIT -1 OFFSET ? + RETURNING user_id + """ + txn.execute(sql, (self.max_mau_value,)) + res = txn.fetchall() + for r in res: + self.is_user_monthly_active.invalidate(r) + print r + self.get_monthly_active_count.invalidate() + elif isinstance(self.database_engine, Sqlite3Engine): + sql = "DELETE FROM monthly_active_users WHERE timestamp < ?" + + txn.execute(sql, (thirty_days_ago,)) + sql = """ + DELETE FROM monthly_active_users + ORDER BY timestamp desc + LIMIT -1 OFFSET ? + """ + txn.execute(sql, (self.max_mau_value,)) + + # It seems poor to invalidate the whole cache, but the alternative + # is to select then delete which has it's own problem. + # It seems unlikely that anyone using this feature on large datasets + # would be using sqlite and if they are then there will be + # larger perf issues than this one to encourage an upgrade to postgres. + + self.is_user_monthly_active.invalidate_all() + self.get_monthly_active_count.invalidate_all() return self.runInteraction("reap_monthly_active_users", _reap_users) + @cachedInlineCallbacks(num_args=0) def get_monthly_active_count(self): """ Generates current count of monthly active users.abs @@ -65,7 +103,7 @@ class MonthlyActiveUsersStore(SQLBaseStore): lock=False, ) - @defer.inlineCallbacks + @cachedInlineCallbacks(num_args=1) def is_user_monthly_active(self, user_id): """ Checks if a given user is part of the monthly active user group -- cgit 1.5.1 From 5e2f7b80844bc8ee401beab494948c093338f404 Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Thu, 2 Aug 2018 22:47:48 +0100 Subject: typo --- synapse/storage/monthly_active_users.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse/storage') diff --git a/synapse/storage/monthly_active_users.py b/synapse/storage/monthly_active_users.py index 19846cefd9..2872ba4cae 100644 --- a/synapse/storage/monthly_active_users.py +++ b/synapse/storage/monthly_active_users.py @@ -58,7 +58,7 @@ class MonthlyActiveUsersStore(SQLBaseStore): txn.execute(sql, (self.max_mau_value,)) # It seems poor to invalidate the whole cache, but the alternative - # is to select then delete which has it's own problem. + # is to select then delete which has its own problems. # It seems unlikely that anyone using this feature on large datasets # would be using sqlite and if they are then there will be # larger perf issues than this one to encourage an upgrade to postgres. -- cgit 1.5.1 From 950807d93a264b6d10ece386d227dc4069f7d0da Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Fri, 3 Aug 2018 13:49:53 +0100 Subject: fix caching and tests --- synapse/app/homeserver.py | 1 - synapse/storage/monthly_active_users.py | 91 ++++++++++++++---------------- tests/storage/test_monthly_active_users.py | 50 +++++++++++----- 3 files changed, 80 insertions(+), 62 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 7d4ea493bc..3a67db8b30 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -64,7 +64,6 @@ from synapse.rest.media.v0.content_repository import ContentRepoResource from synapse.server import HomeServer from synapse.storage import are_all_users_on_domain from synapse.storage.engines import IncorrectDatabaseSetup, create_engine -from synapse.storage.monthly_active_users import MonthlyActiveUsersStore from synapse.storage.prepare_database import UpgradeDatabaseException, prepare_database from synapse.util.caches import CACHE_SIZE_FACTOR from synapse.util.httpresourcetree import create_resource_tree diff --git a/synapse/storage/monthly_active_users.py b/synapse/storage/monthly_active_users.py index 2872ba4cae..d06c90cbcc 100644 --- a/synapse/storage/monthly_active_users.py +++ b/synapse/storage/monthly_active_users.py @@ -1,6 +1,21 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from twisted.internet import defer -from synapse.util.caches.descriptors import cachedInlineCallbacks -from synapse.storage.engines import PostgresEngine, Sqlite3Engine + +from synapse.util.caches.descriptors import cached, cachedInlineCallbacks from ._base import SQLBaseStore @@ -9,7 +24,7 @@ class MonthlyActiveUsersStore(SQLBaseStore): def __init__(self, dbconn, hs): super(MonthlyActiveUsersStore, self).__init__(None, hs) self._clock = hs.get_clock() - self.max_mau_value = hs.config.max_mau_value + self.hs = hs def reap_monthly_active_users(self): """ @@ -19,62 +34,41 @@ class MonthlyActiveUsersStore(SQLBaseStore): Defered() """ def _reap_users(txn): + thirty_days_ago = ( int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30) ) - if isinstance(self.database_engine, PostgresEngine): - sql = """ - DELETE FROM monthly_active_users - WHERE timestamp < ? - RETURNING user_id - """ - txn.execute(sql, (thirty_days_ago,)) - res = txn.fetchall() - for r in res: - self.is_user_monthly_active.invalidate(r) - - sql = """ - DELETE FROM monthly_active_users - ORDER BY timestamp desc - LIMIT -1 OFFSET ? - RETURNING user_id - """ - txn.execute(sql, (self.max_mau_value,)) - res = txn.fetchall() - for r in res: - self.is_user_monthly_active.invalidate(r) - print r - self.get_monthly_active_count.invalidate() - elif isinstance(self.database_engine, Sqlite3Engine): - sql = "DELETE FROM monthly_active_users WHERE timestamp < ?" + sql = "DELETE FROM monthly_active_users WHERE timestamp < ?" - txn.execute(sql, (thirty_days_ago,)) - sql = """ - DELETE FROM monthly_active_users - ORDER BY timestamp desc - LIMIT -1 OFFSET ? - """ - txn.execute(sql, (self.max_mau_value,)) + txn.execute(sql, (thirty_days_ago,)) + sql = """ + DELETE FROM monthly_active_users + ORDER BY timestamp desc + LIMIT -1 OFFSET ? + """ + txn.execute(sql, (self.hs.config.max_mau_value,)) - # It seems poor to invalidate the whole cache, but the alternative - # is to select then delete which has its own problems. - # It seems unlikely that anyone using this feature on large datasets - # would be using sqlite and if they are then there will be - # larger perf issues than this one to encourage an upgrade to postgres. + res = self.runInteraction("reap_monthly_active_users", _reap_users) + # It seems poor to invalidate the whole cache, Postgres supports + # 'Returning' which would allow me to invalidate only the + # specific users, but sqlite has no way to do this and instead + # I would need to SELECT and the DELETE which without locking + # is racy. + # Have resolved to invalidate the whole cache for now and do + # something about it if and when the perf becomes significant + self.is_user_monthly_active.invalidate_all() + self.get_monthly_active_count.invalidate_all() + return res - self.is_user_monthly_active.invalidate_all() - self.get_monthly_active_count.invalidate_all() - - return self.runInteraction("reap_monthly_active_users", _reap_users) - - @cachedInlineCallbacks(num_args=0) + @cached(num_args=0) def get_monthly_active_count(self): """ Generates current count of monthly active users.abs Return: Defered(int): Number of current monthly active users """ + def _count_users(txn): sql = "SELECT COALESCE(count(*), 0) FROM monthly_active_users" @@ -91,7 +85,7 @@ class MonthlyActiveUsersStore(SQLBaseStore): Deferred(bool): True if a new entry was created, False if an existing one was updated. """ - return self._simple_upsert( + self._simple_upsert( desc="upsert_monthly_active_user", table="monthly_active_users", keyvalues={ @@ -102,6 +96,8 @@ class MonthlyActiveUsersStore(SQLBaseStore): }, lock=False, ) + self.is_user_monthly_active.invalidate((user_id,)) + self.get_monthly_active_count.invalidate(()) @cachedInlineCallbacks(num_args=1) def is_user_monthly_active(self, user_id): @@ -120,5 +116,4 @@ class MonthlyActiveUsersStore(SQLBaseStore): retcol="user_id", desc="is_user_monthly_active", ) - defer.returnValue(bool(user_present)) diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index d8d25a6069..c32109ecc5 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -1,6 +1,19 @@ -from twisted.internet import defer +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. -from synapse.storage.monthly_active_users import MonthlyActiveUsersStore +from twisted.internet import defer import tests.unittest import tests.utils @@ -10,20 +23,19 @@ from tests.utils import setup_test_homeserver class MonthlyActiveUsersTestCase(tests.unittest.TestCase): def __init__(self, *args, **kwargs): super(MonthlyActiveUsersTestCase, self).__init__(*args, **kwargs) - self.mau = None @defer.inlineCallbacks def setUp(self): - hs = yield setup_test_homeserver() - self.mau = MonthlyActiveUsersStore(None, hs) + self.hs = yield setup_test_homeserver() + self.store = self.hs.get_datastore() @defer.inlineCallbacks def test_can_insert_and_count_mau(self): - count = yield self.mau.get_monthly_active_count() + count = yield self.store.get_monthly_active_count() self.assertEqual(0, count) - yield self.mau.upsert_monthly_active_user("@user:server") - count = yield self.mau.get_monthly_active_count() + yield self.store.upsert_monthly_active_user("@user:server") + count = yield self.store.get_monthly_active_count() self.assertEqual(1, count) @@ -32,11 +44,23 @@ class MonthlyActiveUsersTestCase(tests.unittest.TestCase): user_id1 = "@user1:server" user_id2 = "@user2:server" user_id3 = "@user3:server" - result = yield self.mau.is_user_monthly_active(user_id1) + result = yield self.store.is_user_monthly_active(user_id1) self.assertFalse(result) - yield self.mau.upsert_monthly_active_user(user_id1) - yield self.mau.upsert_monthly_active_user(user_id2) - result = yield self.mau.is_user_monthly_active(user_id1) + yield self.store.upsert_monthly_active_user(user_id1) + yield self.store.upsert_monthly_active_user(user_id2) + result = yield self.store.is_user_monthly_active(user_id1) self.assertTrue(result) - result = yield self.mau.is_user_monthly_active(user_id3) + result = yield self.store.is_user_monthly_active(user_id3) self.assertFalse(result) + + @defer.inlineCallbacks + def test_reap_monthly_active_users(self): + self.hs.config.max_mau_value = 5 + initial_users = 10 + for i in range(initial_users): + yield self.store.upsert_monthly_active_user("@user%d:server" % i) + count = yield self.store.get_monthly_active_count() + self.assertTrue(count, initial_users) + yield self.store.reap_monthly_active_users() + count = yield self.store.get_monthly_active_count() + self.assertTrue(count, initial_users - self.hs.config.max_mau_value) -- cgit 1.5.1 From 0ca459ea334ff86016bda241c0c823178789c215 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 25 Jul 2018 22:10:39 +0100 Subject: Basic support for room versioning This is the first tranche of support for room versioning. It includes: * setting the default room version in the config file * new room_version param on the createRoom API * storing the version of newly-created rooms in the m.room.create event * fishing the version of existing rooms out of the m.room.create event --- synapse/api/constants.py | 6 ++++++ synapse/api/errors.py | 2 ++ synapse/config/server.py | 14 ++++++++++++ synapse/handlers/room.py | 27 ++++++++++++++++++++++- synapse/replication/slave/storage/events.py | 2 +- synapse/storage/state.py | 33 ++++++++++++++++++++++++++--- tests/utils.py | 4 ++++ 7 files changed, 83 insertions(+), 5 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 4df930c8d1..a27bf3b32d 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd # Copyright 2017 Vector Creations Ltd +# Copyright 2018 New Vector Ltd. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -94,3 +95,8 @@ class RoomCreationPreset(object): class ThirdPartyEntityKind(object): USER = "user" LOCATION = "location" + + +# vdh-test-version is a placeholder to get room versioning support working and tested +# until we have a working v2. +KNOWN_ROOM_VERSIONS = {"1", "vdh-test-version"} diff --git a/synapse/api/errors.py b/synapse/api/errors.py index b41d595059..477ca07a24 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -56,6 +57,7 @@ class Codes(object): CONSENT_NOT_GIVEN = "M_CONSENT_NOT_GIVEN" CANNOT_LEAVE_SERVER_NOTICE_ROOM = "M_CANNOT_LEAVE_SERVER_NOTICE_ROOM" MAU_LIMIT_EXCEEDED = "M_MAU_LIMIT_EXCEEDED" + UNSUPPORTED_ROOM_VERSION = "M_UNSUPPORTED_ROOM_VERSION" class CodeMessageException(RuntimeError): diff --git a/synapse/config/server.py b/synapse/config/server.py index 6a471a0a5e..68ef2789d3 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -16,6 +16,7 @@ import logging +from synapse.api.constants import KNOWN_ROOM_VERSIONS from synapse.http.endpoint import parse_and_validate_server_name from ._base import Config, ConfigError @@ -75,6 +76,16 @@ class ServerConfig(Config): ) else: self.max_mau_value = 0 + + # the version of rooms created by default on this server + self.default_room_version = str(config.get( + "default_room_version", "1", + )) + if self.default_room_version not in KNOWN_ROOM_VERSIONS: + raise ConfigError("Unrecognised value '%s' for default_room_version" % ( + self.default_room_version, + )) + # FIXME: federation_domain_whitelist needs sytests self.federation_domain_whitelist = None federation_domain_whitelist = config.get( @@ -249,6 +260,9 @@ class ServerConfig(Config): # (except those sent by local server admins). The default is False. # block_non_admin_invites: True + # The room_version of rooms which are created by default by this server. + # default_room_version: 1 + # Restrict federation to the following whitelist of domains. # N.B. we recommend also firewalling your federation listener to limit # inbound federation traffic as early as possible, rather than relying diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 7b7804d9b2..a526b684e9 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -21,9 +21,16 @@ import math import string from collections import OrderedDict +from six import string_types + from twisted.internet import defer -from synapse.api.constants import EventTypes, JoinRules, RoomCreationPreset +from synapse.api.constants import ( + KNOWN_ROOM_VERSIONS, + EventTypes, + JoinRules, + RoomCreationPreset, +) from synapse.api.errors import AuthError, Codes, StoreError, SynapseError from synapse.types import RoomAlias, RoomID, RoomStreamToken, StreamToken, UserID from synapse.util import stringutils @@ -99,6 +106,21 @@ class RoomCreationHandler(BaseHandler): if ratelimit: yield self.ratelimit(requester) + room_version = config.get("room_version", self.hs.config.default_room_version) + if not isinstance(room_version, string_types): + raise SynapseError( + 400, + "room_version must be a string", + Codes.BAD_JSON, + ) + + if room_version not in KNOWN_ROOM_VERSIONS: + raise SynapseError( + 400, + "Your homeserver does not support this room version", + Codes.UNSUPPORTED_ROOM_VERSION, + ) + if "room_alias_name" in config: for wchar in string.whitespace: if wchar in config["room_alias_name"]: @@ -184,6 +206,9 @@ class RoomCreationHandler(BaseHandler): creation_content = config.get("creation_content", {}) + # override any attempt to set room versions via the creation_content + creation_content["room_version"] = room_version + room_member_handler = self.hs.get_room_member_handler() yield self._send_events_for_new_room( diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index bdb5eee4af..4830c68f35 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -44,8 +44,8 @@ class SlavedEventStore(EventFederationWorkerStore, RoomMemberWorkerStore, EventPushActionsWorkerStore, StreamWorkerStore, - EventsWorkerStore, StateGroupWorkerStore, + EventsWorkerStore, SignatureWorkerStore, UserErasureWorkerStore, BaseSlavedStore): diff --git a/synapse/storage/state.py b/synapse/storage/state.py index b27b3ae144..17b14d464b 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -21,15 +21,17 @@ from six.moves import range from twisted.internet import defer +from synapse.api.constants import EventTypes +from synapse.api.errors import NotFoundError +from synapse.storage._base import SQLBaseStore from synapse.storage.background_updates import BackgroundUpdateStore from synapse.storage.engines import PostgresEngine +from synapse.storage.events_worker import EventsWorkerStore from synapse.util.caches import get_cache_factor_for, intern_string from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.dictionary_cache import DictionaryCache from synapse.util.stringutils import to_ascii -from ._base import SQLBaseStore - logger = logging.getLogger(__name__) @@ -46,7 +48,8 @@ class _GetStateGroupDelta(namedtuple("_GetStateGroupDelta", ("prev_group", "delt return len(self.delta_ids) if self.delta_ids else 0 -class StateGroupWorkerStore(SQLBaseStore): +# this inherits from EventsWorkerStore because it calls self.get_events +class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): """The parts of StateGroupStore that can be called from workers. """ @@ -61,6 +64,30 @@ class StateGroupWorkerStore(SQLBaseStore): "*stateGroupCache*", 500000 * get_cache_factor_for("stateGroupCache") ) + @defer.inlineCallbacks + def get_room_version(self, room_id): + """Get the room_version of a given room + + Args: + room_id (str) + + Returns: + Deferred[str] + + Raises: + NotFoundError if the room is unknown + """ + # for now we do this by looking at the create event. We may want to cache this + # more intelligently in future. + state_ids = yield self.get_current_state_ids(room_id) + create_id = state_ids.get((EventTypes.Create, "")) + + if not create_id: + raise NotFoundError("Unknown room") + + create_event = yield self.get_event(create_id) + defer.returnValue(create_event.content.get("room_version", "1")) + @cached(max_entries=100000, iterable=True) def get_current_state_ids(self, room_id): """Get the current state event ids for a room based on the diff --git a/tests/utils.py b/tests/utils.py index 9bff3ff3b9..9e188a8ed4 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -74,6 +74,10 @@ def setup_test_homeserver(name="test", datastore=None, config=None, reactor=None config.media_storage_providers = [] config.auto_join_rooms = [] + # we need a sane default_room_version, otherwise attempts to create rooms will + # fail. + config.default_room_version = "1" + # disable user directory updates, because they get done in the # background, which upsets the test runner. config.update_user_directory = False -- cgit 1.5.1 From e10830e9765eb3897da5af6ffb4809badb8e3009 Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Fri, 3 Aug 2018 17:55:50 +0100 Subject: wip commit - tests failing --- synapse/api/auth.py | 6 ++- synapse/storage/client_ips.py | 21 +--------- synapse/storage/monthly_active_users.py | 66 ++++++++++++++++++++++-------- tests/storage/test_client_ips.py | 12 +++--- tests/storage/test_monthly_active_users.py | 14 +++---- 5 files changed, 66 insertions(+), 53 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 943a488339..d8ebbbc6e8 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -775,7 +775,7 @@ class Auth(object): ) @defer.inlineCallbacks - def check_auth_blocking(self, error): + def check_auth_blocking(self): """Checks if the user should be rejected for some external reason, such as monthly active user limiting or global disable flag Args: @@ -785,4 +785,6 @@ class Auth(object): if self.hs.config.limit_usage_by_mau is True: current_mau = yield self.store.get_monthly_active_count() if current_mau >= self.hs.config.max_mau_value: - raise error + raise AuthError( + 403, "MAU Limit Exceeded", errcode=Codes.MAU_LIMIT_EXCEEDED + ) diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py index 83d64d1563..2489527f2c 100644 --- a/synapse/storage/client_ips.py +++ b/synapse/storage/client_ips.py @@ -86,7 +86,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): last_seen = self.client_ip_last_seen.get(key) except KeyError: last_seen = None - yield self._populate_monthly_active_users(user_id) + yield self.populate_monthly_active_users(user_id) # Rate-limited inserts if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY: return @@ -95,25 +95,6 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): self._batch_row_update[key] = (user_agent, device_id, now) - @defer.inlineCallbacks - def _populate_monthly_active_users(self, user_id): - """Checks on the state of monthly active user limits and optionally - add the user to the monthly active tables - - Args: - user_id(str): the user_id to query - """ - - store = self.hs.get_datastore() - if self.hs.config.limit_usage_by_mau: - is_user_monthly_active = yield store.is_user_monthly_active(user_id) - if is_user_monthly_active: - yield store.upsert_monthly_active_user(user_id) - else: - count = yield store.get_monthly_active_count() - if count < self.hs.config.max_mau_value: - yield store.upsert_monthly_active_user(user_id) - def _update_client_ips_batch(self): def update(): to_update = self._batch_row_update diff --git a/synapse/storage/monthly_active_users.py b/synapse/storage/monthly_active_users.py index d06c90cbcc..6def6830d0 100644 --- a/synapse/storage/monthly_active_users.py +++ b/synapse/storage/monthly_active_users.py @@ -19,6 +19,10 @@ from synapse.util.caches.descriptors import cached, cachedInlineCallbacks from ._base import SQLBaseStore +# Number of msec of granularity to store the monthly_active_user timestamp +# This means it is not necessary to update the table on every request +LAST_SEEN_GRANULARITY = 60 * 60 * 1000 + class MonthlyActiveUsersStore(SQLBaseStore): def __init__(self, dbconn, hs): @@ -30,8 +34,9 @@ class MonthlyActiveUsersStore(SQLBaseStore): """ Cleans out monthly active user table to ensure that no stale entries exist. - Return: - Defered() + + Returns: + Deferred[] """ def _reap_users(txn): @@ -49,7 +54,7 @@ class MonthlyActiveUsersStore(SQLBaseStore): """ txn.execute(sql, (self.hs.config.max_mau_value,)) - res = self.runInteraction("reap_monthly_active_users", _reap_users) + res = yield self.runInteraction("reap_monthly_active_users", _reap_users) # It seems poor to invalidate the whole cache, Postgres supports # 'Returning' which would allow me to invalidate only the # specific users, but sqlite has no way to do this and instead @@ -57,16 +62,16 @@ class MonthlyActiveUsersStore(SQLBaseStore): # is racy. # Have resolved to invalidate the whole cache for now and do # something about it if and when the perf becomes significant - self.is_user_monthly_active.invalidate_all() + self._user_last_seen_monthly_active.invalidate_all() self.get_monthly_active_count.invalidate_all() return res @cached(num_args=0) def get_monthly_active_count(self): - """ - Generates current count of monthly active users.abs - Return: - Defered(int): Number of current monthly active users + """Generates current count of monthly active users.abs + + Returns: + Defered[int]: Number of current monthly active users """ def _count_users(txn): @@ -82,10 +87,10 @@ class MonthlyActiveUsersStore(SQLBaseStore): Updates or inserts monthly active user member Arguments: user_id (str): user to add/update - Deferred(bool): True if a new entry was created, False if an + Deferred[bool]: True if a new entry was created, False if an existing one was updated. """ - self._simple_upsert( + is_insert = self._simple_upsert( desc="upsert_monthly_active_user", table="monthly_active_users", keyvalues={ @@ -96,24 +101,49 @@ class MonthlyActiveUsersStore(SQLBaseStore): }, lock=False, ) - self.is_user_monthly_active.invalidate((user_id,)) - self.get_monthly_active_count.invalidate(()) + if is_insert: + self._user_last_seen_monthly_active.invalidate((user_id,)) + self.get_monthly_active_count.invalidate(()) @cachedInlineCallbacks(num_args=1) - def is_user_monthly_active(self, user_id): + def _user_last_seen_monthly_active(self, user_id): """ Checks if a given user is part of the monthly active user group Arguments: user_id (str): user to add/update Return: - bool : True if user part of group, False otherwise + int : timestamp since last seen, None if never seen + """ - user_present = yield self._simple_select_onecol( + result = yield self._simple_select_onecol( table="monthly_active_users", keyvalues={ "user_id": user_id, }, - retcol="user_id", - desc="is_user_monthly_active", + retcol="timestamp", + desc="_user_last_seen_monthly_active", ) - defer.returnValue(bool(user_present)) + timestamp = None + if len(result) > 0: + timestamp = result[0] + defer.returnValue(timestamp) + + @defer.inlineCallbacks + def populate_monthly_active_users(self, user_id): + """Checks on the state of monthly active user limits and optionally + add the user to the monthly active tables + + Args: + user_id(str): the user_id to query + """ + + if self.hs.config.limit_usage_by_mau: + last_seen_timestamp = yield self._user_last_seen_monthly_active(user_id) + now = self.hs.get_clock().time_msec() + + if last_seen_timestamp is None: + count = yield self.get_monthly_active_count() + if count < self.hs.config.max_mau_value: + yield self.upsert_monthly_active_user(user_id) + elif now - last_seen_timestamp > LAST_SEEN_GRANULARITY: + yield self.upsert_monthly_active_user(user_id) diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index e1552510cc..7a58c6eb24 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -64,7 +64,7 @@ class ClientIpStoreTestCase(tests.unittest.TestCase): yield self.store.insert_client_ip( user_id, "access_token", "ip", "user_agent", "device_id", ) - active = yield self.store.is_user_monthly_active(user_id) + active = yield self.store._user_last_seen_monthly_active(user_id) self.assertFalse(active) @defer.inlineCallbacks @@ -80,7 +80,7 @@ class ClientIpStoreTestCase(tests.unittest.TestCase): yield self.store.insert_client_ip( user_id, "access_token", "ip", "user_agent", "device_id", ) - active = yield self.store.is_user_monthly_active(user_id) + active = yield self.store._user_last_seen_monthly_active(user_id) self.assertFalse(active) @defer.inlineCallbacks @@ -88,13 +88,13 @@ class ClientIpStoreTestCase(tests.unittest.TestCase): self.hs.config.limit_usage_by_mau = True self.hs.config.max_mau_value = 50 user_id = "@user:server" - active = yield self.store.is_user_monthly_active(user_id) + active = yield self.store._user_last_seen_monthly_active(user_id) self.assertFalse(active) yield self.store.insert_client_ip( user_id, "access_token", "ip", "user_agent", "device_id", ) - active = yield self.store.is_user_monthly_active(user_id) + active = yield self.store._user_last_seen_monthly_active(user_id) self.assertTrue(active) @defer.inlineCallbacks @@ -103,7 +103,7 @@ class ClientIpStoreTestCase(tests.unittest.TestCase): self.hs.config.max_mau_value = 50 user_id = "@user:server" - active = yield self.store.is_user_monthly_active(user_id) + active = yield self.store._user_last_seen_monthly_active(user_id) self.assertFalse(active) yield self.store.insert_client_ip( @@ -112,5 +112,5 @@ class ClientIpStoreTestCase(tests.unittest.TestCase): yield self.store.insert_client_ip( user_id, "access_token", "ip", "user_agent", "device_id", ) - active = yield self.store.is_user_monthly_active(user_id) + active = yield self.store._user_last_seen_monthly_active(user_id) self.assertTrue(active) diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index c32109ecc5..0bfd24a7fb 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -40,18 +40,18 @@ class MonthlyActiveUsersTestCase(tests.unittest.TestCase): self.assertEqual(1, count) @defer.inlineCallbacks - def test_is_user_monthly_active(self): + def test__user_last_seen_monthly_active(self): user_id1 = "@user1:server" user_id2 = "@user2:server" user_id3 = "@user3:server" - result = yield self.store.is_user_monthly_active(user_id1) - self.assertFalse(result) + result = yield self.store._user_last_seen_monthly_active(user_id1) + self.assertFalse(result == 0) yield self.store.upsert_monthly_active_user(user_id1) yield self.store.upsert_monthly_active_user(user_id2) - result = yield self.store.is_user_monthly_active(user_id1) - self.assertTrue(result) - result = yield self.store.is_user_monthly_active(user_id3) - self.assertFalse(result) + result = yield self.store._user_last_seen_monthly_active(user_id1) + self.assertTrue(result > 0) + result = yield self.store._user_last_seen_monthly_active(user_id3) + self.assertFalse(result == 0) @defer.inlineCallbacks def test_reap_monthly_active_users(self): -- cgit 1.5.1 From 886be75ad1bc60e016611b453b9644e8db17a9f1 Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Fri, 3 Aug 2018 22:29:03 +0100 Subject: bug fixes --- synapse/handlers/auth.py | 15 ++------------- synapse/handlers/register.py | 8 ++++---- synapse/storage/monthly_active_users.py | 3 +-- tests/api/test_auth.py | 10 +++------- tests/handlers/test_register.py | 1 - 5 files changed, 10 insertions(+), 27 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 8f9cff92e8..7ea8ce9f94 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -520,7 +520,7 @@ class AuthHandler(BaseHandler): """ logger.info("Logging in user %s on device %s", user_id, device_id) access_token = yield self.issue_access_token(user_id, device_id) - yield self._check_mau_limits() + yield self.auth.check_auth_blocking() # the device *should* have been registered before we got here; however, # it's possible we raced against a DELETE operation. The thing we @@ -734,7 +734,7 @@ class AuthHandler(BaseHandler): @defer.inlineCallbacks def validate_short_term_login_token_and_get_user_id(self, login_token): - yield self._check_mau_limits() + yield self.auth.check_auth_blocking() auth_api = self.hs.get_auth() user_id = None try: @@ -907,17 +907,6 @@ class AuthHandler(BaseHandler): else: return defer.succeed(False) - @defer.inlineCallbacks - def _check_mau_limits(self): - """ - Ensure that if mau blocking is enabled that invalid users cannot - log in. - """ - error = AuthError( - 403, "Monthly Active User limits exceeded", errcode=Codes.MAU_LIMIT_EXCEEDED - ) - yield self.auth.check_auth_blocking(error) - @attr.s class MacaroonGenerator(object): diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 706ed8c292..8cf0a36a8f 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -540,7 +540,7 @@ class RegistrationHandler(BaseHandler): Do not accept registrations if monthly active user limits exceeded and limiting is enabled """ - error = RegistrationError( - 403, "Monthly Active User limits exceeded", errcode=Codes.MAU_LIMIT_EXCEEDED - ) - yield self.auth.check_auth_blocking(error) + try: + yield self.auth.check_auth_blocking() + except AuthError as e: + raise RegistrationError(e.code, e.message, e.errcode) diff --git a/synapse/storage/monthly_active_users.py b/synapse/storage/monthly_active_users.py index 6def6830d0..135837507a 100644 --- a/synapse/storage/monthly_active_users.py +++ b/synapse/storage/monthly_active_users.py @@ -54,7 +54,7 @@ class MonthlyActiveUsersStore(SQLBaseStore): """ txn.execute(sql, (self.hs.config.max_mau_value,)) - res = yield self.runInteraction("reap_monthly_active_users", _reap_users) + yield self.runInteraction("reap_monthly_active_users", _reap_users) # It seems poor to invalidate the whole cache, Postgres supports # 'Returning' which would allow me to invalidate only the # specific users, but sqlite has no way to do this and instead @@ -64,7 +64,6 @@ class MonthlyActiveUsersStore(SQLBaseStore): # something about it if and when the perf becomes significant self._user_last_seen_monthly_active.invalidate_all() self.get_monthly_active_count.invalidate_all() - return res @cached(num_args=0) def get_monthly_active_count(self): diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index 54bdf28663..e963963c73 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -452,12 +452,8 @@ class AuthTestCase(unittest.TestCase): lots_of_users = 100 small_number_of_users = 1 - error = AuthError( - 403, "MAU Limit Exceeded", errcode=Codes.MAU_LIMIT_EXCEEDED - ) - # Ensure no error thrown - yield self.auth.check_auth_blocking(error) + yield self.auth.check_auth_blocking() self.hs.config.limit_usage_by_mau = True @@ -466,10 +462,10 @@ class AuthTestCase(unittest.TestCase): ) with self.assertRaises(AuthError): - yield self.auth.check_auth_blocking(error) + yield self.auth.check_auth_blocking() # Ensure does not throw an error self.store.get_monthly_active_count = Mock( return_value=defer.succeed(small_number_of_users) ) - yield self.auth.check_auth_blocking(error) + yield self.auth.check_auth_blocking() diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index 6b5b8b3772..4ea59a58de 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -104,7 +104,6 @@ class RegistrationTestCase(unittest.TestCase): self.store.get_monthly_active_count = Mock( return_value=defer.succeed(self.lots_of_users) ) - with self.assertRaises(RegistrationError): yield self.handler.get_or_create_user("requester", 'b', "display_name") -- cgit 1.5.1 From 62ace05c4521caeed023e5b9dadd993afe5c154f Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 26 Jul 2018 13:31:59 +0100 Subject: Move necessary storage functions to worker classes --- synapse/storage/events.py | 82 --------------------------------------- synapse/storage/events_worker.py | 84 ++++++++++++++++++++++++++++++++++++++++ synapse/storage/room.py | 32 +++++++-------- 3 files changed, 100 insertions(+), 98 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/events.py b/synapse/storage/events.py index e8e5a0fe44..5374796fd7 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -1430,88 +1430,6 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore (event.event_id, event.redacts) ) - @defer.inlineCallbacks - def have_events_in_timeline(self, event_ids): - """Given a list of event ids, check if we have already processed and - stored them as non outliers. - """ - rows = yield self._simple_select_many_batch( - table="events", - retcols=("event_id",), - column="event_id", - iterable=list(event_ids), - keyvalues={"outlier": False}, - desc="have_events_in_timeline", - ) - - defer.returnValue(set(r["event_id"] for r in rows)) - - @defer.inlineCallbacks - def have_seen_events(self, event_ids): - """Given a list of event ids, check if we have already processed them. - - Args: - event_ids (iterable[str]): - - Returns: - Deferred[set[str]]: The events we have already seen. - """ - results = set() - - def have_seen_events_txn(txn, chunk): - sql = ( - "SELECT event_id FROM events as e WHERE e.event_id IN (%s)" - % (",".join("?" * len(chunk)), ) - ) - txn.execute(sql, chunk) - for (event_id, ) in txn: - results.add(event_id) - - # break the input up into chunks of 100 - input_iterator = iter(event_ids) - for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), - []): - yield self.runInteraction( - "have_seen_events", - have_seen_events_txn, - chunk, - ) - defer.returnValue(results) - - def get_seen_events_with_rejections(self, event_ids): - """Given a list of event ids, check if we rejected them. - - Args: - event_ids (list[str]) - - Returns: - Deferred[dict[str, str|None): - Has an entry for each event id we already have seen. Maps to - the rejected reason string if we rejected the event, else maps - to None. - """ - if not event_ids: - return defer.succeed({}) - - def f(txn): - sql = ( - "SELECT e.event_id, reason FROM events as e " - "LEFT JOIN rejections as r ON e.event_id = r.event_id " - "WHERE e.event_id = ?" - ) - - res = {} - for event_id in event_ids: - txn.execute(sql, (event_id,)) - row = txn.fetchone() - if row: - _, rejected = row - res[event_id] = rejected - - return res - - return self.runInteraction("get_rejection_reasons", f) - @defer.inlineCallbacks def count_daily_messages(self): """ diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py index 9b4cfeb899..e1dc2c62fd 100644 --- a/synapse/storage/events_worker.py +++ b/synapse/storage/events_worker.py @@ -12,7 +12,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import itertools import logging + from collections import namedtuple from canonicaljson import json @@ -442,3 +444,85 @@ class EventsWorkerStore(SQLBaseStore): self._get_event_cache.prefill((original_ev.event_id,), cache_entry) defer.returnValue(cache_entry) + + @defer.inlineCallbacks + def have_events_in_timeline(self, event_ids): + """Given a list of event ids, check if we have already processed and + stored them as non outliers. + """ + rows = yield self._simple_select_many_batch( + table="events", + retcols=("event_id",), + column="event_id", + iterable=list(event_ids), + keyvalues={"outlier": False}, + desc="have_events_in_timeline", + ) + + defer.returnValue(set(r["event_id"] for r in rows)) + + @defer.inlineCallbacks + def have_seen_events(self, event_ids): + """Given a list of event ids, check if we have already processed them. + + Args: + event_ids (iterable[str]): + + Returns: + Deferred[set[str]]: The events we have already seen. + """ + results = set() + + def have_seen_events_txn(txn, chunk): + sql = ( + "SELECT event_id FROM events as e WHERE e.event_id IN (%s)" + % (",".join("?" * len(chunk)), ) + ) + txn.execute(sql, chunk) + for (event_id, ) in txn: + results.add(event_id) + + # break the input up into chunks of 100 + input_iterator = iter(event_ids) + for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), + []): + yield self.runInteraction( + "have_seen_events", + have_seen_events_txn, + chunk, + ) + defer.returnValue(results) + + def get_seen_events_with_rejections(self, event_ids): + """Given a list of event ids, check if we rejected them. + + Args: + event_ids (list[str]) + + Returns: + Deferred[dict[str, str|None): + Has an entry for each event id we already have seen. Maps to + the rejected reason string if we rejected the event, else maps + to None. + """ + if not event_ids: + return defer.succeed({}) + + def f(txn): + sql = ( + "SELECT e.event_id, reason FROM events as e " + "LEFT JOIN rejections as r ON e.event_id = r.event_id " + "WHERE e.event_id = ?" + ) + + res = {} + for event_id in event_ids: + txn.execute(sql, (event_id,)) + row = txn.fetchone() + if row: + _, rejected = row + res[event_id] = rejected + + return res + + return self.runInteraction("get_rejection_reasons", f) diff --git a/synapse/storage/room.py b/synapse/storage/room.py index 3147fb6827..3378fc77d1 100644 --- a/synapse/storage/room.py +++ b/synapse/storage/room.py @@ -41,6 +41,22 @@ RatelimitOverride = collections.namedtuple( class RoomWorkerStore(SQLBaseStore): + def get_room(self, room_id): + """Retrieve a room. + + Args: + room_id (str): The ID of the room to retrieve. + Returns: + A namedtuple containing the room information, or an empty list. + """ + return self._simple_select_one( + table="rooms", + keyvalues={"room_id": room_id}, + retcols=("room_id", "is_public", "creator"), + desc="get_room", + allow_none=True, + ) + def get_public_room_ids(self): return self._simple_select_onecol( table="rooms", @@ -215,22 +231,6 @@ class RoomStore(RoomWorkerStore, SearchStore): logger.error("store_room with room_id=%s failed: %s", room_id, e) raise StoreError(500, "Problem creating room.") - def get_room(self, room_id): - """Retrieve a room. - - Args: - room_id (str): The ID of the room to retrieve. - Returns: - A namedtuple containing the room information, or an empty list. - """ - return self._simple_select_one( - table="rooms", - keyvalues={"room_id": room_id}, - retcols=("room_id", "is_public", "creator"), - desc="get_room", - allow_none=True, - ) - @defer.inlineCallbacks def set_room_is_public(self, room_id, is_public): def set_room_is_public_txn(txn, next_id): -- cgit 1.5.1 From 96a9a296455e2c08be4e0375cf784e4113fa51e1 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 26 Jul 2018 13:35:36 +0100 Subject: Pull in necessary stores in federation_reader --- synapse/app/federation_reader.py | 2 ++ synapse/storage/events_worker.py | 1 - 2 files changed, 2 insertions(+), 1 deletion(-) (limited to 'synapse/storage') diff --git a/synapse/app/federation_reader.py b/synapse/app/federation_reader.py index c512b4be87..0c557a8242 100644 --- a/synapse/app/federation_reader.py +++ b/synapse/app/federation_reader.py @@ -36,6 +36,7 @@ from synapse.replication.slave.storage.appservice import SlavedApplicationServic from synapse.replication.slave.storage.directory import DirectoryStore from synapse.replication.slave.storage.events import SlavedEventStore from synapse.replication.slave.storage.keys import SlavedKeyStore +from synapse.replication.slave.storage.profile import SlavedProfileStore from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore from synapse.replication.slave.storage.pushers import SlavedPusherStore from synapse.replication.slave.storage.receipts import SlavedReceiptsStore @@ -53,6 +54,7 @@ logger = logging.getLogger("synapse.app.federation_reader") class FederationReaderSlavedStore( + SlavedProfileStore, SlavedApplicationServiceStore, SlavedPusherStore, SlavedPushRuleStore, diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py index e1dc2c62fd..59822178ff 100644 --- a/synapse/storage/events_worker.py +++ b/synapse/storage/events_worker.py @@ -14,7 +14,6 @@ # limitations under the License. import itertools import logging - from collections import namedtuple from canonicaljson import json -- cgit 1.5.1 From 16d78be315fd71d8bdb39af53317b3f7dd1dbb32 Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Mon, 6 Aug 2018 17:51:15 +0100 Subject: make use of _simple_select_one_onecol, improved comments --- synapse/storage/monthly_active_users.py | 19 +++++++++++-------- .../storage/schema/delta/51/monthly_active_users.sql | 4 ++++ 2 files changed, 15 insertions(+), 8 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/monthly_active_users.py b/synapse/storage/monthly_active_users.py index 135837507a..c28c698c6c 100644 --- a/synapse/storage/monthly_active_users.py +++ b/synapse/storage/monthly_active_users.py @@ -15,7 +15,7 @@ from twisted.internet import defer -from synapse.util.caches.descriptors import cached, cachedInlineCallbacks +from synapse.util.caches.descriptors import cached from ._base import SQLBaseStore @@ -104,7 +104,7 @@ class MonthlyActiveUsersStore(SQLBaseStore): self._user_last_seen_monthly_active.invalidate((user_id,)) self.get_monthly_active_count.invalidate(()) - @cachedInlineCallbacks(num_args=1) + @cached(num_args=1) def _user_last_seen_monthly_active(self, user_id): """ Checks if a given user is part of the monthly active user group @@ -114,18 +114,16 @@ class MonthlyActiveUsersStore(SQLBaseStore): int : timestamp since last seen, None if never seen """ - result = yield self._simple_select_onecol( + + return(self._simple_select_one_onecol( table="monthly_active_users", keyvalues={ "user_id": user_id, }, retcol="timestamp", + allow_none=True, desc="_user_last_seen_monthly_active", - ) - timestamp = None - if len(result) > 0: - timestamp = result[0] - defer.returnValue(timestamp) + )) @defer.inlineCallbacks def populate_monthly_active_users(self, user_id): @@ -140,6 +138,11 @@ class MonthlyActiveUsersStore(SQLBaseStore): last_seen_timestamp = yield self._user_last_seen_monthly_active(user_id) now = self.hs.get_clock().time_msec() + # We want to reduce to the total number of db writes, and are happy + # to trade accuracy of timestamp in order to lighten load. This means + # We always insert new users (where MAU threshold has not been reached), + # but only update if we have not previously seen the user for + # LAST_SEEN_GRANULARITY ms if last_seen_timestamp is None: count = yield self.get_monthly_active_count() if count < self.hs.config.max_mau_value: diff --git a/synapse/storage/schema/delta/51/monthly_active_users.sql b/synapse/storage/schema/delta/51/monthly_active_users.sql index f2b6d3e31e..10aac90ce1 100644 --- a/synapse/storage/schema/delta/51/monthly_active_users.sql +++ b/synapse/storage/schema/delta/51/monthly_active_users.sql @@ -16,6 +16,10 @@ -- a table of monthly active users, for use where blocking based on mau limits CREATE TABLE monthly_active_users ( user_id TEXT NOT NULL, + -- Last time we saw the user. Not guaranteed to be accurate due to rate limiting + -- on updates, Granularity of updates governed by + -- syanpse.storage.monthly_active_users.LAST_SEEN_GRANULARITY + -- Measured in ms since epoch. timestamp BIGINT NOT NULL ); -- cgit 1.5.1 From e54794f5b658992627eab5400b13199128eec0a1 Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Mon, 6 Aug 2018 21:55:54 +0100 Subject: Fix postgres compatibility bug --- synapse/storage/monthly_active_users.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/monthly_active_users.py b/synapse/storage/monthly_active_users.py index c28c698c6c..abe1e6bb99 100644 --- a/synapse/storage/monthly_active_users.py +++ b/synapse/storage/monthly_active_users.py @@ -43,14 +43,25 @@ class MonthlyActiveUsersStore(SQLBaseStore): thirty_days_ago = ( int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30) ) - + # Purge stale users sql = "DELETE FROM monthly_active_users WHERE timestamp < ?" - txn.execute(sql, (thirty_days_ago,)) + + # If MAU user count still exceeds the MAU threshold, then delete on + # a least recently active basis. + # Note it is not possible to write this query using OFFSET due to + # incompatibilities in how sqlite an postgres support the feature. + # sqlite requires 'LIMIT -1 OFFSET ?', the LIMIT must be present + # While Postgres does not require 'LIMIT', but also does not support + # negative LIMIT values. So there is no way to write it that both can + # support sql = """ DELETE FROM monthly_active_users - ORDER BY timestamp desc - LIMIT -1 OFFSET ? + WHERE user_id NOT IN ( + SELECT user_id FROM monthly_active_users + ORDER BY timestamp DESC + LIMIT ? + ) """ txn.execute(sql, (self.hs.config.max_mau_value,)) -- cgit 1.5.1 From a74b25faaaf47b4696cb30207ef2eae6a3a5d8c7 Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Mon, 6 Aug 2018 23:25:25 +0100 Subject: WIP building out mau reserved users --- synapse/storage/monthly_active_users.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) (limited to 'synapse/storage') diff --git a/synapse/storage/monthly_active_users.py b/synapse/storage/monthly_active_users.py index abe1e6bb99..6a37d6fc22 100644 --- a/synapse/storage/monthly_active_users.py +++ b/synapse/storage/monthly_active_users.py @@ -19,16 +19,30 @@ from synapse.util.caches.descriptors import cached from ._base import SQLBaseStore + # Number of msec of granularity to store the monthly_active_user timestamp # This means it is not necessary to update the table on every request LAST_SEEN_GRANULARITY = 60 * 60 * 1000 class MonthlyActiveUsersStore(SQLBaseStore): + @defer.inlineCallbacks def __init__(self, dbconn, hs): super(MonthlyActiveUsersStore, self).__init__(None, hs) self._clock = hs.get_clock() self.hs = hs + threepids = self.hs.config.mau_limits_reserved_threepids + self.reserved_user_ids = set() + for tp in threepids: + user_id = yield hs.get_datastore().get_user_id_by_threepid( + tp["medium"], tp["address"] + ) + if user_id: + self.reserved_user_ids.add(user_id) + else: + logger.warning( + "mau limit reserved threepid %s not found in db" % tp + ) def reap_monthly_active_users(self): """ @@ -78,7 +92,7 @@ class MonthlyActiveUsersStore(SQLBaseStore): @cached(num_args=0) def get_monthly_active_count(self): - """Generates current count of monthly active users.abs + """Generates current count of monthly active users Returns: Defered[int]: Number of current monthly active users -- cgit 1.5.1 From ca9bc1f4feaaabce48595ee2e387529f5bd365e6 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 7 Aug 2018 10:00:03 +0100 Subject: Fix occasional glitches in the synapse_event_persisted_position metric Every so often this metric glitched to a negative number. I'm assuming it was due to backfilled events. --- synapse/storage/events.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/events.py b/synapse/storage/events.py index e8e5a0fe44..ce32e8fefd 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -485,9 +485,14 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore new_forward_extremeties=new_forward_extremeties, ) persist_event_counter.inc(len(chunk)) - synapse.metrics.event_persisted_position.set( - chunk[-1][0].internal_metadata.stream_ordering, - ) + + if not backfilled: + # backfilled events have negative stream orderings, so we don't + # want to set the event_persisted_position to that. + synapse.metrics.event_persisted_position.set( + chunk[-1][0].internal_metadata.stream_ordering, + ) + for event, context in chunk: if context.app_service: origin_type = "local" -- cgit 1.5.1 From 495cb100d127212d55a46c177706d732950e70be Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 6 Aug 2018 14:46:17 +0100 Subject: Allow profile changes to happen on workers --- synapse/app/event_creator.py | 8 ++++ synapse/handlers/profile.py | 26 ++++++++++--- synapse/replication/http/__init__.py | 3 +- synapse/replication/http/profile.py | 71 ++++++++++++++++++++++++++++++++++++ synapse/storage/profile.py | 6 ++- 5 files changed, 106 insertions(+), 8 deletions(-) create mode 100644 synapse/replication/http/profile.py (limited to 'synapse/storage') diff --git a/synapse/app/event_creator.py b/synapse/app/event_creator.py index 374f115644..0c56fc777b 100644 --- a/synapse/app/event_creator.py +++ b/synapse/app/event_creator.py @@ -45,6 +45,11 @@ from synapse.replication.slave.storage.registration import SlavedRegistrationSto from synapse.replication.slave.storage.room import RoomStore from synapse.replication.slave.storage.transactions import TransactionStore from synapse.replication.tcp.client import ReplicationClientHandler +from synapse.rest.client.v1.profile import ( + ProfileAvatarURLRestServlet, + ProfileDisplaynameRestServlet, + ProfileRestServlet, +) from synapse.rest.client.v1.room import ( JoinRoomAliasServlet, RoomMembershipRestServlet, @@ -101,6 +106,9 @@ class EventCreatorServer(HomeServer): RoomMembershipRestServlet(self).register(resource) RoomStateEventRestServlet(self).register(resource) JoinRoomAliasServlet(self).register(resource) + ProfileAvatarURLRestServlet(self).register(resource) + ProfileDisplaynameRestServlet(self).register(resource) + ProfileRestServlet(self).register(resource) resources.update({ "/_matrix/client/r0": resource, "/_matrix/client/unstable": resource, diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index cb5c6d587e..a3bdb1830f 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -19,6 +19,7 @@ from twisted.internet import defer from synapse.api.errors import AuthError, CodeMessageException, SynapseError from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.replication.http.profile import ReplicationHandleProfileChangeRestServlet from synapse.types import UserID, get_domain_from_id from ._base import BaseHandler @@ -45,6 +46,10 @@ class ProfileHandler(BaseHandler): self._start_update_remote_profile_cache, self.PROFILE_UPDATE_MS, ) + self._notify_master_profile_change = ( + ReplicationHandleProfileChangeRestServlet.make_client(hs) + ) + @defer.inlineCallbacks def get_profile(self, user_id): target_user = UserID.from_string(user_id) @@ -147,10 +152,16 @@ class ProfileHandler(BaseHandler): ) if self.hs.config.user_directory_search_all_users: - profile = yield self.store.get_profileinfo(target_user.localpart) - yield self.user_directory_handler.handle_local_profile_change( - target_user.to_string(), profile - ) + if self.hs.config.worker_app is None: + profile = yield self.store.get_profileinfo(target_user.localpart) + yield self.user_directory_handler.handle_local_profile_change( + target_user.to_string(), profile + ) + else: + yield self._notify_master_profile_change( + requester=requester, + user_id=target_user.to_string(), + ) yield self._update_join_states(requester, target_user) @@ -196,11 +207,16 @@ class ProfileHandler(BaseHandler): target_user.localpart, new_avatar_url ) - if self.hs.config.user_directory_search_all_users: + if self.hs.config.worker_app is None: profile = yield self.store.get_profileinfo(target_user.localpart) yield self.user_directory_handler.handle_local_profile_change( target_user.to_string(), profile ) + else: + yield self._notify_master_profile_change( + requester=requester, + user_id=target_user.to_string(), + ) yield self._update_join_states(requester, target_user) diff --git a/synapse/replication/http/__init__.py b/synapse/replication/http/__init__.py index 589ee94c66..1fbbdf53e9 100644 --- a/synapse/replication/http/__init__.py +++ b/synapse/replication/http/__init__.py @@ -14,7 +14,7 @@ # limitations under the License. from synapse.http.server import JsonResource -from synapse.replication.http import membership, send_event +from synapse.replication.http import membership, profile, send_event REPLICATION_PREFIX = "/_synapse/replication" @@ -27,3 +27,4 @@ class ReplicationRestResource(JsonResource): def register_servlets(self, hs): send_event.register_servlets(hs, self) membership.register_servlets(hs, self) + profile.register_servlets(hs, self) diff --git a/synapse/replication/http/profile.py b/synapse/replication/http/profile.py new file mode 100644 index 0000000000..c4d54c936f --- /dev/null +++ b/synapse/replication/http/profile.py @@ -0,0 +1,71 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from twisted.internet import defer + +from synapse.http.servlet import parse_json_object_from_request +from synapse.replication.http._base import ReplicationEndpoint +from synapse.types import Requester, UserID + +logger = logging.getLogger(__name__) + + +class ReplicationHandleProfileChangeRestServlet(ReplicationEndpoint): + NAME = "profile_changed" + PATH_ARGS = ("user_id",) + POST = True + + def __init__(self, hs): + super(ReplicationHandleProfileChangeRestServlet, self).__init__(hs) + + self.user_directory_handler = hs.get_user_directory_handler() + self.store = hs.get_datastore() + self.clock = hs.get_clock() + + @staticmethod + def _serialize_payload(requester, user_id): + """ + Args: + requester (Requester) + user_id (str) + """ + + return { + "requester": requester.serialize(), + } + + @defer.inlineCallbacks + def _handle_request(self, request, user_id): + content = parse_json_object_from_request(request) + + requester = Requester.deserialize(self.store, content["requester"]) + + if requester.user: + request.authenticated_entity = requester.user.to_string() + + target_user = UserID.from_string(user_id) + + profile = yield self.store.get_profileinfo(target_user.localpart) + yield self.user_directory_handler.handle_local_profile_change( + user_id, profile + ) + + defer.returnValue((200, {})) + + +def register_servlets(hs, http_server): + ReplicationHandleProfileChangeRestServlet(hs).register(http_server) diff --git a/synapse/storage/profile.py b/synapse/storage/profile.py index 60295da254..246ab836bd 100644 --- a/synapse/storage/profile.py +++ b/synapse/storage/profile.py @@ -71,8 +71,6 @@ class ProfileWorkerStore(SQLBaseStore): desc="get_from_remote_profile_cache", ) - -class ProfileStore(ProfileWorkerStore): def create_profile(self, user_localpart): return self._simple_insert( table="profiles", @@ -182,3 +180,7 @@ class ProfileStore(ProfileWorkerStore): if res: defer.returnValue(True) + + +class ProfileStore(ProfileWorkerStore): + pass -- cgit 1.5.1 From cd9765805e2792ad7d6b0b014447f51d2b9add8c Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 7 Aug 2018 10:48:31 +0100 Subject: Allow ratelimiting on workers --- synapse/storage/room.py | 58 ++++++++++++++++++++++++------------------------- 1 file changed, 29 insertions(+), 29 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/room.py b/synapse/storage/room.py index 3147fb6827..1e473cc188 100644 --- a/synapse/storage/room.py +++ b/synapse/storage/room.py @@ -170,6 +170,35 @@ class RoomWorkerStore(SQLBaseStore): desc="is_room_blocked", ) + @cachedInlineCallbacks(max_entries=10000) + def get_ratelimit_for_user(self, user_id): + """Check if there are any overrides for ratelimiting for the given + user + + Args: + user_id (str) + + Returns: + RatelimitOverride if there is an override, else None. If the contents + of RatelimitOverride are None or 0 then ratelimitng has been + disabled for that user entirely. + """ + row = yield self._simple_select_one( + table="ratelimit_override", + keyvalues={"user_id": user_id}, + retcols=("messages_per_second", "burst_count"), + allow_none=True, + desc="get_ratelimit_for_user", + ) + + if row: + defer.returnValue(RatelimitOverride( + messages_per_second=row["messages_per_second"], + burst_count=row["burst_count"], + )) + else: + defer.returnValue(None) + class RoomStore(RoomWorkerStore, SearchStore): @@ -469,35 +498,6 @@ class RoomStore(RoomWorkerStore, SearchStore): "get_all_new_public_rooms", get_all_new_public_rooms ) - @cachedInlineCallbacks(max_entries=10000) - def get_ratelimit_for_user(self, user_id): - """Check if there are any overrides for ratelimiting for the given - user - - Args: - user_id (str) - - Returns: - RatelimitOverride if there is an override, else None. If the contents - of RatelimitOverride are None or 0 then ratelimitng has been - disabled for that user entirely. - """ - row = yield self._simple_select_one( - table="ratelimit_override", - keyvalues={"user_id": user_id}, - retcols=("messages_per_second", "burst_count"), - allow_none=True, - desc="get_ratelimit_for_user", - ) - - if row: - defer.returnValue(RatelimitOverride( - messages_per_second=row["messages_per_second"], - burst_count=row["burst_count"], - )) - else: - defer.returnValue(None) - @defer.inlineCallbacks def block_room(self, room_id, user_id): yield self._simple_insert( -- cgit 1.5.1 From e8eba2b4e3a99d35f08c96205ebb18211adddcb9 Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Tue, 7 Aug 2018 17:49:43 +0100 Subject: implement reserved users for mau limits --- synapse/app/homeserver.py | 6 +++ synapse/config/server.py | 2 +- synapse/storage/monthly_active_users.py | 45 +++++++++++++++++------ tests/storage/test_monthly_active_users.py | 59 +++++++++++++++++++++++++++++- 4 files changed, 99 insertions(+), 13 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 3a67db8b30..a4a65e7286 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -518,6 +518,8 @@ def run(hs): # If you increase the loop period, the accuracy of user_daily_visits # table will decrease clock.looping_call(generate_user_daily_visit_stats, 5 * 60 * 1000) + + # monthly active user limiting functionality clock.looping_call( hs.get_datastore().reap_monthly_active_users, 1000 * 60 * 60 ) @@ -530,9 +532,13 @@ def run(hs): current_mau_gauge.set(float(count)) max_mau_value_gauge.set(float(hs.config.max_mau_value)) + hs.get_datastore().initialise_reserved_users( + hs.config.mau_limits_reserved_threepids + ) generate_monthly_active_users() if hs.config.limit_usage_by_mau: clock.looping_call(generate_monthly_active_users, 5 * 60 * 1000) + # End of monthly active user settings if hs.config.report_stats: logger.info("Scheduling stats reporting for 3 hour intervals") diff --git a/synapse/config/server.py b/synapse/config/server.py index cacac253ab..114d7a9815 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -75,7 +75,7 @@ class ServerConfig(Config): "max_mau_value", 0, ) self.mau_limits_reserved_threepids = config.get( - "mau_limit_reserved_threepid", [] + "mau_limit_reserved_threepids", [] ) # FIXME: federation_domain_whitelist needs sytests diff --git a/synapse/storage/monthly_active_users.py b/synapse/storage/monthly_active_users.py index 6a37d6fc22..168f564ed5 100644 --- a/synapse/storage/monthly_active_users.py +++ b/synapse/storage/monthly_active_users.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging from twisted.internet import defer @@ -19,6 +20,7 @@ from synapse.util.caches.descriptors import cached from ._base import SQLBaseStore +logger = logging.getLogger(__name__) # Number of msec of granularity to store the monthly_active_user timestamp # This means it is not necessary to update the table on every request @@ -26,24 +28,31 @@ LAST_SEEN_GRANULARITY = 60 * 60 * 1000 class MonthlyActiveUsersStore(SQLBaseStore): - @defer.inlineCallbacks def __init__(self, dbconn, hs): super(MonthlyActiveUsersStore, self).__init__(None, hs) self._clock = hs.get_clock() self.hs = hs - threepids = self.hs.config.mau_limits_reserved_threepids - self.reserved_user_ids = set() + self.reserved_users = () + + @defer.inlineCallbacks + def initialise_reserved_users(self, threepids): + # TODO Why can't I do this in init? + store = self.hs.get_datastore() + reserved_user_list = [] for tp in threepids: - user_id = yield hs.get_datastore().get_user_id_by_threepid( + user_id = yield store.get_user_id_by_threepid( tp["medium"], tp["address"] ) if user_id: - self.reserved_user_ids.add(user_id) + self.upsert_monthly_active_user(user_id) + reserved_user_list.append(user_id) else: logger.warning( "mau limit reserved threepid %s not found in db" % tp ) + self.reserved_users = tuple(reserved_user_list) + @defer.inlineCallbacks def reap_monthly_active_users(self): """ Cleans out monthly active user table to ensure that no stale @@ -58,8 +67,20 @@ class MonthlyActiveUsersStore(SQLBaseStore): int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30) ) # Purge stale users - sql = "DELETE FROM monthly_active_users WHERE timestamp < ?" - txn.execute(sql, (thirty_days_ago,)) + + # questionmarks is a hack to overcome sqlite not supporting + # tuples in 'WHERE IN %s' + questionmarks = '?' * len(self.reserved_users) + query_args = [thirty_days_ago] + query_args.extend(self.reserved_users) + + sql = """ + DELETE FROM monthly_active_users + WHERE timestamp < ? + AND user_id NOT IN ({}) + """.format(','.join(questionmarks)) + + txn.execute(sql, query_args) # If MAU user count still exceeds the MAU threshold, then delete on # a least recently active basis. @@ -69,6 +90,8 @@ class MonthlyActiveUsersStore(SQLBaseStore): # While Postgres does not require 'LIMIT', but also does not support # negative LIMIT values. So there is no way to write it that both can # support + query_args = [self.hs.config.max_mau_value] + query_args.extend(self.reserved_users) sql = """ DELETE FROM monthly_active_users WHERE user_id NOT IN ( @@ -76,8 +99,9 @@ class MonthlyActiveUsersStore(SQLBaseStore): ORDER BY timestamp DESC LIMIT ? ) - """ - txn.execute(sql, (self.hs.config.max_mau_value,)) + AND user_id NOT IN ({}) + """.format(','.join(questionmarks)) + txn.execute(sql, query_args) yield self.runInteraction("reap_monthly_active_users", _reap_users) # It seems poor to invalidate the whole cache, Postgres supports @@ -136,7 +160,7 @@ class MonthlyActiveUsersStore(SQLBaseStore): Arguments: user_id (str): user to add/update Return: - int : timestamp since last seen, None if never seen + Deferred[int] : timestamp since last seen, None if never seen """ @@ -158,7 +182,6 @@ class MonthlyActiveUsersStore(SQLBaseStore): Args: user_id(str): the user_id to query """ - if self.hs.config.limit_usage_by_mau: last_seen_timestamp = yield self._user_last_seen_monthly_active(user_id) now = self.hs.get_clock().time_msec() diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index 0bfd24a7fb..cbd480cd42 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -19,6 +19,8 @@ import tests.unittest import tests.utils from tests.utils import setup_test_homeserver +FORTY_DAYS = 40 * 24 * 60 * 60 + class MonthlyActiveUsersTestCase(tests.unittest.TestCase): def __init__(self, *args, **kwargs): @@ -29,6 +31,56 @@ class MonthlyActiveUsersTestCase(tests.unittest.TestCase): self.hs = yield setup_test_homeserver() self.store = self.hs.get_datastore() + @defer.inlineCallbacks + def test_initialise_reserved_users(self): + + user1 = "@user1:server" + user1_email = "user1@matrix.org" + user2 = "@user2:server" + user2_email = "user2@matrix.org" + threepids = [ + {'medium': 'email', 'address': user1_email}, + {'medium': 'email', 'address': user2_email} + ] + user_num = len(threepids) + + yield self.store.register( + user_id=user1, + token="123", + password_hash=None) + + yield self.store.register( + user_id=user2, + token="456", + password_hash=None) + + now = int(self.hs.get_clock().time_msec()) + yield self.store.user_add_threepid(user1, "email", user1_email, now, now) + yield self.store.user_add_threepid(user2, "email", user2_email, now, now) + yield self.store.initialise_reserved_users(threepids) + + active_count = yield self.store.get_monthly_active_count() + + # Test total counts + self.assertEquals(active_count, user_num) + + # Test user is marked as active + + timestamp = yield self.store._user_last_seen_monthly_active(user1) + self.assertTrue(timestamp) + timestamp = yield self.store._user_last_seen_monthly_active(user2) + self.assertTrue(timestamp) + + # Test that users are never removed from the db. + self.hs.config.max_mau_value = 0 + + self.hs.get_clock().advance_time(FORTY_DAYS) + + yield self.store.reap_monthly_active_users() + + active_count = yield self.store.get_monthly_active_count() + self.assertEquals(active_count, user_num) + @defer.inlineCallbacks def test_can_insert_and_count_mau(self): count = yield self.store.get_monthly_active_count() @@ -63,4 +115,9 @@ class MonthlyActiveUsersTestCase(tests.unittest.TestCase): self.assertTrue(count, initial_users) yield self.store.reap_monthly_active_users() count = yield self.store.get_monthly_active_count() - self.assertTrue(count, initial_users - self.hs.config.max_mau_value) + self.assertEquals(count, initial_users - self.hs.config.max_mau_value) + + self.hs.get_clock().advance_time(FORTY_DAYS) + yield self.store.reap_monthly_active_users() + count = yield self.store.get_monthly_active_count() + self.assertEquals(count, 0) -- cgit 1.5.1 From ef3589063adddd1eee89f136e6a54250cce414a1 Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Tue, 7 Aug 2018 17:59:27 +0100 Subject: prevent total number of reserved users being too large --- synapse/storage/monthly_active_users.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'synapse/storage') diff --git a/synapse/storage/monthly_active_users.py b/synapse/storage/monthly_active_users.py index 168f564ed5..54de5a8686 100644 --- a/synapse/storage/monthly_active_users.py +++ b/synapse/storage/monthly_active_users.py @@ -39,7 +39,9 @@ class MonthlyActiveUsersStore(SQLBaseStore): # TODO Why can't I do this in init? store = self.hs.get_datastore() reserved_user_list = [] - for tp in threepids: + + # Do not add more reserved users than the total allowable number + for tp in threepids[:self.hs.config.max_mau_value]: user_id = yield store.get_user_id_by_threepid( tp["medium"], tp["address"] ) -- cgit 1.5.1 From 312ae747466f7c60c0148b75446f357263330148 Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Wed, 8 Aug 2018 13:33:16 +0100 Subject: typos --- synapse/storage/monthly_active_users.py | 4 ++-- synapse/storage/schema/delta/51/monthly_active_users.sql | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/monthly_active_users.py b/synapse/storage/monthly_active_users.py index abe1e6bb99..8b3beaf26a 100644 --- a/synapse/storage/monthly_active_users.py +++ b/synapse/storage/monthly_active_users.py @@ -50,7 +50,7 @@ class MonthlyActiveUsersStore(SQLBaseStore): # If MAU user count still exceeds the MAU threshold, then delete on # a least recently active basis. # Note it is not possible to write this query using OFFSET due to - # incompatibilities in how sqlite an postgres support the feature. + # incompatibilities in how sqlite and postgres support the feature. # sqlite requires 'LIMIT -1 OFFSET ?', the LIMIT must be present # While Postgres does not require 'LIMIT', but also does not support # negative LIMIT values. So there is no way to write it that both can @@ -78,7 +78,7 @@ class MonthlyActiveUsersStore(SQLBaseStore): @cached(num_args=0) def get_monthly_active_count(self): - """Generates current count of monthly active users.abs + """Generates current count of monthly active users Returns: Defered[int]: Number of current monthly active users diff --git a/synapse/storage/schema/delta/51/monthly_active_users.sql b/synapse/storage/schema/delta/51/monthly_active_users.sql index 10aac90ce1..c9d537d5a3 100644 --- a/synapse/storage/schema/delta/51/monthly_active_users.sql +++ b/synapse/storage/schema/delta/51/monthly_active_users.sql @@ -18,7 +18,7 @@ CREATE TABLE monthly_active_users ( user_id TEXT NOT NULL, -- Last time we saw the user. Not guaranteed to be accurate due to rate limiting -- on updates, Granularity of updates governed by - -- syanpse.storage.monthly_active_users.LAST_SEEN_GRANULARITY + -- synapse.storage.monthly_active_users.LAST_SEEN_GRANULARITY -- Measured in ms since epoch. timestamp BIGINT NOT NULL ); -- cgit 1.5.1 From ce6db0e5473e31292567b11599eb334c3275c564 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 8 Aug 2018 17:01:57 +0100 Subject: Choose state algorithm based on room version --- synapse/handlers/federation.py | 8 +++- synapse/handlers/room_member.py | 5 +- synapse/state/__init__.py | 104 +++++++++++++++++++++++++++++++++++----- synapse/storage/events.py | 4 +- 4 files changed, 105 insertions(+), 16 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 0dffd44e22..75a819dd11 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -274,8 +274,9 @@ class FederationHandler(BaseHandler): ev_ids, get_prev_content=False, check_redacted=False ) + room_version = yield self.store.get_room_version(pdu.room_id) state_map = yield resolve_events_with_factory( - state_groups, {pdu.event_id: pdu}, fetch + room_version, state_groups, {pdu.event_id: pdu}, fetch ) state = (yield self.store.get_events(state_map.values())).values() @@ -1811,7 +1812,10 @@ class FederationHandler(BaseHandler): (d.type, d.state_key): d for d in different_events if d }) - new_state = self.state_handler.resolve_events( + room_version = yield self.store.get_room_version(event.room_id) + + new_state = yield self.state_handler.resolve_events( + room_version, [list(local_view.values()), list(remote_view.values())], event ) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 0d4a3f4677..b6010bb41e 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -341,9 +341,10 @@ class RoomMemberHandler(object): prev_events_and_hashes = yield self.store.get_prev_events_for_room( room_id, ) - latest_event_ids = ( + latest_event_ids = [ event_id for (event_id, _, _) in prev_events_and_hashes - ) + ] + current_state_ids = yield self.state_handler.get_current_state_ids( room_id, latest_event_ids=latest_event_ids, ) diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 8c091d07c9..222daa0b28 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -23,16 +23,15 @@ from frozendict import frozendict from twisted.internet import defer -from synapse.api.constants import EventTypes +from synapse.api.constants import EventTypes, RoomVersions from synapse.events.snapshot import EventContext +from synapse.state import v1 from synapse.util.async import Linearizer from synapse.util.caches import CACHE_SIZE_FACTOR from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.logutils import log_function from synapse.util.metrics import Measure -from .v1 import resolve_events_with_factory, resolve_events_with_state_map - logger = logging.getLogger(__name__) @@ -263,8 +262,14 @@ class StateHandler(object): defer.returnValue(context) logger.debug("calling resolve_state_groups from compute_event_context") + if event.type == EventTypes.Create: + room_version = event.content.get("room_version", RoomVersions.V1) + else: + room_version = None + entry = yield self.resolve_state_groups_for_events( event.room_id, [e for e, _ in event.prev_events], + explicit_room_version=room_version, ) prev_state_ids = entry.state @@ -332,13 +337,17 @@ class StateHandler(object): defer.returnValue(context) @defer.inlineCallbacks - def resolve_state_groups_for_events(self, room_id, event_ids): + def resolve_state_groups_for_events(self, room_id, event_ids, + explicit_room_version=None): """ Given a list of event_ids this method fetches the state at each event, resolves conflicts between them and returns them. Args: - room_id (str): - event_ids (list[str]): + room_id (str) + event_ids (list[str]) + explicit_room_version (str|None): If set uses the the given room + version to choose the resolution algorithm. If None, then + checks the database for room version. Returns: Deferred[_StateCacheEntry]: resolved state @@ -364,8 +373,13 @@ class StateHandler(object): delta_ids=delta_ids, )) + room_version = explicit_room_version + if not room_version: + room_version = yield self.store.get_room_version(room_id) + result = yield self._state_resolution_handler.resolve_state_groups( - room_id, state_groups_ids, None, self._state_map_factory, + room_id, room_version, state_groups_ids, None, + self._state_map_factory, ) defer.returnValue(result) @@ -374,7 +388,8 @@ class StateHandler(object): ev_ids, get_prev_content=False, check_redacted=False, ) - def resolve_events(self, state_sets, event): + @defer.inlineCallbacks + def resolve_events(self, room_version, state_sets, event): logger.info( "Resolving state for %s with %d groups", event.room_id, len(state_sets) ) @@ -389,14 +404,18 @@ class StateHandler(object): for ev in st } + room_version = yield self.store.get_room_version(event.room_id) + with Measure(self.clock, "state._resolve_events"): - new_state = resolve_events_with_state_map(state_set_ids, state_map) + new_state = resolve_events_with_state_map( + room_version, state_set_ids, state_map, + ) new_state = { key: state_map[ev_id] for key, ev_id in iteritems(new_state) } - return new_state + defer.returnValue(new_state) class StateResolutionHandler(object): @@ -429,7 +448,7 @@ class StateResolutionHandler(object): @defer.inlineCallbacks @log_function def resolve_state_groups( - self, room_id, state_groups_ids, event_map, state_map_factory, + self, room_id, room_version, state_groups_ids, event_map, state_map_factory, ): """Resolves conflicts between a set of state groups @@ -438,6 +457,7 @@ class StateResolutionHandler(object): Args: room_id (str): room we are resolving for (used for logging) + room_version (str): version of the room state_groups_ids (dict[int, dict[(str, str), str]]): map from state group id to the state in that state group (where 'state' is a map from state key to event id) @@ -491,6 +511,7 @@ class StateResolutionHandler(object): logger.info("Resolving conflicted state for %r", room_id) with Measure(self.clock, "state._resolve_events"): new_state = yield resolve_events_with_factory( + room_version, list(itervalues(state_groups_ids)), event_map=event_map, state_map_factory=state_map_factory, @@ -572,3 +593,64 @@ def _make_state_cache_entry( prev_group=prev_group, delta_ids=delta_ids, ) + + +def resolve_events_with_state_map(room_version, state_sets, state_map): + """ + Args: + room_version(str): Version of the room + state_sets(list): List of dicts of (type, state_key) -> event_id, + which are the different state groups to resolve. + state_map(dict): a dict from event_id to event, for all events in + state_sets. + + Returns + dict[(str, str), str]: + a map from (type, state_key) to event_id. + """ + if room_version in (RoomVersions.V1, RoomVersions.VDH_TEST,): + return v1.resolve_events_with_state_map( + state_sets, state_map, + ) + else: + # This should only happen if we added a version but forgot to add it to + # the list above. + raise Exception( + "No state resolution algorithm defined for version %r" % (room_version,) + ) + + +def resolve_events_with_factory(room_version, state_sets, event_map, state_map_factory): + """ + Args: + room_version(str): Version of the room + + state_sets(list): List of dicts of (type, state_key) -> event_id, + which are the different state groups to resolve. + + event_map(dict[str,FrozenEvent]|None): + a dict from event_id to event, for any events that we happen to + have in flight (eg, those currently being persisted). This will be + used as a starting point fof finding the state we need; any missing + events will be requested via state_map_factory. + + If None, all events will be fetched via state_map_factory. + + state_map_factory(func): will be called + with a list of event_ids that are needed, and should return with + a Deferred of dict of event_id to event. + + Returns + Deferred[dict[(str, str), str]]: + a map from (type, state_key) to event_id. + """ + if room_version in (RoomVersions.V1, RoomVersions.VDH_TEST,): + return v1.resolve_events_with_factory( + state_sets, event_map, state_map_factory, + ) + else: + # This should only happen if we added a version but forgot to add it to + # the list above. + raise Exception( + "No state resolution algorithm defined for version %r" % (room_version,) + ) diff --git a/synapse/storage/events.py b/synapse/storage/events.py index ce32e8fefd..dc68d365c2 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -705,9 +705,11 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore } events_map = {ev.event_id: ev for ev, _ in events_context} + room_version = yield self.get_room_version(room_id) + logger.debug("calling resolve_state_groups from preserve_events") res = yield self._state_resolution_handler.resolve_state_groups( - room_id, state_groups, events_map, get_events + room_id, room_version, state_groups, events_map, get_events ) defer.returnValue((res.state, None)) -- cgit 1.5.1 From 885ea9c602ca4002385a6d73c94c35e0958de809 Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Thu, 9 Aug 2018 18:02:12 +0100 Subject: rename _user_last_seen_monthly_active --- synapse/api/auth.py | 2 +- synapse/storage/monthly_active_users.py | 10 +++++----- tests/storage/test_client_ips.py | 12 ++++++------ tests/storage/test_monthly_active_users.py | 13 +++++++------ 4 files changed, 19 insertions(+), 18 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index df6022ff69..c31c6a6a08 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -789,7 +789,7 @@ class Auth(object): if self.hs.config.limit_usage_by_mau is True: # If the user is already part of the MAU cohort if user_id: - timestamp = yield self.store._user_last_seen_monthly_active(user_id) + timestamp = yield self.store.user_last_seen_monthly_active(user_id) if timestamp: return # Else if there is no room in the MAU bucket, bail diff --git a/synapse/storage/monthly_active_users.py b/synapse/storage/monthly_active_users.py index d47dcef3a0..07211432af 100644 --- a/synapse/storage/monthly_active_users.py +++ b/synapse/storage/monthly_active_users.py @@ -113,7 +113,7 @@ class MonthlyActiveUsersStore(SQLBaseStore): # is racy. # Have resolved to invalidate the whole cache for now and do # something about it if and when the perf becomes significant - self._user_last_seen_monthly_active.invalidate_all() + self.user_last_seen_monthly_active.invalidate_all() self.get_monthly_active_count.invalidate_all() @cached(num_args=0) @@ -152,11 +152,11 @@ class MonthlyActiveUsersStore(SQLBaseStore): lock=False, ) if is_insert: - self._user_last_seen_monthly_active.invalidate((user_id,)) + self.user_last_seen_monthly_active.invalidate((user_id,)) self.get_monthly_active_count.invalidate(()) @cached(num_args=1) - def _user_last_seen_monthly_active(self, user_id): + def user_last_seen_monthly_active(self, user_id): """ Checks if a given user is part of the monthly active user group Arguments: @@ -173,7 +173,7 @@ class MonthlyActiveUsersStore(SQLBaseStore): }, retcol="timestamp", allow_none=True, - desc="_user_last_seen_monthly_active", + desc="user_last_seen_monthly_active", )) @defer.inlineCallbacks @@ -185,7 +185,7 @@ class MonthlyActiveUsersStore(SQLBaseStore): user_id(str): the user_id to query """ if self.hs.config.limit_usage_by_mau: - last_seen_timestamp = yield self._user_last_seen_monthly_active(user_id) + last_seen_timestamp = yield self.user_last_seen_monthly_active(user_id) now = self.hs.get_clock().time_msec() # We want to reduce to the total number of db writes, and are happy diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index 7a58c6eb24..6d2bb3058e 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -64,7 +64,7 @@ class ClientIpStoreTestCase(tests.unittest.TestCase): yield self.store.insert_client_ip( user_id, "access_token", "ip", "user_agent", "device_id", ) - active = yield self.store._user_last_seen_monthly_active(user_id) + active = yield self.store.user_last_seen_monthly_active(user_id) self.assertFalse(active) @defer.inlineCallbacks @@ -80,7 +80,7 @@ class ClientIpStoreTestCase(tests.unittest.TestCase): yield self.store.insert_client_ip( user_id, "access_token", "ip", "user_agent", "device_id", ) - active = yield self.store._user_last_seen_monthly_active(user_id) + active = yield self.store.user_last_seen_monthly_active(user_id) self.assertFalse(active) @defer.inlineCallbacks @@ -88,13 +88,13 @@ class ClientIpStoreTestCase(tests.unittest.TestCase): self.hs.config.limit_usage_by_mau = True self.hs.config.max_mau_value = 50 user_id = "@user:server" - active = yield self.store._user_last_seen_monthly_active(user_id) + active = yield self.store.user_last_seen_monthly_active(user_id) self.assertFalse(active) yield self.store.insert_client_ip( user_id, "access_token", "ip", "user_agent", "device_id", ) - active = yield self.store._user_last_seen_monthly_active(user_id) + active = yield self.store.user_last_seen_monthly_active(user_id) self.assertTrue(active) @defer.inlineCallbacks @@ -103,7 +103,7 @@ class ClientIpStoreTestCase(tests.unittest.TestCase): self.hs.config.max_mau_value = 50 user_id = "@user:server" - active = yield self.store._user_last_seen_monthly_active(user_id) + active = yield self.store.user_last_seen_monthly_active(user_id) self.assertFalse(active) yield self.store.insert_client_ip( @@ -112,5 +112,5 @@ class ClientIpStoreTestCase(tests.unittest.TestCase): yield self.store.insert_client_ip( user_id, "access_token", "ip", "user_agent", "device_id", ) - active = yield self.store._user_last_seen_monthly_active(user_id) + active = yield self.store.user_last_seen_monthly_active(user_id) self.assertTrue(active) diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index cbd480cd42..be74c30218 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -66,9 +66,9 @@ class MonthlyActiveUsersTestCase(tests.unittest.TestCase): # Test user is marked as active - timestamp = yield self.store._user_last_seen_monthly_active(user1) + timestamp = yield self.store.user_last_seen_monthly_active(user1) self.assertTrue(timestamp) - timestamp = yield self.store._user_last_seen_monthly_active(user2) + timestamp = yield self.store.user_last_seen_monthly_active(user2) self.assertTrue(timestamp) # Test that users are never removed from the db. @@ -92,17 +92,18 @@ class MonthlyActiveUsersTestCase(tests.unittest.TestCase): self.assertEqual(1, count) @defer.inlineCallbacks - def test__user_last_seen_monthly_active(self): + def test_user_last_seen_monthly_active(self): user_id1 = "@user1:server" user_id2 = "@user2:server" user_id3 = "@user3:server" - result = yield self.store._user_last_seen_monthly_active(user_id1) + + result = yield self.store.user_last_seen_monthly_active(user_id1) self.assertFalse(result == 0) yield self.store.upsert_monthly_active_user(user_id1) yield self.store.upsert_monthly_active_user(user_id2) - result = yield self.store._user_last_seen_monthly_active(user_id1) + result = yield self.store.user_last_seen_monthly_active(user_id1) self.assertTrue(result > 0) - result = yield self.store._user_last_seen_monthly_active(user_id3) + result = yield self.store.user_last_seen_monthly_active(user_id3) self.assertFalse(result == 0) @defer.inlineCallbacks -- cgit 1.5.1 From b37c4724191995eec4f5bc9d64f176b2a315f777 Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Fri, 10 Aug 2018 23:50:21 +1000 Subject: Rename async to async_helpers because `async` is a keyword on Python 3.7 (#3678) --- changelog.d/3678.misc | 1 + synapse/app/federation_sender.py | 2 +- synapse/federation/federation_server.py | 8 +- synapse/handlers/device.py | 2 +- synapse/handlers/federation.py | 2 +- synapse/handlers/initial_sync.py | 2 +- synapse/handlers/message.py | 2 +- synapse/handlers/pagination.py | 2 +- synapse/handlers/presence.py | 2 +- synapse/handlers/read_marker.py | 2 +- synapse/handlers/register.py | 2 +- synapse/handlers/room_list.py | 2 +- synapse/handlers/room_member.py | 2 +- synapse/handlers/sync.py | 2 +- synapse/http/client.py | 2 +- synapse/http/matrixfederationclient.py | 2 +- synapse/notifier.py | 2 +- synapse/push/bulk_push_rule_evaluator.py | 2 +- synapse/push/mailer.py | 2 +- synapse/rest/client/transactions.py | 2 +- synapse/rest/media/v1/media_repository.py | 2 +- synapse/rest/media/v1/preview_url_resource.py | 2 +- synapse/state.py | 2 +- synapse/storage/events.py | 2 +- synapse/storage/roommember.py | 2 +- synapse/util/async.py | 415 -------------------------- synapse/util/async_helpers.py | 415 ++++++++++++++++++++++++++ synapse/util/caches/descriptors.py | 2 +- synapse/util/caches/response_cache.py | 2 +- synapse/util/caches/snapshot_cache.py | 2 +- synapse/util/logcontext.py | 2 +- tests/storage/test__base.py | 2 +- tests/util/test_linearizer.py | 2 +- tests/util/test_rwlock.py | 2 +- 34 files changed, 450 insertions(+), 449 deletions(-) create mode 100644 changelog.d/3678.misc delete mode 100644 synapse/util/async.py create mode 100644 synapse/util/async_helpers.py (limited to 'synapse/storage') diff --git a/changelog.d/3678.misc b/changelog.d/3678.misc new file mode 100644 index 0000000000..0d7c8da64a --- /dev/null +++ b/changelog.d/3678.misc @@ -0,0 +1 @@ +Rename synapse.util.async to synapse.util.async_helpers to mitigate async becoming a keyword on Python 3.7. diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py index e082d45c94..7bbf0ad082 100644 --- a/synapse/app/federation_sender.py +++ b/synapse/app/federation_sender.py @@ -40,7 +40,7 @@ from synapse.replication.slave.storage.transactions import TransactionStore from synapse.replication.tcp.client import ReplicationClientHandler from synapse.server import HomeServer from synapse.storage.engines import create_engine -from synapse.util.async import Linearizer +from synapse.util.async_helpers import Linearizer from synapse.util.httpresourcetree import create_resource_tree from synapse.util.logcontext import LoggingContext, run_in_background from synapse.util.manhole import manhole diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 2b62f687b6..a23136784a 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -40,7 +40,7 @@ from synapse.federation.persistence import TransactionActions from synapse.federation.units import Edu, Transaction from synapse.http.endpoint import parse_server_name from synapse.types import get_domain_from_id -from synapse.util import async +from synapse.util.async_helpers import Linearizer, concurrently_execute from synapse.util.caches.response_cache import ResponseCache from synapse.util.logutils import log_function @@ -67,8 +67,8 @@ class FederationServer(FederationBase): self.auth = hs.get_auth() self.handler = hs.get_handlers().federation_handler - self._server_linearizer = async.Linearizer("fed_server") - self._transaction_linearizer = async.Linearizer("fed_txn_handler") + self._server_linearizer = Linearizer("fed_server") + self._transaction_linearizer = Linearizer("fed_txn_handler") self.transaction_actions = TransactionActions(self.store) @@ -200,7 +200,7 @@ class FederationServer(FederationBase): event_id, f.getTraceback().rstrip(), ) - yield async.concurrently_execute( + yield concurrently_execute( process_pdus_for_room, pdus_by_room.keys(), TRANSACTION_CONCURRENCY_LIMIT, ) diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 2d44f15da3..9e017116a9 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -23,7 +23,7 @@ from synapse.api.constants import EventTypes from synapse.api.errors import FederationDeniedError from synapse.types import RoomStreamToken, get_domain_from_id from synapse.util import stringutils -from synapse.util.async import Linearizer +from synapse.util.async_helpers import Linearizer from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.metrics import measure_func from synapse.util.retryutils import NotRetryingDestination diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 0dffd44e22..2380d17f4e 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -52,7 +52,7 @@ from synapse.events.validator import EventValidator from synapse.state import resolve_events_with_factory from synapse.types import UserID, get_domain_from_id from synapse.util import logcontext, unwrapFirstError -from synapse.util.async import Linearizer +from synapse.util.async_helpers import Linearizer from synapse.util.distributor import user_joined_room from synapse.util.frozenutils import unfreeze from synapse.util.logutils import log_function diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 40e7580a61..1fb17fd9a5 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -25,7 +25,7 @@ from synapse.handlers.presence import format_user_presence_state from synapse.streams.config import PaginationConfig from synapse.types import StreamToken, UserID from synapse.util import unwrapFirstError -from synapse.util.async import concurrently_execute +from synapse.util.async_helpers import concurrently_execute from synapse.util.caches.snapshot_cache import SnapshotCache from synapse.util.logcontext import make_deferred_yieldable, run_in_background from synapse.visibility import filter_events_for_client diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index bcb093ba3e..01a362360e 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -32,7 +32,7 @@ from synapse.events.utils import serialize_event from synapse.events.validator import EventValidator from synapse.replication.http.send_event import ReplicationSendEventRestServlet from synapse.types import RoomAlias, UserID -from synapse.util.async import Linearizer +from synapse.util.async_helpers import Linearizer from synapse.util.frozenutils import frozendict_json_encoder from synapse.util.logcontext import run_in_background from synapse.util.metrics import measure_func diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index b2849783ed..a97d43550f 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -22,7 +22,7 @@ from synapse.api.constants import Membership from synapse.api.errors import SynapseError from synapse.events.utils import serialize_event from synapse.types import RoomStreamToken -from synapse.util.async import ReadWriteLock +from synapse.util.async_helpers import ReadWriteLock from synapse.util.logcontext import run_in_background from synapse.util.stringutils import random_string from synapse.visibility import filter_events_for_client diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 3732830194..20fc3b0323 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -36,7 +36,7 @@ from synapse.api.errors import SynapseError from synapse.metrics import LaterGauge from synapse.storage.presence import UserPresenceState from synapse.types import UserID, get_domain_from_id -from synapse.util.async import Linearizer +from synapse.util.async_helpers import Linearizer from synapse.util.caches.descriptors import cachedInlineCallbacks from synapse.util.logcontext import run_in_background from synapse.util.logutils import log_function diff --git a/synapse/handlers/read_marker.py b/synapse/handlers/read_marker.py index 995460f82a..32108568c6 100644 --- a/synapse/handlers/read_marker.py +++ b/synapse/handlers/read_marker.py @@ -17,7 +17,7 @@ import logging from twisted.internet import defer -from synapse.util.async import Linearizer +from synapse.util.async_helpers import Linearizer from ._base import BaseHandler diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 0e16bbe0ee..3526b20d5a 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -28,7 +28,7 @@ from synapse.api.errors import ( ) from synapse.http.client import CaptchaServerHttpClient from synapse.types import RoomAlias, RoomID, UserID, create_requester -from synapse.util.async import Linearizer +from synapse.util.async_helpers import Linearizer from synapse.util.threepids import check_3pid_allowed from ._base import BaseHandler diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index 828229f5c3..37e41afd61 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -26,7 +26,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, JoinRules from synapse.types import ThirdPartyInstanceID -from synapse.util.async import concurrently_execute +from synapse.util.async_helpers import concurrently_execute from synapse.util.caches.descriptors import cachedInlineCallbacks from synapse.util.caches.response_cache import ResponseCache diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 0d4a3f4677..fb94b5d7d4 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -30,7 +30,7 @@ import synapse.types from synapse.api.constants import EventTypes, Membership from synapse.api.errors import AuthError, Codes, SynapseError from synapse.types import RoomID, UserID -from synapse.util.async import Linearizer +from synapse.util.async_helpers import Linearizer from synapse.util.distributor import user_joined_room, user_left_room logger = logging.getLogger(__name__) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index dff1f67dcb..6393a9674b 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -25,7 +25,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership from synapse.push.clientformat import format_push_rules_for_user from synapse.types import RoomStreamToken -from synapse.util.async import concurrently_execute +from synapse.util.async_helpers import concurrently_execute from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.lrucache import LruCache from synapse.util.caches.response_cache import ResponseCache diff --git a/synapse/http/client.py b/synapse/http/client.py index 3771e0b3f6..ab4fbf59b2 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -42,7 +42,7 @@ from twisted.web.http_headers import Headers from synapse.api.errors import Codes, HttpResponseException, SynapseError from synapse.http import cancelled_to_request_timed_out_error, redact_uri from synapse.http.endpoint import SpiderEndpoint -from synapse.util.async import add_timeout_to_deferred +from synapse.util.async_helpers import add_timeout_to_deferred from synapse.util.caches import CACHE_SIZE_FACTOR from synapse.util.logcontext import make_deferred_yieldable diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 762273f59b..44b61e70a4 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -43,7 +43,7 @@ from synapse.api.errors import ( from synapse.http import cancelled_to_request_timed_out_error from synapse.http.endpoint import matrix_federation_endpoint from synapse.util import logcontext -from synapse.util.async import add_timeout_to_deferred +from synapse.util.async_helpers import add_timeout_to_deferred from synapse.util.logcontext import make_deferred_yieldable logger = logging.getLogger(__name__) diff --git a/synapse/notifier.py b/synapse/notifier.py index e650c3e494..82f391481c 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -25,7 +25,7 @@ from synapse.api.errors import AuthError from synapse.handlers.presence import format_user_presence_state from synapse.metrics import LaterGauge from synapse.types import StreamToken -from synapse.util.async import ( +from synapse.util.async_helpers import ( DeferredTimeoutError, ObservableDeferred, add_timeout_to_deferred, diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 1d14d3639c..8f9a76147f 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -26,7 +26,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership from synapse.event_auth import get_user_power_level from synapse.state import POWER_KEY -from synapse.util.async import Linearizer +from synapse.util.async_helpers import Linearizer from synapse.util.caches import register_cache from synapse.util.caches.descriptors import cached diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index 9d601208fd..bfa6df7b68 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -35,7 +35,7 @@ from synapse.push.presentable_names import ( name_from_member_event, ) from synapse.types import UserID -from synapse.util.async import concurrently_execute +from synapse.util.async_helpers import concurrently_execute from synapse.visibility import filter_events_for_client logger = logging.getLogger(__name__) diff --git a/synapse/rest/client/transactions.py b/synapse/rest/client/transactions.py index 00b1b3066e..511e96ab00 100644 --- a/synapse/rest/client/transactions.py +++ b/synapse/rest/client/transactions.py @@ -17,7 +17,7 @@ to ensure idempotency when performing PUTs using the REST API.""" import logging -from synapse.util.async import ObservableDeferred +from synapse.util.async_helpers import ObservableDeferred from synapse.util.logcontext import make_deferred_yieldable, run_in_background logger = logging.getLogger(__name__) diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 8fb413d825..4c589e05e0 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -36,7 +36,7 @@ from synapse.api.errors import ( ) from synapse.http.matrixfederationclient import MatrixFederationHttpClient from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.util.async import Linearizer +from synapse.util.async_helpers import Linearizer from synapse.util.logcontext import make_deferred_yieldable from synapse.util.retryutils import NotRetryingDestination from synapse.util.stringutils import is_ascii, random_string diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 27aa0def2f..778ef97337 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -42,7 +42,7 @@ from synapse.http.server import ( ) from synapse.http.servlet import parse_integer, parse_string from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.util.async import ObservableDeferred +from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.logcontext import make_deferred_yieldable, run_in_background from synapse.util.stringutils import is_ascii, random_string diff --git a/synapse/state.py b/synapse/state.py index e1092b97a9..8b92d4057a 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -28,7 +28,7 @@ from synapse import event_auth from synapse.api.constants import EventTypes from synapse.api.errors import AuthError from synapse.events.snapshot import EventContext -from synapse.util.async import Linearizer +from synapse.util.async_helpers import Linearizer from synapse.util.caches import CACHE_SIZE_FACTOR from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.logutils import log_function diff --git a/synapse/storage/events.py b/synapse/storage/events.py index ce32e8fefd..d4aa192a0a 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -38,7 +38,7 @@ from synapse.storage.background_updates import BackgroundUpdateStore from synapse.storage.event_federation import EventFederationStore from synapse.storage.events_worker import EventsWorkerStore from synapse.types import RoomStreamToken, get_domain_from_id -from synapse.util.async import ObservableDeferred +from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches.descriptors import cached, cachedInlineCallbacks from synapse.util.frozenutils import frozendict_json_encoder from synapse.util.logcontext import PreserveLoggingContext, make_deferred_yieldable diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 10dce21cea..9b4e6d6aa8 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -26,7 +26,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership from synapse.storage.events_worker import EventsWorkerStore from synapse.types import get_domain_from_id -from synapse.util.async import Linearizer +from synapse.util.async_helpers import Linearizer from synapse.util.caches import intern_string from synapse.util.caches.descriptors import cached, cachedInlineCallbacks from synapse.util.stringutils import to_ascii diff --git a/synapse/util/async.py b/synapse/util/async.py deleted file mode 100644 index a7094e2fb4..0000000000 --- a/synapse/util/async.py +++ /dev/null @@ -1,415 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# Copyright 2018 New Vector Ltd. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import collections -import logging -from contextlib import contextmanager - -from six.moves import range - -from twisted.internet import defer -from twisted.internet.defer import CancelledError -from twisted.python import failure - -from synapse.util import Clock, logcontext, unwrapFirstError - -from .logcontext import ( - PreserveLoggingContext, - make_deferred_yieldable, - run_in_background, -) - -logger = logging.getLogger(__name__) - - -class ObservableDeferred(object): - """Wraps a deferred object so that we can add observer deferreds. These - observer deferreds do not affect the callback chain of the original - deferred. - - If consumeErrors is true errors will be captured from the origin deferred. - - Cancelling or otherwise resolving an observer will not affect the original - ObservableDeferred. - - NB that it does not attempt to do anything with logcontexts; in general - you should probably make_deferred_yieldable the deferreds - returned by `observe`, and ensure that the original deferred runs its - callbacks in the sentinel logcontext. - """ - - __slots__ = ["_deferred", "_observers", "_result"] - - def __init__(self, deferred, consumeErrors=False): - object.__setattr__(self, "_deferred", deferred) - object.__setattr__(self, "_result", None) - object.__setattr__(self, "_observers", set()) - - def callback(r): - object.__setattr__(self, "_result", (True, r)) - while self._observers: - try: - # TODO: Handle errors here. - self._observers.pop().callback(r) - except Exception: - pass - return r - - def errback(f): - object.__setattr__(self, "_result", (False, f)) - while self._observers: - try: - # TODO: Handle errors here. - self._observers.pop().errback(f) - except Exception: - pass - - if consumeErrors: - return None - else: - return f - - deferred.addCallbacks(callback, errback) - - def observe(self): - """Observe the underlying deferred. - - Can return either a deferred if the underlying deferred is still pending - (or has failed), or the actual value. Callers may need to use maybeDeferred. - """ - if not self._result: - d = defer.Deferred() - - def remove(r): - self._observers.discard(d) - return r - d.addBoth(remove) - - self._observers.add(d) - return d - else: - success, res = self._result - return res if success else defer.fail(res) - - def observers(self): - return self._observers - - def has_called(self): - return self._result is not None - - def has_succeeded(self): - return self._result is not None and self._result[0] is True - - def get_result(self): - return self._result[1] - - def __getattr__(self, name): - return getattr(self._deferred, name) - - def __setattr__(self, name, value): - setattr(self._deferred, name, value) - - def __repr__(self): - return "" % ( - id(self), self._result, self._deferred, - ) - - -def concurrently_execute(func, args, limit): - """Executes the function with each argument conncurrently while limiting - the number of concurrent executions. - - Args: - func (func): Function to execute, should return a deferred. - args (list): List of arguments to pass to func, each invocation of func - gets a signle argument. - limit (int): Maximum number of conccurent executions. - - Returns: - deferred: Resolved when all function invocations have finished. - """ - it = iter(args) - - @defer.inlineCallbacks - def _concurrently_execute_inner(): - try: - while True: - yield func(next(it)) - except StopIteration: - pass - - return logcontext.make_deferred_yieldable(defer.gatherResults([ - run_in_background(_concurrently_execute_inner) - for _ in range(limit) - ], consumeErrors=True)).addErrback(unwrapFirstError) - - -class Linearizer(object): - """Limits concurrent access to resources based on a key. Useful to ensure - only a few things happen at a time on a given resource. - - Example: - - with (yield limiter.queue("test_key")): - # do some work. - - """ - def __init__(self, name=None, max_count=1, clock=None): - """ - Args: - max_count(int): The maximum number of concurrent accesses - """ - if name is None: - self.name = id(self) - else: - self.name = name - - if not clock: - from twisted.internet import reactor - clock = Clock(reactor) - self._clock = clock - self.max_count = max_count - - # key_to_defer is a map from the key to a 2 element list where - # the first element is the number of things executing, and - # the second element is an OrderedDict, where the keys are deferreds for the - # things blocked from executing. - self.key_to_defer = {} - - @defer.inlineCallbacks - def queue(self, key): - entry = self.key_to_defer.setdefault(key, [0, collections.OrderedDict()]) - - # If the number of things executing is greater than the maximum - # then add a deferred to the list of blocked items - # When on of the things currently executing finishes it will callback - # this item so that it can continue executing. - if entry[0] >= self.max_count: - new_defer = defer.Deferred() - entry[1][new_defer] = 1 - - logger.info( - "Waiting to acquire linearizer lock %r for key %r", self.name, key, - ) - try: - yield make_deferred_yieldable(new_defer) - except Exception as e: - if isinstance(e, CancelledError): - logger.info( - "Cancelling wait for linearizer lock %r for key %r", - self.name, key, - ) - else: - logger.warn( - "Unexpected exception waiting for linearizer lock %r for key %r", - self.name, key, - ) - - # we just have to take ourselves back out of the queue. - del entry[1][new_defer] - raise - - logger.info("Acquired linearizer lock %r for key %r", self.name, key) - entry[0] += 1 - - # if the code holding the lock completes synchronously, then it - # will recursively run the next claimant on the list. That can - # relatively rapidly lead to stack exhaustion. This is essentially - # the same problem as http://twistedmatrix.com/trac/ticket/9304. - # - # In order to break the cycle, we add a cheeky sleep(0) here to - # ensure that we fall back to the reactor between each iteration. - # - # (This needs to happen while we hold the lock, and the context manager's exit - # code must be synchronous, so this is the only sensible place.) - yield self._clock.sleep(0) - - else: - logger.info( - "Acquired uncontended linearizer lock %r for key %r", self.name, key, - ) - entry[0] += 1 - - @contextmanager - def _ctx_manager(): - try: - yield - finally: - logger.info("Releasing linearizer lock %r for key %r", self.name, key) - - # We've finished executing so check if there are any things - # blocked waiting to execute and start one of them - entry[0] -= 1 - - if entry[1]: - (next_def, _) = entry[1].popitem(last=False) - - # we need to run the next thing in the sentinel context. - with PreserveLoggingContext(): - next_def.callback(None) - elif entry[0] == 0: - # We were the last thing for this key: remove it from the - # map. - del self.key_to_defer[key] - - defer.returnValue(_ctx_manager()) - - -class ReadWriteLock(object): - """A deferred style read write lock. - - Example: - - with (yield read_write_lock.read("test_key")): - # do some work - """ - - # IMPLEMENTATION NOTES - # - # We track the most recent queued reader and writer deferreds (which get - # resolved when they release the lock). - # - # Read: We know its safe to acquire a read lock when the latest writer has - # been resolved. The new reader is appeneded to the list of latest readers. - # - # Write: We know its safe to acquire the write lock when both the latest - # writers and readers have been resolved. The new writer replaces the latest - # writer. - - def __init__(self): - # Latest readers queued - self.key_to_current_readers = {} - - # Latest writer queued - self.key_to_current_writer = {} - - @defer.inlineCallbacks - def read(self, key): - new_defer = defer.Deferred() - - curr_readers = self.key_to_current_readers.setdefault(key, set()) - curr_writer = self.key_to_current_writer.get(key, None) - - curr_readers.add(new_defer) - - # We wait for the latest writer to finish writing. We can safely ignore - # any existing readers... as they're readers. - yield make_deferred_yieldable(curr_writer) - - @contextmanager - def _ctx_manager(): - try: - yield - finally: - new_defer.callback(None) - self.key_to_current_readers.get(key, set()).discard(new_defer) - - defer.returnValue(_ctx_manager()) - - @defer.inlineCallbacks - def write(self, key): - new_defer = defer.Deferred() - - curr_readers = self.key_to_current_readers.get(key, set()) - curr_writer = self.key_to_current_writer.get(key, None) - - # We wait on all latest readers and writer. - to_wait_on = list(curr_readers) - if curr_writer: - to_wait_on.append(curr_writer) - - # We can clear the list of current readers since the new writer waits - # for them to finish. - curr_readers.clear() - self.key_to_current_writer[key] = new_defer - - yield make_deferred_yieldable(defer.gatherResults(to_wait_on)) - - @contextmanager - def _ctx_manager(): - try: - yield - finally: - new_defer.callback(None) - if self.key_to_current_writer[key] == new_defer: - self.key_to_current_writer.pop(key) - - defer.returnValue(_ctx_manager()) - - -class DeferredTimeoutError(Exception): - """ - This error is raised by default when a L{Deferred} times out. - """ - - -def add_timeout_to_deferred(deferred, timeout, reactor, on_timeout_cancel=None): - """ - Add a timeout to a deferred by scheduling it to be cancelled after - timeout seconds. - - This is essentially a backport of deferred.addTimeout, which was introduced - in twisted 16.5. - - If the deferred gets timed out, it errbacks with a DeferredTimeoutError, - unless a cancelable function was passed to its initialization or unless - a different on_timeout_cancel callable is provided. - - Args: - deferred (defer.Deferred): deferred to be timed out - timeout (Number): seconds to time out after - reactor (twisted.internet.reactor): the Twisted reactor to use - - on_timeout_cancel (callable): A callable which is called immediately - after the deferred times out, and not if this deferred is - otherwise cancelled before the timeout. - - It takes an arbitrary value, which is the value of the deferred at - that exact point in time (probably a CancelledError Failure), and - the timeout. - - The default callable (if none is provided) will translate a - CancelledError Failure into a DeferredTimeoutError. - """ - timed_out = [False] - - def time_it_out(): - timed_out[0] = True - deferred.cancel() - - delayed_call = reactor.callLater(timeout, time_it_out) - - def convert_cancelled(value): - if timed_out[0]: - to_call = on_timeout_cancel or _cancelled_to_timed_out_error - return to_call(value, timeout) - return value - - deferred.addBoth(convert_cancelled) - - def cancel_timeout(result): - # stop the pending call to cancel the deferred if it's been fired - if delayed_call.active(): - delayed_call.cancel() - return result - - deferred.addBoth(cancel_timeout) - - -def _cancelled_to_timed_out_error(value, timeout): - if isinstance(value, failure.Failure): - value.trap(CancelledError) - raise DeferredTimeoutError(timeout, "Deferred") - return value diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py new file mode 100644 index 0000000000..a7094e2fb4 --- /dev/null +++ b/synapse/util/async_helpers.py @@ -0,0 +1,415 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import collections +import logging +from contextlib import contextmanager + +from six.moves import range + +from twisted.internet import defer +from twisted.internet.defer import CancelledError +from twisted.python import failure + +from synapse.util import Clock, logcontext, unwrapFirstError + +from .logcontext import ( + PreserveLoggingContext, + make_deferred_yieldable, + run_in_background, +) + +logger = logging.getLogger(__name__) + + +class ObservableDeferred(object): + """Wraps a deferred object so that we can add observer deferreds. These + observer deferreds do not affect the callback chain of the original + deferred. + + If consumeErrors is true errors will be captured from the origin deferred. + + Cancelling or otherwise resolving an observer will not affect the original + ObservableDeferred. + + NB that it does not attempt to do anything with logcontexts; in general + you should probably make_deferred_yieldable the deferreds + returned by `observe`, and ensure that the original deferred runs its + callbacks in the sentinel logcontext. + """ + + __slots__ = ["_deferred", "_observers", "_result"] + + def __init__(self, deferred, consumeErrors=False): + object.__setattr__(self, "_deferred", deferred) + object.__setattr__(self, "_result", None) + object.__setattr__(self, "_observers", set()) + + def callback(r): + object.__setattr__(self, "_result", (True, r)) + while self._observers: + try: + # TODO: Handle errors here. + self._observers.pop().callback(r) + except Exception: + pass + return r + + def errback(f): + object.__setattr__(self, "_result", (False, f)) + while self._observers: + try: + # TODO: Handle errors here. + self._observers.pop().errback(f) + except Exception: + pass + + if consumeErrors: + return None + else: + return f + + deferred.addCallbacks(callback, errback) + + def observe(self): + """Observe the underlying deferred. + + Can return either a deferred if the underlying deferred is still pending + (or has failed), or the actual value. Callers may need to use maybeDeferred. + """ + if not self._result: + d = defer.Deferred() + + def remove(r): + self._observers.discard(d) + return r + d.addBoth(remove) + + self._observers.add(d) + return d + else: + success, res = self._result + return res if success else defer.fail(res) + + def observers(self): + return self._observers + + def has_called(self): + return self._result is not None + + def has_succeeded(self): + return self._result is not None and self._result[0] is True + + def get_result(self): + return self._result[1] + + def __getattr__(self, name): + return getattr(self._deferred, name) + + def __setattr__(self, name, value): + setattr(self._deferred, name, value) + + def __repr__(self): + return "" % ( + id(self), self._result, self._deferred, + ) + + +def concurrently_execute(func, args, limit): + """Executes the function with each argument conncurrently while limiting + the number of concurrent executions. + + Args: + func (func): Function to execute, should return a deferred. + args (list): List of arguments to pass to func, each invocation of func + gets a signle argument. + limit (int): Maximum number of conccurent executions. + + Returns: + deferred: Resolved when all function invocations have finished. + """ + it = iter(args) + + @defer.inlineCallbacks + def _concurrently_execute_inner(): + try: + while True: + yield func(next(it)) + except StopIteration: + pass + + return logcontext.make_deferred_yieldable(defer.gatherResults([ + run_in_background(_concurrently_execute_inner) + for _ in range(limit) + ], consumeErrors=True)).addErrback(unwrapFirstError) + + +class Linearizer(object): + """Limits concurrent access to resources based on a key. Useful to ensure + only a few things happen at a time on a given resource. + + Example: + + with (yield limiter.queue("test_key")): + # do some work. + + """ + def __init__(self, name=None, max_count=1, clock=None): + """ + Args: + max_count(int): The maximum number of concurrent accesses + """ + if name is None: + self.name = id(self) + else: + self.name = name + + if not clock: + from twisted.internet import reactor + clock = Clock(reactor) + self._clock = clock + self.max_count = max_count + + # key_to_defer is a map from the key to a 2 element list where + # the first element is the number of things executing, and + # the second element is an OrderedDict, where the keys are deferreds for the + # things blocked from executing. + self.key_to_defer = {} + + @defer.inlineCallbacks + def queue(self, key): + entry = self.key_to_defer.setdefault(key, [0, collections.OrderedDict()]) + + # If the number of things executing is greater than the maximum + # then add a deferred to the list of blocked items + # When on of the things currently executing finishes it will callback + # this item so that it can continue executing. + if entry[0] >= self.max_count: + new_defer = defer.Deferred() + entry[1][new_defer] = 1 + + logger.info( + "Waiting to acquire linearizer lock %r for key %r", self.name, key, + ) + try: + yield make_deferred_yieldable(new_defer) + except Exception as e: + if isinstance(e, CancelledError): + logger.info( + "Cancelling wait for linearizer lock %r for key %r", + self.name, key, + ) + else: + logger.warn( + "Unexpected exception waiting for linearizer lock %r for key %r", + self.name, key, + ) + + # we just have to take ourselves back out of the queue. + del entry[1][new_defer] + raise + + logger.info("Acquired linearizer lock %r for key %r", self.name, key) + entry[0] += 1 + + # if the code holding the lock completes synchronously, then it + # will recursively run the next claimant on the list. That can + # relatively rapidly lead to stack exhaustion. This is essentially + # the same problem as http://twistedmatrix.com/trac/ticket/9304. + # + # In order to break the cycle, we add a cheeky sleep(0) here to + # ensure that we fall back to the reactor between each iteration. + # + # (This needs to happen while we hold the lock, and the context manager's exit + # code must be synchronous, so this is the only sensible place.) + yield self._clock.sleep(0) + + else: + logger.info( + "Acquired uncontended linearizer lock %r for key %r", self.name, key, + ) + entry[0] += 1 + + @contextmanager + def _ctx_manager(): + try: + yield + finally: + logger.info("Releasing linearizer lock %r for key %r", self.name, key) + + # We've finished executing so check if there are any things + # blocked waiting to execute and start one of them + entry[0] -= 1 + + if entry[1]: + (next_def, _) = entry[1].popitem(last=False) + + # we need to run the next thing in the sentinel context. + with PreserveLoggingContext(): + next_def.callback(None) + elif entry[0] == 0: + # We were the last thing for this key: remove it from the + # map. + del self.key_to_defer[key] + + defer.returnValue(_ctx_manager()) + + +class ReadWriteLock(object): + """A deferred style read write lock. + + Example: + + with (yield read_write_lock.read("test_key")): + # do some work + """ + + # IMPLEMENTATION NOTES + # + # We track the most recent queued reader and writer deferreds (which get + # resolved when they release the lock). + # + # Read: We know its safe to acquire a read lock when the latest writer has + # been resolved. The new reader is appeneded to the list of latest readers. + # + # Write: We know its safe to acquire the write lock when both the latest + # writers and readers have been resolved. The new writer replaces the latest + # writer. + + def __init__(self): + # Latest readers queued + self.key_to_current_readers = {} + + # Latest writer queued + self.key_to_current_writer = {} + + @defer.inlineCallbacks + def read(self, key): + new_defer = defer.Deferred() + + curr_readers = self.key_to_current_readers.setdefault(key, set()) + curr_writer = self.key_to_current_writer.get(key, None) + + curr_readers.add(new_defer) + + # We wait for the latest writer to finish writing. We can safely ignore + # any existing readers... as they're readers. + yield make_deferred_yieldable(curr_writer) + + @contextmanager + def _ctx_manager(): + try: + yield + finally: + new_defer.callback(None) + self.key_to_current_readers.get(key, set()).discard(new_defer) + + defer.returnValue(_ctx_manager()) + + @defer.inlineCallbacks + def write(self, key): + new_defer = defer.Deferred() + + curr_readers = self.key_to_current_readers.get(key, set()) + curr_writer = self.key_to_current_writer.get(key, None) + + # We wait on all latest readers and writer. + to_wait_on = list(curr_readers) + if curr_writer: + to_wait_on.append(curr_writer) + + # We can clear the list of current readers since the new writer waits + # for them to finish. + curr_readers.clear() + self.key_to_current_writer[key] = new_defer + + yield make_deferred_yieldable(defer.gatherResults(to_wait_on)) + + @contextmanager + def _ctx_manager(): + try: + yield + finally: + new_defer.callback(None) + if self.key_to_current_writer[key] == new_defer: + self.key_to_current_writer.pop(key) + + defer.returnValue(_ctx_manager()) + + +class DeferredTimeoutError(Exception): + """ + This error is raised by default when a L{Deferred} times out. + """ + + +def add_timeout_to_deferred(deferred, timeout, reactor, on_timeout_cancel=None): + """ + Add a timeout to a deferred by scheduling it to be cancelled after + timeout seconds. + + This is essentially a backport of deferred.addTimeout, which was introduced + in twisted 16.5. + + If the deferred gets timed out, it errbacks with a DeferredTimeoutError, + unless a cancelable function was passed to its initialization or unless + a different on_timeout_cancel callable is provided. + + Args: + deferred (defer.Deferred): deferred to be timed out + timeout (Number): seconds to time out after + reactor (twisted.internet.reactor): the Twisted reactor to use + + on_timeout_cancel (callable): A callable which is called immediately + after the deferred times out, and not if this deferred is + otherwise cancelled before the timeout. + + It takes an arbitrary value, which is the value of the deferred at + that exact point in time (probably a CancelledError Failure), and + the timeout. + + The default callable (if none is provided) will translate a + CancelledError Failure into a DeferredTimeoutError. + """ + timed_out = [False] + + def time_it_out(): + timed_out[0] = True + deferred.cancel() + + delayed_call = reactor.callLater(timeout, time_it_out) + + def convert_cancelled(value): + if timed_out[0]: + to_call = on_timeout_cancel or _cancelled_to_timed_out_error + return to_call(value, timeout) + return value + + deferred.addBoth(convert_cancelled) + + def cancel_timeout(result): + # stop the pending call to cancel the deferred if it's been fired + if delayed_call.active(): + delayed_call.cancel() + return result + + deferred.addBoth(cancel_timeout) + + +def _cancelled_to_timed_out_error(value, timeout): + if isinstance(value, failure.Failure): + value.trap(CancelledError) + raise DeferredTimeoutError(timeout, "Deferred") + return value diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 861c24809c..187510576a 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -25,7 +25,7 @@ from six import itervalues, string_types from twisted.internet import defer from synapse.util import logcontext, unwrapFirstError -from synapse.util.async import ObservableDeferred +from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches import get_cache_factor_for from synapse.util.caches.lrucache import LruCache from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py index a8491b42d5..afb03b2e1b 100644 --- a/synapse/util/caches/response_cache.py +++ b/synapse/util/caches/response_cache.py @@ -16,7 +16,7 @@ import logging from twisted.internet import defer -from synapse.util.async import ObservableDeferred +from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches import register_cache from synapse.util.logcontext import make_deferred_yieldable, run_in_background diff --git a/synapse/util/caches/snapshot_cache.py b/synapse/util/caches/snapshot_cache.py index d03678b8c8..8318db8d2c 100644 --- a/synapse/util/caches/snapshot_cache.py +++ b/synapse/util/caches/snapshot_cache.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.util.async import ObservableDeferred +from synapse.util.async_helpers import ObservableDeferred class SnapshotCache(object): diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py index 8dcae50b39..07e83fadda 100644 --- a/synapse/util/logcontext.py +++ b/synapse/util/logcontext.py @@ -526,7 +526,7 @@ _to_ignore = [ "synapse.util.logcontext", "synapse.http.server", "synapse.storage._base", - "synapse.util.async", + "synapse.util.async_helpers", ] diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index 6d6f00c5c5..52eff2a104 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -18,7 +18,7 @@ from mock import Mock from twisted.internet import defer -from synapse.util.async import ObservableDeferred +from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches.descriptors import Cache, cached from tests import unittest diff --git a/tests/util/test_linearizer.py b/tests/util/test_linearizer.py index 4729bd5a0a..7f48a72de8 100644 --- a/tests/util/test_linearizer.py +++ b/tests/util/test_linearizer.py @@ -20,7 +20,7 @@ from twisted.internet import defer, reactor from twisted.internet.defer import CancelledError from synapse.util import Clock, logcontext -from synapse.util.async import Linearizer +from synapse.util.async_helpers import Linearizer from tests import unittest diff --git a/tests/util/test_rwlock.py b/tests/util/test_rwlock.py index 24194e3b25..7cd470be67 100644 --- a/tests/util/test_rwlock.py +++ b/tests/util/test_rwlock.py @@ -14,7 +14,7 @@ # limitations under the License. -from synapse.util.async import ReadWriteLock +from synapse.util.async_helpers import ReadWriteLock from tests import unittest -- cgit 1.5.1 From 31fa743567c3f70c66c7af2f1332560b04b78f7a Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Sat, 11 Aug 2018 22:38:34 +0100 Subject: fix sqlite/postgres incompatibility in reap_monthly_active_users --- synapse/storage/monthly_active_users.py | 44 +++++++++++++++++++++------------ 1 file changed, 28 insertions(+), 16 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/monthly_active_users.py b/synapse/storage/monthly_active_users.py index d47dcef3a0..9a6e699386 100644 --- a/synapse/storage/monthly_active_users.py +++ b/synapse/storage/monthly_active_users.py @@ -64,23 +64,27 @@ class MonthlyActiveUsersStore(SQLBaseStore): Deferred[] """ def _reap_users(txn): + # Purge stale users thirty_days_ago = ( int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30) ) - # Purge stale users - - # questionmarks is a hack to overcome sqlite not supporting - # tuples in 'WHERE IN %s' - questionmarks = '?' * len(self.reserved_users) query_args = [thirty_days_ago] - query_args.extend(self.reserved_users) - - sql = """ - DELETE FROM monthly_active_users - WHERE timestamp < ? - AND user_id NOT IN ({}) - """.format(','.join(questionmarks)) + base_sql = "DELETE FROM monthly_active_users WHERE timestamp < ?" + + # Need if/else since 'AND user_id NOT IN ({})' fails on Postgres + # when len(reserved_users) == 0. Works fine on sqlite. + if len(self.reserved_users) > 0: + # questionmarks is a hack to overcome sqlite not supporting + # tuples in 'WHERE IN %s' + questionmarks = '?' * len(self.reserved_users) + + query_args.extend(self.reserved_users) + sql = base_sql + """ AND user_id NOT IN ({})""".format( + ','.join(questionmarks) + ) + else: + sql = base_sql txn.execute(sql, query_args) @@ -93,16 +97,24 @@ class MonthlyActiveUsersStore(SQLBaseStore): # negative LIMIT values. So there is no way to write it that both can # support query_args = [self.hs.config.max_mau_value] - query_args.extend(self.reserved_users) - sql = """ + + base_sql = """ DELETE FROM monthly_active_users WHERE user_id NOT IN ( SELECT user_id FROM monthly_active_users ORDER BY timestamp DESC LIMIT ? ) - AND user_id NOT IN ({}) - """.format(','.join(questionmarks)) + """ + # Need if/else since 'AND user_id NOT IN ({})' fails on Postgres + # when len(reserved_users) == 0. Works fine on sqlite. + if len(self.reserved_users) > 0: + query_args.extend(self.reserved_users) + sql = base_sql + """ AND user_id NOT IN ({})""".format( + ','.join(questionmarks) + ) + else: + sql = base_sql txn.execute(sql, query_args) yield self.runInteraction("reap_monthly_active_users", _reap_users) -- cgit 1.5.1 From 99dd975dae7baaaef2a3b0a92fa51965b121ae34 Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Mon, 13 Aug 2018 16:47:46 +1000 Subject: Run tests under PostgreSQL (#3423) --- .travis.yml | 10 +++ changelog.d/3423.misc | 1 + synapse/handlers/presence.py | 5 ++ synapse/storage/client_ips.py | 5 ++ tests/__init__.py | 3 + tests/api/test_auth.py | 2 +- tests/api/test_filtering.py | 5 +- tests/crypto/test_keyring.py | 6 +- tests/handlers/test_auth.py | 2 +- tests/handlers/test_device.py | 2 +- tests/handlers/test_directory.py | 1 + tests/handlers/test_e2e_keys.py | 2 +- tests/handlers/test_profile.py | 1 + tests/handlers/test_register.py | 1 + tests/handlers/test_typing.py | 1 + tests/replication/slave/storage/_base.py | 1 + tests/rest/client/v1/test_admin.py | 2 +- tests/rest/client/v1/test_events.py | 1 + tests/rest/client/v1/test_profile.py | 1 + tests/rest/client/v1/test_register.py | 2 +- tests/rest/client/v1/test_rooms.py | 1 + tests/rest/client/v1/test_typing.py | 1 + tests/rest/client/v2_alpha/test_filter.py | 2 +- tests/rest/client/v2_alpha/test_register.py | 2 +- tests/rest/client/v2_alpha/test_sync.py | 2 +- tests/server.py | 10 ++- tests/storage/test_appservice.py | 13 ++- tests/storage/test_background_update.py | 4 +- tests/storage/test_client_ips.py | 2 +- tests/storage/test_devices.py | 2 +- tests/storage/test_directory.py | 2 +- tests/storage/test_end_to_end_keys.py | 3 +- tests/storage/test_event_federation.py | 2 +- tests/storage/test_event_push_actions.py | 2 +- tests/storage/test_keys.py | 2 +- tests/storage/test_monthly_active_users.py | 2 +- tests/storage/test_presence.py | 2 +- tests/storage/test_profile.py | 2 +- tests/storage/test_redaction.py | 2 +- tests/storage/test_registration.py | 2 +- tests/storage/test_room.py | 4 +- tests/storage/test_roommember.py | 2 +- tests/storage/test_state.py | 2 +- tests/storage/test_user_directory.py | 2 +- tests/test_federation.py | 5 +- tests/test_server.py | 2 +- tests/test_visibility.py | 2 +- tests/utils.py | 133 ++++++++++++++++++++++++---- tox.ini | 20 ++++- 49 files changed, 227 insertions(+), 59 deletions(-) create mode 100644 changelog.d/3423.misc (limited to 'synapse/storage') diff --git a/.travis.yml b/.travis.yml index b34b17af75..318701c9f8 100644 --- a/.travis.yml +++ b/.travis.yml @@ -8,6 +8,9 @@ before_script: - git remote set-branches --add origin develop - git fetch origin develop +services: + - postgresql + matrix: fast_finish: true include: @@ -20,6 +23,9 @@ matrix: - python: 2.7 env: TOX_ENV=py27 + - python: 2.7 + env: TOX_ENV=py27-postgres TRIAL_FLAGS="-j 4" + - python: 3.6 env: TOX_ENV=py36 @@ -29,6 +35,10 @@ matrix: - python: 3.6 env: TOX_ENV=check-newsfragment + allow_failures: + - python: 2.7 + env: TOX_ENV=py27-postgres TRIAL_FLAGS="-j 4" + install: - pip install tox diff --git a/changelog.d/3423.misc b/changelog.d/3423.misc new file mode 100644 index 0000000000..51768c6d14 --- /dev/null +++ b/changelog.d/3423.misc @@ -0,0 +1 @@ +The test suite now can run under PostgreSQL. diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 20fc3b0323..3671d24f60 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -95,6 +95,7 @@ class PresenceHandler(object): Args: hs (synapse.server.HomeServer): """ + self.hs = hs self.is_mine = hs.is_mine self.is_mine_id = hs.is_mine_id self.clock = hs.get_clock() @@ -230,6 +231,10 @@ class PresenceHandler(object): earlier than they should when synapse is restarted. This affect of this is some spurious presence changes that will self-correct. """ + # If the DB pool has already terminated, don't try updating + if not self.hs.get_db_pool().running: + return + logger.info( "Performing _on_shutdown. Persisting %d unpersisted changes", len(self.user_to_current_state) diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py index 2489527f2c..8fc678fa67 100644 --- a/synapse/storage/client_ips.py +++ b/synapse/storage/client_ips.py @@ -96,6 +96,11 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): self._batch_row_update[key] = (user_agent, device_id, now) def _update_client_ips_batch(self): + + # If the DB pool has already terminated, don't try updating + if not self.hs.get_db_pool().running: + return + def update(): to_update = self._batch_row_update self._batch_row_update = {} diff --git a/tests/__init__.py b/tests/__init__.py index 24006c949e..9d9ca22829 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -15,4 +15,7 @@ from twisted.trial import util +from tests import utils + util.DEFAULT_TIMEOUT_DURATION = 10 +utils.setupdb() diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index f8e28876bb..a65689ba89 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -39,7 +39,7 @@ class AuthTestCase(unittest.TestCase): self.state_handler = Mock() self.store = Mock() - self.hs = yield setup_test_homeserver(handlers=None) + self.hs = yield setup_test_homeserver(self.addCleanup, handlers=None) self.hs.get_datastore = Mock(return_value=self.store) self.hs.handlers = TestHandlers(self.hs) self.auth = Auth(self.hs) diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index 1c2d71052c..48b2d3d663 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -46,7 +46,10 @@ class FilteringTestCase(unittest.TestCase): self.mock_http_client.put_json = DeferredMockCallable() hs = yield setup_test_homeserver( - handlers=None, http_client=self.mock_http_client, keyring=Mock() + self.addCleanup, + handlers=None, + http_client=self.mock_http_client, + keyring=Mock(), ) self.filtering = hs.get_filtering() diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index 0c6f510d11..8299dc72c8 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -58,12 +58,10 @@ class KeyringTestCase(unittest.TestCase): self.mock_perspective_server = MockPerspectiveServer() self.http_client = Mock() self.hs = yield utils.setup_test_homeserver( - handlers=None, http_client=self.http_client + self.addCleanup, handlers=None, http_client=self.http_client ) keys = self.mock_perspective_server.get_verify_keys() - self.hs.config.perspectives = { - self.mock_perspective_server.server_name: keys - } + self.hs.config.perspectives = {self.mock_perspective_server.server_name: keys} def check_context(self, _, expected): self.assertEquals( diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index ede01f8099..56c0f87fb7 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -35,7 +35,7 @@ class AuthHandlers(object): class AuthTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): - self.hs = yield setup_test_homeserver(handlers=None) + self.hs = yield setup_test_homeserver(self.addCleanup, handlers=None) self.hs.handlers = AuthHandlers(self.hs) self.auth_handler = self.hs.handlers.auth_handler self.macaroon_generator = self.hs.get_macaroon_generator() diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index d70d645504..56e7acd37c 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -34,7 +34,7 @@ class DeviceTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): - hs = yield utils.setup_test_homeserver() + hs = yield utils.setup_test_homeserver(self.addCleanup) self.handler = hs.get_device_handler() self.store = hs.get_datastore() self.clock = hs.get_clock() diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index 06de9f5eca..ec7355688b 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -46,6 +46,7 @@ class DirectoryTestCase(unittest.TestCase): self.mock_registry.register_query_handler = register_query_handler hs = yield setup_test_homeserver( + self.addCleanup, http_client=None, resource_for_federation=Mock(), federation_client=self.mock_federation, diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index 57ab228455..8dccc6826e 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -34,7 +34,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): self.hs = yield utils.setup_test_homeserver( - handlers=None, federation_client=mock.Mock() + self.addCleanup, handlers=None, federation_client=mock.Mock() ) self.handler = synapse.handlers.e2e_keys.E2eKeysHandler(self.hs) diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index 9268a6fe2b..62dc69003c 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -48,6 +48,7 @@ class ProfileTestCase(unittest.TestCase): self.mock_registry.register_query_handler = register_query_handler hs = yield setup_test_homeserver( + self.addCleanup, http_client=None, handlers=None, resource_for_federation=Mock(), diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index dbec81076f..d48d40c8dd 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -40,6 +40,7 @@ class RegistrationTestCase(unittest.TestCase): self.mock_distributor.declare("registered_user") self.mock_captcha_client = Mock() self.hs = yield setup_test_homeserver( + self.addCleanup, handlers=None, http_client=None, expire_access_token=True, diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index becfa77bfa..ad58073a14 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -67,6 +67,7 @@ class TypingNotificationsTestCase(unittest.TestCase): self.state_handler = Mock() hs = yield setup_test_homeserver( + self.addCleanup, "test", auth=self.auth, clock=self.clock, diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index c23b6e2cfd..65df116efc 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -54,6 +54,7 @@ class BaseSlavedStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): self.hs = yield setup_test_homeserver( + self.addCleanup, "blue", http_client=None, federation_client=Mock(), diff --git a/tests/rest/client/v1/test_admin.py b/tests/rest/client/v1/test_admin.py index 67d9ab94e2..1a553fa3f9 100644 --- a/tests/rest/client/v1/test_admin.py +++ b/tests/rest/client/v1/test_admin.py @@ -51,7 +51,7 @@ class UserRegisterTestCase(unittest.TestCase): self.secrets = Mock() self.hs = setup_test_homeserver( - http_client=None, clock=self.hs_clock, reactor=self.clock + self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.clock ) self.hs.config.registration_shared_secret = u"shared" diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py index 0316b74fa1..956f7fc4c4 100644 --- a/tests/rest/client/v1/test_events.py +++ b/tests/rest/client/v1/test_events.py @@ -41,6 +41,7 @@ class EventStreamPermissionsTestCase(RestTestCase): self.mock_resource = MockHttpResource(prefix=PATH_PREFIX) hs = yield setup_test_homeserver( + self.addCleanup, http_client=None, federation_client=Mock(), ratelimiter=NonCallableMock(spec_set=["send_message"]), diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py index 9ba0ffc19f..1eab9c3bdb 100644 --- a/tests/rest/client/v1/test_profile.py +++ b/tests/rest/client/v1/test_profile.py @@ -46,6 +46,7 @@ class ProfileTestCase(unittest.TestCase): ) hs = yield setup_test_homeserver( + self.addCleanup, "test", http_client=None, resource_for_client=self.mock_resource, diff --git a/tests/rest/client/v1/test_register.py b/tests/rest/client/v1/test_register.py index 6f15d69ecd..4be88b8a39 100644 --- a/tests/rest/client/v1/test_register.py +++ b/tests/rest/client/v1/test_register.py @@ -49,7 +49,7 @@ class CreateUserServletTestCase(unittest.TestCase): self.hs_clock = Clock(self.clock) self.hs = self.hs = setup_test_homeserver( - http_client=None, clock=self.hs_clock, reactor=self.clock + self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.clock ) self.hs.get_datastore = Mock(return_value=self.datastore) self.hs.get_handlers = Mock(return_value=handlers) diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 00fc796787..9fe0760496 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -50,6 +50,7 @@ class RoomBase(unittest.TestCase): self.hs_clock = Clock(self.clock) self.hs = setup_test_homeserver( + self.addCleanup, "red", http_client=None, clock=self.hs_clock, diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py index 7f1a435e7b..677265edf6 100644 --- a/tests/rest/client/v1/test_typing.py +++ b/tests/rest/client/v1/test_typing.py @@ -44,6 +44,7 @@ class RoomTypingTestCase(RestTestCase): self.auth_user_id = self.user_id hs = yield setup_test_homeserver( + self.addCleanup, "red", clock=self.clock, http_client=None, diff --git a/tests/rest/client/v2_alpha/test_filter.py b/tests/rest/client/v2_alpha/test_filter.py index de33b10a5f..8260c130f8 100644 --- a/tests/rest/client/v2_alpha/test_filter.py +++ b/tests/rest/client/v2_alpha/test_filter.py @@ -43,7 +43,7 @@ class FilterTestCase(unittest.TestCase): self.hs_clock = Clock(self.clock) self.hs = setup_test_homeserver( - http_client=None, clock=self.hs_clock, reactor=self.clock + self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.clock ) self.auth = self.hs.get_auth() diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index 9487babac3..b72bd0fb7f 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -47,7 +47,7 @@ class RegisterRestServletTestCase(unittest.TestCase): login_handler=self.login_handler, ) self.hs = setup_test_homeserver( - http_client=None, clock=self.hs_clock, reactor=self.clock + self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.clock ) self.hs.get_auth = Mock(return_value=self.auth) self.hs.get_handlers = Mock(return_value=self.handlers) diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py index bafc0d1df0..2e1d06c509 100644 --- a/tests/rest/client/v2_alpha/test_sync.py +++ b/tests/rest/client/v2_alpha/test_sync.py @@ -40,7 +40,7 @@ class FilterTestCase(unittest.TestCase): self.hs_clock = Clock(self.clock) self.hs = setup_test_homeserver( - http_client=None, clock=self.hs_clock, reactor=self.clock + self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.clock ) self.auth = self.hs.get_auth() diff --git a/tests/server.py b/tests/server.py index 05708be8b9..beb24cf032 100644 --- a/tests/server.py +++ b/tests/server.py @@ -147,12 +147,15 @@ class ThreadedMemoryReactorClock(MemoryReactorClock): return d -def setup_test_homeserver(*args, **kwargs): +def setup_test_homeserver(cleanup_func, *args, **kwargs): """ Set up a synchronous test server, driven by the reactor used by the homeserver. """ - d = _sth(*args, **kwargs).result + d = _sth(cleanup_func, *args, **kwargs).result + + if isinstance(d, Failure): + d.raiseException() # Make the thread pool synchronous. clock = d.get_clock() @@ -189,6 +192,9 @@ def setup_test_homeserver(*args, **kwargs): def start(self): pass + def stop(self): + pass + def callInThreadWithCallback(self, onResult, function, *args, **kwargs): def _(res): if isinstance(res, Failure): diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index fbb25a8844..c893990454 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -43,7 +43,10 @@ class ApplicationServiceStoreTestCase(unittest.TestCase): password_providers=[], ) hs = yield setup_test_homeserver( - config=config, federation_sender=Mock(), federation_client=Mock() + self.addCleanup, + config=config, + federation_sender=Mock(), + federation_client=Mock(), ) self.as_token = "token1" @@ -108,7 +111,10 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): password_providers=[], ) hs = yield setup_test_homeserver( - config=config, federation_sender=Mock(), federation_client=Mock() + self.addCleanup, + config=config, + federation_sender=Mock(), + federation_client=Mock(), ) self.db_pool = hs.get_db_pool() @@ -392,6 +398,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): app_service_config_files=[f1, f2], event_cache_size=1, password_providers=[] ) hs = yield setup_test_homeserver( + self.addCleanup, config=config, datastore=Mock(), federation_sender=Mock(), @@ -409,6 +416,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): app_service_config_files=[f1, f2], event_cache_size=1, password_providers=[] ) hs = yield setup_test_homeserver( + self.addCleanup, config=config, datastore=Mock(), federation_sender=Mock(), @@ -432,6 +440,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): app_service_config_files=[f1, f2], event_cache_size=1, password_providers=[] ) hs = yield setup_test_homeserver( + self.addCleanup, config=config, datastore=Mock(), federation_sender=Mock(), diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py index b4f6baf441..81403727c5 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py @@ -9,7 +9,9 @@ from tests.utils import setup_test_homeserver class BackgroundUpdateTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): - hs = yield setup_test_homeserver() # type: synapse.server.HomeServer + hs = yield setup_test_homeserver( + self.addCleanup + ) # type: synapse.server.HomeServer self.store = hs.get_datastore() self.clock = hs.get_clock() diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index ea00bbe84c..fa60d949ba 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -28,7 +28,7 @@ class ClientIpStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def setUp(self): - self.hs = yield tests.utils.setup_test_homeserver() + self.hs = yield tests.utils.setup_test_homeserver(self.addCleanup) self.store = self.hs.get_datastore() self.clock = self.hs.get_clock() diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py index 63bc42d9e0..aef4dfaf57 100644 --- a/tests/storage/test_devices.py +++ b/tests/storage/test_devices.py @@ -28,7 +28,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def setUp(self): - hs = yield tests.utils.setup_test_homeserver() + hs = yield tests.utils.setup_test_homeserver(self.addCleanup) self.store = hs.get_datastore() diff --git a/tests/storage/test_directory.py b/tests/storage/test_directory.py index 9a8ba2fcfe..b4510c1c8d 100644 --- a/tests/storage/test_directory.py +++ b/tests/storage/test_directory.py @@ -26,7 +26,7 @@ from tests.utils import setup_test_homeserver class DirectoryStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): - hs = yield setup_test_homeserver() + hs = yield setup_test_homeserver(self.addCleanup) self.store = DirectoryStore(None, hs) diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py index d45c775c2d..8f0aaece40 100644 --- a/tests/storage/test_end_to_end_keys.py +++ b/tests/storage/test_end_to_end_keys.py @@ -26,8 +26,7 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def setUp(self): - hs = yield tests.utils.setup_test_homeserver() - + hs = yield tests.utils.setup_test_homeserver(self.addCleanup) self.store = hs.get_datastore() @defer.inlineCallbacks diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index 66eb119581..2fdf34fdf6 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -22,7 +22,7 @@ import tests.utils class EventFederationWorkerStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def setUp(self): - hs = yield tests.utils.setup_test_homeserver() + hs = yield tests.utils.setup_test_homeserver(self.addCleanup) self.store = hs.get_datastore() @defer.inlineCallbacks diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index 5e87b4530d..b114c6fb1d 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -33,7 +33,7 @@ HIGHLIGHT = [ class EventPushActionsStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def setUp(self): - hs = yield tests.utils.setup_test_homeserver() + hs = yield tests.utils.setup_test_homeserver(self.addCleanup) self.store = hs.get_datastore() @defer.inlineCallbacks diff --git a/tests/storage/test_keys.py b/tests/storage/test_keys.py index ad0a55b324..47f4a8ceac 100644 --- a/tests/storage/test_keys.py +++ b/tests/storage/test_keys.py @@ -28,7 +28,7 @@ class KeyStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def setUp(self): - hs = yield tests.utils.setup_test_homeserver() + hs = yield tests.utils.setup_test_homeserver(self.addCleanup) self.store = hs.get_datastore() @defer.inlineCallbacks diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index 22b1072d9f..0a2c859f26 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -28,7 +28,7 @@ class MonthlyActiveUsersTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def setUp(self): - self.hs = yield setup_test_homeserver() + self.hs = yield setup_test_homeserver(self.addCleanup) self.store = self.hs.get_datastore() @defer.inlineCallbacks diff --git a/tests/storage/test_presence.py b/tests/storage/test_presence.py index 12c540dfab..b5b58ff660 100644 --- a/tests/storage/test_presence.py +++ b/tests/storage/test_presence.py @@ -26,7 +26,7 @@ from tests.utils import MockClock, setup_test_homeserver class PresenceStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): - hs = yield setup_test_homeserver(clock=MockClock()) + hs = yield setup_test_homeserver(self.addCleanup, clock=MockClock()) self.store = PresenceStore(None, hs) diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py index 5acbc8be0c..a1f6618bf9 100644 --- a/tests/storage/test_profile.py +++ b/tests/storage/test_profile.py @@ -26,7 +26,7 @@ from tests.utils import setup_test_homeserver class ProfileStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): - hs = yield setup_test_homeserver() + hs = yield setup_test_homeserver(self.addCleanup) self.store = ProfileStore(None, hs) diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index 85ce61e841..c4e9fb72bf 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -29,7 +29,7 @@ class RedactionTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): hs = yield setup_test_homeserver( - resource_for_federation=Mock(), http_client=None + self.addCleanup, resource_for_federation=Mock(), http_client=None ) self.store = hs.get_datastore() diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index bd96896bb3..4eda122edc 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -23,7 +23,7 @@ from tests.utils import setup_test_homeserver class RegistrationStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): - hs = yield setup_test_homeserver() + hs = yield setup_test_homeserver(self.addCleanup) self.db_pool = hs.get_db_pool() self.store = hs.get_datastore() diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py index 84d49b55c1..a1ea23b068 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py @@ -26,7 +26,7 @@ from tests.utils import setup_test_homeserver class RoomStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): - hs = yield setup_test_homeserver() + hs = yield setup_test_homeserver(self.addCleanup) # We can't test RoomStore on its own without the DirectoryStore, for # management of the 'room_aliases' table @@ -57,7 +57,7 @@ class RoomStoreTestCase(unittest.TestCase): class RoomEventsStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): - hs = setup_test_homeserver() + hs = setup_test_homeserver(self.addCleanup) # Room events need the full datastore, for persist_event() and # get_room_state() diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index 0d9908926a..c83ef60062 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -29,7 +29,7 @@ class RoomMemberStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): hs = yield setup_test_homeserver( - resource_for_federation=Mock(), http_client=None + self.addCleanup, resource_for_federation=Mock(), http_client=None ) # We can't test the RoomMemberStore on its own without the other event # storage logic diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index ed5b41644a..6168c46248 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -33,7 +33,7 @@ class StateStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def setUp(self): - hs = yield tests.utils.setup_test_homeserver() + hs = yield tests.utils.setup_test_homeserver(self.addCleanup) self.store = hs.get_datastore() self.event_builder_factory = hs.get_event_builder_factory() diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py index 7a273eab48..b46e0ea7e2 100644 --- a/tests/storage/test_user_directory.py +++ b/tests/storage/test_user_directory.py @@ -29,7 +29,7 @@ BOBBY = "@bobby:a" class UserDirectoryStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): - self.hs = yield setup_test_homeserver() + self.hs = yield setup_test_homeserver(self.addCleanup) self.store = UserDirectoryStore(None, self.hs) # alice and bob are both in !room_id. bobby is not but shares diff --git a/tests/test_federation.py b/tests/test_federation.py index f40ff29b52..2540604fcc 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -18,7 +18,10 @@ class MessageAcceptTests(unittest.TestCase): self.reactor = ThreadedMemoryReactorClock() self.hs_clock = Clock(self.reactor) self.homeserver = setup_test_homeserver( - http_client=self.http_client, clock=self.hs_clock, reactor=self.reactor + self.addCleanup, + http_client=self.http_client, + clock=self.hs_clock, + reactor=self.reactor, ) user_id = UserID("us", "test") diff --git a/tests/test_server.py b/tests/test_server.py index fc396226ea..895e490406 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -16,7 +16,7 @@ class JsonResourceTests(unittest.TestCase): self.reactor = MemoryReactorClock() self.hs_clock = Clock(self.reactor) self.homeserver = setup_test_homeserver( - http_client=None, clock=self.hs_clock, reactor=self.reactor + self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.reactor ) def test_handler_for_request(self): diff --git a/tests/test_visibility.py b/tests/test_visibility.py index 8643d63125..45a78338d6 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -31,7 +31,7 @@ TEST_ROOM_ID = "!TEST:ROOM" class FilterEventsForServerTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def setUp(self): - self.hs = yield setup_test_homeserver() + self.hs = yield setup_test_homeserver(self.addCleanup) self.event_creation_handler = self.hs.get_event_creation_handler() self.event_builder_factory = self.hs.get_event_builder_factory() self.store = self.hs.get_datastore() diff --git a/tests/utils.py b/tests/utils.py index 8668b5478f..90378326f8 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -13,7 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import atexit import hashlib +import os +import uuid from inspect import getcallargs from mock import Mock, patch @@ -27,23 +30,80 @@ from synapse.http.server import HttpServer from synapse.server import HomeServer from synapse.storage import PostgresEngine from synapse.storage.engines import create_engine -from synapse.storage.prepare_database import prepare_database +from synapse.storage.prepare_database import ( + _get_or_create_schema_state, + _setup_new_database, + prepare_database, +) from synapse.util.logcontext import LoggingContext from synapse.util.ratelimitutils import FederationRateLimiter # set this to True to run the tests against postgres instead of sqlite. -# It requires you to have a local postgres database called synapse_test, within -# which ALL TABLES WILL BE DROPPED -USE_POSTGRES_FOR_TESTS = False +USE_POSTGRES_FOR_TESTS = os.environ.get("SYNAPSE_POSTGRES", False) +POSTGRES_USER = os.environ.get("SYNAPSE_POSTGRES_USER", "postgres") +POSTGRES_BASE_DB = "_synapse_unit_tests_base_%s" % (os.getpid(),) + + +def setupdb(): + + # If we're using PostgreSQL, set up the db once + if USE_POSTGRES_FOR_TESTS: + pgconfig = { + "name": "psycopg2", + "args": { + "database": POSTGRES_BASE_DB, + "user": POSTGRES_USER, + "cp_min": 1, + "cp_max": 5, + }, + } + config = Mock() + config.password_providers = [] + config.database_config = pgconfig + db_engine = create_engine(pgconfig) + db_conn = db_engine.module.connect(user=POSTGRES_USER) + db_conn.autocommit = True + cur = db_conn.cursor() + cur.execute("DROP DATABASE IF EXISTS %s;" % (POSTGRES_BASE_DB,)) + cur.execute("CREATE DATABASE %s;" % (POSTGRES_BASE_DB,)) + cur.close() + db_conn.close() + + # Set up in the db + db_conn = db_engine.module.connect( + database=POSTGRES_BASE_DB, user=POSTGRES_USER + ) + cur = db_conn.cursor() + _get_or_create_schema_state(cur, db_engine) + _setup_new_database(cur, db_engine) + db_conn.commit() + cur.close() + db_conn.close() + + def _cleanup(): + db_conn = db_engine.module.connect(user=POSTGRES_USER) + db_conn.autocommit = True + cur = db_conn.cursor() + cur.execute("DROP DATABASE IF EXISTS %s;" % (POSTGRES_BASE_DB,)) + cur.close() + db_conn.close() + + atexit.register(_cleanup) @defer.inlineCallbacks def setup_test_homeserver( - name="test", datastore=None, config=None, reactor=None, **kargs + cleanup_func, name="test", datastore=None, config=None, reactor=None, **kargs ): - """Setup a homeserver suitable for running tests against. Keyword arguments - are passed to the Homeserver constructor. If no datastore is supplied a - datastore backed by an in-memory sqlite db will be given to the HS. + """ + Setup a homeserver suitable for running tests against. Keyword arguments + are passed to the Homeserver constructor. + + If no datastore is supplied, one is created and given to the homeserver. + + Args: + cleanup_func : The function used to register a cleanup routine for + after the test. """ if reactor is None: from twisted.internet import reactor @@ -95,9 +155,11 @@ def setup_test_homeserver( kargs["clock"] = MockClock() if USE_POSTGRES_FOR_TESTS: + test_db = "synapse_test_%s" % uuid.uuid4().hex + config.database_config = { "name": "psycopg2", - "args": {"database": "synapse_test", "cp_min": 1, "cp_max": 5}, + "args": {"database": test_db, "cp_min": 1, "cp_max": 5}, } else: config.database_config = { @@ -107,6 +169,21 @@ def setup_test_homeserver( db_engine = create_engine(config.database_config) + # Create the database before we actually try and connect to it, based off + # the template database we generate in setupdb() + if datastore is None and isinstance(db_engine, PostgresEngine): + db_conn = db_engine.module.connect( + database=POSTGRES_BASE_DB, user=POSTGRES_USER + ) + db_conn.autocommit = True + cur = db_conn.cursor() + cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,)) + cur.execute( + "CREATE DATABASE %s WITH TEMPLATE %s;" % (test_db, POSTGRES_BASE_DB) + ) + cur.close() + db_conn.close() + # we need to configure the connection pool to run the on_new_connection # function, so that we can test code that uses custom sqlite functions # (like rank). @@ -125,15 +202,35 @@ def setup_test_homeserver( reactor=reactor, **kargs ) - db_conn = hs.get_db_conn() - # make sure that the database is empty - if isinstance(db_engine, PostgresEngine): - cur = db_conn.cursor() - cur.execute("SELECT tablename FROM pg_tables where schemaname='public'") - rows = cur.fetchall() - for r in rows: - cur.execute("DROP TABLE %s CASCADE" % r[0]) - yield prepare_database(db_conn, db_engine, config) + + # Prepare the DB on SQLite -- PostgreSQL is a copy of an already up to + # date db + if not isinstance(db_engine, PostgresEngine): + db_conn = hs.get_db_conn() + yield prepare_database(db_conn, db_engine, config) + db_conn.commit() + db_conn.close() + + else: + # We need to do cleanup on PostgreSQL + def cleanup(): + # Close all the db pools + hs.get_db_pool().close() + + # Drop the test database + db_conn = db_engine.module.connect( + database=POSTGRES_BASE_DB, user=POSTGRES_USER + ) + db_conn.autocommit = True + cur = db_conn.cursor() + cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,)) + db_conn.commit() + cur.close() + db_conn.close() + + # Register the cleanup hook + cleanup_func(cleanup) + hs.setup() else: hs = HomeServer( diff --git a/tox.ini b/tox.ini index ed26644bd9..085f438989 100644 --- a/tox.ini +++ b/tox.ini @@ -1,7 +1,7 @@ [tox] envlist = packaging, py27, py36, pep8, check_isort -[testenv] +[base] deps = coverage Twisted>=15.1 @@ -15,6 +15,15 @@ deps = setenv = PYTHONDONTWRITEBYTECODE = no_byte_code +[testenv] +deps = + {[base]deps} + +setenv = + {[base]setenv} + +passenv = * + commands = /usr/bin/find "{toxinidir}" -name '*.pyc' -delete coverage run {env:COVERAGE_OPTS:} --source="{toxinidir}/synapse" \ @@ -46,6 +55,15 @@ commands = # ) usedevelop=true +[testenv:py27-postgres] +usedevelop=true +deps = + {[base]deps} + psycopg2 +setenv = + {[base]setenv} + SYNAPSE_POSTGRES = 1 + [testenv:py36] usedevelop=true commands = -- cgit 1.5.1 From 9ecbaf8ba8b12537862e3045a650fffd438ba8a5 Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Tue, 14 Aug 2018 16:55:28 +0100 Subject: adding missing yield --- synapse/storage/monthly_active_users.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse/storage') diff --git a/synapse/storage/monthly_active_users.py b/synapse/storage/monthly_active_users.py index 7b36accdb6..7e417f811e 100644 --- a/synapse/storage/monthly_active_users.py +++ b/synapse/storage/monthly_active_users.py @@ -46,7 +46,7 @@ class MonthlyActiveUsersStore(SQLBaseStore): tp["medium"], tp["address"] ) if user_id: - self.upsert_monthly_active_user(user_id) + yield self.upsert_monthly_active_user(user_id) reserved_user_list.append(user_id) else: logger.warning( -- cgit 1.5.1 From 2f78f432c421702e3756d029b3c669db837d8bcd Mon Sep 17 00:00:00 2001 From: Matthew Hodgson Date: Wed, 15 Aug 2018 17:35:22 +0200 Subject: speed up /members and add at= and membership params (#3568) --- changelog.d/3331.feature | 1 + changelog.d/3567.feature | 1 + changelog.d/3568.feature | 1 + synapse/handlers/message.py | 88 ++++++++++++++++++++++++++++++++++++------ synapse/rest/client/v1/room.py | 32 +++++++++++++-- synapse/storage/events.py | 2 +- synapse/storage/state.py | 66 ++++++++++++++++++++++++++++++- tests/storage/test_state.py | 2 +- 8 files changed, 174 insertions(+), 19 deletions(-) create mode 100644 changelog.d/3331.feature create mode 100644 changelog.d/3567.feature create mode 100644 changelog.d/3568.feature (limited to 'synapse/storage') diff --git a/changelog.d/3331.feature b/changelog.d/3331.feature new file mode 100644 index 0000000000..e574b9bcc3 --- /dev/null +++ b/changelog.d/3331.feature @@ -0,0 +1 @@ +add support for the include_redundant_members filter param as per MSC1227 diff --git a/changelog.d/3567.feature b/changelog.d/3567.feature new file mode 100644 index 0000000000..c74c1f57a9 --- /dev/null +++ b/changelog.d/3567.feature @@ -0,0 +1 @@ +make the /context API filter & lazy-load aware as per MSC1227 diff --git a/changelog.d/3568.feature b/changelog.d/3568.feature new file mode 100644 index 0000000000..247f02ba4e --- /dev/null +++ b/changelog.d/3568.feature @@ -0,0 +1 @@ +speed up /members API and add `at` and `membership` params as per MSC1227 diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 01a362360e..893c9bcdc4 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -25,7 +25,13 @@ from twisted.internet import defer from twisted.internet.defer import succeed from synapse.api.constants import MAX_DEPTH, EventTypes, Membership -from synapse.api.errors import AuthError, Codes, ConsentNotGivenError, SynapseError +from synapse.api.errors import ( + AuthError, + Codes, + ConsentNotGivenError, + NotFoundError, + SynapseError, +) from synapse.api.urls import ConsentURIBuilder from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.events.utils import serialize_event @@ -36,6 +42,7 @@ from synapse.util.async_helpers import Linearizer from synapse.util.frozenutils import frozendict_json_encoder from synapse.util.logcontext import run_in_background from synapse.util.metrics import measure_func +from synapse.visibility import filter_events_for_client from ._base import BaseHandler @@ -82,28 +89,85 @@ class MessageHandler(object): defer.returnValue(data) @defer.inlineCallbacks - def get_state_events(self, user_id, room_id, is_guest=False): + def get_state_events( + self, user_id, room_id, types=None, filtered_types=None, + at_token=None, is_guest=False, + ): """Retrieve all state events for a given room. If the user is joined to the room then return the current state. If the user has - left the room return the state events from when they left. + left the room return the state events from when they left. If an explicit + 'at' parameter is passed, return the state events as of that event, if + visible. Args: user_id(str): The user requesting state events. room_id(str): The room ID to get all state events from. + types(list[(str, str|None)]|None): List of (type, state_key) tuples + which are used to filter the state fetched. If `state_key` is None, + all events are returned of the given type. + May be None, which matches any key. + filtered_types(list[str]|None): Only apply filtering via `types` to this + list of event types. Other types of events are returned unfiltered. + If None, `types` filtering is applied to all events. + at_token(StreamToken|None): the stream token of the at which we are requesting + the stats. If the user is not allowed to view the state as of that + stream token, we raise a 403 SynapseError. If None, returns the current + state based on the current_state_events table. + is_guest(bool): whether this user is a guest Returns: A list of dicts representing state events. [{}, {}, {}] + Raises: + NotFoundError (404) if the at token does not yield an event + + AuthError (403) if the user doesn't have permission to view + members of this room. """ - membership, membership_event_id = yield self.auth.check_in_room_or_world_readable( - room_id, user_id - ) + if at_token: + # FIXME this claims to get the state at a stream position, but + # get_recent_events_for_room operates by topo ordering. This therefore + # does not reliably give you the state at the given stream position. + # (https://github.com/matrix-org/synapse/issues/3305) + last_events, _ = yield self.store.get_recent_events_for_room( + room_id, end_token=at_token.room_key, limit=1, + ) - if membership == Membership.JOIN: - room_state = yield self.state.get_current_state(room_id) - elif membership == Membership.LEAVE: - room_state = yield self.store.get_state_for_events( - [membership_event_id], None + if not last_events: + raise NotFoundError("Can't find event for token %s" % (at_token, )) + + visible_events = yield filter_events_for_client( + self.store, user_id, last_events, + ) + + event = last_events[0] + if visible_events: + room_state = yield self.store.get_state_for_events( + [event.event_id], types, filtered_types=filtered_types, + ) + room_state = room_state[event.event_id] + else: + raise AuthError( + 403, + "User %s not allowed to view events in room %s at token %s" % ( + user_id, room_id, at_token, + ) + ) + else: + membership, membership_event_id = ( + yield self.auth.check_in_room_or_world_readable( + room_id, user_id, + ) ) - room_state = room_state[membership_event_id] + + if membership == Membership.JOIN: + state_ids = yield self.store.get_filtered_current_state_ids( + room_id, types, filtered_types=filtered_types, + ) + room_state = yield self.store.get_events(state_ids.values()) + elif membership == Membership.LEAVE: + room_state = yield self.store.get_state_for_events( + [membership_event_id], types, filtered_types=filtered_types, + ) + room_state = room_state[membership_event_id] now = self.clock.time_msec() defer.returnValue( diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index fa5989e74e..fcc1091760 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -34,7 +34,7 @@ from synapse.http.servlet import ( parse_string, ) from synapse.streams.config import PaginationConfig -from synapse.types import RoomAlias, RoomID, ThirdPartyInstanceID, UserID +from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID from .base import ClientV1RestServlet, client_path_patterns @@ -384,15 +384,39 @@ class RoomMemberListRestServlet(ClientV1RestServlet): def on_GET(self, request, room_id): # TODO support Pagination stream API (limit/tokens) requester = yield self.auth.get_user_by_req(request) - events = yield self.message_handler.get_state_events( + handler = self.message_handler + + # request the state as of a given event, as identified by a stream token, + # for consistency with /messages etc. + # useful for getting the membership in retrospect as of a given /sync + # response. + at_token_string = parse_string(request, "at") + if at_token_string is None: + at_token = None + else: + at_token = StreamToken.from_string(at_token_string) + + # let you filter down on particular memberships. + # XXX: this may not be the best shape for this API - we could pass in a filter + # instead, except filters aren't currently aware of memberships. + # See https://github.com/matrix-org/matrix-doc/issues/1337 for more details. + membership = parse_string(request, "membership") + not_membership = parse_string(request, "not_membership") + + events = yield handler.get_state_events( room_id=room_id, user_id=requester.user.to_string(), + at_token=at_token, + types=[(EventTypes.Member, None)], ) chunk = [] for event in events: - if event["type"] != EventTypes.Member: + if ( + (membership and event['content'].get("membership") != membership) or + (not_membership and event['content'].get("membership") == not_membership) + ): continue chunk.append(event) @@ -401,6 +425,8 @@ class RoomMemberListRestServlet(ClientV1RestServlet): })) +# deprecated in favour of /members?membership=join? +# except it does custom AS logic and has a simpler return format class JoinedRoomMemberListRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/rooms/(?P[^/]*)/joined_members$") diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 135af54fa9..025a7fb6d9 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -1911,7 +1911,7 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore max_depth = max(row[0] for row in rows) if max_depth <= token.topological: - # We need to ensure we don't delete all the events from the datanase + # We need to ensure we don't delete all the events from the database # otherwise we wouldn't be able to send any events (due to not # having any backwards extremeties) raise SynapseError( diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 17b14d464b..754dfa6973 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -116,6 +116,69 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): _get_current_state_ids_txn, ) + # FIXME: how should this be cached? + def get_filtered_current_state_ids(self, room_id, types, filtered_types=None): + """Get the current state event of a given type for a room based on the + current_state_events table. This may not be as up-to-date as the result + of doing a fresh state resolution as per state_handler.get_current_state + Args: + room_id (str) + types (list[(Str, (Str|None))]): List of (type, state_key) tuples + which are used to filter the state fetched. `state_key` may be + None, which matches any `state_key` + filtered_types (list[Str]|None): List of types to apply the above filter to. + Returns: + deferred: dict of (type, state_key) -> event + """ + + include_other_types = False if filtered_types is None else True + + def _get_filtered_current_state_ids_txn(txn): + results = {} + sql = """SELECT type, state_key, event_id FROM current_state_events + WHERE room_id = ? %s""" + # Turns out that postgres doesn't like doing a list of OR's and + # is about 1000x slower, so we just issue a query for each specific + # type seperately. + if types: + clause_to_args = [ + ( + "AND type = ? AND state_key = ?", + (etype, state_key) + ) if state_key is not None else ( + "AND type = ?", + (etype,) + ) + for etype, state_key in types + ] + + if include_other_types: + unique_types = set(filtered_types) + clause_to_args.append( + ( + "AND type <> ? " * len(unique_types), + list(unique_types) + ) + ) + else: + # If types is None we fetch all the state, and so just use an + # empty where clause with no extra args. + clause_to_args = [("", [])] + for where_clause, where_args in clause_to_args: + args = [room_id] + args.extend(where_args) + txn.execute(sql % (where_clause,), args) + for row in txn: + typ, state_key, event_id = row + key = (intern_string(typ), intern_string(state_key)) + results[key] = event_id + return results + + return self.runInteraction( + "get_filtered_current_state_ids", + _get_filtered_current_state_ids_txn, + ) + @cached(max_entries=10000, iterable=True) def get_state_group_delta(self, state_group): """Given a state group try to return a previous group and a delta between @@ -389,8 +452,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): If None, `types` filtering is applied to all events. Returns: - deferred: A list of dicts corresponding to the event_ids given. - The dicts are mappings from (type, state_key) -> state_events + deferred: A dict of (event_id) -> (type, state_key) -> [state_events] """ event_to_groups = yield self._get_state_group_for_events( event_ids, diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index 6168c46248..ebfd969b36 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -148,7 +148,7 @@ class StateStoreTestCase(tests.unittest.TestCase): {(e3.type, e3.state_key): e3, (e5.type, e5.state_key): e5}, state ) - # check we can use filter_types to grab a specific room member + # check we can use filtered_types to grab a specific room member # without filtering out the other event types state = yield self.store.get_state_for_event( e5.event_id, -- cgit 1.5.1 From 3f543dc021dd9456c8ed2da7a1e4769a68c07729 Mon Sep 17 00:00:00 2001 From: Matthew Hodgson Date: Thu, 16 Aug 2018 10:46:50 +0200 Subject: initial cut at a room summary API (#3574) --- changelog.d/3574.feature | 1 + synapse/handlers/sync.py | 159 ++++++++++++++++++++++++++++++++--- synapse/rest/client/v2_alpha/sync.py | 1 + synapse/storage/_base.py | 7 +- synapse/storage/state.py | 5 +- synapse/storage/stream.py | 4 +- 6 files changed, 159 insertions(+), 18 deletions(-) create mode 100644 changelog.d/3574.feature (limited to 'synapse/storage') diff --git a/changelog.d/3574.feature b/changelog.d/3574.feature new file mode 100644 index 0000000000..87ac32df72 --- /dev/null +++ b/changelog.d/3574.feature @@ -0,0 +1 @@ +implement `summary` block in /sync response as per MSC688 diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 3b21a04a5d..ac3edf0cc9 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -75,6 +75,7 @@ class JoinedSyncResult(collections.namedtuple("JoinedSyncResult", [ "ephemeral", "account_data", "unread_notifications", + "summary", ])): __slots__ = [] @@ -503,10 +504,142 @@ class SyncHandler(object): state = {} defer.returnValue(state) + @defer.inlineCallbacks + def compute_summary(self, room_id, sync_config, batch, state, now_token): + """ Works out a room summary block for this room, summarising the number + of joined members in the room, and providing the 'hero' members if the + room has no name so clients can consistently name rooms. Also adds + state events to 'state' if needed to describe the heroes. + + Args: + room_id(str): + sync_config(synapse.handlers.sync.SyncConfig): + batch(synapse.handlers.sync.TimelineBatch): The timeline batch for + the room that will be sent to the user. + state(dict): dict of (type, state_key) -> Event as returned by + compute_state_delta + now_token(str): Token of the end of the current batch. + + Returns: + A deferred dict describing the room summary + """ + + # FIXME: this promulgates https://github.com/matrix-org/synapse/issues/3305 + last_events, _ = yield self.store.get_recent_event_ids_for_room( + room_id, end_token=now_token.room_key, limit=1, + ) + + if not last_events: + defer.returnValue(None) + return + + last_event = last_events[-1] + state_ids = yield self.store.get_state_ids_for_event( + last_event.event_id, [ + (EventTypes.Member, None), + (EventTypes.Name, ''), + (EventTypes.CanonicalAlias, ''), + ] + ) + + member_ids = { + state_key: event_id + for (t, state_key), event_id in state_ids.iteritems() + if t == EventTypes.Member + } + name_id = state_ids.get((EventTypes.Name, '')) + canonical_alias_id = state_ids.get((EventTypes.CanonicalAlias, '')) + + summary = {} + + # FIXME: it feels very heavy to load up every single membership event + # just to calculate the counts. + member_events = yield self.store.get_events(member_ids.values()) + + joined_user_ids = [] + invited_user_ids = [] + + for ev in member_events.values(): + if ev.content.get("membership") == Membership.JOIN: + joined_user_ids.append(ev.state_key) + elif ev.content.get("membership") == Membership.INVITE: + invited_user_ids.append(ev.state_key) + + # TODO: only send these when they change. + summary["m.joined_member_count"] = len(joined_user_ids) + summary["m.invited_member_count"] = len(invited_user_ids) + + if name_id or canonical_alias_id: + defer.returnValue(summary) + + # FIXME: order by stream ordering, not alphabetic + + me = sync_config.user.to_string() + if (joined_user_ids or invited_user_ids): + summary['m.heroes'] = sorted( + [ + user_id + for user_id in (joined_user_ids + invited_user_ids) + if user_id != me + ] + )[0:5] + else: + summary['m.heroes'] = sorted( + [user_id for user_id in member_ids.keys() if user_id != me] + )[0:5] + + if not sync_config.filter_collection.lazy_load_members(): + defer.returnValue(summary) + + # ensure we send membership events for heroes if needed + cache_key = (sync_config.user.to_string(), sync_config.device_id) + cache = self.get_lazy_loaded_members_cache(cache_key) + + # track which members the client should already know about via LL: + # Ones which are already in state... + existing_members = set( + user_id for (typ, user_id) in state.keys() + if typ == EventTypes.Member + ) + + # ...or ones which are in the timeline... + for ev in batch.events: + if ev.type == EventTypes.Member: + existing_members.add(ev.state_key) + + # ...and then ensure any missing ones get included in state. + missing_hero_event_ids = [ + member_ids[hero_id] + for hero_id in summary['m.heroes'] + if ( + cache.get(hero_id) != member_ids[hero_id] and + hero_id not in existing_members + ) + ] + + missing_hero_state = yield self.store.get_events(missing_hero_event_ids) + missing_hero_state = missing_hero_state.values() + + for s in missing_hero_state: + cache.set(s.state_key, s.event_id) + state[(EventTypes.Member, s.state_key)] = s + + defer.returnValue(summary) + + def get_lazy_loaded_members_cache(self, cache_key): + cache = self.lazy_loaded_members_cache.get(cache_key) + if cache is None: + logger.debug("creating LruCache for %r", cache_key) + cache = LruCache(LAZY_LOADED_MEMBERS_CACHE_MAX_SIZE) + self.lazy_loaded_members_cache[cache_key] = cache + else: + logger.debug("found LruCache for %r", cache_key) + return cache + @defer.inlineCallbacks def compute_state_delta(self, room_id, batch, sync_config, since_token, now_token, full_state): - """ Works out the differnce in state between the start of the timeline + """ Works out the difference in state between the start of the timeline and the previous sync. Args: @@ -520,7 +653,7 @@ class SyncHandler(object): full_state(bool): Whether to force returning the full state. Returns: - A deferred new event dictionary + A deferred dict of (type, state_key) -> Event """ # TODO(mjark) Check if the state events were received by the server # after the previous sync, since we need to include those state @@ -618,13 +751,7 @@ class SyncHandler(object): if lazy_load_members and not include_redundant_members: cache_key = (sync_config.user.to_string(), sync_config.device_id) - cache = self.lazy_loaded_members_cache.get(cache_key) - if cache is None: - logger.debug("creating LruCache for %r", cache_key) - cache = LruCache(LAZY_LOADED_MEMBERS_CACHE_MAX_SIZE) - self.lazy_loaded_members_cache[cache_key] = cache - else: - logger.debug("found LruCache for %r", cache_key) + cache = self.get_lazy_loaded_members_cache(cache_key) # if it's a new sync sequence, then assume the client has had # amnesia and doesn't want any recent lazy-loaded members @@ -1425,7 +1552,6 @@ class SyncHandler(object): if events == [] and tags is None: return - since_token = sync_result_builder.since_token now_token = sync_result_builder.now_token sync_config = sync_result_builder.sync_config @@ -1468,6 +1594,18 @@ class SyncHandler(object): full_state=full_state ) + summary = {} + if ( + sync_config.filter_collection.lazy_load_members() and + ( + any(ev.type == EventTypes.Member for ev in batch.events) or + since_token is None + ) + ): + summary = yield self.compute_summary( + room_id, sync_config, batch, state, now_token + ) + if room_builder.rtype == "joined": unread_notifications = {} room_sync = JoinedSyncResult( @@ -1477,6 +1615,7 @@ class SyncHandler(object): ephemeral=ephemeral, account_data=account_data_events, unread_notifications=unread_notifications, + summary=summary, ) if room_sync or always_include: diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index 8aa06faf23..1275baa1ba 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -370,6 +370,7 @@ class SyncRestServlet(RestServlet): ephemeral_events = room.ephemeral result["ephemeral"] = {"events": ephemeral_events} result["unread_notifications"] = room.unread_notifications + result["summary"] = room.summary return result diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 44f37b4c1e..08dffd774f 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -1150,17 +1150,16 @@ class SQLBaseStore(object): defer.returnValue(retval) def get_user_count_txn(self, txn): - """Get a total number of registerd users in the users list. + """Get a total number of registered users in the users list. Args: txn : Transaction object Returns: - defer.Deferred: resolves to int + int : number of users """ sql_count = "SELECT COUNT(*) FROM users WHERE is_guest = 0;" txn.execute(sql_count) - count = txn.fetchone()[0] - defer.returnValue(count) + return txn.fetchone()[0] def _simple_search_list(self, table, term, col, retcols, desc="_simple_search_list"): diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 754dfa6973..dd03c4168b 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -480,7 +480,8 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): @defer.inlineCallbacks def get_state_ids_for_events(self, event_ids, types=None, filtered_types=None): """ - Get the state dicts corresponding to a list of events + Get the state dicts corresponding to a list of events, containing the event_ids + of the state events (as opposed to the events themselves) Args: event_ids(list(str)): events whose state should be returned @@ -493,7 +494,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): If None, `types` filtering is applied to all events. Returns: - A deferred dict from event_id -> (type, state_key) -> state_event + A deferred dict from event_id -> (type, state_key) -> event_id """ event_to_groups = yield self._get_state_group_for_events( event_ids, diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index b9f2b74ac6..4c296d72c0 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -348,7 +348,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): end_token (str): The stream token representing now. Returns: - Deferred[tuple[list[FrozenEvent], str]]: Returns a list of + Deferred[tuple[list[FrozenEvent], str]]: Returns a list of events and a token pointing to the start of the returned events. The events returned are in ascending order. @@ -379,7 +379,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): end_token (str): The stream token representing now. Returns: - Deferred[tuple[list[_EventDictReturn], str]]: Returns a list of + Deferred[tuple[list[_EventDictReturn], str]]: Returns a list of _EventDictReturn and a token pointing to the start of the returned events. The events returned are in ascending order. -- cgit 1.5.1 From bcfeb44afe750dadd4199e7c02302b0157bdab11 Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Thu, 16 Aug 2018 22:55:32 +0100 Subject: call reap on start up and fix under reaping bug --- synapse/api/auth.py | 2 +- synapse/app/homeserver.py | 1 + synapse/storage/monthly_active_users.py | 5 ++++- tests/storage/test_monthly_active_users.py | 13 +++++++++++++ 4 files changed, 19 insertions(+), 2 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 3b2a2ab77a..ab1e3a4e35 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -799,7 +799,7 @@ class Auth(object): current_mau = yield self.store.get_monthly_active_count() if current_mau >= self.hs.config.max_mau_value: raise AuthError( - 403, "Monthly Active User Limits AU Limit Exceeded", + 403, "Monthly Active User Limit Exceeded", admin_uri=self.hs.config.admin_uri, errcode=Codes.RESOURCE_LIMIT_EXCEED ) diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index a98bb506e5..800b9c0e3c 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -525,6 +525,7 @@ def run(hs): clock.looping_call( hs.get_datastore().reap_monthly_active_users, 1000 * 60 * 60 ) + yield hs.get_datastore().reap_monthly_active_users() @defer.inlineCallbacks def generate_monthly_active_users(): diff --git a/synapse/storage/monthly_active_users.py b/synapse/storage/monthly_active_users.py index 7e417f811e..06f9a75a97 100644 --- a/synapse/storage/monthly_active_users.py +++ b/synapse/storage/monthly_active_users.py @@ -96,7 +96,10 @@ class MonthlyActiveUsersStore(SQLBaseStore): # While Postgres does not require 'LIMIT', but also does not support # negative LIMIT values. So there is no way to write it that both can # support - query_args = [self.hs.config.max_mau_value] + safe_guard = self.hs.config.max_mau_value - len(self.reserved_users) + # Must be greater than zero for postgres + safe_guard = safe_guard if safe_guard > 0 else 0 + query_args = [safe_guard] base_sql = """ DELETE FROM monthly_active_users diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index 511acbde9b..f2ed866ae7 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -75,6 +75,19 @@ class MonthlyActiveUsersTestCase(tests.unittest.TestCase): active_count = yield self.store.get_monthly_active_count() self.assertEquals(active_count, user_num) + # Test that regalar users are removed from the db + ru_count = 2 + yield self.store.upsert_monthly_active_user("@ru1:server") + yield self.store.upsert_monthly_active_user("@ru2:server") + active_count = yield self.store.get_monthly_active_count() + + self.assertEqual(active_count, user_num + ru_count) + self.hs.config.max_mau_value = user_num + yield self.store.reap_monthly_active_users() + + active_count = yield self.store.get_monthly_active_count() + self.assertEquals(active_count, user_num) + @defer.inlineCallbacks def test_can_insert_and_count_mau(self): count = yield self.store.get_monthly_active_count() -- cgit 1.5.1 From 38f708a2bbee13ebd34f1595910dd0e4fd532818 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 17 Aug 2018 11:37:42 +0100 Subject: Remote profile cache should remain in master worker --- synapse/storage/profile.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/profile.py b/synapse/storage/profile.py index 246ab836bd..88b50f33b5 100644 --- a/synapse/storage/profile.py +++ b/synapse/storage/profile.py @@ -94,6 +94,8 @@ class ProfileWorkerStore(SQLBaseStore): desc="set_profile_avatar_url", ) + +class ProfileStore(ProfileWorkerStore): def add_remote_profile_cache(self, user_id, displayname, avatar_url): """Ensure we are caching the remote user's profiles. @@ -180,7 +182,3 @@ class ProfileWorkerStore(SQLBaseStore): if res: defer.returnValue(True) - - -class ProfileStore(ProfileWorkerStore): - pass -- cgit 1.5.1 From bb81e78ec6c05edc95b25a954bdf1cf688d4d652 Mon Sep 17 00:00:00 2001 From: Matthew Hodgson Date: Wed, 22 Aug 2018 00:56:37 +0200 Subject: Split the state_group_cache in two (#3726) Splits the state_group_cache in two. One half contains normal state events; the other contains member events. The idea is that the lazyloading common case of: "I want a subset of member events plus all of the other state" can be accomplished efficiently by splitting the cache into two, and asking for "all events" from the non-members cache, and "just these keys" from the members cache. This means we can avoid having to make DictionaryCache aware of these sort of complicated queries, whilst letting LL requests benefit from the caching. Previously we were unable to sensibly use the caching and had to pull all state from the DB irrespective of the filtering, which made things slow. Hopefully fixes https://github.com/matrix-org/synapse/issues/3720. --- changelog.d/3726.misc | 1 + synapse/storage/state.py | 158 +++++++++++++++++++++++++++++++++++++++----- tests/storage/test_state.py | 105 ++++++++++++++++++++++++++--- 3 files changed, 236 insertions(+), 28 deletions(-) create mode 100644 changelog.d/3726.misc (limited to 'synapse/storage') diff --git a/changelog.d/3726.misc b/changelog.d/3726.misc new file mode 100644 index 0000000000..c4f66ec998 --- /dev/null +++ b/changelog.d/3726.misc @@ -0,0 +1 @@ +Split the state_group_cache into member and non-member state events (and so speed up LL /sync) diff --git a/synapse/storage/state.py b/synapse/storage/state.py index dd03c4168b..4b971efdba 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -60,8 +60,43 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): def __init__(self, db_conn, hs): super(StateGroupWorkerStore, self).__init__(db_conn, hs) + # Originally the state store used a single DictionaryCache to cache the + # event IDs for the state types in a given state group to avoid hammering + # on the state_group* tables. + # + # The point of using a DictionaryCache is that it can cache a subset + # of the state events for a given state group (i.e. a subset of the keys for a + # given dict which is an entry in the cache for a given state group ID). + # + # However, this poses problems when performing complicated queries + # on the store - for instance: "give me all the state for this group, but + # limit members to this subset of users", as DictionaryCache's API isn't + # rich enough to say "please cache any of these fields, apart from this subset". + # This is problematic when lazy loading members, which requires this behaviour, + # as without it the cache has no choice but to speculatively load all + # state events for the group, which negates the efficiency being sought. + # + # Rather than overcomplicating DictionaryCache's API, we instead split the + # state_group_cache into two halves - one for tracking non-member events, + # and the other for tracking member_events. This means that lazy loading + # queries can be made in a cache-friendly manner by querying both caches + # separately and then merging the result. So for the example above, you + # would query the members cache for a specific subset of state keys + # (which DictionaryCache will handle efficiently and fine) and the non-members + # cache for all state (which DictionaryCache will similarly handle fine) + # and then just merge the results together. + # + # We size the non-members cache to be smaller than the members cache as the + # vast majority of state in Matrix (today) is member events. + self._state_group_cache = DictionaryCache( - "*stateGroupCache*", 500000 * get_cache_factor_for("stateGroupCache") + "*stateGroupCache*", + # TODO: this hasn't been tuned yet + 50000 * get_cache_factor_for("stateGroupCache") + ) + self._state_group_members_cache = DictionaryCache( + "*stateGroupMembersCache*", + 500000 * get_cache_factor_for("stateGroupMembersCache") ) @defer.inlineCallbacks @@ -275,7 +310,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): }) @defer.inlineCallbacks - def _get_state_groups_from_groups(self, groups, types): + def _get_state_groups_from_groups(self, groups, types, members=None): """Returns the state groups for a given set of groups, filtering on types of state events. @@ -284,6 +319,9 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): types (Iterable[str, str|None]|None): list of 2-tuples of the form (`type`, `state_key`), where a `state_key` of `None` matches all state_keys for the `type`. If None, all types are returned. + members (bool|None): If not None, then, in addition to any filtering + implied by types, the results are also filtered to only include + member events (if True), or to exclude member events (if False) Returns: dictionary state_group -> (dict of (type, state_key) -> event id) @@ -294,14 +332,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): for chunk in chunks: res = yield self.runInteraction( "_get_state_groups_from_groups", - self._get_state_groups_from_groups_txn, chunk, types, + self._get_state_groups_from_groups_txn, chunk, types, members, ) results.update(res) defer.returnValue(results) def _get_state_groups_from_groups_txn( - self, txn, groups, types=None, + self, txn, groups, types=None, members=None, ): results = {group: {} for group in groups} @@ -339,6 +377,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): %s """) + if members is True: + sql += " AND type = '%s'" % (EventTypes.Member,) + elif members is False: + sql += " AND type <> '%s'" % (EventTypes.Member,) + # Turns out that postgres doesn't like doing a list of OR's and # is about 1000x slower, so we just issue a query for each specific # type seperately. @@ -386,6 +429,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): else: where_clause = "" + if members is True: + where_clause += " AND type = '%s'" % EventTypes.Member + elif members is False: + where_clause += " AND type <> '%s'" % EventTypes.Member + # We don't use WITH RECURSIVE on sqlite3 as there are distributions # that ship with an sqlite3 version that doesn't support it (e.g. wheezy) for group in groups: @@ -580,10 +628,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): defer.returnValue({row["event_id"]: row["state_group"] for row in rows}) - def _get_some_state_from_cache(self, group, types, filtered_types=None): + def _get_some_state_from_cache(self, cache, group, types, filtered_types=None): """Checks if group is in cache. See `_get_state_for_groups` Args: + cache(DictionaryCache): the state group cache to use group(int): The state group to lookup types(list[str, str|None]): List of 2-tuples of the form (`type`, `state_key`), where a `state_key` of `None` matches all @@ -597,11 +646,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): requests state from the cache, if False we need to query the DB for the missing state. """ - is_all, known_absent, state_dict_ids = self._state_group_cache.get(group) + is_all, known_absent, state_dict_ids = cache.get(group) type_to_key = {} - # tracks whether any of ourrequested types are missing from the cache + # tracks whether any of our requested types are missing from the cache missing_types = False for typ, state_key in types: @@ -648,7 +697,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): if include(k[0], k[1]) }, got_all - def _get_all_state_from_cache(self, group): + def _get_all_state_from_cache(self, cache, group): """Checks if group is in cache. See `_get_state_for_groups` Returns 2-tuple (`state_dict`, `got_all`). `got_all` is a bool @@ -656,9 +705,10 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): cache, if False we need to query the DB for the missing state. Args: + cache(DictionaryCache): the state group cache to use group: The state group to lookup """ - is_all, _, state_dict_ids = self._state_group_cache.get(group) + is_all, _, state_dict_ids = cache.get(group) return state_dict_ids, is_all @@ -681,6 +731,62 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): list of event types. Other types of events are returned unfiltered. If None, `types` filtering is applied to all events. + Returns: + Deferred[dict[int, dict[(type, state_key), EventBase]]] + a dictionary mapping from state group to state dictionary. + """ + if types is not None: + non_member_types = [t for t in types if t[0] != EventTypes.Member] + + if filtered_types is not None and EventTypes.Member not in filtered_types: + # we want all of the membership events + member_types = None + else: + member_types = [t for t in types if t[0] == EventTypes.Member] + + else: + non_member_types = None + member_types = None + + non_member_state = yield self._get_state_for_groups_using_cache( + groups, self._state_group_cache, non_member_types, filtered_types, + ) + # XXX: we could skip this entirely if member_types is [] + member_state = yield self._get_state_for_groups_using_cache( + # we set filtered_types=None as member_state only ever contain members. + groups, self._state_group_members_cache, member_types, None, + ) + + state = non_member_state + for group in groups: + state[group].update(member_state[group]) + + defer.returnValue(state) + + @defer.inlineCallbacks + def _get_state_for_groups_using_cache( + self, groups, cache, types=None, filtered_types=None + ): + """Gets the state at each of a list of state groups, optionally + filtering by type/state_key, querying from a specific cache. + + Args: + groups (iterable[int]): list of state groups for which we want + to get the state. + cache (DictionaryCache): the cache of group ids to state dicts which + we will pass through - either the normal state cache or the specific + members state cache. + types (None|iterable[(str, None|str)]): + indicates the state type/keys required. If None, the whole + state is fetched and returned. + + Otherwise, each entry should be a `(type, state_key)` tuple to + include in the response. A `state_key` of None is a wildcard + meaning that we require all state with that type. + filtered_types(list[str]|None): Only apply filtering via `types` to this + list of event types. Other types of events are returned unfiltered. + If None, `types` filtering is applied to all events. + Returns: Deferred[dict[int, dict[(type, state_key), EventBase]]] a dictionary mapping from state group to state dictionary. @@ -692,7 +798,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): if types is not None: for group in set(groups): state_dict_ids, got_all = self._get_some_state_from_cache( - group, types, filtered_types + cache, group, types, filtered_types ) results[group] = state_dict_ids @@ -701,7 +807,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): else: for group in set(groups): state_dict_ids, got_all = self._get_all_state_from_cache( - group + cache, group ) results[group] = state_dict_ids @@ -710,8 +816,8 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): missing_groups.append(group) if missing_groups: - # Okay, so we have some missing_types, lets fetch them. - cache_seq_num = self._state_group_cache.sequence + # Okay, so we have some missing_types, let's fetch them. + cache_seq_num = cache.sequence # the DictionaryCache knows if it has *all* the state, but # does not know if it has all of the keys of a particular type, @@ -725,7 +831,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): types_to_fetch = types group_to_state_dict = yield self._get_state_groups_from_groups( - missing_groups, types_to_fetch + missing_groups, types_to_fetch, cache == self._state_group_members_cache, ) for group, group_state_dict in iteritems(group_to_state_dict): @@ -745,7 +851,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): # update the cache with all the things we fetched from the # database. - self._state_group_cache.update( + cache.update( cache_seq_num, key=group, value=group_state_dict, @@ -847,15 +953,33 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): ], ) - # Prefill the state group cache with this group. + # Prefill the state group caches with this group. # It's fine to use the sequence like this as the state group map # is immutable. (If the map wasn't immutable then this prefill could # race with another update) + + current_member_state_ids = { + s: ev + for (s, ev) in iteritems(current_state_ids) + if s[0] == EventTypes.Member + } + txn.call_after( + self._state_group_members_cache.update, + self._state_group_members_cache.sequence, + key=state_group, + value=dict(current_member_state_ids), + ) + + current_non_member_state_ids = { + s: ev + for (s, ev) in iteritems(current_state_ids) + if s[0] != EventTypes.Member + } txn.call_after( self._state_group_cache.update, self._state_group_cache.sequence, key=state_group, - value=dict(current_state_ids), + value=dict(current_non_member_state_ids), ) return state_group diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index ebfd969b36..d717b9f94e 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -185,6 +185,7 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_some_state_from_cache correctly filters out members with types=[] (state_dict, is_all) = yield self.store._get_some_state_from_cache( + self.store._state_group_cache, group, [], filtered_types=[EventTypes.Member] ) @@ -197,8 +198,20 @@ class StateStoreTestCase(tests.unittest.TestCase): state_dict, ) + (state_dict, is_all) = yield self.store._get_some_state_from_cache( + self.store._state_group_members_cache, + group, [], filtered_types=[EventTypes.Member] + ) + + self.assertEqual(is_all, True) + self.assertDictEqual( + {}, + state_dict, + ) + # test _get_some_state_from_cache correctly filters in members with wildcard types (state_dict, is_all) = yield self.store._get_some_state_from_cache( + self.store._state_group_cache, group, [(EventTypes.Member, None)], filtered_types=[EventTypes.Member] ) @@ -207,6 +220,18 @@ class StateStoreTestCase(tests.unittest.TestCase): { (e1.type, e1.state_key): e1.event_id, (e2.type, e2.state_key): e2.event_id, + }, + state_dict, + ) + + (state_dict, is_all) = yield self.store._get_some_state_from_cache( + self.store._state_group_members_cache, + group, [(EventTypes.Member, None)], filtered_types=[EventTypes.Member] + ) + + self.assertEqual(is_all, True) + self.assertDictEqual( + { (e3.type, e3.state_key): e3.event_id, # e4 is overwritten by e5 (e5.type, e5.state_key): e5.event_id, @@ -216,6 +241,7 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_some_state_from_cache correctly filters in members with specific types (state_dict, is_all) = yield self.store._get_some_state_from_cache( + self.store._state_group_cache, group, [(EventTypes.Member, e5.state_key)], filtered_types=[EventTypes.Member], @@ -226,6 +252,20 @@ class StateStoreTestCase(tests.unittest.TestCase): { (e1.type, e1.state_key): e1.event_id, (e2.type, e2.state_key): e2.event_id, + }, + state_dict, + ) + + (state_dict, is_all) = yield self.store._get_some_state_from_cache( + self.store._state_group_members_cache, + group, + [(EventTypes.Member, e5.state_key)], + filtered_types=[EventTypes.Member], + ) + + self.assertEqual(is_all, True) + self.assertDictEqual( + { (e5.type, e5.state_key): e5.event_id, }, state_dict, @@ -234,6 +274,7 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_some_state_from_cache correctly filters in members with specific types # and no filtered_types (state_dict, is_all) = yield self.store._get_some_state_from_cache( + self.store._state_group_members_cache, group, [(EventTypes.Member, e5.state_key)], filtered_types=None ) @@ -254,9 +295,6 @@ class StateStoreTestCase(tests.unittest.TestCase): { (e1.type, e1.state_key): e1.event_id, (e2.type, e2.state_key): e2.event_id, - (e3.type, e3.state_key): e3.event_id, - # e4 is overwritten by e5 - (e5.type, e5.state_key): e5.event_id, }, ) @@ -269,8 +307,6 @@ class StateStoreTestCase(tests.unittest.TestCase): # list fetched keys so it knows it's partial fetched_keys=( (e1.type, e1.state_key), - (e3.type, e3.state_key), - (e5.type, e5.state_key), ), ) @@ -284,8 +320,6 @@ class StateStoreTestCase(tests.unittest.TestCase): set( [ (e1.type, e1.state_key), - (e3.type, e3.state_key), - (e5.type, e5.state_key), ] ), ) @@ -293,8 +327,6 @@ class StateStoreTestCase(tests.unittest.TestCase): state_dict_ids, { (e1.type, e1.state_key): e1.event_id, - (e3.type, e3.state_key): e3.event_id, - (e5.type, e5.state_key): e5.event_id, }, ) @@ -304,14 +336,25 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_some_state_from_cache correctly filters out members with types=[] room_id = self.room.to_string() (state_dict, is_all) = yield self.store._get_some_state_from_cache( + self.store._state_group_cache, group, [], filtered_types=[EventTypes.Member] ) self.assertEqual(is_all, False) self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict) + room_id = self.room.to_string() + (state_dict, is_all) = yield self.store._get_some_state_from_cache( + self.store._state_group_members_cache, + group, [], filtered_types=[EventTypes.Member] + ) + + self.assertEqual(is_all, True) + self.assertDictEqual({}, state_dict) + # test _get_some_state_from_cache correctly filters in members wildcard types (state_dict, is_all) = yield self.store._get_some_state_from_cache( + self.store._state_group_cache, group, [(EventTypes.Member, None)], filtered_types=[EventTypes.Member] ) @@ -319,8 +362,19 @@ class StateStoreTestCase(tests.unittest.TestCase): self.assertDictEqual( { (e1.type, e1.state_key): e1.event_id, + }, + state_dict, + ) + + (state_dict, is_all) = yield self.store._get_some_state_from_cache( + self.store._state_group_members_cache, + group, [(EventTypes.Member, None)], filtered_types=[EventTypes.Member] + ) + + self.assertEqual(is_all, True) + self.assertDictEqual( + { (e3.type, e3.state_key): e3.event_id, - # e4 is overwritten by e5 (e5.type, e5.state_key): e5.event_id, }, state_dict, @@ -328,6 +382,7 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_some_state_from_cache correctly filters in members with specific types (state_dict, is_all) = yield self.store._get_some_state_from_cache( + self.store._state_group_cache, group, [(EventTypes.Member, e5.state_key)], filtered_types=[EventTypes.Member], @@ -337,6 +392,20 @@ class StateStoreTestCase(tests.unittest.TestCase): self.assertDictEqual( { (e1.type, e1.state_key): e1.event_id, + }, + state_dict, + ) + + (state_dict, is_all) = yield self.store._get_some_state_from_cache( + self.store._state_group_members_cache, + group, + [(EventTypes.Member, e5.state_key)], + filtered_types=[EventTypes.Member], + ) + + self.assertEqual(is_all, True) + self.assertDictEqual( + { (e5.type, e5.state_key): e5.event_id, }, state_dict, @@ -345,8 +414,22 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_some_state_from_cache correctly filters in members with specific types # and no filtered_types (state_dict, is_all) = yield self.store._get_some_state_from_cache( + self.store._state_group_cache, + group, [(EventTypes.Member, e5.state_key)], filtered_types=None + ) + + self.assertEqual(is_all, False) + self.assertDictEqual({}, state_dict) + + (state_dict, is_all) = yield self.store._get_some_state_from_cache( + self.store._state_group_members_cache, group, [(EventTypes.Member, e5.state_key)], filtered_types=None ) self.assertEqual(is_all, True) - self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict) + self.assertDictEqual( + { + (e5.type, e5.state_key): e5.event_id, + }, + state_dict, + ) -- cgit 1.5.1 From 7a0da69eee49ce84ddd8a433288d7ce9cb76e14e Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 23 Aug 2018 10:28:10 +0100 Subject: Add missing yield --- synapse/storage/monthly_active_users.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'synapse/storage') diff --git a/synapse/storage/monthly_active_users.py b/synapse/storage/monthly_active_users.py index 06f9a75a97..fd3b630bd2 100644 --- a/synapse/storage/monthly_active_users.py +++ b/synapse/storage/monthly_active_users.py @@ -147,6 +147,7 @@ class MonthlyActiveUsersStore(SQLBaseStore): return count return self.runInteraction("count_users", _count_users) + @defer.inlineCallbacks def upsert_monthly_active_user(self, user_id): """ Updates or inserts monthly active user member @@ -155,7 +156,7 @@ class MonthlyActiveUsersStore(SQLBaseStore): Deferred[bool]: True if a new entry was created, False if an existing one was updated. """ - is_insert = self._simple_upsert( + is_insert = yield self._simple_upsert( desc="upsert_monthly_active_user", table="monthly_active_users", keyvalues={ -- cgit 1.5.1 From cd77270a669fa827912efd86a0723d85e9e388ed Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 23 Aug 2018 19:17:08 +0100 Subject: Implement trail users --- synapse/api/auth.py | 6 +++++- synapse/config/server.py | 6 ++++++ synapse/storage/monthly_active_users.py | 5 +++++ synapse/storage/registration.py | 28 +++++++++++++++++++++++++++- tests/storage/test_registration.py | 1 + 5 files changed, 44 insertions(+), 2 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 4ca40a0f71..8d2aa5870a 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -797,11 +797,15 @@ class Auth(object): limit_type=self.hs.config.hs_disabled_limit_type ) if self.hs.config.limit_usage_by_mau is True: - # If the user is already part of the MAU cohort + # If the user is already part of the MAU cohort or a trial user if user_id: timestamp = yield self.store.user_last_seen_monthly_active(user_id) if timestamp: return + + is_trial = yield self.store.is_trial_user(user_id) + if is_trial: + return # Else if there is no room in the MAU bucket, bail current_mau = yield self.store.get_monthly_active_count() if current_mau >= self.hs.config.max_mau_value: diff --git a/synapse/config/server.py b/synapse/config/server.py index 68a612e594..8eecd28e7d 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -77,10 +77,15 @@ class ServerConfig(Config): self.max_mau_value = config.get( "max_mau_value", 0, ) + self.mau_limits_reserved_threepids = config.get( "mau_limit_reserved_threepids", [] ) + self.mau_trial_days = config.get( + "mau_trial_days", 0, + ) + # Options to disable HS self.hs_disabled = config.get("hs_disabled", False) self.hs_disabled_message = config.get("hs_disabled_message", "") @@ -365,6 +370,7 @@ class ServerConfig(Config): # Enables monthly active user checking # limit_usage_by_mau: False # max_mau_value: 50 + # mau_trial_days: 2 # # Sometimes the server admin will want to ensure certain accounts are # never blocked by mau checking. These accounts are specified here. diff --git a/synapse/storage/monthly_active_users.py b/synapse/storage/monthly_active_users.py index fd3b630bd2..d178f5c5ba 100644 --- a/synapse/storage/monthly_active_users.py +++ b/synapse/storage/monthly_active_users.py @@ -201,6 +201,11 @@ class MonthlyActiveUsersStore(SQLBaseStore): user_id(str): the user_id to query """ if self.hs.config.limit_usage_by_mau: + is_trial = yield self.is_trial_user(user_id) + if is_trial: + # we don't track trial users in the MAU table. + return + last_seen_timestamp = yield self.user_last_seen_monthly_active(user_id) now = self.hs.get_clock().time_msec() diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 07333f777d..26b429e307 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -26,6 +26,11 @@ from synapse.util.caches.descriptors import cached, cachedInlineCallbacks class RegistrationWorkerStore(SQLBaseStore): + def __init__(self, db_conn, hs): + super(RegistrationWorkerStore, self).__init__(db_conn, hs) + + self.config = hs.config + @cached() def get_user_by_id(self, user_id): return self._simple_select_one( @@ -36,12 +41,33 @@ class RegistrationWorkerStore(SQLBaseStore): retcols=[ "name", "password_hash", "is_guest", "consent_version", "consent_server_notice_sent", - "appservice_id", + "appservice_id", "creation_ts", ], allow_none=True, desc="get_user_by_id", ) + @defer.inlineCallbacks + def is_trial_user(self, user_id): + """Checks if user is in the "trial" period, i.e. within the first + N days of registration defined by `mau_trial_days` config + + Args: + user_id (str) + + Returns: + Deferred[bool] + """ + + info = yield self.get_user_by_id(user_id) + if not info: + defer.returnValue(False) + + now = self.clock.time_msec() + trial_duration_ms = self.config.mau_trial_days * 24 * 60 * 60 * 1000 + is_trial = (now - info["creation_ts"] * 1000) < trial_duration_ms + defer.returnValue(is_trial) + @cached() def get_user_by_access_token(self, token): """Get a user from the given access token. diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index 4eda122edc..3dfb7b903a 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -46,6 +46,7 @@ class RegistrationStoreTestCase(unittest.TestCase): "consent_version": None, "consent_server_notice_sent": None, "appservice_id": None, + "creation_ts": 1000, }, (yield self.store.get_user_by_id(self.user_id)), ) -- cgit 1.5.1