diff options
-rw-r--r-- | synapse/api/auth.py | 41 | ||||
-rw-r--r-- | synapse/api/constants.py | 5 | ||||
-rw-r--r-- | synapse/config/captcha.py | 4 | ||||
-rw-r--r-- | synapse/handlers/auth.py | 6 | ||||
-rw-r--r-- | synapse/handlers/federation.py | 227 | ||||
-rw-r--r-- | synapse/handlers/room.py | 96 | ||||
-rw-r--r-- | synapse/rest/client/v2_alpha/register.py | 13 | ||||
-rw-r--r-- | synapse/rest/media/v1/thumbnailer.py | 2 | ||||
-rw-r--r-- | synapse/storage/event_federation.py | 108 | ||||
-rw-r--r-- | synapse/storage/events.py | 405 | ||||
-rw-r--r-- | synapse/storage/roommember.py | 44 | ||||
-rw-r--r-- | synapse/storage/signatures.py | 28 | ||||
-rw-r--r-- | synapse/storage/state.py | 36 | ||||
-rw-r--r-- | synapse/storage/util/id_generators.py | 31 |
14 files changed, 628 insertions, 418 deletions
diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 1a25bf1086..487be7ce9c 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -187,6 +187,9 @@ class Auth(object): join_rule = JoinRules.INVITE user_level = self._get_user_power_level(event.user_id, auth_events) + target_level = self._get_user_power_level( + target_user_id, auth_events + ) # FIXME (erikj): What should we do here as the default? ban_level = self._get_named_level(auth_events, "ban", 50) @@ -258,12 +261,12 @@ class Auth(object): elif target_user_id != event.user_id: kick_level = self._get_named_level(auth_events, "kick", 50) - if user_level < kick_level: + if user_level < kick_level or user_level <= target_level: raise AuthError( 403, "You cannot kick user %s." % target_user_id ) elif Membership.BAN == membership: - if user_level < ban_level: + if user_level < ban_level or user_level <= target_level: raise AuthError(403, "You don't have permission to ban") else: raise AuthError(500, "Unknown membership %s" % membership) @@ -573,26 +576,26 @@ class Auth(object): # Check other levels: levels_to_check = [ - ("users_default", []), - ("events_default", []), - ("state_default", []), - ("ban", []), - ("redact", []), - ("kick", []), - ("invite", []), + ("users_default", None), + ("events_default", None), + ("state_default", None), + ("ban", None), + ("redact", None), + ("kick", None), + ("invite", None), ] old_list = current_state.content.get("users") for user in set(old_list.keys() + user_list.keys()): levels_to_check.append( - (user, ["users"]) + (user, "users") ) old_list = current_state.content.get("events") new_list = event.content.get("events") for ev_id in set(old_list.keys() + new_list.keys()): levels_to_check.append( - (ev_id, ["events"]) + (ev_id, "events") ) old_state = current_state.content @@ -600,12 +603,10 @@ class Auth(object): for level_to_check, dir in levels_to_check: old_loc = old_state - for d in dir: - old_loc = old_loc.get(d, {}) - new_loc = new_state - for d in dir: - new_loc = new_loc.get(d, {}) + if dir: + old_loc = old_loc.get(dir, {}) + new_loc = new_loc.get(dir, {}) if level_to_check in old_loc: old_level = int(old_loc[level_to_check]) @@ -621,6 +622,14 @@ class Auth(object): if new_level == old_level: continue + if dir == "users" and level_to_check != event.user_id: + if old_level == user_level: + raise AuthError( + 403, + "You don't have permission to remove ops level equal " + "to your own" + ) + if old_level > user_level or new_level > user_level: raise AuthError( 403, diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 3e15e8a9d7..7156ee4e7d 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -87,3 +87,8 @@ class RejectedReason(object): AUTH_ERROR = "auth_error" REPLACED = "replaced" NOT_ANCESTOR = "not_ancestor" + + +class RoomCreationPreset(object): + PRIVATE_CHAT = "private_chat" + PUBLIC_CHAT = "public_chat" diff --git a/synapse/config/captcha.py b/synapse/config/captcha.py index cf72dc4340..15a132b4e3 100644 --- a/synapse/config/captcha.py +++ b/synapse/config/captcha.py @@ -29,10 +29,10 @@ class CaptchaConfig(Config): ## Captcha ## # This Home Server's ReCAPTCHA public key. - recaptcha_private_key: "YOUR_PUBLIC_KEY" + recaptcha_private_key: "YOUR_PRIVATE_KEY" # This Home Server's ReCAPTCHA private key. - recaptcha_public_key: "YOUR_PRIVATE_KEY" + recaptcha_public_key: "YOUR_PUBLIC_KEY" # Enables ReCaptcha checks when registering, preventing signup # unless a captcha is answered. Requires a valid ReCaptcha diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 63071653a3..1ecf7fef17 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -85,8 +85,10 @@ class AuthHandler(BaseHandler): # email auth link on there). It's probably too open to abuse # because it lets unauthenticated clients store arbitrary objects # on a home server. - # sess['clientdict'] = clientdict - # self._save_session(sess) + # Revisit: Assumimg the REST APIs do sensible validation, the data + # isn't arbintrary. + sess['clientdict'] = clientdict + self._save_session(sess) pass elif 'clientdict' in sess: clientdict = sess['clientdict'] diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index d7f197f247..f7155fd8d3 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -140,26 +140,29 @@ class FederationHandler(BaseHandler): if state and auth_chain is not None: # If we have any state or auth_chain given to us by the replication # layer, then we should handle them (if we haven't before.) + + event_infos = [] + for e in itertools.chain(auth_chain, state): if e.event_id in seen_ids: continue - e.internal_metadata.outlier = True - try: - auth_ids = [e_id for e_id, _ in e.auth_events] - auth = { - (e.type, e.state_key): e for e in auth_chain - if e.event_id in auth_ids - } - yield self._handle_new_event( - origin, e, auth_events=auth - ) - seen_ids.add(e.event_id) - except: - logger.exception( - "Failed to handle state event %s", - e.event_id, - ) + auth_ids = [e_id for e_id, _ in e.auth_events] + auth = { + (e.type, e.state_key): e for e in auth_chain + if e.event_id in auth_ids + } + event_infos.append({ + "event": e, + "auth_events": auth, + }) + seen_ids.add(e.event_id) + + yield self._handle_new_events( + origin, + event_infos, + outliers=True + ) try: _, event_stream_id, max_stream_id = yield self._handle_new_event( @@ -344,38 +347,29 @@ class FederationHandler(BaseHandler): ).addErrback(unwrapFirstError) auth_events.update({a.event_id: a for a in results}) - yield defer.gatherResults( - [ - self._handle_new_event( - dest, a, - auth_events={ - (auth_events[a_id].type, auth_events[a_id].state_key): - auth_events[a_id] - for a_id, _ in a.auth_events - }, - ) - for a in auth_events.values() - if a.event_id not in seen_events - ], - consumeErrors=True, - ).addErrback(unwrapFirstError) - - yield defer.gatherResults( - [ - self._handle_new_event( - dest, event_map[e_id], - state=events_to_state[e_id], - backfilled=True, - auth_events={ - (auth_events[a_id].type, auth_events[a_id].state_key): - auth_events[a_id] - for a_id, _ in event_map[e_id].auth_events - }, - ) - for e_id in events_to_state - ], - consumeErrors=True - ).addErrback(unwrapFirstError) + ev_infos = [] + for a in auth_events.values(): + if a.event_id in seen_events: + continue + ev_infos.append({ + "event": a, + "auth_events": { + (auth_events[a_id].type, auth_events[a_id].state_key): + auth_events[a_id] + for a_id, _ in a.auth_events + } + }) + + for e_id in events_to_state: + ev_infos.append({ + "event": event_map[e_id], + "state": events_to_state[e_id], + "auth_events": { + (auth_events[a_id].type, auth_events[a_id].state_key): + auth_events[a_id] + for a_id, _ in event_map[e_id].auth_events + } + }) events.sort(key=lambda e: e.depth) @@ -383,10 +377,14 @@ class FederationHandler(BaseHandler): if event in events_to_state: continue - yield self._handle_new_event( - dest, event, - backfilled=True, - ) + ev_infos.append({ + "event": event, + }) + + yield self._handle_new_events( + dest, ev_infos, + backfilled=True, + ) defer.returnValue(events) @@ -652,32 +650,22 @@ class FederationHandler(BaseHandler): # FIXME pass - yield self._handle_auth_events( - origin, [e for e in auth_chain if e.event_id != event.event_id] - ) - - @defer.inlineCallbacks - def handle_state(e): + ev_infos = [] + for e in itertools.chain(state, auth_chain): if e.event_id == event.event_id: - return + continue e.internal_metadata.outlier = True - try: - auth_ids = [e_id for e_id, _ in e.auth_events] - auth = { + auth_ids = [e_id for e_id, _ in e.auth_events] + ev_infos.append({ + "event": e, + "auth_events": { (e.type, e.state_key): e for e in auth_chain if e.event_id in auth_ids } - yield self._handle_new_event( - origin, e, auth_events=auth - ) - except: - logger.exception( - "Failed to handle state event %s", - e.event_id, - ) + }) - yield defer.DeferredList([handle_state(e) for e in state]) + yield self._handle_new_events(origin, ev_infos, outliers=True) auth_ids = [e_id for e_id, _ in event.auth_events] auth_events = { @@ -994,11 +982,54 @@ class FederationHandler(BaseHandler): def _handle_new_event(self, origin, event, state=None, backfilled=False, current_state=None, auth_events=None): - logger.debug( - "_handle_new_event: %s, sigs: %s", - event.event_id, event.signatures, + outlier = event.internal_metadata.is_outlier() + + context = yield self._prep_event( + origin, event, + state=state, + backfilled=backfilled, + current_state=current_state, + auth_events=auth_events, ) + event_stream_id, max_stream_id = yield self.store.persist_event( + event, + context=context, + backfilled=backfilled, + is_new_state=(not outlier and not backfilled), + current_state=current_state, + ) + + defer.returnValue((context, event_stream_id, max_stream_id)) + + @defer.inlineCallbacks + def _handle_new_events(self, origin, event_infos, backfilled=False, + outliers=False): + contexts = yield defer.gatherResults( + [ + self._prep_event( + origin, + ev_info["event"], + state=ev_info.get("state"), + backfilled=backfilled, + auth_events=ev_info.get("auth_events"), + ) + for ev_info in event_infos + ] + ) + + yield self.store.persist_events( + [ + (ev_info["event"], context) + for ev_info, context in itertools.izip(event_infos, contexts) + ], + backfilled=backfilled, + is_new_state=(not outliers and not backfilled), + ) + + @defer.inlineCallbacks + def _prep_event(self, origin, event, state=None, backfilled=False, + current_state=None, auth_events=None): outlier = event.internal_metadata.is_outlier() context = yield self.state_handler.compute_event_context( @@ -1008,13 +1039,6 @@ class FederationHandler(BaseHandler): if not auth_events: auth_events = context.current_state - logger.debug( - "_handle_new_event: %s, auth_events: %s", - event.event_id, auth_events, - ) - - is_new_state = not outlier - # This is a hack to fix some old rooms where the initial join event # didn't reference the create event in its auth events. if event.type == EventTypes.Member and not event.auth_events: @@ -1038,26 +1062,7 @@ class FederationHandler(BaseHandler): context.rejected = RejectedReason.AUTH_ERROR - # FIXME: Don't store as rejected with AUTH_ERROR if we haven't - # seen all the auth events. - yield self.store.persist_event( - event, - context=context, - backfilled=backfilled, - is_new_state=False, - current_state=current_state, - ) - raise - - event_stream_id, max_stream_id = yield self.store.persist_event( - event, - context=context, - backfilled=backfilled, - is_new_state=(is_new_state and not backfilled), - current_state=current_state, - ) - - defer.returnValue((context, event_stream_id, max_stream_id)) + defer.returnValue(context) @defer.inlineCallbacks def on_query_auth(self, origin, event_id, remote_auth_chain, rejects, @@ -1120,14 +1125,24 @@ class FederationHandler(BaseHandler): @log_function def do_auth(self, origin, event, context, auth_events): # Check if we have all the auth events. - have_events = yield self.store.have_events( - [e_id for e_id, _ in event.auth_events] - ) - + current_state = set(e.event_id for e in auth_events.values()) event_auth_events = set(e_id for e_id, _ in event.auth_events) + + if event_auth_events - current_state: + have_events = yield self.store.have_events( + event_auth_events - current_state + ) + else: + have_events = {} + + have_events.update({ + e.event_id: "" + for e in auth_events.values() + }) + seen_events = set(have_events.keys()) - missing_auth = event_auth_events - seen_events + missing_auth = event_auth_events - seen_events - current_state if missing_auth: logger.info("Missing auth: %s", missing_auth) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 891707df44..7511d294f3 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -19,12 +19,15 @@ from twisted.internet import defer from ._base import BaseHandler from synapse.types import UserID, RoomAlias, RoomID -from synapse.api.constants import EventTypes, Membership, JoinRules +from synapse.api.constants import ( + EventTypes, Membership, JoinRules, RoomCreationPreset, +) from synapse.api.errors import StoreError, SynapseError from synapse.util import stringutils, unwrapFirstError from synapse.util.async import run_on_reactor from synapse.events.utils import serialize_event +from collections import OrderedDict import logging import string @@ -33,6 +36,19 @@ logger = logging.getLogger(__name__) class RoomCreationHandler(BaseHandler): + PRESETS_DICT = { + RoomCreationPreset.PRIVATE_CHAT: { + "join_rules": JoinRules.INVITE, + "history_visibility": "invited", + "original_invitees_have_ops": False, + }, + RoomCreationPreset.PUBLIC_CHAT: { + "join_rules": JoinRules.PUBLIC, + "history_visibility": "shared", + "original_invitees_have_ops": False, + }, + } + @defer.inlineCallbacks def create_room(self, user_id, room_id, config): """ Creates a new room. @@ -121,9 +137,25 @@ class RoomCreationHandler(BaseHandler): servers=[self.hs.hostname], ) + preset_config = config.get( + "preset", + RoomCreationPreset.PUBLIC_CHAT + if is_public + else RoomCreationPreset.PRIVATE_CHAT + ) + + raw_initial_state = config.get("initial_state", []) + + initial_state = OrderedDict() + for val in raw_initial_state: + initial_state[(val["type"], val.get("state_key", ""))] = val["content"] + user = UserID.from_string(user_id) creation_events = self._create_events_for_new_room( - user, room_id, is_public=is_public + user, room_id, + preset_config=preset_config, + invite_list=invite_list, + initial_state=initial_state, ) msg_handler = self.hs.get_handlers().message_handler @@ -170,7 +202,10 @@ class RoomCreationHandler(BaseHandler): defer.returnValue(result) - def _create_events_for_new_room(self, creator, room_id, is_public=False): + def _create_events_for_new_room(self, creator, room_id, preset_config, + invite_list, initial_state): + config = RoomCreationHandler.PRESETS_DICT[preset_config] + creator_id = creator.to_string() event_keys = { @@ -203,9 +238,10 @@ class RoomCreationHandler(BaseHandler): }, ) - power_levels_event = create( - etype=EventTypes.PowerLevels, - content={ + returned_events = [creation_event, join_event] + + if (EventTypes.PowerLevels, '') not in initial_state: + power_level_content = { "users": { creator.to_string(): 100, }, @@ -221,21 +257,43 @@ class RoomCreationHandler(BaseHandler): "kick": 50, "redact": 50, "invite": 0, - }, - ) + } - join_rule = JoinRules.PUBLIC if is_public else JoinRules.INVITE - join_rules_event = create( - etype=EventTypes.JoinRules, - content={"join_rule": join_rule}, - ) + if config["original_invitees_have_ops"]: + for invitee in invite_list: + power_level_content["users"][invitee] = 100 - return [ - creation_event, - join_event, - power_levels_event, - join_rules_event, - ] + power_levels_event = create( + etype=EventTypes.PowerLevels, + content=power_level_content, + ) + + returned_events.append(power_levels_event) + + if (EventTypes.JoinRules, '') not in initial_state: + join_rules_event = create( + etype=EventTypes.JoinRules, + content={"join_rule": config["join_rules"]}, + ) + + returned_events.append(join_rules_event) + + if (EventTypes.RoomHistoryVisibility, '') not in initial_state: + history_event = create( + etype=EventTypes.RoomHistoryVisibility, + content={"history_visibility": config["history_visibility"]} + ) + + returned_events.append(history_event) + + for (etype, state_key), content in initial_state.items(): + returned_events.append(create( + etype=etype, + state_key=state_key, + content=content, + )) + + return returned_events class RoomMemberHandler(BaseHandler): diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 72dfb876c5..0c737d73b8 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -57,10 +57,19 @@ class RegisterRestServlet(RestServlet): yield run_on_reactor() body = parse_request_allow_empty(request) - if 'password' not in body: - raise SynapseError(400, "", Codes.MISSING_PARAM) + # we do basic sanity checks here because the auth + # layer will store these in sessions + if 'password' in body: + if ((not isinstance(body['password'], str) and + not isinstance(body['password'], unicode)) or + len(body['password']) > 512): + raise SynapseError(400, "Invalid password") if 'username' in body: + if ((not isinstance(body['username'], str) and + not isinstance(body['username'], unicode)) or + len(body['username']) > 512): + raise SynapseError(400, "Invalid username") desired_username = body['username'] yield self.registration_handler.check_username(desired_username) diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py index 28404f2b7b..1e965c363a 100644 --- a/synapse/rest/media/v1/thumbnailer.py +++ b/synapse/rest/media/v1/thumbnailer.py @@ -82,7 +82,7 @@ class Thumbnailer(object): def save_image(self, output_image, output_type, output_path): output_bytes_io = BytesIO() - output_image.save(output_bytes_io, self.FORMATS[output_type], quality=70) + output_image.save(output_bytes_io, self.FORMATS[output_type], quality=80) output_bytes = output_bytes_io.getvalue() with open(output_path, "wb") as output_file: output_file.write(output_bytes) diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py index c71019d93b..45b86c94e8 100644 --- a/synapse/storage/event_federation.py +++ b/synapse/storage/event_federation.py @@ -282,8 +282,7 @@ class EventFederationStore(SQLBaseStore): }, ) - def _handle_prev_events(self, txn, outlier, event_id, prev_events, - room_id): + def _handle_mult_prev_events(self, txn, events): """ For the given event, update the event edges table and forward and backward extremities tables. @@ -293,68 +292,75 @@ class EventFederationStore(SQLBaseStore): table="event_edges", values=[ { - "event_id": event_id, + "event_id": ev.event_id, "prev_event_id": e_id, - "room_id": room_id, + "room_id": ev.room_id, "is_state": False, } - for e_id, _ in prev_events + for ev in events + for e_id, _ in ev.prev_events ], ) - # Update the extremities table if this is not an outlier. - if not outlier: - for e_id, _ in prev_events: - # TODO (erikj): This could be done as a bulk insert - self._simple_delete_txn( - txn, - table="event_forward_extremities", - keyvalues={ - "event_id": e_id, - "room_id": room_id, - } - ) + events_by_room = {} + for ev in events: + events_by_room.setdefault(ev.room_id, []).append(ev) - # We only insert as a forward extremity the new event if there are - # no other events that reference it as a prev event - query = ( - "SELECT 1 FROM event_edges WHERE prev_event_id = ?" - ) + for room_id, room_events in events_by_room.items(): + prevs = [ + e_id for ev in room_events for e_id, _ in ev.prev_events + if not ev.internal_metadata.is_outlier() + ] + if prevs: + txn.execute( + "DELETE FROM event_forward_extremities" + " WHERE room_id = ?" + " AND event_id in (%s)" % ( + ",".join(["?"] * len(prevs)), + ), + [room_id] + prevs, + ) - txn.execute(query, (event_id,)) + query = ( + "INSERT INTO event_forward_extremities (event_id, room_id)" + " SELECT ?, ? WHERE NOT EXISTS (" + " SELECT 1 FROM event_edges WHERE prev_event_id = ?" + " )" + ) - if not txn.fetchone(): - query = ( - "INSERT INTO event_forward_extremities" - " (event_id, room_id)" - " VALUES (?, ?)" - ) + txn.executemany( + query, + [(ev.event_id, ev.room_id, ev.event_id) for ev in events] + ) - txn.execute(query, (event_id, room_id)) - - query = ( - "INSERT INTO event_backward_extremities (event_id, room_id)" - " SELECT ?, ? WHERE NOT EXISTS (" - " SELECT 1 FROM event_backward_extremities" - " WHERE event_id = ? AND room_id = ?" - " )" - " AND NOT EXISTS (" - " SELECT 1 FROM events WHERE event_id = ? AND room_id = ? " - " AND outlier = ?" - " )" - ) + query = ( + "INSERT INTO event_backward_extremities (event_id, room_id)" + " SELECT ?, ? WHERE NOT EXISTS (" + " SELECT 1 FROM event_backward_extremities" + " WHERE event_id = ? AND room_id = ?" + " )" + " AND NOT EXISTS (" + " SELECT 1 FROM events WHERE event_id = ? AND room_id = ? " + " AND outlier = ?" + " )" + ) - txn.executemany(query, [ - (e_id, room_id, e_id, room_id, e_id, room_id, False) - for e_id, _ in prev_events - ]) + txn.executemany(query, [ + (e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False) + for ev in events for e_id, _ in ev.prev_events + if not ev.internal_metadata.is_outlier() + ]) - query = ( - "DELETE FROM event_backward_extremities" - " WHERE event_id = ? AND room_id = ?" - ) - txn.execute(query, (event_id, room_id)) + query = ( + "DELETE FROM event_backward_extremities" + " WHERE event_id = ? AND room_id = ?" + ) + txn.executemany( + query, + [(ev.event_id, ev.room_id) for ev in events] + ) + for room_id in events_by_room: txn.call_after( self.get_latest_event_ids_in_room.invalidate, room_id ) diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 2caf0aae80..ed7ea38804 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -23,9 +23,7 @@ from synapse.events.utils import prune_event from synapse.util.logcontext import preserve_context_over_deferred from synapse.util.logutils import log_function from synapse.api.constants import EventTypes -from synapse.crypto.event_signing import compute_event_reference_hash -from syutil.base64util import decode_base64 from syutil.jsonutil import encode_json from contextlib import contextmanager @@ -47,6 +45,48 @@ EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events class EventsStore(SQLBaseStore): @defer.inlineCallbacks + def persist_events(self, events_and_contexts, backfilled=False, + is_new_state=True): + if not events_and_contexts: + return + + if backfilled: + if not self.min_token_deferred.called: + yield self.min_token_deferred + start = self.min_token - 1 + self.min_token -= len(events_and_contexts) + 1 + stream_orderings = range(start, self.min_token, -1) + + @contextmanager + def stream_ordering_manager(): + yield stream_orderings + stream_ordering_manager = stream_ordering_manager() + else: + stream_ordering_manager = yield self._stream_id_gen.get_next_mult( + self, len(events_and_contexts) + ) + + with stream_ordering_manager as stream_orderings: + for (event, _), stream in zip(events_and_contexts, stream_orderings): + event.internal_metadata.stream_ordering = stream + + chunks = [ + events_and_contexts[x:x+100] + for x in xrange(0, len(events_and_contexts), 100) + ] + + for chunk in chunks: + # We can't easily parallelize these since different chunks + # might contain the same event. :( + yield self.runInteraction( + "persist_events", + self._persist_events_txn, + events_and_contexts=chunk, + backfilled=backfilled, + is_new_state=is_new_state, + ) + + @defer.inlineCallbacks @log_function def persist_event(self, event, context, backfilled=False, is_new_state=True, current_state=None): @@ -67,13 +107,13 @@ class EventsStore(SQLBaseStore): try: with stream_ordering_manager as stream_ordering: + event.internal_metadata.stream_ordering = stream_ordering yield self.runInteraction( "persist_event", self._persist_event_txn, event=event, context=context, backfilled=backfilled, - stream_ordering=stream_ordering, is_new_state=is_new_state, current_state=current_state, ) @@ -116,12 +156,7 @@ class EventsStore(SQLBaseStore): @log_function def _persist_event_txn(self, txn, event, context, backfilled, - stream_ordering=None, is_new_state=True, - current_state=None): - - # Remove the any existing cache entries for the event_id - txn.call_after(self._invalidate_get_event_cache, event.event_id) - + is_new_state=True, current_state=None): # We purposefully do this first since if we include a `current_state` # key, we *want* to update the `current_state_events` table if current_state: @@ -149,37 +184,78 @@ class EventsStore(SQLBaseStore): } ) - outlier = event.internal_metadata.is_outlier() + return self._persist_events_txn( + txn, + [(event, context)], + backfilled=backfilled, + is_new_state=is_new_state, + ) - if not outlier: - self._update_min_depth_for_room_txn( - txn, - event.room_id, - event.depth + @log_function + def _persist_events_txn(self, txn, events_and_contexts, backfilled, + is_new_state=True): + + # Remove the any existing cache entries for the event_ids + for event, _ in events_and_contexts: + txn.call_after(self._invalidate_get_event_cache, event.event_id) + + depth_updates = {} + for event, _ in events_and_contexts: + if event.internal_metadata.is_outlier(): + continue + depth_updates[event.room_id] = max( + event.depth, depth_updates.get(event.room_id, event.depth) ) - have_persisted = self._simple_select_one_txn( - txn, - table="events", - keyvalues={"event_id": event.event_id}, - retcols=["event_id", "outlier"], - allow_none=True, + for room_id, depth in depth_updates.items(): + self._update_min_depth_for_room_txn(txn, room_id, depth) + + txn.execute( + "SELECT event_id, outlier FROM events WHERE event_id in (%s)" % ( + ",".join(["?"] * len(events_and_contexts)), + ), + [event.event_id for event, _ in events_and_contexts] ) + have_persisted = { + event_id: outlier + for event_id, outlier in txn.fetchall() + } + + event_map = {} + to_remove = set() + for event, context in events_and_contexts: + # Handle the case of the list including the same event multiple + # times. The tricky thing here is when they differ by whether + # they are an outlier. + if event.event_id in event_map: + other = event_map[event.event_id] + + if not other.internal_metadata.is_outlier(): + to_remove.add(event) + continue + elif not event.internal_metadata.is_outlier(): + to_remove.add(event) + continue + else: + to_remove.add(other) + + event_map[event.event_id] = event - metadata_json = encode_json( - event.internal_metadata.get_dict(), - using_frozen_dicts=USE_FROZEN_DICTS - ).decode("UTF-8") - - # If we have already persisted this event, we don't need to do any - # more processing. - # The processing above must be done on every call to persist event, - # since they might not have happened on previous calls. For example, - # if we are persisting an event that we had persisted as an outlier, - # but is no longer one. - if have_persisted: - if not outlier and have_persisted["outlier"]: - self._store_state_groups_txn(txn, event, context) + if event.event_id not in have_persisted: + continue + + to_remove.add(event) + + outlier_persisted = have_persisted[event.event_id] + if not event.internal_metadata.is_outlier() and outlier_persisted: + self._store_state_groups_txn( + txn, event, context, + ) + + metadata_json = encode_json( + event.internal_metadata.get_dict(), + using_frozen_dicts=USE_FROZEN_DICTS + ).decode("UTF-8") sql = ( "UPDATE event_json SET internal_metadata = ?" @@ -198,94 +274,91 @@ class EventsStore(SQLBaseStore): sql, (False, event.event_id,) ) - return - if not outlier: - self._store_state_groups_txn(txn, event, context) - - self._handle_prev_events( - txn, - outlier=outlier, - event_id=event.event_id, - prev_events=event.prev_events, - room_id=event.room_id, + events_and_contexts = filter( + lambda ec: ec[0] not in to_remove, + events_and_contexts ) - if event.type == EventTypes.Member: - self._store_room_member_txn(txn, event) - elif event.type == EventTypes.Name: - self._store_room_name_txn(txn, event) - elif event.type == EventTypes.Topic: - self._store_room_topic_txn(txn, event) - elif event.type == EventTypes.Redaction: - self._store_redaction(txn, event) - - event_dict = { - k: v - for k, v in event.get_dict().items() - if k not in [ - "redacted", - "redacted_because", - ] - } + if not events_and_contexts: + return + + self._store_mult_state_groups_txn(txn, [ + (event, context) + for event, context in events_and_contexts + if not event.internal_metadata.is_outlier() + ]) - self._simple_insert_txn( + self._handle_mult_prev_events( txn, - table="event_json", - values={ - "event_id": event.event_id, - "room_id": event.room_id, - "internal_metadata": metadata_json, - "json": encode_json( - event_dict, using_frozen_dicts=USE_FROZEN_DICTS - ).decode("UTF-8"), - }, + events=[event for event, _ in events_and_contexts], ) - content = encode_json( - event.content, using_frozen_dicts=USE_FROZEN_DICTS - ).decode("UTF-8") - - vals = { - "topological_ordering": event.depth, - "event_id": event.event_id, - "type": event.type, - "room_id": event.room_id, - "content": content, - "processed": True, - "outlier": outlier, - "depth": event.depth, - } + for event, _ in events_and_contexts: + if event.type == EventTypes.Name: + self._store_room_name_txn(txn, event) + elif event.type == EventTypes.Topic: + self._store_room_topic_txn(txn, event) + elif event.type == EventTypes.Redaction: + self._store_redaction(txn, event) - unrec = { - k: v - for k, v in event.get_dict().items() - if k not in vals.keys() and k not in [ - "redacted", - "redacted_because", - "signatures", - "hashes", - "prev_events", + self._store_room_members_txn( + txn, + [ + event + for event, _ in events_and_contexts + if event.type == EventTypes.Member ] - } + ) - vals["unrecognized_keys"] = encode_json( - unrec, using_frozen_dicts=USE_FROZEN_DICTS - ).decode("UTF-8") + def event_dict(event): + return { + k: v + for k, v in event.get_dict().items() + if k not in [ + "redacted", + "redacted_because", + ] + } - sql = ( - "INSERT INTO events" - " (stream_ordering, topological_ordering, event_id, type," - " room_id, content, processed, outlier, depth)" - " VALUES (?,?,?,?,?,?,?,?,?)" + self._simple_insert_many_txn( + txn, + table="event_json", + values=[ + { + "event_id": event.event_id, + "room_id": event.room_id, + "internal_metadata": encode_json( + event.internal_metadata.get_dict(), + using_frozen_dicts=USE_FROZEN_DICTS + ).decode("UTF-8"), + "json": encode_json( + event_dict(event), using_frozen_dicts=USE_FROZEN_DICTS + ).decode("UTF-8"), + } + for event, _ in events_and_contexts + ], ) - txn.execute( - sql, - ( - stream_ordering, event.depth, event.event_id, event.type, - event.room_id, content, True, outlier, event.depth - ) + self._simple_insert_many_txn( + txn, + table="events", + values=[ + { + "stream_ordering": event.internal_metadata.stream_ordering, + "topological_ordering": event.depth, + "depth": event.depth, + "event_id": event.event_id, + "room_id": event.room_id, + "type": event.type, + "processed": True, + "outlier": event.internal_metadata.is_outlier(), + "content": encode_json( + event.content, using_frozen_dicts=USE_FROZEN_DICTS + ).decode("UTF-8"), + } + for event, _ in events_and_contexts + ], ) if context.rejected: @@ -293,20 +366,6 @@ class EventsStore(SQLBaseStore): txn, event.event_id, context.rejected ) - for hash_alg, hash_base64 in event.hashes.items(): - hash_bytes = decode_base64(hash_base64) - self._store_event_content_hash_txn( - txn, event.event_id, hash_alg, hash_bytes, - ) - - for prev_event_id, prev_hashes in event.prev_events: - for alg, hash_base64 in prev_hashes.items(): - hash_bytes = decode_base64(hash_base64) - self._store_prev_event_hash_txn( - txn, event.event_id, prev_event_id, alg, - hash_bytes - ) - self._simple_insert_many_txn( txn, table="event_auth", @@ -316,16 +375,22 @@ class EventsStore(SQLBaseStore): "room_id": event.room_id, "auth_id": auth_id, } + for event, _ in events_and_contexts for auth_id, _ in event.auth_events ], ) - (ref_alg, ref_hash_bytes) = compute_event_reference_hash(event) - self._store_event_reference_hash_txn( - txn, event.event_id, ref_alg, ref_hash_bytes + self._store_event_reference_hashes_txn( + txn, [event for event, _ in events_and_contexts] + ) + + state_events_and_contexts = filter( + lambda i: i[0].is_state(), + events_and_contexts, ) - if event.is_state(): + state_values = [] + for event, context in state_events_and_contexts: vals = { "event_id": event.event_id, "room_id": event.room_id, @@ -337,52 +402,56 @@ class EventsStore(SQLBaseStore): if hasattr(event, "replaces_state"): vals["prev_state"] = event.replaces_state - self._simple_insert_txn( - txn, - "state_events", - vals, - ) + state_values.append(vals) - self._simple_insert_many_txn( - txn, - table="event_edges", - values=[ - { - "event_id": event.event_id, - "prev_event_id": e_id, - "room_id": event.room_id, - "is_state": True, - } - for e_id, h in event.prev_state - ], - ) + self._simple_insert_many_txn( + txn, + table="state_events", + values=state_values, + ) - if is_new_state and not context.rejected: - txn.call_after( - self.get_current_state_for_key.invalidate, - event.room_id, event.type, event.state_key - ) + self._simple_insert_many_txn( + txn, + table="event_edges", + values=[ + { + "event_id": event.event_id, + "prev_event_id": prev_id, + "room_id": event.room_id, + "is_state": True, + } + for event, _ in state_events_and_contexts + for prev_id, _ in event.prev_state + ], + ) - if (event.type == EventTypes.Name - or event.type == EventTypes.Aliases): + if is_new_state: + for event, _ in state_events_and_contexts: + if not context.rejected: txn.call_after( - self.get_room_name_and_aliases.invalidate, - event.room_id + self.get_current_state_for_key.invalidate, + event.room_id, event.type, event.state_key + ) + + if event.type in [EventTypes.Name, EventTypes.Aliases]: + txn.call_after( + self.get_room_name_and_aliases.invalidate, + event.room_id + ) + + self._simple_upsert_txn( + txn, + "current_state_events", + keyvalues={ + "room_id": event.room_id, + "type": event.type, + "state_key": event.state_key, + }, + values={ + "event_id": event.event_id, + } ) - self._simple_upsert_txn( - txn, - "current_state_events", - keyvalues={ - "room_id": event.room_id, - "type": event.type, - "state_key": event.state_key, - }, - values={ - "event_id": event.event_id, - } - ) - return def _store_redaction(self, txn, event): diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index d36a6c18a8..4db07f6fb4 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -35,38 +35,28 @@ RoomsForUser = namedtuple( class RoomMemberStore(SQLBaseStore): - def _store_room_member_txn(self, txn, event): + def _store_room_members_txn(self, txn, events): """Store a room member in the database. """ - try: - target_user_id = event.state_key - except: - logger.exception( - "Failed to parse target_user_id=%s", target_user_id - ) - raise - - logger.debug( - "_store_room_member_txn: target_user_id=%s, membership=%s", - target_user_id, - event.membership, - ) - - self._simple_insert_txn( + self._simple_insert_many_txn( txn, - "room_memberships", - { - "event_id": event.event_id, - "user_id": target_user_id, - "sender": event.user_id, - "room_id": event.room_id, - "membership": event.membership, - } + table="room_memberships", + values=[ + { + "event_id": event.event_id, + "user_id": event.state_key, + "sender": event.user_id, + "room_id": event.room_id, + "membership": event.membership, + } + for event in events + ] ) - txn.call_after(self.get_rooms_for_user.invalidate, target_user_id) - txn.call_after(self.get_joined_hosts_for_room.invalidate, event.room_id) - txn.call_after(self.get_users_in_room.invalidate, event.room_id) + for event in events: + txn.call_after(self.get_rooms_for_user.invalidate, event.state_key) + txn.call_after(self.get_joined_hosts_for_room.invalidate, event.room_id) + txn.call_after(self.get_users_in_room.invalidate, event.room_id) def get_room_member(self, user_id, room_id): """Retrieve the current state of a room member. diff --git a/synapse/storage/signatures.py b/synapse/storage/signatures.py index f051828630..4f15e534b4 100644 --- a/synapse/storage/signatures.py +++ b/synapse/storage/signatures.py @@ -18,6 +18,7 @@ from twisted.internet import defer from _base import SQLBaseStore from syutil.base64util import encode_base64 +from synapse.crypto.event_signing import compute_event_reference_hash class SignatureStore(SQLBaseStore): @@ -101,23 +102,26 @@ class SignatureStore(SQLBaseStore): txn.execute(query, (event_id, )) return {k: v for k, v in txn.fetchall()} - def _store_event_reference_hash_txn(self, txn, event_id, algorithm, - hash_bytes): + def _store_event_reference_hashes_txn(self, txn, events): """Store a hash for a PDU Args: txn (cursor): - event_id (str): Id for the Event. - algorithm (str): Hashing algorithm. - hash_bytes (bytes): Hash function output bytes. + events (list): list of Events. """ - self._simple_insert_txn( + + vals = [] + for event in events: + ref_alg, ref_hash_bytes = compute_event_reference_hash(event) + vals.append({ + "event_id": event.event_id, + "algorithm": ref_alg, + "hash": buffer(ref_hash_bytes), + }) + + self._simple_insert_many_txn( txn, - "event_reference_hashes", - { - "event_id": event_id, - "algorithm": algorithm, - "hash": buffer(hash_bytes), - }, + table="event_reference_hashes", + values=vals, ) def _get_event_signatures_txn(self, txn, event_id): diff --git a/synapse/storage/state.py b/synapse/storage/state.py index d7844edee3..47bec65497 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -100,16 +100,23 @@ class StateStore(SQLBaseStore): ) def _store_state_groups_txn(self, txn, event, context): - if context.current_state is None: - return + return self._store_mult_state_groups_txn(txn, [(event, context)]) - state_events = dict(context.current_state) + def _store_mult_state_groups_txn(self, txn, events_and_contexts): + state_groups = {} + for event, context in events_and_contexts: + if context.current_state is None: + continue - if event.is_state(): - state_events[(event.type, event.state_key)] = event + if context.state_group is not None: + state_groups[event.event_id] = context.state_group + continue + + state_events = dict(context.current_state) + + if event.is_state(): + state_events[(event.type, event.state_key)] = event - state_group = context.state_group - if not state_group: state_group = self._state_groups_id_gen.get_next_txn(txn) self._simple_insert_txn( txn, @@ -135,14 +142,19 @@ class StateStore(SQLBaseStore): for state in state_events.values() ], ) + state_groups[event.event_id] = state_group - self._simple_insert_txn( + self._simple_insert_many_txn( txn, table="event_to_state_groups", - values={ - "state_group": state_group, - "event_id": event.event_id, - }, + values=[ + { + "state_group": state_groups[event.event_id], + "event_id": event.event_id, + } + for event, context in events_and_contexts + if context.current_state is not None + ], ) @defer.inlineCallbacks diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index b39006315d..e956df62c7 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -111,6 +111,37 @@ class StreamIdGenerator(object): defer.returnValue(manager()) @defer.inlineCallbacks + def get_next_mult(self, store, n): + """ + Usage: + with yield stream_id_gen.get_next(store, n) as stream_ids: + # ... persist events ... + """ + if not self._current_max: + yield store.runInteraction( + "_compute_current_max", + self._get_or_compute_current_max, + ) + + with self._lock: + next_ids = range(self._current_max + 1, self._current_max + n + 1) + self._current_max += n + + for next_id in next_ids: + self._unfinished_ids.append(next_id) + + @contextlib.contextmanager + def manager(): + try: + yield next_ids + finally: + with self._lock: + for next_id in next_ids: + self._unfinished_ids.remove(next_id) + + defer.returnValue(manager()) + + @defer.inlineCallbacks def get_max_token(self, store): """Returns the maximum stream id such that all stream ids less than or equal to it have been successfully persisted. |