diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 08dffd774f..be61147b9b 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -17,9 +17,10 @@ import sys
import threading
import time
-from six import iteritems, iterkeys, itervalues
+from six import PY2, iteritems, iterkeys, itervalues
from six.moves import intern, range
+from canonicaljson import json
from prometheus_client import Histogram
from twisted.internet import defer
@@ -1216,3 +1217,32 @@ class _RollbackButIsFineException(Exception):
something went wrong.
"""
pass
+
+
+def db_to_json(db_content):
+ """
+ Take some data from a database row and return a JSON-decoded object.
+
+ Args:
+ db_content (memoryview|buffer|bytes|bytearray|unicode)
+ """
+ # psycopg2 on Python 3 returns memoryview objects, which we need to
+ # cast to bytes to decode
+ if isinstance(db_content, memoryview):
+ db_content = db_content.tobytes()
+
+ # psycopg2 on Python 2 returns buffer objects, which we need to cast to
+ # bytes to decode
+ if PY2 and isinstance(db_content, buffer):
+ db_content = bytes(db_content)
+
+ # Decode it to a Unicode string before feeding it to json.loads, so we
+ # consistenty get a Unicode-containing object out.
+ if isinstance(db_content, (bytes, bytearray)):
+ db_content = db_content.decode('utf8')
+
+ try:
+ return json.loads(db_content)
+ except Exception:
+ logging.warning("Tried to decode '%r' as JSON and failed", db_content)
+ raise
diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py
index 8fc678fa67..9ad17b7c25 100644
--- a/synapse/storage/client_ips.py
+++ b/synapse/storage/client_ips.py
@@ -119,21 +119,25 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
for entry in iteritems(to_update):
(user_id, access_token, ip), (user_agent, device_id, last_seen) = entry
- self._simple_upsert_txn(
- txn,
- table="user_ips",
- keyvalues={
- "user_id": user_id,
- "access_token": access_token,
- "ip": ip,
- "user_agent": user_agent,
- "device_id": device_id,
- },
- values={
- "last_seen": last_seen,
- },
- lock=False,
- )
+ try:
+ self._simple_upsert_txn(
+ txn,
+ table="user_ips",
+ keyvalues={
+ "user_id": user_id,
+ "access_token": access_token,
+ "ip": ip,
+ "user_agent": user_agent,
+ "device_id": device_id,
+ },
+ values={
+ "last_seen": last_seen,
+ },
+ lock=False,
+ )
+ except Exception as e:
+ # Failed to upsert, log and continue
+ logger.error("Failed to insert client IP %r: %r", entry, e)
@defer.inlineCallbacks
def get_last_client_ip_by_device(self, user_id, device_id):
diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py
index 73646da025..e06b0bc56d 100644
--- a/synapse/storage/deviceinbox.py
+++ b/synapse/storage/deviceinbox.py
@@ -169,7 +169,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
local_by_user_then_device = {}
for user_id, messages_by_device in messages_by_user_then_device.items():
messages_json_for_user = {}
- devices = messages_by_device.keys()
+ devices = list(messages_by_device.keys())
if len(devices) == 1 and devices[0] == "*":
# Handle wildcard device_ids.
sql = (
diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py
index c0943ecf91..d10ff9e4b9 100644
--- a/synapse/storage/devices.py
+++ b/synapse/storage/devices.py
@@ -24,7 +24,7 @@ from synapse.api.errors import StoreError
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
-from ._base import Cache, SQLBaseStore
+from ._base import Cache, SQLBaseStore, db_to_json
logger = logging.getLogger(__name__)
@@ -411,7 +411,7 @@ class DeviceStore(SQLBaseStore):
if device is not None:
key_json = device.get("key_json", None)
if key_json:
- result["keys"] = json.loads(key_json)
+ result["keys"] = db_to_json(key_json)
device_display_name = device.get("device_display_name", None)
if device_display_name:
result["device_display_name"] = device_display_name
@@ -466,7 +466,7 @@ class DeviceStore(SQLBaseStore):
retcol="content",
desc="_get_cached_user_device",
)
- defer.returnValue(json.loads(content))
+ defer.returnValue(db_to_json(content))
@cachedInlineCallbacks()
def _get_cached_devices_for_user(self, user_id):
@@ -479,7 +479,7 @@ class DeviceStore(SQLBaseStore):
desc="_get_cached_devices_for_user",
)
defer.returnValue({
- device["device_id"]: json.loads(device["content"])
+ device["device_id"]: db_to_json(device["content"])
for device in devices
})
@@ -511,7 +511,7 @@ class DeviceStore(SQLBaseStore):
key_json = device.get("key_json", None)
if key_json:
- result["keys"] = json.loads(key_json)
+ result["keys"] = db_to_json(key_json)
device_display_name = device.get("device_display_name", None)
if device_display_name:
result["device_display_name"] = device_display_name
diff --git a/synapse/storage/directory.py b/synapse/storage/directory.py
index 808194236a..cfb687cb53 100644
--- a/synapse/storage/directory.py
+++ b/synapse/storage/directory.py
@@ -75,7 +75,6 @@ class DirectoryWorkerStore(SQLBaseStore):
},
retcol="creator",
desc="get_room_alias_creator",
- allow_none=True
)
@cached(max_entries=5000)
diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py
index 523b4360c3..1f1721e820 100644
--- a/synapse/storage/end_to_end_keys.py
+++ b/synapse/storage/end_to_end_keys.py
@@ -14,13 +14,13 @@
# limitations under the License.
from six import iteritems
-from canonicaljson import encode_canonical_json, json
+from canonicaljson import encode_canonical_json
from twisted.internet import defer
from synapse.util.caches.descriptors import cached
-from ._base import SQLBaseStore
+from ._base import SQLBaseStore, db_to_json
class EndToEndKeyStore(SQLBaseStore):
@@ -90,7 +90,7 @@ class EndToEndKeyStore(SQLBaseStore):
for user_id, device_keys in iteritems(results):
for device_id, device_info in iteritems(device_keys):
- device_info["keys"] = json.loads(device_info.pop("key_json"))
+ device_info["keys"] = db_to_json(device_info.pop("key_json"))
defer.returnValue(results)
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index 8a0386c1a4..42225f8a2a 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -41,13 +41,18 @@ class PostgresEngine(object):
db_conn.set_isolation_level(
self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ
)
+
+ # Set the bytea output to escape, vs the default of hex
+ cursor = db_conn.cursor()
+ cursor.execute("SET bytea_output TO escape")
+
# Asynchronous commit, don't wait for the server to call fsync before
# ending the transaction.
# https://www.postgresql.org/docs/current/static/wal-async-commit.html
if not self.synchronous_commit:
- cursor = db_conn.cursor()
cursor.execute("SET synchronous_commit TO OFF")
- cursor.close()
+
+ cursor.close()
def is_deadlock(self, error):
if isinstance(error, self.module.DatabaseError):
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index f39c8c8461..03cedf3a75 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -19,7 +19,7 @@ import logging
from collections import OrderedDict, deque, namedtuple
from functools import wraps
-from six import iteritems
+from six import iteritems, text_type
from six.moves import range
from canonicaljson import json
@@ -38,6 +38,7 @@ from synapse.storage.background_updates import BackgroundUpdateStore
from synapse.storage.event_federation import EventFederationStore
from synapse.storage.events_worker import EventsWorkerStore
from synapse.types import RoomStreamToken, get_domain_from_id
+from synapse.util import batch_iter
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
from synapse.util.frozenutils import frozendict_json_encoder
@@ -386,12 +387,10 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
)
for room_id, ev_ctx_rm in iteritems(events_by_room):
- # Work out new extremities by recursively adding and removing
- # the new events.
latest_event_ids = yield self.get_latest_event_ids_in_room(
room_id
)
- new_latest_event_ids = yield self._calculate_new_extremeties(
+ new_latest_event_ids = yield self._calculate_new_extremities(
room_id, ev_ctx_rm, latest_event_ids
)
@@ -400,6 +399,12 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
# No change in extremities, so no change in state
continue
+ # there should always be at least one forward extremity.
+ # (except during the initial persistence of the send_join
+ # results, in which case there will be no existing
+ # extremities, so we'll `continue` above and skip this bit.)
+ assert new_latest_event_ids, "No forward extremities left!"
+
new_forward_extremeties[room_id] = new_latest_event_ids
len_1 = (
@@ -517,44 +522,79 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
)
@defer.inlineCallbacks
- def _calculate_new_extremeties(self, room_id, event_contexts, latest_event_ids):
- """Calculates the new forward extremeties for a room given events to
+ def _calculate_new_extremities(self, room_id, event_contexts, latest_event_ids):
+ """Calculates the new forward extremities for a room given events to
persist.
Assumes that we are only persisting events for one room at a time.
"""
- new_latest_event_ids = set(latest_event_ids)
- # First, add all the new events to the list
- new_latest_event_ids.update(
- event.event_id for event, ctx in event_contexts
+
+ # we're only interested in new events which aren't outliers and which aren't
+ # being rejected.
+ new_events = [
+ event for event, ctx in event_contexts
if not event.internal_metadata.is_outlier() and not ctx.rejected
+ ]
+
+ # start with the existing forward extremities
+ result = set(latest_event_ids)
+
+ # add all the new events to the list
+ result.update(
+ event.event_id for event in new_events
)
- # Now remove all events that are referenced by the to-be-added events
- new_latest_event_ids.difference_update(
+
+ # Now remove all events which are prev_events of any of the new events
+ result.difference_update(
e_id
- for event, ctx in event_contexts
+ for event in new_events
for e_id, _ in event.prev_events
- if not event.internal_metadata.is_outlier() and not ctx.rejected
)
- # And finally remove any events that are referenced by previously added
- # events.
- rows = yield self._simple_select_many_batch(
- table="event_edges",
- column="prev_event_id",
- iterable=list(new_latest_event_ids),
- retcols=["prev_event_id"],
- keyvalues={
- "is_state": False,
- },
- desc="_calculate_new_extremeties",
- )
+ # Finally, remove any events which are prev_events of any existing events.
+ existing_prevs = yield self._get_events_which_are_prevs(result)
+ result.difference_update(existing_prevs)
- new_latest_event_ids.difference_update(
- row["prev_event_id"] for row in rows
- )
+ defer.returnValue(result)
+
+ @defer.inlineCallbacks
+ def _get_events_which_are_prevs(self, event_ids):
+ """Filter the supplied list of event_ids to get those which are prev_events of
+ existing (non-outlier/rejected) events.
+
+ Args:
+ event_ids (Iterable[str]): event ids to filter
+
+ Returns:
+ Deferred[List[str]]: filtered event ids
+ """
+ results = []
+
+ def _get_events(txn, batch):
+ sql = """
+ SELECT prev_event_id
+ FROM event_edges
+ INNER JOIN events USING (event_id)
+ LEFT JOIN rejections USING (event_id)
+ WHERE
+ prev_event_id IN (%s)
+ AND NOT events.outlier
+ AND rejections.event_id IS NULL
+ """ % (
+ ",".join("?" for _ in batch),
+ )
+
+ txn.execute(sql, batch)
+ results.extend(r[0] for r in txn)
+
+ for chunk in batch_iter(event_ids, 100):
+ yield self.runInteraction(
+ "_get_events_which_are_prevs",
+ _get_events,
+ chunk,
+ )
- defer.returnValue(new_latest_event_ids)
+ defer.returnValue(results)
@defer.inlineCallbacks
def _get_new_state_after_events(self, room_id, events_context, old_latest_event_ids,
@@ -586,10 +626,6 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
the new current state is only returned if we've already calculated
it.
"""
-
- if not new_latest_event_ids:
- return
-
# map from state_group to ((type, key) -> event_id) state map
state_groups_map = {}
@@ -930,6 +966,10 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
)
self._invalidate_cache_and_stream(
+ txn, self.get_room_summary, (room_id,)
+ )
+
+ self._invalidate_cache_and_stream(
txn, self.get_current_state_ids, (room_id,)
)
@@ -1220,7 +1260,7 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
"sender": event.sender,
"contains_url": (
"url" in event.content
- and isinstance(event.content["url"], basestring)
+ and isinstance(event.content["url"], text_type)
),
}
for event, _ in events_and_contexts
@@ -1529,7 +1569,7 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
contains_url = "url" in content
if contains_url:
- contains_url &= isinstance(content["url"], basestring)
+ contains_url &= isinstance(content["url"], text_type)
except (KeyError, AttributeError):
# If the event is missing a necessary field then
# skip over it.
@@ -1886,20 +1926,6 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
")"
)
- # create an index on should_delete because later we'll be looking for
- # the should_delete / shouldn't_delete subsets
- txn.execute(
- "CREATE INDEX events_to_purge_should_delete"
- " ON events_to_purge(should_delete)",
- )
-
- # We do joins against events_to_purge for e.g. calculating state
- # groups to purge, etc., so lets make an index.
- txn.execute(
- "CREATE INDEX events_to_purge_id"
- " ON events_to_purge(event_id)",
- )
-
# First ensure that we're not about to delete all the forward extremeties
txn.execute(
"SELECT e.event_id, e.depth FROM events as e "
@@ -1910,9 +1936,9 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
(room_id,)
)
rows = txn.fetchall()
- max_depth = max(row[0] for row in rows)
+ max_depth = max(row[1] for row in rows)
- if max_depth <= token.topological:
+ if max_depth < token.topological:
# We need to ensure we don't delete all the events from the database
# otherwise we wouldn't be able to send any events (due to not
# having any backwards extremeties)
@@ -1926,19 +1952,45 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
should_delete_params = ()
if not delete_local_events:
should_delete_expr += " AND event_id NOT LIKE ?"
- should_delete_params += ("%:" + self.hs.hostname, )
+
+ # We include the parameter twice since we use the expression twice
+ should_delete_params += (
+ "%:" + self.hs.hostname,
+ "%:" + self.hs.hostname,
+ )
should_delete_params += (room_id, token.topological)
+ # Note that we insert events that are outliers and aren't going to be
+ # deleted, as nothing will happen to them.
txn.execute(
"INSERT INTO events_to_purge"
" SELECT event_id, %s"
" FROM events AS e LEFT JOIN state_events USING (event_id)"
- " WHERE e.room_id = ? AND topological_ordering < ?" % (
+ " WHERE (NOT outlier OR (%s)) AND e.room_id = ? AND topological_ordering < ?"
+ % (
+ should_delete_expr,
should_delete_expr,
),
should_delete_params,
)
+
+ # We create the indices *after* insertion as that's a lot faster.
+
+ # create an index on should_delete because later we'll be looking for
+ # the should_delete / shouldn't_delete subsets
+ txn.execute(
+ "CREATE INDEX events_to_purge_should_delete"
+ " ON events_to_purge(should_delete)",
+ )
+
+ # We do joins against events_to_purge for e.g. calculating state
+ # groups to purge, etc., so lets make an index.
+ txn.execute(
+ "CREATE INDEX events_to_purge_id"
+ " ON events_to_purge(event_id)",
+ )
+
txn.execute(
"SELECT event_id, should_delete FROM events_to_purge"
)
diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py
index 59822178ff..a8326f5296 100644
--- a/synapse/storage/events_worker.py
+++ b/synapse/storage/events_worker.py
@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
import itertools
import logging
from collections import namedtuple
@@ -265,7 +266,7 @@ class EventsWorkerStore(SQLBaseStore):
"""
with Measure(self._clock, "_fetch_event_list"):
try:
- event_id_lists = zip(*event_list)[0]
+ event_id_lists = list(zip(*event_list))[0]
event_ids = [
item for sublist in event_id_lists for item in sublist
]
@@ -299,14 +300,14 @@ class EventsWorkerStore(SQLBaseStore):
logger.exception("do_fetch")
# We only want to resolve deferreds from the main thread
- def fire(evs):
+ def fire(evs, exc):
for _, d in evs:
if not d.called:
with PreserveLoggingContext():
- d.errback(e)
+ d.errback(exc)
with PreserveLoggingContext():
- self.hs.get_reactor().callFromThread(fire, event_list)
+ self.hs.get_reactor().callFromThread(fire, event_list, e)
@defer.inlineCallbacks
def _enqueue_events(self, events, check_redacted=True, allow_rejected=False):
diff --git a/synapse/storage/filtering.py b/synapse/storage/filtering.py
index 2d5896c5b4..6ddcc909bf 100644
--- a/synapse/storage/filtering.py
+++ b/synapse/storage/filtering.py
@@ -13,14 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from canonicaljson import encode_canonical_json, json
+from canonicaljson import encode_canonical_json
from twisted.internet import defer
from synapse.api.errors import Codes, SynapseError
from synapse.util.caches.descriptors import cachedInlineCallbacks
-from ._base import SQLBaseStore
+from ._base import SQLBaseStore, db_to_json
class FilteringStore(SQLBaseStore):
@@ -44,7 +44,7 @@ class FilteringStore(SQLBaseStore):
desc="get_user_filter",
)
- defer.returnValue(json.loads(bytes(def_json).decode("utf-8")))
+ defer.returnValue(db_to_json(def_json))
def add_user_filter(self, user_localpart, user_filter):
def_json = encode_canonical_json(user_filter)
diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py
index f547977600..a1331c1a61 100644
--- a/synapse/storage/keys.py
+++ b/synapse/storage/keys.py
@@ -134,6 +134,7 @@ class KeyStore(SQLBaseStore):
"""
key_id = "%s:%s" % (verify_key.alg, verify_key.version)
+ # XXX fix this to not need a lock (#3819)
def _txn(txn):
self._simple_upsert_txn(
txn,
diff --git a/synapse/storage/monthly_active_users.py b/synapse/storage/monthly_active_users.py
index d178f5c5ba..0fe8c8e24c 100644
--- a/synapse/storage/monthly_active_users.py
+++ b/synapse/storage/monthly_active_users.py
@@ -36,7 +36,6 @@ class MonthlyActiveUsersStore(SQLBaseStore):
@defer.inlineCallbacks
def initialise_reserved_users(self, threepids):
- # TODO Why can't I do this in init?
store = self.hs.get_datastore()
reserved_user_list = []
@@ -148,6 +147,23 @@ class MonthlyActiveUsersStore(SQLBaseStore):
return self.runInteraction("count_users", _count_users)
@defer.inlineCallbacks
+ def get_registered_reserved_users_count(self):
+ """Of the reserved threepids defined in config, how many are associated
+ with registered users?
+
+ Returns:
+ Defered[int]: Number of real reserved users
+ """
+ count = 0
+ for tp in self.hs.config.mau_limits_reserved_threepids:
+ user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
+ tp["medium"], tp["address"]
+ )
+ if user_id:
+ count = count + 1
+ defer.returnValue(count)
+
+ @defer.inlineCallbacks
def upsert_monthly_active_user(self, user_id):
"""
Updates or inserts monthly active user member
@@ -156,6 +172,10 @@ class MonthlyActiveUsersStore(SQLBaseStore):
Deferred[bool]: True if a new entry was created, False if an
existing one was updated.
"""
+ # Am consciously deciding to lock the table on the basis that is ought
+ # never be a big table and alternative approaches (batching multiple
+ # upserts into a single txn) introduced a lot of extra complexity.
+ # See https://github.com/matrix-org/synapse/issues/3854 for more
is_insert = yield self._simple_upsert(
desc="upsert_monthly_active_user",
table="monthly_active_users",
@@ -165,7 +185,6 @@ class MonthlyActiveUsersStore(SQLBaseStore):
values={
"timestamp": int(self._clock.time_msec()),
},
- lock=False,
)
if is_insert:
self.user_last_seen_monthly_active.invalidate((user_id,))
@@ -200,10 +219,14 @@ class MonthlyActiveUsersStore(SQLBaseStore):
Args:
user_id(str): the user_id to query
"""
+
if self.hs.config.limit_usage_by_mau:
+ # Trial users and guests should not be included as part of MAU group
+ is_guest = yield self.is_guest(user_id)
+ if is_guest:
+ return
is_trial = yield self.is_trial_user(user_id)
if is_trial:
- # we don't track trial users in the MAU table.
return
last_seen_timestamp = yield self.user_last_seen_monthly_active(user_id)
diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py
index 8443bd4c1b..c7987bfcdd 100644
--- a/synapse/storage/pusher.py
+++ b/synapse/storage/pusher.py
@@ -15,7 +15,8 @@
# limitations under the License.
import logging
-import types
+
+import six
from canonicaljson import encode_canonical_json, json
@@ -27,6 +28,11 @@ from ._base import SQLBaseStore
logger = logging.getLogger(__name__)
+if six.PY2:
+ db_binary_type = buffer
+else:
+ db_binary_type = memoryview
+
class PusherWorkerStore(SQLBaseStore):
def _decode_pushers_rows(self, rows):
@@ -34,18 +40,18 @@ class PusherWorkerStore(SQLBaseStore):
dataJson = r['data']
r['data'] = None
try:
- if isinstance(dataJson, types.BufferType):
+ if isinstance(dataJson, db_binary_type):
dataJson = str(dataJson).decode("UTF8")
r['data'] = json.loads(dataJson)
except Exception as e:
logger.warn(
"Invalid JSON in data for pusher %d: %s, %s",
- r['id'], dataJson, e.message,
+ r['id'], dataJson, e.args[0],
)
pass
- if isinstance(r['pushkey'], types.BufferType):
+ if isinstance(r['pushkey'], db_binary_type):
r['pushkey'] = str(r['pushkey']).decode("UTF8")
return rows
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 9b4e6d6aa8..0707f9a86a 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -51,6 +51,12 @@ ProfileInfo = namedtuple(
"ProfileInfo", ("avatar_url", "display_name")
)
+# "members" points to a truncated list of (user_id, event_id) tuples for users of
+# a given membership type, suitable for use in calculating heroes for a room.
+# "count" points to the total numberr of users of a given membership type.
+MemberSummary = namedtuple(
+ "MemberSummary", ("members", "count")
+)
_MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update"
@@ -82,6 +88,65 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return [to_ascii(r[0]) for r in txn]
return self.runInteraction("get_users_in_room", f)
+ @cached(max_entries=100000)
+ def get_room_summary(self, room_id):
+ """ Get the details of a room roughly suitable for use by the room
+ summary extension to /sync. Useful when lazy loading room members.
+ Args:
+ room_id (str): The room ID to query
+ Returns:
+ Deferred[dict[str, MemberSummary]:
+ dict of membership states, pointing to a MemberSummary named tuple.
+ """
+
+ def _get_room_summary_txn(txn):
+ # first get counts.
+ # We do this all in one transaction to keep the cache small.
+ # FIXME: get rid of this when we have room_stats
+ sql = """
+ SELECT count(*), m.membership FROM room_memberships as m
+ INNER JOIN current_state_events as c
+ ON m.event_id = c.event_id
+ AND m.room_id = c.room_id
+ AND m.user_id = c.state_key
+ WHERE c.type = 'm.room.member' AND c.room_id = ?
+ GROUP BY m.membership
+ """
+
+ txn.execute(sql, (room_id,))
+ res = {}
+ for count, membership in txn:
+ summary = res.setdefault(to_ascii(membership), MemberSummary([], count))
+
+ # we order by membership and then fairly arbitrarily by event_id so
+ # heroes are consistent
+ sql = """
+ SELECT m.user_id, m.membership, m.event_id
+ FROM room_memberships as m
+ INNER JOIN current_state_events as c
+ ON m.event_id = c.event_id
+ AND m.room_id = c.room_id
+ AND m.user_id = c.state_key
+ WHERE c.type = 'm.room.member' AND c.room_id = ?
+ ORDER BY
+ CASE m.membership WHEN ? THEN 1 WHEN ? THEN 2 ELSE 3 END ASC,
+ m.event_id ASC
+ LIMIT ?
+ """
+
+ # 6 is 5 (number of heroes) plus 1, in case one of them is the calling user.
+ txn.execute(sql, (room_id, Membership.JOIN, Membership.INVITE, 6))
+ for user_id, membership, event_id in txn:
+ summary = res[to_ascii(membership)]
+ # we will always have a summary for this membership type at this
+ # point given the summary currently contains the counts.
+ members = summary.members
+ members.append((to_ascii(user_id), to_ascii(event_id)))
+
+ return res
+
+ return self.runInteraction("get_room_summary", _get_room_summary_txn)
+
@cached()
def get_invited_rooms_for_user(self, user_id):
""" Get all the rooms the user is invited to
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 4b971efdba..3f4cbd61c4 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -255,7 +255,17 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
)
@defer.inlineCallbacks
- def get_state_groups_ids(self, room_id, event_ids):
+ def get_state_groups_ids(self, _room_id, event_ids):
+ """Get the event IDs of all the state for the state groups for the given events
+
+ Args:
+ _room_id (str): id of the room for these events
+ event_ids (iterable[str]): ids of the events
+
+ Returns:
+ Deferred[dict[int, dict[tuple[str, str], str]]]:
+ dict of state_group_id -> (dict of (type, state_key) -> event id)
+ """
if not event_ids:
defer.returnValue({})
@@ -270,7 +280,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
@defer.inlineCallbacks
def get_state_ids_for_group(self, state_group):
- """Get the state IDs for the given state group
+ """Get the event IDs of all the state in the given state group
Args:
state_group (int)
@@ -286,7 +296,9 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
def get_state_groups(self, room_id, event_ids):
""" Get the state groups for the given list of event_ids
- The return value is a dict mapping group names to lists of events.
+ Returns:
+ Deferred[dict[int, list[EventBase]]]:
+ dict of state_group_id -> list of state events.
"""
if not event_ids:
defer.returnValue({})
@@ -324,7 +336,9 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
member events (if True), or to exclude member events (if False)
Returns:
- dictionary state_group -> (dict of (type, state_key) -> event id)
+ Returns:
+ Deferred[dict[int, dict[tuple[str, str], str]]]:
+ dict of state_group_id -> (dict of (type, state_key) -> event id)
"""
results = {}
@@ -732,8 +746,8 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
If None, `types` filtering is applied to all events.
Returns:
- Deferred[dict[int, dict[(type, state_key), EventBase]]]
- a dictionary mapping from state group to state dictionary.
+ Deferred[dict[int, dict[tuple[str, str], str]]]:
+ dict of state_group_id -> (dict of (type, state_key) -> event id)
"""
if types is not None:
non_member_types = [t for t in types if t[0] != EventTypes.Member]
@@ -788,8 +802,8 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
If None, `types` filtering is applied to all events.
Returns:
- Deferred[dict[int, dict[(type, state_key), EventBase]]]
- a dictionary mapping from state group to state dictionary.
+ Deferred[dict[int, dict[tuple[str, str], str]]]:
+ dict of state_group_id -> (dict of (type, state_key) -> event id)
"""
if types:
types = frozenset(types)
diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py
index 428e7fa36e..a3032cdce9 100644
--- a/synapse/storage/transactions.py
+++ b/synapse/storage/transactions.py
@@ -18,14 +18,14 @@ from collections import namedtuple
import six
-from canonicaljson import encode_canonical_json, json
+from canonicaljson import encode_canonical_json
from twisted.internet import defer
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.util.caches.descriptors import cached
+from synapse.util.caches.expiringcache import ExpiringCache
-from ._base import SQLBaseStore
+from ._base import SQLBaseStore, db_to_json
# py2 sqlite has buffer hardcoded as only binary type, so we must use it,
# despite being deprecated and removed in favor of memoryview
@@ -50,6 +50,8 @@ _UpdateTransactionRow = namedtuple(
)
)
+SENTINEL = object()
+
class TransactionStore(SQLBaseStore):
"""A collection of queries for handling PDUs.
@@ -60,6 +62,12 @@ class TransactionStore(SQLBaseStore):
self._clock.looping_call(self._start_cleanup_transactions, 30 * 60 * 1000)
+ self._destination_retry_cache = ExpiringCache(
+ cache_name="get_destination_retry_timings",
+ clock=self._clock,
+ expiry_ms=5 * 60 * 1000,
+ )
+
def get_received_txn_response(self, transaction_id, origin):
"""For an incoming transaction from a given origin, check if we have
already responded to it. If so, return the response code and response
@@ -95,7 +103,8 @@ class TransactionStore(SQLBaseStore):
)
if result and result["response_code"]:
- return result["response_code"], json.loads(str(result["response_json"]))
+ return result["response_code"], db_to_json(result["response_json"])
+
else:
return None
@@ -155,7 +164,7 @@ class TransactionStore(SQLBaseStore):
"""
pass
- @cached(max_entries=10000)
+ @defer.inlineCallbacks
def get_destination_retry_timings(self, destination):
"""Gets the current retry timings (if any) for a given destination.
@@ -166,10 +175,20 @@ class TransactionStore(SQLBaseStore):
None if not retrying
Otherwise a dict for the retry scheme
"""
- return self.runInteraction(
+
+ result = self._destination_retry_cache.get(destination, SENTINEL)
+ if result is not SENTINEL:
+ defer.returnValue(result)
+
+ result = yield self.runInteraction(
"get_destination_retry_timings",
self._get_destination_retry_timings, destination)
+ # We don't hugely care about race conditions between getting and
+ # invalidating the cache, since we time out fairly quickly anyway.
+ self._destination_retry_cache[destination] = result
+ defer.returnValue(result)
+
def _get_destination_retry_timings(self, txn, destination):
result = self._simple_select_one_txn(
txn,
@@ -197,8 +216,7 @@ class TransactionStore(SQLBaseStore):
retry_interval (int) - how long until next retry in ms
"""
- # XXX: we could chose to not bother persisting this if our cache thinks
- # this is a NOOP
+ self._destination_retry_cache.pop(destination, None)
return self.runInteraction(
"set_destination_retry_timings",
self._set_destination_retry_timings,
@@ -211,10 +229,6 @@ class TransactionStore(SQLBaseStore):
retry_last_ts, retry_interval):
self.database_engine.lock_table(txn, "destinations")
- self._invalidate_cache_and_stream(
- txn, self.get_destination_retry_timings, (destination,)
- )
-
# We need to be careful here as the data may have changed from under us
# due to a worker setting the timings.
|