diff options
Diffstat (limited to 'synapse')
40 files changed, 1321 insertions, 224 deletions
diff --git a/synapse/__init__.py b/synapse/__init__.py index ff251ce597..7628e7c505 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -16,4 +16,4 @@ """ This is a reference implementation of a Matrix home server. """ -__version__ = "0.19.2" +__version__ = "0.19.3" diff --git a/synapse/api/constants.py b/synapse/api/constants.py index a8123cddcb..ca23c9c460 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -43,9 +43,6 @@ class JoinRules(object): class LoginType(object): PASSWORD = u"m.login.password" - OAUTH = u"m.login.oauth2" - EMAIL_CODE = u"m.login.email.code" - EMAIL_URL = u"m.login.email.url" EMAIL_IDENTITY = u"m.login.email.identity" RECAPTCHA = u"m.login.recaptcha" DUMMY = u"m.login.dummy" diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py index b3fb408cfd..3f29595256 100644 --- a/synapse/app/synchrotron.py +++ b/synapse/app/synchrotron.py @@ -87,6 +87,10 @@ class SynchrotronSlavedStore( RoomMemberStore.__dict__["who_forgot_in_room"] ) + did_forget = ( + RoomMemberStore.__dict__["did_forget"] + ) + # XXX: This is a bit broken because we don't persist the accepted list in a # way that can be replicated. This means that we don't have a way to # invalidate the cache correctly. diff --git a/synapse/config/tls.py b/synapse/config/tls.py index 3c58d2de17..e081840a83 100644 --- a/synapse/config/tls.py +++ b/synapse/config/tls.py @@ -95,7 +95,7 @@ class TlsConfig(Config): # make HTTPS requests to this server will check that the TLS # certificates returned by this server match one of the fingerprints. # - # Synapse automatically adds its the fingerprint of its own certificate + # Synapse automatically adds the fingerprint of its own certificate # to the list. So if federation traffic is handle directly by synapse # then no modification to the list is required. # diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index b5bcfd705a..5dcd4eecce 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -206,8 +206,7 @@ class FederationClient(FederationBase): Args: destinations (list): Which home servers to query - pdu_origin (str): The home server that originally sent the pdu. - event_id (str) + event_id (str): event to fetch outlier (bool): Indicates whether the PDU is an `outlier`, i.e. if it's from an arbitary point in the context as opposed to part of the current block of PDUs. Defaults to `False` diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py index bb3d9258a6..90235ff098 100644 --- a/synapse/federation/transaction_queue.py +++ b/synapse/federation/transaction_queue.py @@ -303,18 +303,10 @@ class TransactionQueue(object): try: self.pending_transactions[destination] = 1 + # XXX: what's this for? yield run_on_reactor() while True: - pending_pdus = self.pending_pdus_by_dest.pop(destination, []) - pending_edus = self.pending_edus_by_dest.pop(destination, []) - pending_presence = self.pending_presence_by_dest.pop(destination, {}) - pending_failures = self.pending_failures_by_dest.pop(destination, []) - - pending_edus.extend( - self.pending_edus_keyed_by_dest.pop(destination, {}).values() - ) - limiter = yield get_retry_limiter( destination, self.clock, @@ -326,6 +318,24 @@ class TransactionQueue(object): yield self._get_new_device_messages(destination) ) + # BEGIN CRITICAL SECTION + # + # In order to avoid a race condition, we need to make sure that + # the following code (from popping the queues up to the point + # where we decide if we actually have any pending messages) is + # atomic - otherwise new PDUs or EDUs might arrive in the + # meantime, but not get sent because we hold the + # pending_transactions flag. + + pending_pdus = self.pending_pdus_by_dest.pop(destination, []) + pending_edus = self.pending_edus_by_dest.pop(destination, []) + pending_presence = self.pending_presence_by_dest.pop(destination, {}) + pending_failures = self.pending_failures_by_dest.pop(destination, []) + + pending_edus.extend( + self.pending_edus_keyed_by_dest.pop(destination, {}).values() + ) + pending_edus.extend(device_message_edus) if pending_presence: pending_edus.append( @@ -355,6 +365,8 @@ class TransactionQueue(object): ) return + # END CRITICAL SECTION + success = yield self._send_new_transaction( destination, pending_pdus, pending_edus, pending_failures, limiter=limiter, diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index 084e33ca6a..f36b358b45 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -19,7 +19,6 @@ from ._base import BaseHandler import logging - logger = logging.getLogger(__name__) @@ -54,3 +53,46 @@ class AdminHandler(BaseHandler): } defer.returnValue(ret) + + @defer.inlineCallbacks + def get_users(self): + """Function to reterive a list of users in users table. + + Args: + Returns: + defer.Deferred: resolves to list[dict[str, Any]] + """ + ret = yield self.store.get_users() + + defer.returnValue(ret) + + @defer.inlineCallbacks + def get_users_paginate(self, order, start, limit): + """Function to reterive a paginated list of users from + users list. This will return a json object, which contains + list of users and the total number of users in users table. + + Args: + order (str): column name to order the select by this column + start (int): start number to begin the query from + limit (int): number of rows to reterive + Returns: + defer.Deferred: resolves to json object {list[dict[str, Any]], count} + """ + ret = yield self.store.get_users_paginate(order, start, limit) + + defer.returnValue(ret) + + @defer.inlineCallbacks + def search_users(self, term): + """Function to search users list for one or more users with + the matched term. + + Args: + term (str): search term + Returns: + defer.Deferred: resolves to list[dict[str, Any]] + """ + ret = yield self.store.search_users(term) + + defer.returnValue(ret) diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 8cb47ac417..e859b3165f 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -12,11 +12,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from synapse.api import errors from synapse.api.constants import EventTypes from synapse.util import stringutils from synapse.util.async import Linearizer +from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.metrics import measure_func from synapse.types import get_domain_from_id, RoomStreamToken from twisted.internet import defer @@ -35,10 +35,11 @@ class DeviceHandler(BaseHandler): self.state = hs.get_state_handler() self.federation_sender = hs.get_federation_sender() self.federation = hs.get_replication_layer() - self._remote_edue_linearizer = Linearizer(name="remote_device_list") + + self._edu_updater = DeviceListEduUpdater(hs, self) self.federation.register_edu_handler( - "m.device_list_update", self._incoming_device_list_update, + "m.device_list_update", self._edu_updater.incoming_device_list_update, ) self.federation.register_query_handler( "user_devices", self.on_federation_query_user_devices, @@ -246,30 +247,51 @@ class DeviceHandler(BaseHandler): # Then work out if any users have since joined rooms_changed = self.store.get_rooms_that_changed(room_ids, from_token.room_key) + stream_ordering = RoomStreamToken.parse_stream_token( + from_token.room_key).stream + possibly_changed = set(changed) for room_id in rooms_changed: - # Fetch the current state at the time. - stream_ordering = RoomStreamToken.parse_stream_token(from_token.room_key) - + # Fetch the current state at the time. try: event_ids = yield self.store.get_forward_extremeties_for_room( room_id, stream_ordering=stream_ordering ) - prev_state_ids = yield self.store.get_state_ids_for_events(event_ids) - except: - prev_state_ids = {} + except errors.StoreError: + # we have purged the stream_ordering index since the stream + # ordering: treat it the same as a new room + event_ids = [] current_state_ids = yield self.state.get_current_state_ids(room_id) + # special-case for an empty prev state: include all members + # in the changed list + if not event_ids: + for key, event_id in current_state_ids.iteritems(): + etype, state_key = key + if etype != EventTypes.Member: + continue + possibly_changed.add(state_key) + continue + + # mapping from event_id -> state_dict + prev_state_ids = yield self.store.get_state_ids_for_events(event_ids) + # If there has been any change in membership, include them in the # possibly changed list. We'll check if they are joined below, # and we're not toooo worried about spuriously adding users. for key, event_id in current_state_ids.iteritems(): etype, state_key = key - if etype == EventTypes.Member: - prev_event_id = prev_state_ids.get(key, None) + if etype != EventTypes.Member: + continue + + # check if this member has changed since any of the extremities + # at the stream_ordering, and add them to the list if so. + for state_dict in prev_state_ids.values(): + prev_event_id = state_dict.get(key, None) if not prev_event_id or prev_event_id != event_id: possibly_changed.add(state_key) + break users_who_share_room = yield self.store.get_users_who_share_room_with_user( user_id @@ -279,13 +301,69 @@ class DeviceHandler(BaseHandler): # and those that actually still share a room with the user defer.returnValue(users_who_share_room & possibly_changed) - @measure_func("_incoming_device_list_update") @defer.inlineCallbacks - def _incoming_device_list_update(self, origin, edu_content): - user_id = edu_content["user_id"] - device_id = edu_content["device_id"] - stream_id = edu_content["stream_id"] - prev_ids = edu_content.get("prev_id", []) + def on_federation_query_user_devices(self, user_id): + stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id) + defer.returnValue({ + "user_id": user_id, + "stream_id": stream_id, + "devices": devices, + }) + + @defer.inlineCallbacks + def user_left_room(self, user, room_id): + user_id = user.to_string() + rooms = yield self.store.get_rooms_for_user(user_id) + if not rooms: + # We no longer share rooms with this user, so we'll no longer + # receive device updates. Mark this in DB. + yield self.store.mark_remote_user_device_list_as_unsubscribed(user_id) + + +def _update_device_from_client_ips(device, client_ips): + ip = client_ips.get((device["user_id"], device["device_id"]), {}) + device.update({ + "last_seen_ts": ip.get("last_seen"), + "last_seen_ip": ip.get("ip"), + }) + + +class DeviceListEduUpdater(object): + "Handles incoming device list updates from federation and updates the DB" + + def __init__(self, hs, device_handler): + self.store = hs.get_datastore() + self.federation = hs.get_replication_layer() + self.clock = hs.get_clock() + self.device_handler = device_handler + + self._remote_edu_linearizer = Linearizer(name="remote_device_list") + + # user_id -> list of updates waiting to be handled. + self._pending_updates = {} + + # Recently seen stream ids. We don't bother keeping these in the DB, + # but they're useful to have them about to reduce the number of spurious + # resyncs. + self._seen_updates = ExpiringCache( + cache_name="device_update_edu", + clock=self.clock, + max_len=10000, + expiry_ms=30 * 60 * 1000, + iterable=True, + ) + + @defer.inlineCallbacks + def incoming_device_list_update(self, origin, edu_content): + """Called on incoming device list update from federation. Responsible + for parsing the EDU and adding to pending updates list. + """ + + user_id = edu_content.pop("user_id") + device_id = edu_content.pop("device_id") + stream_id = str(edu_content.pop("stream_id")) # They may come as ints + prev_ids = edu_content.pop("prev_id", []) + prev_ids = [str(p) for p in prev_ids] # They may come as ints if get_domain_from_id(user_id) != origin: # TODO: Raise? @@ -298,20 +376,28 @@ class DeviceHandler(BaseHandler): # probably won't get any further updates. return - with (yield self._remote_edue_linearizer.queue(user_id)): - # If the prev id matches whats in our cache table, then we don't need - # to resync the users device list, otherwise we do. - resync = True - if len(prev_ids) == 1: - extremity = yield self.store.get_device_list_last_stream_id_for_remote( - user_id - ) - logger.info("Extrem: %r, prev_ids: %r", extremity, prev_ids) - if str(extremity) == str(prev_ids[0]): - resync = False + self._pending_updates.setdefault(user_id, []).append( + (device_id, stream_id, prev_ids, edu_content) + ) + + yield self._handle_device_updates(user_id) + + @measure_func("_incoming_device_list_update") + @defer.inlineCallbacks + def _handle_device_updates(self, user_id): + "Actually handle pending updates." + + with (yield self._remote_edu_linearizer.queue(user_id)): + pending_updates = self._pending_updates.pop(user_id, []) + if not pending_updates: + # This can happen since we batch updates + return + + resync = yield self._need_to_do_resync(user_id, pending_updates) if resync: # Fetch all devices for the user. + origin = get_domain_from_id(user_id) result = yield self.federation.query_user_devices(origin, user_id) stream_id = result["stream_id"] devices = result["devices"] @@ -319,40 +405,50 @@ class DeviceHandler(BaseHandler): user_id, devices, stream_id, ) device_ids = [device["device_id"] for device in devices] - yield self.notify_device_update(user_id, device_ids) + yield self.device_handler.notify_device_update(user_id, device_ids) else: # Simply update the single device, since we know that is the only # change (becuase of the single prev_id matching the current cache) - content = dict(edu_content) - for key in ("user_id", "device_id", "stream_id", "prev_ids"): - content.pop(key, None) - yield self.store.update_remote_device_list_cache_entry( - user_id, device_id, content, stream_id, + for device_id, stream_id, prev_ids, content in pending_updates: + yield self.store.update_remote_device_list_cache_entry( + user_id, device_id, content, stream_id, + ) + + yield self.device_handler.notify_device_update( + user_id, [device_id for device_id, _, _, _ in pending_updates] ) - yield self.notify_device_update(user_id, [device_id]) - @defer.inlineCallbacks - def on_federation_query_user_devices(self, user_id): - stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id) - defer.returnValue({ - "user_id": user_id, - "stream_id": stream_id, - "devices": devices, - }) + self._seen_updates.setdefault(user_id, set()).update( + stream_id for _, stream_id, _, _ in pending_updates + ) @defer.inlineCallbacks - def user_left_room(self, user, room_id): - user_id = user.to_string() - rooms = yield self.store.get_rooms_for_user(user_id) - if not rooms: - # We no longer share rooms with this user, so we'll no longer - # receive device updates. Mark this in DB. - yield self.store.mark_remote_user_device_list_as_unsubscribed(user_id) + def _need_to_do_resync(self, user_id, updates): + """Given a list of updates for a user figure out if we need to do a full + resync, or whether we have enough data that we can just apply the delta. + """ + seen_updates = self._seen_updates.get(user_id, set()) + extremity = yield self.store.get_device_list_last_stream_id_for_remote( + user_id + ) -def _update_device_from_client_ips(device, client_ips): - ip = client_ips.get((device["user_id"], device["device_id"]), {}) - device.update({ - "last_seen_ts": ip.get("last_seen"), - "last_seen_ip": ip.get("ip"), - }) + stream_id_in_updates = set() # stream_ids in updates list + for _, stream_id, prev_ids, _ in updates: + if not prev_ids: + # We always do a resync if there are no previous IDs + defer.returnValue(True) + + for prev_id in prev_ids: + if prev_id == extremity: + continue + elif prev_id in seen_updates: + continue + elif prev_id in stream_id_in_updates: + continue + else: + defer.returnValue(True) + + stream_id_in_updates.add(stream_id) + + defer.returnValue(False) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 996bfd0e23..ed0fa51e7f 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1096,7 +1096,7 @@ class FederationHandler(BaseHandler): if prev_id != event.event_id: results[(event.type, event.state_key)] = prev_id else: - del results[(event.type, event.state_key)] + results.pop((event.type, event.state_key), None) defer.returnValue(results.values()) else: diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index fdfce2a88c..da610e430f 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -531,7 +531,7 @@ class PresenceHandler(object): # There are things not in our in memory cache. Lets pull them out of # the database. res = yield self.store.get_presence_for_users(missing) - states.update({state.user_id: state for state in res}) + states.update(res) missing = [user_id for user_id, state in states.items() if not state] if missing: diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index b2806555cf..2052d6d05f 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -719,7 +719,9 @@ class RoomMemberHandler(BaseHandler): ) membership = member.membership if member else None - if membership is not None and membership != Membership.LEAVE: + if membership is not None and membership not in [ + Membership.LEAVE, Membership.BAN + ]: raise SynapseError(400, "User %s in room %s" % ( user_id, room_id )) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index d7dcd1ce5b..5572cb883f 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -609,14 +609,14 @@ class SyncHandler(object): deleted = yield self.store.delete_messages_for_device( user_id, device_id, since_stream_id ) - logger.info("Deleted %d to-device messages up to %d", - deleted, since_stream_id) + logger.debug("Deleted %d to-device messages up to %d", + deleted, since_stream_id) messages, stream_id = yield self.store.get_new_messages_for_device( user_id, device_id, since_stream_id, now_token.to_device_key ) - logger.info( + logger.debug( "Returning %d to-device messages between %d and %d (current token: %d)", len(messages), since_stream_id, stream_id, now_token.to_device_key ) diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py index 2eb325c7c7..c7afd11111 100644 --- a/synapse/push/emailpusher.py +++ b/synapse/push/emailpusher.py @@ -218,7 +218,8 @@ class EmailPusher(object): ) def seconds_until(self, ts_msec): - return (ts_msec - self.clock.time_msec()) / 1000 + secs = (ts_msec - self.clock.time_msec()) / 1000 + return max(secs, 0) def get_room_throttle_ms(self, room_id): if room_id in self.throttle_params: diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py index 18076e0f3b..ab133db872 100644 --- a/synapse/replication/slave/storage/_base.py +++ b/synapse/replication/slave/storage/_base.py @@ -54,7 +54,9 @@ class BaseSlavedStore(SQLBaseStore): try: getattr(self, cache_func).invalidate(tuple(keys)) except AttributeError: - logger.info("Got unexpected cache_func: %r", cache_func) + # We probably haven't pulled in the cache in this worker, + # which is fine. + pass self._cache_id_gen.advance(int(stream["position"])) return defer.succeed(None) diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py index 735c03c7eb..77c64722c7 100644 --- a/synapse/replication/slave/storage/account_data.py +++ b/synapse/replication/slave/storage/account_data.py @@ -46,6 +46,12 @@ class SlavedAccountDataStore(BaseSlavedStore): ) get_tags_for_user = TagsStore.__dict__["get_tags_for_user"] + get_tags_for_room = ( + DataStore.get_tags_for_room.__func__ + ) + get_account_data_for_room = ( + DataStore.get_account_data_for_room.__func__ + ) get_updated_tags = DataStore.get_updated_tags.__func__ get_updated_account_data_for_user = ( diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py index cc860f9f9b..f9102e0d89 100644 --- a/synapse/replication/slave/storage/deviceinbox.py +++ b/synapse/replication/slave/storage/deviceinbox.py @@ -17,6 +17,7 @@ from ._base import BaseSlavedStore from ._slaved_id_tracker import SlavedIdTracker from synapse.storage import DataStore from synapse.util.caches.stream_change_cache import StreamChangeCache +from synapse.util.caches.expiringcache import ExpiringCache class SlavedDeviceInboxStore(BaseSlavedStore): @@ -34,6 +35,13 @@ class SlavedDeviceInboxStore(BaseSlavedStore): self._device_inbox_id_gen.get_current_token() ) + self._last_device_delete_cache = ExpiringCache( + cache_name="last_device_delete_cache", + clock=self._clock, + max_len=10000, + expiry_ms=30 * 60 * 1000, + ) + get_to_device_stream_token = DataStore.get_to_device_stream_token.__func__ get_new_messages_for_device = DataStore.get_new_messages_for_device.__func__ get_new_device_msgs_for_remote = DataStore.get_new_device_msgs_for_remote.__func__ diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index d72ff6055c..622b2d8540 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -85,6 +85,12 @@ class SlavedEventStore(BaseSlavedStore): get_unread_event_push_actions_by_room_for_user = ( EventPushActionsStore.__dict__["get_unread_event_push_actions_by_room_for_user"] ) + _get_unread_counts_by_receipt_txn = ( + DataStore._get_unread_counts_by_receipt_txn.__func__ + ) + _get_unread_counts_by_pos_txn = ( + DataStore._get_unread_counts_by_pos_txn.__func__ + ) _get_state_group_for_events = ( StateStore.__dict__["_get_state_group_for_events"] ) diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py index 703f4a49bf..40f6c9a386 100644 --- a/synapse/replication/slave/storage/presence.py +++ b/synapse/replication/slave/storage/presence.py @@ -18,6 +18,7 @@ from ._slaved_id_tracker import SlavedIdTracker from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.storage import DataStore +from synapse.storage.presence import PresenceStore class SlavedPresenceStore(BaseSlavedStore): @@ -35,7 +36,8 @@ class SlavedPresenceStore(BaseSlavedStore): _get_active_presence = DataStore._get_active_presence.__func__ take_presence_startup_info = DataStore.take_presence_startup_info.__func__ - get_presence_for_users = DataStore.get_presence_for_users.__func__ + _get_presence_for_user = PresenceStore.__dict__["_get_presence_for_user"] + get_presence_for_users = PresenceStore.__dict__["get_presence_for_users"] def get_current_presence_token(self): return self._presence_id_gen.get_current_token() diff --git a/synapse/rest/client/transactions.py b/synapse/rest/client/transactions.py index efa77b8c51..fceca2edeb 100644 --- a/synapse/rest/client/transactions.py +++ b/synapse/rest/client/transactions.py @@ -87,9 +87,17 @@ class HttpTransactionCache(object): deferred = fn(*args, **kwargs) - # We don't add an errback to the raw deferred, so we ask ObservableDeferred - # to swallow the error. This is fine as the error will still be reported - # to the observers. + # if the request fails with a Twisted failure, remove it + # from the transaction map. This is done to ensure that we don't + # cache transient errors like rate-limiting errors, etc. + def remove_from_map(err): + self.transactions.pop(txn_key, None) + return err + deferred.addErrback(remove_from_map) + + # We don't add any other errbacks to the raw deferred, so we ask + # ObservableDeferred to swallow the error. This is fine as the error will + # still be reported to the observers. observable = ObservableDeferred(deferred, consumeErrors=True) self.transactions[txn_key] = (observable, self.clock.time_msec()) return observable.observe() diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py index af21661d7c..29fcd72375 100644 --- a/synapse/rest/client/v1/admin.py +++ b/synapse/rest/client/v1/admin.py @@ -17,6 +17,7 @@ from twisted.internet import defer from synapse.api.errors import AuthError, SynapseError from synapse.types import UserID +from synapse.http.servlet import parse_json_object_from_request from .base import ClientV1RestServlet, client_path_patterns @@ -25,6 +26,34 @@ import logging logger = logging.getLogger(__name__) +class UsersRestServlet(ClientV1RestServlet): + PATTERNS = client_path_patterns("/admin/users/(?P<user_id>[^/]*)") + + def __init__(self, hs): + super(UsersRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + + @defer.inlineCallbacks + def on_GET(self, request, user_id): + target_user = UserID.from_string(user_id) + requester = yield self.auth.get_user_by_req(request) + is_admin = yield self.auth.is_server_admin(requester.user) + + if not is_admin: + raise AuthError(403, "You are not a server admin") + + # To allow all users to get the users list + # if not is_admin and target_user != auth_user: + # raise AuthError(403, "You are not a server admin") + + if not self.hs.is_mine(target_user): + raise SynapseError(400, "Can only users a local user") + + ret = yield self.handlers.admin_handler.get_users() + + defer.returnValue((200, ret)) + + class WhoisRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/admin/whois/(?P<user_id>[^/]*)") @@ -128,8 +157,199 @@ class DeactivateAccountRestServlet(ClientV1RestServlet): defer.returnValue((200, {})) +class ResetPasswordRestServlet(ClientV1RestServlet): + """Post request to allow an administrator reset password for a user. + This need a user have a administrator access in Synapse. + Example: + http://localhost:8008/_matrix/client/api/v1/admin/reset_password/ + @user:to_reset_password?access_token=admin_access_token + JsonBodyToSend: + { + "new_password": "secret" + } + Returns: + 200 OK with empty object if success otherwise an error. + """ + PATTERNS = client_path_patterns("/admin/reset_password/(?P<target_user_id>[^/]*)") + + def __init__(self, hs): + self.store = hs.get_datastore() + super(ResetPasswordRestServlet, self).__init__(hs) + self.hs = hs + self.auth = hs.get_auth() + self.auth_handler = hs.get_auth_handler() + + @defer.inlineCallbacks + def on_POST(self, request, target_user_id): + """Post request to allow an administrator reset password for a user. + This need a user have a administrator access in Synapse. + """ + UserID.from_string(target_user_id) + requester = yield self.auth.get_user_by_req(request) + is_admin = yield self.auth.is_server_admin(requester.user) + + if not is_admin: + raise AuthError(403, "You are not a server admin") + + params = parse_json_object_from_request(request) + new_password = params['new_password'] + if not new_password: + raise SynapseError(400, "Missing 'new_password' arg") + + logger.info("new_password: %r", new_password) + + yield self.auth_handler.set_password( + target_user_id, new_password, requester + ) + defer.returnValue((200, {})) + + +class GetUsersPaginatedRestServlet(ClientV1RestServlet): + """Get request to get specific number of users from Synapse. + This need a user have a administrator access in Synapse. + Example: + http://localhost:8008/_matrix/client/api/v1/admin/users_paginate/ + @admin:user?access_token=admin_access_token&start=0&limit=10 + Returns: + 200 OK with json object {list[dict[str, Any]], count} or empty object. + """ + PATTERNS = client_path_patterns("/admin/users_paginate/(?P<target_user_id>[^/]*)") + + def __init__(self, hs): + self.store = hs.get_datastore() + super(GetUsersPaginatedRestServlet, self).__init__(hs) + self.hs = hs + self.auth = hs.get_auth() + self.handlers = hs.get_handlers() + + @defer.inlineCallbacks + def on_GET(self, request, target_user_id): + """Get request to get specific number of users from Synapse. + This need a user have a administrator access in Synapse. + """ + target_user = UserID.from_string(target_user_id) + requester = yield self.auth.get_user_by_req(request) + is_admin = yield self.auth.is_server_admin(requester.user) + + if not is_admin: + raise AuthError(403, "You are not a server admin") + + # To allow all users to get the users list + # if not is_admin and target_user != auth_user: + # raise AuthError(403, "You are not a server admin") + + if not self.hs.is_mine(target_user): + raise SynapseError(400, "Can only users a local user") + + order = "name" # order by name in user table + start = request.args.get("start")[0] + limit = request.args.get("limit")[0] + if not limit: + raise SynapseError(400, "Missing 'limit' arg") + if not start: + raise SynapseError(400, "Missing 'start' arg") + logger.info("limit: %s, start: %s", limit, start) + + ret = yield self.handlers.admin_handler.get_users_paginate( + order, start, limit + ) + defer.returnValue((200, ret)) + + @defer.inlineCallbacks + def on_POST(self, request, target_user_id): + """Post request to get specific number of users from Synapse.. + This need a user have a administrator access in Synapse. + Example: + http://localhost:8008/_matrix/client/api/v1/admin/users_paginate/ + @admin:user?access_token=admin_access_token + JsonBodyToSend: + { + "start": "0", + "limit": "10 + } + Returns: + 200 OK with json object {list[dict[str, Any]], count} or empty object. + """ + UserID.from_string(target_user_id) + requester = yield self.auth.get_user_by_req(request) + is_admin = yield self.auth.is_server_admin(requester.user) + + if not is_admin: + raise AuthError(403, "You are not a server admin") + + order = "name" # order by name in user table + params = parse_json_object_from_request(request) + limit = params['limit'] + start = params['start'] + if not limit: + raise SynapseError(400, "Missing 'limit' arg") + if not start: + raise SynapseError(400, "Missing 'start' arg") + logger.info("limit: %s, start: %s", limit, start) + + ret = yield self.handlers.admin_handler.get_users_paginate( + order, start, limit + ) + defer.returnValue((200, ret)) + + +class SearchUsersRestServlet(ClientV1RestServlet): + """Get request to search user table for specific users according to + search term. + This need a user have a administrator access in Synapse. + Example: + http://localhost:8008/_matrix/client/api/v1/admin/search_users/ + @admin:user?access_token=admin_access_token&term=alice + Returns: + 200 OK with json object {list[dict[str, Any]], count} or empty object. + """ + PATTERNS = client_path_patterns("/admin/search_users/(?P<target_user_id>[^/]*)") + + def __init__(self, hs): + self.store = hs.get_datastore() + super(SearchUsersRestServlet, self).__init__(hs) + self.hs = hs + self.auth = hs.get_auth() + self.handlers = hs.get_handlers() + + @defer.inlineCallbacks + def on_GET(self, request, target_user_id): + """Get request to search user table for specific users according to + search term. + This need a user have a administrator access in Synapse. + """ + target_user = UserID.from_string(target_user_id) + requester = yield self.auth.get_user_by_req(request) + is_admin = yield self.auth.is_server_admin(requester.user) + + if not is_admin: + raise AuthError(403, "You are not a server admin") + + # To allow all users to get the users list + # if not is_admin and target_user != auth_user: + # raise AuthError(403, "You are not a server admin") + + if not self.hs.is_mine(target_user): + raise SynapseError(400, "Can only users a local user") + + term = request.args.get("term")[0] + if not term: + raise SynapseError(400, "Missing 'term' arg") + + logger.info("term: %s ", term) + + ret = yield self.handlers.admin_handler.search_users( + term + ) + defer.returnValue((200, ret)) + + def register_servlets(hs, http_server): WhoisRestServlet(hs).register(http_server) PurgeMediaCacheRestServlet(hs).register(http_server) DeactivateAccountRestServlet(hs).register(http_server) PurgeHistoryRestServlet(hs).register(http_server) + UsersRestServlet(hs).register(http_server) + ResetPasswordRestServlet(hs).register(http_server) + GetUsersPaginatedRestServlet(hs).register(http_server) + SearchUsersRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py index 355e82474b..1a5045c9ec 100644 --- a/synapse/rest/client/v1/profile.py +++ b/synapse/rest/client/v1/profile.py @@ -46,6 +46,7 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet): def on_PUT(self, request, user_id): requester = yield self.auth.get_user_by_req(request, allow_guest=True) user = UserID.from_string(user_id) + is_admin = yield self.auth.is_server_admin(requester.user) content = parse_json_object_from_request(request) @@ -55,7 +56,7 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet): defer.returnValue((400, "Unable to parse name")) yield self.handlers.profile_handler.set_displayname( - user, requester, new_name) + user, requester, new_name, is_admin) defer.returnValue((200, {})) @@ -88,6 +89,7 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet): def on_PUT(self, request, user_id): requester = yield self.auth.get_user_by_req(request) user = UserID.from_string(user_id) + is_admin = yield self.auth.is_server_admin(requester.user) content = parse_json_object_from_request(request) try: @@ -96,7 +98,7 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet): defer.returnValue((400, "Unable to parse name")) yield self.handlers.profile_handler.set_avatar_url( - user, requester, new_name) + user, requester, new_name, is_admin) defer.returnValue((200, {})) diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 6554f57df1..90242a6bac 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -608,6 +608,10 @@ class RoomMembershipRestServlet(ClientV1RestServlet): raise SynapseError(400, "Missing user_id key.") target = UserID.from_string(content["user_id"]) + event_content = None + if 'reason' in content and membership_action in ['kick', 'ban']: + event_content = {'reason': content['reason']} + yield self.handlers.room_member_handler.update_membership( requester=requester, target=target, @@ -615,6 +619,7 @@ class RoomMembershipRestServlet(ClientV1RestServlet): action=membership_action, txn_id=txn_id, third_party_signed=content.get("third_party_signed", None), + content=event_content, ) defer.returnValue((200, {})) diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 3cbeca503c..481ffee200 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -240,6 +240,9 @@ class MediaRepository(object): if t_method == "crop": t_len = thumbnailer.crop(t_path, t_width, t_height, t_type) elif t_method == "scale": + t_width, t_height = thumbnailer.aspect(t_width, t_height) + t_width = min(m_width, t_width) + t_height = min(m_height, t_height) t_len = thumbnailer.scale(t_path, t_width, t_height, t_type) else: t_len = None diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index b9968debe5..d604e7668f 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -297,6 +297,82 @@ class DataStore(RoomMemberStore, RoomStore, desc="get_user_ip_and_agents", ) + def get_users(self): + """Function to reterive a list of users in users table. + + Args: + Returns: + defer.Deferred: resolves to list[dict[str, Any]] + """ + return self._simple_select_list( + table="users", + keyvalues={}, + retcols=[ + "name", + "password_hash", + "is_guest", + "admin" + ], + desc="get_users", + ) + + def get_users_paginate(self, order, start, limit): + """Function to reterive a paginated list of users from + users list. This will return a json object, which contains + list of users and the total number of users in users table. + + Args: + order (str): column name to order the select by this column + start (int): start number to begin the query from + limit (int): number of rows to reterive + Returns: + defer.Deferred: resolves to json object {list[dict[str, Any]], count} + """ + is_guest = 0 + i_start = (int)(start) + i_limit = (int)(limit) + return self.get_user_list_paginate( + table="users", + keyvalues={ + "is_guest": is_guest + }, + pagevalues=[ + order, + i_limit, + i_start + ], + retcols=[ + "name", + "password_hash", + "is_guest", + "admin" + ], + desc="get_users_paginate", + ) + + def search_users(self, term): + """Function to search users list for one or more users with + the matched term. + + Args: + term (str): search term + col (str): column to query term should be matched to + Returns: + defer.Deferred: resolves to list[dict[str, Any]] + """ + return self._simple_search_list( + table="users", + term=term, + col="name", + retcols=[ + "name", + "password_hash", + "is_guest", + "admin" + ], + desc="search_users", + ) + def are_all_users_on_domain(txn, database_engine, domain): sql = database_engine.convert_param_style( diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 05374682fd..a7a8ec9b7b 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -18,7 +18,6 @@ from synapse.api.errors import StoreError from synapse.util.logcontext import LoggingContext, PreserveLoggingContext from synapse.util.caches.dictionary_cache import DictionaryCache from synapse.util.caches.descriptors import Cache -from synapse.util.caches import intern_dict from synapse.storage.engines import PostgresEngine import synapse.metrics @@ -80,7 +79,13 @@ class LoggingTransaction(object): def executemany(self, sql, *args): self._do_execute(self.txn.executemany, sql, *args) + def _make_sql_one_line(self, sql): + "Strip newlines out of SQL so that the loggers in the DB are on one line" + return " ".join(l.strip() for l in sql.splitlines() if l.strip()) + def _do_execute(self, func, sql, *args): + sql = self._make_sql_one_line(sql) + # TODO(paul): Maybe use 'info' and 'debug' for values? sql_logger.debug("[SQL] {%s} %s", self.name, sql) @@ -350,9 +355,9 @@ class SQLBaseStore(object): Returns: A list of dicts where the key is the column header. """ - col_headers = list(column[0] for column in cursor.description) + col_headers = list(intern(column[0]) for column in cursor.description) results = list( - intern_dict(dict(zip(col_headers, row))) for row in cursor.fetchall() + dict(zip(col_headers, row)) for row in cursor.fetchall() ) return results @@ -483,10 +488,6 @@ class SQLBaseStore(object): " AND ".join("%s = ?" % (k,) for k in keyvalues) ) sqlargs = values.values() + keyvalues.values() - logger.debug( - "[SQL] %s Args=%s", - sql, sqlargs, - ) txn.execute(sql, sqlargs) if txn.rowcount == 0: @@ -501,10 +502,6 @@ class SQLBaseStore(object): ", ".join(k for k in allvalues), ", ".join("?" for _ in allvalues) ) - logger.debug( - "[SQL] %s Args=%s", - sql, keyvalues.values(), - ) txn.execute(sql, allvalues.values()) return True @@ -934,6 +931,165 @@ class SQLBaseStore(object): else: return 0 + def _simple_select_list_paginate(self, table, keyvalues, pagevalues, retcols, + desc="_simple_select_list_paginate"): + """Executes a SELECT query on the named table with start and limit, + of row numbers, which may return zero or number of rows from start to limit, + returning the result as a list of dicts. + + Args: + table (str): the table name + keyvalues (dict[str, Any] | None): + column names and values to select the rows with, or None to not + apply a WHERE clause. + retcols (iterable[str]): the names of the columns to return + order (str): order the select by this column + start (int): start number to begin the query from + limit (int): number of rows to reterive + Returns: + defer.Deferred: resolves to list[dict[str, Any]] + """ + return self.runInteraction( + desc, + self._simple_select_list_paginate_txn, + table, keyvalues, pagevalues, retcols + ) + + @classmethod + def _simple_select_list_paginate_txn(cls, txn, table, keyvalues, pagevalues, retcols): + """Executes a SELECT query on the named table with start and limit, + of row numbers, which may return zero or number of rows from start to limit, + returning the result as a list of dicts. + + Args: + txn : Transaction object + table (str): the table name + keyvalues (dict[str, T] | None): + column names and values to select the rows with, or None to not + apply a WHERE clause. + pagevalues ([]): + order (str): order the select by this column + start (int): start number to begin the query from + limit (int): number of rows to reterive + retcols (iterable[str]): the names of the columns to return + Returns: + defer.Deferred: resolves to list[dict[str, Any]] + + """ + if keyvalues: + sql = "SELECT %s FROM %s WHERE %s ORDER BY %s" % ( + ", ".join(retcols), + table, + " AND ".join("%s = ?" % (k,) for k in keyvalues), + " ? ASC LIMIT ? OFFSET ?" + ) + txn.execute(sql, keyvalues.values() + pagevalues) + else: + sql = "SELECT %s FROM %s ORDER BY %s" % ( + ", ".join(retcols), + table, + " ? ASC LIMIT ? OFFSET ?" + ) + txn.execute(sql, pagevalues) + + return cls.cursor_to_dict(txn) + + @defer.inlineCallbacks + def get_user_list_paginate(self, table, keyvalues, pagevalues, retcols, + desc="get_user_list_paginate"): + """Get a list of users from start row to a limit number of rows. This will + return a json object with users and total number of users in users list. + + Args: + table (str): the table name + keyvalues (dict[str, Any] | None): + column names and values to select the rows with, or None to not + apply a WHERE clause. + pagevalues ([]): + order (str): order the select by this column + start (int): start number to begin the query from + limit (int): number of rows to reterive + retcols (iterable[str]): the names of the columns to return + Returns: + defer.Deferred: resolves to json object {list[dict[str, Any]], count} + """ + users = yield self.runInteraction( + desc, + self._simple_select_list_paginate_txn, + table, keyvalues, pagevalues, retcols + ) + count = yield self.runInteraction( + desc, + self.get_user_count_txn + ) + retval = { + "users": users, + "total": count + } + defer.returnValue(retval) + + def get_user_count_txn(self, txn): + """Get a total number of registerd users in the users list. + + Args: + txn : Transaction object + Returns: + defer.Deferred: resolves to int + """ + sql_count = "SELECT COUNT(*) FROM users WHERE is_guest = 0;" + txn.execute(sql_count) + count = txn.fetchone()[0] + defer.returnValue(count) + + def _simple_search_list(self, table, term, col, retcols, + desc="_simple_search_list"): + """Executes a SELECT query on the named table, which may return zero or + more rows, returning the result as a list of dicts. + + Args: + table (str): the table name + term (str | None): + term for searching the table matched to a column. + col (str): column to query term should be matched to + retcols (iterable[str]): the names of the columns to return + Returns: + defer.Deferred: resolves to list[dict[str, Any]] or None + """ + + return self.runInteraction( + desc, + self._simple_search_list_txn, + table, term, col, retcols + ) + + @classmethod + def _simple_search_list_txn(cls, txn, table, term, col, retcols): + """Executes a SELECT query on the named table, which may return zero or + more rows, returning the result as a list of dicts. + + Args: + txn : Transaction object + table (str): the table name + term (str | None): + term for searching the table matched to a column. + col (str): column to query term should be matched to + retcols (iterable[str]): the names of the columns to return + Returns: + defer.Deferred: resolves to list[dict[str, Any]] or None + """ + if term: + sql = "SELECT %s FROM %s WHERE %s LIKE ?" % ( + ", ".join(retcols), + table, + col + ) + termvalues = ["%%" + term + "%%"] + txn.execute(sql, termvalues) + else: + return 0 + + return cls.cursor_to_dict(txn) + class _RollbackButIsFineException(Exception): """ This exception is used to rollback a transaction without implying diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py index bde3b5cbbc..5c7db5e5f6 100644 --- a/synapse/storage/deviceinbox.py +++ b/synapse/storage/deviceinbox.py @@ -20,6 +20,8 @@ from twisted.internet import defer from .background_updates import BackgroundUpdateStore +from synapse.util.caches.expiringcache import ExpiringCache + logger = logging.getLogger(__name__) @@ -42,6 +44,15 @@ class DeviceInboxStore(BackgroundUpdateStore): self._background_drop_index_device_inbox, ) + # Map of (user_id, device_id) to the last stream_id that has been + # deleted up to. This is so that we can no op deletions. + self._last_device_delete_cache = ExpiringCache( + cache_name="last_device_delete_cache", + clock=self._clock, + max_len=10000, + expiry_ms=30 * 60 * 1000, + ) + @defer.inlineCallbacks def add_messages_to_device_inbox(self, local_messages_by_user_then_device, remote_messages_by_destination): @@ -251,6 +262,7 @@ class DeviceInboxStore(BackgroundUpdateStore): "get_new_messages_for_device", get_new_messages_for_device_txn, ) + @defer.inlineCallbacks def delete_messages_for_device(self, user_id, device_id, up_to_stream_id): """ Args: @@ -260,6 +272,18 @@ class DeviceInboxStore(BackgroundUpdateStore): Returns: A deferred that resolves to the number of messages deleted. """ + # If we have cached the last stream id we've deleted up to, we can + # check if there is likely to be anything that needs deleting + last_deleted_stream_id = self._last_device_delete_cache.get( + (user_id, device_id), None + ) + if last_deleted_stream_id: + has_changed = self._device_inbox_stream_cache.has_entity_changed( + user_id, last_deleted_stream_id + ) + if not has_changed: + defer.returnValue(0) + def delete_messages_for_device_txn(txn): sql = ( "DELETE FROM device_inbox" @@ -269,10 +293,20 @@ class DeviceInboxStore(BackgroundUpdateStore): txn.execute(sql, (user_id, device_id, up_to_stream_id)) return txn.rowcount - return self.runInteraction( + count = yield self.runInteraction( "delete_messages_for_device", delete_messages_for_device_txn ) + # Update the cache, ensuring that we only ever increase the value + last_deleted_stream_id = self._last_device_delete_cache.get( + (user_id, device_id), 0 + ) + self._last_device_delete_cache[(user_id, device_id)] = max( + last_deleted_stream_id, up_to_stream_id + ) + + defer.returnValue(count) + def get_all_new_device_messages(self, last_pos, current_pos, limit): """ Args: diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index 8e17800364..bd56ba2515 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -19,6 +19,8 @@ from twisted.internet import defer from synapse.api.errors import StoreError from ._base import SQLBaseStore +from synapse.util.caches.descriptors import cached, cachedList, cachedInlineCallbacks + logger = logging.getLogger(__name__) @@ -31,6 +33,13 @@ class DeviceStore(SQLBaseStore): self._prune_old_outbound_device_pokes, 60 * 60 * 1000 ) + self.register_background_index_update( + "device_lists_stream_idx", + index_name="device_lists_stream_user_id", + table="device_lists_stream", + columns=["user_id", "device_id"], + ) + @defer.inlineCallbacks def store_device(self, user_id, device_id, initial_device_display_name): @@ -144,6 +153,7 @@ class DeviceStore(SQLBaseStore): defer.returnValue({d["device_id"]: d for d in devices}) + @cached(max_entries=10000) def get_device_list_last_stream_id_for_remote(self, user_id): """Get the last stream_id we got for a user. May be None if we haven't got any information for them. @@ -156,16 +166,36 @@ class DeviceStore(SQLBaseStore): allow_none=True, ) + @cachedList(cached_method_name="get_device_list_last_stream_id_for_remote", + list_name="user_ids", inlineCallbacks=True) + def get_device_list_last_stream_id_for_remotes(self, user_ids): + rows = yield self._simple_select_many_batch( + table="device_lists_remote_extremeties", + column="user_id", + iterable=user_ids, + retcols=("user_id", "stream_id",), + desc="get_user_devices_from_cache", + ) + + results = {user_id: None for user_id in user_ids} + results.update({ + row["user_id"]: row["stream_id"] for row in rows + }) + + defer.returnValue(results) + + @defer.inlineCallbacks def mark_remote_user_device_list_as_unsubscribed(self, user_id): """Mark that we no longer track device lists for remote user. """ - return self._simple_delete( + yield self._simple_delete( table="device_lists_remote_extremeties", keyvalues={ "user_id": user_id, }, desc="mark_remote_user_device_list_as_unsubscribed", ) + self.get_device_list_last_stream_id_for_remote.invalidate((user_id,)) def update_remote_device_list_cache_entry(self, user_id, device_id, content, stream_id): @@ -191,6 +221,12 @@ class DeviceStore(SQLBaseStore): } ) + txn.call_after(self._get_cached_user_device.invalidate, (user_id, device_id,)) + txn.call_after(self._get_cached_devices_for_user.invalidate, (user_id,)) + txn.call_after( + self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,) + ) + self._simple_upsert_txn( txn, table="device_lists_remote_extremeties", @@ -234,6 +270,12 @@ class DeviceStore(SQLBaseStore): ] ) + txn.call_after(self._get_cached_devices_for_user.invalidate, (user_id,)) + txn.call_after(self._get_cached_user_device.invalidate_many, (user_id,)) + txn.call_after( + self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,) + ) + self._simple_upsert_txn( txn, table="device_lists_remote_extremeties", @@ -320,6 +362,7 @@ class DeviceStore(SQLBaseStore): return (now_stream_id, results) + @defer.inlineCallbacks def get_user_devices_from_cache(self, query_list): """Get the devices (and keys if any) for remote users from the cache. @@ -332,27 +375,11 @@ class DeviceStore(SQLBaseStore): a set of user_ids and results_map is a mapping of user_id -> device_id -> device_info """ - return self.runInteraction( - "get_user_devices_from_cache", self._get_user_devices_from_cache_txn, - query_list, + user_ids = set(user_id for user_id, _ in query_list) + user_map = yield self.get_device_list_last_stream_id_for_remotes(list(user_ids)) + user_ids_in_cache = set( + user_id for user_id, stream_id in user_map.items() if stream_id ) - - def _get_user_devices_from_cache_txn(self, txn, query_list): - user_ids = {user_id for user_id, _ in query_list} - - user_ids_in_cache = set() - for user_id in user_ids: - stream_ids = self._simple_select_onecol_txn( - txn, - table="device_lists_remote_extremeties", - keyvalues={ - "user_id": user_id, - }, - retcol="stream_id", - ) - if stream_ids: - user_ids_in_cache.add(user_id) - user_ids_not_in_cache = user_ids - user_ids_in_cache results = {} @@ -361,32 +388,40 @@ class DeviceStore(SQLBaseStore): continue if device_id: - content = self._simple_select_one_onecol_txn( - txn, - table="device_lists_remote_cache", - keyvalues={ - "user_id": user_id, - "device_id": device_id, - }, - retcol="content", - ) - results.setdefault(user_id, {})[device_id] = json.loads(content) + device = yield self._get_cached_user_device(user_id, device_id) + results.setdefault(user_id, {})[device_id] = device else: - devices = self._simple_select_list_txn( - txn, - table="device_lists_remote_cache", - keyvalues={ - "user_id": user_id, - }, - retcols=("device_id", "content"), - ) - results[user_id] = { - device["device_id"]: json.loads(device["content"]) - for device in devices - } - user_ids_in_cache.discard(user_id) + results[user_id] = yield self._get_cached_devices_for_user(user_id) + + defer.returnValue((user_ids_not_in_cache, results)) + + @cachedInlineCallbacks(num_args=2, tree=True) + def _get_cached_user_device(self, user_id, device_id): + content = yield self._simple_select_one_onecol( + table="device_lists_remote_cache", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + }, + retcol="content", + desc="_get_cached_user_device", + ) + defer.returnValue(json.loads(content)) - return user_ids_not_in_cache, results + @cachedInlineCallbacks() + def _get_cached_devices_for_user(self, user_id): + devices = yield self._simple_select_list( + table="device_lists_remote_cache", + keyvalues={ + "user_id": user_id, + }, + retcols=("device_id", "content"), + desc="_get_cached_devices_for_user", + ) + defer.returnValue({ + device["device_id"]: json.loads(device["content"]) + for device in devices + }) def get_devices_with_keys_by_user(self, user_id): """Get all devices (with any device keys) for a user @@ -473,7 +508,7 @@ class DeviceStore(SQLBaseStore): defer.returnValue(set(changed)) sql = """ - SELECT user_id FROM device_lists_stream WHERE stream_id > ? + SELECT DISTINCT user_id FROM device_lists_stream WHERE stream_id > ? """ rows = yield self._execute("get_user_whose_devices_changed", None, sql, from_key) defer.returnValue(set(row[0] for row in rows)) @@ -489,7 +524,7 @@ class DeviceStore(SQLBaseStore): WHERE stream_id > ? """ return self._execute( - "get_users_and_hosts_device_list", None, + "get_all_device_list_changes_for_remotes", None, sql, from_key, ) @@ -518,6 +553,16 @@ class DeviceStore(SQLBaseStore): host, stream_id, ) + # Delete older entries in the table, as we really only care about + # when the latest change happened. + txn.executemany( + """ + DELETE FROM device_lists_stream + WHERE user_id = ? AND device_id = ? AND stream_id < ? + """, + [(user_id, device_id, stream_id) for device_id in device_ids] + ) + self._simple_insert_many_txn( txn, table="device_lists_stream", diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py index 2040e022fa..b9f1365f92 100644 --- a/synapse/storage/end_to_end_keys.py +++ b/synapse/storage/end_to_end_keys.py @@ -93,7 +93,7 @@ class EndToEndKeyStore(SQLBaseStore): query_clause = "user_id = ?" query_params.append(user_id) - if device_id: + if device_id is not None: query_clause += " AND device_id = ?" query_params.append(device_id) diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py index ee88c61954..256e50dc20 100644 --- a/synapse/storage/event_federation.py +++ b/synapse/storage/event_federation.py @@ -281,15 +281,30 @@ class EventFederationStore(SQLBaseStore): ) def get_forward_extremeties_for_room(self, room_id, stream_ordering): + """For a given room_id and stream_ordering, return the forward + extremeties of the room at that point in "time". + + Throws a StoreError if we have since purged the index for + stream_orderings from that point. + + Args: + room_id (str): + stream_ordering (int): + + Returns: + deferred, which resolves to a list of event_ids + """ # We want to make the cache more effective, so we clamp to the last # change before the given ordering. last_change = self._events_stream_cache.get_max_pos_of_last_change(room_id) # We don't always have a full stream_to_exterm_id table, e.g. after # the upgrade that introduced it, so we make sure we never ask for a - # try and pin to a stream_ordering from before a restart + # stream_ordering from before a restart last_change = max(self._stream_order_on_start, last_change) + # provided the last_change is recent enough, we now clamp the requested + # stream_ordering to it. if last_change > self.stream_ordering_month_ago: stream_ordering = min(last_change, stream_ordering) diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py index 7de3e8c58c..14543b4269 100644 --- a/synapse/storage/event_push_actions.py +++ b/synapse/storage/event_push_actions.py @@ -15,6 +15,7 @@ from ._base import SQLBaseStore from twisted.internet import defer +from synapse.util.async import sleep from synapse.util.caches.descriptors import cachedInlineCallbacks from synapse.types import RoomStreamToken from .stream import lower_bound @@ -25,11 +26,46 @@ import ujson as json logger = logging.getLogger(__name__) +DEFAULT_NOTIF_ACTION = ["notify", {"set_tweak": "highlight", "value": False}] +DEFAULT_HIGHLIGHT_ACTION = [ + "notify", {"set_tweak": "sound", "value": "default"}, {"set_tweak": "highlight"} +] + + +def _serialize_action(actions, is_highlight): + """Custom serializer for actions. This allows us to "compress" common actions. + + We use the fact that most users have the same actions for notifs (and for + highlights). + We store these default actions as the empty string rather than the full JSON. + Since the empty string isn't valid JSON there is no risk of this clashing with + any real JSON actions + """ + if is_highlight: + if actions == DEFAULT_HIGHLIGHT_ACTION: + return "" # We use empty string as the column is non-NULL + else: + if actions == DEFAULT_NOTIF_ACTION: + return "" + return json.dumps(actions) + + +def _deserialize_action(actions, is_highlight): + """Custom deserializer for actions. This allows us to "compress" common actions + """ + if actions: + return json.loads(actions) + + if is_highlight: + return DEFAULT_HIGHLIGHT_ACTION + else: + return DEFAULT_NOTIF_ACTION + + class EventPushActionsStore(SQLBaseStore): EPA_HIGHLIGHT_INDEX = "epa_highlight_index" def __init__(self, hs): - self.stream_ordering_month_ago = None super(EventPushActionsStore, self).__init__(hs) self.register_background_index_update( @@ -47,6 +83,11 @@ class EventPushActionsStore(SQLBaseStore): where_clause="highlight=1" ) + self._doing_notif_rotation = False + self._rotate_notif_loop = self._clock.looping_call( + self._rotate_notifs, 30 * 60 * 1000 + ) + def _set_push_actions_for_event_and_users_txn(self, txn, event, tuples): """ Args: @@ -55,15 +96,17 @@ class EventPushActionsStore(SQLBaseStore): """ values = [] for uid, actions in tuples: + is_highlight = 1 if _action_has_highlight(actions) else 0 + values.append({ 'room_id': event.room_id, 'event_id': event.event_id, 'user_id': uid, - 'actions': json.dumps(actions), + 'actions': _serialize_action(actions, is_highlight), 'stream_ordering': event.internal_metadata.stream_ordering, 'topological_ordering': event.depth, 'notif': 1, - 'highlight': 1 if _action_has_highlight(actions) else 0, + 'highlight': is_highlight, }) for uid, __ in tuples: @@ -77,66 +120,83 @@ class EventPushActionsStore(SQLBaseStore): def get_unread_event_push_actions_by_room_for_user( self, room_id, user_id, last_read_event_id ): - def _get_unread_event_push_actions_by_room(txn): - sql = ( - "SELECT stream_ordering, topological_ordering" - " FROM events" - " WHERE room_id = ? AND event_id = ?" - ) - txn.execute( - sql, (room_id, last_read_event_id) - ) - results = txn.fetchall() - if len(results) == 0: - return {"notify_count": 0, "highlight_count": 0} - - stream_ordering = results[0][0] - topological_ordering = results[0][1] - token = RoomStreamToken( - topological_ordering, stream_ordering - ) - - # First get number of notifications. - # We don't need to put a notif=1 clause as all rows always have - # notif=1 - sql = ( - "SELECT count(*)" - " FROM event_push_actions ea" - " WHERE" - " user_id = ?" - " AND room_id = ?" - " AND %s" - ) % (lower_bound(token, self.database_engine, inclusive=False),) - - txn.execute(sql, (user_id, room_id)) - row = txn.fetchone() - notify_count = row[0] if row else 0 + ret = yield self.runInteraction( + "get_unread_event_push_actions_by_room", + self._get_unread_counts_by_receipt_txn, + room_id, user_id, last_read_event_id + ) + defer.returnValue(ret) - # Now get the number of highlights - sql = ( - "SELECT count(*)" - " FROM event_push_actions ea" - " WHERE" - " highlight = 1" - " AND user_id = ?" - " AND room_id = ?" - " AND %s" - ) % (lower_bound(token, self.database_engine, inclusive=False),) + def _get_unread_counts_by_receipt_txn(self, txn, room_id, user_id, + last_read_event_id): + sql = ( + "SELECT stream_ordering, topological_ordering" + " FROM events" + " WHERE room_id = ? AND event_id = ?" + ) + txn.execute( + sql, (room_id, last_read_event_id) + ) + results = txn.fetchall() + if len(results) == 0: + return {"notify_count": 0, "highlight_count": 0} - txn.execute(sql, (user_id, room_id)) - row = txn.fetchone() - highlight_count = row[0] if row else 0 + stream_ordering = results[0][0] + topological_ordering = results[0][1] - return { - "notify_count": notify_count, - "highlight_count": highlight_count, - } + return self._get_unread_counts_by_pos_txn( + txn, room_id, user_id, topological_ordering, stream_ordering + ) - ret = yield self.runInteraction( - "get_unread_event_push_actions_by_room", - _get_unread_event_push_actions_by_room + def _get_unread_counts_by_pos_txn(self, txn, room_id, user_id, topological_ordering, + stream_ordering): + token = RoomStreamToken( + topological_ordering, stream_ordering ) - defer.returnValue(ret) + + # First get number of notifications. + # We don't need to put a notif=1 clause as all rows always have + # notif=1 + sql = ( + "SELECT count(*)" + " FROM event_push_actions ea" + " WHERE" + " user_id = ?" + " AND room_id = ?" + " AND %s" + ) % (lower_bound(token, self.database_engine, inclusive=False),) + + txn.execute(sql, (user_id, room_id)) + row = txn.fetchone() + notify_count = row[0] if row else 0 + + txn.execute(""" + SELECT notif_count FROM event_push_summary + WHERE room_id = ? AND user_id = ? AND stream_ordering > ? + """, (room_id, user_id, stream_ordering,)) + rows = txn.fetchall() + if rows: + notify_count += rows[0][0] + + # Now get the number of highlights + sql = ( + "SELECT count(*)" + " FROM event_push_actions ea" + " WHERE" + " highlight = 1" + " AND user_id = ?" + " AND room_id = ?" + " AND %s" + ) % (lower_bound(token, self.database_engine, inclusive=False),) + + txn.execute(sql, (user_id, room_id)) + row = txn.fetchone() + highlight_count = row[0] if row else 0 + + return { + "notify_count": notify_count, + "highlight_count": highlight_count, + } @defer.inlineCallbacks def get_push_action_users_in_range(self, min_stream_ordering, max_stream_ordering): @@ -176,7 +236,8 @@ class EventPushActionsStore(SQLBaseStore): # find rooms that have a read receipt in them and return the next # push actions sql = ( - "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions" + "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions," + " ep.highlight " " FROM (" " SELECT room_id," " MAX(topological_ordering) as topological_ordering," @@ -217,7 +278,7 @@ class EventPushActionsStore(SQLBaseStore): def get_no_receipt(txn): sql = ( "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions," - " e.received_ts" + " ep.highlight " " FROM event_push_actions AS ep" " INNER JOIN events AS e USING (room_id, event_id)" " WHERE" @@ -246,7 +307,7 @@ class EventPushActionsStore(SQLBaseStore): "event_id": row[0], "room_id": row[1], "stream_ordering": row[2], - "actions": json.loads(row[3]), + "actions": _deserialize_action(row[3], row[4]), } for row in after_read_receipt + no_read_receipt ] @@ -285,7 +346,7 @@ class EventPushActionsStore(SQLBaseStore): def get_after_receipt(txn): sql = ( "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions," - " e.received_ts" + " ep.highlight, e.received_ts" " FROM (" " SELECT room_id," " MAX(topological_ordering) as topological_ordering," @@ -327,7 +388,7 @@ class EventPushActionsStore(SQLBaseStore): def get_no_receipt(txn): sql = ( "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions," - " e.received_ts" + " ep.highlight, e.received_ts" " FROM event_push_actions AS ep" " INNER JOIN events AS e USING (room_id, event_id)" " WHERE" @@ -357,8 +418,8 @@ class EventPushActionsStore(SQLBaseStore): "event_id": row[0], "room_id": row[1], "stream_ordering": row[2], - "actions": json.loads(row[3]), - "received_ts": row[4], + "actions": _deserialize_action(row[3], row[4]), + "received_ts": row[5], } for row in after_read_receipt + no_read_receipt ] @@ -392,7 +453,7 @@ class EventPushActionsStore(SQLBaseStore): sql = ( "SELECT epa.event_id, epa.room_id," " epa.stream_ordering, epa.topological_ordering," - " epa.actions, epa.profile_tag, e.received_ts" + " epa.actions, epa.highlight, epa.profile_tag, e.received_ts" " FROM event_push_actions epa, events e" " WHERE epa.event_id = e.event_id" " AND epa.user_id = ? %s" @@ -407,7 +468,7 @@ class EventPushActionsStore(SQLBaseStore): "get_push_actions_for_user", f ) for pa in push_actions: - pa["actions"] = json.loads(pa["actions"]) + pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"]) defer.returnValue(push_actions) @defer.inlineCallbacks @@ -448,10 +509,14 @@ class EventPushActionsStore(SQLBaseStore): ) def _remove_old_push_actions_before_txn(self, txn, room_id, user_id, - topological_ordering): + topological_ordering, stream_ordering): """ - Purges old, stale push actions for a user and room before a given - topological_ordering + Purges old push actions for a user and room before a given + topological_ordering. + + We however keep a months worth of highlighted notifications, so that + users can still get a list of recent highlights. + Args: txn: The transcation room_id: Room ID to delete from @@ -475,10 +540,16 @@ class EventPushActionsStore(SQLBaseStore): txn.execute( "DELETE FROM event_push_actions " " WHERE user_id = ? AND room_id = ? AND " - " topological_ordering < ? AND stream_ordering < ?", + " topological_ordering <= ?" + " AND ((stream_ordering < ? AND highlight = 1) or highlight = 0)", (user_id, room_id, topological_ordering, self.stream_ordering_month_ago) ) + txn.execute(""" + DELETE FROM event_push_summary + WHERE room_id = ? AND user_id = ? AND stream_ordering <= ? + """, (room_id, user_id, stream_ordering)) + @defer.inlineCallbacks def _find_stream_orderings_for_times(self): yield self.runInteraction( @@ -495,6 +566,14 @@ class EventPushActionsStore(SQLBaseStore): "Found stream ordering 1 month ago: it's %d", self.stream_ordering_month_ago ) + logger.info("Searching for stream ordering 1 day ago") + self.stream_ordering_day_ago = self._find_first_stream_ordering_after_ts_txn( + txn, self._clock.time_msec() - 24 * 60 * 60 * 1000 + ) + logger.info( + "Found stream ordering 1 day ago: it's %d", + self.stream_ordering_day_ago + ) def _find_first_stream_ordering_after_ts_txn(self, txn, ts): """ @@ -534,6 +613,131 @@ class EventPushActionsStore(SQLBaseStore): return range_end + @defer.inlineCallbacks + def _rotate_notifs(self): + if self._doing_notif_rotation or self.stream_ordering_day_ago is None: + return + self._doing_notif_rotation = True + + try: + while True: + logger.info("Rotating notifications") + + caught_up = yield self.runInteraction( + "_rotate_notifs", + self._rotate_notifs_txn + ) + if caught_up: + break + yield sleep(5) + finally: + self._doing_notif_rotation = False + + def _rotate_notifs_txn(self, txn): + """Archives older notifications into event_push_summary. Returns whether + the archiving process has caught up or not. + """ + + old_rotate_stream_ordering = self._simple_select_one_onecol_txn( + txn, + table="event_push_summary_stream_ordering", + keyvalues={}, + retcol="stream_ordering", + ) + + # We don't to try and rotate millions of rows at once, so we cap the + # maximum stream ordering we'll rotate before. + txn.execute(""" + SELECT stream_ordering FROM event_push_actions + WHERE stream_ordering > ? + ORDER BY stream_ordering ASC LIMIT 1 OFFSET 50000 + """, (old_rotate_stream_ordering,)) + stream_row = txn.fetchone() + if stream_row: + offset_stream_ordering, = stream_row + rotate_to_stream_ordering = min( + self.stream_ordering_day_ago, offset_stream_ordering + ) + caught_up = offset_stream_ordering >= self.stream_ordering_day_ago + else: + rotate_to_stream_ordering = self.stream_ordering_day_ago + caught_up = True + + logger.info("Rotating notifications up to: %s", rotate_to_stream_ordering) + + self._rotate_notifs_before_txn(txn, rotate_to_stream_ordering) + + # We have caught up iff we were limited by `stream_ordering_day_ago` + return caught_up + + def _rotate_notifs_before_txn(self, txn, rotate_to_stream_ordering): + old_rotate_stream_ordering = self._simple_select_one_onecol_txn( + txn, + table="event_push_summary_stream_ordering", + keyvalues={}, + retcol="stream_ordering", + ) + + # Calculate the new counts that should be upserted into event_push_summary + sql = """ + SELECT user_id, room_id, + coalesce(old.notif_count, 0) + upd.notif_count, + upd.stream_ordering, + old.user_id + FROM ( + SELECT user_id, room_id, count(*) as notif_count, + max(stream_ordering) as stream_ordering + FROM event_push_actions + WHERE ? <= stream_ordering AND stream_ordering < ? + AND highlight = 0 + GROUP BY user_id, room_id + ) AS upd + LEFT JOIN event_push_summary AS old USING (user_id, room_id) + """ + + txn.execute(sql, (old_rotate_stream_ordering, rotate_to_stream_ordering,)) + rows = txn.fetchall() + + logger.info("Rotating notifications, handling %d rows", len(rows)) + + # If the `old.user_id` above is NULL then we know there isn't already an + # entry in the table, so we simply insert it. Otherwise we update the + # existing table. + self._simple_insert_many_txn( + txn, + table="event_push_summary", + values=[ + { + "user_id": row[0], + "room_id": row[1], + "notif_count": row[2], + "stream_ordering": row[3], + } + for row in rows if row[4] is None + ] + ) + + txn.executemany( + """ + UPDATE event_push_summary SET notif_count = ?, stream_ordering = ? + WHERE user_id = ? AND room_id = ? + """, + ((row[2], row[3], row[0], row[1],) for row in rows if row[4] is not None) + ) + + txn.execute( + "DELETE FROM event_push_actions" + " WHERE ? <= stream_ordering AND stream_ordering < ? AND highlight = 0", + (old_rotate_stream_ordering, rotate_to_stream_ordering,) + ) + + logger.info("Rotating notifications, deleted %s push actions", txn.rowcount) + + txn.execute( + "UPDATE event_push_summary_stream_ordering SET stream_ordering = ?", + (rotate_to_stream_ordering,) + ) + def _action_has_highlight(actions): for action in actions: diff --git a/synapse/storage/events.py b/synapse/storage/events.py index c88f689d3a..db01eb6d14 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -311,6 +311,21 @@ class EventsStore(SQLBaseStore): new_forward_extremeties[room_id] = new_latest_event_ids + len_1 = ( + len(latest_event_ids) == 1 + and len(new_latest_event_ids) == 1 + ) + if len_1: + all_single_prev_not_state = all( + len(event.prev_events) == 1 + and not event.is_state() + for event, ctx in ev_ctx_rm + ) + # Don't bother calculating state if they're just + # a long chain of single ancestor non-state events. + if all_single_prev_not_state: + continue + state = yield self._calculate_state_delta( room_id, ev_ctx_rm, new_latest_event_ids ) diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index b357f22be7..ed84db6b4b 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -25,7 +25,7 @@ logger = logging.getLogger(__name__) # Remember to update this number every time a change is made to database # schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 40 +SCHEMA_VERSION = 41 dir_path = os.path.abspath(os.path.dirname(__file__)) diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py index 7460f98a1f..4d1590d2b4 100644 --- a/synapse/storage/presence.py +++ b/synapse/storage/presence.py @@ -15,7 +15,7 @@ from ._base import SQLBaseStore from synapse.api.constants import PresenceState -from synapse.util.caches.descriptors import cached, cachedInlineCallbacks +from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList from collections import namedtuple from twisted.internet import defer @@ -85,6 +85,9 @@ class PresenceStore(SQLBaseStore): self.presence_stream_cache.entity_has_changed, state.user_id, stream_id, ) + self._invalidate_cache_and_stream( + txn, self._get_presence_for_user, (state.user_id,) + ) # Actually insert new rows self._simple_insert_many_txn( @@ -143,7 +146,12 @@ class PresenceStore(SQLBaseStore): "get_all_presence_updates", get_all_presence_updates_txn ) - @defer.inlineCallbacks + @cached() + def _get_presence_for_user(self, user_id): + raise NotImplementedError() + + @cachedList(cached_method_name="_get_presence_for_user", list_name="user_ids", + num_args=1, inlineCallbacks=True) def get_presence_for_users(self, user_ids): rows = yield self._simple_select_many_batch( table="presence_stream", @@ -165,7 +173,7 @@ class PresenceStore(SQLBaseStore): for row in rows: row["currently_active"] = bool(row["currently_active"]) - defer.returnValue([UserPresenceState(**row) for row in rows]) + defer.returnValue({row["user_id"]: UserPresenceState(**row) for row in rows}) def get_current_presence_token(self): return self._presence_id_gen.get_current_token() diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py index f72d15f5ed..5cf41501ea 100644 --- a/synapse/storage/receipts.py +++ b/synapse/storage/receipts.py @@ -351,6 +351,7 @@ class ReceiptsStore(SQLBaseStore): room_id=room_id, user_id=user_id, topological_ordering=topological_ordering, + stream_ordering=stream_ordering, ) return True diff --git a/synapse/storage/schema/delta/40/event_push_summary.sql b/synapse/storage/schema/delta/40/event_push_summary.sql new file mode 100644 index 0000000000..3918f0b794 --- /dev/null +++ b/synapse/storage/schema/delta/40/event_push_summary.sql @@ -0,0 +1,37 @@ +/* Copyright 2017 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Aggregate of old notification counts that have been deleted out of the +-- main event_push_actions table. This count does not include those that were +-- highlights, as they remain in the event_push_actions table. +CREATE TABLE event_push_summary ( + user_id TEXT NOT NULL, + room_id TEXT NOT NULL, + notif_count BIGINT NOT NULL, + stream_ordering BIGINT NOT NULL +); + +CREATE INDEX event_push_summary_user_rm ON event_push_summary(user_id, room_id); + + +-- The stream ordering up to which we have aggregated the event_push_actions +-- table into event_push_summary +CREATE TABLE event_push_summary_stream_ordering ( + Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row. + stream_ordering BIGINT NOT NULL, + CHECK (Lock='X') +); + +INSERT INTO event_push_summary_stream_ordering (stream_ordering) VALUES (0); diff --git a/synapse/storage/schema/delta/40/pushers.sql b/synapse/storage/schema/delta/40/pushers.sql new file mode 100644 index 0000000000..054a223f14 --- /dev/null +++ b/synapse/storage/schema/delta/40/pushers.sql @@ -0,0 +1,39 @@ +/* Copyright 2017 Vector Creations Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS pushers2 ( + id BIGINT PRIMARY KEY, + user_name TEXT NOT NULL, + access_token BIGINT DEFAULT NULL, + profile_tag TEXT NOT NULL, + kind TEXT NOT NULL, + app_id TEXT NOT NULL, + app_display_name TEXT NOT NULL, + device_display_name TEXT NOT NULL, + pushkey TEXT NOT NULL, + ts BIGINT NOT NULL, + lang TEXT, + data TEXT, + last_stream_ordering INTEGER, + last_success BIGINT, + failing_since BIGINT, + UNIQUE (app_id, pushkey, user_name) +); + +INSERT INTO pushers2 SELECT * FROM PUSHERS; + +DROP TABLE PUSHERS; + +ALTER TABLE pushers2 RENAME TO pushers; diff --git a/synapse/storage/schema/delta/41/device_list_stream_idx.sql b/synapse/storage/schema/delta/41/device_list_stream_idx.sql new file mode 100644 index 0000000000..b7bee8b692 --- /dev/null +++ b/synapse/storage/schema/delta/41/device_list_stream_idx.sql @@ -0,0 +1,17 @@ +/* Copyright 2017 Vector Creations Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT into background_updates (update_name, progress_json) + VALUES ('device_lists_stream_idx', '{}'); diff --git a/synapse/storage/schema/delta/41/device_outbound_index.sql b/synapse/storage/schema/delta/41/device_outbound_index.sql new file mode 100644 index 0000000000..62f0b9892b --- /dev/null +++ b/synapse/storage/schema/delta/41/device_outbound_index.sql @@ -0,0 +1,16 @@ +/* Copyright 2017 Vector Creations Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE INDEX device_lists_outbound_pokes_stream ON device_lists_outbound_pokes(stream_id); diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 1b3800eb6a..84482d8285 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -413,7 +413,19 @@ class StateStore(SQLBaseStore): defer.returnValue({event: event_to_state[event] for event in event_ids}) @defer.inlineCallbacks - def get_state_ids_for_events(self, event_ids, types): + def get_state_ids_for_events(self, event_ids, types=None): + """ + Get the state dicts corresponding to a list of events + + Args: + event_ids(list(str)): events whose state should be returned + types(list[(str, str)]|None): List of (type, state_key) tuples + which are used to filter the state fetched. May be None, which + matches any key + + Returns: + A deferred dict from event_id -> (type, state_key) -> state_event + """ event_to_groups = yield self._get_state_group_for_events( event_ids, ) diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py index 2987c38a2d..cbdde34a57 100644 --- a/synapse/util/caches/expiringcache.py +++ b/synapse/util/caches/expiringcache.py @@ -100,6 +100,13 @@ class ExpiringCache(object): except KeyError: return default + def setdefault(self, key, value): + try: + return self[key] + except KeyError: + self[key] = value + return value + def _prune_cache(self): if not self._expiry_ms: # zero expiry time means don't expire. This should never get called |