diff --git a/synapse/__init__.py b/synapse/__init__.py
index ff251ce597..7628e7c505 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server.
"""
-__version__ = "0.19.2"
+__version__ = "0.19.3"
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index a8123cddcb..ca23c9c460 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -43,9 +43,6 @@ class JoinRules(object):
class LoginType(object):
PASSWORD = u"m.login.password"
- OAUTH = u"m.login.oauth2"
- EMAIL_CODE = u"m.login.email.code"
- EMAIL_URL = u"m.login.email.url"
EMAIL_IDENTITY = u"m.login.email.identity"
RECAPTCHA = u"m.login.recaptcha"
DUMMY = u"m.login.dummy"
diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py
index b3fb408cfd..3f29595256 100644
--- a/synapse/app/synchrotron.py
+++ b/synapse/app/synchrotron.py
@@ -87,6 +87,10 @@ class SynchrotronSlavedStore(
RoomMemberStore.__dict__["who_forgot_in_room"]
)
+ did_forget = (
+ RoomMemberStore.__dict__["did_forget"]
+ )
+
# XXX: This is a bit broken because we don't persist the accepted list in a
# way that can be replicated. This means that we don't have a way to
# invalidate the cache correctly.
diff --git a/synapse/config/tls.py b/synapse/config/tls.py
index 3c58d2de17..e081840a83 100644
--- a/synapse/config/tls.py
+++ b/synapse/config/tls.py
@@ -95,7 +95,7 @@ class TlsConfig(Config):
# make HTTPS requests to this server will check that the TLS
# certificates returned by this server match one of the fingerprints.
#
- # Synapse automatically adds its the fingerprint of its own certificate
+ # Synapse automatically adds the fingerprint of its own certificate
# to the list. So if federation traffic is handle directly by synapse
# then no modification to the list is required.
#
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index b5bcfd705a..5dcd4eecce 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -206,8 +206,7 @@ class FederationClient(FederationBase):
Args:
destinations (list): Which home servers to query
- pdu_origin (str): The home server that originally sent the pdu.
- event_id (str)
+ event_id (str): event to fetch
outlier (bool): Indicates whether the PDU is an `outlier`, i.e. if
it's from an arbitary point in the context as opposed to part
of the current block of PDUs. Defaults to `False`
diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py
index bb3d9258a6..90235ff098 100644
--- a/synapse/federation/transaction_queue.py
+++ b/synapse/federation/transaction_queue.py
@@ -303,18 +303,10 @@ class TransactionQueue(object):
try:
self.pending_transactions[destination] = 1
+ # XXX: what's this for?
yield run_on_reactor()
while True:
- pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
- pending_edus = self.pending_edus_by_dest.pop(destination, [])
- pending_presence = self.pending_presence_by_dest.pop(destination, {})
- pending_failures = self.pending_failures_by_dest.pop(destination, [])
-
- pending_edus.extend(
- self.pending_edus_keyed_by_dest.pop(destination, {}).values()
- )
-
limiter = yield get_retry_limiter(
destination,
self.clock,
@@ -326,6 +318,24 @@ class TransactionQueue(object):
yield self._get_new_device_messages(destination)
)
+ # BEGIN CRITICAL SECTION
+ #
+ # In order to avoid a race condition, we need to make sure that
+ # the following code (from popping the queues up to the point
+ # where we decide if we actually have any pending messages) is
+ # atomic - otherwise new PDUs or EDUs might arrive in the
+ # meantime, but not get sent because we hold the
+ # pending_transactions flag.
+
+ pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
+ pending_edus = self.pending_edus_by_dest.pop(destination, [])
+ pending_presence = self.pending_presence_by_dest.pop(destination, {})
+ pending_failures = self.pending_failures_by_dest.pop(destination, [])
+
+ pending_edus.extend(
+ self.pending_edus_keyed_by_dest.pop(destination, {}).values()
+ )
+
pending_edus.extend(device_message_edus)
if pending_presence:
pending_edus.append(
@@ -355,6 +365,8 @@ class TransactionQueue(object):
)
return
+ # END CRITICAL SECTION
+
success = yield self._send_new_transaction(
destination, pending_pdus, pending_edus, pending_failures,
limiter=limiter,
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index 084e33ca6a..f36b358b45 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -19,7 +19,6 @@ from ._base import BaseHandler
import logging
-
logger = logging.getLogger(__name__)
@@ -54,3 +53,46 @@ class AdminHandler(BaseHandler):
}
defer.returnValue(ret)
+
+ @defer.inlineCallbacks
+ def get_users(self):
+ """Function to reterive a list of users in users table.
+
+ Args:
+ Returns:
+ defer.Deferred: resolves to list[dict[str, Any]]
+ """
+ ret = yield self.store.get_users()
+
+ defer.returnValue(ret)
+
+ @defer.inlineCallbacks
+ def get_users_paginate(self, order, start, limit):
+ """Function to reterive a paginated list of users from
+ users list. This will return a json object, which contains
+ list of users and the total number of users in users table.
+
+ Args:
+ order (str): column name to order the select by this column
+ start (int): start number to begin the query from
+ limit (int): number of rows to reterive
+ Returns:
+ defer.Deferred: resolves to json object {list[dict[str, Any]], count}
+ """
+ ret = yield self.store.get_users_paginate(order, start, limit)
+
+ defer.returnValue(ret)
+
+ @defer.inlineCallbacks
+ def search_users(self, term):
+ """Function to search users list for one or more users with
+ the matched term.
+
+ Args:
+ term (str): search term
+ Returns:
+ defer.Deferred: resolves to list[dict[str, Any]]
+ """
+ ret = yield self.store.search_users(term)
+
+ defer.returnValue(ret)
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 8cb47ac417..e859b3165f 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -12,11 +12,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
from synapse.api import errors
from synapse.api.constants import EventTypes
from synapse.util import stringutils
from synapse.util.async import Linearizer
+from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.metrics import measure_func
from synapse.types import get_domain_from_id, RoomStreamToken
from twisted.internet import defer
@@ -35,10 +35,11 @@ class DeviceHandler(BaseHandler):
self.state = hs.get_state_handler()
self.federation_sender = hs.get_federation_sender()
self.federation = hs.get_replication_layer()
- self._remote_edue_linearizer = Linearizer(name="remote_device_list")
+
+ self._edu_updater = DeviceListEduUpdater(hs, self)
self.federation.register_edu_handler(
- "m.device_list_update", self._incoming_device_list_update,
+ "m.device_list_update", self._edu_updater.incoming_device_list_update,
)
self.federation.register_query_handler(
"user_devices", self.on_federation_query_user_devices,
@@ -246,30 +247,51 @@ class DeviceHandler(BaseHandler):
# Then work out if any users have since joined
rooms_changed = self.store.get_rooms_that_changed(room_ids, from_token.room_key)
+ stream_ordering = RoomStreamToken.parse_stream_token(
+ from_token.room_key).stream
+
possibly_changed = set(changed)
for room_id in rooms_changed:
- # Fetch the current state at the time.
- stream_ordering = RoomStreamToken.parse_stream_token(from_token.room_key)
-
+ # Fetch the current state at the time.
try:
event_ids = yield self.store.get_forward_extremeties_for_room(
room_id, stream_ordering=stream_ordering
)
- prev_state_ids = yield self.store.get_state_ids_for_events(event_ids)
- except:
- prev_state_ids = {}
+ except errors.StoreError:
+ # we have purged the stream_ordering index since the stream
+ # ordering: treat it the same as a new room
+ event_ids = []
current_state_ids = yield self.state.get_current_state_ids(room_id)
+ # special-case for an empty prev state: include all members
+ # in the changed list
+ if not event_ids:
+ for key, event_id in current_state_ids.iteritems():
+ etype, state_key = key
+ if etype != EventTypes.Member:
+ continue
+ possibly_changed.add(state_key)
+ continue
+
+ # mapping from event_id -> state_dict
+ prev_state_ids = yield self.store.get_state_ids_for_events(event_ids)
+
# If there has been any change in membership, include them in the
# possibly changed list. We'll check if they are joined below,
# and we're not toooo worried about spuriously adding users.
for key, event_id in current_state_ids.iteritems():
etype, state_key = key
- if etype == EventTypes.Member:
- prev_event_id = prev_state_ids.get(key, None)
+ if etype != EventTypes.Member:
+ continue
+
+ # check if this member has changed since any of the extremities
+ # at the stream_ordering, and add them to the list if so.
+ for state_dict in prev_state_ids.values():
+ prev_event_id = state_dict.get(key, None)
if not prev_event_id or prev_event_id != event_id:
possibly_changed.add(state_key)
+ break
users_who_share_room = yield self.store.get_users_who_share_room_with_user(
user_id
@@ -279,13 +301,69 @@ class DeviceHandler(BaseHandler):
# and those that actually still share a room with the user
defer.returnValue(users_who_share_room & possibly_changed)
- @measure_func("_incoming_device_list_update")
@defer.inlineCallbacks
- def _incoming_device_list_update(self, origin, edu_content):
- user_id = edu_content["user_id"]
- device_id = edu_content["device_id"]
- stream_id = edu_content["stream_id"]
- prev_ids = edu_content.get("prev_id", [])
+ def on_federation_query_user_devices(self, user_id):
+ stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id)
+ defer.returnValue({
+ "user_id": user_id,
+ "stream_id": stream_id,
+ "devices": devices,
+ })
+
+ @defer.inlineCallbacks
+ def user_left_room(self, user, room_id):
+ user_id = user.to_string()
+ rooms = yield self.store.get_rooms_for_user(user_id)
+ if not rooms:
+ # We no longer share rooms with this user, so we'll no longer
+ # receive device updates. Mark this in DB.
+ yield self.store.mark_remote_user_device_list_as_unsubscribed(user_id)
+
+
+def _update_device_from_client_ips(device, client_ips):
+ ip = client_ips.get((device["user_id"], device["device_id"]), {})
+ device.update({
+ "last_seen_ts": ip.get("last_seen"),
+ "last_seen_ip": ip.get("ip"),
+ })
+
+
+class DeviceListEduUpdater(object):
+ "Handles incoming device list updates from federation and updates the DB"
+
+ def __init__(self, hs, device_handler):
+ self.store = hs.get_datastore()
+ self.federation = hs.get_replication_layer()
+ self.clock = hs.get_clock()
+ self.device_handler = device_handler
+
+ self._remote_edu_linearizer = Linearizer(name="remote_device_list")
+
+ # user_id -> list of updates waiting to be handled.
+ self._pending_updates = {}
+
+ # Recently seen stream ids. We don't bother keeping these in the DB,
+ # but they're useful to have them about to reduce the number of spurious
+ # resyncs.
+ self._seen_updates = ExpiringCache(
+ cache_name="device_update_edu",
+ clock=self.clock,
+ max_len=10000,
+ expiry_ms=30 * 60 * 1000,
+ iterable=True,
+ )
+
+ @defer.inlineCallbacks
+ def incoming_device_list_update(self, origin, edu_content):
+ """Called on incoming device list update from federation. Responsible
+ for parsing the EDU and adding to pending updates list.
+ """
+
+ user_id = edu_content.pop("user_id")
+ device_id = edu_content.pop("device_id")
+ stream_id = str(edu_content.pop("stream_id")) # They may come as ints
+ prev_ids = edu_content.pop("prev_id", [])
+ prev_ids = [str(p) for p in prev_ids] # They may come as ints
if get_domain_from_id(user_id) != origin:
# TODO: Raise?
@@ -298,20 +376,28 @@ class DeviceHandler(BaseHandler):
# probably won't get any further updates.
return
- with (yield self._remote_edue_linearizer.queue(user_id)):
- # If the prev id matches whats in our cache table, then we don't need
- # to resync the users device list, otherwise we do.
- resync = True
- if len(prev_ids) == 1:
- extremity = yield self.store.get_device_list_last_stream_id_for_remote(
- user_id
- )
- logger.info("Extrem: %r, prev_ids: %r", extremity, prev_ids)
- if str(extremity) == str(prev_ids[0]):
- resync = False
+ self._pending_updates.setdefault(user_id, []).append(
+ (device_id, stream_id, prev_ids, edu_content)
+ )
+
+ yield self._handle_device_updates(user_id)
+
+ @measure_func("_incoming_device_list_update")
+ @defer.inlineCallbacks
+ def _handle_device_updates(self, user_id):
+ "Actually handle pending updates."
+
+ with (yield self._remote_edu_linearizer.queue(user_id)):
+ pending_updates = self._pending_updates.pop(user_id, [])
+ if not pending_updates:
+ # This can happen since we batch updates
+ return
+
+ resync = yield self._need_to_do_resync(user_id, pending_updates)
if resync:
# Fetch all devices for the user.
+ origin = get_domain_from_id(user_id)
result = yield self.federation.query_user_devices(origin, user_id)
stream_id = result["stream_id"]
devices = result["devices"]
@@ -319,40 +405,50 @@ class DeviceHandler(BaseHandler):
user_id, devices, stream_id,
)
device_ids = [device["device_id"] for device in devices]
- yield self.notify_device_update(user_id, device_ids)
+ yield self.device_handler.notify_device_update(user_id, device_ids)
else:
# Simply update the single device, since we know that is the only
# change (becuase of the single prev_id matching the current cache)
- content = dict(edu_content)
- for key in ("user_id", "device_id", "stream_id", "prev_ids"):
- content.pop(key, None)
- yield self.store.update_remote_device_list_cache_entry(
- user_id, device_id, content, stream_id,
+ for device_id, stream_id, prev_ids, content in pending_updates:
+ yield self.store.update_remote_device_list_cache_entry(
+ user_id, device_id, content, stream_id,
+ )
+
+ yield self.device_handler.notify_device_update(
+ user_id, [device_id for device_id, _, _, _ in pending_updates]
)
- yield self.notify_device_update(user_id, [device_id])
- @defer.inlineCallbacks
- def on_federation_query_user_devices(self, user_id):
- stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id)
- defer.returnValue({
- "user_id": user_id,
- "stream_id": stream_id,
- "devices": devices,
- })
+ self._seen_updates.setdefault(user_id, set()).update(
+ stream_id for _, stream_id, _, _ in pending_updates
+ )
@defer.inlineCallbacks
- def user_left_room(self, user, room_id):
- user_id = user.to_string()
- rooms = yield self.store.get_rooms_for_user(user_id)
- if not rooms:
- # We no longer share rooms with this user, so we'll no longer
- # receive device updates. Mark this in DB.
- yield self.store.mark_remote_user_device_list_as_unsubscribed(user_id)
+ def _need_to_do_resync(self, user_id, updates):
+ """Given a list of updates for a user figure out if we need to do a full
+ resync, or whether we have enough data that we can just apply the delta.
+ """
+ seen_updates = self._seen_updates.get(user_id, set())
+ extremity = yield self.store.get_device_list_last_stream_id_for_remote(
+ user_id
+ )
-def _update_device_from_client_ips(device, client_ips):
- ip = client_ips.get((device["user_id"], device["device_id"]), {})
- device.update({
- "last_seen_ts": ip.get("last_seen"),
- "last_seen_ip": ip.get("ip"),
- })
+ stream_id_in_updates = set() # stream_ids in updates list
+ for _, stream_id, prev_ids, _ in updates:
+ if not prev_ids:
+ # We always do a resync if there are no previous IDs
+ defer.returnValue(True)
+
+ for prev_id in prev_ids:
+ if prev_id == extremity:
+ continue
+ elif prev_id in seen_updates:
+ continue
+ elif prev_id in stream_id_in_updates:
+ continue
+ else:
+ defer.returnValue(True)
+
+ stream_id_in_updates.add(stream_id)
+
+ defer.returnValue(False)
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 996bfd0e23..ed0fa51e7f 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -1096,7 +1096,7 @@ class FederationHandler(BaseHandler):
if prev_id != event.event_id:
results[(event.type, event.state_key)] = prev_id
else:
- del results[(event.type, event.state_key)]
+ results.pop((event.type, event.state_key), None)
defer.returnValue(results.values())
else:
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index fdfce2a88c..da610e430f 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -531,7 +531,7 @@ class PresenceHandler(object):
# There are things not in our in memory cache. Lets pull them out of
# the database.
res = yield self.store.get_presence_for_users(missing)
- states.update({state.user_id: state for state in res})
+ states.update(res)
missing = [user_id for user_id, state in states.items() if not state]
if missing:
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index b2806555cf..2052d6d05f 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -719,7 +719,9 @@ class RoomMemberHandler(BaseHandler):
)
membership = member.membership if member else None
- if membership is not None and membership != Membership.LEAVE:
+ if membership is not None and membership not in [
+ Membership.LEAVE, Membership.BAN
+ ]:
raise SynapseError(400, "User %s in room %s" % (
user_id, room_id
))
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index d7dcd1ce5b..5572cb883f 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -609,14 +609,14 @@ class SyncHandler(object):
deleted = yield self.store.delete_messages_for_device(
user_id, device_id, since_stream_id
)
- logger.info("Deleted %d to-device messages up to %d",
- deleted, since_stream_id)
+ logger.debug("Deleted %d to-device messages up to %d",
+ deleted, since_stream_id)
messages, stream_id = yield self.store.get_new_messages_for_device(
user_id, device_id, since_stream_id, now_token.to_device_key
)
- logger.info(
+ logger.debug(
"Returning %d to-device messages between %d and %d (current token: %d)",
len(messages), since_stream_id, stream_id, now_token.to_device_key
)
diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py
index 2eb325c7c7..c7afd11111 100644
--- a/synapse/push/emailpusher.py
+++ b/synapse/push/emailpusher.py
@@ -218,7 +218,8 @@ class EmailPusher(object):
)
def seconds_until(self, ts_msec):
- return (ts_msec - self.clock.time_msec()) / 1000
+ secs = (ts_msec - self.clock.time_msec()) / 1000
+ return max(secs, 0)
def get_room_throttle_ms(self, room_id):
if room_id in self.throttle_params:
diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py
index 18076e0f3b..ab133db872 100644
--- a/synapse/replication/slave/storage/_base.py
+++ b/synapse/replication/slave/storage/_base.py
@@ -54,7 +54,9 @@ class BaseSlavedStore(SQLBaseStore):
try:
getattr(self, cache_func).invalidate(tuple(keys))
except AttributeError:
- logger.info("Got unexpected cache_func: %r", cache_func)
+ # We probably haven't pulled in the cache in this worker,
+ # which is fine.
+ pass
self._cache_id_gen.advance(int(stream["position"]))
return defer.succeed(None)
diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py
index 735c03c7eb..77c64722c7 100644
--- a/synapse/replication/slave/storage/account_data.py
+++ b/synapse/replication/slave/storage/account_data.py
@@ -46,6 +46,12 @@ class SlavedAccountDataStore(BaseSlavedStore):
)
get_tags_for_user = TagsStore.__dict__["get_tags_for_user"]
+ get_tags_for_room = (
+ DataStore.get_tags_for_room.__func__
+ )
+ get_account_data_for_room = (
+ DataStore.get_account_data_for_room.__func__
+ )
get_updated_tags = DataStore.get_updated_tags.__func__
get_updated_account_data_for_user = (
diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py
index cc860f9f9b..f9102e0d89 100644
--- a/synapse/replication/slave/storage/deviceinbox.py
+++ b/synapse/replication/slave/storage/deviceinbox.py
@@ -17,6 +17,7 @@ from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
from synapse.storage import DataStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
+from synapse.util.caches.expiringcache import ExpiringCache
class SlavedDeviceInboxStore(BaseSlavedStore):
@@ -34,6 +35,13 @@ class SlavedDeviceInboxStore(BaseSlavedStore):
self._device_inbox_id_gen.get_current_token()
)
+ self._last_device_delete_cache = ExpiringCache(
+ cache_name="last_device_delete_cache",
+ clock=self._clock,
+ max_len=10000,
+ expiry_ms=30 * 60 * 1000,
+ )
+
get_to_device_stream_token = DataStore.get_to_device_stream_token.__func__
get_new_messages_for_device = DataStore.get_new_messages_for_device.__func__
get_new_device_msgs_for_remote = DataStore.get_new_device_msgs_for_remote.__func__
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index d72ff6055c..622b2d8540 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -85,6 +85,12 @@ class SlavedEventStore(BaseSlavedStore):
get_unread_event_push_actions_by_room_for_user = (
EventPushActionsStore.__dict__["get_unread_event_push_actions_by_room_for_user"]
)
+ _get_unread_counts_by_receipt_txn = (
+ DataStore._get_unread_counts_by_receipt_txn.__func__
+ )
+ _get_unread_counts_by_pos_txn = (
+ DataStore._get_unread_counts_by_pos_txn.__func__
+ )
_get_state_group_for_events = (
StateStore.__dict__["_get_state_group_for_events"]
)
diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py
index 703f4a49bf..40f6c9a386 100644
--- a/synapse/replication/slave/storage/presence.py
+++ b/synapse/replication/slave/storage/presence.py
@@ -18,6 +18,7 @@ from ._slaved_id_tracker import SlavedIdTracker
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.storage import DataStore
+from synapse.storage.presence import PresenceStore
class SlavedPresenceStore(BaseSlavedStore):
@@ -35,7 +36,8 @@ class SlavedPresenceStore(BaseSlavedStore):
_get_active_presence = DataStore._get_active_presence.__func__
take_presence_startup_info = DataStore.take_presence_startup_info.__func__
- get_presence_for_users = DataStore.get_presence_for_users.__func__
+ _get_presence_for_user = PresenceStore.__dict__["_get_presence_for_user"]
+ get_presence_for_users = PresenceStore.__dict__["get_presence_for_users"]
def get_current_presence_token(self):
return self._presence_id_gen.get_current_token()
diff --git a/synapse/rest/client/transactions.py b/synapse/rest/client/transactions.py
index efa77b8c51..fceca2edeb 100644
--- a/synapse/rest/client/transactions.py
+++ b/synapse/rest/client/transactions.py
@@ -87,9 +87,17 @@ class HttpTransactionCache(object):
deferred = fn(*args, **kwargs)
- # We don't add an errback to the raw deferred, so we ask ObservableDeferred
- # to swallow the error. This is fine as the error will still be reported
- # to the observers.
+ # if the request fails with a Twisted failure, remove it
+ # from the transaction map. This is done to ensure that we don't
+ # cache transient errors like rate-limiting errors, etc.
+ def remove_from_map(err):
+ self.transactions.pop(txn_key, None)
+ return err
+ deferred.addErrback(remove_from_map)
+
+ # We don't add any other errbacks to the raw deferred, so we ask
+ # ObservableDeferred to swallow the error. This is fine as the error will
+ # still be reported to the observers.
observable = ObservableDeferred(deferred, consumeErrors=True)
self.transactions[txn_key] = (observable, self.clock.time_msec())
return observable.observe()
diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py
index af21661d7c..29fcd72375 100644
--- a/synapse/rest/client/v1/admin.py
+++ b/synapse/rest/client/v1/admin.py
@@ -17,6 +17,7 @@ from twisted.internet import defer
from synapse.api.errors import AuthError, SynapseError
from synapse.types import UserID
+from synapse.http.servlet import parse_json_object_from_request
from .base import ClientV1RestServlet, client_path_patterns
@@ -25,6 +26,34 @@ import logging
logger = logging.getLogger(__name__)
+class UsersRestServlet(ClientV1RestServlet):
+ PATTERNS = client_path_patterns("/admin/users/(?P<user_id>[^/]*)")
+
+ def __init__(self, hs):
+ super(UsersRestServlet, self).__init__(hs)
+ self.handlers = hs.get_handlers()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, user_id):
+ target_user = UserID.from_string(user_id)
+ requester = yield self.auth.get_user_by_req(request)
+ is_admin = yield self.auth.is_server_admin(requester.user)
+
+ if not is_admin:
+ raise AuthError(403, "You are not a server admin")
+
+ # To allow all users to get the users list
+ # if not is_admin and target_user != auth_user:
+ # raise AuthError(403, "You are not a server admin")
+
+ if not self.hs.is_mine(target_user):
+ raise SynapseError(400, "Can only users a local user")
+
+ ret = yield self.handlers.admin_handler.get_users()
+
+ defer.returnValue((200, ret))
+
+
class WhoisRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/admin/whois/(?P<user_id>[^/]*)")
@@ -128,8 +157,199 @@ class DeactivateAccountRestServlet(ClientV1RestServlet):
defer.returnValue((200, {}))
+class ResetPasswordRestServlet(ClientV1RestServlet):
+ """Post request to allow an administrator reset password for a user.
+ This need a user have a administrator access in Synapse.
+ Example:
+ http://localhost:8008/_matrix/client/api/v1/admin/reset_password/
+ @user:to_reset_password?access_token=admin_access_token
+ JsonBodyToSend:
+ {
+ "new_password": "secret"
+ }
+ Returns:
+ 200 OK with empty object if success otherwise an error.
+ """
+ PATTERNS = client_path_patterns("/admin/reset_password/(?P<target_user_id>[^/]*)")
+
+ def __init__(self, hs):
+ self.store = hs.get_datastore()
+ super(ResetPasswordRestServlet, self).__init__(hs)
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.auth_handler = hs.get_auth_handler()
+
+ @defer.inlineCallbacks
+ def on_POST(self, request, target_user_id):
+ """Post request to allow an administrator reset password for a user.
+ This need a user have a administrator access in Synapse.
+ """
+ UserID.from_string(target_user_id)
+ requester = yield self.auth.get_user_by_req(request)
+ is_admin = yield self.auth.is_server_admin(requester.user)
+
+ if not is_admin:
+ raise AuthError(403, "You are not a server admin")
+
+ params = parse_json_object_from_request(request)
+ new_password = params['new_password']
+ if not new_password:
+ raise SynapseError(400, "Missing 'new_password' arg")
+
+ logger.info("new_password: %r", new_password)
+
+ yield self.auth_handler.set_password(
+ target_user_id, new_password, requester
+ )
+ defer.returnValue((200, {}))
+
+
+class GetUsersPaginatedRestServlet(ClientV1RestServlet):
+ """Get request to get specific number of users from Synapse.
+ This need a user have a administrator access in Synapse.
+ Example:
+ http://localhost:8008/_matrix/client/api/v1/admin/users_paginate/
+ @admin:user?access_token=admin_access_token&start=0&limit=10
+ Returns:
+ 200 OK with json object {list[dict[str, Any]], count} or empty object.
+ """
+ PATTERNS = client_path_patterns("/admin/users_paginate/(?P<target_user_id>[^/]*)")
+
+ def __init__(self, hs):
+ self.store = hs.get_datastore()
+ super(GetUsersPaginatedRestServlet, self).__init__(hs)
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.handlers = hs.get_handlers()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, target_user_id):
+ """Get request to get specific number of users from Synapse.
+ This need a user have a administrator access in Synapse.
+ """
+ target_user = UserID.from_string(target_user_id)
+ requester = yield self.auth.get_user_by_req(request)
+ is_admin = yield self.auth.is_server_admin(requester.user)
+
+ if not is_admin:
+ raise AuthError(403, "You are not a server admin")
+
+ # To allow all users to get the users list
+ # if not is_admin and target_user != auth_user:
+ # raise AuthError(403, "You are not a server admin")
+
+ if not self.hs.is_mine(target_user):
+ raise SynapseError(400, "Can only users a local user")
+
+ order = "name" # order by name in user table
+ start = request.args.get("start")[0]
+ limit = request.args.get("limit")[0]
+ if not limit:
+ raise SynapseError(400, "Missing 'limit' arg")
+ if not start:
+ raise SynapseError(400, "Missing 'start' arg")
+ logger.info("limit: %s, start: %s", limit, start)
+
+ ret = yield self.handlers.admin_handler.get_users_paginate(
+ order, start, limit
+ )
+ defer.returnValue((200, ret))
+
+ @defer.inlineCallbacks
+ def on_POST(self, request, target_user_id):
+ """Post request to get specific number of users from Synapse..
+ This need a user have a administrator access in Synapse.
+ Example:
+ http://localhost:8008/_matrix/client/api/v1/admin/users_paginate/
+ @admin:user?access_token=admin_access_token
+ JsonBodyToSend:
+ {
+ "start": "0",
+ "limit": "10
+ }
+ Returns:
+ 200 OK with json object {list[dict[str, Any]], count} or empty object.
+ """
+ UserID.from_string(target_user_id)
+ requester = yield self.auth.get_user_by_req(request)
+ is_admin = yield self.auth.is_server_admin(requester.user)
+
+ if not is_admin:
+ raise AuthError(403, "You are not a server admin")
+
+ order = "name" # order by name in user table
+ params = parse_json_object_from_request(request)
+ limit = params['limit']
+ start = params['start']
+ if not limit:
+ raise SynapseError(400, "Missing 'limit' arg")
+ if not start:
+ raise SynapseError(400, "Missing 'start' arg")
+ logger.info("limit: %s, start: %s", limit, start)
+
+ ret = yield self.handlers.admin_handler.get_users_paginate(
+ order, start, limit
+ )
+ defer.returnValue((200, ret))
+
+
+class SearchUsersRestServlet(ClientV1RestServlet):
+ """Get request to search user table for specific users according to
+ search term.
+ This need a user have a administrator access in Synapse.
+ Example:
+ http://localhost:8008/_matrix/client/api/v1/admin/search_users/
+ @admin:user?access_token=admin_access_token&term=alice
+ Returns:
+ 200 OK with json object {list[dict[str, Any]], count} or empty object.
+ """
+ PATTERNS = client_path_patterns("/admin/search_users/(?P<target_user_id>[^/]*)")
+
+ def __init__(self, hs):
+ self.store = hs.get_datastore()
+ super(SearchUsersRestServlet, self).__init__(hs)
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.handlers = hs.get_handlers()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, target_user_id):
+ """Get request to search user table for specific users according to
+ search term.
+ This need a user have a administrator access in Synapse.
+ """
+ target_user = UserID.from_string(target_user_id)
+ requester = yield self.auth.get_user_by_req(request)
+ is_admin = yield self.auth.is_server_admin(requester.user)
+
+ if not is_admin:
+ raise AuthError(403, "You are not a server admin")
+
+ # To allow all users to get the users list
+ # if not is_admin and target_user != auth_user:
+ # raise AuthError(403, "You are not a server admin")
+
+ if not self.hs.is_mine(target_user):
+ raise SynapseError(400, "Can only users a local user")
+
+ term = request.args.get("term")[0]
+ if not term:
+ raise SynapseError(400, "Missing 'term' arg")
+
+ logger.info("term: %s ", term)
+
+ ret = yield self.handlers.admin_handler.search_users(
+ term
+ )
+ defer.returnValue((200, ret))
+
+
def register_servlets(hs, http_server):
WhoisRestServlet(hs).register(http_server)
PurgeMediaCacheRestServlet(hs).register(http_server)
DeactivateAccountRestServlet(hs).register(http_server)
PurgeHistoryRestServlet(hs).register(http_server)
+ UsersRestServlet(hs).register(http_server)
+ ResetPasswordRestServlet(hs).register(http_server)
+ GetUsersPaginatedRestServlet(hs).register(http_server)
+ SearchUsersRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py
index 355e82474b..1a5045c9ec 100644
--- a/synapse/rest/client/v1/profile.py
+++ b/synapse/rest/client/v1/profile.py
@@ -46,6 +46,7 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet):
def on_PUT(self, request, user_id):
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
user = UserID.from_string(user_id)
+ is_admin = yield self.auth.is_server_admin(requester.user)
content = parse_json_object_from_request(request)
@@ -55,7 +56,7 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet):
defer.returnValue((400, "Unable to parse name"))
yield self.handlers.profile_handler.set_displayname(
- user, requester, new_name)
+ user, requester, new_name, is_admin)
defer.returnValue((200, {}))
@@ -88,6 +89,7 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet):
def on_PUT(self, request, user_id):
requester = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id)
+ is_admin = yield self.auth.is_server_admin(requester.user)
content = parse_json_object_from_request(request)
try:
@@ -96,7 +98,7 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet):
defer.returnValue((400, "Unable to parse name"))
yield self.handlers.profile_handler.set_avatar_url(
- user, requester, new_name)
+ user, requester, new_name, is_admin)
defer.returnValue((200, {}))
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 6554f57df1..90242a6bac 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -608,6 +608,10 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
raise SynapseError(400, "Missing user_id key.")
target = UserID.from_string(content["user_id"])
+ event_content = None
+ if 'reason' in content and membership_action in ['kick', 'ban']:
+ event_content = {'reason': content['reason']}
+
yield self.handlers.room_member_handler.update_membership(
requester=requester,
target=target,
@@ -615,6 +619,7 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
action=membership_action,
txn_id=txn_id,
third_party_signed=content.get("third_party_signed", None),
+ content=event_content,
)
defer.returnValue((200, {}))
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 3cbeca503c..481ffee200 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -240,6 +240,9 @@ class MediaRepository(object):
if t_method == "crop":
t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
elif t_method == "scale":
+ t_width, t_height = thumbnailer.aspect(t_width, t_height)
+ t_width = min(m_width, t_width)
+ t_height = min(m_height, t_height)
t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
else:
t_len = None
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index b9968debe5..d604e7668f 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -297,6 +297,82 @@ class DataStore(RoomMemberStore, RoomStore,
desc="get_user_ip_and_agents",
)
+ def get_users(self):
+ """Function to reterive a list of users in users table.
+
+ Args:
+ Returns:
+ defer.Deferred: resolves to list[dict[str, Any]]
+ """
+ return self._simple_select_list(
+ table="users",
+ keyvalues={},
+ retcols=[
+ "name",
+ "password_hash",
+ "is_guest",
+ "admin"
+ ],
+ desc="get_users",
+ )
+
+ def get_users_paginate(self, order, start, limit):
+ """Function to reterive a paginated list of users from
+ users list. This will return a json object, which contains
+ list of users and the total number of users in users table.
+
+ Args:
+ order (str): column name to order the select by this column
+ start (int): start number to begin the query from
+ limit (int): number of rows to reterive
+ Returns:
+ defer.Deferred: resolves to json object {list[dict[str, Any]], count}
+ """
+ is_guest = 0
+ i_start = (int)(start)
+ i_limit = (int)(limit)
+ return self.get_user_list_paginate(
+ table="users",
+ keyvalues={
+ "is_guest": is_guest
+ },
+ pagevalues=[
+ order,
+ i_limit,
+ i_start
+ ],
+ retcols=[
+ "name",
+ "password_hash",
+ "is_guest",
+ "admin"
+ ],
+ desc="get_users_paginate",
+ )
+
+ def search_users(self, term):
+ """Function to search users list for one or more users with
+ the matched term.
+
+ Args:
+ term (str): search term
+ col (str): column to query term should be matched to
+ Returns:
+ defer.Deferred: resolves to list[dict[str, Any]]
+ """
+ return self._simple_search_list(
+ table="users",
+ term=term,
+ col="name",
+ retcols=[
+ "name",
+ "password_hash",
+ "is_guest",
+ "admin"
+ ],
+ desc="search_users",
+ )
+
def are_all_users_on_domain(txn, database_engine, domain):
sql = database_engine.convert_param_style(
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 05374682fd..a7a8ec9b7b 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -18,7 +18,6 @@ from synapse.api.errors import StoreError
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
from synapse.util.caches.dictionary_cache import DictionaryCache
from synapse.util.caches.descriptors import Cache
-from synapse.util.caches import intern_dict
from synapse.storage.engines import PostgresEngine
import synapse.metrics
@@ -80,7 +79,13 @@ class LoggingTransaction(object):
def executemany(self, sql, *args):
self._do_execute(self.txn.executemany, sql, *args)
+ def _make_sql_one_line(self, sql):
+ "Strip newlines out of SQL so that the loggers in the DB are on one line"
+ return " ".join(l.strip() for l in sql.splitlines() if l.strip())
+
def _do_execute(self, func, sql, *args):
+ sql = self._make_sql_one_line(sql)
+
# TODO(paul): Maybe use 'info' and 'debug' for values?
sql_logger.debug("[SQL] {%s} %s", self.name, sql)
@@ -350,9 +355,9 @@ class SQLBaseStore(object):
Returns:
A list of dicts where the key is the column header.
"""
- col_headers = list(column[0] for column in cursor.description)
+ col_headers = list(intern(column[0]) for column in cursor.description)
results = list(
- intern_dict(dict(zip(col_headers, row))) for row in cursor.fetchall()
+ dict(zip(col_headers, row)) for row in cursor.fetchall()
)
return results
@@ -483,10 +488,6 @@ class SQLBaseStore(object):
" AND ".join("%s = ?" % (k,) for k in keyvalues)
)
sqlargs = values.values() + keyvalues.values()
- logger.debug(
- "[SQL] %s Args=%s",
- sql, sqlargs,
- )
txn.execute(sql, sqlargs)
if txn.rowcount == 0:
@@ -501,10 +502,6 @@ class SQLBaseStore(object):
", ".join(k for k in allvalues),
", ".join("?" for _ in allvalues)
)
- logger.debug(
- "[SQL] %s Args=%s",
- sql, keyvalues.values(),
- )
txn.execute(sql, allvalues.values())
return True
@@ -934,6 +931,165 @@ class SQLBaseStore(object):
else:
return 0
+ def _simple_select_list_paginate(self, table, keyvalues, pagevalues, retcols,
+ desc="_simple_select_list_paginate"):
+ """Executes a SELECT query on the named table with start and limit,
+ of row numbers, which may return zero or number of rows from start to limit,
+ returning the result as a list of dicts.
+
+ Args:
+ table (str): the table name
+ keyvalues (dict[str, Any] | None):
+ column names and values to select the rows with, or None to not
+ apply a WHERE clause.
+ retcols (iterable[str]): the names of the columns to return
+ order (str): order the select by this column
+ start (int): start number to begin the query from
+ limit (int): number of rows to reterive
+ Returns:
+ defer.Deferred: resolves to list[dict[str, Any]]
+ """
+ return self.runInteraction(
+ desc,
+ self._simple_select_list_paginate_txn,
+ table, keyvalues, pagevalues, retcols
+ )
+
+ @classmethod
+ def _simple_select_list_paginate_txn(cls, txn, table, keyvalues, pagevalues, retcols):
+ """Executes a SELECT query on the named table with start and limit,
+ of row numbers, which may return zero or number of rows from start to limit,
+ returning the result as a list of dicts.
+
+ Args:
+ txn : Transaction object
+ table (str): the table name
+ keyvalues (dict[str, T] | None):
+ column names and values to select the rows with, or None to not
+ apply a WHERE clause.
+ pagevalues ([]):
+ order (str): order the select by this column
+ start (int): start number to begin the query from
+ limit (int): number of rows to reterive
+ retcols (iterable[str]): the names of the columns to return
+ Returns:
+ defer.Deferred: resolves to list[dict[str, Any]]
+
+ """
+ if keyvalues:
+ sql = "SELECT %s FROM %s WHERE %s ORDER BY %s" % (
+ ", ".join(retcols),
+ table,
+ " AND ".join("%s = ?" % (k,) for k in keyvalues),
+ " ? ASC LIMIT ? OFFSET ?"
+ )
+ txn.execute(sql, keyvalues.values() + pagevalues)
+ else:
+ sql = "SELECT %s FROM %s ORDER BY %s" % (
+ ", ".join(retcols),
+ table,
+ " ? ASC LIMIT ? OFFSET ?"
+ )
+ txn.execute(sql, pagevalues)
+
+ return cls.cursor_to_dict(txn)
+
+ @defer.inlineCallbacks
+ def get_user_list_paginate(self, table, keyvalues, pagevalues, retcols,
+ desc="get_user_list_paginate"):
+ """Get a list of users from start row to a limit number of rows. This will
+ return a json object with users and total number of users in users list.
+
+ Args:
+ table (str): the table name
+ keyvalues (dict[str, Any] | None):
+ column names and values to select the rows with, or None to not
+ apply a WHERE clause.
+ pagevalues ([]):
+ order (str): order the select by this column
+ start (int): start number to begin the query from
+ limit (int): number of rows to reterive
+ retcols (iterable[str]): the names of the columns to return
+ Returns:
+ defer.Deferred: resolves to json object {list[dict[str, Any]], count}
+ """
+ users = yield self.runInteraction(
+ desc,
+ self._simple_select_list_paginate_txn,
+ table, keyvalues, pagevalues, retcols
+ )
+ count = yield self.runInteraction(
+ desc,
+ self.get_user_count_txn
+ )
+ retval = {
+ "users": users,
+ "total": count
+ }
+ defer.returnValue(retval)
+
+ def get_user_count_txn(self, txn):
+ """Get a total number of registerd users in the users list.
+
+ Args:
+ txn : Transaction object
+ Returns:
+ defer.Deferred: resolves to int
+ """
+ sql_count = "SELECT COUNT(*) FROM users WHERE is_guest = 0;"
+ txn.execute(sql_count)
+ count = txn.fetchone()[0]
+ defer.returnValue(count)
+
+ def _simple_search_list(self, table, term, col, retcols,
+ desc="_simple_search_list"):
+ """Executes a SELECT query on the named table, which may return zero or
+ more rows, returning the result as a list of dicts.
+
+ Args:
+ table (str): the table name
+ term (str | None):
+ term for searching the table matched to a column.
+ col (str): column to query term should be matched to
+ retcols (iterable[str]): the names of the columns to return
+ Returns:
+ defer.Deferred: resolves to list[dict[str, Any]] or None
+ """
+
+ return self.runInteraction(
+ desc,
+ self._simple_search_list_txn,
+ table, term, col, retcols
+ )
+
+ @classmethod
+ def _simple_search_list_txn(cls, txn, table, term, col, retcols):
+ """Executes a SELECT query on the named table, which may return zero or
+ more rows, returning the result as a list of dicts.
+
+ Args:
+ txn : Transaction object
+ table (str): the table name
+ term (str | None):
+ term for searching the table matched to a column.
+ col (str): column to query term should be matched to
+ retcols (iterable[str]): the names of the columns to return
+ Returns:
+ defer.Deferred: resolves to list[dict[str, Any]] or None
+ """
+ if term:
+ sql = "SELECT %s FROM %s WHERE %s LIKE ?" % (
+ ", ".join(retcols),
+ table,
+ col
+ )
+ termvalues = ["%%" + term + "%%"]
+ txn.execute(sql, termvalues)
+ else:
+ return 0
+
+ return cls.cursor_to_dict(txn)
+
class _RollbackButIsFineException(Exception):
""" This exception is used to rollback a transaction without implying
diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py
index bde3b5cbbc..5c7db5e5f6 100644
--- a/synapse/storage/deviceinbox.py
+++ b/synapse/storage/deviceinbox.py
@@ -20,6 +20,8 @@ from twisted.internet import defer
from .background_updates import BackgroundUpdateStore
+from synapse.util.caches.expiringcache import ExpiringCache
+
logger = logging.getLogger(__name__)
@@ -42,6 +44,15 @@ class DeviceInboxStore(BackgroundUpdateStore):
self._background_drop_index_device_inbox,
)
+ # Map of (user_id, device_id) to the last stream_id that has been
+ # deleted up to. This is so that we can no op deletions.
+ self._last_device_delete_cache = ExpiringCache(
+ cache_name="last_device_delete_cache",
+ clock=self._clock,
+ max_len=10000,
+ expiry_ms=30 * 60 * 1000,
+ )
+
@defer.inlineCallbacks
def add_messages_to_device_inbox(self, local_messages_by_user_then_device,
remote_messages_by_destination):
@@ -251,6 +262,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
"get_new_messages_for_device", get_new_messages_for_device_txn,
)
+ @defer.inlineCallbacks
def delete_messages_for_device(self, user_id, device_id, up_to_stream_id):
"""
Args:
@@ -260,6 +272,18 @@ class DeviceInboxStore(BackgroundUpdateStore):
Returns:
A deferred that resolves to the number of messages deleted.
"""
+ # If we have cached the last stream id we've deleted up to, we can
+ # check if there is likely to be anything that needs deleting
+ last_deleted_stream_id = self._last_device_delete_cache.get(
+ (user_id, device_id), None
+ )
+ if last_deleted_stream_id:
+ has_changed = self._device_inbox_stream_cache.has_entity_changed(
+ user_id, last_deleted_stream_id
+ )
+ if not has_changed:
+ defer.returnValue(0)
+
def delete_messages_for_device_txn(txn):
sql = (
"DELETE FROM device_inbox"
@@ -269,10 +293,20 @@ class DeviceInboxStore(BackgroundUpdateStore):
txn.execute(sql, (user_id, device_id, up_to_stream_id))
return txn.rowcount
- return self.runInteraction(
+ count = yield self.runInteraction(
"delete_messages_for_device", delete_messages_for_device_txn
)
+ # Update the cache, ensuring that we only ever increase the value
+ last_deleted_stream_id = self._last_device_delete_cache.get(
+ (user_id, device_id), 0
+ )
+ self._last_device_delete_cache[(user_id, device_id)] = max(
+ last_deleted_stream_id, up_to_stream_id
+ )
+
+ defer.returnValue(count)
+
def get_all_new_device_messages(self, last_pos, current_pos, limit):
"""
Args:
diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py
index 8e17800364..bd56ba2515 100644
--- a/synapse/storage/devices.py
+++ b/synapse/storage/devices.py
@@ -19,6 +19,8 @@ from twisted.internet import defer
from synapse.api.errors import StoreError
from ._base import SQLBaseStore
+from synapse.util.caches.descriptors import cached, cachedList, cachedInlineCallbacks
+
logger = logging.getLogger(__name__)
@@ -31,6 +33,13 @@ class DeviceStore(SQLBaseStore):
self._prune_old_outbound_device_pokes, 60 * 60 * 1000
)
+ self.register_background_index_update(
+ "device_lists_stream_idx",
+ index_name="device_lists_stream_user_id",
+ table="device_lists_stream",
+ columns=["user_id", "device_id"],
+ )
+
@defer.inlineCallbacks
def store_device(self, user_id, device_id,
initial_device_display_name):
@@ -144,6 +153,7 @@ class DeviceStore(SQLBaseStore):
defer.returnValue({d["device_id"]: d for d in devices})
+ @cached(max_entries=10000)
def get_device_list_last_stream_id_for_remote(self, user_id):
"""Get the last stream_id we got for a user. May be None if we haven't
got any information for them.
@@ -156,16 +166,36 @@ class DeviceStore(SQLBaseStore):
allow_none=True,
)
+ @cachedList(cached_method_name="get_device_list_last_stream_id_for_remote",
+ list_name="user_ids", inlineCallbacks=True)
+ def get_device_list_last_stream_id_for_remotes(self, user_ids):
+ rows = yield self._simple_select_many_batch(
+ table="device_lists_remote_extremeties",
+ column="user_id",
+ iterable=user_ids,
+ retcols=("user_id", "stream_id",),
+ desc="get_user_devices_from_cache",
+ )
+
+ results = {user_id: None for user_id in user_ids}
+ results.update({
+ row["user_id"]: row["stream_id"] for row in rows
+ })
+
+ defer.returnValue(results)
+
+ @defer.inlineCallbacks
def mark_remote_user_device_list_as_unsubscribed(self, user_id):
"""Mark that we no longer track device lists for remote user.
"""
- return self._simple_delete(
+ yield self._simple_delete(
table="device_lists_remote_extremeties",
keyvalues={
"user_id": user_id,
},
desc="mark_remote_user_device_list_as_unsubscribed",
)
+ self.get_device_list_last_stream_id_for_remote.invalidate((user_id,))
def update_remote_device_list_cache_entry(self, user_id, device_id, content,
stream_id):
@@ -191,6 +221,12 @@ class DeviceStore(SQLBaseStore):
}
)
+ txn.call_after(self._get_cached_user_device.invalidate, (user_id, device_id,))
+ txn.call_after(self._get_cached_devices_for_user.invalidate, (user_id,))
+ txn.call_after(
+ self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,)
+ )
+
self._simple_upsert_txn(
txn,
table="device_lists_remote_extremeties",
@@ -234,6 +270,12 @@ class DeviceStore(SQLBaseStore):
]
)
+ txn.call_after(self._get_cached_devices_for_user.invalidate, (user_id,))
+ txn.call_after(self._get_cached_user_device.invalidate_many, (user_id,))
+ txn.call_after(
+ self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,)
+ )
+
self._simple_upsert_txn(
txn,
table="device_lists_remote_extremeties",
@@ -320,6 +362,7 @@ class DeviceStore(SQLBaseStore):
return (now_stream_id, results)
+ @defer.inlineCallbacks
def get_user_devices_from_cache(self, query_list):
"""Get the devices (and keys if any) for remote users from the cache.
@@ -332,27 +375,11 @@ class DeviceStore(SQLBaseStore):
a set of user_ids and results_map is a mapping of
user_id -> device_id -> device_info
"""
- return self.runInteraction(
- "get_user_devices_from_cache", self._get_user_devices_from_cache_txn,
- query_list,
+ user_ids = set(user_id for user_id, _ in query_list)
+ user_map = yield self.get_device_list_last_stream_id_for_remotes(list(user_ids))
+ user_ids_in_cache = set(
+ user_id for user_id, stream_id in user_map.items() if stream_id
)
-
- def _get_user_devices_from_cache_txn(self, txn, query_list):
- user_ids = {user_id for user_id, _ in query_list}
-
- user_ids_in_cache = set()
- for user_id in user_ids:
- stream_ids = self._simple_select_onecol_txn(
- txn,
- table="device_lists_remote_extremeties",
- keyvalues={
- "user_id": user_id,
- },
- retcol="stream_id",
- )
- if stream_ids:
- user_ids_in_cache.add(user_id)
-
user_ids_not_in_cache = user_ids - user_ids_in_cache
results = {}
@@ -361,32 +388,40 @@ class DeviceStore(SQLBaseStore):
continue
if device_id:
- content = self._simple_select_one_onecol_txn(
- txn,
- table="device_lists_remote_cache",
- keyvalues={
- "user_id": user_id,
- "device_id": device_id,
- },
- retcol="content",
- )
- results.setdefault(user_id, {})[device_id] = json.loads(content)
+ device = yield self._get_cached_user_device(user_id, device_id)
+ results.setdefault(user_id, {})[device_id] = device
else:
- devices = self._simple_select_list_txn(
- txn,
- table="device_lists_remote_cache",
- keyvalues={
- "user_id": user_id,
- },
- retcols=("device_id", "content"),
- )
- results[user_id] = {
- device["device_id"]: json.loads(device["content"])
- for device in devices
- }
- user_ids_in_cache.discard(user_id)
+ results[user_id] = yield self._get_cached_devices_for_user(user_id)
+
+ defer.returnValue((user_ids_not_in_cache, results))
+
+ @cachedInlineCallbacks(num_args=2, tree=True)
+ def _get_cached_user_device(self, user_id, device_id):
+ content = yield self._simple_select_one_onecol(
+ table="device_lists_remote_cache",
+ keyvalues={
+ "user_id": user_id,
+ "device_id": device_id,
+ },
+ retcol="content",
+ desc="_get_cached_user_device",
+ )
+ defer.returnValue(json.loads(content))
- return user_ids_not_in_cache, results
+ @cachedInlineCallbacks()
+ def _get_cached_devices_for_user(self, user_id):
+ devices = yield self._simple_select_list(
+ table="device_lists_remote_cache",
+ keyvalues={
+ "user_id": user_id,
+ },
+ retcols=("device_id", "content"),
+ desc="_get_cached_devices_for_user",
+ )
+ defer.returnValue({
+ device["device_id"]: json.loads(device["content"])
+ for device in devices
+ })
def get_devices_with_keys_by_user(self, user_id):
"""Get all devices (with any device keys) for a user
@@ -473,7 +508,7 @@ class DeviceStore(SQLBaseStore):
defer.returnValue(set(changed))
sql = """
- SELECT user_id FROM device_lists_stream WHERE stream_id > ?
+ SELECT DISTINCT user_id FROM device_lists_stream WHERE stream_id > ?
"""
rows = yield self._execute("get_user_whose_devices_changed", None, sql, from_key)
defer.returnValue(set(row[0] for row in rows))
@@ -489,7 +524,7 @@ class DeviceStore(SQLBaseStore):
WHERE stream_id > ?
"""
return self._execute(
- "get_users_and_hosts_device_list", None,
+ "get_all_device_list_changes_for_remotes", None,
sql, from_key,
)
@@ -518,6 +553,16 @@ class DeviceStore(SQLBaseStore):
host, stream_id,
)
+ # Delete older entries in the table, as we really only care about
+ # when the latest change happened.
+ txn.executemany(
+ """
+ DELETE FROM device_lists_stream
+ WHERE user_id = ? AND device_id = ? AND stream_id < ?
+ """,
+ [(user_id, device_id, stream_id) for device_id in device_ids]
+ )
+
self._simple_insert_many_txn(
txn,
table="device_lists_stream",
diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py
index 2040e022fa..b9f1365f92 100644
--- a/synapse/storage/end_to_end_keys.py
+++ b/synapse/storage/end_to_end_keys.py
@@ -93,7 +93,7 @@ class EndToEndKeyStore(SQLBaseStore):
query_clause = "user_id = ?"
query_params.append(user_id)
- if device_id:
+ if device_id is not None:
query_clause += " AND device_id = ?"
query_params.append(device_id)
diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
index ee88c61954..256e50dc20 100644
--- a/synapse/storage/event_federation.py
+++ b/synapse/storage/event_federation.py
@@ -281,15 +281,30 @@ class EventFederationStore(SQLBaseStore):
)
def get_forward_extremeties_for_room(self, room_id, stream_ordering):
+ """For a given room_id and stream_ordering, return the forward
+ extremeties of the room at that point in "time".
+
+ Throws a StoreError if we have since purged the index for
+ stream_orderings from that point.
+
+ Args:
+ room_id (str):
+ stream_ordering (int):
+
+ Returns:
+ deferred, which resolves to a list of event_ids
+ """
# We want to make the cache more effective, so we clamp to the last
# change before the given ordering.
last_change = self._events_stream_cache.get_max_pos_of_last_change(room_id)
# We don't always have a full stream_to_exterm_id table, e.g. after
# the upgrade that introduced it, so we make sure we never ask for a
- # try and pin to a stream_ordering from before a restart
+ # stream_ordering from before a restart
last_change = max(self._stream_order_on_start, last_change)
+ # provided the last_change is recent enough, we now clamp the requested
+ # stream_ordering to it.
if last_change > self.stream_ordering_month_ago:
stream_ordering = min(last_change, stream_ordering)
diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py
index 7de3e8c58c..14543b4269 100644
--- a/synapse/storage/event_push_actions.py
+++ b/synapse/storage/event_push_actions.py
@@ -15,6 +15,7 @@
from ._base import SQLBaseStore
from twisted.internet import defer
+from synapse.util.async import sleep
from synapse.util.caches.descriptors import cachedInlineCallbacks
from synapse.types import RoomStreamToken
from .stream import lower_bound
@@ -25,11 +26,46 @@ import ujson as json
logger = logging.getLogger(__name__)
+DEFAULT_NOTIF_ACTION = ["notify", {"set_tweak": "highlight", "value": False}]
+DEFAULT_HIGHLIGHT_ACTION = [
+ "notify", {"set_tweak": "sound", "value": "default"}, {"set_tweak": "highlight"}
+]
+
+
+def _serialize_action(actions, is_highlight):
+ """Custom serializer for actions. This allows us to "compress" common actions.
+
+ We use the fact that most users have the same actions for notifs (and for
+ highlights).
+ We store these default actions as the empty string rather than the full JSON.
+ Since the empty string isn't valid JSON there is no risk of this clashing with
+ any real JSON actions
+ """
+ if is_highlight:
+ if actions == DEFAULT_HIGHLIGHT_ACTION:
+ return "" # We use empty string as the column is non-NULL
+ else:
+ if actions == DEFAULT_NOTIF_ACTION:
+ return ""
+ return json.dumps(actions)
+
+
+def _deserialize_action(actions, is_highlight):
+ """Custom deserializer for actions. This allows us to "compress" common actions
+ """
+ if actions:
+ return json.loads(actions)
+
+ if is_highlight:
+ return DEFAULT_HIGHLIGHT_ACTION
+ else:
+ return DEFAULT_NOTIF_ACTION
+
+
class EventPushActionsStore(SQLBaseStore):
EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
def __init__(self, hs):
- self.stream_ordering_month_ago = None
super(EventPushActionsStore, self).__init__(hs)
self.register_background_index_update(
@@ -47,6 +83,11 @@ class EventPushActionsStore(SQLBaseStore):
where_clause="highlight=1"
)
+ self._doing_notif_rotation = False
+ self._rotate_notif_loop = self._clock.looping_call(
+ self._rotate_notifs, 30 * 60 * 1000
+ )
+
def _set_push_actions_for_event_and_users_txn(self, txn, event, tuples):
"""
Args:
@@ -55,15 +96,17 @@ class EventPushActionsStore(SQLBaseStore):
"""
values = []
for uid, actions in tuples:
+ is_highlight = 1 if _action_has_highlight(actions) else 0
+
values.append({
'room_id': event.room_id,
'event_id': event.event_id,
'user_id': uid,
- 'actions': json.dumps(actions),
+ 'actions': _serialize_action(actions, is_highlight),
'stream_ordering': event.internal_metadata.stream_ordering,
'topological_ordering': event.depth,
'notif': 1,
- 'highlight': 1 if _action_has_highlight(actions) else 0,
+ 'highlight': is_highlight,
})
for uid, __ in tuples:
@@ -77,66 +120,83 @@ class EventPushActionsStore(SQLBaseStore):
def get_unread_event_push_actions_by_room_for_user(
self, room_id, user_id, last_read_event_id
):
- def _get_unread_event_push_actions_by_room(txn):
- sql = (
- "SELECT stream_ordering, topological_ordering"
- " FROM events"
- " WHERE room_id = ? AND event_id = ?"
- )
- txn.execute(
- sql, (room_id, last_read_event_id)
- )
- results = txn.fetchall()
- if len(results) == 0:
- return {"notify_count": 0, "highlight_count": 0}
-
- stream_ordering = results[0][0]
- topological_ordering = results[0][1]
- token = RoomStreamToken(
- topological_ordering, stream_ordering
- )
-
- # First get number of notifications.
- # We don't need to put a notif=1 clause as all rows always have
- # notif=1
- sql = (
- "SELECT count(*)"
- " FROM event_push_actions ea"
- " WHERE"
- " user_id = ?"
- " AND room_id = ?"
- " AND %s"
- ) % (lower_bound(token, self.database_engine, inclusive=False),)
-
- txn.execute(sql, (user_id, room_id))
- row = txn.fetchone()
- notify_count = row[0] if row else 0
+ ret = yield self.runInteraction(
+ "get_unread_event_push_actions_by_room",
+ self._get_unread_counts_by_receipt_txn,
+ room_id, user_id, last_read_event_id
+ )
+ defer.returnValue(ret)
- # Now get the number of highlights
- sql = (
- "SELECT count(*)"
- " FROM event_push_actions ea"
- " WHERE"
- " highlight = 1"
- " AND user_id = ?"
- " AND room_id = ?"
- " AND %s"
- ) % (lower_bound(token, self.database_engine, inclusive=False),)
+ def _get_unread_counts_by_receipt_txn(self, txn, room_id, user_id,
+ last_read_event_id):
+ sql = (
+ "SELECT stream_ordering, topological_ordering"
+ " FROM events"
+ " WHERE room_id = ? AND event_id = ?"
+ )
+ txn.execute(
+ sql, (room_id, last_read_event_id)
+ )
+ results = txn.fetchall()
+ if len(results) == 0:
+ return {"notify_count": 0, "highlight_count": 0}
- txn.execute(sql, (user_id, room_id))
- row = txn.fetchone()
- highlight_count = row[0] if row else 0
+ stream_ordering = results[0][0]
+ topological_ordering = results[0][1]
- return {
- "notify_count": notify_count,
- "highlight_count": highlight_count,
- }
+ return self._get_unread_counts_by_pos_txn(
+ txn, room_id, user_id, topological_ordering, stream_ordering
+ )
- ret = yield self.runInteraction(
- "get_unread_event_push_actions_by_room",
- _get_unread_event_push_actions_by_room
+ def _get_unread_counts_by_pos_txn(self, txn, room_id, user_id, topological_ordering,
+ stream_ordering):
+ token = RoomStreamToken(
+ topological_ordering, stream_ordering
)
- defer.returnValue(ret)
+
+ # First get number of notifications.
+ # We don't need to put a notif=1 clause as all rows always have
+ # notif=1
+ sql = (
+ "SELECT count(*)"
+ " FROM event_push_actions ea"
+ " WHERE"
+ " user_id = ?"
+ " AND room_id = ?"
+ " AND %s"
+ ) % (lower_bound(token, self.database_engine, inclusive=False),)
+
+ txn.execute(sql, (user_id, room_id))
+ row = txn.fetchone()
+ notify_count = row[0] if row else 0
+
+ txn.execute("""
+ SELECT notif_count FROM event_push_summary
+ WHERE room_id = ? AND user_id = ? AND stream_ordering > ?
+ """, (room_id, user_id, stream_ordering,))
+ rows = txn.fetchall()
+ if rows:
+ notify_count += rows[0][0]
+
+ # Now get the number of highlights
+ sql = (
+ "SELECT count(*)"
+ " FROM event_push_actions ea"
+ " WHERE"
+ " highlight = 1"
+ " AND user_id = ?"
+ " AND room_id = ?"
+ " AND %s"
+ ) % (lower_bound(token, self.database_engine, inclusive=False),)
+
+ txn.execute(sql, (user_id, room_id))
+ row = txn.fetchone()
+ highlight_count = row[0] if row else 0
+
+ return {
+ "notify_count": notify_count,
+ "highlight_count": highlight_count,
+ }
@defer.inlineCallbacks
def get_push_action_users_in_range(self, min_stream_ordering, max_stream_ordering):
@@ -176,7 +236,8 @@ class EventPushActionsStore(SQLBaseStore):
# find rooms that have a read receipt in them and return the next
# push actions
sql = (
- "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions"
+ "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
+ " ep.highlight "
" FROM ("
" SELECT room_id,"
" MAX(topological_ordering) as topological_ordering,"
@@ -217,7 +278,7 @@ class EventPushActionsStore(SQLBaseStore):
def get_no_receipt(txn):
sql = (
"SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
- " e.received_ts"
+ " ep.highlight "
" FROM event_push_actions AS ep"
" INNER JOIN events AS e USING (room_id, event_id)"
" WHERE"
@@ -246,7 +307,7 @@ class EventPushActionsStore(SQLBaseStore):
"event_id": row[0],
"room_id": row[1],
"stream_ordering": row[2],
- "actions": json.loads(row[3]),
+ "actions": _deserialize_action(row[3], row[4]),
} for row in after_read_receipt + no_read_receipt
]
@@ -285,7 +346,7 @@ class EventPushActionsStore(SQLBaseStore):
def get_after_receipt(txn):
sql = (
"SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
- " e.received_ts"
+ " ep.highlight, e.received_ts"
" FROM ("
" SELECT room_id,"
" MAX(topological_ordering) as topological_ordering,"
@@ -327,7 +388,7 @@ class EventPushActionsStore(SQLBaseStore):
def get_no_receipt(txn):
sql = (
"SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
- " e.received_ts"
+ " ep.highlight, e.received_ts"
" FROM event_push_actions AS ep"
" INNER JOIN events AS e USING (room_id, event_id)"
" WHERE"
@@ -357,8 +418,8 @@ class EventPushActionsStore(SQLBaseStore):
"event_id": row[0],
"room_id": row[1],
"stream_ordering": row[2],
- "actions": json.loads(row[3]),
- "received_ts": row[4],
+ "actions": _deserialize_action(row[3], row[4]),
+ "received_ts": row[5],
} for row in after_read_receipt + no_read_receipt
]
@@ -392,7 +453,7 @@ class EventPushActionsStore(SQLBaseStore):
sql = (
"SELECT epa.event_id, epa.room_id,"
" epa.stream_ordering, epa.topological_ordering,"
- " epa.actions, epa.profile_tag, e.received_ts"
+ " epa.actions, epa.highlight, epa.profile_tag, e.received_ts"
" FROM event_push_actions epa, events e"
" WHERE epa.event_id = e.event_id"
" AND epa.user_id = ? %s"
@@ -407,7 +468,7 @@ class EventPushActionsStore(SQLBaseStore):
"get_push_actions_for_user", f
)
for pa in push_actions:
- pa["actions"] = json.loads(pa["actions"])
+ pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"])
defer.returnValue(push_actions)
@defer.inlineCallbacks
@@ -448,10 +509,14 @@ class EventPushActionsStore(SQLBaseStore):
)
def _remove_old_push_actions_before_txn(self, txn, room_id, user_id,
- topological_ordering):
+ topological_ordering, stream_ordering):
"""
- Purges old, stale push actions for a user and room before a given
- topological_ordering
+ Purges old push actions for a user and room before a given
+ topological_ordering.
+
+ We however keep a months worth of highlighted notifications, so that
+ users can still get a list of recent highlights.
+
Args:
txn: The transcation
room_id: Room ID to delete from
@@ -475,10 +540,16 @@ class EventPushActionsStore(SQLBaseStore):
txn.execute(
"DELETE FROM event_push_actions "
" WHERE user_id = ? AND room_id = ? AND "
- " topological_ordering < ? AND stream_ordering < ?",
+ " topological_ordering <= ?"
+ " AND ((stream_ordering < ? AND highlight = 1) or highlight = 0)",
(user_id, room_id, topological_ordering, self.stream_ordering_month_ago)
)
+ txn.execute("""
+ DELETE FROM event_push_summary
+ WHERE room_id = ? AND user_id = ? AND stream_ordering <= ?
+ """, (room_id, user_id, stream_ordering))
+
@defer.inlineCallbacks
def _find_stream_orderings_for_times(self):
yield self.runInteraction(
@@ -495,6 +566,14 @@ class EventPushActionsStore(SQLBaseStore):
"Found stream ordering 1 month ago: it's %d",
self.stream_ordering_month_ago
)
+ logger.info("Searching for stream ordering 1 day ago")
+ self.stream_ordering_day_ago = self._find_first_stream_ordering_after_ts_txn(
+ txn, self._clock.time_msec() - 24 * 60 * 60 * 1000
+ )
+ logger.info(
+ "Found stream ordering 1 day ago: it's %d",
+ self.stream_ordering_day_ago
+ )
def _find_first_stream_ordering_after_ts_txn(self, txn, ts):
"""
@@ -534,6 +613,131 @@ class EventPushActionsStore(SQLBaseStore):
return range_end
+ @defer.inlineCallbacks
+ def _rotate_notifs(self):
+ if self._doing_notif_rotation or self.stream_ordering_day_ago is None:
+ return
+ self._doing_notif_rotation = True
+
+ try:
+ while True:
+ logger.info("Rotating notifications")
+
+ caught_up = yield self.runInteraction(
+ "_rotate_notifs",
+ self._rotate_notifs_txn
+ )
+ if caught_up:
+ break
+ yield sleep(5)
+ finally:
+ self._doing_notif_rotation = False
+
+ def _rotate_notifs_txn(self, txn):
+ """Archives older notifications into event_push_summary. Returns whether
+ the archiving process has caught up or not.
+ """
+
+ old_rotate_stream_ordering = self._simple_select_one_onecol_txn(
+ txn,
+ table="event_push_summary_stream_ordering",
+ keyvalues={},
+ retcol="stream_ordering",
+ )
+
+ # We don't to try and rotate millions of rows at once, so we cap the
+ # maximum stream ordering we'll rotate before.
+ txn.execute("""
+ SELECT stream_ordering FROM event_push_actions
+ WHERE stream_ordering > ?
+ ORDER BY stream_ordering ASC LIMIT 1 OFFSET 50000
+ """, (old_rotate_stream_ordering,))
+ stream_row = txn.fetchone()
+ if stream_row:
+ offset_stream_ordering, = stream_row
+ rotate_to_stream_ordering = min(
+ self.stream_ordering_day_ago, offset_stream_ordering
+ )
+ caught_up = offset_stream_ordering >= self.stream_ordering_day_ago
+ else:
+ rotate_to_stream_ordering = self.stream_ordering_day_ago
+ caught_up = True
+
+ logger.info("Rotating notifications up to: %s", rotate_to_stream_ordering)
+
+ self._rotate_notifs_before_txn(txn, rotate_to_stream_ordering)
+
+ # We have caught up iff we were limited by `stream_ordering_day_ago`
+ return caught_up
+
+ def _rotate_notifs_before_txn(self, txn, rotate_to_stream_ordering):
+ old_rotate_stream_ordering = self._simple_select_one_onecol_txn(
+ txn,
+ table="event_push_summary_stream_ordering",
+ keyvalues={},
+ retcol="stream_ordering",
+ )
+
+ # Calculate the new counts that should be upserted into event_push_summary
+ sql = """
+ SELECT user_id, room_id,
+ coalesce(old.notif_count, 0) + upd.notif_count,
+ upd.stream_ordering,
+ old.user_id
+ FROM (
+ SELECT user_id, room_id, count(*) as notif_count,
+ max(stream_ordering) as stream_ordering
+ FROM event_push_actions
+ WHERE ? <= stream_ordering AND stream_ordering < ?
+ AND highlight = 0
+ GROUP BY user_id, room_id
+ ) AS upd
+ LEFT JOIN event_push_summary AS old USING (user_id, room_id)
+ """
+
+ txn.execute(sql, (old_rotate_stream_ordering, rotate_to_stream_ordering,))
+ rows = txn.fetchall()
+
+ logger.info("Rotating notifications, handling %d rows", len(rows))
+
+ # If the `old.user_id` above is NULL then we know there isn't already an
+ # entry in the table, so we simply insert it. Otherwise we update the
+ # existing table.
+ self._simple_insert_many_txn(
+ txn,
+ table="event_push_summary",
+ values=[
+ {
+ "user_id": row[0],
+ "room_id": row[1],
+ "notif_count": row[2],
+ "stream_ordering": row[3],
+ }
+ for row in rows if row[4] is None
+ ]
+ )
+
+ txn.executemany(
+ """
+ UPDATE event_push_summary SET notif_count = ?, stream_ordering = ?
+ WHERE user_id = ? AND room_id = ?
+ """,
+ ((row[2], row[3], row[0], row[1],) for row in rows if row[4] is not None)
+ )
+
+ txn.execute(
+ "DELETE FROM event_push_actions"
+ " WHERE ? <= stream_ordering AND stream_ordering < ? AND highlight = 0",
+ (old_rotate_stream_ordering, rotate_to_stream_ordering,)
+ )
+
+ logger.info("Rotating notifications, deleted %s push actions", txn.rowcount)
+
+ txn.execute(
+ "UPDATE event_push_summary_stream_ordering SET stream_ordering = ?",
+ (rotate_to_stream_ordering,)
+ )
+
def _action_has_highlight(actions):
for action in actions:
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index c88f689d3a..db01eb6d14 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -311,6 +311,21 @@ class EventsStore(SQLBaseStore):
new_forward_extremeties[room_id] = new_latest_event_ids
+ len_1 = (
+ len(latest_event_ids) == 1
+ and len(new_latest_event_ids) == 1
+ )
+ if len_1:
+ all_single_prev_not_state = all(
+ len(event.prev_events) == 1
+ and not event.is_state()
+ for event, ctx in ev_ctx_rm
+ )
+ # Don't bother calculating state if they're just
+ # a long chain of single ancestor non-state events.
+ if all_single_prev_not_state:
+ continue
+
state = yield self._calculate_state_delta(
room_id, ev_ctx_rm, new_latest_event_ids
)
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index b357f22be7..ed84db6b4b 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts.
-SCHEMA_VERSION = 40
+SCHEMA_VERSION = 41
dir_path = os.path.abspath(os.path.dirname(__file__))
diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py
index 7460f98a1f..4d1590d2b4 100644
--- a/synapse/storage/presence.py
+++ b/synapse/storage/presence.py
@@ -15,7 +15,7 @@
from ._base import SQLBaseStore
from synapse.api.constants import PresenceState
-from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
+from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
from collections import namedtuple
from twisted.internet import defer
@@ -85,6 +85,9 @@ class PresenceStore(SQLBaseStore):
self.presence_stream_cache.entity_has_changed,
state.user_id, stream_id,
)
+ self._invalidate_cache_and_stream(
+ txn, self._get_presence_for_user, (state.user_id,)
+ )
# Actually insert new rows
self._simple_insert_many_txn(
@@ -143,7 +146,12 @@ class PresenceStore(SQLBaseStore):
"get_all_presence_updates", get_all_presence_updates_txn
)
- @defer.inlineCallbacks
+ @cached()
+ def _get_presence_for_user(self, user_id):
+ raise NotImplementedError()
+
+ @cachedList(cached_method_name="_get_presence_for_user", list_name="user_ids",
+ num_args=1, inlineCallbacks=True)
def get_presence_for_users(self, user_ids):
rows = yield self._simple_select_many_batch(
table="presence_stream",
@@ -165,7 +173,7 @@ class PresenceStore(SQLBaseStore):
for row in rows:
row["currently_active"] = bool(row["currently_active"])
- defer.returnValue([UserPresenceState(**row) for row in rows])
+ defer.returnValue({row["user_id"]: UserPresenceState(**row) for row in rows})
def get_current_presence_token(self):
return self._presence_id_gen.get_current_token()
diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py
index f72d15f5ed..5cf41501ea 100644
--- a/synapse/storage/receipts.py
+++ b/synapse/storage/receipts.py
@@ -351,6 +351,7 @@ class ReceiptsStore(SQLBaseStore):
room_id=room_id,
user_id=user_id,
topological_ordering=topological_ordering,
+ stream_ordering=stream_ordering,
)
return True
diff --git a/synapse/storage/schema/delta/40/event_push_summary.sql b/synapse/storage/schema/delta/40/event_push_summary.sql
new file mode 100644
index 0000000000..3918f0b794
--- /dev/null
+++ b/synapse/storage/schema/delta/40/event_push_summary.sql
@@ -0,0 +1,37 @@
+/* Copyright 2017 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Aggregate of old notification counts that have been deleted out of the
+-- main event_push_actions table. This count does not include those that were
+-- highlights, as they remain in the event_push_actions table.
+CREATE TABLE event_push_summary (
+ user_id TEXT NOT NULL,
+ room_id TEXT NOT NULL,
+ notif_count BIGINT NOT NULL,
+ stream_ordering BIGINT NOT NULL
+);
+
+CREATE INDEX event_push_summary_user_rm ON event_push_summary(user_id, room_id);
+
+
+-- The stream ordering up to which we have aggregated the event_push_actions
+-- table into event_push_summary
+CREATE TABLE event_push_summary_stream_ordering (
+ Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row.
+ stream_ordering BIGINT NOT NULL,
+ CHECK (Lock='X')
+);
+
+INSERT INTO event_push_summary_stream_ordering (stream_ordering) VALUES (0);
diff --git a/synapse/storage/schema/delta/40/pushers.sql b/synapse/storage/schema/delta/40/pushers.sql
new file mode 100644
index 0000000000..054a223f14
--- /dev/null
+++ b/synapse/storage/schema/delta/40/pushers.sql
@@ -0,0 +1,39 @@
+/* Copyright 2017 Vector Creations Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE IF NOT EXISTS pushers2 (
+ id BIGINT PRIMARY KEY,
+ user_name TEXT NOT NULL,
+ access_token BIGINT DEFAULT NULL,
+ profile_tag TEXT NOT NULL,
+ kind TEXT NOT NULL,
+ app_id TEXT NOT NULL,
+ app_display_name TEXT NOT NULL,
+ device_display_name TEXT NOT NULL,
+ pushkey TEXT NOT NULL,
+ ts BIGINT NOT NULL,
+ lang TEXT,
+ data TEXT,
+ last_stream_ordering INTEGER,
+ last_success BIGINT,
+ failing_since BIGINT,
+ UNIQUE (app_id, pushkey, user_name)
+);
+
+INSERT INTO pushers2 SELECT * FROM PUSHERS;
+
+DROP TABLE PUSHERS;
+
+ALTER TABLE pushers2 RENAME TO pushers;
diff --git a/synapse/storage/schema/delta/41/device_list_stream_idx.sql b/synapse/storage/schema/delta/41/device_list_stream_idx.sql
new file mode 100644
index 0000000000..b7bee8b692
--- /dev/null
+++ b/synapse/storage/schema/delta/41/device_list_stream_idx.sql
@@ -0,0 +1,17 @@
+/* Copyright 2017 Vector Creations Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+INSERT into background_updates (update_name, progress_json)
+ VALUES ('device_lists_stream_idx', '{}');
diff --git a/synapse/storage/schema/delta/41/device_outbound_index.sql b/synapse/storage/schema/delta/41/device_outbound_index.sql
new file mode 100644
index 0000000000..62f0b9892b
--- /dev/null
+++ b/synapse/storage/schema/delta/41/device_outbound_index.sql
@@ -0,0 +1,16 @@
+/* Copyright 2017 Vector Creations Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE INDEX device_lists_outbound_pokes_stream ON device_lists_outbound_pokes(stream_id);
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 1b3800eb6a..84482d8285 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -413,7 +413,19 @@ class StateStore(SQLBaseStore):
defer.returnValue({event: event_to_state[event] for event in event_ids})
@defer.inlineCallbacks
- def get_state_ids_for_events(self, event_ids, types):
+ def get_state_ids_for_events(self, event_ids, types=None):
+ """
+ Get the state dicts corresponding to a list of events
+
+ Args:
+ event_ids(list(str)): events whose state should be returned
+ types(list[(str, str)]|None): List of (type, state_key) tuples
+ which are used to filter the state fetched. May be None, which
+ matches any key
+
+ Returns:
+ A deferred dict from event_id -> (type, state_key) -> state_event
+ """
event_to_groups = yield self._get_state_group_for_events(
event_ids,
)
diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py
index 2987c38a2d..cbdde34a57 100644
--- a/synapse/util/caches/expiringcache.py
+++ b/synapse/util/caches/expiringcache.py
@@ -100,6 +100,13 @@ class ExpiringCache(object):
except KeyError:
return default
+ def setdefault(self, key, value):
+ try:
+ return self[key]
+ except KeyError:
+ self[key] = value
+ return value
+
def _prune_cache(self):
if not self._expiry_ms:
# zero expiry time means don't expire. This should never get called
|