diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index be9934c66f..3725c9795d 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -23,7 +23,7 @@ from synapse.util.lrucache import LruCache
from twisted.internet import defer
-import collections
+from collections import namedtuple, OrderedDict
import simplejson as json
import sys
import time
@@ -35,6 +35,52 @@ sql_logger = logging.getLogger("synapse.storage.SQL")
transaction_logger = logging.getLogger("synapse.storage.txn")
+# TODO(paul):
+# * more generic key management
+# * export monitoring stats
+# * consider other eviction strategies - LRU?
+def cached(max_entries=1000):
+ """ A method decorator that applies a memoizing cache around the function.
+
+ The function is presumed to take one additional argument, which is used as
+ the key for the cache. Cache hits are served directly from the cache;
+ misses use the function body to generate the value.
+
+ The wrapped function has an additional member, a callable called
+ "invalidate". This can be used to remove individual entries from the cache.
+
+ The wrapped function has another additional callable, called "prefill",
+ which can be used to insert values into the cache specifically, without
+ calling the calculation function.
+ """
+ def wrap(orig):
+ cache = OrderedDict()
+
+ def prefill(key, value):
+ while len(cache) > max_entries:
+ cache.popitem(last=False)
+
+ cache[key] = value
+
+ @defer.inlineCallbacks
+ def wrapped(self, key):
+ if key in cache:
+ defer.returnValue(cache[key])
+
+ ret = yield orig(self, key)
+ prefill(key, ret)
+ defer.returnValue(ret)
+
+ def invalidate(key):
+ cache.pop(key, None)
+
+ wrapped.invalidate = invalidate
+ wrapped.prefill = prefill
+ return wrapped
+
+ return wrap
+
+
class LoggingTransaction(object):
"""An object that almost-transparently proxies for the 'txn' object
passed to the constructor. Adds logging to the .execute() method."""
@@ -404,7 +450,8 @@ class SQLBaseStore(object):
Args:
table : string giving the table name
- keyvalues : dict of column names and values to select the rows with
+ keyvalues : dict of column names and values to select the rows with,
+ or None to not apply a WHERE clause.
retcols : list of strings giving the names of the columns to return
"""
return self.runInteraction(
@@ -423,13 +470,20 @@ class SQLBaseStore(object):
keyvalues : dict of column names and values to select the rows with
retcols : list of strings giving the names of the columns to return
"""
- sql = "SELECT %s FROM %s WHERE %s ORDER BY rowid asc" % (
- ", ".join(retcols),
- table,
- " AND ".join("%s = ?" % (k, ) for k in keyvalues)
- )
+ if keyvalues:
+ sql = "SELECT %s FROM %s WHERE %s ORDER BY rowid asc" % (
+ ", ".join(retcols),
+ table,
+ " AND ".join("%s = ?" % (k, ) for k in keyvalues)
+ )
+ txn.execute(sql, keyvalues.values())
+ else:
+ sql = "SELECT %s FROM %s ORDER BY rowid asc" % (
+ ", ".join(retcols),
+ table
+ )
+ txn.execute(sql)
- txn.execute(sql, keyvalues.values())
return self.cursor_to_dict(txn)
def _simple_update_one(self, table, keyvalues, updatevalues,
@@ -586,8 +640,9 @@ class SQLBaseStore(object):
start_time = time.time() * 1000
update_counter = self._get_event_counters.update
+ cache = self._get_event_cache.setdefault(event_id, {})
+
try:
- cache = self._get_event_cache.setdefault(event_id, {})
# Separate cache entries for each way to invoke _get_event_txn
return cache[(check_redacted, get_prev_content, allow_rejected)]
except KeyError:
@@ -786,7 +841,7 @@ class JoinHelper(object):
for table in self.tables:
res += [f for f in table.fields if f not in res]
- self.EntryType = collections.namedtuple("JoinHelperEntry", res)
+ self.EntryType = namedtuple("JoinHelperEntry", res)
def get_fields(self, **prefixes):
"""Get a string representing a list of fields for use in SELECT
diff --git a/synapse/storage/appservice.py b/synapse/storage/appservice.py
index d941b1f387..97481d113b 100644
--- a/synapse/storage/appservice.py
+++ b/synapse/storage/appservice.py
@@ -15,31 +15,21 @@
import logging
from twisted.internet import defer
+from synapse.api.constants import Membership
from synapse.api.errors import StoreError
from synapse.appservice import ApplicationService
+from synapse.storage.roommember import RoomsForUser
from ._base import SQLBaseStore
logger = logging.getLogger(__name__)
-class ApplicationServiceCache(object):
- """Caches ApplicationServices and provides utility functions on top.
-
- This class is designed to be invoked on incoming events in order to avoid
- hammering the database every time to extract a list of application service
- regexes.
- """
-
- def __init__(self):
- self.services = []
-
-
class ApplicationServiceStore(SQLBaseStore):
def __init__(self, hs):
super(ApplicationServiceStore, self).__init__(hs)
- self.cache = ApplicationServiceCache()
+ self.services_cache = []
self.cache_defer = self._populate_cache()
@defer.inlineCallbacks
@@ -56,7 +46,7 @@ class ApplicationServiceStore(SQLBaseStore):
token,
)
# update cache TODO: Should this be in the txn?
- for service in self.cache.services:
+ for service in self.services_cache:
if service.token == token:
service.url = None
service.namespaces = None
@@ -110,13 +100,13 @@ class ApplicationServiceStore(SQLBaseStore):
)
# update cache TODO: Should this be in the txn?
- for (index, cache_service) in enumerate(self.cache.services):
+ for (index, cache_service) in enumerate(self.services_cache):
if service.token == cache_service.token:
- self.cache.services[index] = service
+ self.services_cache[index] = service
logger.info("Updated: %s", service)
return
# new entry
- self.cache.services.append(service)
+ self.services_cache.append(service)
logger.info("Updated(new): %s", service)
def _update_app_service_txn(self, txn, service):
@@ -160,11 +150,34 @@ class ApplicationServiceStore(SQLBaseStore):
@defer.inlineCallbacks
def get_app_services(self):
yield self.cache_defer # make sure the cache is ready
- defer.returnValue(self.cache.services)
+ defer.returnValue(self.services_cache)
+
+ @defer.inlineCallbacks
+ def get_app_service_by_user_id(self, user_id):
+ """Retrieve an application service from their user ID.
+
+ All application services have associated with them a particular user ID.
+ There is no distinguishing feature on the user ID which indicates it
+ represents an application service. This function allows you to map from
+ a user ID to an application service.
+
+ Args:
+ user_id(str): The user ID to see if it is an application service.
+ Returns:
+ synapse.appservice.ApplicationService or None.
+ """
+
+ yield self.cache_defer # make sure the cache is ready
+
+ for service in self.services_cache:
+ if service.sender == user_id:
+ defer.returnValue(service)
+ return
+ defer.returnValue(None)
@defer.inlineCallbacks
def get_app_service_by_token(self, token, from_cache=True):
- """Get the application service with the given token.
+ """Get the application service with the given appservice token.
Args:
token (str): The application service token.
@@ -176,7 +189,7 @@ class ApplicationServiceStore(SQLBaseStore):
yield self.cache_defer # make sure the cache is ready
if from_cache:
- for service in self.cache.services:
+ for service in self.services_cache:
if service.token == token:
defer.returnValue(service)
return
@@ -185,6 +198,77 @@ class ApplicationServiceStore(SQLBaseStore):
# TODO: The from_cache=False impl
# TODO: This should be JOINed with the application_services_regex table.
+ def get_app_service_rooms(self, service):
+ """Get a list of RoomsForUser for this application service.
+
+ Application services may be "interested" in lots of rooms depending on
+ the room ID, the room aliases, or the members in the room. This function
+ takes all of these into account and returns a list of RoomsForUser which
+ represent the entire list of room IDs that this application service
+ wants to know about.
+
+ Args:
+ service: The application service to get a room list for.
+ Returns:
+ A list of RoomsForUser.
+ """
+ return self.runInteraction(
+ "get_app_service_rooms",
+ self._get_app_service_rooms_txn,
+ service,
+ )
+
+ def _get_app_service_rooms_txn(self, txn, service):
+ # get all rooms matching the room ID regex.
+ room_entries = self._simple_select_list_txn(
+ txn=txn, table="rooms", keyvalues=None, retcols=["room_id"]
+ )
+ matching_room_list = set([
+ r["room_id"] for r in room_entries if
+ service.is_interested_in_room(r["room_id"])
+ ])
+
+ # resolve room IDs for matching room alias regex.
+ room_alias_mappings = self._simple_select_list_txn(
+ txn=txn, table="room_aliases", keyvalues=None,
+ retcols=["room_id", "room_alias"]
+ )
+ matching_room_list |= set([
+ r["room_id"] for r in room_alias_mappings if
+ service.is_interested_in_alias(r["room_alias"])
+ ])
+
+ # get all rooms for every user for this AS. This is scoped to users on
+ # this HS only.
+ user_list = self._simple_select_list_txn(
+ txn=txn, table="users", keyvalues=None, retcols=["name"]
+ )
+ user_list = [
+ u["name"] for u in user_list if
+ service.is_interested_in_user(u["name"])
+ ]
+ rooms_for_user_matching_user_id = set() # RoomsForUser list
+ for user_id in user_list:
+ # FIXME: This assumes this store is linked with RoomMemberStore :(
+ rooms_for_user = self._get_rooms_for_user_where_membership_is_txn(
+ txn=txn,
+ user_id=user_id,
+ membership_list=[Membership.JOIN]
+ )
+ rooms_for_user_matching_user_id |= set(rooms_for_user)
+
+ # make RoomsForUser tuples for room ids and aliases which are not in the
+ # main rooms_for_user_list - e.g. they are rooms which do not have AS
+ # registered users in it.
+ known_room_ids = [r.room_id for r in rooms_for_user_matching_user_id]
+ missing_rooms_for_user = [
+ RoomsForUser(r, service.sender, "join") for r in
+ matching_room_list if r not in known_room_ids
+ ]
+ rooms_for_user_matching_user_id |= set(missing_rooms_for_user)
+
+ return rooms_for_user_matching_user_id
+
@defer.inlineCallbacks
def _populate_cache(self):
"""Populates the ApplicationServiceCache from the database."""
@@ -235,7 +319,7 @@ class ApplicationServiceStore(SQLBaseStore):
# TODO get last successful txn id f.e. service
for service in services.values():
logger.info("Found application service: %s", service)
- self.cache.services.append(ApplicationService(
+ self.services_cache.append(ApplicationService(
token=service["token"],
url=service["url"],
namespaces=service["namespaces"],
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 9bf608bc90..65ffb4627f 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -17,7 +17,7 @@ from twisted.internet import defer
from collections import namedtuple
-from ._base import SQLBaseStore
+from ._base import SQLBaseStore, cached
from synapse.api.constants import Membership
from synapse.types import UserID
@@ -35,11 +35,6 @@ RoomsForUser = namedtuple(
class RoomMemberStore(SQLBaseStore):
- def __init__(self, *args, **kw):
- super(RoomMemberStore, self).__init__(*args, **kw)
-
- self._user_rooms_cache = {}
-
def _store_room_member_txn(self, txn, event):
"""Store a room member in the database.
"""
@@ -103,7 +98,7 @@ class RoomMemberStore(SQLBaseStore):
txn.execute(sql, (event.room_id, domain))
- self.invalidate_rooms_for_user(target_user_id)
+ self.get_rooms_for_user.invalidate(target_user_id)
@defer.inlineCallbacks
def get_room_member(self, user_id, room_id):
@@ -185,6 +180,14 @@ class RoomMemberStore(SQLBaseStore):
if not membership_list:
return defer.succeed(None)
+ return self.runInteraction(
+ "get_rooms_for_user_where_membership_is",
+ self._get_rooms_for_user_where_membership_is_txn,
+ user_id, membership_list
+ )
+
+ def _get_rooms_for_user_where_membership_is_txn(self, txn, user_id,
+ membership_list):
where_clause = "user_id = ? AND (%s)" % (
" OR ".join(["membership = ?" for _ in membership_list]),
)
@@ -192,24 +195,18 @@ class RoomMemberStore(SQLBaseStore):
args = [user_id]
args.extend(membership_list)
- def f(txn):
- sql = (
- "SELECT m.room_id, m.sender, m.membership"
- " FROM room_memberships as m"
- " INNER JOIN current_state_events as c"
- " ON m.event_id = c.event_id"
- " WHERE %s"
- ) % (where_clause,)
-
- txn.execute(sql, args)
- return [
- RoomsForUser(**r) for r in self.cursor_to_dict(txn)
- ]
+ sql = (
+ "SELECT m.room_id, m.sender, m.membership"
+ " FROM room_memberships as m"
+ " INNER JOIN current_state_events as c"
+ " ON m.event_id = c.event_id"
+ " WHERE %s"
+ ) % (where_clause,)
- return self.runInteraction(
- "get_rooms_for_user_where_membership_is",
- f
- )
+ txn.execute(sql, args)
+ return [
+ RoomsForUser(**r) for r in self.cursor_to_dict(txn)
+ ]
def get_joined_hosts_for_room(self, room_id):
return self._simple_select_onecol(
@@ -247,33 +244,12 @@ class RoomMemberStore(SQLBaseStore):
results = self._parse_events_txn(txn, rows)
return results
- # TODO(paul): Create a nice @cached decorator to do this
- # @cached
- # def get_foo(...)
- # ...
- # invalidate_foo = get_foo.invalidator
-
- @defer.inlineCallbacks
+ @cached()
def get_rooms_for_user(self, user_id):
- # TODO(paul): put some performance counters in here so we can easily
- # track what impact this cache is having
- if user_id in self._user_rooms_cache:
- defer.returnValue(self._user_rooms_cache[user_id])
-
- rooms = yield self.get_rooms_for_user_where_membership_is(
+ return self.get_rooms_for_user_where_membership_is(
user_id, membership_list=[Membership.JOIN],
)
- # TODO(paul): Consider applying a maximum size; just evict things at
- # random, or consider LRU?
-
- self._user_rooms_cache[user_id] = rooms
- defer.returnValue(rooms)
-
- def invalidate_rooms_for_user(self, user_id):
- if user_id in self._user_rooms_cache:
- del self._user_rooms_cache[user_id]
-
@defer.inlineCallbacks
def user_rooms_intersect(self, user_id_list):
""" Checks whether all the users whose IDs are given in a list share a
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index 3ccb6f8a61..09bc522210 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -36,6 +36,7 @@ what sort order was used:
from twisted.internet import defer
from ._base import SQLBaseStore
+from synapse.api.constants import EventTypes
from synapse.api.errors import SynapseError
from synapse.util.logutils import log_function
@@ -127,6 +128,85 @@ class _StreamToken(namedtuple("_StreamToken", "topological stream")):
class StreamStore(SQLBaseStore):
+
+ @defer.inlineCallbacks
+ def get_appservice_room_stream(self, service, from_key, to_key, limit=0):
+ # NB this lives here instead of appservice.py so we can reuse the
+ # 'private' StreamToken class in this file.
+ if limit:
+ limit = max(limit, MAX_STREAM_SIZE)
+ else:
+ limit = MAX_STREAM_SIZE
+
+ # From and to keys should be integers from ordering.
+ from_id = _StreamToken.parse_stream_token(from_key)
+ to_id = _StreamToken.parse_stream_token(to_key)
+
+ if from_key == to_key:
+ defer.returnValue(([], to_key))
+ return
+
+ # select all the events between from/to with a sensible limit
+ sql = (
+ "SELECT e.event_id, e.room_id, e.type, s.state_key, "
+ "e.stream_ordering FROM events AS e LEFT JOIN state_events as s ON "
+ "e.event_id = s.event_id "
+ "WHERE e.stream_ordering > ? AND e.stream_ordering <= ? "
+ "ORDER BY stream_ordering ASC LIMIT %(limit)d "
+ ) % {
+ "limit": limit
+ }
+
+ def f(txn):
+ # pull out all the events between the tokens
+ txn.execute(sql, (from_id.stream, to_id.stream,))
+ rows = self.cursor_to_dict(txn)
+
+ # Logic:
+ # - We want ALL events which match the AS room_id regex
+ # - We want ALL events which match the rooms represented by the AS
+ # room_alias regex
+ # - We want ALL events for rooms that AS users have joined.
+ # This is currently supported via get_app_service_rooms (which is
+ # used for the Notifier listener rooms). We can't reasonably make a
+ # SQL query for these room IDs, so we'll pull all the events between
+ # from/to and filter in python.
+ rooms_for_as = self._get_app_service_rooms_txn(txn, service)
+ room_ids_for_as = [r.room_id for r in rooms_for_as]
+
+ def app_service_interested(row):
+ if row["room_id"] in room_ids_for_as:
+ return True
+
+ if row["type"] == EventTypes.Member:
+ if service.is_interested_in_user(row.get("state_key")):
+ return True
+ return False
+
+ ret = self._get_events_txn(
+ txn,
+ # apply the filter on the room id list
+ [
+ r["event_id"] for r in rows
+ if app_service_interested(r)
+ ],
+ get_prev_content=True
+ )
+
+ self._set_before_and_after(ret, rows)
+
+ if rows:
+ key = "s%d" % max(r["stream_ordering"] for r in rows)
+ else:
+ # Assume we didn't get anything because there was nothing to
+ # get.
+ key = to_key
+
+ return ret, key
+
+ results = yield self.runInteraction("get_appservice_room_stream", f)
+ defer.returnValue(results)
+
@log_function
def get_room_events_stream(self, user_id, from_key, to_key, room_id,
limit=0, with_feedback=False):
@@ -184,8 +264,7 @@ class StreamStore(SQLBaseStore):
self._set_before_and_after(ret, rows)
if rows:
- key = "s%d" % max([r["stream_ordering"] for r in rows])
-
+ key = "s%d" % max(r["stream_ordering"] for r in rows)
else:
# Assume we didn't get anything because there was nothing to
# get.
diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py
index e06ef35690..0b8a3b7a07 100644
--- a/synapse/storage/transactions.py
+++ b/synapse/storage/transactions.py
@@ -13,12 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import SQLBaseStore, Table
+from ._base import SQLBaseStore, Table, cached
from collections import namedtuple
-from twisted.internet import defer
-
import logging
logger = logging.getLogger(__name__)
@@ -28,10 +26,6 @@ class TransactionStore(SQLBaseStore):
"""A collection of queries for handling PDUs.
"""
- # a write-through cache of DestinationsTable.EntryType indexed by
- # destination string
- destination_retry_cache = {}
-
def get_received_txn_response(self, transaction_id, origin):
"""For an incoming transaction from a given origin, check if we have
already responded to it. If so, return the response code and response
@@ -211,6 +205,7 @@ class TransactionStore(SQLBaseStore):
return ReceivedTransactionsTable.decode_results(txn.fetchall())
+ @cached()
def get_destination_retry_timings(self, destination):
"""Gets the current retry timings (if any) for a given destination.
@@ -221,9 +216,6 @@ class TransactionStore(SQLBaseStore):
None if not retrying
Otherwise a DestinationsTable.EntryType for the retry scheme
"""
- if destination in self.destination_retry_cache:
- return defer.succeed(self.destination_retry_cache[destination])
-
return self.runInteraction(
"get_destination_retry_timings",
self._get_destination_retry_timings, destination)
@@ -250,7 +242,9 @@ class TransactionStore(SQLBaseStore):
retry_interval (int) - how long until next retry in ms
"""
- self.destination_retry_cache[destination] = (
+ # As this is the new value, we might as well prefill the cache
+ self.get_destination_retry_timings.prefill(
+ destination,
DestinationsTable.EntryType(
destination,
retry_last_ts,
|