diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 69cb28268a..53c685c173 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -40,6 +40,7 @@ from .filtering import FilteringStore
from .group_server import GroupServerStore
from .keys import KeyStore
from .media_repository import MediaRepositoryStore
+from .monthly_active_users import MonthlyActiveUsersStore
from .openid import OpenIdStore
from .presence import PresenceStore, UserPresenceState
from .profile import ProfileStore
@@ -89,6 +90,7 @@ class DataStore(RoomMemberStore, RoomStore,
UserDirectoryStore,
GroupServerStore,
UserErasureStore,
+ MonthlyActiveUsersStore,
):
def __init__(self, db_conn, hs):
@@ -96,7 +98,6 @@ class DataStore(RoomMemberStore, RoomStore,
self._clock = hs.get_clock()
self.database_engine = hs.database_engine
- self.db_conn = db_conn
self._stream_id_gen = StreamIdGenerator(
db_conn, "events", "stream_ordering",
extra_tables=[("local_invites", "stream_id")]
@@ -269,31 +270,6 @@ class DataStore(RoomMemberStore, RoomStore,
return self.runInteraction("count_users", _count_users)
- def count_monthly_users(self):
- """Counts the number of users who used this homeserver in the last 30 days
-
- This method should be refactored with count_daily_users - the only
- reason not to is waiting on definition of mau
-
- Returns:
- Defered[int]
- """
- def _count_monthly_users(txn):
- thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
- sql = """
- SELECT COALESCE(count(*), 0) FROM (
- SELECT user_id FROM user_ips
- WHERE last_seen > ?
- GROUP BY user_id
- ) u
- """
-
- txn.execute(sql, (thirty_days_ago,))
- count, = txn.fetchone()
- return count
-
- return self.runInteraction("count_monthly_users", _count_monthly_users)
-
def count_r30_users(self):
"""
Counts the number of 30 day retained users, defined as:-
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 44f37b4c1e..08dffd774f 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -1150,17 +1150,16 @@ class SQLBaseStore(object):
defer.returnValue(retval)
def get_user_count_txn(self, txn):
- """Get a total number of registerd users in the users list.
+ """Get a total number of registered users in the users list.
Args:
txn : Transaction object
Returns:
- defer.Deferred: resolves to int
+ int : number of users
"""
sql_count = "SELECT COUNT(*) FROM users WHERE is_guest = 0;"
txn.execute(sql_count)
- count = txn.fetchone()[0]
- defer.returnValue(count)
+ return txn.fetchone()[0]
def _simple_search_list(self, table, term, col, retcols,
desc="_simple_search_list"):
diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py
index b8cefd43d6..8fc678fa67 100644
--- a/synapse/storage/client_ips.py
+++ b/synapse/storage/client_ips.py
@@ -35,6 +35,7 @@ LAST_SEEN_GRANULARITY = 120 * 1000
class ClientIpStore(background_updates.BackgroundUpdateStore):
def __init__(self, db_conn, hs):
+
self.client_ip_last_seen = Cache(
name="client_ip_last_seen",
keylen=4,
@@ -74,6 +75,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
"before", "shutdown", self._update_client_ips_batch
)
+ @defer.inlineCallbacks
def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id,
now=None):
if not now:
@@ -84,7 +86,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
last_seen = self.client_ip_last_seen.get(key)
except KeyError:
last_seen = None
-
+ yield self.populate_monthly_active_users(user_id)
# Rate-limited inserts
if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
return
@@ -94,6 +96,11 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
self._batch_row_update[key] = (user_agent, device_id, now)
def _update_client_ips_batch(self):
+
+ # If the DB pool has already terminated, don't try updating
+ if not self.hs.get_db_pool().running:
+ return
+
def update():
to_update = self._batch_row_update
self._batch_row_update = {}
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index e8e5a0fe44..f39c8c8461 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -38,7 +38,7 @@ from synapse.storage.background_updates import BackgroundUpdateStore
from synapse.storage.event_federation import EventFederationStore
from synapse.storage.events_worker import EventsWorkerStore
from synapse.types import RoomStreamToken, get_domain_from_id
-from synapse.util.async import ObservableDeferred
+from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
from synapse.util.frozenutils import frozendict_json_encoder
from synapse.util.logcontext import PreserveLoggingContext, make_deferred_yieldable
@@ -485,9 +485,14 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
new_forward_extremeties=new_forward_extremeties,
)
persist_event_counter.inc(len(chunk))
- synapse.metrics.event_persisted_position.set(
- chunk[-1][0].internal_metadata.stream_ordering,
- )
+
+ if not backfilled:
+ # backfilled events have negative stream orderings, so we don't
+ # want to set the event_persisted_position to that.
+ synapse.metrics.event_persisted_position.set(
+ chunk[-1][0].internal_metadata.stream_ordering,
+ )
+
for event, context in chunk:
if context.app_service:
origin_type = "local"
@@ -700,9 +705,11 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
}
events_map = {ev.event_id: ev for ev, _ in events_context}
+ room_version = yield self.get_room_version(room_id)
+
logger.debug("calling resolve_state_groups from preserve_events")
res = yield self._state_resolution_handler.resolve_state_groups(
- room_id, state_groups, events_map, get_events
+ room_id, room_version, state_groups, events_map, get_events
)
defer.returnValue((res.state, None))
@@ -1431,88 +1438,6 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
)
@defer.inlineCallbacks
- def have_events_in_timeline(self, event_ids):
- """Given a list of event ids, check if we have already processed and
- stored them as non outliers.
- """
- rows = yield self._simple_select_many_batch(
- table="events",
- retcols=("event_id",),
- column="event_id",
- iterable=list(event_ids),
- keyvalues={"outlier": False},
- desc="have_events_in_timeline",
- )
-
- defer.returnValue(set(r["event_id"] for r in rows))
-
- @defer.inlineCallbacks
- def have_seen_events(self, event_ids):
- """Given a list of event ids, check if we have already processed them.
-
- Args:
- event_ids (iterable[str]):
-
- Returns:
- Deferred[set[str]]: The events we have already seen.
- """
- results = set()
-
- def have_seen_events_txn(txn, chunk):
- sql = (
- "SELECT event_id FROM events as e WHERE e.event_id IN (%s)"
- % (",".join("?" * len(chunk)), )
- )
- txn.execute(sql, chunk)
- for (event_id, ) in txn:
- results.add(event_id)
-
- # break the input up into chunks of 100
- input_iterator = iter(event_ids)
- for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)),
- []):
- yield self.runInteraction(
- "have_seen_events",
- have_seen_events_txn,
- chunk,
- )
- defer.returnValue(results)
-
- def get_seen_events_with_rejections(self, event_ids):
- """Given a list of event ids, check if we rejected them.
-
- Args:
- event_ids (list[str])
-
- Returns:
- Deferred[dict[str, str|None):
- Has an entry for each event id we already have seen. Maps to
- the rejected reason string if we rejected the event, else maps
- to None.
- """
- if not event_ids:
- return defer.succeed({})
-
- def f(txn):
- sql = (
- "SELECT e.event_id, reason FROM events as e "
- "LEFT JOIN rejections as r ON e.event_id = r.event_id "
- "WHERE e.event_id = ?"
- )
-
- res = {}
- for event_id in event_ids:
- txn.execute(sql, (event_id,))
- row = txn.fetchone()
- if row:
- _, rejected = row
- res[event_id] = rejected
-
- return res
-
- return self.runInteraction("get_rejection_reasons", f)
-
- @defer.inlineCallbacks
def count_daily_messages(self):
"""
Returns an estimate of the number of messages sent in the last day.
@@ -1988,7 +1913,7 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
max_depth = max(row[0] for row in rows)
if max_depth <= token.topological:
- # We need to ensure we don't delete all the events from the datanase
+ # We need to ensure we don't delete all the events from the database
# otherwise we wouldn't be able to send any events (due to not
# having any backwards extremeties)
raise SynapseError(
diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py
index 9b4cfeb899..59822178ff 100644
--- a/synapse/storage/events_worker.py
+++ b/synapse/storage/events_worker.py
@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import itertools
import logging
from collections import namedtuple
@@ -442,3 +443,85 @@ class EventsWorkerStore(SQLBaseStore):
self._get_event_cache.prefill((original_ev.event_id,), cache_entry)
defer.returnValue(cache_entry)
+
+ @defer.inlineCallbacks
+ def have_events_in_timeline(self, event_ids):
+ """Given a list of event ids, check if we have already processed and
+ stored them as non outliers.
+ """
+ rows = yield self._simple_select_many_batch(
+ table="events",
+ retcols=("event_id",),
+ column="event_id",
+ iterable=list(event_ids),
+ keyvalues={"outlier": False},
+ desc="have_events_in_timeline",
+ )
+
+ defer.returnValue(set(r["event_id"] for r in rows))
+
+ @defer.inlineCallbacks
+ def have_seen_events(self, event_ids):
+ """Given a list of event ids, check if we have already processed them.
+
+ Args:
+ event_ids (iterable[str]):
+
+ Returns:
+ Deferred[set[str]]: The events we have already seen.
+ """
+ results = set()
+
+ def have_seen_events_txn(txn, chunk):
+ sql = (
+ "SELECT event_id FROM events as e WHERE e.event_id IN (%s)"
+ % (",".join("?" * len(chunk)), )
+ )
+ txn.execute(sql, chunk)
+ for (event_id, ) in txn:
+ results.add(event_id)
+
+ # break the input up into chunks of 100
+ input_iterator = iter(event_ids)
+ for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)),
+ []):
+ yield self.runInteraction(
+ "have_seen_events",
+ have_seen_events_txn,
+ chunk,
+ )
+ defer.returnValue(results)
+
+ def get_seen_events_with_rejections(self, event_ids):
+ """Given a list of event ids, check if we rejected them.
+
+ Args:
+ event_ids (list[str])
+
+ Returns:
+ Deferred[dict[str, str|None):
+ Has an entry for each event id we already have seen. Maps to
+ the rejected reason string if we rejected the event, else maps
+ to None.
+ """
+ if not event_ids:
+ return defer.succeed({})
+
+ def f(txn):
+ sql = (
+ "SELECT e.event_id, reason FROM events as e "
+ "LEFT JOIN rejections as r ON e.event_id = r.event_id "
+ "WHERE e.event_id = ?"
+ )
+
+ res = {}
+ for event_id in event_ids:
+ txn.execute(sql, (event_id,))
+ row = txn.fetchone()
+ if row:
+ _, rejected = row
+ res[event_id] = rejected
+
+ return res
+
+ return self.runInteraction("get_rejection_reasons", f)
diff --git a/synapse/storage/monthly_active_users.py b/synapse/storage/monthly_active_users.py
new file mode 100644
index 0000000000..d178f5c5ba
--- /dev/null
+++ b/synapse/storage/monthly_active_users.py
@@ -0,0 +1,222 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from twisted.internet import defer
+
+from synapse.util.caches.descriptors import cached
+
+from ._base import SQLBaseStore
+
+logger = logging.getLogger(__name__)
+
+# Number of msec of granularity to store the monthly_active_user timestamp
+# This means it is not necessary to update the table on every request
+LAST_SEEN_GRANULARITY = 60 * 60 * 1000
+
+
+class MonthlyActiveUsersStore(SQLBaseStore):
+ def __init__(self, dbconn, hs):
+ super(MonthlyActiveUsersStore, self).__init__(None, hs)
+ self._clock = hs.get_clock()
+ self.hs = hs
+ self.reserved_users = ()
+
+ @defer.inlineCallbacks
+ def initialise_reserved_users(self, threepids):
+ # TODO Why can't I do this in init?
+ store = self.hs.get_datastore()
+ reserved_user_list = []
+
+ # Do not add more reserved users than the total allowable number
+ for tp in threepids[:self.hs.config.max_mau_value]:
+ user_id = yield store.get_user_id_by_threepid(
+ tp["medium"], tp["address"]
+ )
+ if user_id:
+ yield self.upsert_monthly_active_user(user_id)
+ reserved_user_list.append(user_id)
+ else:
+ logger.warning(
+ "mau limit reserved threepid %s not found in db" % tp
+ )
+ self.reserved_users = tuple(reserved_user_list)
+
+ @defer.inlineCallbacks
+ def reap_monthly_active_users(self):
+ """
+ Cleans out monthly active user table to ensure that no stale
+ entries exist.
+
+ Returns:
+ Deferred[]
+ """
+ def _reap_users(txn):
+ # Purge stale users
+
+ thirty_days_ago = (
+ int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
+ )
+ query_args = [thirty_days_ago]
+ base_sql = "DELETE FROM monthly_active_users WHERE timestamp < ?"
+
+ # Need if/else since 'AND user_id NOT IN ({})' fails on Postgres
+ # when len(reserved_users) == 0. Works fine on sqlite.
+ if len(self.reserved_users) > 0:
+ # questionmarks is a hack to overcome sqlite not supporting
+ # tuples in 'WHERE IN %s'
+ questionmarks = '?' * len(self.reserved_users)
+
+ query_args.extend(self.reserved_users)
+ sql = base_sql + """ AND user_id NOT IN ({})""".format(
+ ','.join(questionmarks)
+ )
+ else:
+ sql = base_sql
+
+ txn.execute(sql, query_args)
+
+ # If MAU user count still exceeds the MAU threshold, then delete on
+ # a least recently active basis.
+ # Note it is not possible to write this query using OFFSET due to
+ # incompatibilities in how sqlite and postgres support the feature.
+ # sqlite requires 'LIMIT -1 OFFSET ?', the LIMIT must be present
+ # While Postgres does not require 'LIMIT', but also does not support
+ # negative LIMIT values. So there is no way to write it that both can
+ # support
+ safe_guard = self.hs.config.max_mau_value - len(self.reserved_users)
+ # Must be greater than zero for postgres
+ safe_guard = safe_guard if safe_guard > 0 else 0
+ query_args = [safe_guard]
+
+ base_sql = """
+ DELETE FROM monthly_active_users
+ WHERE user_id NOT IN (
+ SELECT user_id FROM monthly_active_users
+ ORDER BY timestamp DESC
+ LIMIT ?
+ )
+ """
+ # Need if/else since 'AND user_id NOT IN ({})' fails on Postgres
+ # when len(reserved_users) == 0. Works fine on sqlite.
+ if len(self.reserved_users) > 0:
+ query_args.extend(self.reserved_users)
+ sql = base_sql + """ AND user_id NOT IN ({})""".format(
+ ','.join(questionmarks)
+ )
+ else:
+ sql = base_sql
+ txn.execute(sql, query_args)
+
+ yield self.runInteraction("reap_monthly_active_users", _reap_users)
+ # It seems poor to invalidate the whole cache, Postgres supports
+ # 'Returning' which would allow me to invalidate only the
+ # specific users, but sqlite has no way to do this and instead
+ # I would need to SELECT and the DELETE which without locking
+ # is racy.
+ # Have resolved to invalidate the whole cache for now and do
+ # something about it if and when the perf becomes significant
+ self.user_last_seen_monthly_active.invalidate_all()
+ self.get_monthly_active_count.invalidate_all()
+
+ @cached(num_args=0)
+ def get_monthly_active_count(self):
+ """Generates current count of monthly active users
+
+ Returns:
+ Defered[int]: Number of current monthly active users
+ """
+
+ def _count_users(txn):
+ sql = "SELECT COALESCE(count(*), 0) FROM monthly_active_users"
+
+ txn.execute(sql)
+ count, = txn.fetchone()
+ return count
+ return self.runInteraction("count_users", _count_users)
+
+ @defer.inlineCallbacks
+ def upsert_monthly_active_user(self, user_id):
+ """
+ Updates or inserts monthly active user member
+ Arguments:
+ user_id (str): user to add/update
+ Deferred[bool]: True if a new entry was created, False if an
+ existing one was updated.
+ """
+ is_insert = yield self._simple_upsert(
+ desc="upsert_monthly_active_user",
+ table="monthly_active_users",
+ keyvalues={
+ "user_id": user_id,
+ },
+ values={
+ "timestamp": int(self._clock.time_msec()),
+ },
+ lock=False,
+ )
+ if is_insert:
+ self.user_last_seen_monthly_active.invalidate((user_id,))
+ self.get_monthly_active_count.invalidate(())
+
+ @cached(num_args=1)
+ def user_last_seen_monthly_active(self, user_id):
+ """
+ Checks if a given user is part of the monthly active user group
+ Arguments:
+ user_id (str): user to add/update
+ Return:
+ Deferred[int] : timestamp since last seen, None if never seen
+
+ """
+
+ return(self._simple_select_one_onecol(
+ table="monthly_active_users",
+ keyvalues={
+ "user_id": user_id,
+ },
+ retcol="timestamp",
+ allow_none=True,
+ desc="user_last_seen_monthly_active",
+ ))
+
+ @defer.inlineCallbacks
+ def populate_monthly_active_users(self, user_id):
+ """Checks on the state of monthly active user limits and optionally
+ add the user to the monthly active tables
+
+ Args:
+ user_id(str): the user_id to query
+ """
+ if self.hs.config.limit_usage_by_mau:
+ is_trial = yield self.is_trial_user(user_id)
+ if is_trial:
+ # we don't track trial users in the MAU table.
+ return
+
+ last_seen_timestamp = yield self.user_last_seen_monthly_active(user_id)
+ now = self.hs.get_clock().time_msec()
+
+ # We want to reduce to the total number of db writes, and are happy
+ # to trade accuracy of timestamp in order to lighten load. This means
+ # We always insert new users (where MAU threshold has not been reached),
+ # but only update if we have not previously seen the user for
+ # LAST_SEEN_GRANULARITY ms
+ if last_seen_timestamp is None:
+ count = yield self.get_monthly_active_count()
+ if count < self.hs.config.max_mau_value:
+ yield self.upsert_monthly_active_user(user_id)
+ elif now - last_seen_timestamp > LAST_SEEN_GRANULARITY:
+ yield self.upsert_monthly_active_user(user_id)
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index b290f834b3..b364719312 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts.
-SCHEMA_VERSION = 50
+SCHEMA_VERSION = 51
dir_path = os.path.abspath(os.path.dirname(__file__))
diff --git a/synapse/storage/profile.py b/synapse/storage/profile.py
index 60295da254..88b50f33b5 100644
--- a/synapse/storage/profile.py
+++ b/synapse/storage/profile.py
@@ -71,8 +71,6 @@ class ProfileWorkerStore(SQLBaseStore):
desc="get_from_remote_profile_cache",
)
-
-class ProfileStore(ProfileWorkerStore):
def create_profile(self, user_localpart):
return self._simple_insert(
table="profiles",
@@ -96,6 +94,8 @@ class ProfileStore(ProfileWorkerStore):
desc="set_profile_avatar_url",
)
+
+class ProfileStore(ProfileWorkerStore):
def add_remote_profile_cache(self, user_id, displayname, avatar_url):
"""Ensure we are caching the remote user's profiles.
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 07333f777d..26b429e307 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -26,6 +26,11 @@ from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
class RegistrationWorkerStore(SQLBaseStore):
+ def __init__(self, db_conn, hs):
+ super(RegistrationWorkerStore, self).__init__(db_conn, hs)
+
+ self.config = hs.config
+
@cached()
def get_user_by_id(self, user_id):
return self._simple_select_one(
@@ -36,12 +41,33 @@ class RegistrationWorkerStore(SQLBaseStore):
retcols=[
"name", "password_hash", "is_guest",
"consent_version", "consent_server_notice_sent",
- "appservice_id",
+ "appservice_id", "creation_ts",
],
allow_none=True,
desc="get_user_by_id",
)
+ @defer.inlineCallbacks
+ def is_trial_user(self, user_id):
+ """Checks if user is in the "trial" period, i.e. within the first
+ N days of registration defined by `mau_trial_days` config
+
+ Args:
+ user_id (str)
+
+ Returns:
+ Deferred[bool]
+ """
+
+ info = yield self.get_user_by_id(user_id)
+ if not info:
+ defer.returnValue(False)
+
+ now = self.clock.time_msec()
+ trial_duration_ms = self.config.mau_trial_days * 24 * 60 * 60 * 1000
+ is_trial = (now - info["creation_ts"] * 1000) < trial_duration_ms
+ defer.returnValue(is_trial)
+
@cached()
def get_user_by_access_token(self, token):
"""Get a user from the given access token.
diff --git a/synapse/storage/room.py b/synapse/storage/room.py
index 3147fb6827..61013b8919 100644
--- a/synapse/storage/room.py
+++ b/synapse/storage/room.py
@@ -41,6 +41,22 @@ RatelimitOverride = collections.namedtuple(
class RoomWorkerStore(SQLBaseStore):
+ def get_room(self, room_id):
+ """Retrieve a room.
+
+ Args:
+ room_id (str): The ID of the room to retrieve.
+ Returns:
+ A namedtuple containing the room information, or an empty list.
+ """
+ return self._simple_select_one(
+ table="rooms",
+ keyvalues={"room_id": room_id},
+ retcols=("room_id", "is_public", "creator"),
+ desc="get_room",
+ allow_none=True,
+ )
+
def get_public_room_ids(self):
return self._simple_select_onecol(
table="rooms",
@@ -170,6 +186,35 @@ class RoomWorkerStore(SQLBaseStore):
desc="is_room_blocked",
)
+ @cachedInlineCallbacks(max_entries=10000)
+ def get_ratelimit_for_user(self, user_id):
+ """Check if there are any overrides for ratelimiting for the given
+ user
+
+ Args:
+ user_id (str)
+
+ Returns:
+ RatelimitOverride if there is an override, else None. If the contents
+ of RatelimitOverride are None or 0 then ratelimitng has been
+ disabled for that user entirely.
+ """
+ row = yield self._simple_select_one(
+ table="ratelimit_override",
+ keyvalues={"user_id": user_id},
+ retcols=("messages_per_second", "burst_count"),
+ allow_none=True,
+ desc="get_ratelimit_for_user",
+ )
+
+ if row:
+ defer.returnValue(RatelimitOverride(
+ messages_per_second=row["messages_per_second"],
+ burst_count=row["burst_count"],
+ ))
+ else:
+ defer.returnValue(None)
+
class RoomStore(RoomWorkerStore, SearchStore):
@@ -215,22 +260,6 @@ class RoomStore(RoomWorkerStore, SearchStore):
logger.error("store_room with room_id=%s failed: %s", room_id, e)
raise StoreError(500, "Problem creating room.")
- def get_room(self, room_id):
- """Retrieve a room.
-
- Args:
- room_id (str): The ID of the room to retrieve.
- Returns:
- A namedtuple containing the room information, or an empty list.
- """
- return self._simple_select_one(
- table="rooms",
- keyvalues={"room_id": room_id},
- retcols=("room_id", "is_public", "creator"),
- desc="get_room",
- allow_none=True,
- )
-
@defer.inlineCallbacks
def set_room_is_public(self, room_id, is_public):
def set_room_is_public_txn(txn, next_id):
@@ -469,35 +498,6 @@ class RoomStore(RoomWorkerStore, SearchStore):
"get_all_new_public_rooms", get_all_new_public_rooms
)
- @cachedInlineCallbacks(max_entries=10000)
- def get_ratelimit_for_user(self, user_id):
- """Check if there are any overrides for ratelimiting for the given
- user
-
- Args:
- user_id (str)
-
- Returns:
- RatelimitOverride if there is an override, else None. If the contents
- of RatelimitOverride are None or 0 then ratelimitng has been
- disabled for that user entirely.
- """
- row = yield self._simple_select_one(
- table="ratelimit_override",
- keyvalues={"user_id": user_id},
- retcols=("messages_per_second", "burst_count"),
- allow_none=True,
- desc="get_ratelimit_for_user",
- )
-
- if row:
- defer.returnValue(RatelimitOverride(
- messages_per_second=row["messages_per_second"],
- burst_count=row["burst_count"],
- ))
- else:
- defer.returnValue(None)
-
@defer.inlineCallbacks
def block_room(self, room_id, user_id):
yield self._simple_insert(
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 10dce21cea..9b4e6d6aa8 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -26,7 +26,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.storage.events_worker import EventsWorkerStore
from synapse.types import get_domain_from_id
-from synapse.util.async import Linearizer
+from synapse.util.async_helpers import Linearizer
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
from synapse.util.stringutils import to_ascii
diff --git a/synapse/storage/schema/delta/51/monthly_active_users.sql b/synapse/storage/schema/delta/51/monthly_active_users.sql
new file mode 100644
index 0000000000..c9d537d5a3
--- /dev/null
+++ b/synapse/storage/schema/delta/51/monthly_active_users.sql
@@ -0,0 +1,27 @@
+/* Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- a table of monthly active users, for use where blocking based on mau limits
+CREATE TABLE monthly_active_users (
+ user_id TEXT NOT NULL,
+ -- Last time we saw the user. Not guaranteed to be accurate due to rate limiting
+ -- on updates, Granularity of updates governed by
+ -- synapse.storage.monthly_active_users.LAST_SEEN_GRANULARITY
+ -- Measured in ms since epoch.
+ timestamp BIGINT NOT NULL
+);
+
+CREATE UNIQUE INDEX monthly_active_users_users ON monthly_active_users(user_id);
+CREATE INDEX monthly_active_users_time_stamp ON monthly_active_users(timestamp);
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index b27b3ae144..4b971efdba 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -21,15 +21,17 @@ from six.moves import range
from twisted.internet import defer
+from synapse.api.constants import EventTypes
+from synapse.api.errors import NotFoundError
+from synapse.storage._base import SQLBaseStore
from synapse.storage.background_updates import BackgroundUpdateStore
from synapse.storage.engines import PostgresEngine
+from synapse.storage.events_worker import EventsWorkerStore
from synapse.util.caches import get_cache_factor_for, intern_string
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.dictionary_cache import DictionaryCache
from synapse.util.stringutils import to_ascii
-from ._base import SQLBaseStore
-
logger = logging.getLogger(__name__)
@@ -46,7 +48,8 @@ class _GetStateGroupDelta(namedtuple("_GetStateGroupDelta", ("prev_group", "delt
return len(self.delta_ids) if self.delta_ids else 0
-class StateGroupWorkerStore(SQLBaseStore):
+# this inherits from EventsWorkerStore because it calls self.get_events
+class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
"""The parts of StateGroupStore that can be called from workers.
"""
@@ -57,10 +60,69 @@ class StateGroupWorkerStore(SQLBaseStore):
def __init__(self, db_conn, hs):
super(StateGroupWorkerStore, self).__init__(db_conn, hs)
+ # Originally the state store used a single DictionaryCache to cache the
+ # event IDs for the state types in a given state group to avoid hammering
+ # on the state_group* tables.
+ #
+ # The point of using a DictionaryCache is that it can cache a subset
+ # of the state events for a given state group (i.e. a subset of the keys for a
+ # given dict which is an entry in the cache for a given state group ID).
+ #
+ # However, this poses problems when performing complicated queries
+ # on the store - for instance: "give me all the state for this group, but
+ # limit members to this subset of users", as DictionaryCache's API isn't
+ # rich enough to say "please cache any of these fields, apart from this subset".
+ # This is problematic when lazy loading members, which requires this behaviour,
+ # as without it the cache has no choice but to speculatively load all
+ # state events for the group, which negates the efficiency being sought.
+ #
+ # Rather than overcomplicating DictionaryCache's API, we instead split the
+ # state_group_cache into two halves - one for tracking non-member events,
+ # and the other for tracking member_events. This means that lazy loading
+ # queries can be made in a cache-friendly manner by querying both caches
+ # separately and then merging the result. So for the example above, you
+ # would query the members cache for a specific subset of state keys
+ # (which DictionaryCache will handle efficiently and fine) and the non-members
+ # cache for all state (which DictionaryCache will similarly handle fine)
+ # and then just merge the results together.
+ #
+ # We size the non-members cache to be smaller than the members cache as the
+ # vast majority of state in Matrix (today) is member events.
+
self._state_group_cache = DictionaryCache(
- "*stateGroupCache*", 500000 * get_cache_factor_for("stateGroupCache")
+ "*stateGroupCache*",
+ # TODO: this hasn't been tuned yet
+ 50000 * get_cache_factor_for("stateGroupCache")
+ )
+ self._state_group_members_cache = DictionaryCache(
+ "*stateGroupMembersCache*",
+ 500000 * get_cache_factor_for("stateGroupMembersCache")
)
+ @defer.inlineCallbacks
+ def get_room_version(self, room_id):
+ """Get the room_version of a given room
+
+ Args:
+ room_id (str)
+
+ Returns:
+ Deferred[str]
+
+ Raises:
+ NotFoundError if the room is unknown
+ """
+ # for now we do this by looking at the create event. We may want to cache this
+ # more intelligently in future.
+ state_ids = yield self.get_current_state_ids(room_id)
+ create_id = state_ids.get((EventTypes.Create, ""))
+
+ if not create_id:
+ raise NotFoundError("Unknown room")
+
+ create_event = yield self.get_event(create_id)
+ defer.returnValue(create_event.content.get("room_version", "1"))
+
@cached(max_entries=100000, iterable=True)
def get_current_state_ids(self, room_id):
"""Get the current state event ids for a room based on the
@@ -89,6 +151,69 @@ class StateGroupWorkerStore(SQLBaseStore):
_get_current_state_ids_txn,
)
+ # FIXME: how should this be cached?
+ def get_filtered_current_state_ids(self, room_id, types, filtered_types=None):
+ """Get the current state event of a given type for a room based on the
+ current_state_events table. This may not be as up-to-date as the result
+ of doing a fresh state resolution as per state_handler.get_current_state
+ Args:
+ room_id (str)
+ types (list[(Str, (Str|None))]): List of (type, state_key) tuples
+ which are used to filter the state fetched. `state_key` may be
+ None, which matches any `state_key`
+ filtered_types (list[Str]|None): List of types to apply the above filter to.
+ Returns:
+ deferred: dict of (type, state_key) -> event
+ """
+
+ include_other_types = False if filtered_types is None else True
+
+ def _get_filtered_current_state_ids_txn(txn):
+ results = {}
+ sql = """SELECT type, state_key, event_id FROM current_state_events
+ WHERE room_id = ? %s"""
+ # Turns out that postgres doesn't like doing a list of OR's and
+ # is about 1000x slower, so we just issue a query for each specific
+ # type seperately.
+ if types:
+ clause_to_args = [
+ (
+ "AND type = ? AND state_key = ?",
+ (etype, state_key)
+ ) if state_key is not None else (
+ "AND type = ?",
+ (etype,)
+ )
+ for etype, state_key in types
+ ]
+
+ if include_other_types:
+ unique_types = set(filtered_types)
+ clause_to_args.append(
+ (
+ "AND type <> ? " * len(unique_types),
+ list(unique_types)
+ )
+ )
+ else:
+ # If types is None we fetch all the state, and so just use an
+ # empty where clause with no extra args.
+ clause_to_args = [("", [])]
+ for where_clause, where_args in clause_to_args:
+ args = [room_id]
+ args.extend(where_args)
+ txn.execute(sql % (where_clause,), args)
+ for row in txn:
+ typ, state_key, event_id = row
+ key = (intern_string(typ), intern_string(state_key))
+ results[key] = event_id
+ return results
+
+ return self.runInteraction(
+ "get_filtered_current_state_ids",
+ _get_filtered_current_state_ids_txn,
+ )
+
@cached(max_entries=10000, iterable=True)
def get_state_group_delta(self, state_group):
"""Given a state group try to return a previous group and a delta between
@@ -185,7 +310,7 @@ class StateGroupWorkerStore(SQLBaseStore):
})
@defer.inlineCallbacks
- def _get_state_groups_from_groups(self, groups, types):
+ def _get_state_groups_from_groups(self, groups, types, members=None):
"""Returns the state groups for a given set of groups, filtering on
types of state events.
@@ -194,6 +319,9 @@ class StateGroupWorkerStore(SQLBaseStore):
types (Iterable[str, str|None]|None): list of 2-tuples of the form
(`type`, `state_key`), where a `state_key` of `None` matches all
state_keys for the `type`. If None, all types are returned.
+ members (bool|None): If not None, then, in addition to any filtering
+ implied by types, the results are also filtered to only include
+ member events (if True), or to exclude member events (if False)
Returns:
dictionary state_group -> (dict of (type, state_key) -> event id)
@@ -204,14 +332,14 @@ class StateGroupWorkerStore(SQLBaseStore):
for chunk in chunks:
res = yield self.runInteraction(
"_get_state_groups_from_groups",
- self._get_state_groups_from_groups_txn, chunk, types,
+ self._get_state_groups_from_groups_txn, chunk, types, members,
)
results.update(res)
defer.returnValue(results)
def _get_state_groups_from_groups_txn(
- self, txn, groups, types=None,
+ self, txn, groups, types=None, members=None,
):
results = {group: {} for group in groups}
@@ -249,6 +377,11 @@ class StateGroupWorkerStore(SQLBaseStore):
%s
""")
+ if members is True:
+ sql += " AND type = '%s'" % (EventTypes.Member,)
+ elif members is False:
+ sql += " AND type <> '%s'" % (EventTypes.Member,)
+
# Turns out that postgres doesn't like doing a list of OR's and
# is about 1000x slower, so we just issue a query for each specific
# type seperately.
@@ -296,6 +429,11 @@ class StateGroupWorkerStore(SQLBaseStore):
else:
where_clause = ""
+ if members is True:
+ where_clause += " AND type = '%s'" % EventTypes.Member
+ elif members is False:
+ where_clause += " AND type <> '%s'" % EventTypes.Member
+
# We don't use WITH RECURSIVE on sqlite3 as there are distributions
# that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
for group in groups:
@@ -362,8 +500,7 @@ class StateGroupWorkerStore(SQLBaseStore):
If None, `types` filtering is applied to all events.
Returns:
- deferred: A list of dicts corresponding to the event_ids given.
- The dicts are mappings from (type, state_key) -> state_events
+ deferred: A dict of (event_id) -> (type, state_key) -> [state_events]
"""
event_to_groups = yield self._get_state_group_for_events(
event_ids,
@@ -391,7 +528,8 @@ class StateGroupWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
def get_state_ids_for_events(self, event_ids, types=None, filtered_types=None):
"""
- Get the state dicts corresponding to a list of events
+ Get the state dicts corresponding to a list of events, containing the event_ids
+ of the state events (as opposed to the events themselves)
Args:
event_ids(list(str)): events whose state should be returned
@@ -404,7 +542,7 @@ class StateGroupWorkerStore(SQLBaseStore):
If None, `types` filtering is applied to all events.
Returns:
- A deferred dict from event_id -> (type, state_key) -> state_event
+ A deferred dict from event_id -> (type, state_key) -> event_id
"""
event_to_groups = yield self._get_state_group_for_events(
event_ids,
@@ -490,10 +628,11 @@ class StateGroupWorkerStore(SQLBaseStore):
defer.returnValue({row["event_id"]: row["state_group"] for row in rows})
- def _get_some_state_from_cache(self, group, types, filtered_types=None):
+ def _get_some_state_from_cache(self, cache, group, types, filtered_types=None):
"""Checks if group is in cache. See `_get_state_for_groups`
Args:
+ cache(DictionaryCache): the state group cache to use
group(int): The state group to lookup
types(list[str, str|None]): List of 2-tuples of the form
(`type`, `state_key`), where a `state_key` of `None` matches all
@@ -507,11 +646,11 @@ class StateGroupWorkerStore(SQLBaseStore):
requests state from the cache, if False we need to query the DB for the
missing state.
"""
- is_all, known_absent, state_dict_ids = self._state_group_cache.get(group)
+ is_all, known_absent, state_dict_ids = cache.get(group)
type_to_key = {}
- # tracks whether any of ourrequested types are missing from the cache
+ # tracks whether any of our requested types are missing from the cache
missing_types = False
for typ, state_key in types:
@@ -558,7 +697,7 @@ class StateGroupWorkerStore(SQLBaseStore):
if include(k[0], k[1])
}, got_all
- def _get_all_state_from_cache(self, group):
+ def _get_all_state_from_cache(self, cache, group):
"""Checks if group is in cache. See `_get_state_for_groups`
Returns 2-tuple (`state_dict`, `got_all`). `got_all` is a bool
@@ -566,9 +705,10 @@ class StateGroupWorkerStore(SQLBaseStore):
cache, if False we need to query the DB for the missing state.
Args:
+ cache(DictionaryCache): the state group cache to use
group: The state group to lookup
"""
- is_all, _, state_dict_ids = self._state_group_cache.get(group)
+ is_all, _, state_dict_ids = cache.get(group)
return state_dict_ids, is_all
@@ -595,6 +735,62 @@ class StateGroupWorkerStore(SQLBaseStore):
Deferred[dict[int, dict[(type, state_key), EventBase]]]
a dictionary mapping from state group to state dictionary.
"""
+ if types is not None:
+ non_member_types = [t for t in types if t[0] != EventTypes.Member]
+
+ if filtered_types is not None and EventTypes.Member not in filtered_types:
+ # we want all of the membership events
+ member_types = None
+ else:
+ member_types = [t for t in types if t[0] == EventTypes.Member]
+
+ else:
+ non_member_types = None
+ member_types = None
+
+ non_member_state = yield self._get_state_for_groups_using_cache(
+ groups, self._state_group_cache, non_member_types, filtered_types,
+ )
+ # XXX: we could skip this entirely if member_types is []
+ member_state = yield self._get_state_for_groups_using_cache(
+ # we set filtered_types=None as member_state only ever contain members.
+ groups, self._state_group_members_cache, member_types, None,
+ )
+
+ state = non_member_state
+ for group in groups:
+ state[group].update(member_state[group])
+
+ defer.returnValue(state)
+
+ @defer.inlineCallbacks
+ def _get_state_for_groups_using_cache(
+ self, groups, cache, types=None, filtered_types=None
+ ):
+ """Gets the state at each of a list of state groups, optionally
+ filtering by type/state_key, querying from a specific cache.
+
+ Args:
+ groups (iterable[int]): list of state groups for which we want
+ to get the state.
+ cache (DictionaryCache): the cache of group ids to state dicts which
+ we will pass through - either the normal state cache or the specific
+ members state cache.
+ types (None|iterable[(str, None|str)]):
+ indicates the state type/keys required. If None, the whole
+ state is fetched and returned.
+
+ Otherwise, each entry should be a `(type, state_key)` tuple to
+ include in the response. A `state_key` of None is a wildcard
+ meaning that we require all state with that type.
+ filtered_types(list[str]|None): Only apply filtering via `types` to this
+ list of event types. Other types of events are returned unfiltered.
+ If None, `types` filtering is applied to all events.
+
+ Returns:
+ Deferred[dict[int, dict[(type, state_key), EventBase]]]
+ a dictionary mapping from state group to state dictionary.
+ """
if types:
types = frozenset(types)
results = {}
@@ -602,7 +798,7 @@ class StateGroupWorkerStore(SQLBaseStore):
if types is not None:
for group in set(groups):
state_dict_ids, got_all = self._get_some_state_from_cache(
- group, types, filtered_types
+ cache, group, types, filtered_types
)
results[group] = state_dict_ids
@@ -611,7 +807,7 @@ class StateGroupWorkerStore(SQLBaseStore):
else:
for group in set(groups):
state_dict_ids, got_all = self._get_all_state_from_cache(
- group
+ cache, group
)
results[group] = state_dict_ids
@@ -620,8 +816,8 @@ class StateGroupWorkerStore(SQLBaseStore):
missing_groups.append(group)
if missing_groups:
- # Okay, so we have some missing_types, lets fetch them.
- cache_seq_num = self._state_group_cache.sequence
+ # Okay, so we have some missing_types, let's fetch them.
+ cache_seq_num = cache.sequence
# the DictionaryCache knows if it has *all* the state, but
# does not know if it has all of the keys of a particular type,
@@ -635,7 +831,7 @@ class StateGroupWorkerStore(SQLBaseStore):
types_to_fetch = types
group_to_state_dict = yield self._get_state_groups_from_groups(
- missing_groups, types_to_fetch
+ missing_groups, types_to_fetch, cache == self._state_group_members_cache,
)
for group, group_state_dict in iteritems(group_to_state_dict):
@@ -655,7 +851,7 @@ class StateGroupWorkerStore(SQLBaseStore):
# update the cache with all the things we fetched from the
# database.
- self._state_group_cache.update(
+ cache.update(
cache_seq_num,
key=group,
value=group_state_dict,
@@ -757,15 +953,33 @@ class StateGroupWorkerStore(SQLBaseStore):
],
)
- # Prefill the state group cache with this group.
+ # Prefill the state group caches with this group.
# It's fine to use the sequence like this as the state group map
# is immutable. (If the map wasn't immutable then this prefill could
# race with another update)
+
+ current_member_state_ids = {
+ s: ev
+ for (s, ev) in iteritems(current_state_ids)
+ if s[0] == EventTypes.Member
+ }
+ txn.call_after(
+ self._state_group_members_cache.update,
+ self._state_group_members_cache.sequence,
+ key=state_group,
+ value=dict(current_member_state_ids),
+ )
+
+ current_non_member_state_ids = {
+ s: ev
+ for (s, ev) in iteritems(current_state_ids)
+ if s[0] != EventTypes.Member
+ }
txn.call_after(
self._state_group_cache.update,
self._state_group_cache.sequence,
key=state_group,
- value=dict(current_state_ids),
+ value=dict(current_non_member_state_ids),
)
return state_group
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index b9f2b74ac6..4c296d72c0 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -348,7 +348,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
end_token (str): The stream token representing now.
Returns:
- Deferred[tuple[list[FrozenEvent], str]]: Returns a list of
+ Deferred[tuple[list[FrozenEvent], str]]: Returns a list of
events and a token pointing to the start of the returned
events.
The events returned are in ascending order.
@@ -379,7 +379,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
end_token (str): The stream token representing now.
Returns:
- Deferred[tuple[list[_EventDictReturn], str]]: Returns a list of
+ Deferred[tuple[list[_EventDictReturn], str]]: Returns a list of
_EventDictReturn and a token pointing to the start of the returned
events.
The events returned are in ascending order.
|