From 61407986b40e28b590961d364f5618bbe7d44e94 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Wed, 30 Mar 2016 16:18:46 +0100 Subject: Add a entry to current_state_resets table when the current state is reset --- synapse/storage/events.py | 9 +++++++++ 1 file changed, 9 insertions(+) (limited to 'synapse/storage') diff --git a/synapse/storage/events.py b/synapse/storage/events.py index a4b8995496..bd4d503b6d 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -205,6 +205,15 @@ class EventsStore(SQLBaseStore): txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,)) txn.call_after(self.get_room_name_and_aliases, event.room_id) + # Add an entry to the current_state_resets table to record the point + # where we clobbered the current state + stream_order = event.internal_metadata.stream_ordering + self._simple_insert_txn( + txn, + table="current_state_resets", + values={"event_stream_ordering": stream_order} + ) + self._simple_delete_txn( txn, table="current_state_events", -- cgit 1.4.1 From 1fbb094c6fbaab33ef8e17802e37057e83718e7e Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Wed, 30 Mar 2016 17:19:56 +0100 Subject: Add replication streams for ex outliers and current state resets --- synapse/replication/resource.py | 17 ++++++- synapse/storage/events.py | 60 +++++++++++++++++++++++- synapse/storage/schema/delta/30/state_stream.sql | 38 +++++++++++++++ 3 files changed, 113 insertions(+), 2 deletions(-) create mode 100644 synapse/storage/schema/delta/30/state_stream.sql (limited to 'synapse/storage') diff --git a/synapse/replication/resource.py b/synapse/replication/resource.py index 096a79a7a4..7afa1242d5 100644 --- a/synapse/replication/resource.py +++ b/synapse/replication/resource.py @@ -204,7 +204,11 @@ class ReplicationResource(Resource): request_events = current_token.events if request_backfill is None: request_backfill = current_token.backfill - events_rows, backfill_rows = yield self.store.get_all_new_events( + ( + events_rows, backfill_rows, + forward_ex_outliers, backward_ex_outliers, + state_resets + ) = yield self.store.get_all_new_events( request_backfill, request_events, current_token.backfill, current_token.events, limit @@ -215,6 +219,17 @@ class ReplicationResource(Resource): writer.write_header_and_rows("backfill", backfill_rows, ( "position", "internal", "json", "state_group" )) + writer.write_header_and_rows( + "forward_ex_outliers", forward_ex_outliers, + ("position", "event_id", "state_group") + ) + writer.write_header_and_rows( + "backward_ex_outliers", backward_ex_outliers, + ("position", "event_id", "state_group") + ) + writer.write_header_and_rows( + "state_resets", state_resets, ("position",) + ) @defer.inlineCallbacks def presence(self, writer, current_token): diff --git a/synapse/storage/events.py b/synapse/storage/events.py index bd4d503b6d..9725a3fed7 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -323,6 +323,18 @@ class EventsStore(SQLBaseStore): (metadata_json, event.event_id,) ) + stream_order = event.internal_metadata.stream_ordering + state_group_id = context.state_group or context.new_state_group_id + self._simple_insert_txn( + txn, + table="ex_outlier_stream", + values={ + "event_stream_ordering": stream_order, + "event_id": event.event_id, + "state_group": state_group_id, + } + ) + sql = ( "UPDATE events SET outlier = ?" " WHERE event_id = ?" @@ -1119,8 +1131,34 @@ class EventsStore(SQLBaseStore): if last_forward_id != current_forward_id: txn.execute(sql, (last_forward_id, current_forward_id, limit)) new_forward_events = txn.fetchall() + + if len(new_forward_events) == limit: + upper_bound = new_forward_events[-1][0] + else: + upper_bound = current_forward_id + + sql = ( + "SELECT -event_stream_ordering FROM current_state_resets" + " WHERE ? < event_stream_ordering" + " AND event_stream_ordering <= ?" + " ORDER BY event_stream_ordering ASC" + ) + txn.execute(sql, (last_forward_id, upper_bound)) + state_resets = txn.fetchall() + + sql = ( + "SELECT -event_stream_ordering, event_id, state_group" + " FROM ex_outlier_stream" + " WHERE ? > event_stream_ordering" + " AND event_stream_ordering >= ?" + " ORDER BY event_stream_ordering DESC" + ) + txn.execute(sql, (last_forward_id, upper_bound)) + forward_ex_outliers = txn.fetchall() else: new_forward_events = [] + state_resets = [] + forward_ex_outliers = [] sql = ( "SELECT -e.stream_ordering, ej.internal_metadata, ej.json" @@ -1136,8 +1174,28 @@ class EventsStore(SQLBaseStore): if last_backfill_id != current_backfill_id: txn.execute(sql, (-last_backfill_id, -current_backfill_id, limit)) new_backfill_events = txn.fetchall() + + if len(new_backfill_events) == limit: + upper_bound = new_backfill_events[-1][0] + else: + upper_bound = current_backfill_id + + sql = ( + "SELECT -event_stream_ordering, event_id, state_group" + " FROM ex_outlier_stream" + " WHERE ? > event_stream_ordering" + " AND event_stream_ordering >= ?" + " ORDER BY event_stream_ordering DESC" + ) + txn.execute(sql, (-last_backfill_id, -upper_bound)) + backward_ex_outliers = txn.fetchall() else: new_backfill_events = [] + backward_ex_outliers = [] - return (new_forward_events, new_backfill_events) + return ( + new_forward_events, new_backfill_events, + forward_ex_outliers, backward_ex_outliers, + state_resets, + ) return self.runInteraction("get_all_new_events", get_all_new_events_txn) diff --git a/synapse/storage/schema/delta/30/state_stream.sql b/synapse/storage/schema/delta/30/state_stream.sql new file mode 100644 index 0000000000..706fe1dcf4 --- /dev/null +++ b/synapse/storage/schema/delta/30/state_stream.sql @@ -0,0 +1,38 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +/** + * The positions in the event stream_ordering when the current_state was + * replaced by the state at the event. + */ + +CREATE TABLE IF NOT EXISTS current_state_resets( + event_stream_ordering BIGINT PRIMARY KEY NOT NULL +); + +/* The outlier events that have aquired a state group typically through + * backfill. This is tracked separately to the events table, as assigning a + * state group change the position of the existing event in the stream + * ordering. + * However since a stream_ordering is assigned in persist_event for the + * (event, state) pair, we can use that stream_ordering to identify when + * the new state was assigned for the event. + */ +CREATE TABLE IF NOT EXISTS ex_outlier_stream( + event_stream_ordering BIGINT PRIMARY KEY NOT NULL, + event_id TEXT NOT NULL, + state_group BIGINT NOT NULL +); -- cgit 1.4.1 From 2ec54260350b46c937527bd566b713cf3544f1d2 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Thu, 31 Mar 2016 10:33:02 +0100 Subject: Use a namedtuple rather than tuple unpacking --- synapse/replication/resource.py | 16 ++++++---------- synapse/storage/events.py | 11 +++++++++-- 2 files changed, 15 insertions(+), 12 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/replication/resource.py b/synapse/replication/resource.py index 7afa1242d5..69afcb03d2 100644 --- a/synapse/replication/resource.py +++ b/synapse/replication/resource.py @@ -204,31 +204,27 @@ class ReplicationResource(Resource): request_events = current_token.events if request_backfill is None: request_backfill = current_token.backfill - ( - events_rows, backfill_rows, - forward_ex_outliers, backward_ex_outliers, - state_resets - ) = yield self.store.get_all_new_events( + res = yield self.store.get_all_new_events( request_backfill, request_events, current_token.backfill, current_token.events, limit ) - writer.write_header_and_rows("events", events_rows, ( + writer.write_header_and_rows("events", res.new_forward_events, ( "position", "internal", "json", "state_group" )) - writer.write_header_and_rows("backfill", backfill_rows, ( + writer.write_header_and_rows("backfill", res.new_backfill_events, ( "position", "internal", "json", "state_group" )) writer.write_header_and_rows( - "forward_ex_outliers", forward_ex_outliers, + "forward_ex_outliers", res.forward_ex_outliers, ("position", "event_id", "state_group") ) writer.write_header_and_rows( - "backward_ex_outliers", backward_ex_outliers, + "backward_ex_outliers", res.backward_ex_outliers, ("position", "event_id", "state_group") ) writer.write_header_and_rows( - "state_resets", state_resets, ("position",) + "state_resets", res.state_resets, ("position",) ) @defer.inlineCallbacks diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 9725a3fed7..b7ad045e41 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -25,7 +25,7 @@ from synapse.api.constants import EventTypes from canonicaljson import encode_canonical_json from contextlib import contextmanager - +from collections import namedtuple import logging import math @@ -1193,9 +1193,16 @@ class EventsStore(SQLBaseStore): new_backfill_events = [] backward_ex_outliers = [] - return ( + return AllNewEventsResult( new_forward_events, new_backfill_events, forward_ex_outliers, backward_ex_outliers, state_resets, ) return self.runInteraction("get_all_new_events", get_all_new_events_txn) + + +AllNewEventsResult = namedtuple("AllNewEventsResult", [ + "new_forward_events", "new_backfill_events", + "forward_ex_outliers", "backward_ex_outliers", + "state_resets" +]) -- cgit 1.4.1 From 76503f95ed2b618a3ff97c6b04868d8e440ef9d4 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Thu, 31 Mar 2016 15:00:42 +0100 Subject: Remove the is_new_state argument to persist event. Move the checks for whether an event is new state inside persist event itself. This was harder than expected because there wasn't enough information passed to persist event to correctly handle invites from remote servers for new rooms. --- synapse/events/__init__.py | 3 ++ synapse/handlers/federation.py | 20 ++-------- synapse/storage/events.py | 90 +++++++++++++++++++++++------------------- 3 files changed, 57 insertions(+), 56 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index 925a83c645..13154b1723 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -33,6 +33,9 @@ class _EventInternalMetadata(object): def is_outlier(self): return getattr(self, "outlier", False) + def is_invite_from_remote(self): + return getattr(self, "invite_from_remote", False) + def _event_dict_property(key): def getter(self): diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 267fedf114..4a35344d32 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -102,8 +102,7 @@ class FederationHandler(BaseHandler): @log_function @defer.inlineCallbacks - def on_receive_pdu(self, origin, pdu, state=None, - auth_chain=None): + def on_receive_pdu(self, origin, pdu, state=None, auth_chain=None): """ Called by the ReplicationLayer when we have a new pdu. We need to do auth checks and put it through the StateHandler. """ @@ -174,11 +173,7 @@ class FederationHandler(BaseHandler): }) seen_ids.add(e.event_id) - yield self._handle_new_events( - origin, - event_infos, - outliers=True - ) + yield self._handle_new_events(origin, event_infos) try: context, event_stream_id, max_stream_id = yield self._handle_new_event( @@ -761,6 +756,7 @@ class FederationHandler(BaseHandler): event = pdu event.internal_metadata.outlier = True + event.internal_metadata.invite_from_remote = True event.signatures.update( compute_event_signature( @@ -1069,9 +1065,6 @@ class FederationHandler(BaseHandler): @defer.inlineCallbacks @log_function def _handle_new_event(self, origin, event, state=None, auth_events=None): - - outlier = event.internal_metadata.is_outlier() - context = yield self._prep_event( origin, event, state=state, @@ -1087,14 +1080,12 @@ class FederationHandler(BaseHandler): event_stream_id, max_stream_id = yield self.store.persist_event( event, context=context, - is_new_state=not outlier, ) defer.returnValue((context, event_stream_id, max_stream_id)) @defer.inlineCallbacks - def _handle_new_events(self, origin, event_infos, backfilled=False, - outliers=False): + def _handle_new_events(self, origin, event_infos, backfilled=False): contexts = yield defer.gatherResults( [ self._prep_event( @@ -1113,7 +1104,6 @@ class FederationHandler(BaseHandler): for ev_info, context in itertools.izip(event_infos, contexts) ], backfilled=backfilled, - is_new_state=(not outliers and not backfilled), ) @defer.inlineCallbacks @@ -1176,7 +1166,6 @@ class FederationHandler(BaseHandler): (e, events_to_context[e.event_id]) for e in itertools.chain(auth_events, state) ], - is_new_state=False, ) new_event_context = yield self.state_handler.compute_event_context( @@ -1185,7 +1174,6 @@ class FederationHandler(BaseHandler): event_stream_id, max_stream_id = yield self.store.persist_event( event, new_event_context, - is_new_state=True, current_state=state, ) diff --git a/synapse/storage/events.py b/synapse/storage/events.py index dc3e994de9..b5f9b3b900 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -61,8 +61,7 @@ class EventsStore(SQLBaseStore): ) @defer.inlineCallbacks - def persist_events(self, events_and_contexts, backfilled=False, - is_new_state=True): + def persist_events(self, events_and_contexts, backfilled=False): if not events_and_contexts: return @@ -110,13 +109,11 @@ class EventsStore(SQLBaseStore): self._persist_events_txn, events_and_contexts=chunk, backfilled=backfilled, - is_new_state=is_new_state, ) @defer.inlineCallbacks @log_function - def persist_event(self, event, context, - is_new_state=True, current_state=None): + def persist_event(self, event, context, current_state=None): try: with self._stream_id_gen.get_next() as stream_ordering: @@ -128,7 +125,6 @@ class EventsStore(SQLBaseStore): self._persist_event_txn, event=event, context=context, - is_new_state=is_new_state, current_state=current_state, ) except _RollbackButIsFineException: @@ -194,8 +190,7 @@ class EventsStore(SQLBaseStore): defer.returnValue({e.event_id: e for e in events}) @log_function - def _persist_event_txn(self, txn, event, context, - is_new_state, current_state): + def _persist_event_txn(self, txn, event, context, current_state): # We purposefully do this first since if we include a `current_state` # key, we *want* to update the `current_state_events` table if current_state: @@ -236,12 +231,10 @@ class EventsStore(SQLBaseStore): txn, [(event, context)], backfilled=False, - is_new_state=is_new_state, ) @log_function - def _persist_events_txn(self, txn, events_and_contexts, backfilled, - is_new_state): + def _persist_events_txn(self, txn, events_and_contexts, backfilled): depth_updates = {} for event, context in events_and_contexts: # Remove the any existing cache entries for the event_ids @@ -452,10 +445,9 @@ class EventsStore(SQLBaseStore): txn, [event for event, _ in events_and_contexts] ) - state_events_and_contexts = filter( - lambda i: i[0].is_state(), - events_and_contexts, - ) + state_events_and_contexts = [ + ec for ec in events_and_contexts if ec[0].is_state() + ] state_values = [] for event, context in state_events_and_contexts: @@ -493,32 +485,50 @@ class EventsStore(SQLBaseStore): ], ) - if is_new_state: - for event, _ in state_events_and_contexts: - if not context.rejected: - txn.call_after( - self._get_current_state_for_key.invalidate, - (event.room_id, event.type, event.state_key,) - ) + for event, _ in state_events_and_contexts: + if backfilled: + # Backfilled events come before the current state so shouldn't + # clobber it. + continue - if event.type in [EventTypes.Name, EventTypes.Aliases]: - txn.call_after( - self.get_room_name_and_aliases.invalidate, - (event.room_id,) - ) - - self._simple_upsert_txn( - txn, - "current_state_events", - keyvalues={ - "room_id": event.room_id, - "type": event.type, - "state_key": event.state_key, - }, - values={ - "event_id": event.event_id, - } - ) + if (not event.internal_metadata.is_invite_from_remote() + and event.internal_metadata.is_outlier()): + # Outlier events generally shouldn't clobber the current state. + # However invites from remote severs for rooms we aren't in + # are a bit special: they don't come with any associated + # state so are technically an outlier, however all the + # client-facing code assumes that they are in the current + # state table so we insert the event anyway. + continue + + if context.rejected: + # If the event failed it's auth checks then it shouldn't + # clobbler the current state. + continue + + txn.call_after( + self._get_current_state_for_key.invalidate, + (event.room_id, event.type, event.state_key,) + ) + + if event.type in [EventTypes.Name, EventTypes.Aliases]: + txn.call_after( + self.get_room_name_and_aliases.invalidate, + (event.room_id,) + ) + + self._simple_upsert_txn( + txn, + "current_state_events", + keyvalues={ + "room_id": event.room_id, + "type": event.type, + "state_key": event.state_key, + }, + values={ + "event_id": event.event_id, + } + ) return -- cgit 1.4.1 From 5d069291696c62c94d795506dd5a22a3dbd019e1 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Thu, 31 Mar 2016 15:09:09 +0100 Subject: Move the check for backfilled outside the for loop --- synapse/storage/events.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/events.py b/synapse/storage/events.py index b5f9b3b900..83279d65fa 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -485,12 +485,12 @@ class EventsStore(SQLBaseStore): ], ) - for event, _ in state_events_and_contexts: - if backfilled: - # Backfilled events come before the current state so shouldn't - # clobber it. - continue + if backfilled: + # Backfilled events come before the current state so we don't need + # to update the current state table + return + for event, _ in state_events_and_contexts: if (not event.internal_metadata.is_invite_from_remote() and event.internal_metadata.is_outlier()): # Outlier events generally shouldn't clobber the current state. -- cgit 1.4.1 From 7753fc65702dc2b98885f462fe0e6ac9f8e27b06 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Fri, 1 Apr 2016 10:34:51 +0100 Subject: Fix the invalidation of the names and aliases cache --- synapse/storage/events.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse/storage') diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 83279d65fa..7468e6e00c 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -198,7 +198,7 @@ class EventsStore(SQLBaseStore): txn.call_after(self.get_rooms_for_user.invalidate_all) txn.call_after(self.get_users_in_room.invalidate, (event.room_id,)) txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,)) - txn.call_after(self.get_room_name_and_aliases, event.room_id) + txn.call_after(self.get_room_name_and_aliases.invalidate, (event.room_id,)) # Add an entry to the current_state_resets table to record the point # where we clobbered the current state -- cgit 1.4.1 From e36bfbab38def70e0fcc1bafcecb6e666dbbc1ad Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Fri, 1 Apr 2016 13:29:05 +0100 Subject: Use a stream id generator for backfilled ids --- synapse/storage/__init__.py | 20 ++++-------- synapse/storage/account_data.py | 4 +-- synapse/storage/events.py | 19 +++-------- synapse/storage/presence.py | 6 ++-- synapse/storage/push_rule.py | 2 +- synapse/storage/pusher.py | 2 +- synapse/storage/receipts.py | 6 ++-- synapse/storage/state.py | 2 +- synapse/storage/stream.py | 2 +- synapse/storage/tags.py | 6 ++-- synapse/storage/util/id_generators.py | 61 +++++++++++++++++++++++------------ 11 files changed, 69 insertions(+), 61 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index aaad38039e..f87e907cd8 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -88,15 +88,6 @@ class DataStore(RoomMemberStore, RoomStore, self.hs = hs self.database_engine = hs.database_engine - cur = db_conn.cursor() - try: - cur.execute("SELECT MIN(stream_ordering) FROM events",) - rows = cur.fetchall() - self.min_stream_token = rows[0][0] if rows and rows[0] and rows[0][0] else -1 - self.min_stream_token = min(self.min_stream_token, -1) - finally: - cur.close() - self.client_ip_last_seen = Cache( name="client_ip_last_seen", keylen=4, @@ -105,6 +96,9 @@ class DataStore(RoomMemberStore, RoomStore, self._stream_id_gen = StreamIdGenerator( db_conn, "events", "stream_ordering" ) + self._backfill_id_gen = StreamIdGenerator( + db_conn, "events", "stream_ordering", direction=-1 + ) self._receipts_id_gen = StreamIdGenerator( db_conn, "receipts_linearized", "stream_id" ) @@ -129,7 +123,7 @@ class DataStore(RoomMemberStore, RoomStore, extra_tables=[("deleted_pushers", "stream_id")], ) - events_max = self._stream_id_gen.get_max_token() + events_max = self._stream_id_gen.get_current_token() event_cache_prefill, min_event_val = self._get_cache_dict( db_conn, "events", entity_column="room_id", @@ -145,7 +139,7 @@ class DataStore(RoomMemberStore, RoomStore, "MembershipStreamChangeCache", events_max, ) - account_max = self._account_data_id_gen.get_max_token() + account_max = self._account_data_id_gen.get_current_token() self._account_data_stream_cache = StreamChangeCache( "AccountDataAndTagsChangeCache", account_max, ) @@ -156,7 +150,7 @@ class DataStore(RoomMemberStore, RoomStore, db_conn, "presence_stream", entity_column="user_id", stream_column="stream_id", - max_value=self._presence_id_gen.get_max_token(), + max_value=self._presence_id_gen.get_current_token(), ) self.presence_stream_cache = StreamChangeCache( "PresenceStreamChangeCache", min_presence_val, @@ -167,7 +161,7 @@ class DataStore(RoomMemberStore, RoomStore, db_conn, "push_rules_stream", entity_column="user_id", stream_column="stream_id", - max_value=self._push_rules_stream_id_gen.get_max_token()[0], + max_value=self._push_rules_stream_id_gen.get_current_token()[0], ) self.push_rules_stream_cache = StreamChangeCache( diff --git a/synapse/storage/account_data.py b/synapse/storage/account_data.py index faddefe219..7a7fbf1e52 100644 --- a/synapse/storage/account_data.py +++ b/synapse/storage/account_data.py @@ -200,7 +200,7 @@ class AccountDataStore(SQLBaseStore): "add_room_account_data", add_account_data_txn, next_id ) - result = self._account_data_id_gen.get_max_token() + result = self._account_data_id_gen.get_current_token() defer.returnValue(result) @defer.inlineCallbacks @@ -239,7 +239,7 @@ class AccountDataStore(SQLBaseStore): "add_user_account_data", add_account_data_txn, next_id ) - result = self._account_data_id_gen.get_max_token() + result = self._account_data_id_gen.get_current_token() defer.returnValue(result) def _update_max_stream_id(self, txn, next_id): diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 83279d65fa..4ab23c1597 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -24,7 +24,6 @@ from synapse.util.logutils import log_function from synapse.api.constants import EventTypes from canonicaljson import encode_canonical_json -from contextlib import contextmanager from collections import namedtuple import logging @@ -66,14 +65,9 @@ class EventsStore(SQLBaseStore): return if backfilled: - start = self.min_stream_token - 1 - self.min_stream_token -= len(events_and_contexts) + 1 - stream_orderings = range(start, self.min_stream_token, -1) - - @contextmanager - def stream_ordering_manager(): - yield stream_orderings - stream_ordering_manager = stream_ordering_manager() + stream_ordering_manager = self._backfill_id_gen.get_next_mult( + len(events_and_contexts) + ) else: stream_ordering_manager = self._stream_id_gen.get_next_mult( len(events_and_contexts) @@ -130,7 +124,7 @@ class EventsStore(SQLBaseStore): except _RollbackButIsFineException: pass - max_persisted_id = yield self._stream_id_gen.get_max_token() + max_persisted_id = yield self._stream_id_gen.get_current_token() defer.returnValue((stream_ordering, max_persisted_id)) @defer.inlineCallbacks @@ -1117,10 +1111,7 @@ class EventsStore(SQLBaseStore): def get_current_backfill_token(self): """The current minimum token that backfilled events have reached""" - - # TODO: Fix race with the persit_event txn by using one of the - # stream id managers - return -self.min_stream_token + return -self._backfill_id_gen.get_current_token() def get_all_new_events(self, last_backfill_id, last_forward_id, current_backfill_id, current_forward_id, limit): diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py index 4cec31e316..59b4ef5ce6 100644 --- a/synapse/storage/presence.py +++ b/synapse/storage/presence.py @@ -68,7 +68,9 @@ class PresenceStore(SQLBaseStore): self._update_presence_txn, stream_orderings, presence_states, ) - defer.returnValue((stream_orderings[-1], self._presence_id_gen.get_max_token())) + defer.returnValue(( + stream_orderings[-1], self._presence_id_gen.get_current_token() + )) def _update_presence_txn(self, txn, stream_orderings, presence_states): for stream_id, state in zip(stream_orderings, presence_states): @@ -155,7 +157,7 @@ class PresenceStore(SQLBaseStore): defer.returnValue([UserPresenceState(**row) for row in rows]) def get_current_presence_token(self): - return self._presence_id_gen.get_max_token() + return self._presence_id_gen.get_current_token() def allow_presence_visible(self, observed_localpart, observer_userid): return self._simple_insert( diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index 9dbad2fd5f..d2bf7f2aec 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -392,7 +392,7 @@ class PushRuleStore(SQLBaseStore): """Get the position of the push rules stream. Returns a pair of a stream id for the push_rules stream and the room stream ordering it corresponds to.""" - return self._push_rules_stream_id_gen.get_max_token() + return self._push_rules_stream_id_gen.get_current_token() def have_push_rules_changed_for_user(self, user_id, last_id): if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id): diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py index 87b2ac5773..d1669c778a 100644 --- a/synapse/storage/pusher.py +++ b/synapse/storage/pusher.py @@ -78,7 +78,7 @@ class PusherStore(SQLBaseStore): defer.returnValue(rows) def get_pushers_stream_token(self): - return self._pushers_id_gen.get_max_token() + return self._pushers_id_gen.get_current_token() def get_all_updated_pushers(self, last_id, current_id, limit): def get_all_updated_pushers_txn(txn): diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py index 6b9d848eaa..4befebc8e2 100644 --- a/synapse/storage/receipts.py +++ b/synapse/storage/receipts.py @@ -31,7 +31,7 @@ class ReceiptsStore(SQLBaseStore): super(ReceiptsStore, self).__init__(hs) self._receipts_stream_cache = StreamChangeCache( - "ReceiptsRoomChangeCache", self._receipts_id_gen.get_max_token() + "ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token() ) @cached(num_args=2) @@ -221,7 +221,7 @@ class ReceiptsStore(SQLBaseStore): defer.returnValue(results) def get_max_receipt_stream_id(self): - return self._receipts_id_gen.get_max_token() + return self._receipts_id_gen.get_current_token() def insert_linearized_receipt_txn(self, txn, room_id, receipt_type, user_id, event_id, data, stream_id): @@ -346,7 +346,7 @@ class ReceiptsStore(SQLBaseStore): room_id, receipt_type, user_id, event_ids, data ) - max_persisted_id = self._stream_id_gen.get_max_token() + max_persisted_id = self._stream_id_gen.get_current_token() defer.returnValue((stream_id, max_persisted_id)) diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 7fc9a4f264..8644830657 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -458,4 +458,4 @@ class StateStore(SQLBaseStore): ) def get_state_stream_token(self): - return self._state_groups_id_gen.get_max_token() + return self._state_groups_id_gen.get_current_token() diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index cf84938be5..76bcd9cd00 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -539,7 +539,7 @@ class StreamStore(SQLBaseStore): @defer.inlineCallbacks def get_room_events_max_id(self, direction='f'): - token = yield self._stream_id_gen.get_max_token() + token = yield self._stream_id_gen.get_current_token() if direction != 'b': defer.returnValue("s%d" % (token,)) else: diff --git a/synapse/storage/tags.py b/synapse/storage/tags.py index a0e6b42b30..9da23f34cb 100644 --- a/synapse/storage/tags.py +++ b/synapse/storage/tags.py @@ -30,7 +30,7 @@ class TagsStore(SQLBaseStore): Returns: A deferred int. """ - return self._account_data_id_gen.get_max_token() + return self._account_data_id_gen.get_current_token() @cached() def get_tags_for_user(self, user_id): @@ -200,7 +200,7 @@ class TagsStore(SQLBaseStore): self.get_tags_for_user.invalidate((user_id,)) - result = self._account_data_id_gen.get_max_token() + result = self._account_data_id_gen.get_current_token() defer.returnValue(result) @defer.inlineCallbacks @@ -222,7 +222,7 @@ class TagsStore(SQLBaseStore): self.get_tags_for_user.invalidate((user_id,)) - result = self._account_data_id_gen.get_max_token() + result = self._account_data_id_gen.get_current_token() defer.returnValue(result) def _update_revision_txn(self, txn, user_id, room_id, next_id): diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index a02dfc7d58..03f2aa6a5c 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -21,7 +21,7 @@ import threading class IdGenerator(object): def __init__(self, db_conn, table, column): self._lock = threading.Lock() - self._next_id = _load_max_id(db_conn, table, column) + self._next_id = _load_current_id(db_conn, table, column) def get_next(self): with self._lock: @@ -29,12 +29,16 @@ class IdGenerator(object): return self._next_id -def _load_max_id(db_conn, table, column): +def _load_current_id(db_conn, table, column, direction=1): cur = db_conn.cursor() - cur.execute("SELECT MAX(%s) FROM %s" % (column, table,)) + if direction == 1: + cur.execute("SELECT MAX(%s) FROM %s" % (column, table,)) + else: + cur.execute("SELECT MIN(%s) FROM %s" % (column, table,)) val, = cur.fetchone() cur.close() - return int(val) if val else 1 + current_id = int(val) if val else direction + return (max if direction == 1 else min)(current_id, direction) class StreamIdGenerator(object): @@ -45,17 +49,30 @@ class StreamIdGenerator(object): all ids less than or equal to it have completed. This handles the fact that persistence of events can complete out of order. + :param connection db_conn: A database connection to use to fetch the + initial value of the generator from. + :param str table: A database table to read the initial value of the id + generator from. + :param str column: The column of the database table to read the initial + value from the id generator from. + :param list extra_tables: List of pairs of database tables and columns to + use to source the initial value of the generator from. The value with + the largest magnitude is used. + :param int direction: which direction the stream ids grow in. +1 to grow + upwards, -1 to grow downwards. + Usage: with stream_id_gen.get_next() as stream_id: # ... persist event ... """ - def __init__(self, db_conn, table, column, extra_tables=[]): + def __init__(self, db_conn, table, column, extra_tables=[], direction=1): self._lock = threading.Lock() - self._current_max = _load_max_id(db_conn, table, column) + self._direction = direction + self._current = _load_current_id(db_conn, table, column, direction) for table, column in extra_tables: - self._current_max = max( - self._current_max, - _load_max_id(db_conn, table, column) + self._current = (max if direction > 0 else min)( + self._current, + _load_current_id(db_conn, table, column, direction) ) self._unfinished_ids = deque() @@ -66,8 +83,8 @@ class StreamIdGenerator(object): # ... persist event ... """ with self._lock: - self._current_max += 1 - next_id = self._current_max + self._current += self._direction + next_id = self._current self._unfinished_ids.append(next_id) @@ -88,8 +105,12 @@ class StreamIdGenerator(object): # ... persist events ... """ with self._lock: - next_ids = range(self._current_max + 1, self._current_max + n + 1) - self._current_max += n + next_ids = range( + self._current + self._direction, + self._current + self._direction * (n + 1), + self._direction + ) + self._current += n for next_id in next_ids: self._unfinished_ids.append(next_id) @@ -105,15 +126,15 @@ class StreamIdGenerator(object): return manager() - def get_max_token(self): + def get_current_token(self): """Returns the maximum stream id such that all stream ids less than or equal to it have been successfully persisted. """ with self._lock: if self._unfinished_ids: - return self._unfinished_ids[0] - 1 + return self._unfinished_ids[0] - self._direction - return self._current_max + return self._current class ChainedIdGenerator(object): @@ -125,7 +146,7 @@ class ChainedIdGenerator(object): def __init__(self, chained_generator, db_conn, table, column): self.chained_generator = chained_generator self._lock = threading.Lock() - self._current_max = _load_max_id(db_conn, table, column) + self._current_max = _load_current_id(db_conn, table, column) self._unfinished_ids = deque() def get_next(self): @@ -137,7 +158,7 @@ class ChainedIdGenerator(object): with self._lock: self._current_max += 1 next_id = self._current_max - chained_id = self.chained_generator.get_max_token() + chained_id = self.chained_generator.get_current_token() self._unfinished_ids.append((next_id, chained_id)) @@ -151,7 +172,7 @@ class ChainedIdGenerator(object): return manager() - def get_max_token(self): + def get_current_token(self): """Returns the maximum stream id such that all stream ids less than or equal to it have been successfully persisted. """ @@ -160,4 +181,4 @@ class ChainedIdGenerator(object): stream_id, chained_id = self._unfinished_ids[0] return (stream_id - 1, chained_id) - return (self._current_max, self.chained_generator.get_max_token()) + return (self._current_max, self.chained_generator.get_current_token()) -- cgit 1.4.1 From a2866e2e6a8fa60a538a98f62e1733ab062020aa Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Fri, 1 Apr 2016 13:50:54 +0100 Subject: Rename direction to step, apply checks consistently --- synapse/storage/__init__.py | 2 +- synapse/storage/util/id_generators.py | 30 +++++++++++++++--------------- 2 files changed, 16 insertions(+), 16 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index f87e907cd8..57863bba4d 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -97,7 +97,7 @@ class DataStore(RoomMemberStore, RoomStore, db_conn, "events", "stream_ordering" ) self._backfill_id_gen = StreamIdGenerator( - db_conn, "events", "stream_ordering", direction=-1 + db_conn, "events", "stream_ordering", step=-1 ) self._receipts_id_gen = StreamIdGenerator( db_conn, "receipts_linearized", "stream_id" diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 03f2aa6a5c..310b7dc6ee 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -29,16 +29,16 @@ class IdGenerator(object): return self._next_id -def _load_current_id(db_conn, table, column, direction=1): +def _load_current_id(db_conn, table, column, step=1): cur = db_conn.cursor() - if direction == 1: + if step == 1: cur.execute("SELECT MAX(%s) FROM %s" % (column, table,)) else: cur.execute("SELECT MIN(%s) FROM %s" % (column, table,)) val, = cur.fetchone() cur.close() - current_id = int(val) if val else direction - return (max if direction == 1 else min)(current_id, direction) + current_id = int(val) if val else step + return (max if step > 0 else min)(current_id, step) class StreamIdGenerator(object): @@ -58,21 +58,21 @@ class StreamIdGenerator(object): :param list extra_tables: List of pairs of database tables and columns to use to source the initial value of the generator from. The value with the largest magnitude is used. - :param int direction: which direction the stream ids grow in. +1 to grow + :param int step: which direction the stream ids grow in. +1 to grow upwards, -1 to grow downwards. Usage: with stream_id_gen.get_next() as stream_id: # ... persist event ... """ - def __init__(self, db_conn, table, column, extra_tables=[], direction=1): + def __init__(self, db_conn, table, column, extra_tables=[], step=1): self._lock = threading.Lock() - self._direction = direction - self._current = _load_current_id(db_conn, table, column, direction) + self._step = step + self._current = _load_current_id(db_conn, table, column, step) for table, column in extra_tables: - self._current = (max if direction > 0 else min)( + self._current = (max if step > 0 else min)( self._current, - _load_current_id(db_conn, table, column, direction) + _load_current_id(db_conn, table, column, step) ) self._unfinished_ids = deque() @@ -83,7 +83,7 @@ class StreamIdGenerator(object): # ... persist event ... """ with self._lock: - self._current += self._direction + self._current += self._step next_id = self._current self._unfinished_ids.append(next_id) @@ -106,9 +106,9 @@ class StreamIdGenerator(object): """ with self._lock: next_ids = range( - self._current + self._direction, - self._current + self._direction * (n + 1), - self._direction + self._current + self._step, + self._current + self._step * (n + 1), + self._step ) self._current += n @@ -132,7 +132,7 @@ class StreamIdGenerator(object): """ with self._lock: if self._unfinished_ids: - return self._unfinished_ids[0] - self._direction + return self._unfinished_ids[0] - self._step return self._current -- cgit 1.4.1 From 35b5c4ba1b1892fde18f531c96d71aa58de649e1 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Fri, 1 Apr 2016 15:07:01 +0100 Subject: use google style doc strings --- synapse/storage/util/id_generators.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 310b7dc6ee..58ea54cf67 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -49,17 +49,18 @@ class StreamIdGenerator(object): all ids less than or equal to it have completed. This handles the fact that persistence of events can complete out of order. - :param connection db_conn: A database connection to use to fetch the - initial value of the generator from. - :param str table: A database table to read the initial value of the id - generator from. - :param str column: The column of the database table to read the initial - value from the id generator from. - :param list extra_tables: List of pairs of database tables and columns to - use to source the initial value of the generator from. The value with - the largest magnitude is used. - :param int step: which direction the stream ids grow in. +1 to grow - upwards, -1 to grow downwards. + Args: + db_conn(connection): A database connection to use to fetch the + initial value of the generator from. + table(str): A database table to read the initial value of the id + generator from. + column(str): The column of the database table to read the initial + value from the id generator from. + extra_tables(list): List of pairs of database tables and columns to + use to source the initial value of the generator from. The value + with the largest magnitude is used. + step(int): which direction the stream ids grow in. +1 to grow + upwards, -1 to grow downwards. Usage: with stream_id_gen.get_next() as stream_id: -- cgit 1.4.1 From 9bc5b4c663ed4bb35ef74a820c108765c7ca0f67 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Fri, 1 Apr 2016 15:08:20 +0100 Subject: Assert that the step != 0 --- synapse/storage/util/id_generators.py | 1 + 1 file changed, 1 insertion(+) (limited to 'synapse/storage') diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 58ea54cf67..f69f1cdad4 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -67,6 +67,7 @@ class StreamIdGenerator(object): # ... persist event ... """ def __init__(self, db_conn, table, column, extra_tables=[], step=1): + assert step != 0 self._lock = threading.Lock() self._step = step self._current = _load_current_id(db_conn, table, column, step) -- cgit 1.4.1 From 2a37467fa1358eb41513893efe44cbd294dca36c Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Fri, 1 Apr 2016 16:08:59 +0100 Subject: Use google style doc strings. pycharm supports them so there is no need to use the other format. Might as well convert the existing strings to reduce the risk of people accidentally cargo culting the wrong doc string format. --- setup.cfg | 3 ++ synapse/handlers/_base.py | 27 +++++++----- synapse/handlers/auth.py | 26 +++++++---- synapse/handlers/federation.py | 23 +++++----- synapse/handlers/room_member.py | 48 ++++++++++----------- synapse/handlers/sync.py | 49 +++++++++++++-------- synapse/http/servlet.py | 81 ++++++++++++++++++++++------------- synapse/notifier.py | 15 ++++--- synapse/push/baserules.py | 8 ++-- synapse/rest/client/v2_alpha/sync.py | 79 ++++++++++++++++++---------------- synapse/state.py | 19 ++++---- synapse/storage/event_push_actions.py | 5 ++- synapse/storage/registration.py | 15 ++++--- synapse/storage/state.py | 13 +++--- 14 files changed, 242 insertions(+), 169 deletions(-) (limited to 'synapse/storage') diff --git a/setup.cfg b/setup.cfg index f8cc13c840..5ebce1c56b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -17,3 +17,6 @@ ignore = [flake8] max-line-length = 90 ignore = W503 ; W503 requires that binary operators be at the end, not start, of lines. Erik doesn't like it. + +[pep8] +max-line-length = 90 diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index 90eabb6eb7..5601ecea6e 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -41,8 +41,9 @@ class BaseHandler(object): """ Common base class for the event handlers. - :type store: synapse.storage.events.StateStore - :type state_handler: synapse.state.StateHandler + Attributes: + store (synapse.storage.events.StateStore): + state_handler (synapse.state.StateHandler): """ def __init__(self, hs): @@ -65,11 +66,12 @@ class BaseHandler(object): """ Returns dict of user_id -> list of events that user is allowed to see. - :param (str, bool) user_tuples: (user id, is_peeking) for each - user to be checked. is_peeking should be true if: - * the user is not currently a member of the room, and: - * the user has not been a member of the room since the given - events + Args: + user_tuples (str, bool): (user id, is_peeking) for each user to be + checked. is_peeking should be true if: + * the user is not currently a member of the room, and: + * the user has not been a member of the room since the + given events """ forgotten = yield defer.gatherResults([ self.store.who_forgot_in_room( @@ -165,13 +167,16 @@ class BaseHandler(object): """ Check which events a user is allowed to see - :param str user_id: user id to be checked - :param [synapse.events.EventBase] events: list of events to be checked - :param bool is_peeking should be True if: + Args: + user_id(str): user id to be checked + events([synapse.events.EventBase]): list of events to be checked + is_peeking(bool): should be True if: * the user is not currently a member of the room, and: * the user has not been a member of the room since the given events - :rtype [synapse.events.EventBase] + + Returns: + [synapse.events.EventBase] """ types = ( (EventTypes.RoomHistoryVisibility, ""), diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 82d458b424..d5d6faa85f 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -163,9 +163,13 @@ class AuthHandler(BaseHandler): def get_session_id(self, clientdict): """ Gets the session ID for a client given the client dictionary - :param clientdict: The dictionary sent by the client in the request - :return: The string session ID the client sent. If the client did not - send a session ID, returns None. + + Args: + clientdict: The dictionary sent by the client in the request + + Returns: + str|None: The string session ID the client sent. If the client did + not send a session ID, returns None. """ sid = None if clientdict and 'auth' in clientdict: @@ -179,9 +183,11 @@ class AuthHandler(BaseHandler): Store a key-value pair into the sessions data associated with this request. This data is stored server-side and cannot be modified by the client. - :param session_id: (string) The ID of this session as returned from check_auth - :param key: (string) The key to store the data under - :param value: (any) The data to store + + Args: + session_id (string): The ID of this session as returned from check_auth + key (string): The key to store the data under + value (any): The data to store """ sess = self._get_session_info(session_id) sess.setdefault('serverdict', {})[key] = value @@ -190,9 +196,11 @@ class AuthHandler(BaseHandler): def get_session_data(self, session_id, key, default=None): """ Retrieve data stored with set_session_data - :param session_id: (string) The ID of this session as returned from check_auth - :param key: (string) The key to store the data under - :param default: (any) Value to return if the key has not been set + + Args: + session_id (string): The ID of this session as returned from check_auth + key (string): The key to store the data under + default (any): Value to return if the key has not been set """ sess = self._get_session_info(session_id) return sess.setdefault('serverdict', {}).get(key, default) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 4a35344d32..092802b973 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1706,13 +1706,15 @@ class FederationHandler(BaseHandler): def _check_signature(self, event, auth_events): """ Checks that the signature in the event is consistent with its invite. - :param event (Event): The m.room.member event to check - :param auth_events (dict<(event type, state_key), event>) - :raises - AuthError if signature didn't match any keys, or key has been + Args: + event (Event): The m.room.member event to check + auth_events (dict<(event type, state_key), event>): + + Raises: + AuthError: if signature didn't match any keys, or key has been revoked, - SynapseError if a transient error meant a key couldn't be checked + SynapseError: if a transient error meant a key couldn't be checked for revocation. """ signed = event.content["third_party_invite"]["signed"] @@ -1754,12 +1756,13 @@ class FederationHandler(BaseHandler): """ Checks whether public_key has been revoked. - :param public_key (str): base-64 encoded public key. - :param url (str): Key revocation URL. + Args: + public_key (str): base-64 encoded public key. + url (str): Key revocation URL. - :raises - AuthError if they key has been revoked. - SynapseError if a transient error meant a key couldn't be checked + Raises: + AuthError: if they key has been revoked. + SynapseError: if a transient error meant a key couldn't be checked for revocation. """ try: diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 5fdbd3adcc..01f833c371 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -411,7 +411,7 @@ class RoomMemberHandler(BaseHandler): address (str): The third party identifier (e.g. "foo@example.com"). Returns: - (str) the matrix ID of the 3pid, or None if it is not recognized. + str: the matrix ID of the 3pid, or None if it is not recognized. """ try: data = yield self.hs.get_simple_http_client().get_json( @@ -545,29 +545,29 @@ class RoomMemberHandler(BaseHandler): """ Asks an identity server for a third party invite. - :param id_server (str): hostname + optional port for the identity server. - :param medium (str): The literal string "email". - :param address (str): The third party address being invited. - :param room_id (str): The ID of the room to which the user is invited. - :param inviter_user_id (str): The user ID of the inviter. - :param room_alias (str): An alias for the room, for cosmetic - notifications. - :param room_avatar_url (str): The URL of the room's avatar, for cosmetic - notifications. - :param room_join_rules (str): The join rules of the email - (e.g. "public"). - :param room_name (str): The m.room.name of the room. - :param inviter_display_name (str): The current display name of the - inviter. - :param inviter_avatar_url (str): The URL of the inviter's avatar. - - :return: A deferred tuple containing: - token (str): The token which must be signed to prove authenticity. - public_keys ([{"public_key": str, "key_validity_url": str}]): - public_key is a base64-encoded ed25519 public key. - fallback_public_key: One element from public_keys. - display_name (str): A user-friendly name to represent the invited - user. + Args: + id_server (str): hostname + optional port for the identity server. + medium (str): The literal string "email". + address (str): The third party address being invited. + room_id (str): The ID of the room to which the user is invited. + inviter_user_id (str): The user ID of the inviter. + room_alias (str): An alias for the room, for cosmetic notifications. + room_avatar_url (str): The URL of the room's avatar, for cosmetic + notifications. + room_join_rules (str): The join rules of the email (e.g. "public"). + room_name (str): The m.room.name of the room. + inviter_display_name (str): The current display name of the + inviter. + inviter_avatar_url (str): The URL of the inviter's avatar. + + Returns: + A deferred tuple containing: + token (str): The token which must be signed to prove authenticity. + public_keys ([{"public_key": str, "key_validity_url": str}]): + public_key is a base64-encoded ed25519 public key. + fallback_public_key: One element from public_keys. + display_name (str): A user-friendly name to represent the invited + user. """ is_url = "%s%s/_matrix/identity/api/v1/store-invite" % ( diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 48ab5707e1..20a0626574 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -671,7 +671,8 @@ class SyncHandler(BaseHandler): def load_filtered_recents(self, room_id, sync_config, now_token, since_token=None, recents=None, newly_joined_room=False): """ - :returns a Deferred TimelineBatch + Returns: + a Deferred TimelineBatch """ with Measure(self.clock, "load_filtered_recents"): filtering_factor = 2 @@ -838,8 +839,11 @@ class SyncHandler(BaseHandler): """ Get the room state after the given event - :param synapse.events.EventBase event: event of interest - :return: A Deferred map from ((type, state_key)->Event) + Args: + event(synapse.events.EventBase): event of interest + + Returns: + A Deferred map from ((type, state_key)->Event) """ state = yield self.store.get_state_for_event(event.event_id) if event.is_state(): @@ -850,9 +854,13 @@ class SyncHandler(BaseHandler): @defer.inlineCallbacks def get_state_at(self, room_id, stream_position): """ Get the room state at a particular stream position - :param str room_id: room for which to get state - :param StreamToken stream_position: point at which to get state - :returns: A Deferred map from ((type, state_key)->Event) + + Args: + room_id(str): room for which to get state + stream_position(StreamToken): point at which to get state + + Returns: + A Deferred map from ((type, state_key)->Event) """ last_events, token = yield self.store.get_recent_events_for_room( room_id, end_token=stream_position.room_key, limit=1, @@ -873,15 +881,18 @@ class SyncHandler(BaseHandler): """ Works out the differnce in state between the start of the timeline and the previous sync. - :param str room_id - :param TimelineBatch batch: The timeline batch for the room that will - be sent to the user. - :param sync_config - :param str since_token: Token of the end of the previous batch. May be None. - :param str now_token: Token of the end of the current batch. - :param bool full_state: Whether to force returning the full state. + Args: + room_id(str): + batch(synapse.handlers.sync.TimelineBatch): The timeline batch for + the room that will be sent to the user. + sync_config(synapse.handlers.sync.SyncConfig): + since_token(str|None): Token of the end of the previous batch. May + be None. + now_token(str): Token of the end of the current batch. + full_state(bool): Whether to force returning the full state. - :returns A new event dictionary + Returns: + A deferred new event dictionary """ # TODO(mjark) Check if the state events were received by the server # after the previous sync, since we need to include those state @@ -953,11 +964,13 @@ class SyncHandler(BaseHandler): Check if the user has just joined the given room (so should be given the full state) - :param sync_config: - :param dict[(str,str), synapse.events.FrozenEvent] state_delta: the - difference in state since the last sync + Args: + sync_config(synapse.handlers.sync.SyncConfig): + state_delta(dict[(str,str), synapse.events.FrozenEvent]): the + difference in state since the last sync - :returns A deferred Tuple (state_delta, limited) + Returns: + A deferred Tuple (state_delta, limited) """ join_event = state_delta.get(( EventTypes.Member, sync_config.user.to_string()), None) diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index 1c8bd8666f..e41afeab8e 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -26,14 +26,19 @@ logger = logging.getLogger(__name__) def parse_integer(request, name, default=None, required=False): """Parse an integer parameter from the request string - :param request: the twisted HTTP request. - :param name (str): the name of the query parameter. - :param default: value to use if the parameter is absent, defaults to None. - :param required (bool): whether to raise a 400 SynapseError if the - parameter is absent, defaults to False. - :return: An int value or the default. - :raises - SynapseError if the parameter is absent and required, or if the + Args: + request: the twisted HTTP request. + name (str): the name of the query parameter. + default (int|None): value to use if the parameter is absent, defaults + to None. + required (bool): whether to raise a 400 SynapseError if the + parameter is absent, defaults to False. + + Returns: + int|None: An int value or the default. + + Raises: + SynapseError: if the parameter is absent and required, or if the parameter is present and not an integer. """ if name in request.args: @@ -53,14 +58,19 @@ def parse_integer(request, name, default=None, required=False): def parse_boolean(request, name, default=None, required=False): """Parse a boolean parameter from the request query string - :param request: the twisted HTTP request. - :param name (str): the name of the query parameter. - :param default: value to use if the parameter is absent, defaults to None. - :param required (bool): whether to raise a 400 SynapseError if the - parameter is absent, defaults to False. - :return: A bool value or the default. - :raises - SynapseError if the parameter is absent and required, or if the + Args: + request: the twisted HTTP request. + name (str): the name of the query parameter. + default (bool|None): value to use if the parameter is absent, defaults + to None. + required (bool): whether to raise a 400 SynapseError if the + parameter is absent, defaults to False. + + Returns: + bool|None: A bool value or the default. + + Raises: + SynapseError: if the parameter is absent and required, or if the parameter is present and not one of "true" or "false". """ @@ -88,15 +98,20 @@ def parse_string(request, name, default=None, required=False, allowed_values=None, param_type="string"): """Parse a string parameter from the request query string. - :param request: the twisted HTTP request. - :param name (str): the name of the query parameter. - :param default: value to use if the parameter is absent, defaults to None. - :param required (bool): whether to raise a 400 SynapseError if the - parameter is absent, defaults to False. - :param allowed_values (list): List of allowed values for the string, - or None if any value is allowed, defaults to None - :return: A string value or the default. - :raises + Args: + request: the twisted HTTP request. + name (str): the name of the query parameter. + default (str|None): value to use if the parameter is absent, defaults + to None. + required (bool): whether to raise a 400 SynapseError if the + parameter is absent, defaults to False. + allowed_values (list[str]): List of allowed values for the string, + or None if any value is allowed, defaults to None + + Returns: + str|None: A string value or the default. + + Raises: SynapseError if the parameter is absent and required, or if the parameter is present, must be one of a list of allowed values and is not one of those allowed values. @@ -122,9 +137,13 @@ def parse_string(request, name, default=None, required=False, def parse_json_value_from_request(request): """Parse a JSON value from the body of a twisted HTTP request. - :param request: the twisted HTTP request. - :returns: The JSON value. - :raises + Args: + request: the twisted HTTP request. + + Returns: + The JSON value. + + Raises: SynapseError if the request body couldn't be decoded as JSON. """ try: @@ -143,8 +162,10 @@ def parse_json_value_from_request(request): def parse_json_object_from_request(request): """Parse a JSON object from the body of a twisted HTTP request. - :param request: the twisted HTTP request. - :raises + Args: + request: the twisted HTTP request. + + Raises: SynapseError if the request body couldn't be decoded as JSON or if it wasn't a JSON object. """ diff --git a/synapse/notifier.py b/synapse/notifier.py index f00cd8c588..6af7a8f424 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -503,13 +503,14 @@ class Notifier(object): def wait_for_replication(self, callback, timeout): """Wait for an event to happen. - :param callback: - Gets called whenever an event happens. If this returns a truthy - value then ``wait_for_replication`` returns, otherwise it waits - for another event. - :param int timeout: - How many milliseconds to wait for callback return a truthy value. - :returns: + Args: + callback: Gets called whenever an event happens. If this returns a + truthy value then ``wait_for_replication`` returns, otherwise + it waits for another event. + timeout: How many milliseconds to wait for callback return a truthy + value. + + Returns: A deferred that resolves with the value returned by the callback. """ listener = _NotificationListener(None) diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py index 792af70eb7..6add94beeb 100644 --- a/synapse/push/baserules.py +++ b/synapse/push/baserules.py @@ -19,9 +19,11 @@ import copy def list_with_base_rules(rawrules): """Combine the list of rules set by the user with the default push rules - :param list rawrules: The rules the user has modified or set. - :returns: A new list with the rules set by the user combined with the - defaults. + Args: + rawrules(list): The rules the user has modified or set. + + Returns: + A new list with the rules set by the user combined with the defaults. """ ruleslist = [] diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index c5785d7074..60d3dc4030 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -199,15 +199,17 @@ class SyncRestServlet(RestServlet): """ Encode the joined rooms in a sync result - :param list[synapse.handlers.sync.JoinedSyncResult] rooms: list of sync - results for rooms this user is joined to - :param int time_now: current time - used as a baseline for age - calculations - :param int token_id: ID of the user's auth token - used for namespacing - of transaction IDs - - :return: the joined rooms list, in our response format - :rtype: dict[str, dict[str, object]] + Args: + rooms(list[synapse.handlers.sync.JoinedSyncResult]): list of sync + results for rooms this user is joined to + time_now(int): current time - used as a baseline for age + calculations + token_id(int): ID of the user's auth token - used for namespacing + of transaction IDs + + Returns: + dict[str, dict[str, object]]: the joined rooms list, in our + response format """ joined = {} for room in rooms: @@ -221,15 +223,17 @@ class SyncRestServlet(RestServlet): """ Encode the invited rooms in a sync result - :param list[synapse.handlers.sync.InvitedSyncResult] rooms: list of - sync results for rooms this user is joined to - :param int time_now: current time - used as a baseline for age - calculations - :param int token_id: ID of the user's auth token - used for namespacing + Args: + rooms(list[synapse.handlers.sync.InvitedSyncResult]): list of + sync results for rooms this user is joined to + time_now(int): current time - used as a baseline for age + calculations + token_id(int): ID of the user's auth token - used for namespacing of transaction IDs - :return: the invited rooms list, in our response format - :rtype: dict[str, dict[str, object]] + Returns: + dict[str, dict[str, object]]: the invited rooms list, in our + response format """ invited = {} for room in rooms: @@ -251,15 +255,17 @@ class SyncRestServlet(RestServlet): """ Encode the archived rooms in a sync result - :param list[synapse.handlers.sync.ArchivedSyncResult] rooms: list of - sync results for rooms this user is joined to - :param int time_now: current time - used as a baseline for age - calculations - :param int token_id: ID of the user's auth token - used for namespacing - of transaction IDs - - :return: the invited rooms list, in our response format - :rtype: dict[str, dict[str, object]] + Args: + rooms (list[synapse.handlers.sync.ArchivedSyncResult]): list of + sync results for rooms this user is joined to + time_now(int): current time - used as a baseline for age + calculations + token_id(int): ID of the user's auth token - used for namespacing + of transaction IDs + + Returns: + dict[str, dict[str, object]]: The invited rooms list, in our + response format """ joined = {} for room in rooms: @@ -272,17 +278,18 @@ class SyncRestServlet(RestServlet): @staticmethod def encode_room(room, time_now, token_id, joined=True): """ - :param JoinedSyncResult|ArchivedSyncResult room: sync result for a - single room - :param int time_now: current time - used as a baseline for age - calculations - :param int token_id: ID of the user's auth token - used for namespacing - of transaction IDs - :param joined: True if the user is joined to this room - will mean - we handle ephemeral events - - :return: the room, encoded in our response format - :rtype: dict[str, object] + Args: + room (JoinedSyncResult|ArchivedSyncResult): sync result for a + single room + time_now (int): current time - used as a baseline for age + calculations + token_id (int): ID of the user's auth token - used for namespacing + of transaction IDs + joined (bool): True if the user is joined to this room - will mean + we handle ephemeral events + + Returns: + dict[str, object]: the room, encoded in our response format """ def serialize(event): # TODO(mjark): Respect formatting requirements in the filter. diff --git a/synapse/state.py b/synapse/state.py index 41d32e664a..4a9e148de7 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -86,7 +86,8 @@ class StateHandler(object): If `event_type` is specified, then the method returns only the one event (or None) with that `event_type` and `state_key`. - :returns map from (type, state_key) to event + Returns: + map from (type, state_key) to event """ event_ids = yield self.store.get_latest_event_ids_in_room(room_id) @@ -176,10 +177,11 @@ class StateHandler(object): """ Given a list of event_ids this method fetches the state at each event, resolves conflicts between them and returns them. - :returns a Deferred tuple of (`state_group`, `state`, `prev_state`). - `state_group` is the name of a state group if one and only one is - involved. `state` is a map from (type, state_key) to event, and - `prev_state` is a list of event ids. + Returns: + a Deferred tuple of (`state_group`, `state`, `prev_state`). + `state_group` is the name of a state group if one and only one is + involved. `state` is a map from (type, state_key) to event, and + `prev_state` is a list of event ids. """ logger.debug("resolve_state_groups event_ids %s", event_ids) @@ -251,9 +253,10 @@ class StateHandler(object): def _resolve_events(self, state_sets, event_type=None, state_key=""): """ - :returns a tuple (new_state, prev_states). new_state is a map - from (type, state_key) to event. prev_states is a list of event_ids. - :rtype: (dict[(str, str), synapse.events.FrozenEvent], list[str]) + Returns + (dict[(str, str), synapse.events.FrozenEvent], list[str]): a tuple + (new_state, prev_states). new_state is a map from (type, state_key) + to event. prev_states is a list of event_ids. """ with Measure(self.clock, "state._resolve_events"): state = {} diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py index dc5830450a..3933b6e2c5 100644 --- a/synapse/storage/event_push_actions.py +++ b/synapse/storage/event_push_actions.py @@ -26,8 +26,9 @@ logger = logging.getLogger(__name__) class EventPushActionsStore(SQLBaseStore): def _set_push_actions_for_event_and_users_txn(self, txn, event, tuples): """ - :param event: the event set actions for - :param tuples: list of tuples of (user_id, actions) + Args: + event: the event set actions for + tuples: list of tuples of (user_id, actions) """ values = [] for uid, actions in tuples: diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index bd4eb88a92..d46a963bb8 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -458,12 +458,15 @@ class RegistrationStore(SQLBaseStore): """ Gets the 3pid's guest access token if exists, else saves access_token. - :param medium (str): Medium of the 3pid. Must be "email". - :param address (str): 3pid address. - :param access_token (str): The access token to persist if none is - already persisted. - :param inviter_user_id (str): User ID of the inviter. - :return (deferred str): Whichever access token is persisted at the end + Args: + medium (str): Medium of the 3pid. Must be "email". + address (str): 3pid address. + access_token (str): The access token to persist if none is + already persisted. + inviter_user_id (str): User ID of the inviter. + + Returns: + deferred str: Whichever access token is persisted at the end of this function call. """ def insert(txn): diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 7fc9a4f264..f84fd0e30a 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -249,11 +249,14 @@ class StateStore(SQLBaseStore): """ Get the state dict corresponding to a particular event - :param str event_id: event whose state should be returned - :param list[(str, str)]|None types: List of (type, state_key) tuples - which are used to filter the state fetched. May be None, which - matches any key - :return: a deferred dict from (type, state_key) -> state_event + Args: + event_id(str): event whose state should be returned + types(list[(str, str)]|None): List of (type, state_key) tuples + which are used to filter the state fetched. May be None, which + matches any key + + Returns: + A deferred dict from (type, state_key) -> state_event """ state_map = yield self.get_state_for_events([event_id], types) defer.returnValue(state_map[event_id]) -- cgit 1.4.1 From d76d89323c4b962b4f3ff72a3c9b40f2d2d347b3 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 1 Apr 2016 17:39:32 +0100 Subject: Use computed prev event ids --- synapse/handlers/_base.py | 29 +++++++++++++++++------------ synapse/handlers/message.py | 6 +++++- synapse/handlers/room_member.py | 3 +++ synapse/storage/event_federation.py | 16 ++++++++++++++++ 4 files changed, 41 insertions(+), 13 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index 2d8ff79a9e..242afa1564 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -199,20 +199,25 @@ class BaseHandler(object): ) @defer.inlineCallbacks - def _create_new_client_event(self, builder): - latest_ret = yield self.store.get_latest_event_ids_and_hashes_in_room( - builder.room_id, - ) - - if latest_ret: - depth = max([d for _, _, d in latest_ret]) + 1 + def _create_new_client_event(self, builder, prev_event_ids=None): + if prev_event_ids: + prev_events = yield self.store.add_event_hashes(prev_event_ids) + prev_max_depth = yield self.store.get_max_depth_of_events(prev_event_ids) + depth = prev_max_depth + 1 else: - depth = 1 + latest_ret = yield self.store.get_latest_event_ids_and_hashes_in_room( + builder.room_id, + ) - prev_events = [ - (event_id, prev_hashes) - for event_id, prev_hashes, _ in latest_ret - ] + if latest_ret: + depth = max([d for _, _, d in latest_ret]) + 1 + else: + depth = 1 + + prev_events = [ + (event_id, prev_hashes) + for event_id, prev_hashes, _ in latest_ret + ] builder.prev_events = prev_events builder.depth = depth diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 0bb111d047..10608c0dd9 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -176,7 +176,7 @@ class MessageHandler(BaseHandler): defer.returnValue(chunk) @defer.inlineCallbacks - def create_event(self, event_dict, token_id=None, txn_id=None): + def create_event(self, event_dict, token_id=None, txn_id=None, prev_event_ids=None): """ Given a dict from a client, create a new event. @@ -187,6 +187,9 @@ class MessageHandler(BaseHandler): Args: event_dict (dict): An entire event + token_id (str) + txn_id (str) + prev_event_ids (list): The prev event ids to use when creating the event Returns: Tuple of created event (FrozenEvent), Context @@ -225,6 +228,7 @@ class MessageHandler(BaseHandler): event, context = yield self._create_new_client_event( builder=builder, + prev_event_ids=prev_event_ids, ) defer.returnValue((event, context)) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 09b8f6217a..7c1bb8cfe4 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -98,6 +98,7 @@ class RoomMemberHandler(BaseHandler): @defer.inlineCallbacks def _local_membership_update( self, requester, target, room_id, membership, + prev_event_ids, txn_id=None, ratelimit=True, ): @@ -120,6 +121,7 @@ class RoomMemberHandler(BaseHandler): }, token_id=requester.access_token_id, txn_id=txn_id, + prev_event_ids=prev_event_ids, ) yield self.handle_new_client_event( @@ -268,6 +270,7 @@ class RoomMemberHandler(BaseHandler): membership=effective_membership_state, txn_id=txn_id, ratelimit=ratelimit, + prev_event_ids=latest_event_ids, ) @defer.inlineCallbacks diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py index 3489315e0d..0827946207 100644 --- a/synapse/storage/event_federation.py +++ b/synapse/storage/event_federation.py @@ -163,6 +163,22 @@ class EventFederationStore(SQLBaseStore): room_id, ) + @defer.inlineCallbacks + def get_max_depth_of_events(self, event_ids): + sql = ( + "SELECT MAX(depth) FROM events WHERE event_id IN (%s)" + ) % (",".join(["?"] * len(event_ids)),) + + rows = yield self._execute( + "get_max_depth_of_events", None, + sql, *event_ids + ) + + if rows: + defer.returnValue(rows[0][0]) + else: + defer.returnValue(1) + def _get_min_depth_interaction(self, txn, room_id): min_depth = self._simple_select_one_onecol_txn( txn, -- cgit 1.4.1 From 3d76b7cb2ba05fbf17be0a6647f39c419f428c16 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 4 Apr 2016 15:52:01 +0100 Subject: Store invites in a separate table. --- synapse/handlers/room_member.py | 2 +- synapse/storage/events.py | 13 +--- synapse/storage/prepare_database.py | 2 +- synapse/storage/roommember.py | 111 ++++++++++++++++++++++------ synapse/storage/schema/delta/31/invites.sql | 28 +++++++ 5 files changed, 124 insertions(+), 32 deletions(-) create mode 100644 synapse/storage/schema/delta/31/invites.sql (limited to 'synapse/storage') diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 7c1bb8cfe4..98e346d48e 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -400,7 +400,7 @@ class RoomMemberHandler(BaseHandler): @defer.inlineCallbacks def get_inviter(self, user_id, room_id): - invite = yield self.store.get_room_member(user_id=user_id, room_id=room_id) + invite = yield self.store.get_inviter(user_id=user_id, room_id=room_id) if invite: defer.returnValue(UserID.from_string(invite.sender)) diff --git a/synapse/storage/events.py b/synapse/storage/events.py index c4dc3b3d51..5d299a1132 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -367,7 +367,8 @@ class EventsStore(SQLBaseStore): event for event, _ in events_and_contexts if event.type == EventTypes.Member - ] + ], + backfilled=backfilled, ) def event_dict(event): @@ -485,14 +486,8 @@ class EventsStore(SQLBaseStore): return for event, _ in state_events_and_contexts: - if (not event.internal_metadata.is_invite_from_remote() - and event.internal_metadata.is_outlier()): - # Outlier events generally shouldn't clobber the current state. - # However invites from remote severs for rooms we aren't in - # are a bit special: they don't come with any associated - # state so are technically an outlier, however all the - # client-facing code assumes that they are in the current - # state table so we insert the event anyway. + if event.internal_metadata.is_outlier(): + # Outlier events shouldn't clobber the current state. continue if context.rejected: diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 3f29aad1e8..4099387ba7 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -25,7 +25,7 @@ logger = logging.getLogger(__name__) # Remember to update this number every time a change is made to database # schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 30 +SCHEMA_VERSION = 31 dir_path = os.path.abspath(os.path.dirname(__file__)) diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 430b49c12e..4c026b33ae 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -36,7 +36,7 @@ RoomsForUser = namedtuple( class RoomMemberStore(SQLBaseStore): - def _store_room_members_txn(self, txn, events): + def _store_room_members_txn(self, txn, events, backfilled): """Store a room member in the database. """ self._simple_insert_many_txn( @@ -62,6 +62,41 @@ class RoomMemberStore(SQLBaseStore): self._membership_stream_cache.entity_has_changed, event.state_key, event.internal_metadata.stream_ordering ) + txn.call_after( + self.get_invited_rooms_for_user.invalidate, (event.state_key,) + ) + + is_mine = self.hs.is_mine_id(event.state_key) + is_new_state = not backfilled and ( + not event.internal_metadata.is_outlier() + or event.internal_metadata.is_invite_from_remote() + ) + if is_new_state and is_mine: + if event.membership == Membership.INVITE: + self._simple_insert_txn( + txn, + table="invites", + values={ + "event_id": event.event_id, + "invitee": event.state_key, + "inviter": event.sender, + "room_id": event.room_id, + "stream_id": event.internal_metadata.stream_ordering, + } + ) + else: + sql = ( + "UPDATE invites SET stream_id = ?, replaced_by = ? WHERE" + " room_id = ? AND invitee = ? AND locally_rejected is NULL" + " AND replaced_by is NULL" + ) + + txn.execute(sql, ( + event.internal_metadata.stream_ordering, + event.event_id, + event.room_id, + event.state_key, + )) def get_room_member(self, user_id, room_id): """Retrieve the current state of a room member. @@ -127,6 +162,14 @@ class RoomMemberStore(SQLBaseStore): user_id, [Membership.INVITE] ) + @defer.inlineCallbacks + def get_inviter(self, user_id, room_id): + invites = yield self.get_invited_rooms_for_user(user_id) + for invite in invites: + if invite.room_id == room_id: + defer.returnValue(invite) + defer.returnValue(None) + def get_leave_and_ban_events_for_user(self, user_id): """ Get all the leave events for a user Args: @@ -163,29 +206,55 @@ class RoomMemberStore(SQLBaseStore): def _get_rooms_for_user_where_membership_is_txn(self, txn, user_id, membership_list): - where_clause = "user_id = ? AND (%s) AND forgotten = 0" % ( - " OR ".join(["membership = ?" for _ in membership_list]), - ) - args = [user_id] - args.extend(membership_list) + do_invite = Membership.INVITE in membership_list + membership_list = [m for m in membership_list if m != Membership.INVITE] - sql = ( - "SELECT m.room_id, m.sender, m.membership, m.event_id, e.stream_ordering" - " FROM current_state_events as c" - " INNER JOIN room_memberships as m" - " ON m.event_id = c.event_id" - " INNER JOIN events as e" - " ON e.event_id = c.event_id" - " AND m.room_id = c.room_id" - " AND m.user_id = c.state_key" - " WHERE %s" - ) % (where_clause,) + results = [] + if membership_list: + where_clause = "user_id = ? AND (%s) AND forgotten = 0" % ( + " OR ".join(["membership = ?" for _ in membership_list]), + ) + + args = [user_id] + args.extend(membership_list) + + sql = ( + "SELECT m.room_id, m.sender, m.membership, m.event_id, e.stream_ordering" + " FROM current_state_events as c" + " INNER JOIN room_memberships as m" + " ON m.event_id = c.event_id" + " INNER JOIN events as e" + " ON e.event_id = c.event_id" + " AND m.room_id = c.room_id" + " AND m.user_id = c.state_key" + " WHERE %s" + ) % (where_clause,) + + txn.execute(sql, args) + results = [ + RoomsForUser(**r) for r in self.cursor_to_dict(txn) + ] + + if do_invite: + sql = ( + "SELECT i.room_id, inviter, i.event_id, e.stream_ordering" + " FROM invites as i" + " INNER JOIN events as e USING (event_id)" + " WHERE invitee = ? AND locally_rejected is NULL" + " AND replaced_by is NULL" + ) + + txn.execute(sql, (user_id,)) + results.extend(RoomsForUser( + room_id=r["room_id"], + sender=r["inviter"], + event_id=r["event_id"], + stream_ordering=r["stream_ordering"], + membership=Membership.INVITE, + ) for r in self.cursor_to_dict(txn)) - txn.execute(sql, args) - return [ - RoomsForUser(**r) for r in self.cursor_to_dict(txn) - ] + return results @cached(max_entries=5000) def get_joined_hosts_for_room(self, room_id): diff --git a/synapse/storage/schema/delta/31/invites.sql b/synapse/storage/schema/delta/31/invites.sql new file mode 100644 index 0000000000..4f6fb9ea63 --- /dev/null +++ b/synapse/storage/schema/delta/31/invites.sql @@ -0,0 +1,28 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +CREATE TABLE invites( + stream_id BIGINT NOT NULL, + inviter TEXT NOT NULL, + invitee TEXT NOT NULL, + event_id TEXT NOT NULL, + room_id TEXT NOT NULL, + locally_rejected TEXT, + replaced_by TEXT +); + +CREATE INDEX invites_id ON invites(stream_id); +CREATE INDEX invites_for_user_idx ON invites(invitee, locally_rejected, replaced_by, room_id); -- cgit 1.4.1 From 92ab45a330c2d6c4e896786135e93b6cabfad1ea Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 4 Apr 2016 17:07:43 +0100 Subject: Add upgrade path, rename table --- synapse/storage/roommember.py | 6 +++--- synapse/storage/schema/delta/31/invites.sql | 20 +++++++++++++++++--- 2 files changed, 20 insertions(+), 6 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 4c026b33ae..abe5942744 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -75,7 +75,7 @@ class RoomMemberStore(SQLBaseStore): if event.membership == Membership.INVITE: self._simple_insert_txn( txn, - table="invites", + table="local_invites", values={ "event_id": event.event_id, "invitee": event.state_key, @@ -86,7 +86,7 @@ class RoomMemberStore(SQLBaseStore): ) else: sql = ( - "UPDATE invites SET stream_id = ?, replaced_by = ? WHERE" + "UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE" " room_id = ? AND invitee = ? AND locally_rejected is NULL" " AND replaced_by is NULL" ) @@ -239,7 +239,7 @@ class RoomMemberStore(SQLBaseStore): if do_invite: sql = ( "SELECT i.room_id, inviter, i.event_id, e.stream_ordering" - " FROM invites as i" + " FROM local_invites as i" " INNER JOIN events as e USING (event_id)" " WHERE invitee = ? AND locally_rejected is NULL" " AND replaced_by is NULL" diff --git a/synapse/storage/schema/delta/31/invites.sql b/synapse/storage/schema/delta/31/invites.sql index 4f6fb9ea63..1c83430da4 100644 --- a/synapse/storage/schema/delta/31/invites.sql +++ b/synapse/storage/schema/delta/31/invites.sql @@ -14,7 +14,7 @@ */ -CREATE TABLE invites( +CREATE TABLE local_invites( stream_id BIGINT NOT NULL, inviter TEXT NOT NULL, invitee TEXT NOT NULL, @@ -24,5 +24,19 @@ CREATE TABLE invites( replaced_by TEXT ); -CREATE INDEX invites_id ON invites(stream_id); -CREATE INDEX invites_for_user_idx ON invites(invitee, locally_rejected, replaced_by, room_id); +-- Insert all invites for local users into new `invites` table +INSERT INTO local_invites SELECT + stream_ordering as stream_id, + sender as inviter, + state_key as invitee, + event_id, + room_id, + NULL as locally_rejected, + NULL as replaced_by +FROM events +NATURAL JOIN current_state_events +NATURAL JOIN room_memberships +WHERE membership = 'invite' AND state_key IN (SELECT name FROM users); + +CREATE INDEX local_invites_id ON local_invites(stream_id); +CREATE INDEX local_invites_for_user_idx ON local_invites(invitee, locally_rejected, replaced_by, room_id); -- cgit 1.4.1 From 0c53d750e7145f57ed97c544efdb846cc9e37b67 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 4 Apr 2016 18:02:48 +0100 Subject: Docs and indents --- synapse/handlers/room_member.py | 5 ++++- synapse/storage/roommember.py | 18 ++++++++++++++++-- synapse/storage/schema/delta/31/invites.sql | 22 +++++++++++----------- 3 files changed, 31 insertions(+), 14 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 98e346d48e..f1c3e90ecd 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -400,7 +400,10 @@ class RoomMemberHandler(BaseHandler): @defer.inlineCallbacks def get_inviter(self, user_id, room_id): - invite = yield self.store.get_inviter(user_id=user_id, room_id=room_id) + invite = yield self.store.get_invite_for_user_in_room( + user_id=user_id, + room_id=room_id, + ) if invite: defer.returnValue(UserID.from_string(invite.sender)) diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index abe5942744..36456a75fc 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -66,11 +66,15 @@ class RoomMemberStore(SQLBaseStore): self.get_invited_rooms_for_user.invalidate, (event.state_key,) ) - is_mine = self.hs.is_mine_id(event.state_key) + # We update the local_invites table only if the event is "current", + # i.e., its something that has just happened. + # The only current event that can also be an outlier is if its an + # invite that has come in across federation. is_new_state = not backfilled and ( not event.internal_metadata.is_outlier() or event.internal_metadata.is_invite_from_remote() ) + is_mine = self.hs.is_mine_id(event.state_key) if is_new_state and is_mine: if event.membership == Membership.INVITE: self._simple_insert_txn( @@ -163,7 +167,17 @@ class RoomMemberStore(SQLBaseStore): ) @defer.inlineCallbacks - def get_inviter(self, user_id, room_id): + def get_invite_for_user_in_room(self, user_id, room_id): + """Gets the invite for the given user and room + + Args: + user_id (str) + room_id (str) + + Returns: + Deferred: Resolves to either a RoomsForUser or None if no invite was + found. + """ invites = yield self.get_invited_rooms_for_user(user_id) for invite in invites: if invite.room_id == room_id: diff --git a/synapse/storage/schema/delta/31/invites.sql b/synapse/storage/schema/delta/31/invites.sql index 1c83430da4..2c57846d5a 100644 --- a/synapse/storage/schema/delta/31/invites.sql +++ b/synapse/storage/schema/delta/31/invites.sql @@ -26,17 +26,17 @@ CREATE TABLE local_invites( -- Insert all invites for local users into new `invites` table INSERT INTO local_invites SELECT - stream_ordering as stream_id, - sender as inviter, - state_key as invitee, - event_id, - room_id, - NULL as locally_rejected, - NULL as replaced_by -FROM events -NATURAL JOIN current_state_events -NATURAL JOIN room_memberships -WHERE membership = 'invite' AND state_key IN (SELECT name FROM users); + stream_ordering as stream_id, + sender as inviter, + state_key as invitee, + event_id, + room_id, + NULL as locally_rejected, + NULL as replaced_by + FROM events + NATURAL JOIN current_state_events + NATURAL JOIN room_memberships + WHERE membership = 'invite' AND state_key IN (SELECT name FROM users); CREATE INDEX local_invites_id ON local_invites(stream_id); CREATE INDEX local_invites_for_user_idx ON local_invites(invitee, locally_rejected, replaced_by, room_id); -- cgit 1.4.1 From df727f212606b771b1410c8e322fb8a99d159de4 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 5 Apr 2016 11:13:24 +0100 Subject: Fix stuck invites If rejecting a remote invite fails with an error response don't fail the entire request; instead mark the invite as locally rejected. This fixes the bug where users can get stuck invites which they can neither accept nor reject. --- synapse/handlers/federation.py | 34 +++++++++++++++++++++++----------- synapse/handlers/room_member.py | 18 ++++++++++++++---- synapse/storage/__init__.py | 3 ++- synapse/storage/roommember.py | 19 +++++++++++++++++++ 4 files changed, 58 insertions(+), 16 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 4049c01d26..19769eecd7 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -784,13 +784,19 @@ class FederationHandler(BaseHandler): @defer.inlineCallbacks def do_remotely_reject_invite(self, target_hosts, room_id, user_id): - origin, event = yield self._make_and_verify_event( - target_hosts, - room_id, - user_id, - "leave" - ) - signed_event = self._sign_event(event) + try: + origin, event = yield self._make_and_verify_event( + target_hosts, + room_id, + user_id, + "leave" + ) + signed_event = self._sign_event(event) + except SynapseError: + raise + except CodeMessageException as e: + logger.warn("Failed to reject invite: %s", e) + raise SynapseError(500, "Failed to reject invite") # Try the host we successfully got a response to /make_join/ # request first. @@ -800,10 +806,16 @@ class FederationHandler(BaseHandler): except ValueError: pass - yield self.replication_layer.send_leave( - target_hosts, - signed_event - ) + try: + yield self.replication_layer.send_leave( + target_hosts, + signed_event + ) + except SynapseError: + raise + except CodeMessageException as e: + logger.warn("Failed to reject invite: %s", e) + raise SynapseError(500, "Failed to reject invite") context = yield self.state_handler.compute_event_context(event) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index f1c3e90ecd..6c7409215a 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -258,10 +258,20 @@ class RoomMemberHandler(BaseHandler): else: # send the rejection to the inviter's HS. remote_room_hosts = remote_room_hosts + [inviter.domain] - ret = yield self.reject_remote_invite( - target.to_string(), room_id, remote_room_hosts - ) - defer.returnValue(ret) + + try: + ret = yield self.reject_remote_invite( + target.to_string(), room_id, remote_room_hosts + ) + defer.returnValue(ret) + except SynapseError as e: + logger.warn("Failed to reject invite: %s", e) + + yield self.store.locally_reject_invite( + target.to_string(), room_id + ) + + defer.returnValue({}) yield self._local_membership_update( requester=requester, diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 57863bba4d..07916b292d 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -94,7 +94,8 @@ class DataStore(RoomMemberStore, RoomStore, ) self._stream_id_gen = StreamIdGenerator( - db_conn, "events", "stream_ordering" + db_conn, "events", "stream_ordering", + extra_tables=[("local_invites", "stream_id")] ) self._backfill_id_gen = StreamIdGenerator( db_conn, "events", "stream_ordering", step=-1 diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 36456a75fc..66e7a40e3c 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -102,6 +102,25 @@ class RoomMemberStore(SQLBaseStore): event.state_key, )) + @defer.inlineCallbacks + def locally_reject_invite(self, user_id, room_id): + sql = ( + "UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE" + " room_id = ? AND invitee = ? AND locally_rejected is NULL" + " AND replaced_by is NULL" + ) + + def f(txn, stream_ordering): + txn.execute(sql, ( + stream_ordering, + True, + room_id, + user_id, + )) + + with self._stream_id_gen.get_next() as stream_ordering: + yield self.runInteraction("locally_reject_invite", f, stream_ordering) + def get_room_member(self, user_id, room_id): """Retrieve the current state of a room member. -- cgit 1.4.1 From a1e0d316ea354fce07939073d9afc9c5d1013939 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Wed, 6 Apr 2016 13:05:19 +0100 Subject: Move _get_cache_dict into the SQLBaseStore --- synapse/storage/__init__.py | 33 --------------------------------- synapse/storage/_base.py | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 34 insertions(+), 33 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 07916b292d..045ae6c03f 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -177,39 +177,6 @@ class DataStore(RoomMemberStore, RoomStore, self.__presence_on_startup = None return active_on_startup - def _get_cache_dict(self, db_conn, table, entity_column, stream_column, max_value): - # Fetch a mapping of room_id -> max stream position for "recent" rooms. - # It doesn't really matter how many we get, the StreamChangeCache will - # do the right thing to ensure it respects the max size of cache. - sql = ( - "SELECT %(entity)s, MAX(%(stream)s) FROM %(table)s" - " WHERE %(stream)s > ? - 100000" - " GROUP BY %(entity)s" - ) % { - "table": table, - "entity": entity_column, - "stream": stream_column, - } - - sql = self.database_engine.convert_param_style(sql) - - txn = db_conn.cursor() - txn.execute(sql, (int(max_value),)) - rows = txn.fetchall() - txn.close() - - cache = { - row[0]: int(row[1]) - for row in rows - } - - if cache: - min_val = min(cache.values()) - else: - min_val = max_value - - return cache, min_val - def _get_active_presence(self, db_conn): """Fetch non-offline presence from the database so that we can register the appropriate time outs. diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index b75b79df36..04d7fcf6d6 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -816,6 +816,40 @@ class SQLBaseStore(object): self._next_stream_id += 1 return i + def _get_cache_dict(self, db_conn, table, entity_column, stream_column, + max_value): + # Fetch a mapping of room_id -> max stream position for "recent" rooms. + # It doesn't really matter how many we get, the StreamChangeCache will + # do the right thing to ensure it respects the max size of cache. + sql = ( + "SELECT %(entity)s, MAX(%(stream)s) FROM %(table)s" + " WHERE %(stream)s > ? - 100000" + " GROUP BY %(entity)s" + ) % { + "table": table, + "entity": entity_column, + "stream": stream_column, + } + + sql = self.database_engine.convert_param_style(sql) + + txn = db_conn.cursor() + txn.execute(sql, (int(max_value),)) + rows = txn.fetchall() + txn.close() + + cache = { + row[0]: int(row[1]) + for row in rows + } + + if cache: + min_val = min(cache.values()) + else: + min_val = max_value + + return cache, min_val + class _RollbackButIsFineException(Exception): """ This exception is used to rollback a transaction without implying -- cgit 1.4.1 From 87f2dec8d475f038beb138bc56e3ef76fcb83ec6 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Wed, 6 Apr 2016 13:08:05 +0100 Subject: Make the cache objects be per instance rather than being global --- synapse/storage/receipts.py | 4 ++-- synapse/storage/registration.py | 2 +- synapse/storage/state.py | 4 ++-- synapse/util/caches/descriptors.py | 45 ++++++++++++++++++++------------------ 4 files changed, 29 insertions(+), 26 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py index 4befebc8e2..7fdd84bbdc 100644 --- a/synapse/storage/receipts.py +++ b/synapse/storage/receipts.py @@ -160,8 +160,8 @@ class ReceiptsStore(SQLBaseStore): "content": content, }]) - @cachedList(cache=get_linearized_receipts_for_room.cache, list_name="room_ids", - num_args=3, inlineCallbacks=True) + @cachedList(cached_method_name="get_linearized_receipts_for_room", + list_name="room_ids", num_args=3, inlineCallbacks=True) def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None): if not room_ids: defer.returnValue({}) diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index d46a963bb8..1f71773aaa 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -319,7 +319,7 @@ class RegistrationStore(SQLBaseStore): defer.returnValue(res if res else False) - @cachedList(cache=is_guest.cache, list_name="user_ids", num_args=1, + @cachedList(cached_method_name="is_guest", list_name="user_ids", num_args=1, inlineCallbacks=True) def are_guests(self, user_ids): sql = "SELECT name, is_guest FROM users WHERE name IN (%s)" % ( diff --git a/synapse/storage/state.py b/synapse/storage/state.py index e9f9406014..c5d2a3a6df 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -273,8 +273,8 @@ class StateStore(SQLBaseStore): desc="_get_state_group_for_event", ) - @cachedList(cache=_get_state_group_for_event.cache, list_name="event_ids", - num_args=1, inlineCallbacks=True) + @cachedList(cached_method_name="_get_state_group_for_event", + list_name="event_ids", num_args=1, inlineCallbacks=True) def _get_state_group_for_events(self, event_ids): """Returns mapping event_id -> state_group """ diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 35544b19fd..758f5982b0 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -167,7 +167,8 @@ class CacheDescriptor(object): % (orig.__name__,) ) - self.cache = Cache( + def __get__(self, obj, objtype=None): + cache = Cache( name=self.orig.__name__, max_entries=self.max_entries, keylen=self.num_args, @@ -175,14 +176,12 @@ class CacheDescriptor(object): tree=self.tree, ) - def __get__(self, obj, objtype=None): - @functools.wraps(self.orig) def wrapped(*args, **kwargs): arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names) try: - cached_result_d = self.cache.get(cache_key) + cached_result_d = cache.get(cache_key) observer = cached_result_d.observe() if DEBUG_CACHES: @@ -204,7 +203,7 @@ class CacheDescriptor(object): # Get the sequence number of the cache before reading from the # database so that we can tell if the cache is invalidated # while the SELECT is executing (SYN-369) - sequence = self.cache.sequence + sequence = cache.sequence ret = defer.maybeDeferred( preserve_context_over_fn, @@ -213,20 +212,21 @@ class CacheDescriptor(object): ) def onErr(f): - self.cache.invalidate(cache_key) + cache.invalidate(cache_key) return f ret.addErrback(onErr) ret = ObservableDeferred(ret, consumeErrors=True) - self.cache.update(sequence, cache_key, ret) + cache.update(sequence, cache_key, ret) return preserve_context_over_deferred(ret.observe()) - wrapped.invalidate = self.cache.invalidate - wrapped.invalidate_all = self.cache.invalidate_all - wrapped.invalidate_many = self.cache.invalidate_many - wrapped.prefill = self.cache.prefill + wrapped.invalidate = cache.invalidate + wrapped.invalidate_all = cache.invalidate_all + wrapped.invalidate_many = cache.invalidate_many + wrapped.prefill = cache.prefill + wrapped.cache = cache obj.__dict__[self.orig.__name__] = wrapped @@ -240,11 +240,12 @@ class CacheListDescriptor(object): the list of missing keys to the wrapped fucntion. """ - def __init__(self, orig, cache, list_name, num_args=1, inlineCallbacks=False): + def __init__(self, orig, cached_method_name, list_name, num_args=1, + inlineCallbacks=False): """ Args: orig (function) - cache (Cache) + method_name (str); The name of the chached method. list_name (str): Name of the argument which is the bulk lookup list num_args (int) inlineCallbacks (bool): Whether orig is a generator that should @@ -263,7 +264,7 @@ class CacheListDescriptor(object): self.arg_names = inspect.getargspec(orig).args[1:num_args + 1] self.list_pos = self.arg_names.index(self.list_name) - self.cache = cache + self.cached_method_name = cached_method_name self.sentinel = object() @@ -277,11 +278,13 @@ class CacheListDescriptor(object): if self.list_name not in self.arg_names: raise Exception( "Couldn't see arguments %r for %r." - % (self.list_name, cache.name,) + % (self.list_name, cached_method_name,) ) def __get__(self, obj, objtype=None): + cache = getattr(obj, self.cached_method_name).cache + @functools.wraps(self.orig) def wrapped(*args, **kwargs): arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) @@ -297,14 +300,14 @@ class CacheListDescriptor(object): key[self.list_pos] = arg try: - res = self.cache.get(tuple(key)).observe() + res = cache.get(tuple(key)).observe() res.addCallback(lambda r, arg: (arg, r), arg) cached[arg] = res except KeyError: missing.append(arg) if missing: - sequence = self.cache.sequence + sequence = cache.sequence args_to_call = dict(arg_dict) args_to_call[self.list_name] = missing @@ -327,10 +330,10 @@ class CacheListDescriptor(object): key = list(keyargs) key[self.list_pos] = arg - self.cache.update(sequence, tuple(key), observer) + cache.update(sequence, tuple(key), observer) def invalidate(f, key): - self.cache.invalidate(key) + cache.invalidate(key) return f observer.addErrback(invalidate, tuple(key)) @@ -370,7 +373,7 @@ def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False, tree=False): ) -def cachedList(cache, list_name, num_args=1, inlineCallbacks=False): +def cachedList(cached_method_name, list_name, num_args=1, inlineCallbacks=False): """Creates a descriptor that wraps a function in a `CacheListDescriptor`. Used to do batch lookups for an already created cache. A single argument @@ -400,7 +403,7 @@ def cachedList(cache, list_name, num_args=1, inlineCallbacks=False): """ return lambda orig: CacheListDescriptor( orig, - cache=cache, + cached_method_name=cached_method_name, list_name=list_name, num_args=num_args, inlineCallbacks=inlineCallbacks, -- cgit 1.4.1 From 8aab9d87fa6739345810f0edf3982fe7f898ee30 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 6 Apr 2016 14:08:18 +0100 Subject: Don't require config to create database --- scripts/synapse_port_db | 13 ++--- synapse/app/homeserver.py | 15 +++-- synapse/storage/engines/__init__.py | 6 +- synapse/storage/engines/postgres.py | 8 +-- synapse/storage/engines/sqlite3.py | 13 +---- synapse/storage/prepare_database.py | 64 +++++++--------------- .../schema/delta/14/upgrade_appservice_db.py | 6 +- synapse/storage/schema/delta/20/pushers.py | 6 +- synapse/storage/schema/delta/25/fts.py | 6 +- synapse/storage/schema/delta/27/ts.py | 6 +- synapse/storage/schema/delta/30/as_users.py | 4 +- tests/storage/test_base.py | 2 +- tests/utils.py | 6 +- 13 files changed, 69 insertions(+), 86 deletions(-) (limited to 'synapse/storage') diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db index a2a0f364cf..253a6ef6c7 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db @@ -19,6 +19,7 @@ from twisted.enterprise import adbapi from synapse.storage._base import LoggingTransaction, SQLBaseStore from synapse.storage.engines import create_engine +from synapse.storage.prepare_database import prepare_database import argparse import curses @@ -37,6 +38,7 @@ BOOLEAN_COLUMNS = { "rooms": ["is_public"], "event_edges": ["is_state"], "presence_list": ["accepted"], + "presence_stream": ["currently_active"], } @@ -292,7 +294,7 @@ class Porter(object): } ) - database_engine.prepare_database(db_conn) + prepare_database(db_conn, database_engine, config=None) db_conn.commit() @@ -309,8 +311,8 @@ class Porter(object): **self.postgres_config["args"] ) - sqlite_engine = create_engine(FakeConfig(sqlite_config)) - postgres_engine = create_engine(FakeConfig(postgres_config)) + sqlite_engine = create_engine(sqlite_config) + postgres_engine = create_engine(postgres_config) self.sqlite_store = Store(sqlite_db_pool, sqlite_engine) self.postgres_store = Store(postgres_db_pool, postgres_engine) @@ -792,8 +794,3 @@ if __name__ == "__main__": if end_error_exec_info: exc_type, exc_value, exc_traceback = end_error_exec_info traceback.print_exception(exc_type, exc_value, exc_traceback) - - -class FakeConfig: - def __init__(self, database_config): - self.database_config = database_config diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index fcdc8e6e10..2b4473b9ac 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -33,7 +33,7 @@ from synapse.python_dependencies import ( from synapse.rest import ClientRestResource from synapse.storage.engines import create_engine, IncorrectDatabaseSetup from synapse.storage import are_all_users_on_domain -from synapse.storage.prepare_database import UpgradeDatabaseException +from synapse.storage.prepare_database import UpgradeDatabaseException, prepare_database from synapse.server import HomeServer @@ -245,7 +245,7 @@ class SynapseHomeServer(HomeServer): except IncorrectDatabaseSetup as e: quit_with_error(e.message) - def get_db_conn(self): + def get_db_conn(self, run_new_connection=True): # Any param beginning with cp_ is a parameter for adbapi, and should # not be passed to the database engine. db_params = { @@ -254,7 +254,8 @@ class SynapseHomeServer(HomeServer): } db_conn = self.database_engine.module.connect(**db_params) - self.database_engine.on_new_connection(db_conn) + if run_new_connection: + self.database_engine.on_new_connection(db_conn) return db_conn @@ -386,7 +387,7 @@ def setup(config_options): tls_server_context_factory = context_factory.ServerContextFactory(config) - database_engine = create_engine(config) + database_engine = create_engine(config.database_config) config.database_config["args"]["cp_openfun"] = database_engine.on_new_connection hs = SynapseHomeServer( @@ -402,8 +403,10 @@ def setup(config_options): logger.info("Preparing database: %s...", config.database_config['name']) try: - db_conn = hs.get_db_conn() - database_engine.prepare_database(db_conn) + db_conn = hs.get_db_conn(run_new_connection=False) + prepare_database(db_conn, database_engine, config=config) + database_engine.on_new_connection(db_conn) + hs.run_startup_checks(db_conn, database_engine) db_conn.commit() diff --git a/synapse/storage/engines/__init__.py b/synapse/storage/engines/__init__.py index a48230b93f..7bb5de1fe7 100644 --- a/synapse/storage/engines/__init__.py +++ b/synapse/storage/engines/__init__.py @@ -26,13 +26,13 @@ SUPPORTED_MODULE = { } -def create_engine(config): - name = config.database_config["name"] +def create_engine(database_config): + name = database_config["name"] engine_class = SUPPORTED_MODULE.get(name, None) if engine_class: module = importlib.import_module(name) - return engine_class(module, config=config) + return engine_class(module) raise RuntimeError( "Unsupported database engine '%s'" % (name,) diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py index a09685b4df..c2290943b4 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py @@ -13,18 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.prepare_database import prepare_database - from ._base import IncorrectDatabaseSetup class PostgresEngine(object): single_threaded = False - def __init__(self, database_module, config): + def __init__(self, database_module): self.module = database_module self.module.extensions.register_type(self.module.extensions.UNICODE) - self.config = config def check_database(self, txn): txn.execute("SHOW SERVER_ENCODING") @@ -44,9 +41,6 @@ class PostgresEngine(object): self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ ) - def prepare_database(self, db_conn): - prepare_database(db_conn, self, config=self.config) - def is_deadlock(self, error): if isinstance(error, self.module.DatabaseError): return error.pgcode in ["40001", "40P01"] diff --git a/synapse/storage/engines/sqlite3.py b/synapse/storage/engines/sqlite3.py index 522b905949..14203aa500 100644 --- a/synapse/storage/engines/sqlite3.py +++ b/synapse/storage/engines/sqlite3.py @@ -13,9 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.prepare_database import ( - prepare_database, prepare_sqlite3_database -) +from synapse.storage.prepare_database import prepare_database import struct @@ -23,9 +21,8 @@ import struct class Sqlite3Engine(object): single_threaded = True - def __init__(self, database_module, config): + def __init__(self, database_module): self.module = database_module - self.config = config def check_database(self, txn): pass @@ -34,13 +31,9 @@ class Sqlite3Engine(object): return sql def on_new_connection(self, db_conn): - self.prepare_database(db_conn) + prepare_database(db_conn, self, config=None) db_conn.create_function("rank", 1, _rank) - def prepare_database(self, db_conn): - prepare_sqlite3_database(db_conn) - prepare_database(db_conn, self, config=self.config) - def is_deadlock(self, error): return False diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 4099387ba7..00833422af 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -53,6 +53,9 @@ class UpgradeDatabaseException(PrepareDatabaseException): def prepare_database(db_conn, database_engine, config): """Prepares a database for usage. Will either create all necessary tables or upgrade from an older schema version. + + If `config` is None then prepare_database will assert that no upgrade is + necessary, *or* will create a fresh database if the database is empty. """ try: cur = db_conn.cursor() @@ -60,13 +63,18 @@ def prepare_database(db_conn, database_engine, config): if version_info: user_version, delta_files, upgraded = version_info - _upgrade_existing_database( - cur, user_version, delta_files, upgraded, database_engine, config - ) - else: - _setup_new_database(cur, database_engine, config) - # cur.execute("PRAGMA user_version = %d" % (SCHEMA_VERSION,)) + if config is None: + if user_version != SCHEMA_VERSION: + # If we don't pass in a config file then we are expecting to + # have already upgraded the DB. + raise UpgradeDatabaseException("Database needs to be upgraded") + else: + _upgrade_existing_database( + cur, user_version, delta_files, upgraded, database_engine, config + ) + else: + _setup_new_database(cur, database_engine) cur.close() db_conn.commit() @@ -75,7 +83,7 @@ def prepare_database(db_conn, database_engine, config): raise -def _setup_new_database(cur, database_engine, config): +def _setup_new_database(cur, database_engine): """Sets up the database by finding a base set of "full schemas" and then applying any necessary deltas. @@ -148,12 +156,13 @@ def _setup_new_database(cur, database_engine, config): applied_delta_files=[], upgraded=False, database_engine=database_engine, - config=config, + config=None, + is_empty=True, ) def _upgrade_existing_database(cur, current_version, applied_delta_files, - upgraded, database_engine, config): + upgraded, database_engine, config, is_empty=False): """Upgrades an existing database. Delta files can either be SQL stored in *.sql files, or python modules @@ -246,7 +255,9 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files, module_name, absolute_path, python_file ) logger.debug("Running script %s", relative_path) - module.run_upgrade(cur, database_engine, config=config) + module.run_create(cur, database_engine) + if not is_empty: + module.run_upgrade(cur, database_engine, config=config) elif ext == ".pyc": # Sometimes .pyc files turn up anyway even though we've # disabled their generation; e.g. from distribution package @@ -361,36 +372,3 @@ def _get_or_create_schema_state(txn, database_engine): return current_version, applied_deltas, upgraded return None - - -def prepare_sqlite3_database(db_conn): - """This function should be called before `prepare_database` on sqlite3 - databases. - - Since we changed the way we store the current schema version and handle - updates to schemas, we need a way to upgrade from the old method to the - new. This only affects sqlite databases since they were the only ones - supported at the time. - """ - with db_conn: - schema_path = os.path.join( - dir_path, "schema", "schema_version.sql", - ) - create_schema = read_schema(schema_path) - db_conn.executescript(create_schema) - - c = db_conn.execute("SELECT * FROM schema_version") - rows = c.fetchall() - c.close() - - if not rows: - c = db_conn.execute("PRAGMA user_version") - row = c.fetchone() - c.close() - - if row and row[0]: - db_conn.execute( - "REPLACE INTO schema_version (version, upgraded)" - " VALUES (?,?)", - (row[0], False) - ) diff --git a/synapse/storage/schema/delta/14/upgrade_appservice_db.py b/synapse/storage/schema/delta/14/upgrade_appservice_db.py index 5c40a77757..8755bb2e49 100644 --- a/synapse/storage/schema/delta/14/upgrade_appservice_db.py +++ b/synapse/storage/schema/delta/14/upgrade_appservice_db.py @@ -18,7 +18,7 @@ import logging logger = logging.getLogger(__name__) -def run_upgrade(cur, *args, **kwargs): +def run_create(cur, *args, **kwargs): cur.execute("SELECT id, regex FROM application_services_regex") for row in cur.fetchall(): try: @@ -35,3 +35,7 @@ def run_upgrade(cur, *args, **kwargs): "UPDATE application_services_regex SET regex=? WHERE id=?", (new_regex, row[0]) ) + + +def run_upgrade(*args, **kwargs): + pass diff --git a/synapse/storage/schema/delta/20/pushers.py b/synapse/storage/schema/delta/20/pushers.py index 29164732af..147496a38b 100644 --- a/synapse/storage/schema/delta/20/pushers.py +++ b/synapse/storage/schema/delta/20/pushers.py @@ -27,7 +27,7 @@ import logging logger = logging.getLogger(__name__) -def run_upgrade(cur, database_engine, *args, **kwargs): +def run_create(cur, database_engine, *args, **kwargs): logger.info("Porting pushers table...") cur.execute(""" CREATE TABLE IF NOT EXISTS pushers2 ( @@ -74,3 +74,7 @@ def run_upgrade(cur, database_engine, *args, **kwargs): cur.execute("DROP TABLE pushers") cur.execute("ALTER TABLE pushers2 RENAME TO pushers") logger.info("Moved %d pushers to new table", count) + + +def run_upgrade(*args, **kwargs): + pass diff --git a/synapse/storage/schema/delta/25/fts.py b/synapse/storage/schema/delta/25/fts.py index d3ff2b1779..4269ac69ad 100644 --- a/synapse/storage/schema/delta/25/fts.py +++ b/synapse/storage/schema/delta/25/fts.py @@ -43,7 +43,7 @@ SQLITE_TABLE = ( ) -def run_upgrade(cur, database_engine, *args, **kwargs): +def run_create(cur, database_engine, *args, **kwargs): if isinstance(database_engine, PostgresEngine): for statement in get_statements(POSTGRES_TABLE.splitlines()): cur.execute(statement) @@ -76,3 +76,7 @@ def run_upgrade(cur, database_engine, *args, **kwargs): sql = database_engine.convert_param_style(sql) cur.execute(sql, ("event_search", progress_json)) + + +def run_upgrade(*args, **kwargs): + pass diff --git a/synapse/storage/schema/delta/27/ts.py b/synapse/storage/schema/delta/27/ts.py index f8c16391a2..71b12a2731 100644 --- a/synapse/storage/schema/delta/27/ts.py +++ b/synapse/storage/schema/delta/27/ts.py @@ -27,7 +27,7 @@ ALTER_TABLE = ( ) -def run_upgrade(cur, database_engine, *args, **kwargs): +def run_create(cur, database_engine, *args, **kwargs): for statement in get_statements(ALTER_TABLE.splitlines()): cur.execute(statement) @@ -55,3 +55,7 @@ def run_upgrade(cur, database_engine, *args, **kwargs): sql = database_engine.convert_param_style(sql) cur.execute(sql, ("event_origin_server_ts", progress_json)) + + +def run_upgrade(*args, **kwargs): + pass diff --git a/synapse/storage/schema/delta/30/as_users.py b/synapse/storage/schema/delta/30/as_users.py index 4f6e9dd540..b417e3ac08 100644 --- a/synapse/storage/schema/delta/30/as_users.py +++ b/synapse/storage/schema/delta/30/as_users.py @@ -18,7 +18,7 @@ from synapse.storage.appservice import ApplicationServiceStore logger = logging.getLogger(__name__) -def run_upgrade(cur, database_engine, config, *args, **kwargs): +def run_create(cur, database_engine, *args, **kwargs): # NULL indicates user was not registered by an appservice. try: cur.execute("ALTER TABLE users ADD COLUMN appservice_id TEXT") @@ -26,6 +26,8 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs): # Maybe we already added the column? Hope so... pass + +def run_upgrade(cur, database_engine, config, *args, **kwargs): cur.execute("SELECT name FROM users") rows = cur.fetchall() diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index 2e33beb07c..afbefb2e2d 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -53,7 +53,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): "test", db_pool=self.db_pool, config=config, - database_engine=create_engine(config), + database_engine=create_engine(config.database_config), ) self.datastore = SQLBaseStore(hs) diff --git a/tests/utils.py b/tests/utils.py index 52405502e9..c179df31ee 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -64,7 +64,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs): hs = HomeServer( name, db_pool=db_pool, config=config, version_string="Synapse/tests", - database_engine=create_engine(config), + database_engine=create_engine(config.database_config), get_db_conn=db_pool.get_db_conn, **kargs ) @@ -73,7 +73,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs): hs = HomeServer( name, db_pool=None, datastore=datastore, config=config, version_string="Synapse/tests", - database_engine=create_engine(config), + database_engine=create_engine(config.database_config), **kargs ) @@ -298,7 +298,7 @@ class SQLiteMemoryDbPool(ConnectionPool, object): return conn def create_engine(self): - return create_engine(self.config) + return create_engine(self.config.database_config) class MemoryDataStore(object): -- cgit 1.4.1 From 75fb9ac1be0fada60cdde38153ac0e3fe2b19a0a Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Wed, 6 Apr 2016 14:12:51 +0100 Subject: Add a slaved events store class Add a test to check that get_room_names_and_aliases does the same thing on both the master and on the slave data store. --- synapse/replication/slave/__init__.py | 14 ++ synapse/replication/slave/storage/__init__.py | 14 ++ synapse/replication/slave/storage/_base.py | 28 +++ .../slave/storage/_slaved_id_tracker.py | 30 ++++ synapse/replication/slave/storage/events.py | 198 +++++++++++++++++++++ synapse/storage/events.py | 4 +- tests/replication/slave/__init__.py | 14 ++ tests/replication/slave/storage/__init__.py | 14 ++ tests/replication/slave/storage/_base.py | 57 ++++++ tests/replication/slave/storage/test_events.py | 114 ++++++++++++ 10 files changed, 485 insertions(+), 2 deletions(-) create mode 100644 synapse/replication/slave/__init__.py create mode 100644 synapse/replication/slave/storage/__init__.py create mode 100644 synapse/replication/slave/storage/_base.py create mode 100644 synapse/replication/slave/storage/_slaved_id_tracker.py create mode 100644 synapse/replication/slave/storage/events.py create mode 100644 tests/replication/slave/__init__.py create mode 100644 tests/replication/slave/storage/__init__.py create mode 100644 tests/replication/slave/storage/_base.py create mode 100644 tests/replication/slave/storage/test_events.py (limited to 'synapse/storage') diff --git a/synapse/replication/slave/__init__.py b/synapse/replication/slave/__init__.py new file mode 100644 index 0000000000..b7df13c9ee --- /dev/null +++ b/synapse/replication/slave/__init__.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/synapse/replication/slave/storage/__init__.py b/synapse/replication/slave/storage/__init__.py new file mode 100644 index 0000000000..b7df13c9ee --- /dev/null +++ b/synapse/replication/slave/storage/__init__.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py new file mode 100644 index 0000000000..46e43ce1c7 --- /dev/null +++ b/synapse/replication/slave/storage/_base.py @@ -0,0 +1,28 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.storage._base import SQLBaseStore +from twisted.internet import defer + + +class BaseSlavedStore(SQLBaseStore): + def __init__(self, db_conn, hs): + super(BaseSlavedStore, self).__init__(hs) + + def stream_positions(self): + return {} + + def process_replication(self, result): + return defer.succeed(None) diff --git a/synapse/replication/slave/storage/_slaved_id_tracker.py b/synapse/replication/slave/storage/_slaved_id_tracker.py new file mode 100644 index 0000000000..24b5c79d4a --- /dev/null +++ b/synapse/replication/slave/storage/_slaved_id_tracker.py @@ -0,0 +1,30 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.storage.util.id_generators import _load_current_id + + +class SlavedIdTracker(object): + def __init__(self, db_conn, table, column, extra_tables=[], step=1): + self.step = step + self._current = _load_current_id(db_conn, table, column, step) + for table, column in extra_tables: + self.advance(_load_current_id(db_conn, table, column)) + + def advance(self, new_id): + self._current = (max if self.step > 0 else min)(self._current, new_id) + + def get_current_token(self): + return self._current diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py new file mode 100644 index 0000000000..68b924e37b --- /dev/null +++ b/synapse/replication/slave/storage/events.py @@ -0,0 +1,198 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from ._base import BaseSlavedStore +from ._slaved_id_tracker import SlavedIdTracker + +from synapse.api.constants import EventTypes +from synapse.events import FrozenEvent +from synapse.storage import DataStore +from synapse.storage.room import RoomStore +from synapse.storage.roommember import RoomMemberStore +from synapse.storage.event_federation import EventFederationStore +from synapse.storage.state import StateStore +from synapse.util.caches.stream_change_cache import StreamChangeCache + +import ujson as json + +# So, um, we want to borrow a load of functions intended for reading from +# a DataStore, but we don't want to take functions that either write to the +# DataStore or are cached and don't have cache invalidation logic. +# +# Rather than write duplicate versions of those functions, or lift them to +# a common base class, we going to grab the underlying __func__ object from +# the method descriptor on the DataStore and chuck them into our class. + + +class SlavedEventStore(BaseSlavedStore): + + def __init__(self, db_conn, hs): + super(SlavedEventStore, self).__init__(db_conn, hs) + self._stream_id_gen = SlavedIdTracker( + db_conn, "events", "stream_ordering", + ) + self._backfill_id_gen = SlavedIdTracker( + db_conn, "events", "stream_ordering", step=-1 + ) + events_max = self._stream_id_gen.get_current_token() + event_cache_prefill, min_event_val = self._get_cache_dict( + db_conn, "events", + entity_column="room_id", + stream_column="stream_ordering", + max_value=events_max, + ) + self._events_stream_cache = StreamChangeCache( + "EventsRoomStreamChangeCache", min_event_val, + prefilled_cache=event_cache_prefill, + ) + + # Cached functions can't be accessed through a class instance so we need + # to reach inside the __dict__ to extract them. + get_room_name_and_aliases = RoomStore.__dict__["get_room_name_and_aliases"] + get_rooms_for_user = RoomMemberStore.__dict__["get_rooms_for_user"] + get_users_in_room = RoomMemberStore.__dict__["get_users_in_room"] + get_latest_event_ids_in_room = EventFederationStore.__dict__[ + "get_latest_event_ids_in_room" + ] + _get_current_state_for_key = StateStore.__dict__[ + "_get_current_state_for_key" + ] + + get_current_state = DataStore.get_current_state.__func__ + get_current_state_for_key = DataStore.get_current_state_for_key.__func__ + _get_rooms_for_user_where_membership_is_txn = ( + DataStore._get_rooms_for_user_where_membership_is_txn.__func__ + ) + get_rooms_for_user_where_membership_is = ( + DataStore.get_rooms_for_user_where_membership_is.__func__ + ) + get_membership_changes_for_user = ( + DataStore.get_membership_changes_for_user.__func__ + ) + get_room_events_max_id = DataStore.get_room_events_max_id.__func__ + get_room_events_stream_for_room = ( + DataStore.get_room_events_stream_for_room.__func__ + ) + _set_before_and_after = DataStore._set_before_and_after + + _get_events = DataStore._get_events.__func__ + _get_events_from_cache = DataStore._get_events_from_cache.__func__ + + _invalidate_get_event_cache = DataStore._invalidate_get_event_cache.__func__ + _parse_events_txn = DataStore._parse_events_txn.__func__ + _get_events_txn = DataStore._get_events_txn.__func__ + _fetch_events_txn = DataStore._fetch_events_txn.__func__ + _fetch_event_rows = DataStore._fetch_event_rows.__func__ + _get_event_from_row_txn = DataStore._get_event_from_row_txn.__func__ + + def stream_positions(self): + result = super(SlavedEventStore, self).stream_positions() + result["events"] = self._stream_id_gen.get_current_token() + result["backfilled"] = self._backfill_id_gen.get_current_token() + return result + + def process_replication(self, result): + state_resets = set( + r[0] for r in result.get("state_resets", {"rows": []})["rows"] + ) + + stream = result.get("events") + if stream: + self._stream_id_gen.advance(stream["position"]) + for row in stream["rows"]: + self._process_replication_row( + row, backfilled=False, state_resets=state_resets + ) + + stream = result.get("backfill") + if stream: + self._backfill_id_gen.advance(stream["position"]) + for row in stream["rows"]: + self._process_replication_row( + row, backfilled=True, state_resets=state_resets + ) + + stream = result.get("forward_ex_outliers") + if stream: + for row in stream["rows"]: + event_id = row[1] + self._invalidate_get_event_cache(event_id) + + stream = result.get("backward_ex_outliers") + if stream: + for row in stream["rows"]: + event_id = row[1] + self._invalidate_get_event_cache(event_id) + + return super(SlavedEventStore, self).process_replication(result) + + def _process_replication_row(self, row, backfilled, state_resets): + position = row[0] + internal = json.loads(row[1]) + event_json = json.loads(row[2]) + + event = FrozenEvent(event_json, internal_metadata_dict=internal) + self._invalidate_caches_for_event( + event, backfilled, reset_state=position in state_resets + ) + + def _invalidate_caches_for_event(self, event, backfilled, reset_state): + if reset_state: + self._get_current_state_for_key.invalidate_all() + self.get_rooms_for_user.invalidate_all() + self.get_users_in_room.invalidate((event.room_id,)) + # self.get_joined_hosts_for_room.invalidate((event.room_id,)) + self.get_room_name_and_aliases.invalidate((event.room_id,)) + + self._invalidate_get_event_cache(event.event_id) + + if not backfilled: + self._events_stream_cache.entity_has_changed( + event.room_id, event.internal_metadata.stream_ordering + ) + + # self.get_unread_event_push_actions_by_room_for_user.invalidate_many( + # (event.room_id,) + # ) + + if event.type == EventTypes.Redaction: + self._invalidate_get_event_cache(event.redacts) + + if event.type == EventTypes.Member: + self.get_rooms_for_user.invalidate((event.state_key,)) + # self.get_joined_hosts_for_room.invalidate((event.room_id,)) + self.get_users_in_room.invalidate((event.room_id,)) + # self._membership_stream_cache.entity_has_changed( + # event.state_key, event.internal_metadata.stream_ordering + # ) + + if not event.is_state(): + return + + if backfilled: + return + + if (not event.internal_metadata.is_invite_from_remote() + and event.internal_metadata.is_outlier()): + return + + self._get_current_state_for_key.invalidate(( + event.room_id, event.type, event.state_key + )) + + if event.type in [EventTypes.Name, EventTypes.Aliases]: + self.get_room_name_and_aliases.invalidate( + (event.room_id,) + ) + pass diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 5d299a1132..ee87a71719 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -1134,7 +1134,7 @@ class EventsStore(SQLBaseStore): upper_bound = current_forward_id sql = ( - "SELECT -event_stream_ordering FROM current_state_resets" + "SELECT event_stream_ordering FROM current_state_resets" " WHERE ? < event_stream_ordering" " AND event_stream_ordering <= ?" " ORDER BY event_stream_ordering ASC" @@ -1143,7 +1143,7 @@ class EventsStore(SQLBaseStore): state_resets = txn.fetchall() sql = ( - "SELECT -event_stream_ordering, event_id, state_group" + "SELECT event_stream_ordering, event_id, state_group" " FROM ex_outlier_stream" " WHERE ? > event_stream_ordering" " AND event_stream_ordering >= ?" diff --git a/tests/replication/slave/__init__.py b/tests/replication/slave/__init__.py new file mode 100644 index 0000000000..b7df13c9ee --- /dev/null +++ b/tests/replication/slave/__init__.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/replication/slave/storage/__init__.py b/tests/replication/slave/storage/__init__.py new file mode 100644 index 0000000000..b7df13c9ee --- /dev/null +++ b/tests/replication/slave/storage/__init__.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py new file mode 100644 index 0000000000..0f525a8943 --- /dev/null +++ b/tests/replication/slave/storage/_base.py @@ -0,0 +1,57 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer +from tests import unittest + +from synapse.replication.slave.storage.events import SlavedEventStore + +from mock import Mock, NonCallableMock +from tests.utils import setup_test_homeserver +from synapse.replication.resource import ReplicationResource + + +class BaseSlavedStoreTestCase(unittest.TestCase): + @defer.inlineCallbacks + def setUp(self): + self.hs = yield setup_test_homeserver( + "blue", + http_client=None, + replication_layer=Mock(), + ratelimiter=NonCallableMock(spec_set=[ + "send_message", + ]), + ) + self.hs.get_ratelimiter().send_message.return_value = (True, 0) + + self.replication = ReplicationResource(self.hs) + + self.master_store = self.hs.get_datastore() + self.slaved_store = SlavedEventStore(self.hs.get_db_conn(), self.hs) + self.event_id = 0 + + @defer.inlineCallbacks + def replicate(self): + streams = self.slaved_store.stream_positions() + result = yield self.replication.replicate(streams, 100) + yield self.slaved_store.process_replication(result) + + @defer.inlineCallbacks + def check(self, method, args, expected_result=None): + master_result = yield getattr(self.master_store, method)(*args) + slaved_result = yield getattr(self.slaved_store, method)(*args) + self.assertEqual(master_result, slaved_result) + if expected_result is not None: + self.assertEqual(master_result, expected_result) + self.assertEqual(slaved_result, expected_result) diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py new file mode 100644 index 0000000000..c30c7c6063 --- /dev/null +++ b/tests/replication/slave/storage/test_events.py @@ -0,0 +1,114 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import BaseSlavedStoreTestCase + +from synapse.types import UserID +from synapse.events import FrozenEvent +from synapse.events.snapshot import EventContext + +from twisted.internet import defer + +USER_ID = "@feeling:blue" +USER = UserID.from_string(USER_ID) +OUTLIER = {"outlier": True} +ROOM_ID = "!room:blue" + + +class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): + + @defer.inlineCallbacks + def test_room_name_and_aliases(self): + create = yield self.persist(type="m.room.create", key="", creator=USER_ID) + yield self.persist(type="m.room.member", key=USER_ID, membership="join") + yield self.persist(type="m.room.name", key="", name="name1") + yield self.persist( + type="m.room.aliases", key="blue", aliases=["#1:blue"] + ) + yield self.replicate() + yield self.check( + "get_room_name_and_aliases", (ROOM_ID,), ("name1", ["#1:blue"]) + ) + + # Set the room name. + yield self.persist(type="m.room.name", key="", name="name2") + yield self.replicate() + yield self.check( + "get_room_name_and_aliases", (ROOM_ID,), ("name2", ["#1:blue"]) + ) + + # Set the room aliases. + yield self.persist( + type="m.room.aliases", key="blue", aliases=["#2:blue"] + ) + yield self.replicate() + yield self.check( + "get_room_name_and_aliases", (ROOM_ID,), ("name2", ["#2:blue"]) + ) + + # Leave and join the room clobbering the state. + yield self.persist(type="m.room.member", key=USER_ID, membership="leave") + yield self.persist( + type="m.room.member", key=USER_ID, membership="join", + reset_state=[create] + ) + yield self.replicate() + + yield self.check( + "get_room_name_and_aliases", (ROOM_ID,), (None, []) + ) + + event_id = 0 + + @defer.inlineCallbacks + def persist( + self, sender=USER_ID, room_id=ROOM_ID, type={}, key=None, + internal={}, + state=None, reset_state=False, backfill=False, + depth=None, prev_events=[], auth_events=[], prev_state=[], + **content + ): + if depth is None: + depth = self.event_id + + event_dict = { + "sender": sender, + "type": type, + "content": content, + "event_id": "$%d:blue" % (self.event_id,), + "room_id": room_id, + "depth": depth, + "origin_server_ts": self.event_id, + "prev_events": prev_events, + "auth_events": auth_events, + } + if key is not None: + event_dict["state_key"] = key + event_dict["prev_state"] = prev_state + + event = FrozenEvent(event_dict, internal_metadata_dict=internal) + + self.event_id += 1 + + context = EventContext(current_state=state) + + if backfill: + yield self.master_store.persist_events( + [(event, context)], backfilled=True + ) + else: + yield self.master_store.persist_event( + event, context, current_state=reset_state + ) + defer.returnValue(event) -- cgit 1.4.1