diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 08dffd774f..be61147b9b 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -17,9 +17,10 @@ import sys
import threading
import time
-from six import iteritems, iterkeys, itervalues
+from six import PY2, iteritems, iterkeys, itervalues
from six.moves import intern, range
+from canonicaljson import json
from prometheus_client import Histogram
from twisted.internet import defer
@@ -1216,3 +1217,32 @@ class _RollbackButIsFineException(Exception):
something went wrong.
"""
pass
+
+
+def db_to_json(db_content):
+ """
+ Take some data from a database row and return a JSON-decoded object.
+
+ Args:
+ db_content (memoryview|buffer|bytes|bytearray|unicode)
+ """
+ # psycopg2 on Python 3 returns memoryview objects, which we need to
+ # cast to bytes to decode
+ if isinstance(db_content, memoryview):
+ db_content = db_content.tobytes()
+
+ # psycopg2 on Python 2 returns buffer objects, which we need to cast to
+ # bytes to decode
+ if PY2 and isinstance(db_content, buffer):
+ db_content = bytes(db_content)
+
+ # Decode it to a Unicode string before feeding it to json.loads, so we
+ # consistenty get a Unicode-containing object out.
+ if isinstance(db_content, (bytes, bytearray)):
+ db_content = db_content.decode('utf8')
+
+ try:
+ return json.loads(db_content)
+ except Exception:
+ logging.warning("Tried to decode '%r' as JSON and failed", db_content)
+ raise
diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py
index 73646da025..e06b0bc56d 100644
--- a/synapse/storage/deviceinbox.py
+++ b/synapse/storage/deviceinbox.py
@@ -169,7 +169,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
local_by_user_then_device = {}
for user_id, messages_by_device in messages_by_user_then_device.items():
messages_json_for_user = {}
- devices = messages_by_device.keys()
+ devices = list(messages_by_device.keys())
if len(devices) == 1 and devices[0] == "*":
# Handle wildcard device_ids.
sql = (
diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py
index c0943ecf91..d10ff9e4b9 100644
--- a/synapse/storage/devices.py
+++ b/synapse/storage/devices.py
@@ -24,7 +24,7 @@ from synapse.api.errors import StoreError
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
-from ._base import Cache, SQLBaseStore
+from ._base import Cache, SQLBaseStore, db_to_json
logger = logging.getLogger(__name__)
@@ -411,7 +411,7 @@ class DeviceStore(SQLBaseStore):
if device is not None:
key_json = device.get("key_json", None)
if key_json:
- result["keys"] = json.loads(key_json)
+ result["keys"] = db_to_json(key_json)
device_display_name = device.get("device_display_name", None)
if device_display_name:
result["device_display_name"] = device_display_name
@@ -466,7 +466,7 @@ class DeviceStore(SQLBaseStore):
retcol="content",
desc="_get_cached_user_device",
)
- defer.returnValue(json.loads(content))
+ defer.returnValue(db_to_json(content))
@cachedInlineCallbacks()
def _get_cached_devices_for_user(self, user_id):
@@ -479,7 +479,7 @@ class DeviceStore(SQLBaseStore):
desc="_get_cached_devices_for_user",
)
defer.returnValue({
- device["device_id"]: json.loads(device["content"])
+ device["device_id"]: db_to_json(device["content"])
for device in devices
})
@@ -511,7 +511,7 @@ class DeviceStore(SQLBaseStore):
key_json = device.get("key_json", None)
if key_json:
- result["keys"] = json.loads(key_json)
+ result["keys"] = db_to_json(key_json)
device_display_name = device.get("device_display_name", None)
if device_display_name:
result["device_display_name"] = device_display_name
diff --git a/synapse/storage/directory.py b/synapse/storage/directory.py
index 808194236a..cfb687cb53 100644
--- a/synapse/storage/directory.py
+++ b/synapse/storage/directory.py
@@ -75,7 +75,6 @@ class DirectoryWorkerStore(SQLBaseStore):
},
retcol="creator",
desc="get_room_alias_creator",
- allow_none=True
)
@cached(max_entries=5000)
diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py
index 523b4360c3..1f1721e820 100644
--- a/synapse/storage/end_to_end_keys.py
+++ b/synapse/storage/end_to_end_keys.py
@@ -14,13 +14,13 @@
# limitations under the License.
from six import iteritems
-from canonicaljson import encode_canonical_json, json
+from canonicaljson import encode_canonical_json
from twisted.internet import defer
from synapse.util.caches.descriptors import cached
-from ._base import SQLBaseStore
+from ._base import SQLBaseStore, db_to_json
class EndToEndKeyStore(SQLBaseStore):
@@ -90,7 +90,7 @@ class EndToEndKeyStore(SQLBaseStore):
for user_id, device_keys in iteritems(results):
for device_id, device_info in iteritems(device_keys):
- device_info["keys"] = json.loads(device_info.pop("key_json"))
+ device_info["keys"] = db_to_json(device_info.pop("key_json"))
defer.returnValue(results)
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index 8a0386c1a4..42225f8a2a 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -41,13 +41,18 @@ class PostgresEngine(object):
db_conn.set_isolation_level(
self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ
)
+
+ # Set the bytea output to escape, vs the default of hex
+ cursor = db_conn.cursor()
+ cursor.execute("SET bytea_output TO escape")
+
# Asynchronous commit, don't wait for the server to call fsync before
# ending the transaction.
# https://www.postgresql.org/docs/current/static/wal-async-commit.html
if not self.synchronous_commit:
- cursor = db_conn.cursor()
cursor.execute("SET synchronous_commit TO OFF")
- cursor.close()
+
+ cursor.close()
def is_deadlock(self, error):
if isinstance(error, self.module.DatabaseError):
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index f39c8c8461..e7487311ce 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -19,7 +19,7 @@ import logging
from collections import OrderedDict, deque, namedtuple
from functools import wraps
-from six import iteritems
+from six import iteritems, text_type
from six.moves import range
from canonicaljson import json
@@ -930,6 +930,10 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
)
self._invalidate_cache_and_stream(
+ txn, self.get_room_summary, (room_id,)
+ )
+
+ self._invalidate_cache_and_stream(
txn, self.get_current_state_ids, (room_id,)
)
@@ -1220,7 +1224,7 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
"sender": event.sender,
"contains_url": (
"url" in event.content
- and isinstance(event.content["url"], basestring)
+ and isinstance(event.content["url"], text_type)
),
}
for event, _ in events_and_contexts
@@ -1529,7 +1533,7 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
contains_url = "url" in content
if contains_url:
- contains_url &= isinstance(content["url"], basestring)
+ contains_url &= isinstance(content["url"], text_type)
except (KeyError, AttributeError):
# If the event is missing a necessary field then
# skip over it.
@@ -1886,20 +1890,6 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
")"
)
- # create an index on should_delete because later we'll be looking for
- # the should_delete / shouldn't_delete subsets
- txn.execute(
- "CREATE INDEX events_to_purge_should_delete"
- " ON events_to_purge(should_delete)",
- )
-
- # We do joins against events_to_purge for e.g. calculating state
- # groups to purge, etc., so lets make an index.
- txn.execute(
- "CREATE INDEX events_to_purge_id"
- " ON events_to_purge(event_id)",
- )
-
# First ensure that we're not about to delete all the forward extremeties
txn.execute(
"SELECT e.event_id, e.depth FROM events as e "
@@ -1910,9 +1900,9 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
(room_id,)
)
rows = txn.fetchall()
- max_depth = max(row[0] for row in rows)
+ max_depth = max(row[1] for row in rows)
- if max_depth <= token.topological:
+ if max_depth < token.topological:
# We need to ensure we don't delete all the events from the database
# otherwise we wouldn't be able to send any events (due to not
# having any backwards extremeties)
@@ -1926,19 +1916,45 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
should_delete_params = ()
if not delete_local_events:
should_delete_expr += " AND event_id NOT LIKE ?"
- should_delete_params += ("%:" + self.hs.hostname, )
+
+ # We include the parameter twice since we use the expression twice
+ should_delete_params += (
+ "%:" + self.hs.hostname,
+ "%:" + self.hs.hostname,
+ )
should_delete_params += (room_id, token.topological)
+ # Note that we insert events that are outliers and aren't going to be
+ # deleted, as nothing will happen to them.
txn.execute(
"INSERT INTO events_to_purge"
" SELECT event_id, %s"
" FROM events AS e LEFT JOIN state_events USING (event_id)"
- " WHERE e.room_id = ? AND topological_ordering < ?" % (
+ " WHERE (NOT outlier OR (%s)) AND e.room_id = ? AND topological_ordering < ?"
+ % (
+ should_delete_expr,
should_delete_expr,
),
should_delete_params,
)
+
+ # We create the indices *after* insertion as that's a lot faster.
+
+ # create an index on should_delete because later we'll be looking for
+ # the should_delete / shouldn't_delete subsets
+ txn.execute(
+ "CREATE INDEX events_to_purge_should_delete"
+ " ON events_to_purge(should_delete)",
+ )
+
+ # We do joins against events_to_purge for e.g. calculating state
+ # groups to purge, etc., so lets make an index.
+ txn.execute(
+ "CREATE INDEX events_to_purge_id"
+ " ON events_to_purge(event_id)",
+ )
+
txn.execute(
"SELECT event_id, should_delete FROM events_to_purge"
)
diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py
index 59822178ff..a8326f5296 100644
--- a/synapse/storage/events_worker.py
+++ b/synapse/storage/events_worker.py
@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
import itertools
import logging
from collections import namedtuple
@@ -265,7 +266,7 @@ class EventsWorkerStore(SQLBaseStore):
"""
with Measure(self._clock, "_fetch_event_list"):
try:
- event_id_lists = zip(*event_list)[0]
+ event_id_lists = list(zip(*event_list))[0]
event_ids = [
item for sublist in event_id_lists for item in sublist
]
@@ -299,14 +300,14 @@ class EventsWorkerStore(SQLBaseStore):
logger.exception("do_fetch")
# We only want to resolve deferreds from the main thread
- def fire(evs):
+ def fire(evs, exc):
for _, d in evs:
if not d.called:
with PreserveLoggingContext():
- d.errback(e)
+ d.errback(exc)
with PreserveLoggingContext():
- self.hs.get_reactor().callFromThread(fire, event_list)
+ self.hs.get_reactor().callFromThread(fire, event_list, e)
@defer.inlineCallbacks
def _enqueue_events(self, events, check_redacted=True, allow_rejected=False):
diff --git a/synapse/storage/filtering.py b/synapse/storage/filtering.py
index 2d5896c5b4..6ddcc909bf 100644
--- a/synapse/storage/filtering.py
+++ b/synapse/storage/filtering.py
@@ -13,14 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from canonicaljson import encode_canonical_json, json
+from canonicaljson import encode_canonical_json
from twisted.internet import defer
from synapse.api.errors import Codes, SynapseError
from synapse.util.caches.descriptors import cachedInlineCallbacks
-from ._base import SQLBaseStore
+from ._base import SQLBaseStore, db_to_json
class FilteringStore(SQLBaseStore):
@@ -44,7 +44,7 @@ class FilteringStore(SQLBaseStore):
desc="get_user_filter",
)
- defer.returnValue(json.loads(bytes(def_json).decode("utf-8")))
+ defer.returnValue(db_to_json(def_json))
def add_user_filter(self, user_localpart, user_filter):
def_json = encode_canonical_json(user_filter)
diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py
index f547977600..a1331c1a61 100644
--- a/synapse/storage/keys.py
+++ b/synapse/storage/keys.py
@@ -134,6 +134,7 @@ class KeyStore(SQLBaseStore):
"""
key_id = "%s:%s" % (verify_key.alg, verify_key.version)
+ # XXX fix this to not need a lock (#3819)
def _txn(txn):
self._simple_upsert_txn(
txn,
diff --git a/synapse/storage/monthly_active_users.py b/synapse/storage/monthly_active_users.py
index d178f5c5ba..59580949f1 100644
--- a/synapse/storage/monthly_active_users.py
+++ b/synapse/storage/monthly_active_users.py
@@ -36,7 +36,6 @@ class MonthlyActiveUsersStore(SQLBaseStore):
@defer.inlineCallbacks
def initialise_reserved_users(self, threepids):
- # TODO Why can't I do this in init?
store = self.hs.get_datastore()
reserved_user_list = []
@@ -148,6 +147,23 @@ class MonthlyActiveUsersStore(SQLBaseStore):
return self.runInteraction("count_users", _count_users)
@defer.inlineCallbacks
+ def get_registered_reserved_users_count(self):
+ """Of the reserved threepids defined in config, how many are associated
+ with registered users?
+
+ Returns:
+ Defered[int]: Number of real reserved users
+ """
+ count = 0
+ for tp in self.hs.config.mau_limits_reserved_threepids:
+ user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
+ tp["medium"], tp["address"]
+ )
+ if user_id:
+ count = count + 1
+ defer.returnValue(count)
+
+ @defer.inlineCallbacks
def upsert_monthly_active_user(self, user_id):
"""
Updates or inserts monthly active user member
@@ -200,10 +216,14 @@ class MonthlyActiveUsersStore(SQLBaseStore):
Args:
user_id(str): the user_id to query
"""
+
if self.hs.config.limit_usage_by_mau:
+ # Trial users and guests should not be included as part of MAU group
+ is_guest = yield self.is_guest(user_id)
+ if is_guest:
+ return
is_trial = yield self.is_trial_user(user_id)
if is_trial:
- # we don't track trial users in the MAU table.
return
last_seen_timestamp = yield self.user_last_seen_monthly_active(user_id)
diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py
index 8443bd4c1b..c7987bfcdd 100644
--- a/synapse/storage/pusher.py
+++ b/synapse/storage/pusher.py
@@ -15,7 +15,8 @@
# limitations under the License.
import logging
-import types
+
+import six
from canonicaljson import encode_canonical_json, json
@@ -27,6 +28,11 @@ from ._base import SQLBaseStore
logger = logging.getLogger(__name__)
+if six.PY2:
+ db_binary_type = buffer
+else:
+ db_binary_type = memoryview
+
class PusherWorkerStore(SQLBaseStore):
def _decode_pushers_rows(self, rows):
@@ -34,18 +40,18 @@ class PusherWorkerStore(SQLBaseStore):
dataJson = r['data']
r['data'] = None
try:
- if isinstance(dataJson, types.BufferType):
+ if isinstance(dataJson, db_binary_type):
dataJson = str(dataJson).decode("UTF8")
r['data'] = json.loads(dataJson)
except Exception as e:
logger.warn(
"Invalid JSON in data for pusher %d: %s, %s",
- r['id'], dataJson, e.message,
+ r['id'], dataJson, e.args[0],
)
pass
- if isinstance(r['pushkey'], types.BufferType):
+ if isinstance(r['pushkey'], db_binary_type):
r['pushkey'] = str(r['pushkey']).decode("UTF8")
return rows
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 9b4e6d6aa8..0707f9a86a 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -51,6 +51,12 @@ ProfileInfo = namedtuple(
"ProfileInfo", ("avatar_url", "display_name")
)
+# "members" points to a truncated list of (user_id, event_id) tuples for users of
+# a given membership type, suitable for use in calculating heroes for a room.
+# "count" points to the total numberr of users of a given membership type.
+MemberSummary = namedtuple(
+ "MemberSummary", ("members", "count")
+)
_MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update"
@@ -82,6 +88,65 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return [to_ascii(r[0]) for r in txn]
return self.runInteraction("get_users_in_room", f)
+ @cached(max_entries=100000)
+ def get_room_summary(self, room_id):
+ """ Get the details of a room roughly suitable for use by the room
+ summary extension to /sync. Useful when lazy loading room members.
+ Args:
+ room_id (str): The room ID to query
+ Returns:
+ Deferred[dict[str, MemberSummary]:
+ dict of membership states, pointing to a MemberSummary named tuple.
+ """
+
+ def _get_room_summary_txn(txn):
+ # first get counts.
+ # We do this all in one transaction to keep the cache small.
+ # FIXME: get rid of this when we have room_stats
+ sql = """
+ SELECT count(*), m.membership FROM room_memberships as m
+ INNER JOIN current_state_events as c
+ ON m.event_id = c.event_id
+ AND m.room_id = c.room_id
+ AND m.user_id = c.state_key
+ WHERE c.type = 'm.room.member' AND c.room_id = ?
+ GROUP BY m.membership
+ """
+
+ txn.execute(sql, (room_id,))
+ res = {}
+ for count, membership in txn:
+ summary = res.setdefault(to_ascii(membership), MemberSummary([], count))
+
+ # we order by membership and then fairly arbitrarily by event_id so
+ # heroes are consistent
+ sql = """
+ SELECT m.user_id, m.membership, m.event_id
+ FROM room_memberships as m
+ INNER JOIN current_state_events as c
+ ON m.event_id = c.event_id
+ AND m.room_id = c.room_id
+ AND m.user_id = c.state_key
+ WHERE c.type = 'm.room.member' AND c.room_id = ?
+ ORDER BY
+ CASE m.membership WHEN ? THEN 1 WHEN ? THEN 2 ELSE 3 END ASC,
+ m.event_id ASC
+ LIMIT ?
+ """
+
+ # 6 is 5 (number of heroes) plus 1, in case one of them is the calling user.
+ txn.execute(sql, (room_id, Membership.JOIN, Membership.INVITE, 6))
+ for user_id, membership, event_id in txn:
+ summary = res[to_ascii(membership)]
+ # we will always have a summary for this membership type at this
+ # point given the summary currently contains the counts.
+ members = summary.members
+ members.append((to_ascii(user_id), to_ascii(event_id)))
+
+ return res
+
+ return self.runInteraction("get_room_summary", _get_room_summary_txn)
+
@cached()
def get_invited_rooms_for_user(self, user_id):
""" Get all the rooms the user is invited to
diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py
index 428e7fa36e..0c42bd3322 100644
--- a/synapse/storage/transactions.py
+++ b/synapse/storage/transactions.py
@@ -18,14 +18,14 @@ from collections import namedtuple
import six
-from canonicaljson import encode_canonical_json, json
+from canonicaljson import encode_canonical_json
from twisted.internet import defer
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util.caches.descriptors import cached
-from ._base import SQLBaseStore
+from ._base import SQLBaseStore, db_to_json
# py2 sqlite has buffer hardcoded as only binary type, so we must use it,
# despite being deprecated and removed in favor of memoryview
@@ -95,7 +95,8 @@ class TransactionStore(SQLBaseStore):
)
if result and result["response_code"]:
- return result["response_code"], json.loads(str(result["response_json"]))
+ return result["response_code"], db_to_json(result["response_json"])
+
else:
return None
|