diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 979fa22438..23b4a8d76d 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -15,50 +15,49 @@
# limitations under the License.
import datetime
-from dateutil import tz
-import time
import logging
+import time
+from dateutil import tz
+
+from synapse.api.constants import PresenceState
from synapse.storage.devices import DeviceStore
-from .appservice import (
- ApplicationServiceStore, ApplicationServiceTransactionStore
-)
+from synapse.storage.user_erasure_store import UserErasureStore
+from synapse.util.caches.stream_change_cache import StreamChangeCache
+
+from .account_data import AccountDataStore
+from .appservice import ApplicationServiceStore, ApplicationServiceTransactionStore
+from .client_ips import ClientIpStore
+from .deviceinbox import DeviceInboxStore
from .directory import DirectoryStore
+from .end_to_end_keys import EndToEndKeyStore
+from .engines import PostgresEngine
+from .event_federation import EventFederationStore
+from .event_push_actions import EventPushActionsStore
from .events import EventsStore
+from .filtering import FilteringStore
+from .group_server import GroupServerStore
+from .keys import KeyStore
+from .media_repository import MediaRepositoryStore
+from .monthly_active_users import MonthlyActiveUsersStore
+from .openid import OpenIdStore
from .presence import PresenceStore, UserPresenceState
from .profile import ProfileStore
+from .push_rule import PushRuleStore
+from .pusher import PusherStore
+from .receipts import ReceiptsStore
from .registration import RegistrationStore
+from .rejections import RejectionsStore
from .room import RoomStore
from .roommember import RoomMemberStore
-from .stream import StreamStore
-from .transactions import TransactionStore
-from .keys import KeyStore
-from .event_federation import EventFederationStore
-from .pusher import PusherStore
-from .push_rule import PushRuleStore
-from .media_repository import MediaRepositoryStore
-from .rejections import RejectionsStore
-from .event_push_actions import EventPushActionsStore
-from .deviceinbox import DeviceInboxStore
-from .group_server import GroupServerStore
-from .state import StateStore
-from .signatures import SignatureStore
-from .filtering import FilteringStore
-from .end_to_end_keys import EndToEndKeyStore
-
-from .receipts import ReceiptsStore
from .search import SearchStore
+from .signatures import SignatureStore
+from .state import StateStore
+from .stream import StreamStore
from .tags import TagsStore
-from .account_data import AccountDataStore
-from .openid import OpenIdStore
-from .client_ips import ClientIpStore
+from .transactions import TransactionStore
from .user_directory import UserDirectoryStore
-
-from .util.id_generators import IdGenerator, StreamIdGenerator, ChainedIdGenerator
-from .engines import PostgresEngine
-
-from synapse.api.constants import PresenceState
-from synapse.util.caches.stream_change_cache import StreamChangeCache
+from .util.id_generators import ChainedIdGenerator, IdGenerator, StreamIdGenerator
logger = logging.getLogger(__name__)
@@ -68,6 +67,7 @@ class DataStore(RoomMemberStore, RoomStore,
PresenceStore, TransactionStore,
DirectoryStore, KeyStore, StateStore, SignatureStore,
ApplicationServiceStore,
+ EventsStore,
EventFederationStore,
MediaRepositoryStore,
RejectionsStore,
@@ -75,7 +75,6 @@ class DataStore(RoomMemberStore, RoomStore,
PusherStore,
PushRuleStore,
ApplicationServiceTransactionStore,
- EventsStore,
ReceiptsStore,
EndToEndKeyStore,
SearchStore,
@@ -88,6 +87,8 @@ class DataStore(RoomMemberStore, RoomStore,
DeviceInboxStore,
UserDirectoryStore,
GroupServerStore,
+ UserErasureStore,
+ MonthlyActiveUsersStore,
):
def __init__(self, db_conn, hs):
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 22d6257a9f..be61147b9b 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -13,22 +13,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+import sys
+import threading
+import time
-from synapse.api.errors import StoreError
-from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
-from synapse.util.caches.descriptors import Cache
-from synapse.storage.engines import PostgresEngine
+from six import PY2, iteritems, iterkeys, itervalues
+from six.moves import intern, range
+from canonicaljson import json
from prometheus_client import Histogram
from twisted.internet import defer
-import sys
-import time
-import threading
-
-from six import itervalues, iterkeys, iteritems
-from six.moves import intern, range
+from synapse.api.errors import StoreError
+from synapse.storage.engines import PostgresEngine
+from synapse.util.caches.descriptors import Cache
+from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
logger = logging.getLogger(__name__)
@@ -221,7 +221,7 @@ class SQLBaseStore(object):
self._clock.looping_call(loop, 10000)
def _new_transaction(self, conn, desc, after_callbacks, exception_callbacks,
- logging_context, func, *args, **kwargs):
+ func, *args, **kwargs):
start = time.time()
txn_id = self._TXN_ID
@@ -285,8 +285,7 @@ class SQLBaseStore(object):
end = time.time()
duration = end - start
- if logging_context is not None:
- logging_context.add_database_transaction(duration)
+ LoggingContext.current_context().add_database_transaction(duration)
transaction_logger.debug("[TXN END] {%s} %f sec", name, duration)
@@ -310,19 +309,21 @@ class SQLBaseStore(object):
Returns:
Deferred: The result of func
"""
- current_context = LoggingContext.current_context()
-
after_callbacks = []
exception_callbacks = []
- def inner_func(conn, *args, **kwargs):
- return self._new_transaction(
- conn, desc, after_callbacks, exception_callbacks, current_context,
- func, *args, **kwargs
+ if LoggingContext.current_context() == LoggingContext.sentinel:
+ logger.warn(
+ "Starting db txn '%s' from sentinel context",
+ desc,
)
try:
- result = yield self.runWithConnection(inner_func, *args, **kwargs)
+ result = yield self.runWithConnection(
+ self._new_transaction,
+ desc, after_callbacks, exception_callbacks, func,
+ *args, **kwargs
+ )
for after_callback, after_args, after_kwargs in after_callbacks:
after_callback(*after_args, **after_kwargs)
@@ -347,22 +348,25 @@ class SQLBaseStore(object):
Returns:
Deferred: The result of func
"""
- current_context = LoggingContext.current_context()
+ parent_context = LoggingContext.current_context()
+ if parent_context == LoggingContext.sentinel:
+ logger.warn(
+ "Starting db connection from sentinel context: metrics will be lost",
+ )
+ parent_context = None
start_time = time.time()
def inner_func(conn, *args, **kwargs):
- with LoggingContext("runWithConnection") as context:
+ with LoggingContext("runWithConnection", parent_context) as context:
sched_duration_sec = time.time() - start_time
sql_scheduling_timer.observe(sched_duration_sec)
- current_context.add_database_scheduled(sched_duration_sec)
+ context.add_database_scheduled(sched_duration_sec)
if self.database_engine.is_connection_closed(conn):
logger.debug("Reconnecting closed database connection")
conn.reconnect()
- current_context.copy_to(context)
-
return func(conn, *args, **kwargs)
with PreserveLoggingContext():
@@ -1147,17 +1151,16 @@ class SQLBaseStore(object):
defer.returnValue(retval)
def get_user_count_txn(self, txn):
- """Get a total number of registerd users in the users list.
+ """Get a total number of registered users in the users list.
Args:
txn : Transaction object
Returns:
- defer.Deferred: resolves to int
+ int : number of users
"""
sql_count = "SELECT COUNT(*) FROM users WHERE is_guest = 0;"
txn.execute(sql_count)
- count = txn.fetchone()[0]
- defer.returnValue(count)
+ return txn.fetchone()[0]
def _simple_search_list(self, table, term, col, retcols,
desc="_simple_search_list"):
@@ -1214,3 +1217,32 @@ class _RollbackButIsFineException(Exception):
something went wrong.
"""
pass
+
+
+def db_to_json(db_content):
+ """
+ Take some data from a database row and return a JSON-decoded object.
+
+ Args:
+ db_content (memoryview|buffer|bytes|bytearray|unicode)
+ """
+ # psycopg2 on Python 3 returns memoryview objects, which we need to
+ # cast to bytes to decode
+ if isinstance(db_content, memoryview):
+ db_content = db_content.tobytes()
+
+ # psycopg2 on Python 2 returns buffer objects, which we need to cast to
+ # bytes to decode
+ if PY2 and isinstance(db_content, buffer):
+ db_content = bytes(db_content)
+
+ # Decode it to a Unicode string before feeding it to json.loads, so we
+ # consistenty get a Unicode-containing object out.
+ if isinstance(db_content, (bytes, bytearray)):
+ db_content = db_content.decode('utf8')
+
+ try:
+ return json.loads(db_content)
+ except Exception:
+ logging.warning("Tried to decode '%r' as JSON and failed", db_content)
+ raise
diff --git a/synapse/storage/account_data.py b/synapse/storage/account_data.py
index f83ff0454a..bbc3355c73 100644
--- a/synapse/storage/account_data.py
+++ b/synapse/storage/account_data.py
@@ -14,17 +14,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import abc
+import logging
+
+from canonicaljson import json
+
from twisted.internet import defer
from synapse.storage._base import SQLBaseStore
from synapse.storage.util.id_generators import StreamIdGenerator
-
+from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
from synapse.util.caches.stream_change_cache import StreamChangeCache
-from synapse.util.caches.descriptors import cached, cachedList, cachedInlineCallbacks
-
-import abc
-import simplejson as json
-import logging
logger = logging.getLogger(__name__)
@@ -114,25 +114,6 @@ class AccountDataWorkerStore(SQLBaseStore):
else:
defer.returnValue(None)
- @cachedList(cached_method_name="get_global_account_data_by_type_for_user",
- num_args=2, list_name="user_ids", inlineCallbacks=True)
- def get_global_account_data_by_type_for_users(self, data_type, user_ids):
- rows = yield self._simple_select_many_batch(
- table="account_data",
- column="user_id",
- iterable=user_ids,
- keyvalues={
- "account_data_type": data_type,
- },
- retcols=("user_id", "content",),
- desc="get_global_account_data_by_type_for_users",
- )
-
- defer.returnValue({
- row["user_id"]: json.loads(row["content"]) if row["content"] else None
- for row in rows
- })
-
@cached(num_args=2)
def get_account_data_for_room(self, user_id, room_id):
"""Get all the client account_data for a user for a room.
diff --git a/synapse/storage/appservice.py b/synapse/storage/appservice.py
index 12ea8a158c..31248d5e06 100644
--- a/synapse/storage/appservice.py
+++ b/synapse/storage/appservice.py
@@ -15,14 +15,16 @@
# limitations under the License.
import logging
import re
-import simplejson as json
+
+from canonicaljson import json
+
from twisted.internet import defer
from synapse.appservice import AppServiceTransaction
from synapse.config.appservice import load_appservices
-from synapse.storage.events import EventsWorkerStore
-from ._base import SQLBaseStore
+from synapse.storage.events_worker import EventsWorkerStore
+from ._base import SQLBaseStore
logger = logging.getLogger(__name__)
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 8af325a9f5..5fe1ca2de7 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -12,15 +12,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import synapse.util.async
-from ._base import SQLBaseStore
-from . import engines
+import logging
+
+from canonicaljson import json
from twisted.internet import defer
-import simplejson as json
-import logging
+from synapse.metrics.background_process_metrics import run_as_background_process
+
+from . import engines
+from ._base import SQLBaseStore
logger = logging.getLogger(__name__)
@@ -87,12 +89,16 @@ class BackgroundUpdateStore(SQLBaseStore):
self._background_update_handlers = {}
self._all_done = False
- @defer.inlineCallbacks
def start_doing_background_updates(self):
- logger.info("Starting background schema updates")
+ run_as_background_process(
+ "background_updates", self._run_background_updates,
+ )
+ @defer.inlineCallbacks
+ def _run_background_updates(self):
+ logger.info("Starting background schema updates")
while True:
- yield synapse.util.async.sleep(
+ yield self.hs.get_clock().sleep(
self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.)
try:
diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py
index ce338514e8..8fc678fa67 100644
--- a/synapse/storage/client_ips.py
+++ b/synapse/storage/client_ips.py
@@ -15,15 +15,15 @@
import logging
-from twisted.internet import defer, reactor
+from six import iteritems
-from ._base import Cache
-from . import background_updates
+from twisted.internet import defer
+from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util.caches import CACHE_SIZE_FACTOR
-from six import iteritems
-
+from . import background_updates
+from ._base import Cache
logger = logging.getLogger(__name__)
@@ -35,6 +35,7 @@ LAST_SEEN_GRANULARITY = 120 * 1000
class ClientIpStore(background_updates.BackgroundUpdateStore):
def __init__(self, db_conn, hs):
+
self.client_ip_last_seen = Cache(
name="client_ip_last_seen",
keylen=4,
@@ -70,8 +71,11 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
self._client_ip_looper = self._clock.looping_call(
self._update_client_ips_batch, 5 * 1000
)
- reactor.addSystemEventTrigger("before", "shutdown", self._update_client_ips_batch)
+ self.hs.get_reactor().addSystemEventTrigger(
+ "before", "shutdown", self._update_client_ips_batch
+ )
+ @defer.inlineCallbacks
def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id,
now=None):
if not now:
@@ -82,7 +86,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
last_seen = self.client_ip_last_seen.get(key)
except KeyError:
last_seen = None
-
+ yield self.populate_monthly_active_users(user_id)
# Rate-limited inserts
if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
return
@@ -92,10 +96,21 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
self._batch_row_update[key] = (user_agent, device_id, now)
def _update_client_ips_batch(self):
- to_update = self._batch_row_update
- self._batch_row_update = {}
- return self.runInteraction(
- "_update_client_ips_batch", self._update_client_ips_batch_txn, to_update
+
+ # If the DB pool has already terminated, don't try updating
+ if not self.hs.get_db_pool().running:
+ return
+
+ def update():
+ to_update = self._batch_row_update
+ self._batch_row_update = {}
+ return self.runInteraction(
+ "_update_client_ips_batch", self._update_client_ips_batch_txn,
+ to_update,
+ )
+
+ return run_as_background_process(
+ "update_client_ips", update,
)
def _update_client_ips_batch_txn(self, txn, to_update):
diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py
index a879e5bfc1..e06b0bc56d 100644
--- a/synapse/storage/deviceinbox.py
+++ b/synapse/storage/deviceinbox.py
@@ -14,14 +14,14 @@
# limitations under the License.
import logging
-import simplejson
-from twisted.internet import defer
+from canonicaljson import json
-from .background_updates import BackgroundUpdateStore
+from twisted.internet import defer
from synapse.util.caches.expiringcache import ExpiringCache
+from .background_updates import BackgroundUpdateStore
logger = logging.getLogger(__name__)
@@ -85,7 +85,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
)
rows = []
for destination, edu in remote_messages_by_destination.items():
- edu_json = simplejson.dumps(edu)
+ edu_json = json.dumps(edu)
rows.append((destination, stream_id, now_ms, edu_json))
txn.executemany(sql, rows)
@@ -169,7 +169,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
local_by_user_then_device = {}
for user_id, messages_by_device in messages_by_user_then_device.items():
messages_json_for_user = {}
- devices = messages_by_device.keys()
+ devices = list(messages_by_device.keys())
if len(devices) == 1 and devices[0] == "*":
# Handle wildcard device_ids.
sql = (
@@ -177,7 +177,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
" WHERE user_id = ?"
)
txn.execute(sql, (user_id,))
- message_json = simplejson.dumps(messages_by_device["*"])
+ message_json = json.dumps(messages_by_device["*"])
for row in txn:
# Add the message for all devices for this user on this
# server.
@@ -199,7 +199,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
# Only insert into the local inbox if the device exists on
# this server
device = row[0]
- message_json = simplejson.dumps(messages_by_device[device])
+ message_json = json.dumps(messages_by_device[device])
messages_json_for_user[device] = message_json
if messages_json_for_user:
@@ -253,7 +253,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
messages = []
for row in txn:
stream_pos = row[0]
- messages.append(simplejson.loads(row[1]))
+ messages.append(json.loads(row[1]))
if len(messages) < limit:
stream_pos = current_stream_id
return (messages, stream_pos)
@@ -389,7 +389,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
messages = []
for row in txn:
stream_pos = row[0]
- messages.append(simplejson.loads(row[1]))
+ messages.append(json.loads(row[1]))
if len(messages) < limit:
stream_pos = current_stream_id
return (messages, stream_pos)
diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py
index d149d8392e..d10ff9e4b9 100644
--- a/synapse/storage/devices.py
+++ b/synapse/storage/devices.py
@@ -13,15 +13,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-import simplejson as json
+
+from six import iteritems, itervalues
+
+from canonicaljson import json
from twisted.internet import defer
from synapse.api.errors import StoreError
-from ._base import SQLBaseStore, Cache
-from synapse.util.caches.descriptors import cached, cachedList, cachedInlineCallbacks
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
-from six import itervalues, iteritems
+from ._base import Cache, SQLBaseStore, db_to_json
logger = logging.getLogger(__name__)
@@ -246,17 +249,31 @@ class DeviceStore(SQLBaseStore):
def _update_remote_device_list_cache_entry_txn(self, txn, user_id, device_id,
content, stream_id):
- self._simple_upsert_txn(
- txn,
- table="device_lists_remote_cache",
- keyvalues={
- "user_id": user_id,
- "device_id": device_id,
- },
- values={
- "content": json.dumps(content),
- }
- )
+ if content.get("deleted"):
+ self._simple_delete_txn(
+ txn,
+ table="device_lists_remote_cache",
+ keyvalues={
+ "user_id": user_id,
+ "device_id": device_id,
+ },
+ )
+
+ txn.call_after(
+ self.device_id_exists_cache.invalidate, (user_id, device_id,)
+ )
+ else:
+ self._simple_upsert_txn(
+ txn,
+ table="device_lists_remote_cache",
+ keyvalues={
+ "user_id": user_id,
+ "device_id": device_id,
+ },
+ values={
+ "content": json.dumps(content),
+ }
+ )
txn.call_after(self._get_cached_user_device.invalidate, (user_id, device_id,))
txn.call_after(self._get_cached_devices_for_user.invalidate, (user_id,))
@@ -364,7 +381,7 @@ class DeviceStore(SQLBaseStore):
now_stream_id = max(stream_id for stream_id in itervalues(query_map))
devices = self._get_e2e_device_keys_txn(
- txn, query_map.keys(), include_all_devices=True
+ txn, query_map.keys(), include_all_devices=True, include_deleted_devices=True
)
prev_sent_id_sql = """
@@ -391,12 +408,15 @@ class DeviceStore(SQLBaseStore):
prev_id = stream_id
- key_json = device.get("key_json", None)
- if key_json:
- result["keys"] = json.loads(key_json)
- device_display_name = device.get("device_display_name", None)
- if device_display_name:
- result["device_display_name"] = device_display_name
+ if device is not None:
+ key_json = device.get("key_json", None)
+ if key_json:
+ result["keys"] = db_to_json(key_json)
+ device_display_name = device.get("device_display_name", None)
+ if device_display_name:
+ result["device_display_name"] = device_display_name
+ else:
+ result["deleted"] = True
results.append(result)
@@ -446,7 +466,7 @@ class DeviceStore(SQLBaseStore):
retcol="content",
desc="_get_cached_user_device",
)
- defer.returnValue(json.loads(content))
+ defer.returnValue(db_to_json(content))
@cachedInlineCallbacks()
def _get_cached_devices_for_user(self, user_id):
@@ -459,7 +479,7 @@ class DeviceStore(SQLBaseStore):
desc="_get_cached_devices_for_user",
)
defer.returnValue({
- device["device_id"]: json.loads(device["content"])
+ device["device_id"]: db_to_json(device["content"])
for device in devices
})
@@ -491,7 +511,7 @@ class DeviceStore(SQLBaseStore):
key_json = device.get("key_json", None)
if key_json:
- result["keys"] = json.loads(key_json)
+ result["keys"] = db_to_json(key_json)
device_display_name = device.get("device_display_name", None)
if device_display_name:
result["device_display_name"] = device_display_name
@@ -692,6 +712,9 @@ class DeviceStore(SQLBaseStore):
logger.info("Pruned %d device list outbound pokes", txn.rowcount)
- return self.runInteraction(
- "_prune_old_outbound_device_pokes", _prune_txn
+ return run_as_background_process(
+ "prune_old_outbound_device_pokes",
+ self.runInteraction,
+ "_prune_old_outbound_device_pokes",
+ _prune_txn,
)
diff --git a/synapse/storage/directory.py b/synapse/storage/directory.py
index d0c0059757..808194236a 100644
--- a/synapse/storage/directory.py
+++ b/synapse/storage/directory.py
@@ -13,15 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import SQLBaseStore
-from synapse.util.caches.descriptors import cached
-
-from synapse.api.errors import SynapseError
+from collections import namedtuple
from twisted.internet import defer
-from collections import namedtuple
+from synapse.api.errors import SynapseError
+from synapse.util.caches.descriptors import cached
+from ._base import SQLBaseStore
RoomAliasMapping = namedtuple(
"RoomAliasMapping",
diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py
index b146487943..1f1721e820 100644
--- a/synapse/storage/end_to_end_keys.py
+++ b/synapse/storage/end_to_end_keys.py
@@ -12,16 +12,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
-
-from synapse.util.caches.descriptors import cached
+from six import iteritems
from canonicaljson import encode_canonical_json
-import simplejson as json
-from ._base import SQLBaseStore
+from twisted.internet import defer
-from six import iteritems
+from synapse.util.caches.descriptors import cached
+
+from ._base import SQLBaseStore, db_to_json
class EndToEndKeyStore(SQLBaseStore):
@@ -65,12 +64,18 @@ class EndToEndKeyStore(SQLBaseStore):
)
@defer.inlineCallbacks
- def get_e2e_device_keys(self, query_list, include_all_devices=False):
+ def get_e2e_device_keys(
+ self, query_list, include_all_devices=False,
+ include_deleted_devices=False,
+ ):
"""Fetch a list of device keys.
Args:
query_list(list): List of pairs of user_ids and device_ids.
include_all_devices (bool): whether to include entries for devices
that don't have device keys
+ include_deleted_devices (bool): whether to include null entries for
+ devices which no longer exist (but were in the query_list).
+ This option only takes effect if include_all_devices is true.
Returns:
Dict mapping from user-id to dict mapping from device_id to
dict containing "key_json", "device_display_name".
@@ -80,19 +85,28 @@ class EndToEndKeyStore(SQLBaseStore):
results = yield self.runInteraction(
"get_e2e_device_keys", self._get_e2e_device_keys_txn,
- query_list, include_all_devices,
+ query_list, include_all_devices, include_deleted_devices,
)
for user_id, device_keys in iteritems(results):
for device_id, device_info in iteritems(device_keys):
- device_info["keys"] = json.loads(device_info.pop("key_json"))
+ device_info["keys"] = db_to_json(device_info.pop("key_json"))
defer.returnValue(results)
- def _get_e2e_device_keys_txn(self, txn, query_list, include_all_devices):
+ def _get_e2e_device_keys_txn(
+ self, txn, query_list, include_all_devices=False,
+ include_deleted_devices=False,
+ ):
query_clauses = []
query_params = []
+ if include_all_devices is False:
+ include_deleted_devices = False
+
+ if include_deleted_devices:
+ deleted_devices = set(query_list)
+
for (user_id, device_id) in query_list:
query_clause = "user_id = ?"
query_params.append(user_id)
@@ -120,8 +134,14 @@ class EndToEndKeyStore(SQLBaseStore):
result = {}
for row in rows:
+ if include_deleted_devices:
+ deleted_devices.remove((row["user_id"], row["device_id"]))
result.setdefault(row["user_id"], {})[row["device_id"]] = row
+ if include_deleted_devices:
+ for user_id, device_id in deleted_devices:
+ result.setdefault(user_id, {})[device_id] = None
+
return result
@defer.inlineCallbacks
diff --git a/synapse/storage/engines/__init__.py b/synapse/storage/engines/__init__.py
index 8c868ece75..e2f9de8451 100644
--- a/synapse/storage/engines/__init__.py
+++ b/synapse/storage/engines/__init__.py
@@ -13,13 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import IncorrectDatabaseSetup
-from .postgres import PostgresEngine
-from .sqlite3 import Sqlite3Engine
-
import importlib
import platform
+from ._base import IncorrectDatabaseSetup
+from .postgres import PostgresEngine
+from .sqlite3 import Sqlite3Engine
SUPPORTED_MODULE = {
"sqlite3": Sqlite3Engine,
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index 8a0386c1a4..42225f8a2a 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -41,13 +41,18 @@ class PostgresEngine(object):
db_conn.set_isolation_level(
self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ
)
+
+ # Set the bytea output to escape, vs the default of hex
+ cursor = db_conn.cursor()
+ cursor.execute("SET bytea_output TO escape")
+
# Asynchronous commit, don't wait for the server to call fsync before
# ending the transaction.
# https://www.postgresql.org/docs/current/static/wal-async-commit.html
if not self.synchronous_commit:
- cursor = db_conn.cursor()
cursor.execute("SET synchronous_commit TO OFF")
- cursor.close()
+
+ cursor.close()
def is_deadlock(self, error):
if isinstance(error, self.module.DatabaseError):
diff --git a/synapse/storage/engines/sqlite3.py b/synapse/storage/engines/sqlite3.py
index 60f0fa7fb3..19949fc474 100644
--- a/synapse/storage/engines/sqlite3.py
+++ b/synapse/storage/engines/sqlite3.py
@@ -13,11 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.storage.prepare_database import prepare_database
-
import struct
import threading
+from synapse.storage.prepare_database import prepare_database
+
class Sqlite3Engine(object):
single_threaded = True
diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
index 8fbf7ffba7..24345b20a6 100644
--- a/synapse/storage/event_federation.py
+++ b/synapse/storage/event_federation.py
@@ -12,23 +12,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
import random
+from six.moves import range
+from six.moves.queue import Empty, PriorityQueue
+
+from unpaddedbase64 import encode_base64
+
from twisted.internet import defer
+from synapse.api.errors import StoreError
+from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore
-from synapse.storage.events import EventsWorkerStore
+from synapse.storage.events_worker import EventsWorkerStore
from synapse.storage.signatures import SignatureWorkerStore
-
-from synapse.api.errors import StoreError
from synapse.util.caches.descriptors import cached
-from unpaddedbase64 import encode_base64
-
-import logging
-from six.moves.queue import PriorityQueue, Empty
-
-from six.moves import range
-
logger = logging.getLogger(__name__)
@@ -115,9 +114,9 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
sql = (
"SELECT b.event_id, MAX(e.depth) FROM events as e"
" INNER JOIN event_edges as g"
- " ON g.event_id = e.event_id AND g.room_id = e.room_id"
+ " ON g.event_id = e.event_id"
" INNER JOIN event_backward_extremities as b"
- " ON g.prev_event_id = b.event_id AND g.room_id = b.room_id"
+ " ON g.prev_event_id = b.event_id"
" WHERE b.room_id = ? AND g.is_state is ?"
" GROUP BY b.event_id"
)
@@ -331,8 +330,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
"SELECT depth, prev_event_id FROM event_edges"
" INNER JOIN events"
" ON prev_event_id = events.event_id"
- " AND event_edges.room_id = events.room_id"
- " WHERE event_edges.room_id = ? AND event_edges.event_id = ?"
+ " WHERE event_edges.event_id = ?"
" AND event_edges.is_state = ?"
" LIMIT ?"
)
@@ -345,6 +343,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
table="events",
keyvalues={
"event_id": event_id,
+ "room_id": room_id,
},
retcol="depth",
allow_none=True,
@@ -366,7 +365,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
txn.execute(
query,
- (room_id, event_id, False, limit - len(event_results))
+ (event_id, False, limit - len(event_results))
)
for row in txn:
@@ -403,7 +402,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
query = (
"SELECT prev_event_id FROM event_edges "
- "WHERE room_id = ? AND event_id = ? AND is_state = ? "
+ "WHERE event_id = ? AND is_state = ? "
"LIMIT ?"
)
@@ -412,7 +411,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
for event_id in front:
txn.execute(
query,
- (room_id, event_id, False, limit - len(event_results))
+ (event_id, False, limit - len(event_results))
)
for e_id, in txn:
@@ -448,7 +447,7 @@ class EventFederationStore(EventFederationWorkerStore):
)
hs.get_clock().looping_call(
- self._delete_old_forward_extrem_cache, 60 * 60 * 1000
+ self._delete_old_forward_extrem_cache, 60 * 60 * 1000,
)
def _update_min_depth_for_room_txn(self, txn, room_id, depth):
@@ -550,9 +549,11 @@ class EventFederationStore(EventFederationWorkerStore):
sql,
(self.stream_ordering_month_ago, self.stream_ordering_month_ago,)
)
- return self.runInteraction(
+ return run_as_background_process(
+ "delete_old_forward_extrem_cache",
+ self.runInteraction,
"_delete_old_forward_extrem_cache",
- _delete_old_forward_extrem_cache_txn
+ _delete_old_forward_extrem_cache_txn,
)
def clean_room_for_join(self, room_id):
diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py
index d0350ee5fe..6840320641 100644
--- a/synapse/storage/event_push_actions.py
+++ b/synapse/storage/event_push_actions.py
@@ -14,16 +14,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.storage._base import SQLBaseStore, LoggingTransaction
-from twisted.internet import defer
-from synapse.util.async import sleep
-from synapse.util.caches.descriptors import cachedInlineCallbacks
-
import logging
-import simplejson as json
from six import iteritems
+from canonicaljson import json
+
+from twisted.internet import defer
+
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage._base import LoggingTransaction, SQLBaseStore
+from synapse.util.caches.descriptors import cachedInlineCallbacks
+
logger = logging.getLogger(__name__)
@@ -84,6 +86,8 @@ class EventPushActionsWorkerStore(SQLBaseStore):
self.find_stream_orderings_looping_call = self._clock.looping_call(
self._find_stream_orderings_for_times, 10 * 60 * 1000
)
+ self._rotate_delay = 3
+ self._rotate_count = 10000
@cachedInlineCallbacks(num_args=3, tree=True, max_entries=5000)
def get_unread_event_push_actions_by_room_for_user(
@@ -455,11 +459,12 @@ class EventPushActionsWorkerStore(SQLBaseStore):
"Error removing push actions after event persistence failure",
)
- @defer.inlineCallbacks
def _find_stream_orderings_for_times(self):
- yield self.runInteraction(
+ return run_as_background_process(
+ "event_push_action_stream_orderings",
+ self.runInteraction,
"_find_stream_orderings_for_times",
- self._find_stream_orderings_for_times_txn
+ self._find_stream_orderings_for_times_txn,
)
def _find_stream_orderings_for_times_txn(self, txn):
@@ -601,7 +606,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
self._doing_notif_rotation = False
self._rotate_notif_loop = self._clock.looping_call(
- self._rotate_notifs, 30 * 60 * 1000
+ self._start_rotate_notifs, 30 * 60 * 1000,
)
def _set_push_actions_for_event_and_users_txn(self, txn, events_and_contexts,
@@ -784,6 +789,9 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
WHERE room_id = ? AND user_id = ? AND stream_ordering <= ?
""", (room_id, user_id, stream_ordering))
+ def _start_rotate_notifs(self):
+ return run_as_background_process("rotate_notifs", self._rotate_notifs)
+
@defer.inlineCallbacks
def _rotate_notifs(self):
if self._doing_notif_rotation or self.stream_ordering_day_ago is None:
@@ -800,7 +808,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
)
if caught_up:
break
- yield sleep(5)
+ yield self.hs.get_clock().sleep(self._rotate_delay)
finally:
self._doing_notif_rotation = False
@@ -821,8 +829,8 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
txn.execute("""
SELECT stream_ordering FROM event_push_actions
WHERE stream_ordering > ?
- ORDER BY stream_ordering ASC LIMIT 1 OFFSET 50000
- """, (old_rotate_stream_ordering,))
+ ORDER BY stream_ordering ASC LIMIT 1 OFFSET ?
+ """, (old_rotate_stream_ordering, self._rotate_count))
stream_row = txn.fetchone()
if stream_row:
offset_stream_ordering, = stream_row
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index cb1082e864..8bf87f38f7 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -14,36 +14,36 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from collections import OrderedDict, deque, namedtuple
-from functools import wraps
import itertools
import logging
+from collections import OrderedDict, deque, namedtuple
+from functools import wraps
+
+from six import iteritems, text_type
+from six.moves import range
+
+from canonicaljson import json
+from prometheus_client import Counter
-import simplejson as json
from twisted.internet import defer
+import synapse.metrics
+from synapse.api.constants import EventTypes
+from synapse.api.errors import SynapseError
+# these are only included to make the type annotations work
+from synapse.events import EventBase # noqa: F401
+from synapse.events.snapshot import EventContext # noqa: F401
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage.background_updates import BackgroundUpdateStore
+from synapse.storage.event_federation import EventFederationStore
from synapse.storage.events_worker import EventsWorkerStore
-from synapse.util.async import ObservableDeferred
+from synapse.types import RoomStreamToken, get_domain_from_id
+from synapse.util.async_helpers import ObservableDeferred
+from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
from synapse.util.frozenutils import frozendict_json_encoder
-from synapse.util.logcontext import (
- PreserveLoggingContext, make_deferred_yieldable,
-)
+from synapse.util.logcontext import PreserveLoggingContext, make_deferred_yieldable
from synapse.util.logutils import log_function
from synapse.util.metrics import Measure
-from synapse.api.constants import EventTypes
-from synapse.api.errors import SynapseError
-from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
-from synapse.types import get_domain_from_id, RoomStreamToken
-import synapse.metrics
-
-# these are only included to make the type annotations work
-from synapse.events import EventBase # noqa: F401
-from synapse.events.snapshot import EventContext # noqa: F401
-
-from six.moves import range
-from six import itervalues, iteritems
-
-from prometheus_client import Counter
logger = logging.getLogger(__name__)
@@ -67,7 +67,13 @@ state_delta_reuse_delta_counter = Counter(
def encode_json(json_object):
- return frozendict_json_encoder.encode(json_object)
+ """
+ Encode a Python object as JSON and return it in a Unicode string.
+ """
+ out = frozendict_json_encoder.encode(json_object)
+ if isinstance(out, bytes):
+ out = out.decode('utf8')
+ return out
class _EventPeristenceQueue(object):
@@ -144,25 +150,22 @@ class _EventPeristenceQueue(object):
try:
queue = self._get_drainining_queue(room_id)
for item in queue:
- # handle_queue_loop runs in the sentinel logcontext, so
- # there is no need to preserve_fn when running the
- # callbacks on the deferred.
try:
ret = yield per_item_callback(item)
- item.deferred.callback(ret)
except Exception:
- item.deferred.errback()
+ with PreserveLoggingContext():
+ item.deferred.errback()
+ else:
+ with PreserveLoggingContext():
+ item.deferred.callback(ret)
finally:
queue = self._event_persist_queues.pop(room_id, None)
if queue:
self._event_persist_queues[room_id] = queue
self._currently_persisting_rooms.discard(room_id)
- # set handle_queue_loop off on the background. We don't want to
- # attribute work done in it to the current request, so we drop the
- # logcontext altogether.
- with PreserveLoggingContext():
- handle_queue_loop()
+ # set handle_queue_loop off in the background
+ run_as_background_process("persist_events", handle_queue_loop)
def _get_drainining_queue(self, room_id):
queue = self._event_persist_queues.setdefault(room_id, deque())
@@ -198,7 +201,9 @@ def _retry_on_integrity_error(func):
return f
-class EventsStore(EventsWorkerStore):
+# inherits from EventFederationStore so that we can call _update_backward_extremities
+# and _handle_mult_prev_events (though arguably those could both be moved in here)
+class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore):
EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts"
EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url"
@@ -236,12 +241,18 @@ class EventsStore(EventsWorkerStore):
self._state_resolution_handler = hs.get_state_resolution_handler()
+ @defer.inlineCallbacks
def persist_events(self, events_and_contexts, backfilled=False):
"""
Write events to the database
Args:
events_and_contexts: list of tuples of (event, context)
- backfilled: ?
+ backfilled (bool): Whether the results are retrieved from federation
+ via backfill or not. Used to determine if they're "new" events
+ which might update the current state etc.
+
+ Returns:
+ Deferred[int]: the stream ordering of the latest persisted event
"""
partitioned = {}
for event, ctx in events_and_contexts:
@@ -258,10 +269,14 @@ class EventsStore(EventsWorkerStore):
for room_id in partitioned:
self._maybe_start_persisting(room_id)
- return make_deferred_yieldable(
+ yield make_deferred_yieldable(
defer.gatherResults(deferreds, consumeErrors=True)
)
+ max_persisted_id = yield self._stream_id_gen.get_current_token()
+
+ defer.returnValue(max_persisted_id)
+
@defer.inlineCallbacks
@log_function
def persist_event(self, event, context, backfilled=False):
@@ -348,11 +363,14 @@ class EventsStore(EventsWorkerStore):
new_forward_extremeties = {}
# map room_id->(type,state_key)->event_id tracking the full
- # state in each room after adding these events
+ # state in each room after adding these events.
+ # This is simply used to prefill the get_current_state_ids
+ # cache
current_state_for_room = {}
- # map room_id->(to_delete, to_insert) where each entry is
- # a map (type,key)->event_id giving the state delta in each
+ # map room_id->(to_delete, to_insert) where to_delete is a list
+ # of type/state keys to remove from current state, and to_insert
+ # is a map (type,key)->event_id giving the state delta in each
# room
state_delta_for_room = {}
@@ -422,19 +440,40 @@ class EventsStore(EventsWorkerStore):
logger.info(
"Calculating state delta for room %s", room_id,
)
- current_state = yield self._get_new_state_after_events(
- room_id,
- ev_ctx_rm,
- latest_event_ids,
- new_latest_event_ids,
- )
+ with Measure(
+ self._clock,
+ "persist_events.get_new_state_after_events",
+ ):
+ res = yield self._get_new_state_after_events(
+ room_id,
+ ev_ctx_rm,
+ latest_event_ids,
+ new_latest_event_ids,
+ )
+ current_state, delta_ids = res
+
+ # If either are not None then there has been a change,
+ # and we need to work out the delta (or use that
+ # given)
+ if delta_ids is not None:
+ # If there is a delta we know that we've
+ # only added or replaced state, never
+ # removed keys entirely.
+ state_delta_for_room[room_id] = ([], delta_ids)
+ elif current_state is not None:
+ with Measure(
+ self._clock,
+ "persist_events.calculate_state_delta",
+ ):
+ delta = yield self._calculate_state_delta(
+ room_id, current_state,
+ )
+ state_delta_for_room[room_id] = delta
+
+ # If we have the current_state then lets prefill
+ # the cache with it.
if current_state is not None:
current_state_for_room[room_id] = current_state
- delta = yield self._calculate_state_delta(
- room_id, current_state,
- )
- if delta is not None:
- state_delta_for_room[room_id] = delta
yield self.runInteraction(
"persist_events",
@@ -446,9 +485,14 @@ class EventsStore(EventsWorkerStore):
new_forward_extremeties=new_forward_extremeties,
)
persist_event_counter.inc(len(chunk))
- synapse.metrics.event_persisted_position.set(
- chunk[-1][0].internal_metadata.stream_ordering,
- )
+
+ if not backfilled:
+ # backfilled events have negative stream orderings, so we don't
+ # want to set the event_persisted_position to that.
+ synapse.metrics.event_persisted_position.set(
+ chunk[-1][0].internal_metadata.stream_ordering,
+ )
+
for event, context in chunk:
if context.app_service:
origin_type = "local"
@@ -501,7 +545,6 @@ class EventsStore(EventsWorkerStore):
iterable=list(new_latest_event_ids),
retcols=["prev_event_id"],
keyvalues={
- "room_id": room_id,
"is_state": False,
},
desc="_calculate_new_extremeties",
@@ -533,9 +576,15 @@ class EventsStore(EventsWorkerStore):
the new forward extremities for the room.
Returns:
- Deferred[dict[(str,str), str]|None]:
- None if there are no changes to the room state, or
- a dict of (type, state_key) -> event_id].
+ Deferred[tuple[dict[(str,str), str]|None, dict[(str,str), str]|None]]:
+ Returns a tuple of two state maps, the first being the full new current
+ state and the second being the delta to the existing current state.
+ If both are None then there has been no change.
+
+ If there has been a change then we only return the delta if its
+ already been calculated. Conversely if we do know the delta then
+ the new current state is only returned if we've already calculated
+ it.
"""
if not new_latest_event_ids:
@@ -543,18 +592,32 @@ class EventsStore(EventsWorkerStore):
# map from state_group to ((type, key) -> event_id) state map
state_groups_map = {}
+
+ # Map from (prev state group, new state group) -> delta state dict
+ state_group_deltas = {}
+
for ev, ctx in events_context:
if ctx.state_group is None:
- # I don't think this can happen, but let's double-check
- raise Exception(
- "Context for new extremity event %s has no state "
- "group" % (ev.event_id, ),
- )
+ # This should only happen for outlier events.
+ if not ev.internal_metadata.is_outlier():
+ raise Exception(
+ "Context for new event %s has no state "
+ "group" % (ev.event_id, ),
+ )
+ continue
if ctx.state_group in state_groups_map:
continue
- state_groups_map[ctx.state_group] = ctx.current_state_ids
+ # We're only interested in pulling out state that has already
+ # been cached in the context. We'll pull stuff out of the DB later
+ # if necessary.
+ current_state_ids = ctx.get_cached_current_state_ids()
+ if current_state_ids is not None:
+ state_groups_map[ctx.state_group] = current_state_ids
+
+ if ctx.prev_group:
+ state_group_deltas[(ctx.prev_group, ctx.state_group)] = ctx.delta_ids
# We need to map the event_ids to their state groups. First, let's
# check if the event is one we're persisting, in which case we can
@@ -569,7 +632,7 @@ class EventsStore(EventsWorkerStore):
for event_id in new_latest_event_ids:
# First search in the list of new events we're adding.
for ev, ctx in events_context:
- if event_id == ev.event_id:
+ if event_id == ev.event_id and ctx.state_group is not None:
event_id_to_state_group[event_id] = ctx.state_group
break
else:
@@ -597,7 +660,26 @@ class EventsStore(EventsWorkerStore):
# If they old and new groups are the same then we don't need to do
# anything.
if old_state_groups == new_state_groups:
- return
+ defer.returnValue((None, None))
+
+ if len(new_state_groups) == 1 and len(old_state_groups) == 1:
+ # If we're going from one state group to another, lets check if
+ # we have a delta for that transition. If we do then we can just
+ # return that.
+
+ new_state_group = next(iter(new_state_groups))
+ old_state_group = next(iter(old_state_groups))
+
+ delta_ids = state_group_deltas.get(
+ (old_state_group, new_state_group,), None
+ )
+ if delta_ids is not None:
+ # We have a delta from the existing to new current state,
+ # so lets just return that. If we happen to already have
+ # the current state in memory then lets also return that,
+ # but it doesn't matter if we don't.
+ new_state = state_groups_map.get(new_state_group)
+ defer.returnValue((new_state, delta_ids))
# Now that we have calculated new_state_groups we need to get
# their state IDs so we can resolve to a single state set.
@@ -609,7 +691,7 @@ class EventsStore(EventsWorkerStore):
if len(new_state_groups) == 1:
# If there is only one state group, then we know what the current
# state is.
- defer.returnValue(state_groups_map[new_state_groups.pop()])
+ defer.returnValue((state_groups_map[new_state_groups.pop()], None))
# Ok, we need to defer to the state handler to resolve our state sets.
@@ -623,12 +705,14 @@ class EventsStore(EventsWorkerStore):
}
events_map = {ev.event_id: ev for ev, _ in events_context}
+ room_version = yield self.get_room_version(room_id)
+
logger.debug("calling resolve_state_groups from preserve_events")
res = yield self._state_resolution_handler.resolve_state_groups(
- room_id, state_groups, events_map, get_events
+ room_id, room_version, state_groups, events_map, get_events
)
- defer.returnValue(res.state)
+ defer.returnValue((res.state, None))
@defer.inlineCallbacks
def _calculate_state_delta(self, room_id, current_state):
@@ -637,28 +721,20 @@ class EventsStore(EventsWorkerStore):
Assumes that we are only persisting events for one room at a time.
Returns:
- 2-tuple (to_delete, to_insert) where both are state dicts,
- i.e. (type, state_key) -> event_id. `to_delete` are the entries to
- first be deleted from current_state_events, `to_insert` are entries
- to insert.
+ tuple[list, dict] (to_delete, to_insert): where to_delete are the
+ type/state_keys to remove from current_state_events and `to_insert`
+ are the updates to current_state_events.
"""
existing_state = yield self.get_current_state_ids(room_id)
- existing_events = set(itervalues(existing_state))
- new_events = set(ev_id for ev_id in itervalues(current_state))
- changed_events = existing_events ^ new_events
-
- if not changed_events:
- return
+ to_delete = [
+ key for key in existing_state
+ if key not in current_state
+ ]
- to_delete = {
- key: ev_id for key, ev_id in iteritems(existing_state)
- if ev_id in changed_events
- }
- events_to_insert = (new_events - existing_events)
to_insert = {
key: ev_id for key, ev_id in iteritems(current_state)
- if ev_id in events_to_insert
+ if ev_id != existing_state.get(key)
}
defer.returnValue((to_delete, to_insert))
@@ -681,10 +757,10 @@ class EventsStore(EventsWorkerStore):
delete_existing (bool): True to purge existing table rows for the
events from the database. This is useful when retrying due to
IntegrityError.
- state_delta_for_room (dict[str, (list[str], list[str])]):
+ state_delta_for_room (dict[str, (list, dict)]):
The current-state delta for each room. For each room, a tuple
- (to_delete, to_insert), being a list of event ids to be removed
- from the current state, and a list of event ids to be added to
+ (to_delete, to_insert), being a list of type/state keys to be
+ removed from the current state, and a state set to be added to
the current state.
new_forward_extremeties (dict[str, list[str]]):
The new forward extremities for each room. For each room, a
@@ -762,9 +838,46 @@ class EventsStore(EventsWorkerStore):
def _update_current_state_txn(self, txn, state_delta_by_room, max_stream_order):
for room_id, current_state_tuple in iteritems(state_delta_by_room):
to_delete, to_insert = current_state_tuple
+
+ # First we add entries to the current_state_delta_stream. We
+ # do this before updating the current_state_events table so
+ # that we can use it to calculate the `prev_event_id`. (This
+ # allows us to not have to pull out the existing state
+ # unnecessarily).
+ sql = """
+ INSERT INTO current_state_delta_stream
+ (stream_id, room_id, type, state_key, event_id, prev_event_id)
+ SELECT ?, ?, ?, ?, ?, (
+ SELECT event_id FROM current_state_events
+ WHERE room_id = ? AND type = ? AND state_key = ?
+ )
+ """
+ txn.executemany(sql, (
+ (
+ max_stream_order, room_id, etype, state_key, None,
+ room_id, etype, state_key,
+ )
+ for etype, state_key in to_delete
+ # We sanity check that we're deleting rather than updating
+ if (etype, state_key) not in to_insert
+ ))
+ txn.executemany(sql, (
+ (
+ max_stream_order, room_id, etype, state_key, ev_id,
+ room_id, etype, state_key,
+ )
+ for (etype, state_key), ev_id in iteritems(to_insert)
+ ))
+
+ # Now we actually update the current_state_events table
+
txn.executemany(
- "DELETE FROM current_state_events WHERE event_id = ?",
- [(ev_id,) for ev_id in itervalues(to_delete)],
+ "DELETE FROM current_state_events"
+ " WHERE room_id = ? AND type = ? AND state_key = ?",
+ (
+ (room_id, etype, state_key)
+ for etype, state_key in itertools.chain(to_delete, to_insert)
+ ),
)
self._simple_insert_many_txn(
@@ -781,26 +894,8 @@ class EventsStore(EventsWorkerStore):
],
)
- state_deltas = {key: None for key in to_delete}
- state_deltas.update(to_insert)
-
- self._simple_insert_many_txn(
- txn,
- table="current_state_delta_stream",
- values=[
- {
- "stream_id": max_stream_order,
- "room_id": room_id,
- "type": key[0],
- "state_key": key[1],
- "event_id": ev_id,
- "prev_event_id": to_delete.get(key, None),
- }
- for key, ev_id in iteritems(state_deltas)
- ]
- )
-
- self._curr_state_delta_stream_cache.entity_has_changed(
+ txn.call_after(
+ self._curr_state_delta_stream_cache.entity_has_changed,
room_id, max_stream_order,
)
@@ -812,7 +907,8 @@ class EventsStore(EventsWorkerStore):
# and which we have added, then we invlidate the caches for all
# those users.
members_changed = set(
- state_key for ev_type, state_key in state_deltas
+ state_key
+ for ev_type, state_key in itertools.chain(to_delete, to_insert)
if ev_type == EventTypes.Member
)
@@ -985,7 +1081,7 @@ class EventsStore(EventsWorkerStore):
metadata_json = encode_json(
event.internal_metadata.get_dict()
- ).decode("UTF-8")
+ )
sql = (
"UPDATE event_json SET internal_metadata = ?"
@@ -1044,7 +1140,6 @@ class EventsStore(EventsWorkerStore):
"event_edge_hashes",
"event_edges",
"event_forward_extremities",
- "event_push_actions",
"event_reference_hashes",
"event_search",
"event_signatures",
@@ -1064,6 +1159,14 @@ class EventsStore(EventsWorkerStore):
[(ev.event_id,) for ev, _ in events_and_contexts]
)
+ for table in (
+ "event_push_actions",
+ ):
+ txn.executemany(
+ "DELETE FROM %s WHERE room_id = ? AND event_id = ?" % (table,),
+ [(ev.room_id, ev.event_id) for ev, _ in events_and_contexts]
+ )
+
def _store_event_txn(self, txn, events_and_contexts):
"""Insert new events into the event and event_json tables
@@ -1092,8 +1195,8 @@ class EventsStore(EventsWorkerStore):
"room_id": event.room_id,
"internal_metadata": encode_json(
event.internal_metadata.get_dict()
- ).decode("UTF-8"),
- "json": encode_json(event_dict(event)).decode("UTF-8"),
+ ),
+ "json": encode_json(event_dict(event)),
}
for event, _ in events_and_contexts
],
@@ -1112,13 +1215,12 @@ class EventsStore(EventsWorkerStore):
"type": event.type,
"processed": True,
"outlier": event.internal_metadata.is_outlier(),
- "content": encode_json(event.content).decode("UTF-8"),
"origin_server_ts": int(event.origin_server_ts),
"received_ts": self._clock.time_msec(),
"sender": event.sender,
"contains_url": (
"url" in event.content
- and isinstance(event.content["url"], basestring)
+ and isinstance(event.content["url"], text_type)
),
}
for event, _ in events_and_contexts
@@ -1336,88 +1438,6 @@ class EventsStore(EventsWorkerStore):
)
@defer.inlineCallbacks
- def have_events_in_timeline(self, event_ids):
- """Given a list of event ids, check if we have already processed and
- stored them as non outliers.
- """
- rows = yield self._simple_select_many_batch(
- table="events",
- retcols=("event_id",),
- column="event_id",
- iterable=list(event_ids),
- keyvalues={"outlier": False},
- desc="have_events_in_timeline",
- )
-
- defer.returnValue(set(r["event_id"] for r in rows))
-
- @defer.inlineCallbacks
- def have_seen_events(self, event_ids):
- """Given a list of event ids, check if we have already processed them.
-
- Args:
- event_ids (iterable[str]):
-
- Returns:
- Deferred[set[str]]: The events we have already seen.
- """
- results = set()
-
- def have_seen_events_txn(txn, chunk):
- sql = (
- "SELECT event_id FROM events as e WHERE e.event_id IN (%s)"
- % (",".join("?" * len(chunk)), )
- )
- txn.execute(sql, chunk)
- for (event_id, ) in txn:
- results.add(event_id)
-
- # break the input up into chunks of 100
- input_iterator = iter(event_ids)
- for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)),
- []):
- yield self.runInteraction(
- "have_seen_events",
- have_seen_events_txn,
- chunk,
- )
- defer.returnValue(results)
-
- def get_seen_events_with_rejections(self, event_ids):
- """Given a list of event ids, check if we rejected them.
-
- Args:
- event_ids (list[str])
-
- Returns:
- Deferred[dict[str, str|None):
- Has an entry for each event id we already have seen. Maps to
- the rejected reason string if we rejected the event, else maps
- to None.
- """
- if not event_ids:
- return defer.succeed({})
-
- def f(txn):
- sql = (
- "SELECT e.event_id, reason FROM events as e "
- "LEFT JOIN rejections as r ON e.event_id = r.event_id "
- "WHERE e.event_id = ?"
- )
-
- res = {}
- for event_id in event_ids:
- txn.execute(sql, (event_id,))
- row = txn.fetchone()
- if row:
- _, rejected = row
- res[event_id] = rejected
-
- return res
-
- return self.runInteraction("get_rejection_reasons", f)
-
- @defer.inlineCallbacks
def count_daily_messages(self):
"""
Returns an estimate of the number of messages sent in the last day.
@@ -1509,7 +1529,7 @@ class EventsStore(EventsWorkerStore):
contains_url = "url" in content
if contains_url:
- contains_url &= isinstance(content["url"], basestring)
+ contains_url &= isinstance(content["url"], text_type)
except (KeyError, AttributeError):
# If the event is missing a necessary field then
# skip over it.
@@ -1890,10 +1910,10 @@ class EventsStore(EventsWorkerStore):
(room_id,)
)
rows = txn.fetchall()
- max_depth = max(row[0] for row in rows)
+ max_depth = max(row[1] for row in rows)
- if max_depth <= token.topological:
- # We need to ensure we don't delete all the events from the datanase
+ if max_depth < token.topological:
+ # We need to ensure we don't delete all the events from the database
# otherwise we wouldn't be able to send any events (due to not
# having any backwards extremeties)
raise SynapseError(
diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py
index 32d9d00ffb..a8326f5296 100644
--- a/synapse/storage/events_worker.py
+++ b/synapse/storage/events_worker.py
@@ -12,27 +12,31 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import SQLBaseStore
-from twisted.internet import defer, reactor
+import itertools
+import logging
+from collections import namedtuple
+
+from canonicaljson import json
+from twisted.internet import defer
+
+from synapse.api.errors import NotFoundError
+# these are only included to make the type annotations work
+from synapse.events import EventBase # noqa: F401
from synapse.events import FrozenEvent
+from synapse.events.snapshot import EventContext # noqa: F401
from synapse.events.utils import prune_event
-
+from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util.logcontext import (
- PreserveLoggingContext, make_deferred_yieldable, run_in_background,
+ LoggingContext,
+ PreserveLoggingContext,
+ make_deferred_yieldable,
+ run_in_background,
)
from synapse.util.metrics import Measure
-from synapse.api.errors import SynapseError
-from collections import namedtuple
-
-import logging
-import simplejson as json
-
-# these are only included to make the type annotations work
-from synapse.events import EventBase # noqa: F401
-from synapse.events.snapshot import EventContext # noqa: F401
+from ._base import SQLBaseStore
logger = logging.getLogger(__name__)
@@ -75,7 +79,7 @@ class EventsWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
def get_event(self, event_id, check_redacted=True,
get_prev_content=False, allow_rejected=False,
- allow_none=False):
+ allow_none=False, check_room_id=None):
"""Get an event from the database by event_id.
Args:
@@ -86,7 +90,9 @@ class EventsWorkerStore(SQLBaseStore):
include the previous states content in the unsigned field.
allow_rejected (bool): If True return rejected events.
allow_none (bool): If True, return None if no event found, if
- False throw an exception.
+ False throw a NotFoundError
+ check_room_id (str|None): if not None, check the room of the found event.
+ If there is a mismatch, behave as per allow_none.
Returns:
Deferred : A FrozenEvent.
@@ -98,10 +104,16 @@ class EventsWorkerStore(SQLBaseStore):
allow_rejected=allow_rejected,
)
- if not events and not allow_none:
- raise SynapseError(404, "Could not find event %s" % (event_id,))
+ event = events[0] if events else None
+
+ if event is not None and check_room_id is not None:
+ if event.room_id != check_room_id:
+ event = None
- defer.returnValue(events[0] if events else None)
+ if event is None and not allow_none:
+ raise NotFoundError("Could not find event %s" % (event_id,))
+
+ defer.returnValue(event)
@defer.inlineCallbacks
def get_events(self, event_ids, check_redacted=True,
@@ -145,6 +157,9 @@ class EventsWorkerStore(SQLBaseStore):
missing_events_ids = [e for e in event_ids if e not in event_entry_map]
if missing_events_ids:
+ log_ctx = LoggingContext.current_context()
+ log_ctx.record_event_fetch(len(missing_events_ids))
+
missing_events = yield self._enqueue_events(
missing_events_ids,
check_redacted=check_redacted,
@@ -218,32 +233,47 @@ class EventsWorkerStore(SQLBaseStore):
"""Takes a database connection and waits for requests for events from
the _event_fetch_list queue.
"""
- event_list = []
i = 0
while True:
+ with self._event_fetch_lock:
+ event_list = self._event_fetch_list
+ self._event_fetch_list = []
+
+ if not event_list:
+ single_threaded = self.database_engine.single_threaded
+ if single_threaded or i > EVENT_QUEUE_ITERATIONS:
+ self._event_fetch_ongoing -= 1
+ return
+ else:
+ self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S)
+ i += 1
+ continue
+ i = 0
+
+ self._fetch_event_list(conn, event_list)
+
+ def _fetch_event_list(self, conn, event_list):
+ """Handle a load of requests from the _event_fetch_list queue
+
+ Args:
+ conn (twisted.enterprise.adbapi.Connection): database connection
+
+ event_list (list[Tuple[list[str], Deferred]]):
+ The fetch requests. Each entry consists of a list of event
+ ids to be fetched, and a deferred to be completed once the
+ events have been fetched.
+
+ """
+ with Measure(self._clock, "_fetch_event_list"):
try:
- with self._event_fetch_lock:
- event_list = self._event_fetch_list
- self._event_fetch_list = []
-
- if not event_list:
- single_threaded = self.database_engine.single_threaded
- if single_threaded or i > EVENT_QUEUE_ITERATIONS:
- self._event_fetch_ongoing -= 1
- return
- else:
- self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S)
- i += 1
- continue
- i = 0
-
- event_id_lists = zip(*event_list)[0]
+ event_id_lists = list(zip(*event_list))[0]
event_ids = [
item for sublist in event_id_lists for item in sublist
]
rows = self._new_transaction(
- conn, "do_fetch", [], [], None, self._fetch_event_rows, event_ids
+ conn, "do_fetch", [], [],
+ self._fetch_event_rows, event_ids,
)
row_dict = {
@@ -265,20 +295,19 @@ class EventsWorkerStore(SQLBaseStore):
except Exception:
logger.exception("Failed to callback")
with PreserveLoggingContext():
- reactor.callFromThread(fire, event_list, row_dict)
+ self.hs.get_reactor().callFromThread(fire, event_list, row_dict)
except Exception as e:
logger.exception("do_fetch")
# We only want to resolve deferreds from the main thread
- def fire(evs):
+ def fire(evs, exc):
for _, d in evs:
if not d.called:
with PreserveLoggingContext():
- d.errback(e)
+ d.errback(exc)
- if event_list:
- with PreserveLoggingContext():
- reactor.callFromThread(fire, event_list)
+ with PreserveLoggingContext():
+ self.hs.get_reactor().callFromThread(fire, event_list, e)
@defer.inlineCallbacks
def _enqueue_events(self, events, check_redacted=True, allow_rejected=False):
@@ -304,10 +333,11 @@ class EventsWorkerStore(SQLBaseStore):
should_start = False
if should_start:
- with PreserveLoggingContext():
- self.runWithConnection(
- self._do_fetch
- )
+ run_as_background_process(
+ "fetch_events",
+ self.runWithConnection,
+ self._do_fetch,
+ )
logger.debug("Loading %d events", len(events))
with PreserveLoggingContext():
@@ -414,3 +444,85 @@ class EventsWorkerStore(SQLBaseStore):
self._get_event_cache.prefill((original_ev.event_id,), cache_entry)
defer.returnValue(cache_entry)
+
+ @defer.inlineCallbacks
+ def have_events_in_timeline(self, event_ids):
+ """Given a list of event ids, check if we have already processed and
+ stored them as non outliers.
+ """
+ rows = yield self._simple_select_many_batch(
+ table="events",
+ retcols=("event_id",),
+ column="event_id",
+ iterable=list(event_ids),
+ keyvalues={"outlier": False},
+ desc="have_events_in_timeline",
+ )
+
+ defer.returnValue(set(r["event_id"] for r in rows))
+
+ @defer.inlineCallbacks
+ def have_seen_events(self, event_ids):
+ """Given a list of event ids, check if we have already processed them.
+
+ Args:
+ event_ids (iterable[str]):
+
+ Returns:
+ Deferred[set[str]]: The events we have already seen.
+ """
+ results = set()
+
+ def have_seen_events_txn(txn, chunk):
+ sql = (
+ "SELECT event_id FROM events as e WHERE e.event_id IN (%s)"
+ % (",".join("?" * len(chunk)), )
+ )
+ txn.execute(sql, chunk)
+ for (event_id, ) in txn:
+ results.add(event_id)
+
+ # break the input up into chunks of 100
+ input_iterator = iter(event_ids)
+ for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)),
+ []):
+ yield self.runInteraction(
+ "have_seen_events",
+ have_seen_events_txn,
+ chunk,
+ )
+ defer.returnValue(results)
+
+ def get_seen_events_with_rejections(self, event_ids):
+ """Given a list of event ids, check if we rejected them.
+
+ Args:
+ event_ids (list[str])
+
+ Returns:
+ Deferred[dict[str, str|None):
+ Has an entry for each event id we already have seen. Maps to
+ the rejected reason string if we rejected the event, else maps
+ to None.
+ """
+ if not event_ids:
+ return defer.succeed({})
+
+ def f(txn):
+ sql = (
+ "SELECT e.event_id, reason FROM events as e "
+ "LEFT JOIN rejections as r ON e.event_id = r.event_id "
+ "WHERE e.event_id = ?"
+ )
+
+ res = {}
+ for event_id in event_ids:
+ txn.execute(sql, (event_id,))
+ row = txn.fetchone()
+ if row:
+ _, rejected = row
+ res[event_id] = rejected
+
+ return res
+
+ return self.runInteraction("get_rejection_reasons", f)
diff --git a/synapse/storage/filtering.py b/synapse/storage/filtering.py
index 2e2763126d..6ddcc909bf 100644
--- a/synapse/storage/filtering.py
+++ b/synapse/storage/filtering.py
@@ -13,14 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from canonicaljson import encode_canonical_json
+
from twisted.internet import defer
-from ._base import SQLBaseStore
-from synapse.api.errors import SynapseError, Codes
+from synapse.api.errors import Codes, SynapseError
from synapse.util.caches.descriptors import cachedInlineCallbacks
-from canonicaljson import encode_canonical_json
-import simplejson as json
+from ._base import SQLBaseStore, db_to_json
class FilteringStore(SQLBaseStore):
@@ -44,7 +44,7 @@ class FilteringStore(SQLBaseStore):
desc="get_user_filter",
)
- defer.returnValue(json.loads(bytes(def_json).decode("utf-8")))
+ defer.returnValue(db_to_json(def_json))
def add_user_filter(self, user_localpart, user_filter):
def_json = encode_canonical_json(user_filter)
diff --git a/synapse/storage/group_server.py b/synapse/storage/group_server.py
index da05ccb027..592d1b4c2a 100644
--- a/synapse/storage/group_server.py
+++ b/synapse/storage/group_server.py
@@ -14,15 +14,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from canonicaljson import json
+
from twisted.internet import defer
from synapse.api.errors import SynapseError
from ._base import SQLBaseStore
-import simplejson as json
-
-
# The category ID for the "default" category. We don't store as null in the
# database to avoid the fun of null != null
_DEFAULT_CATEGORY_ID = ""
diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py
index 0f13b61da8..f547977600 100644
--- a/synapse/storage/keys.py
+++ b/synapse/storage/keys.py
@@ -13,17 +13,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import SQLBaseStore
-from synapse.util.caches.descriptors import cachedInlineCallbacks
+import hashlib
+import logging
-from twisted.internet import defer
import six
-import OpenSSL
from signedjson.key import decode_verify_key_bytes
-import hashlib
-import logging
+import OpenSSL
+from twisted.internet import defer
+
+from synapse.util.caches.descriptors import cachedInlineCallbacks
+
+from ._base import SQLBaseStore
logger = logging.getLogger(__name__)
diff --git a/synapse/storage/monthly_active_users.py b/synapse/storage/monthly_active_users.py
new file mode 100644
index 0000000000..c7899d7fd2
--- /dev/null
+++ b/synapse/storage/monthly_active_users.py
@@ -0,0 +1,221 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from twisted.internet import defer
+
+from synapse.util.caches.descriptors import cached
+
+from ._base import SQLBaseStore
+
+logger = logging.getLogger(__name__)
+
+# Number of msec of granularity to store the monthly_active_user timestamp
+# This means it is not necessary to update the table on every request
+LAST_SEEN_GRANULARITY = 60 * 60 * 1000
+
+
+class MonthlyActiveUsersStore(SQLBaseStore):
+ def __init__(self, dbconn, hs):
+ super(MonthlyActiveUsersStore, self).__init__(None, hs)
+ self._clock = hs.get_clock()
+ self.hs = hs
+ self.reserved_users = ()
+
+ @defer.inlineCallbacks
+ def initialise_reserved_users(self, threepids):
+ store = self.hs.get_datastore()
+ reserved_user_list = []
+
+ # Do not add more reserved users than the total allowable number
+ for tp in threepids[:self.hs.config.max_mau_value]:
+ user_id = yield store.get_user_id_by_threepid(
+ tp["medium"], tp["address"]
+ )
+ if user_id:
+ yield self.upsert_monthly_active_user(user_id)
+ reserved_user_list.append(user_id)
+ else:
+ logger.warning(
+ "mau limit reserved threepid %s not found in db" % tp
+ )
+ self.reserved_users = tuple(reserved_user_list)
+
+ @defer.inlineCallbacks
+ def reap_monthly_active_users(self):
+ """
+ Cleans out monthly active user table to ensure that no stale
+ entries exist.
+
+ Returns:
+ Deferred[]
+ """
+ def _reap_users(txn):
+ # Purge stale users
+
+ thirty_days_ago = (
+ int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
+ )
+ query_args = [thirty_days_ago]
+ base_sql = "DELETE FROM monthly_active_users WHERE timestamp < ?"
+
+ # Need if/else since 'AND user_id NOT IN ({})' fails on Postgres
+ # when len(reserved_users) == 0. Works fine on sqlite.
+ if len(self.reserved_users) > 0:
+ # questionmarks is a hack to overcome sqlite not supporting
+ # tuples in 'WHERE IN %s'
+ questionmarks = '?' * len(self.reserved_users)
+
+ query_args.extend(self.reserved_users)
+ sql = base_sql + """ AND user_id NOT IN ({})""".format(
+ ','.join(questionmarks)
+ )
+ else:
+ sql = base_sql
+
+ txn.execute(sql, query_args)
+
+ # If MAU user count still exceeds the MAU threshold, then delete on
+ # a least recently active basis.
+ # Note it is not possible to write this query using OFFSET due to
+ # incompatibilities in how sqlite and postgres support the feature.
+ # sqlite requires 'LIMIT -1 OFFSET ?', the LIMIT must be present
+ # While Postgres does not require 'LIMIT', but also does not support
+ # negative LIMIT values. So there is no way to write it that both can
+ # support
+ safe_guard = self.hs.config.max_mau_value - len(self.reserved_users)
+ # Must be greater than zero for postgres
+ safe_guard = safe_guard if safe_guard > 0 else 0
+ query_args = [safe_guard]
+
+ base_sql = """
+ DELETE FROM monthly_active_users
+ WHERE user_id NOT IN (
+ SELECT user_id FROM monthly_active_users
+ ORDER BY timestamp DESC
+ LIMIT ?
+ )
+ """
+ # Need if/else since 'AND user_id NOT IN ({})' fails on Postgres
+ # when len(reserved_users) == 0. Works fine on sqlite.
+ if len(self.reserved_users) > 0:
+ query_args.extend(self.reserved_users)
+ sql = base_sql + """ AND user_id NOT IN ({})""".format(
+ ','.join(questionmarks)
+ )
+ else:
+ sql = base_sql
+ txn.execute(sql, query_args)
+
+ yield self.runInteraction("reap_monthly_active_users", _reap_users)
+ # It seems poor to invalidate the whole cache, Postgres supports
+ # 'Returning' which would allow me to invalidate only the
+ # specific users, but sqlite has no way to do this and instead
+ # I would need to SELECT and the DELETE which without locking
+ # is racy.
+ # Have resolved to invalidate the whole cache for now and do
+ # something about it if and when the perf becomes significant
+ self.user_last_seen_monthly_active.invalidate_all()
+ self.get_monthly_active_count.invalidate_all()
+
+ @cached(num_args=0)
+ def get_monthly_active_count(self):
+ """Generates current count of monthly active users
+
+ Returns:
+ Defered[int]: Number of current monthly active users
+ """
+
+ def _count_users(txn):
+ sql = "SELECT COALESCE(count(*), 0) FROM monthly_active_users"
+
+ txn.execute(sql)
+ count, = txn.fetchone()
+ return count
+ return self.runInteraction("count_users", _count_users)
+
+ @defer.inlineCallbacks
+ def upsert_monthly_active_user(self, user_id):
+ """
+ Updates or inserts monthly active user member
+ Arguments:
+ user_id (str): user to add/update
+ Deferred[bool]: True if a new entry was created, False if an
+ existing one was updated.
+ """
+ is_insert = yield self._simple_upsert(
+ desc="upsert_monthly_active_user",
+ table="monthly_active_users",
+ keyvalues={
+ "user_id": user_id,
+ },
+ values={
+ "timestamp": int(self._clock.time_msec()),
+ },
+ lock=False,
+ )
+ if is_insert:
+ self.user_last_seen_monthly_active.invalidate((user_id,))
+ self.get_monthly_active_count.invalidate(())
+
+ @cached(num_args=1)
+ def user_last_seen_monthly_active(self, user_id):
+ """
+ Checks if a given user is part of the monthly active user group
+ Arguments:
+ user_id (str): user to add/update
+ Return:
+ Deferred[int] : timestamp since last seen, None if never seen
+
+ """
+
+ return(self._simple_select_one_onecol(
+ table="monthly_active_users",
+ keyvalues={
+ "user_id": user_id,
+ },
+ retcol="timestamp",
+ allow_none=True,
+ desc="user_last_seen_monthly_active",
+ ))
+
+ @defer.inlineCallbacks
+ def populate_monthly_active_users(self, user_id):
+ """Checks on the state of monthly active user limits and optionally
+ add the user to the monthly active tables
+
+ Args:
+ user_id(str): the user_id to query
+ """
+ if self.hs.config.limit_usage_by_mau:
+ is_trial = yield self.is_trial_user(user_id)
+ if is_trial:
+ # we don't track trial users in the MAU table.
+ return
+
+ last_seen_timestamp = yield self.user_last_seen_monthly_active(user_id)
+ now = self.hs.get_clock().time_msec()
+
+ # We want to reduce to the total number of db writes, and are happy
+ # to trade accuracy of timestamp in order to lighten load. This means
+ # We always insert new users (where MAU threshold has not been reached),
+ # but only update if we have not previously seen the user for
+ # LAST_SEEN_GRANULARITY ms
+ if last_seen_timestamp is None:
+ count = yield self.get_monthly_active_count()
+ if count < self.hs.config.max_mau_value:
+ yield self.upsert_monthly_active_user(user_id)
+ elif now - last_seen_timestamp > LAST_SEEN_GRANULARITY:
+ yield self.upsert_monthly_active_user(user_id)
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index cf2aae0468..b364719312 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -20,13 +20,12 @@ import logging
import os
import re
-
logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts.
-SCHEMA_VERSION = 50
+SCHEMA_VERSION = 51
dir_path = os.path.abspath(os.path.dirname(__file__))
diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py
index f05d91cc58..a0c7a0dc87 100644
--- a/synapse/storage/presence.py
+++ b/synapse/storage/presence.py
@@ -13,13 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import SQLBaseStore
+from collections import namedtuple
+
+from twisted.internet import defer
+
from synapse.api.constants import PresenceState
-from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
from synapse.util import batch_iter
+from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
-from collections import namedtuple
-from twisted.internet import defer
+from ._base import SQLBaseStore
class UserPresenceState(namedtuple("UserPresenceState",
diff --git a/synapse/storage/profile.py b/synapse/storage/profile.py
index 8612bd5ecc..88b50f33b5 100644
--- a/synapse/storage/profile.py
+++ b/synapse/storage/profile.py
@@ -15,8 +15,8 @@
from twisted.internet import defer
-from synapse.storage.roommember import ProfileInfo
from synapse.api.errors import StoreError
+from synapse.storage.roommember import ProfileInfo
from ._base import SQLBaseStore
@@ -71,8 +71,6 @@ class ProfileWorkerStore(SQLBaseStore):
desc="get_from_remote_profile_cache",
)
-
-class ProfileStore(ProfileWorkerStore):
def create_profile(self, user_localpart):
return self._simple_insert(
table="profiles",
@@ -96,6 +94,8 @@ class ProfileStore(ProfileWorkerStore):
desc="set_profile_avatar_url",
)
+
+class ProfileStore(ProfileWorkerStore):
def add_remote_profile_cache(self, user_id, displayname, avatar_url):
"""Ensure we are caching the remote user's profiles.
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index 04a0b59a39..6a5028961d 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -14,20 +14,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import SQLBaseStore
+import abc
+import logging
+
+from canonicaljson import json
+
+from twisted.internet import defer
+
+from synapse.push.baserules import list_with_base_rules
from synapse.storage.appservice import ApplicationServiceWorkerStore
from synapse.storage.pusher import PusherWorkerStore
from synapse.storage.receipts import ReceiptsWorkerStore
from synapse.storage.roommember import RoomMemberWorkerStore
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
-from synapse.push.baserules import list_with_base_rules
-from synapse.api.constants import EventTypes
-from twisted.internet import defer
-import abc
-import logging
-import simplejson as json
+from ._base import SQLBaseStore
logger = logging.getLogger(__name__)
@@ -183,6 +185,7 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore,
defer.returnValue(results)
+ @defer.inlineCallbacks
def bulk_get_push_rules_for_room(self, event, context):
state_group = context.state_group
if not state_group:
@@ -192,9 +195,11 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore,
# To do this we set the state_group to a new object as object() != object()
state_group = object()
- return self._bulk_get_push_rules_for_room(
- event.room_id, state_group, context.current_state_ids, event=event
+ current_state_ids = yield context.get_current_state_ids(self)
+ result = yield self._bulk_get_push_rules_for_room(
+ event.room_id, state_group, current_state_ids, event=event
)
+ defer.returnValue(result)
@cachedInlineCallbacks(num_args=2, cache_context=True)
def _bulk_get_push_rules_for_room(self, room_id, state_group, current_state_ids,
@@ -244,18 +249,6 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore,
if uid in local_users_in_room:
user_ids.add(uid)
- forgotten = yield self.who_forgot_in_room(
- event.room_id, on_invalidate=cache_context.invalidate,
- )
-
- for row in forgotten:
- user_id = row["user_id"]
- event_id = row["event_id"]
-
- mem_id = current_state_ids.get((EventTypes.Member, user_id), None)
- if event_id == mem_id:
- user_ids.discard(user_id)
-
rules_by_user = yield self.bulk_get_push_rules(
user_ids, on_invalidate=cache_context.invalidate,
)
diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py
index 307660b99a..c7987bfcdd 100644
--- a/synapse/storage/pusher.py
+++ b/synapse/storage/pusher.py
@@ -14,19 +14,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import SQLBaseStore
-from twisted.internet import defer
+import logging
-from canonicaljson import encode_canonical_json
+import six
+
+from canonicaljson import encode_canonical_json, json
+
+from twisted.internet import defer
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
-import logging
-import simplejson as json
-import types
+from ._base import SQLBaseStore
logger = logging.getLogger(__name__)
+if six.PY2:
+ db_binary_type = buffer
+else:
+ db_binary_type = memoryview
+
class PusherWorkerStore(SQLBaseStore):
def _decode_pushers_rows(self, rows):
@@ -34,18 +40,18 @@ class PusherWorkerStore(SQLBaseStore):
dataJson = r['data']
r['data'] = None
try:
- if isinstance(dataJson, types.BufferType):
+ if isinstance(dataJson, db_binary_type):
dataJson = str(dataJson).decode("UTF8")
r['data'] = json.loads(dataJson)
except Exception as e:
logger.warn(
"Invalid JSON in data for pusher %d: %s, %s",
- r['id'], dataJson, e.message,
+ r['id'], dataJson, e.args[0],
)
pass
- if isinstance(r['pushkey'], types.BufferType):
+ if isinstance(r['pushkey'], db_binary_type):
r['pushkey'] = str(r['pushkey']).decode("UTF8")
return rows
@@ -233,7 +239,7 @@ class PusherStore(PusherWorkerStore):
)
if newly_inserted:
- self.runInteraction(
+ yield self.runInteraction(
"add_pusher",
self._invalidate_cache_and_stream,
self.get_if_user_has_pusher, (user_id,)
diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py
index c93c228f6e..0ac665e967 100644
--- a/synapse/storage/receipts.py
+++ b/synapse/storage/receipts.py
@@ -14,17 +14,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import SQLBaseStore
-from .util.id_generators import StreamIdGenerator
-from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList, cached
-from synapse.util.caches.stream_change_cache import StreamChangeCache
+import abc
+import logging
+
+from canonicaljson import json
from twisted.internet import defer
-import abc
-import logging
-import simplejson as json
+from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
+from synapse.util.caches.stream_change_cache import StreamChangeCache
+from ._base import SQLBaseStore
+from .util.id_generators import StreamIdGenerator
logger = logging.getLogger(__name__)
@@ -139,7 +140,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
"""
room_ids = set(room_ids)
- if from_key:
+ if from_key is not None:
+ # Only ask the database about rooms where there have been new
+ # receipts added since `from_key`
room_ids = yield self._receipts_stream_cache.get_entities_changed(
room_ids, from_key
)
@@ -150,7 +153,6 @@ class ReceiptsWorkerStore(SQLBaseStore):
defer.returnValue([ev for res in results.values() for ev in res])
- @cachedInlineCallbacks(num_args=3, tree=True)
def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
"""Get receipts for a single room for sending to clients.
@@ -161,7 +163,19 @@ class ReceiptsWorkerStore(SQLBaseStore):
from the start.
Returns:
- list: A list of receipts.
+ Deferred[list]: A list of receipts.
+ """
+ if from_key is not None:
+ # Check the cache first to see if any new receipts have been added
+ # since`from_key`. If not we can no-op.
+ if not self._receipts_stream_cache.has_entity_changed(room_id, from_key):
+ defer.succeed([])
+
+ return self._get_linearized_receipts_for_room(room_id, to_key, from_key)
+
+ @cachedInlineCallbacks(num_args=3, tree=True)
+ def _get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
+ """See get_linearized_receipts_for_room
"""
def f(txn):
if from_key:
@@ -210,7 +224,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
"content": content,
}])
- @cachedList(cached_method_name="get_linearized_receipts_for_room",
+ @cachedList(cached_method_name="_get_linearized_receipts_for_room",
list_name="room_ids", num_args=3, inlineCallbacks=True)
def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
if not room_ids:
@@ -372,7 +386,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
self.get_receipts_for_user.invalidate, (user_id, receipt_type)
)
# FIXME: This shouldn't invalidate the whole cache
- txn.call_after(self.get_linearized_receipts_for_room.invalidate_many, (room_id,))
+ txn.call_after(self._get_linearized_receipts_for_room.invalidate_many, (room_id,))
txn.call_after(
self._receipts_stream_cache.entity_has_changed,
@@ -492,7 +506,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
self.get_receipts_for_user.invalidate, (user_id, receipt_type)
)
# FIXME: This shouldn't invalidate the whole cache
- txn.call_after(self.get_linearized_receipts_for_room.invalidate_many, (room_id,))
+ txn.call_after(self._get_linearized_receipts_for_room.invalidate_many, (room_id,))
self._simple_delete_txn(
txn,
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index c241167fbe..26b429e307 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -15,17 +15,22 @@
import re
+from six.moves import range
+
from twisted.internet import defer
-from synapse.api.errors import StoreError, Codes
+from synapse.api.errors import Codes, StoreError
from synapse.storage import background_updates
from synapse.storage._base import SQLBaseStore
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
-from six.moves import range
-
class RegistrationWorkerStore(SQLBaseStore):
+ def __init__(self, db_conn, hs):
+ super(RegistrationWorkerStore, self).__init__(db_conn, hs)
+
+ self.config = hs.config
+
@cached()
def get_user_by_id(self, user_id):
return self._simple_select_one(
@@ -36,12 +41,33 @@ class RegistrationWorkerStore(SQLBaseStore):
retcols=[
"name", "password_hash", "is_guest",
"consent_version", "consent_server_notice_sent",
- "appservice_id",
+ "appservice_id", "creation_ts",
],
allow_none=True,
desc="get_user_by_id",
)
+ @defer.inlineCallbacks
+ def is_trial_user(self, user_id):
+ """Checks if user is in the "trial" period, i.e. within the first
+ N days of registration defined by `mau_trial_days` config
+
+ Args:
+ user_id (str)
+
+ Returns:
+ Deferred[bool]
+ """
+
+ info = yield self.get_user_by_id(user_id)
+ if not info:
+ defer.returnValue(False)
+
+ now = self.clock.time_msec()
+ trial_duration_ms = self.config.mau_trial_days * 24 * 60 * 60 * 1000
+ is_trial = (now - info["creation_ts"] * 1000) < trial_duration_ms
+ defer.returnValue(is_trial)
+
@cached()
def get_user_by_access_token(self, token):
"""Get a user from the given access token.
@@ -460,15 +486,6 @@ class RegistrationStore(RegistrationWorkerStore,
defer.returnValue(ret['user_id'])
defer.returnValue(None)
- def user_delete_threepids(self, user_id):
- return self._simple_delete(
- "user_threepids",
- keyvalues={
- "user_id": user_id,
- },
- desc="user_delete_threepids",
- )
-
def user_delete_threepid(self, user_id, medium, address):
return self._simple_delete(
"user_threepids",
@@ -632,7 +649,9 @@ class RegistrationStore(RegistrationWorkerStore,
Removes the given user to the table of users who need to be parted from all the
rooms they're in, effectively marking that user as fully deactivated.
"""
- return self._simple_delete_one(
+ # XXX: This should be simple_delete_one but we failed to put a unique index on
+ # the table, so somehow duplicate entries have ended up in it.
+ return self._simple_delete(
"users_pending_deactivation",
keyvalues={
"user_id": user_id,
diff --git a/synapse/storage/rejections.py b/synapse/storage/rejections.py
index 40acb5c4ed..880f047adb 100644
--- a/synapse/storage/rejections.py
+++ b/synapse/storage/rejections.py
@@ -13,10 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import SQLBaseStore
-
import logging
+from ._base import SQLBaseStore
+
logger = logging.getLogger(__name__)
diff --git a/synapse/storage/room.py b/synapse/storage/room.py
index ea6a189185..61013b8919 100644
--- a/synapse/storage/room.py
+++ b/synapse/storage/room.py
@@ -13,6 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import collections
+import logging
+import re
+
+from canonicaljson import json
+
from twisted.internet import defer
from synapse.api.errors import StoreError
@@ -20,11 +26,6 @@ from synapse.storage._base import SQLBaseStore
from synapse.storage.search import SearchStore
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
-import collections
-import logging
-import simplejson as json
-import re
-
logger = logging.getLogger(__name__)
@@ -40,6 +41,22 @@ RatelimitOverride = collections.namedtuple(
class RoomWorkerStore(SQLBaseStore):
+ def get_room(self, room_id):
+ """Retrieve a room.
+
+ Args:
+ room_id (str): The ID of the room to retrieve.
+ Returns:
+ A namedtuple containing the room information, or an empty list.
+ """
+ return self._simple_select_one(
+ table="rooms",
+ keyvalues={"room_id": room_id},
+ retcols=("room_id", "is_public", "creator"),
+ desc="get_room",
+ allow_none=True,
+ )
+
def get_public_room_ids(self):
return self._simple_select_onecol(
table="rooms",
@@ -169,6 +186,35 @@ class RoomWorkerStore(SQLBaseStore):
desc="is_room_blocked",
)
+ @cachedInlineCallbacks(max_entries=10000)
+ def get_ratelimit_for_user(self, user_id):
+ """Check if there are any overrides for ratelimiting for the given
+ user
+
+ Args:
+ user_id (str)
+
+ Returns:
+ RatelimitOverride if there is an override, else None. If the contents
+ of RatelimitOverride are None or 0 then ratelimitng has been
+ disabled for that user entirely.
+ """
+ row = yield self._simple_select_one(
+ table="ratelimit_override",
+ keyvalues={"user_id": user_id},
+ retcols=("messages_per_second", "burst_count"),
+ allow_none=True,
+ desc="get_ratelimit_for_user",
+ )
+
+ if row:
+ defer.returnValue(RatelimitOverride(
+ messages_per_second=row["messages_per_second"],
+ burst_count=row["burst_count"],
+ ))
+ else:
+ defer.returnValue(None)
+
class RoomStore(RoomWorkerStore, SearchStore):
@@ -214,22 +260,6 @@ class RoomStore(RoomWorkerStore, SearchStore):
logger.error("store_room with room_id=%s failed: %s", room_id, e)
raise StoreError(500, "Problem creating room.")
- def get_room(self, room_id):
- """Retrieve a room.
-
- Args:
- room_id (str): The ID of the room to retrieve.
- Returns:
- A namedtuple containing the room information, or an empty list.
- """
- return self._simple_select_one(
- table="rooms",
- keyvalues={"room_id": room_id},
- retcols=("room_id", "is_public", "creator"),
- desc="get_room",
- allow_none=True,
- )
-
@defer.inlineCallbacks
def set_room_is_public(self, room_id, is_public):
def set_room_is_public_txn(txn, next_id):
@@ -468,35 +498,6 @@ class RoomStore(RoomWorkerStore, SearchStore):
"get_all_new_public_rooms", get_all_new_public_rooms
)
- @cachedInlineCallbacks(max_entries=10000)
- def get_ratelimit_for_user(self, user_id):
- """Check if there are any overrides for ratelimiting for the given
- user
-
- Args:
- user_id (str)
-
- Returns:
- RatelimitOverride if there is an override, else None. If the contents
- of RatelimitOverride are None or 0 then ratelimitng has been
- disabled for that user entirely.
- """
- row = yield self._simple_select_one(
- table="ratelimit_override",
- keyvalues={"user_id": user_id},
- retcols=("messages_per_second", "burst_count"),
- allow_none=True,
- desc="get_ratelimit_for_user",
- )
-
- if row:
- defer.returnValue(RatelimitOverride(
- messages_per_second=row["messages_per_second"],
- burst_count=row["burst_count"],
- ))
- else:
- defer.returnValue(None)
-
@defer.inlineCallbacks
def block_room(self, room_id, user_id):
yield self._simple_insert(
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 48a88f755e..9b4e6d6aa8 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -14,23 +14,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
-
+import logging
from collections import namedtuple
-from synapse.storage.events import EventsWorkerStore
-from synapse.util.async import Linearizer
-from synapse.util.caches import intern_string
-from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
-from synapse.util.stringutils import to_ascii
+from six import iteritems, itervalues
-from synapse.api.constants import Membership, EventTypes
-from synapse.types import get_domain_from_id
+from canonicaljson import json
-import logging
-import simplejson as json
+from twisted.internet import defer
-from six import itervalues, iteritems
+from synapse.api.constants import EventTypes, Membership
+from synapse.storage.events_worker import EventsWorkerStore
+from synapse.types import get_domain_from_id
+from synapse.util.async_helpers import Linearizer
+from synapse.util.caches import intern_string
+from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
+from synapse.util.stringutils import to_ascii
logger = logging.getLogger(__name__)
@@ -233,6 +232,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
defer.returnValue(user_who_share_room)
+ @defer.inlineCallbacks
def get_joined_users_from_context(self, event, context):
state_group = context.state_group
if not state_group:
@@ -242,11 +242,13 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# To do this we set the state_group to a new object as object() != object()
state_group = object()
- return self._get_joined_users_from_context(
- event.room_id, state_group, context.current_state_ids,
+ current_state_ids = yield context.get_current_state_ids(self)
+ result = yield self._get_joined_users_from_context(
+ event.room_id, state_group, current_state_ids,
event=event,
context=context,
)
+ defer.returnValue(result)
def get_joined_users_from_state(self, room_id, state_entry):
state_group = state_entry.state_group
@@ -455,21 +457,33 @@ class RoomMemberWorkerStore(EventsWorkerStore):
defer.returnValue(joined_hosts)
- @cached(max_entries=10000, iterable=True)
+ @cached(max_entries=10000)
def _get_joined_hosts_cache(self, room_id):
return _JoinedHostsCache(self, room_id)
- @cached()
- def who_forgot_in_room(self, room_id):
- return self._simple_select_list(
- table="room_memberships",
- retcols=("user_id", "event_id"),
- keyvalues={
- "room_id": room_id,
- "forgotten": 1,
- },
- desc="who_forgot"
- )
+ @cachedInlineCallbacks(num_args=2)
+ def did_forget(self, user_id, room_id):
+ """Returns whether user_id has elected to discard history for room_id.
+
+ Returns False if they have since re-joined."""
+ def f(txn):
+ sql = (
+ "SELECT"
+ " COUNT(*)"
+ " FROM"
+ " room_memberships"
+ " WHERE"
+ " user_id = ?"
+ " AND"
+ " room_id = ?"
+ " AND"
+ " forgotten = 0"
+ )
+ txn.execute(sql, (user_id, room_id))
+ rows = txn.fetchall()
+ return rows[0][0]
+ count = yield self.runInteraction("did_forget_membership", f)
+ defer.returnValue(count == 0)
class RoomMemberStore(RoomMemberWorkerStore):
@@ -578,36 +592,11 @@ class RoomMemberStore(RoomMemberWorkerStore):
)
txn.execute(sql, (user_id, room_id))
- txn.call_after(self.did_forget.invalidate, (user_id, room_id))
self._invalidate_cache_and_stream(
- txn, self.who_forgot_in_room, (room_id,)
+ txn, self.did_forget, (user_id, room_id,),
)
return self.runInteraction("forget_membership", f)
- @cachedInlineCallbacks(num_args=2)
- def did_forget(self, user_id, room_id):
- """Returns whether user_id has elected to discard history for room_id.
-
- Returns False if they have since re-joined."""
- def f(txn):
- sql = (
- "SELECT"
- " COUNT(*)"
- " FROM"
- " room_memberships"
- " WHERE"
- " user_id = ?"
- " AND"
- " room_id = ?"
- " AND"
- " forgotten = 0"
- )
- txn.execute(sql, (user_id, room_id))
- rows = txn.fetchall()
- return rows[0][0]
- count = yield self.runInteraction("did_forget_membership", f)
- defer.returnValue(count == 0)
-
@defer.inlineCallbacks
def _background_add_membership_profile(self, progress, batch_size):
target_min_stream_id = progress.get(
diff --git a/synapse/storage/schema/delta/25/fts.py b/synapse/storage/schema/delta/25/fts.py
index e7351c3ae6..4b2ffd35fd 100644
--- a/synapse/storage/schema/delta/25/fts.py
+++ b/synapse/storage/schema/delta/25/fts.py
@@ -14,11 +14,11 @@
import logging
-from synapse.storage.prepare_database import get_statements
-from synapse.storage.engines import PostgresEngine, Sqlite3Engine
-
import simplejson
+from synapse.storage.engines import PostgresEngine, Sqlite3Engine
+from synapse.storage.prepare_database import get_statements
+
logger = logging.getLogger(__name__)
diff --git a/synapse/storage/schema/delta/27/ts.py b/synapse/storage/schema/delta/27/ts.py
index 6df57b5206..414f9f5aa0 100644
--- a/synapse/storage/schema/delta/27/ts.py
+++ b/synapse/storage/schema/delta/27/ts.py
@@ -14,10 +14,10 @@
import logging
-from synapse.storage.prepare_database import get_statements
-
import simplejson
+from synapse.storage.prepare_database import get_statements
+
logger = logging.getLogger(__name__)
diff --git a/synapse/storage/schema/delta/30/as_users.py b/synapse/storage/schema/delta/30/as_users.py
index 85bd1a2006..ef7ec34346 100644
--- a/synapse/storage/schema/delta/30/as_users.py
+++ b/synapse/storage/schema/delta/30/as_users.py
@@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from synapse.config.appservice import load_appservices
from six.moves import range
+from synapse.config.appservice import load_appservices
logger = logging.getLogger(__name__)
diff --git a/synapse/storage/schema/delta/31/search_update.py b/synapse/storage/schema/delta/31/search_update.py
index fe6b7d196d..7d8ca5f93f 100644
--- a/synapse/storage/schema/delta/31/search_update.py
+++ b/synapse/storage/schema/delta/31/search_update.py
@@ -12,12 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.storage.engines import PostgresEngine
-from synapse.storage.prepare_database import get_statements
-
import logging
+
import simplejson
+from synapse.storage.engines import PostgresEngine
+from synapse.storage.prepare_database import get_statements
+
logger = logging.getLogger(__name__)
diff --git a/synapse/storage/schema/delta/33/event_fields.py b/synapse/storage/schema/delta/33/event_fields.py
index 1e002f9db2..bff1256a7b 100644
--- a/synapse/storage/schema/delta/33/event_fields.py
+++ b/synapse/storage/schema/delta/33/event_fields.py
@@ -12,11 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.storage.prepare_database import get_statements
-
import logging
+
import simplejson
+from synapse.storage.prepare_database import get_statements
+
logger = logging.getLogger(__name__)
diff --git a/synapse/storage/schema/delta/33/remote_media_ts.py b/synapse/storage/schema/delta/33/remote_media_ts.py
index 55ae43f395..9754d3ccfb 100644
--- a/synapse/storage/schema/delta/33/remote_media_ts.py
+++ b/synapse/storage/schema/delta/33/remote_media_ts.py
@@ -14,7 +14,6 @@
import time
-
ALTER_TABLE = "ALTER TABLE remote_media_cache ADD COLUMN last_access_ts BIGINT"
diff --git a/synapse/storage/schema/delta/34/cache_stream.py b/synapse/storage/schema/delta/34/cache_stream.py
index 3b63a1562d..cf09e43e2b 100644
--- a/synapse/storage/schema/delta/34/cache_stream.py
+++ b/synapse/storage/schema/delta/34/cache_stream.py
@@ -12,11 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.storage.prepare_database import get_statements
-from synapse.storage.engines import PostgresEngine
-
import logging
+from synapse.storage.engines import PostgresEngine
+from synapse.storage.prepare_database import get_statements
+
logger = logging.getLogger(__name__)
diff --git a/synapse/storage/schema/delta/34/received_txn_purge.py b/synapse/storage/schema/delta/34/received_txn_purge.py
index 033144341c..67d505e68b 100644
--- a/synapse/storage/schema/delta/34/received_txn_purge.py
+++ b/synapse/storage/schema/delta/34/received_txn_purge.py
@@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.storage.engines import PostgresEngine
-
import logging
+from synapse.storage.engines import PostgresEngine
+
logger = logging.getLogger(__name__)
diff --git a/synapse/storage/schema/delta/34/sent_txn_purge.py b/synapse/storage/schema/delta/34/sent_txn_purge.py
index 81948e3431..0ffab10b6f 100644
--- a/synapse/storage/schema/delta/34/sent_txn_purge.py
+++ b/synapse/storage/schema/delta/34/sent_txn_purge.py
@@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.storage.engines import PostgresEngine
-
import logging
+from synapse.storage.engines import PostgresEngine
+
logger = logging.getLogger(__name__)
diff --git a/synapse/storage/schema/delta/37/remove_auth_idx.py b/synapse/storage/schema/delta/37/remove_auth_idx.py
index 20ad8bd5a6..a377884169 100644
--- a/synapse/storage/schema/delta/37/remove_auth_idx.py
+++ b/synapse/storage/schema/delta/37/remove_auth_idx.py
@@ -12,11 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.storage.prepare_database import get_statements
-from synapse.storage.engines import PostgresEngine
-
import logging
+from synapse.storage.engines import PostgresEngine
+from synapse.storage.prepare_database import get_statements
+
logger = logging.getLogger(__name__)
DROP_INDICES = """
diff --git a/synapse/storage/schema/delta/42/user_dir.py b/synapse/storage/schema/delta/42/user_dir.py
index ea6a18196d..506f326f4d 100644
--- a/synapse/storage/schema/delta/42/user_dir.py
+++ b/synapse/storage/schema/delta/42/user_dir.py
@@ -14,8 +14,8 @@
import logging
-from synapse.storage.prepare_database import get_statements
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
+from synapse.storage.prepare_database import get_statements
logger = logging.getLogger(__name__)
diff --git a/synapse/storage/schema/delta/50/erasure_store.sql b/synapse/storage/schema/delta/50/erasure_store.sql
new file mode 100644
index 0000000000..5d8641a9ab
--- /dev/null
+++ b/synapse/storage/schema/delta/50/erasure_store.sql
@@ -0,0 +1,21 @@
+/* Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- a table of users who have requested that their details be erased
+CREATE TABLE erased_users (
+ user_id TEXT NOT NULL
+);
+
+CREATE UNIQUE INDEX erased_users_user ON erased_users(user_id);
diff --git a/synapse/storage/schema/delta/50/make_event_content_nullable.py b/synapse/storage/schema/delta/50/make_event_content_nullable.py
new file mode 100644
index 0000000000..6dd467b6c5
--- /dev/null
+++ b/synapse/storage/schema/delta/50/make_event_content_nullable.py
@@ -0,0 +1,92 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+We want to stop populating 'event.content', so we need to make it nullable.
+
+If this has to be rolled back, then the following should populate the missing data:
+
+Postgres:
+
+ UPDATE events SET content=(ej.json::json)->'content' FROM event_json ej
+ WHERE ej.event_id = events.event_id AND
+ stream_ordering < (
+ SELECT stream_ordering FROM events WHERE content IS NOT NULL
+ ORDER BY stream_ordering LIMIT 1
+ );
+
+ UPDATE events SET content=(ej.json::json)->'content' FROM event_json ej
+ WHERE ej.event_id = events.event_id AND
+ stream_ordering > (
+ SELECT stream_ordering FROM events WHERE content IS NOT NULL
+ ORDER BY stream_ordering DESC LIMIT 1
+ );
+
+SQLite:
+
+ UPDATE events SET content=(
+ SELECT json_extract(json,'$.content') FROM event_json ej
+ WHERE ej.event_id = events.event_id
+ )
+ WHERE
+ stream_ordering < (
+ SELECT stream_ordering FROM events WHERE content IS NOT NULL
+ ORDER BY stream_ordering LIMIT 1
+ )
+ OR stream_ordering > (
+ SELECT stream_ordering FROM events WHERE content IS NOT NULL
+ ORDER BY stream_ordering DESC LIMIT 1
+ );
+
+"""
+
+import logging
+
+from synapse.storage.engines import PostgresEngine
+
+logger = logging.getLogger(__name__)
+
+
+def run_create(cur, database_engine, *args, **kwargs):
+ pass
+
+
+def run_upgrade(cur, database_engine, *args, **kwargs):
+ if isinstance(database_engine, PostgresEngine):
+ cur.execute("""
+ ALTER TABLE events ALTER COLUMN content DROP NOT NULL;
+ """)
+ return
+
+ # sqlite is an arse about this. ref: https://www.sqlite.org/lang_altertable.html
+
+ cur.execute("SELECT sql FROM sqlite_master WHERE tbl_name='events' AND type='table'")
+ (oldsql,) = cur.fetchone()
+
+ sql = oldsql.replace("content TEXT NOT NULL", "content TEXT")
+ if sql == oldsql:
+ raise Exception("Couldn't find null constraint to drop in %s" % oldsql)
+
+ logger.info("Replacing definition of 'events' with: %s", sql)
+
+ cur.execute("PRAGMA schema_version")
+ (oldver,) = cur.fetchone()
+ cur.execute("PRAGMA writable_schema=ON")
+ cur.execute(
+ "UPDATE sqlite_master SET sql=? WHERE tbl_name='events' AND type='table'",
+ (sql, ),
+ )
+ cur.execute("PRAGMA schema_version=%i" % (oldver + 1,))
+ cur.execute("PRAGMA writable_schema=OFF")
diff --git a/synapse/storage/schema/delta/51/monthly_active_users.sql b/synapse/storage/schema/delta/51/monthly_active_users.sql
new file mode 100644
index 0000000000..c9d537d5a3
--- /dev/null
+++ b/synapse/storage/schema/delta/51/monthly_active_users.sql
@@ -0,0 +1,27 @@
+/* Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- a table of monthly active users, for use where blocking based on mau limits
+CREATE TABLE monthly_active_users (
+ user_id TEXT NOT NULL,
+ -- Last time we saw the user. Not guaranteed to be accurate due to rate limiting
+ -- on updates, Granularity of updates governed by
+ -- synapse.storage.monthly_active_users.LAST_SEEN_GRANULARITY
+ -- Measured in ms since epoch.
+ timestamp BIGINT NOT NULL
+);
+
+CREATE UNIQUE INDEX monthly_active_users_users ON monthly_active_users(user_id);
+CREATE INDEX monthly_active_users_time_stamp ON monthly_active_users(timestamp);
diff --git a/synapse/storage/schema/full_schemas/16/event_edges.sql b/synapse/storage/schema/full_schemas/16/event_edges.sql
index 52eec88357..6b5a5a88fa 100644
--- a/synapse/storage/schema/full_schemas/16/event_edges.sql
+++ b/synapse/storage/schema/full_schemas/16/event_edges.sql
@@ -37,7 +37,8 @@ CREATE TABLE IF NOT EXISTS event_edges(
event_id TEXT NOT NULL,
prev_event_id TEXT NOT NULL,
room_id TEXT NOT NULL,
- is_state BOOL NOT NULL,
+ is_state BOOL NOT NULL, -- true if this is a prev_state edge rather than a regular
+ -- event dag edge.
UNIQUE (event_id, prev_event_id, room_id, is_state)
);
diff --git a/synapse/storage/schema/full_schemas/16/im.sql b/synapse/storage/schema/full_schemas/16/im.sql
index ba5346806e..5f5cb8d01d 100644
--- a/synapse/storage/schema/full_schemas/16/im.sql
+++ b/synapse/storage/schema/full_schemas/16/im.sql
@@ -19,7 +19,12 @@ CREATE TABLE IF NOT EXISTS events(
event_id TEXT NOT NULL,
type TEXT NOT NULL,
room_id TEXT NOT NULL,
- content TEXT NOT NULL,
+
+ -- 'content' used to be created NULLable, but as of delta 50 we drop that constraint.
+ -- the hack we use to drop the constraint doesn't work for an in-memory sqlite
+ -- database, which breaks the sytests. Hence, we no longer make it nullable.
+ content TEXT,
+
unrecognized_keys TEXT,
processed BOOL NOT NULL,
outlier BOOL NOT NULL,
diff --git a/synapse/storage/search.py b/synapse/storage/search.py
index f0fa5d7631..d5b5df93e6 100644
--- a/synapse/storage/search.py
+++ b/synapse/storage/search.py
@@ -13,19 +13,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from collections import namedtuple
import logging
import re
-import simplejson as json
+from collections import namedtuple
from six import string_types
+from canonicaljson import json
+
from twisted.internet import defer
-from .background_updates import BackgroundUpdateStore
from synapse.api.errors import SynapseError
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
+from .background_updates import BackgroundUpdateStore
+
logger = logging.getLogger(__name__)
SearchEntry = namedtuple('SearchEntry', [
diff --git a/synapse/storage/signatures.py b/synapse/storage/signatures.py
index 25922e5a9c..5623391f6e 100644
--- a/synapse/storage/signatures.py
+++ b/synapse/storage/signatures.py
@@ -13,15 +13,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
import six
-from ._base import SQLBaseStore
-
from unpaddedbase64 import encode_base64
+
+from twisted.internet import defer
+
from synapse.crypto.event_signing import compute_event_reference_hash
from synapse.util.caches.descriptors import cached, cachedList
+from ._base import SQLBaseStore
+
# py2 sqlite has buffer hardcoded as only binary type, so we must use it,
# despite being deprecated and removed in favor of memoryview
if six.PY2:
@@ -72,7 +74,7 @@ class SignatureWorkerStore(SQLBaseStore):
txn (cursor):
event_id (str): Id for the Event.
Returns:
- A dict of algorithm -> hash.
+ A dict[unicode, bytes] of algorithm -> hash.
"""
query = (
"SELECT algorithm, hash"
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 85b8ec2b8f..4b971efdba 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -13,21 +13,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from collections import namedtuple
import logging
+from collections import namedtuple
from six import iteritems, itervalues
from six.moves import range
from twisted.internet import defer
+from synapse.api.constants import EventTypes
+from synapse.api.errors import NotFoundError
+from synapse.storage._base import SQLBaseStore
from synapse.storage.background_updates import BackgroundUpdateStore
from synapse.storage.engines import PostgresEngine
-from synapse.util.caches import intern_string, get_cache_factor_for
+from synapse.storage.events_worker import EventsWorkerStore
+from synapse.util.caches import get_cache_factor_for, intern_string
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.dictionary_cache import DictionaryCache
from synapse.util.stringutils import to_ascii
-from ._base import SQLBaseStore
logger = logging.getLogger(__name__)
@@ -45,7 +48,8 @@ class _GetStateGroupDelta(namedtuple("_GetStateGroupDelta", ("prev_group", "delt
return len(self.delta_ids) if self.delta_ids else 0
-class StateGroupWorkerStore(SQLBaseStore):
+# this inherits from EventsWorkerStore because it calls self.get_events
+class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
"""The parts of StateGroupStore that can be called from workers.
"""
@@ -56,9 +60,68 @@ class StateGroupWorkerStore(SQLBaseStore):
def __init__(self, db_conn, hs):
super(StateGroupWorkerStore, self).__init__(db_conn, hs)
+ # Originally the state store used a single DictionaryCache to cache the
+ # event IDs for the state types in a given state group to avoid hammering
+ # on the state_group* tables.
+ #
+ # The point of using a DictionaryCache is that it can cache a subset
+ # of the state events for a given state group (i.e. a subset of the keys for a
+ # given dict which is an entry in the cache for a given state group ID).
+ #
+ # However, this poses problems when performing complicated queries
+ # on the store - for instance: "give me all the state for this group, but
+ # limit members to this subset of users", as DictionaryCache's API isn't
+ # rich enough to say "please cache any of these fields, apart from this subset".
+ # This is problematic when lazy loading members, which requires this behaviour,
+ # as without it the cache has no choice but to speculatively load all
+ # state events for the group, which negates the efficiency being sought.
+ #
+ # Rather than overcomplicating DictionaryCache's API, we instead split the
+ # state_group_cache into two halves - one for tracking non-member events,
+ # and the other for tracking member_events. This means that lazy loading
+ # queries can be made in a cache-friendly manner by querying both caches
+ # separately and then merging the result. So for the example above, you
+ # would query the members cache for a specific subset of state keys
+ # (which DictionaryCache will handle efficiently and fine) and the non-members
+ # cache for all state (which DictionaryCache will similarly handle fine)
+ # and then just merge the results together.
+ #
+ # We size the non-members cache to be smaller than the members cache as the
+ # vast majority of state in Matrix (today) is member events.
+
self._state_group_cache = DictionaryCache(
- "*stateGroupCache*", 500000 * get_cache_factor_for("stateGroupCache")
+ "*stateGroupCache*",
+ # TODO: this hasn't been tuned yet
+ 50000 * get_cache_factor_for("stateGroupCache")
)
+ self._state_group_members_cache = DictionaryCache(
+ "*stateGroupMembersCache*",
+ 500000 * get_cache_factor_for("stateGroupMembersCache")
+ )
+
+ @defer.inlineCallbacks
+ def get_room_version(self, room_id):
+ """Get the room_version of a given room
+
+ Args:
+ room_id (str)
+
+ Returns:
+ Deferred[str]
+
+ Raises:
+ NotFoundError if the room is unknown
+ """
+ # for now we do this by looking at the create event. We may want to cache this
+ # more intelligently in future.
+ state_ids = yield self.get_current_state_ids(room_id)
+ create_id = state_ids.get((EventTypes.Create, ""))
+
+ if not create_id:
+ raise NotFoundError("Unknown room")
+
+ create_event = yield self.get_event(create_id)
+ defer.returnValue(create_event.content.get("room_version", "1"))
@cached(max_entries=100000, iterable=True)
def get_current_state_ids(self, room_id):
@@ -88,6 +151,69 @@ class StateGroupWorkerStore(SQLBaseStore):
_get_current_state_ids_txn,
)
+ # FIXME: how should this be cached?
+ def get_filtered_current_state_ids(self, room_id, types, filtered_types=None):
+ """Get the current state event of a given type for a room based on the
+ current_state_events table. This may not be as up-to-date as the result
+ of doing a fresh state resolution as per state_handler.get_current_state
+ Args:
+ room_id (str)
+ types (list[(Str, (Str|None))]): List of (type, state_key) tuples
+ which are used to filter the state fetched. `state_key` may be
+ None, which matches any `state_key`
+ filtered_types (list[Str]|None): List of types to apply the above filter to.
+ Returns:
+ deferred: dict of (type, state_key) -> event
+ """
+
+ include_other_types = False if filtered_types is None else True
+
+ def _get_filtered_current_state_ids_txn(txn):
+ results = {}
+ sql = """SELECT type, state_key, event_id FROM current_state_events
+ WHERE room_id = ? %s"""
+ # Turns out that postgres doesn't like doing a list of OR's and
+ # is about 1000x slower, so we just issue a query for each specific
+ # type seperately.
+ if types:
+ clause_to_args = [
+ (
+ "AND type = ? AND state_key = ?",
+ (etype, state_key)
+ ) if state_key is not None else (
+ "AND type = ?",
+ (etype,)
+ )
+ for etype, state_key in types
+ ]
+
+ if include_other_types:
+ unique_types = set(filtered_types)
+ clause_to_args.append(
+ (
+ "AND type <> ? " * len(unique_types),
+ list(unique_types)
+ )
+ )
+ else:
+ # If types is None we fetch all the state, and so just use an
+ # empty where clause with no extra args.
+ clause_to_args = [("", [])]
+ for where_clause, where_args in clause_to_args:
+ args = [room_id]
+ args.extend(where_args)
+ txn.execute(sql % (where_clause,), args)
+ for row in txn:
+ typ, state_key, event_id = row
+ key = (intern_string(typ), intern_string(state_key))
+ results[key] = event_id
+ return results
+
+ return self.runInteraction(
+ "get_filtered_current_state_ids",
+ _get_filtered_current_state_ids_txn,
+ )
+
@cached(max_entries=10000, iterable=True)
def get_state_group_delta(self, state_group):
"""Given a state group try to return a previous group and a delta between
@@ -184,8 +310,21 @@ class StateGroupWorkerStore(SQLBaseStore):
})
@defer.inlineCallbacks
- def _get_state_groups_from_groups(self, groups, types):
- """Returns dictionary state_group -> (dict of (type, state_key) -> event id)
+ def _get_state_groups_from_groups(self, groups, types, members=None):
+ """Returns the state groups for a given set of groups, filtering on
+ types of state events.
+
+ Args:
+ groups(list[int]): list of state group IDs to query
+ types (Iterable[str, str|None]|None): list of 2-tuples of the form
+ (`type`, `state_key`), where a `state_key` of `None` matches all
+ state_keys for the `type`. If None, all types are returned.
+ members (bool|None): If not None, then, in addition to any filtering
+ implied by types, the results are also filtered to only include
+ member events (if True), or to exclude member events (if False)
+
+ Returns:
+ dictionary state_group -> (dict of (type, state_key) -> event id)
"""
results = {}
@@ -193,14 +332,17 @@ class StateGroupWorkerStore(SQLBaseStore):
for chunk in chunks:
res = yield self.runInteraction(
"_get_state_groups_from_groups",
- self._get_state_groups_from_groups_txn, chunk, types,
+ self._get_state_groups_from_groups_txn, chunk, types, members,
)
results.update(res)
defer.returnValue(results)
- def _get_state_groups_from_groups_txn(self, txn, groups, types=None):
+ def _get_state_groups_from_groups_txn(
+ self, txn, groups, types=None, members=None,
+ ):
results = {group: {} for group in groups}
+
if types is not None:
types = list(set(types)) # deduplicate types list
@@ -235,10 +377,15 @@ class StateGroupWorkerStore(SQLBaseStore):
%s
""")
+ if members is True:
+ sql += " AND type = '%s'" % (EventTypes.Member,)
+ elif members is False:
+ sql += " AND type <> '%s'" % (EventTypes.Member,)
+
# Turns out that postgres doesn't like doing a list of OR's and
# is about 1000x slower, so we just issue a query for each specific
# type seperately.
- if types:
+ if types is not None:
clause_to_args = [
(
"AND type = ? AND state_key = ?",
@@ -277,10 +424,16 @@ class StateGroupWorkerStore(SQLBaseStore):
else:
where_clauses.append("(type = ? AND state_key = ?)")
where_args.extend([typ[0], typ[1]])
+
where_clause = "AND (%s)" % (" OR ".join(where_clauses))
else:
where_clause = ""
+ if members is True:
+ where_clause += " AND type = '%s'" % EventTypes.Member
+ elif members is False:
+ where_clause += " AND type <> '%s'" % EventTypes.Member
+
# We don't use WITH RECURSIVE on sqlite3 as there are distributions
# that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
for group in groups:
@@ -331,27 +484,30 @@ class StateGroupWorkerStore(SQLBaseStore):
return results
@defer.inlineCallbacks
- def get_state_for_events(self, event_ids, types):
+ def get_state_for_events(self, event_ids, types, filtered_types=None):
"""Given a list of event_ids and type tuples, return a list of state
dicts for each event. The state dicts will only have the type/state_keys
that are in the `types` list.
Args:
- event_ids (list)
- types (list): List of (type, state_key) tuples which are used to
- filter the state fetched. `state_key` may be None, which matches
- any `state_key`
+ event_ids (list[string])
+ types (list[(str, str|None)]|None): List of (type, state_key) tuples
+ which are used to filter the state fetched. If `state_key` is None,
+ all events are returned of the given type.
+ May be None, which matches any key.
+ filtered_types(list[str]|None): Only apply filtering via `types` to this
+ list of event types. Other types of events are returned unfiltered.
+ If None, `types` filtering is applied to all events.
Returns:
- deferred: A list of dicts corresponding to the event_ids given.
- The dicts are mappings from (type, state_key) -> state_events
+ deferred: A dict of (event_id) -> (type, state_key) -> [state_events]
"""
event_to_groups = yield self._get_state_group_for_events(
event_ids,
)
groups = set(itervalues(event_to_groups))
- group_to_state = yield self._get_state_for_groups(groups, types)
+ group_to_state = yield self._get_state_for_groups(groups, types, filtered_types)
state_event_map = yield self.get_events(
[ev_id for sd in itervalues(group_to_state) for ev_id in itervalues(sd)],
@@ -370,25 +526,30 @@ class StateGroupWorkerStore(SQLBaseStore):
defer.returnValue({event: event_to_state[event] for event in event_ids})
@defer.inlineCallbacks
- def get_state_ids_for_events(self, event_ids, types=None):
+ def get_state_ids_for_events(self, event_ids, types=None, filtered_types=None):
"""
- Get the state dicts corresponding to a list of events
+ Get the state dicts corresponding to a list of events, containing the event_ids
+ of the state events (as opposed to the events themselves)
Args:
event_ids(list(str)): events whose state should be returned
- types(list[(str, str)]|None): List of (type, state_key) tuples
- which are used to filter the state fetched. May be None, which
- matches any key
+ types(list[(str, str|None)]|None): List of (type, state_key) tuples
+ which are used to filter the state fetched. If `state_key` is None,
+ all events are returned of the given type.
+ May be None, which matches any key.
+ filtered_types(list[str]|None): Only apply filtering via `types` to this
+ list of event types. Other types of events are returned unfiltered.
+ If None, `types` filtering is applied to all events.
Returns:
- A deferred dict from event_id -> (type, state_key) -> state_event
+ A deferred dict from event_id -> (type, state_key) -> event_id
"""
event_to_groups = yield self._get_state_group_for_events(
event_ids,
)
groups = set(itervalues(event_to_groups))
- group_to_state = yield self._get_state_for_groups(groups, types)
+ group_to_state = yield self._get_state_for_groups(groups, types, filtered_types)
event_to_state = {
event_id: group_to_state[group]
@@ -398,37 +559,45 @@ class StateGroupWorkerStore(SQLBaseStore):
defer.returnValue({event: event_to_state[event] for event in event_ids})
@defer.inlineCallbacks
- def get_state_for_event(self, event_id, types=None):
+ def get_state_for_event(self, event_id, types=None, filtered_types=None):
"""
Get the state dict corresponding to a particular event
Args:
event_id(str): event whose state should be returned
- types(list[(str, str)]|None): List of (type, state_key) tuples
- which are used to filter the state fetched. May be None, which
- matches any key
+ types(list[(str, str|None)]|None): List of (type, state_key) tuples
+ which are used to filter the state fetched. If `state_key` is None,
+ all events are returned of the given type.
+ May be None, which matches any key.
+ filtered_types(list[str]|None): Only apply filtering via `types` to this
+ list of event types. Other types of events are returned unfiltered.
+ If None, `types` filtering is applied to all events.
Returns:
A deferred dict from (type, state_key) -> state_event
"""
- state_map = yield self.get_state_for_events([event_id], types)
+ state_map = yield self.get_state_for_events([event_id], types, filtered_types)
defer.returnValue(state_map[event_id])
@defer.inlineCallbacks
- def get_state_ids_for_event(self, event_id, types=None):
+ def get_state_ids_for_event(self, event_id, types=None, filtered_types=None):
"""
Get the state dict corresponding to a particular event
Args:
event_id(str): event whose state should be returned
- types(list[(str, str)]|None): List of (type, state_key) tuples
- which are used to filter the state fetched. May be None, which
- matches any key
+ types(list[(str, str|None)]|None): List of (type, state_key) tuples
+ which are used to filter the state fetched. If `state_key` is None,
+ all events are returned of the given type.
+ May be None, which matches any key.
+ filtered_types(list[str]|None): Only apply filtering via `types` to this
+ list of event types. Other types of events are returned unfiltered.
+ If None, `types` filtering is applied to all events.
Returns:
A deferred dict from (type, state_key) -> state_event
"""
- state_map = yield self.get_state_ids_for_events([event_id], types)
+ state_map = yield self.get_state_ids_for_events([event_id], types, filtered_types)
defer.returnValue(state_map[event_id])
@cached(max_entries=50000)
@@ -459,58 +628,76 @@ class StateGroupWorkerStore(SQLBaseStore):
defer.returnValue({row["event_id"]: row["state_group"] for row in rows})
- def _get_some_state_from_cache(self, group, types):
+ def _get_some_state_from_cache(self, cache, group, types, filtered_types=None):
"""Checks if group is in cache. See `_get_state_for_groups`
- Returns 3-tuple (`state_dict`, `missing_types`, `got_all`).
- `missing_types` is the list of types that aren't in the cache for that
- group. `got_all` is a bool indicating if we successfully retrieved all
+ Args:
+ cache(DictionaryCache): the state group cache to use
+ group(int): The state group to lookup
+ types(list[str, str|None]): List of 2-tuples of the form
+ (`type`, `state_key`), where a `state_key` of `None` matches all
+ state_keys for the `type`.
+ filtered_types(list[str]|None): Only apply filtering via `types` to this
+ list of event types. Other types of events are returned unfiltered.
+ If None, `types` filtering is applied to all events.
+
+ Returns 2-tuple (`state_dict`, `got_all`).
+ `got_all` is a bool indicating if we successfully retrieved all
requests state from the cache, if False we need to query the DB for the
missing state.
-
- Args:
- group: The state group to lookup
- types (list): List of 2-tuples of the form (`type`, `state_key`),
- where a `state_key` of `None` matches all state_keys for the
- `type`.
"""
- is_all, known_absent, state_dict_ids = self._state_group_cache.get(group)
+ is_all, known_absent, state_dict_ids = cache.get(group)
type_to_key = {}
- missing_types = set()
+
+ # tracks whether any of our requested types are missing from the cache
+ missing_types = False
for typ, state_key in types:
key = (typ, state_key)
- if state_key is None:
+
+ if (
+ state_key is None or
+ (filtered_types is not None and typ not in filtered_types)
+ ):
type_to_key[typ] = None
- missing_types.add(key)
+ # we mark the type as missing from the cache because
+ # when the cache was populated it might have been done with a
+ # restricted set of state_keys, so the wildcard will not work
+ # and the cache may be incomplete.
+ missing_types = True
else:
if type_to_key.get(typ, object()) is not None:
type_to_key.setdefault(typ, set()).add(state_key)
if key not in state_dict_ids and key not in known_absent:
- missing_types.add(key)
+ missing_types = True
sentinel = object()
def include(typ, state_key):
valid_state_keys = type_to_key.get(typ, sentinel)
if valid_state_keys is sentinel:
- return False
+ return filtered_types is not None and typ not in filtered_types
if valid_state_keys is None:
return True
if state_key in valid_state_keys:
return True
return False
- got_all = is_all or not missing_types
+ got_all = is_all
+ if not got_all:
+ # the cache is incomplete. We may still have got all the results we need, if
+ # we don't have any wildcards in the match list.
+ if not missing_types and filtered_types is None:
+ got_all = True
return {
k: v for k, v in iteritems(state_dict_ids)
if include(k[0], k[1])
- }, missing_types, got_all
+ }, got_all
- def _get_all_state_from_cache(self, group):
+ def _get_all_state_from_cache(self, cache, group):
"""Checks if group is in cache. See `_get_state_for_groups`
Returns 2-tuple (`state_dict`, `got_all`). `got_all` is a bool
@@ -518,18 +705,91 @@ class StateGroupWorkerStore(SQLBaseStore):
cache, if False we need to query the DB for the missing state.
Args:
+ cache(DictionaryCache): the state group cache to use
group: The state group to lookup
"""
- is_all, _, state_dict_ids = self._state_group_cache.get(group)
+ is_all, _, state_dict_ids = cache.get(group)
return state_dict_ids, is_all
@defer.inlineCallbacks
- def _get_state_for_groups(self, groups, types=None):
- """Given list of groups returns dict of group -> list of state events
- with matching types. `types` is a list of `(type, state_key)`, where
- a `state_key` of None matches all state_keys. If `types` is None then
- all events are returned.
+ def _get_state_for_groups(self, groups, types=None, filtered_types=None):
+ """Gets the state at each of a list of state groups, optionally
+ filtering by type/state_key
+
+ Args:
+ groups (iterable[int]): list of state groups for which we want
+ to get the state.
+ types (None|iterable[(str, None|str)]):
+ indicates the state type/keys required. If None, the whole
+ state is fetched and returned.
+
+ Otherwise, each entry should be a `(type, state_key)` tuple to
+ include in the response. A `state_key` of None is a wildcard
+ meaning that we require all state with that type.
+ filtered_types(list[str]|None): Only apply filtering via `types` to this
+ list of event types. Other types of events are returned unfiltered.
+ If None, `types` filtering is applied to all events.
+
+ Returns:
+ Deferred[dict[int, dict[(type, state_key), EventBase]]]
+ a dictionary mapping from state group to state dictionary.
+ """
+ if types is not None:
+ non_member_types = [t for t in types if t[0] != EventTypes.Member]
+
+ if filtered_types is not None and EventTypes.Member not in filtered_types:
+ # we want all of the membership events
+ member_types = None
+ else:
+ member_types = [t for t in types if t[0] == EventTypes.Member]
+
+ else:
+ non_member_types = None
+ member_types = None
+
+ non_member_state = yield self._get_state_for_groups_using_cache(
+ groups, self._state_group_cache, non_member_types, filtered_types,
+ )
+ # XXX: we could skip this entirely if member_types is []
+ member_state = yield self._get_state_for_groups_using_cache(
+ # we set filtered_types=None as member_state only ever contain members.
+ groups, self._state_group_members_cache, member_types, None,
+ )
+
+ state = non_member_state
+ for group in groups:
+ state[group].update(member_state[group])
+
+ defer.returnValue(state)
+
+ @defer.inlineCallbacks
+ def _get_state_for_groups_using_cache(
+ self, groups, cache, types=None, filtered_types=None
+ ):
+ """Gets the state at each of a list of state groups, optionally
+ filtering by type/state_key, querying from a specific cache.
+
+ Args:
+ groups (iterable[int]): list of state groups for which we want
+ to get the state.
+ cache (DictionaryCache): the cache of group ids to state dicts which
+ we will pass through - either the normal state cache or the specific
+ members state cache.
+ types (None|iterable[(str, None|str)]):
+ indicates the state type/keys required. If None, the whole
+ state is fetched and returned.
+
+ Otherwise, each entry should be a `(type, state_key)` tuple to
+ include in the response. A `state_key` of None is a wildcard
+ meaning that we require all state with that type.
+ filtered_types(list[str]|None): Only apply filtering via `types` to this
+ list of event types. Other types of events are returned unfiltered.
+ If None, `types` filtering is applied to all events.
+
+ Returns:
+ Deferred[dict[int, dict[(type, state_key), EventBase]]]
+ a dictionary mapping from state group to state dictionary.
"""
if types:
types = frozenset(types)
@@ -537,8 +797,8 @@ class StateGroupWorkerStore(SQLBaseStore):
missing_groups = []
if types is not None:
for group in set(groups):
- state_dict_ids, _, got_all = self._get_some_state_from_cache(
- group, types
+ state_dict_ids, got_all = self._get_some_state_from_cache(
+ cache, group, types, filtered_types
)
results[group] = state_dict_ids
@@ -547,7 +807,7 @@ class StateGroupWorkerStore(SQLBaseStore):
else:
for group in set(groups):
state_dict_ids, got_all = self._get_all_state_from_cache(
- group
+ cache, group
)
results[group] = state_dict_ids
@@ -556,29 +816,46 @@ class StateGroupWorkerStore(SQLBaseStore):
missing_groups.append(group)
if missing_groups:
- # Okay, so we have some missing_types, lets fetch them.
- cache_seq_num = self._state_group_cache.sequence
+ # Okay, so we have some missing_types, let's fetch them.
+ cache_seq_num = cache.sequence
+
+ # the DictionaryCache knows if it has *all* the state, but
+ # does not know if it has all of the keys of a particular type,
+ # which makes wildcard lookups expensive unless we have a complete
+ # cache. Hence, if we are doing a wildcard lookup, populate the
+ # cache fully so that we can do an efficient lookup next time.
+
+ if filtered_types or (types and any(k is None for (t, k) in types)):
+ types_to_fetch = None
+ else:
+ types_to_fetch = types
group_to_state_dict = yield self._get_state_groups_from_groups(
- missing_groups, types
+ missing_groups, types_to_fetch, cache == self._state_group_members_cache,
)
- # Now we want to update the cache with all the things we fetched
- # from the database.
for group, group_state_dict in iteritems(group_to_state_dict):
state_dict = results[group]
- state_dict.update(
- ((intern_string(k[0]), intern_string(k[1])), to_ascii(v))
- for k, v in iteritems(group_state_dict)
- )
-
- self._state_group_cache.update(
+ # update the result, filtering by `types`.
+ if types:
+ for k, v in iteritems(group_state_dict):
+ (typ, _) = k
+ if (
+ (k in types or (typ, None) in types) or
+ (filtered_types and typ not in filtered_types)
+ ):
+ state_dict[k] = v
+ else:
+ state_dict.update(group_state_dict)
+
+ # update the cache with all the things we fetched from the
+ # database.
+ cache.update(
cache_seq_num,
key=group,
- value=state_dict,
- full=(types is None),
- known_absent=types,
+ value=group_state_dict,
+ fetched_keys=types_to_fetch,
)
defer.returnValue(results)
@@ -676,16 +953,33 @@ class StateGroupWorkerStore(SQLBaseStore):
],
)
- # Prefill the state group cache with this group.
+ # Prefill the state group caches with this group.
# It's fine to use the sequence like this as the state group map
# is immutable. (If the map wasn't immutable then this prefill could
# race with another update)
+
+ current_member_state_ids = {
+ s: ev
+ for (s, ev) in iteritems(current_state_ids)
+ if s[0] == EventTypes.Member
+ }
+ txn.call_after(
+ self._state_group_members_cache.update,
+ self._state_group_members_cache.sequence,
+ key=state_group,
+ value=dict(current_member_state_ids),
+ )
+
+ current_non_member_state_ids = {
+ s: ev
+ for (s, ev) in iteritems(current_state_ids)
+ if s[0] != EventTypes.Member
+ }
txn.call_after(
self._state_group_cache.update,
self._state_group_cache.sequence,
key=state_group,
- value=dict(current_state_ids),
- full=True,
+ value=dict(current_non_member_state_ids),
)
return state_group
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index fb463c525a..4c296d72c0 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -33,22 +33,20 @@ what sort order was used:
and stream ordering columns respectively.
"""
+import abc
+import logging
+from collections import namedtuple
+
+from six.moves import range
+
from twisted.internet import defer
from synapse.storage._base import SQLBaseStore
-from synapse.storage.events import EventsWorkerStore
-
+from synapse.storage.engines import PostgresEngine
+from synapse.storage.events_worker import EventsWorkerStore
from synapse.types import RoomStreamToken
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.logcontext import make_deferred_yieldable, run_in_background
-from synapse.storage.engines import PostgresEngine
-
-import abc
-import logging
-
-from six.moves import range
-from collections import namedtuple
-
logger = logging.getLogger(__name__)
@@ -350,7 +348,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
end_token (str): The stream token representing now.
Returns:
- Deferred[tuple[list[FrozenEvent], str]]: Returns a list of
+ Deferred[tuple[list[FrozenEvent], str]]: Returns a list of
events and a token pointing to the start of the returned
events.
The events returned are in ascending order.
@@ -381,7 +379,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
end_token (str): The stream token representing now.
Returns:
- Deferred[tuple[list[_EventDictReturn], str]]: Returns a list of
+ Deferred[tuple[list[_EventDictReturn], str]]: Returns a list of
_EventDictReturn and a token pointing to the start of the returned
events.
The events returned are in ascending order.
@@ -529,7 +527,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
)
@defer.inlineCallbacks
- def get_events_around(self, room_id, event_id, before_limit, after_limit):
+ def get_events_around(
+ self, room_id, event_id, before_limit, after_limit, event_filter=None,
+ ):
"""Retrieve events and pagination tokens around a given event in a
room.
@@ -538,6 +538,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
event_id (str)
before_limit (int)
after_limit (int)
+ event_filter (Filter|None)
Returns:
dict
@@ -545,7 +546,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
results = yield self.runInteraction(
"get_events_around", self._get_events_around_txn,
- room_id, event_id, before_limit, after_limit
+ room_id, event_id, before_limit, after_limit, event_filter,
)
events_before = yield self._get_events(
@@ -565,7 +566,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
"end": results["after"]["token"],
})
- def _get_events_around_txn(self, txn, room_id, event_id, before_limit, after_limit):
+ def _get_events_around_txn(
+ self, txn, room_id, event_id, before_limit, after_limit, event_filter,
+ ):
"""Retrieves event_ids and pagination tokens around a given event in a
room.
@@ -574,6 +577,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
event_id (str)
before_limit (int)
after_limit (int)
+ event_filter (Filter|None)
Returns:
dict
@@ -603,11 +607,13 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
rows, start_token = self._paginate_room_events_txn(
txn, room_id, before_token, direction='b', limit=before_limit,
+ event_filter=event_filter,
)
events_before = [r.event_id for r in rows]
rows, end_token = self._paginate_room_events_txn(
txn, room_id, after_token, direction='f', limit=after_limit,
+ event_filter=event_filter,
)
events_after = [r.event_id for r in rows]
diff --git a/synapse/storage/tags.py b/synapse/storage/tags.py
index 6671d3cfca..0f657b2bd3 100644
--- a/synapse/storage/tags.py
+++ b/synapse/storage/tags.py
@@ -14,16 +14,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.storage.account_data import AccountDataWorkerStore
-
-from synapse.util.caches.descriptors import cached
-from twisted.internet import defer
-
-import simplejson as json
import logging
from six.moves import range
+from canonicaljson import json
+
+from twisted.internet import defer
+
+from synapse.storage.account_data import AccountDataWorkerStore
+from synapse.util.caches.descriptors import cached
+
logger = logging.getLogger(__name__)
diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py
index e485d19b84..0c42bd3322 100644
--- a/synapse/storage/transactions.py
+++ b/synapse/storage/transactions.py
@@ -13,18 +13,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import SQLBaseStore
-from synapse.util.caches.descriptors import cached
+import logging
+from collections import namedtuple
-from twisted.internet import defer
import six
from canonicaljson import encode_canonical_json
-from collections import namedtuple
+from twisted.internet import defer
-import logging
-import simplejson as json
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.util.caches.descriptors import cached
+
+from ._base import SQLBaseStore, db_to_json
# py2 sqlite has buffer hardcoded as only binary type, so we must use it,
# despite being deprecated and removed in favor of memoryview
@@ -57,7 +58,7 @@ class TransactionStore(SQLBaseStore):
def __init__(self, db_conn, hs):
super(TransactionStore, self).__init__(db_conn, hs)
- self._clock.looping_call(self._cleanup_transactions, 30 * 60 * 1000)
+ self._clock.looping_call(self._start_cleanup_transactions, 30 * 60 * 1000)
def get_received_txn_response(self, transaction_id, origin):
"""For an incoming transaction from a given origin, check if we have
@@ -94,7 +95,8 @@ class TransactionStore(SQLBaseStore):
)
if result and result["response_code"]:
- return result["response_code"], json.loads(str(result["response_json"]))
+ return result["response_code"], db_to_json(result["response_json"])
+
else:
return None
@@ -271,6 +273,11 @@ class TransactionStore(SQLBaseStore):
txn.execute(query, (self._clock.time_msec(),))
return self.cursor_to_dict(txn)
+ def _start_cleanup_transactions(self):
+ return run_as_background_process(
+ "cleanup_transactions", self._cleanup_transactions,
+ )
+
def _cleanup_transactions(self):
now = self._clock.time_msec()
month_ago = now - 30 * 24 * 60 * 60 * 1000
diff --git a/synapse/storage/user_directory.py b/synapse/storage/user_directory.py
index 275c299998..a8781b0e5d 100644
--- a/synapse/storage/user_directory.py
+++ b/synapse/storage/user_directory.py
@@ -13,19 +13,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
+import logging
+import re
-from ._base import SQLBaseStore
+from six import iteritems
+
+from twisted.internet import defer
-from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
from synapse.api.constants import EventTypes, JoinRules
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.types import get_domain_from_id, get_localpart_from_id
+from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
-from six import iteritems
-
-import re
-import logging
+from ._base import SQLBaseStore
logger = logging.getLogger(__name__)
@@ -265,7 +265,7 @@ class UserDirectoryStore(SQLBaseStore):
self.get_user_in_public_room.invalidate((user_id,))
def get_users_in_public_due_to_room(self, room_id):
- """Get all user_ids that are in the room directory becuase they're
+ """Get all user_ids that are in the room directory because they're
in the given room_id
"""
return self._simple_select_onecol(
@@ -277,7 +277,7 @@ class UserDirectoryStore(SQLBaseStore):
@defer.inlineCallbacks
def get_users_in_dir_due_to_room(self, room_id):
- """Get all user_ids that are in the room directory becuase they're
+ """Get all user_ids that are in the room directory because they're
in the given room_id
"""
user_ids_dir = yield self._simple_select_onecol(
diff --git a/synapse/storage/user_erasure_store.py b/synapse/storage/user_erasure_store.py
new file mode 100644
index 0000000000..be013f4427
--- /dev/null
+++ b/synapse/storage/user_erasure_store.py
@@ -0,0 +1,103 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import operator
+
+from twisted.internet import defer
+
+from synapse.storage._base import SQLBaseStore
+from synapse.util.caches.descriptors import cached, cachedList
+
+
+class UserErasureWorkerStore(SQLBaseStore):
+ @cached()
+ def is_user_erased(self, user_id):
+ """
+ Check if the given user id has requested erasure
+
+ Args:
+ user_id (str): full user id to check
+
+ Returns:
+ Deferred[bool]: True if the user has requested erasure
+ """
+ return self._simple_select_onecol(
+ table="erased_users",
+ keyvalues={"user_id": user_id},
+ retcol="1",
+ desc="is_user_erased",
+ ).addCallback(operator.truth)
+
+ @cachedList(
+ cached_method_name="is_user_erased",
+ list_name="user_ids",
+ inlineCallbacks=True,
+ )
+ def are_users_erased(self, user_ids):
+ """
+ Checks which users in a list have requested erasure
+
+ Args:
+ user_ids (iterable[str]): full user id to check
+
+ Returns:
+ Deferred[dict[str, bool]]:
+ for each user, whether the user has requested erasure.
+ """
+ # this serves the dual purpose of (a) making sure we can do len and
+ # iterate it multiple times, and (b) avoiding duplicates.
+ user_ids = tuple(set(user_ids))
+
+ def _get_erased_users(txn):
+ txn.execute(
+ "SELECT user_id FROM erased_users WHERE user_id IN (%s)" % (
+ ",".join("?" * len(user_ids))
+ ),
+ user_ids,
+ )
+ return set(r[0] for r in txn)
+
+ erased_users = yield self.runInteraction(
+ "are_users_erased", _get_erased_users,
+ )
+ res = dict((u, u in erased_users) for u in user_ids)
+ defer.returnValue(res)
+
+
+class UserErasureStore(UserErasureWorkerStore):
+ def mark_user_erased(self, user_id):
+ """Indicate that user_id wishes their message history to be erased.
+
+ Args:
+ user_id (str): full user_id to be erased
+ """
+ def f(txn):
+ # first check if they are already in the list
+ txn.execute(
+ "SELECT 1 FROM erased_users WHERE user_id = ?",
+ (user_id, )
+ )
+ if txn.fetchone():
+ return
+
+ # they are not already there: do the insert.
+ txn.execute(
+ "INSERT INTO erased_users (user_id) VALUES (?)",
+ (user_id, )
+ )
+
+ self._invalidate_cache_and_stream(
+ txn, self.is_user_erased, (user_id,)
+ )
+ return self.runInteraction("mark_user_erased", f)
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 95031dc9ec..d6160d5e4d 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -13,9 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from collections import deque
import contextlib
import threading
+from collections import deque
class IdGenerator(object):
|