From 1e25f62ee6a8aaa65c139e264ec2be1f8831eb16 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Wed, 30 Mar 2016 12:55:02 +0100 Subject: Use a stream id generator to assign state group ids --- synapse/storage/state.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) (limited to 'synapse/storage/state.py') diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 02cefdff26..30d1060ecd 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -64,12 +64,12 @@ class StateStore(SQLBaseStore): for group, state_map in group_to_state.items() }) - def _store_state_groups_txn(self, txn, event, context): - return self._store_mult_state_groups_txn(txn, [(event, context)]) - def _store_mult_state_groups_txn(self, txn, events_and_contexts): state_groups = {} for event, context in events_and_contexts: + if event.internal_metadata.is_outlier(): + continue + if context.current_state is None: continue @@ -82,7 +82,8 @@ class StateStore(SQLBaseStore): if event.is_state(): state_events[(event.type, event.state_key)] = event - state_group = self._state_groups_id_gen.get_next() + state_group = context.new_state_group_id + self._simple_insert_txn( txn, table="state_groups", @@ -114,11 +115,10 @@ class StateStore(SQLBaseStore): table="event_to_state_groups", values=[ { - "state_group": state_groups[event.event_id], - "event_id": event.event_id, + "state_group": state_group_id, + "event_id": event_id, } - for event, context in events_and_contexts - if context.current_state is not None + for event_id, state_group_id in state_groups.items() ], ) -- cgit 1.5.1 From 31a9eceda5cf00b0482baf1c8bf1e138c823f621 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Wed, 30 Mar 2016 15:58:20 +0100 Subject: Add a replication stream for state groups --- synapse/replication/resource.py | 36 +++++++++++++++++++++++++++++------- synapse/storage/events.py | 6 +++++- synapse/storage/state.py | 30 ++++++++++++++++++++++++++++++ tests/replication/test_resource.py | 30 +++++++++++++++++++++++++++--- 4 files changed, 91 insertions(+), 11 deletions(-) (limited to 'synapse/storage/state.py') diff --git a/synapse/replication/resource.py b/synapse/replication/resource.py index 8c1ae0fbc7..096a79a7a4 100644 --- a/synapse/replication/resource.py +++ b/synapse/replication/resource.py @@ -38,6 +38,7 @@ STREAM_NAMES = ( ("backfill",), ("push_rules",), ("pushers",), + ("state",), ) @@ -123,6 +124,7 @@ class ReplicationResource(Resource): backfill_token = yield self.store.get_current_backfill_token() push_rules_token, room_stream_token = self.store.get_push_rules_stream_token() pushers_token = self.store.get_pushers_stream_token() + state_token = self.store.get_state_stream_token() defer.returnValue(_ReplicationToken( room_stream_token, @@ -133,6 +135,7 @@ class ReplicationResource(Resource): backfill_token, push_rules_token, pushers_token, + state_token, )) @request_handler @@ -156,6 +159,7 @@ class ReplicationResource(Resource): yield self.receipts(writer, current_token, limit) yield self.push_rules(writer, current_token, limit) yield self.pushers(writer, current_token, limit) + yield self.state(writer, current_token, limit) self.streams(writer, current_token) logger.info("Replicated %d rows", writer.total) @@ -205,12 +209,12 @@ class ReplicationResource(Resource): current_token.backfill, current_token.events, limit ) - writer.write_header_and_rows( - "events", events_rows, ("position", "internal", "json") - ) - writer.write_header_and_rows( - "backfill", backfill_rows, ("position", "internal", "json") - ) + writer.write_header_and_rows("events", events_rows, ( + "position", "internal", "json", "state_group" + )) + writer.write_header_and_rows("backfill", backfill_rows, ( + "position", "internal", "json", "state_group" + )) @defer.inlineCallbacks def presence(self, writer, current_token): @@ -320,6 +324,24 @@ class ReplicationResource(Resource): "position", "user_id", "app_id", "pushkey" )) + @defer.inlineCallbacks + def state(self, writer, current_token, limit): + current_position = current_token.state + + state = parse_integer(writer.request, "state") + if state is not None: + state_groups, state_group_state = ( + yield self.store.get_all_new_state_groups( + state, current_position, limit + ) + ) + writer.write_header_and_rows("state_groups", state_groups, ( + "position", "room_id", "event_id" + )) + writer.write_header_and_rows("state_group_state", state_group_state, ( + "position", "type", "state_key", "event_id" + )) + class _Writer(object): """Writes the streams as a JSON object as the response to the request""" @@ -350,7 +372,7 @@ class _Writer(object): class _ReplicationToken(collections.namedtuple("_ReplicationToken", ( "events", "presence", "typing", "receipts", "account_data", "backfill", - "push_rules", "pushers" + "push_rules", "pushers", "state" ))): __slots__ = [] diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 5f675ab09b..a4b8995496 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -1097,10 +1097,12 @@ class EventsStore(SQLBaseStore): new events or as backfilled events""" def get_all_new_events_txn(txn): sql = ( - "SELECT e.stream_ordering, ej.internal_metadata, ej.json" + "SELECT e.stream_ordering, ej.internal_metadata, ej.json, eg.state_group" " FROM events as e" " JOIN event_json as ej" " ON e.event_id = ej.event_id AND e.room_id = ej.room_id" + " LEFT JOIN event_to_state_groups as eg" + " ON e.event_id = eg.event_id" " WHERE ? < e.stream_ordering AND e.stream_ordering <= ?" " ORDER BY e.stream_ordering ASC" " LIMIT ?" @@ -1116,6 +1118,8 @@ class EventsStore(SQLBaseStore): " FROM events as e" " JOIN event_json as ej" " ON e.event_id = ej.event_id AND e.room_id = ej.room_id" + " LEFT JOIN event_to_state_groups as eg" + " ON e.event_id = eg.event_id" " WHERE ? > e.stream_ordering AND e.stream_ordering >= ?" " ORDER BY e.stream_ordering DESC" " LIMIT ?" diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 30d1060ecd..7fc9a4f264 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -429,3 +429,33 @@ class StateStore(SQLBaseStore): } defer.returnValue(results) + + def get_all_new_state_groups(self, last_id, current_id, limit): + def get_all_new_state_groups_txn(txn): + sql = ( + "SELECT id, room_id, event_id FROM state_groups" + " WHERE ? < id AND id <= ? ORDER BY id LIMIT ?" + ) + txn.execute(sql, (last_id, current_id, limit)) + groups = txn.fetchall() + + if not groups: + return ([], []) + + lower_bound = groups[0][0] + upper_bound = groups[-1][0] + sql = ( + "SELECT state_group, type, state_key, event_id" + " FROM state_groups_state" + " WHERE ? <= state_group AND state_group <= ?" + ) + + txn.execute(sql, (lower_bound, upper_bound)) + state_group_state = txn.fetchall() + return (groups, state_group_state) + return self.runInteraction( + "get_all_new_state_groups", get_all_new_state_groups_txn + ) + + def get_state_stream_token(self): + return self._state_groups_id_gen.get_max_token() diff --git a/tests/replication/test_resource.py b/tests/replication/test_resource.py index f4b5fb3328..b1dd7b4a74 100644 --- a/tests/replication/test_resource.py +++ b/tests/replication/test_resource.py @@ -58,15 +58,21 @@ class ReplicationResourceCase(unittest.TestCase): self.assertEquals(body, {}) @defer.inlineCallbacks - def test_events(self): - get = self.get(events="-1", timeout="0") + def test_events_and_state(self): + get = self.get(events="-1", state="-1", timeout="0") yield self.hs.get_handlers().room_creation_handler.create_room( Requester(self.user, "", False), {} ) code, body = yield get self.assertEquals(code, 200) self.assertEquals(body["events"]["field_names"], [ - "position", "internal", "json" + "position", "internal", "json", "state_group" + ]) + self.assertEquals(body["state_groups"]["field_names"], [ + "position", "room_id", "event_id" + ]) + self.assertEquals(body["state_group_state"]["field_names"], [ + "position", "type", "state_key", "event_id" ]) @defer.inlineCallbacks @@ -132,6 +138,7 @@ class ReplicationResourceCase(unittest.TestCase): test_timeout_backfill = _test_timeout("backfill") test_timeout_push_rules = _test_timeout("push_rules") test_timeout_pushers = _test_timeout("pushers") + test_timeout_state = _test_timeout("state") @defer.inlineCallbacks def send_text_message(self, room_id, message): @@ -182,4 +189,21 @@ class ReplicationResourceCase(unittest.TestCase): ) response_body = json.loads(response_json) + if response_code == 200: + self.check_response(response_body) + defer.returnValue((response_code, response_body)) + + def check_response(self, response_body): + for name, stream in response_body.items(): + self.assertIn("field_names", stream) + field_names = stream["field_names"] + self.assertIn("rows", stream) + self.assertTrue(stream["rows"]) + for row in stream["rows"]: + self.assertEquals( + len(row), len(field_names), + "%s: len(row = %r) == len(field_names = %r)" % ( + name, row, field_names + ) + ) -- cgit 1.5.1 From e36bfbab38def70e0fcc1bafcecb6e666dbbc1ad Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Fri, 1 Apr 2016 13:29:05 +0100 Subject: Use a stream id generator for backfilled ids --- synapse/storage/__init__.py | 20 ++++-------- synapse/storage/account_data.py | 4 +-- synapse/storage/events.py | 19 +++-------- synapse/storage/presence.py | 6 ++-- synapse/storage/push_rule.py | 2 +- synapse/storage/pusher.py | 2 +- synapse/storage/receipts.py | 6 ++-- synapse/storage/state.py | 2 +- synapse/storage/stream.py | 2 +- synapse/storage/tags.py | 6 ++-- synapse/storage/util/id_generators.py | 61 +++++++++++++++++++++++------------ 11 files changed, 69 insertions(+), 61 deletions(-) (limited to 'synapse/storage/state.py') diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index aaad38039e..f87e907cd8 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -88,15 +88,6 @@ class DataStore(RoomMemberStore, RoomStore, self.hs = hs self.database_engine = hs.database_engine - cur = db_conn.cursor() - try: - cur.execute("SELECT MIN(stream_ordering) FROM events",) - rows = cur.fetchall() - self.min_stream_token = rows[0][0] if rows and rows[0] and rows[0][0] else -1 - self.min_stream_token = min(self.min_stream_token, -1) - finally: - cur.close() - self.client_ip_last_seen = Cache( name="client_ip_last_seen", keylen=4, @@ -105,6 +96,9 @@ class DataStore(RoomMemberStore, RoomStore, self._stream_id_gen = StreamIdGenerator( db_conn, "events", "stream_ordering" ) + self._backfill_id_gen = StreamIdGenerator( + db_conn, "events", "stream_ordering", direction=-1 + ) self._receipts_id_gen = StreamIdGenerator( db_conn, "receipts_linearized", "stream_id" ) @@ -129,7 +123,7 @@ class DataStore(RoomMemberStore, RoomStore, extra_tables=[("deleted_pushers", "stream_id")], ) - events_max = self._stream_id_gen.get_max_token() + events_max = self._stream_id_gen.get_current_token() event_cache_prefill, min_event_val = self._get_cache_dict( db_conn, "events", entity_column="room_id", @@ -145,7 +139,7 @@ class DataStore(RoomMemberStore, RoomStore, "MembershipStreamChangeCache", events_max, ) - account_max = self._account_data_id_gen.get_max_token() + account_max = self._account_data_id_gen.get_current_token() self._account_data_stream_cache = StreamChangeCache( "AccountDataAndTagsChangeCache", account_max, ) @@ -156,7 +150,7 @@ class DataStore(RoomMemberStore, RoomStore, db_conn, "presence_stream", entity_column="user_id", stream_column="stream_id", - max_value=self._presence_id_gen.get_max_token(), + max_value=self._presence_id_gen.get_current_token(), ) self.presence_stream_cache = StreamChangeCache( "PresenceStreamChangeCache", min_presence_val, @@ -167,7 +161,7 @@ class DataStore(RoomMemberStore, RoomStore, db_conn, "push_rules_stream", entity_column="user_id", stream_column="stream_id", - max_value=self._push_rules_stream_id_gen.get_max_token()[0], + max_value=self._push_rules_stream_id_gen.get_current_token()[0], ) self.push_rules_stream_cache = StreamChangeCache( diff --git a/synapse/storage/account_data.py b/synapse/storage/account_data.py index faddefe219..7a7fbf1e52 100644 --- a/synapse/storage/account_data.py +++ b/synapse/storage/account_data.py @@ -200,7 +200,7 @@ class AccountDataStore(SQLBaseStore): "add_room_account_data", add_account_data_txn, next_id ) - result = self._account_data_id_gen.get_max_token() + result = self._account_data_id_gen.get_current_token() defer.returnValue(result) @defer.inlineCallbacks @@ -239,7 +239,7 @@ class AccountDataStore(SQLBaseStore): "add_user_account_data", add_account_data_txn, next_id ) - result = self._account_data_id_gen.get_max_token() + result = self._account_data_id_gen.get_current_token() defer.returnValue(result) def _update_max_stream_id(self, txn, next_id): diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 83279d65fa..4ab23c1597 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -24,7 +24,6 @@ from synapse.util.logutils import log_function from synapse.api.constants import EventTypes from canonicaljson import encode_canonical_json -from contextlib import contextmanager from collections import namedtuple import logging @@ -66,14 +65,9 @@ class EventsStore(SQLBaseStore): return if backfilled: - start = self.min_stream_token - 1 - self.min_stream_token -= len(events_and_contexts) + 1 - stream_orderings = range(start, self.min_stream_token, -1) - - @contextmanager - def stream_ordering_manager(): - yield stream_orderings - stream_ordering_manager = stream_ordering_manager() + stream_ordering_manager = self._backfill_id_gen.get_next_mult( + len(events_and_contexts) + ) else: stream_ordering_manager = self._stream_id_gen.get_next_mult( len(events_and_contexts) @@ -130,7 +124,7 @@ class EventsStore(SQLBaseStore): except _RollbackButIsFineException: pass - max_persisted_id = yield self._stream_id_gen.get_max_token() + max_persisted_id = yield self._stream_id_gen.get_current_token() defer.returnValue((stream_ordering, max_persisted_id)) @defer.inlineCallbacks @@ -1117,10 +1111,7 @@ class EventsStore(SQLBaseStore): def get_current_backfill_token(self): """The current minimum token that backfilled events have reached""" - - # TODO: Fix race with the persit_event txn by using one of the - # stream id managers - return -self.min_stream_token + return -self._backfill_id_gen.get_current_token() def get_all_new_events(self, last_backfill_id, last_forward_id, current_backfill_id, current_forward_id, limit): diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py index 4cec31e316..59b4ef5ce6 100644 --- a/synapse/storage/presence.py +++ b/synapse/storage/presence.py @@ -68,7 +68,9 @@ class PresenceStore(SQLBaseStore): self._update_presence_txn, stream_orderings, presence_states, ) - defer.returnValue((stream_orderings[-1], self._presence_id_gen.get_max_token())) + defer.returnValue(( + stream_orderings[-1], self._presence_id_gen.get_current_token() + )) def _update_presence_txn(self, txn, stream_orderings, presence_states): for stream_id, state in zip(stream_orderings, presence_states): @@ -155,7 +157,7 @@ class PresenceStore(SQLBaseStore): defer.returnValue([UserPresenceState(**row) for row in rows]) def get_current_presence_token(self): - return self._presence_id_gen.get_max_token() + return self._presence_id_gen.get_current_token() def allow_presence_visible(self, observed_localpart, observer_userid): return self._simple_insert( diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index 9dbad2fd5f..d2bf7f2aec 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -392,7 +392,7 @@ class PushRuleStore(SQLBaseStore): """Get the position of the push rules stream. Returns a pair of a stream id for the push_rules stream and the room stream ordering it corresponds to.""" - return self._push_rules_stream_id_gen.get_max_token() + return self._push_rules_stream_id_gen.get_current_token() def have_push_rules_changed_for_user(self, user_id, last_id): if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id): diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py index 87b2ac5773..d1669c778a 100644 --- a/synapse/storage/pusher.py +++ b/synapse/storage/pusher.py @@ -78,7 +78,7 @@ class PusherStore(SQLBaseStore): defer.returnValue(rows) def get_pushers_stream_token(self): - return self._pushers_id_gen.get_max_token() + return self._pushers_id_gen.get_current_token() def get_all_updated_pushers(self, last_id, current_id, limit): def get_all_updated_pushers_txn(txn): diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py index 6b9d848eaa..4befebc8e2 100644 --- a/synapse/storage/receipts.py +++ b/synapse/storage/receipts.py @@ -31,7 +31,7 @@ class ReceiptsStore(SQLBaseStore): super(ReceiptsStore, self).__init__(hs) self._receipts_stream_cache = StreamChangeCache( - "ReceiptsRoomChangeCache", self._receipts_id_gen.get_max_token() + "ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token() ) @cached(num_args=2) @@ -221,7 +221,7 @@ class ReceiptsStore(SQLBaseStore): defer.returnValue(results) def get_max_receipt_stream_id(self): - return self._receipts_id_gen.get_max_token() + return self._receipts_id_gen.get_current_token() def insert_linearized_receipt_txn(self, txn, room_id, receipt_type, user_id, event_id, data, stream_id): @@ -346,7 +346,7 @@ class ReceiptsStore(SQLBaseStore): room_id, receipt_type, user_id, event_ids, data ) - max_persisted_id = self._stream_id_gen.get_max_token() + max_persisted_id = self._stream_id_gen.get_current_token() defer.returnValue((stream_id, max_persisted_id)) diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 7fc9a4f264..8644830657 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -458,4 +458,4 @@ class StateStore(SQLBaseStore): ) def get_state_stream_token(self): - return self._state_groups_id_gen.get_max_token() + return self._state_groups_id_gen.get_current_token() diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index cf84938be5..76bcd9cd00 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -539,7 +539,7 @@ class StreamStore(SQLBaseStore): @defer.inlineCallbacks def get_room_events_max_id(self, direction='f'): - token = yield self._stream_id_gen.get_max_token() + token = yield self._stream_id_gen.get_current_token() if direction != 'b': defer.returnValue("s%d" % (token,)) else: diff --git a/synapse/storage/tags.py b/synapse/storage/tags.py index a0e6b42b30..9da23f34cb 100644 --- a/synapse/storage/tags.py +++ b/synapse/storage/tags.py @@ -30,7 +30,7 @@ class TagsStore(SQLBaseStore): Returns: A deferred int. """ - return self._account_data_id_gen.get_max_token() + return self._account_data_id_gen.get_current_token() @cached() def get_tags_for_user(self, user_id): @@ -200,7 +200,7 @@ class TagsStore(SQLBaseStore): self.get_tags_for_user.invalidate((user_id,)) - result = self._account_data_id_gen.get_max_token() + result = self._account_data_id_gen.get_current_token() defer.returnValue(result) @defer.inlineCallbacks @@ -222,7 +222,7 @@ class TagsStore(SQLBaseStore): self.get_tags_for_user.invalidate((user_id,)) - result = self._account_data_id_gen.get_max_token() + result = self._account_data_id_gen.get_current_token() defer.returnValue(result) def _update_revision_txn(self, txn, user_id, room_id, next_id): diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index a02dfc7d58..03f2aa6a5c 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -21,7 +21,7 @@ import threading class IdGenerator(object): def __init__(self, db_conn, table, column): self._lock = threading.Lock() - self._next_id = _load_max_id(db_conn, table, column) + self._next_id = _load_current_id(db_conn, table, column) def get_next(self): with self._lock: @@ -29,12 +29,16 @@ class IdGenerator(object): return self._next_id -def _load_max_id(db_conn, table, column): +def _load_current_id(db_conn, table, column, direction=1): cur = db_conn.cursor() - cur.execute("SELECT MAX(%s) FROM %s" % (column, table,)) + if direction == 1: + cur.execute("SELECT MAX(%s) FROM %s" % (column, table,)) + else: + cur.execute("SELECT MIN(%s) FROM %s" % (column, table,)) val, = cur.fetchone() cur.close() - return int(val) if val else 1 + current_id = int(val) if val else direction + return (max if direction == 1 else min)(current_id, direction) class StreamIdGenerator(object): @@ -45,17 +49,30 @@ class StreamIdGenerator(object): all ids less than or equal to it have completed. This handles the fact that persistence of events can complete out of order. + :param connection db_conn: A database connection to use to fetch the + initial value of the generator from. + :param str table: A database table to read the initial value of the id + generator from. + :param str column: The column of the database table to read the initial + value from the id generator from. + :param list extra_tables: List of pairs of database tables and columns to + use to source the initial value of the generator from. The value with + the largest magnitude is used. + :param int direction: which direction the stream ids grow in. +1 to grow + upwards, -1 to grow downwards. + Usage: with stream_id_gen.get_next() as stream_id: # ... persist event ... """ - def __init__(self, db_conn, table, column, extra_tables=[]): + def __init__(self, db_conn, table, column, extra_tables=[], direction=1): self._lock = threading.Lock() - self._current_max = _load_max_id(db_conn, table, column) + self._direction = direction + self._current = _load_current_id(db_conn, table, column, direction) for table, column in extra_tables: - self._current_max = max( - self._current_max, - _load_max_id(db_conn, table, column) + self._current = (max if direction > 0 else min)( + self._current, + _load_current_id(db_conn, table, column, direction) ) self._unfinished_ids = deque() @@ -66,8 +83,8 @@ class StreamIdGenerator(object): # ... persist event ... """ with self._lock: - self._current_max += 1 - next_id = self._current_max + self._current += self._direction + next_id = self._current self._unfinished_ids.append(next_id) @@ -88,8 +105,12 @@ class StreamIdGenerator(object): # ... persist events ... """ with self._lock: - next_ids = range(self._current_max + 1, self._current_max + n + 1) - self._current_max += n + next_ids = range( + self._current + self._direction, + self._current + self._direction * (n + 1), + self._direction + ) + self._current += n for next_id in next_ids: self._unfinished_ids.append(next_id) @@ -105,15 +126,15 @@ class StreamIdGenerator(object): return manager() - def get_max_token(self): + def get_current_token(self): """Returns the maximum stream id such that all stream ids less than or equal to it have been successfully persisted. """ with self._lock: if self._unfinished_ids: - return self._unfinished_ids[0] - 1 + return self._unfinished_ids[0] - self._direction - return self._current_max + return self._current class ChainedIdGenerator(object): @@ -125,7 +146,7 @@ class ChainedIdGenerator(object): def __init__(self, chained_generator, db_conn, table, column): self.chained_generator = chained_generator self._lock = threading.Lock() - self._current_max = _load_max_id(db_conn, table, column) + self._current_max = _load_current_id(db_conn, table, column) self._unfinished_ids = deque() def get_next(self): @@ -137,7 +158,7 @@ class ChainedIdGenerator(object): with self._lock: self._current_max += 1 next_id = self._current_max - chained_id = self.chained_generator.get_max_token() + chained_id = self.chained_generator.get_current_token() self._unfinished_ids.append((next_id, chained_id)) @@ -151,7 +172,7 @@ class ChainedIdGenerator(object): return manager() - def get_max_token(self): + def get_current_token(self): """Returns the maximum stream id such that all stream ids less than or equal to it have been successfully persisted. """ @@ -160,4 +181,4 @@ class ChainedIdGenerator(object): stream_id, chained_id = self._unfinished_ids[0] return (stream_id - 1, chained_id) - return (self._current_max, self.chained_generator.get_max_token()) + return (self._current_max, self.chained_generator.get_current_token()) -- cgit 1.5.1 From 2a37467fa1358eb41513893efe44cbd294dca36c Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Fri, 1 Apr 2016 16:08:59 +0100 Subject: Use google style doc strings. pycharm supports them so there is no need to use the other format. Might as well convert the existing strings to reduce the risk of people accidentally cargo culting the wrong doc string format. --- setup.cfg | 3 ++ synapse/handlers/_base.py | 27 +++++++----- synapse/handlers/auth.py | 26 +++++++---- synapse/handlers/federation.py | 23 +++++----- synapse/handlers/room_member.py | 48 ++++++++++----------- synapse/handlers/sync.py | 49 +++++++++++++-------- synapse/http/servlet.py | 81 ++++++++++++++++++++++------------- synapse/notifier.py | 15 ++++--- synapse/push/baserules.py | 8 ++-- synapse/rest/client/v2_alpha/sync.py | 79 ++++++++++++++++++---------------- synapse/state.py | 19 ++++---- synapse/storage/event_push_actions.py | 5 ++- synapse/storage/registration.py | 15 ++++--- synapse/storage/state.py | 13 +++--- 14 files changed, 242 insertions(+), 169 deletions(-) (limited to 'synapse/storage/state.py') diff --git a/setup.cfg b/setup.cfg index f8cc13c840..5ebce1c56b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -17,3 +17,6 @@ ignore = [flake8] max-line-length = 90 ignore = W503 ; W503 requires that binary operators be at the end, not start, of lines. Erik doesn't like it. + +[pep8] +max-line-length = 90 diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index 90eabb6eb7..5601ecea6e 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -41,8 +41,9 @@ class BaseHandler(object): """ Common base class for the event handlers. - :type store: synapse.storage.events.StateStore - :type state_handler: synapse.state.StateHandler + Attributes: + store (synapse.storage.events.StateStore): + state_handler (synapse.state.StateHandler): """ def __init__(self, hs): @@ -65,11 +66,12 @@ class BaseHandler(object): """ Returns dict of user_id -> list of events that user is allowed to see. - :param (str, bool) user_tuples: (user id, is_peeking) for each - user to be checked. is_peeking should be true if: - * the user is not currently a member of the room, and: - * the user has not been a member of the room since the given - events + Args: + user_tuples (str, bool): (user id, is_peeking) for each user to be + checked. is_peeking should be true if: + * the user is not currently a member of the room, and: + * the user has not been a member of the room since the + given events """ forgotten = yield defer.gatherResults([ self.store.who_forgot_in_room( @@ -165,13 +167,16 @@ class BaseHandler(object): """ Check which events a user is allowed to see - :param str user_id: user id to be checked - :param [synapse.events.EventBase] events: list of events to be checked - :param bool is_peeking should be True if: + Args: + user_id(str): user id to be checked + events([synapse.events.EventBase]): list of events to be checked + is_peeking(bool): should be True if: * the user is not currently a member of the room, and: * the user has not been a member of the room since the given events - :rtype [synapse.events.EventBase] + + Returns: + [synapse.events.EventBase] """ types = ( (EventTypes.RoomHistoryVisibility, ""), diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 82d458b424..d5d6faa85f 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -163,9 +163,13 @@ class AuthHandler(BaseHandler): def get_session_id(self, clientdict): """ Gets the session ID for a client given the client dictionary - :param clientdict: The dictionary sent by the client in the request - :return: The string session ID the client sent. If the client did not - send a session ID, returns None. + + Args: + clientdict: The dictionary sent by the client in the request + + Returns: + str|None: The string session ID the client sent. If the client did + not send a session ID, returns None. """ sid = None if clientdict and 'auth' in clientdict: @@ -179,9 +183,11 @@ class AuthHandler(BaseHandler): Store a key-value pair into the sessions data associated with this request. This data is stored server-side and cannot be modified by the client. - :param session_id: (string) The ID of this session as returned from check_auth - :param key: (string) The key to store the data under - :param value: (any) The data to store + + Args: + session_id (string): The ID of this session as returned from check_auth + key (string): The key to store the data under + value (any): The data to store """ sess = self._get_session_info(session_id) sess.setdefault('serverdict', {})[key] = value @@ -190,9 +196,11 @@ class AuthHandler(BaseHandler): def get_session_data(self, session_id, key, default=None): """ Retrieve data stored with set_session_data - :param session_id: (string) The ID of this session as returned from check_auth - :param key: (string) The key to store the data under - :param default: (any) Value to return if the key has not been set + + Args: + session_id (string): The ID of this session as returned from check_auth + key (string): The key to store the data under + default (any): Value to return if the key has not been set """ sess = self._get_session_info(session_id) return sess.setdefault('serverdict', {}).get(key, default) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 4a35344d32..092802b973 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1706,13 +1706,15 @@ class FederationHandler(BaseHandler): def _check_signature(self, event, auth_events): """ Checks that the signature in the event is consistent with its invite. - :param event (Event): The m.room.member event to check - :param auth_events (dict<(event type, state_key), event>) - :raises - AuthError if signature didn't match any keys, or key has been + Args: + event (Event): The m.room.member event to check + auth_events (dict<(event type, state_key), event>): + + Raises: + AuthError: if signature didn't match any keys, or key has been revoked, - SynapseError if a transient error meant a key couldn't be checked + SynapseError: if a transient error meant a key couldn't be checked for revocation. """ signed = event.content["third_party_invite"]["signed"] @@ -1754,12 +1756,13 @@ class FederationHandler(BaseHandler): """ Checks whether public_key has been revoked. - :param public_key (str): base-64 encoded public key. - :param url (str): Key revocation URL. + Args: + public_key (str): base-64 encoded public key. + url (str): Key revocation URL. - :raises - AuthError if they key has been revoked. - SynapseError if a transient error meant a key couldn't be checked + Raises: + AuthError: if they key has been revoked. + SynapseError: if a transient error meant a key couldn't be checked for revocation. """ try: diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 5fdbd3adcc..01f833c371 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -411,7 +411,7 @@ class RoomMemberHandler(BaseHandler): address (str): The third party identifier (e.g. "foo@example.com"). Returns: - (str) the matrix ID of the 3pid, or None if it is not recognized. + str: the matrix ID of the 3pid, or None if it is not recognized. """ try: data = yield self.hs.get_simple_http_client().get_json( @@ -545,29 +545,29 @@ class RoomMemberHandler(BaseHandler): """ Asks an identity server for a third party invite. - :param id_server (str): hostname + optional port for the identity server. - :param medium (str): The literal string "email". - :param address (str): The third party address being invited. - :param room_id (str): The ID of the room to which the user is invited. - :param inviter_user_id (str): The user ID of the inviter. - :param room_alias (str): An alias for the room, for cosmetic - notifications. - :param room_avatar_url (str): The URL of the room's avatar, for cosmetic - notifications. - :param room_join_rules (str): The join rules of the email - (e.g. "public"). - :param room_name (str): The m.room.name of the room. - :param inviter_display_name (str): The current display name of the - inviter. - :param inviter_avatar_url (str): The URL of the inviter's avatar. - - :return: A deferred tuple containing: - token (str): The token which must be signed to prove authenticity. - public_keys ([{"public_key": str, "key_validity_url": str}]): - public_key is a base64-encoded ed25519 public key. - fallback_public_key: One element from public_keys. - display_name (str): A user-friendly name to represent the invited - user. + Args: + id_server (str): hostname + optional port for the identity server. + medium (str): The literal string "email". + address (str): The third party address being invited. + room_id (str): The ID of the room to which the user is invited. + inviter_user_id (str): The user ID of the inviter. + room_alias (str): An alias for the room, for cosmetic notifications. + room_avatar_url (str): The URL of the room's avatar, for cosmetic + notifications. + room_join_rules (str): The join rules of the email (e.g. "public"). + room_name (str): The m.room.name of the room. + inviter_display_name (str): The current display name of the + inviter. + inviter_avatar_url (str): The URL of the inviter's avatar. + + Returns: + A deferred tuple containing: + token (str): The token which must be signed to prove authenticity. + public_keys ([{"public_key": str, "key_validity_url": str}]): + public_key is a base64-encoded ed25519 public key. + fallback_public_key: One element from public_keys. + display_name (str): A user-friendly name to represent the invited + user. """ is_url = "%s%s/_matrix/identity/api/v1/store-invite" % ( diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 48ab5707e1..20a0626574 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -671,7 +671,8 @@ class SyncHandler(BaseHandler): def load_filtered_recents(self, room_id, sync_config, now_token, since_token=None, recents=None, newly_joined_room=False): """ - :returns a Deferred TimelineBatch + Returns: + a Deferred TimelineBatch """ with Measure(self.clock, "load_filtered_recents"): filtering_factor = 2 @@ -838,8 +839,11 @@ class SyncHandler(BaseHandler): """ Get the room state after the given event - :param synapse.events.EventBase event: event of interest - :return: A Deferred map from ((type, state_key)->Event) + Args: + event(synapse.events.EventBase): event of interest + + Returns: + A Deferred map from ((type, state_key)->Event) """ state = yield self.store.get_state_for_event(event.event_id) if event.is_state(): @@ -850,9 +854,13 @@ class SyncHandler(BaseHandler): @defer.inlineCallbacks def get_state_at(self, room_id, stream_position): """ Get the room state at a particular stream position - :param str room_id: room for which to get state - :param StreamToken stream_position: point at which to get state - :returns: A Deferred map from ((type, state_key)->Event) + + Args: + room_id(str): room for which to get state + stream_position(StreamToken): point at which to get state + + Returns: + A Deferred map from ((type, state_key)->Event) """ last_events, token = yield self.store.get_recent_events_for_room( room_id, end_token=stream_position.room_key, limit=1, @@ -873,15 +881,18 @@ class SyncHandler(BaseHandler): """ Works out the differnce in state between the start of the timeline and the previous sync. - :param str room_id - :param TimelineBatch batch: The timeline batch for the room that will - be sent to the user. - :param sync_config - :param str since_token: Token of the end of the previous batch. May be None. - :param str now_token: Token of the end of the current batch. - :param bool full_state: Whether to force returning the full state. + Args: + room_id(str): + batch(synapse.handlers.sync.TimelineBatch): The timeline batch for + the room that will be sent to the user. + sync_config(synapse.handlers.sync.SyncConfig): + since_token(str|None): Token of the end of the previous batch. May + be None. + now_token(str): Token of the end of the current batch. + full_state(bool): Whether to force returning the full state. - :returns A new event dictionary + Returns: + A deferred new event dictionary """ # TODO(mjark) Check if the state events were received by the server # after the previous sync, since we need to include those state @@ -953,11 +964,13 @@ class SyncHandler(BaseHandler): Check if the user has just joined the given room (so should be given the full state) - :param sync_config: - :param dict[(str,str), synapse.events.FrozenEvent] state_delta: the - difference in state since the last sync + Args: + sync_config(synapse.handlers.sync.SyncConfig): + state_delta(dict[(str,str), synapse.events.FrozenEvent]): the + difference in state since the last sync - :returns A deferred Tuple (state_delta, limited) + Returns: + A deferred Tuple (state_delta, limited) """ join_event = state_delta.get(( EventTypes.Member, sync_config.user.to_string()), None) diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index 1c8bd8666f..e41afeab8e 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -26,14 +26,19 @@ logger = logging.getLogger(__name__) def parse_integer(request, name, default=None, required=False): """Parse an integer parameter from the request string - :param request: the twisted HTTP request. - :param name (str): the name of the query parameter. - :param default: value to use if the parameter is absent, defaults to None. - :param required (bool): whether to raise a 400 SynapseError if the - parameter is absent, defaults to False. - :return: An int value or the default. - :raises - SynapseError if the parameter is absent and required, or if the + Args: + request: the twisted HTTP request. + name (str): the name of the query parameter. + default (int|None): value to use if the parameter is absent, defaults + to None. + required (bool): whether to raise a 400 SynapseError if the + parameter is absent, defaults to False. + + Returns: + int|None: An int value or the default. + + Raises: + SynapseError: if the parameter is absent and required, or if the parameter is present and not an integer. """ if name in request.args: @@ -53,14 +58,19 @@ def parse_integer(request, name, default=None, required=False): def parse_boolean(request, name, default=None, required=False): """Parse a boolean parameter from the request query string - :param request: the twisted HTTP request. - :param name (str): the name of the query parameter. - :param default: value to use if the parameter is absent, defaults to None. - :param required (bool): whether to raise a 400 SynapseError if the - parameter is absent, defaults to False. - :return: A bool value or the default. - :raises - SynapseError if the parameter is absent and required, or if the + Args: + request: the twisted HTTP request. + name (str): the name of the query parameter. + default (bool|None): value to use if the parameter is absent, defaults + to None. + required (bool): whether to raise a 400 SynapseError if the + parameter is absent, defaults to False. + + Returns: + bool|None: A bool value or the default. + + Raises: + SynapseError: if the parameter is absent and required, or if the parameter is present and not one of "true" or "false". """ @@ -88,15 +98,20 @@ def parse_string(request, name, default=None, required=False, allowed_values=None, param_type="string"): """Parse a string parameter from the request query string. - :param request: the twisted HTTP request. - :param name (str): the name of the query parameter. - :param default: value to use if the parameter is absent, defaults to None. - :param required (bool): whether to raise a 400 SynapseError if the - parameter is absent, defaults to False. - :param allowed_values (list): List of allowed values for the string, - or None if any value is allowed, defaults to None - :return: A string value or the default. - :raises + Args: + request: the twisted HTTP request. + name (str): the name of the query parameter. + default (str|None): value to use if the parameter is absent, defaults + to None. + required (bool): whether to raise a 400 SynapseError if the + parameter is absent, defaults to False. + allowed_values (list[str]): List of allowed values for the string, + or None if any value is allowed, defaults to None + + Returns: + str|None: A string value or the default. + + Raises: SynapseError if the parameter is absent and required, or if the parameter is present, must be one of a list of allowed values and is not one of those allowed values. @@ -122,9 +137,13 @@ def parse_string(request, name, default=None, required=False, def parse_json_value_from_request(request): """Parse a JSON value from the body of a twisted HTTP request. - :param request: the twisted HTTP request. - :returns: The JSON value. - :raises + Args: + request: the twisted HTTP request. + + Returns: + The JSON value. + + Raises: SynapseError if the request body couldn't be decoded as JSON. """ try: @@ -143,8 +162,10 @@ def parse_json_value_from_request(request): def parse_json_object_from_request(request): """Parse a JSON object from the body of a twisted HTTP request. - :param request: the twisted HTTP request. - :raises + Args: + request: the twisted HTTP request. + + Raises: SynapseError if the request body couldn't be decoded as JSON or if it wasn't a JSON object. """ diff --git a/synapse/notifier.py b/synapse/notifier.py index f00cd8c588..6af7a8f424 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -503,13 +503,14 @@ class Notifier(object): def wait_for_replication(self, callback, timeout): """Wait for an event to happen. - :param callback: - Gets called whenever an event happens. If this returns a truthy - value then ``wait_for_replication`` returns, otherwise it waits - for another event. - :param int timeout: - How many milliseconds to wait for callback return a truthy value. - :returns: + Args: + callback: Gets called whenever an event happens. If this returns a + truthy value then ``wait_for_replication`` returns, otherwise + it waits for another event. + timeout: How many milliseconds to wait for callback return a truthy + value. + + Returns: A deferred that resolves with the value returned by the callback. """ listener = _NotificationListener(None) diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py index 792af70eb7..6add94beeb 100644 --- a/synapse/push/baserules.py +++ b/synapse/push/baserules.py @@ -19,9 +19,11 @@ import copy def list_with_base_rules(rawrules): """Combine the list of rules set by the user with the default push rules - :param list rawrules: The rules the user has modified or set. - :returns: A new list with the rules set by the user combined with the - defaults. + Args: + rawrules(list): The rules the user has modified or set. + + Returns: + A new list with the rules set by the user combined with the defaults. """ ruleslist = [] diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index c5785d7074..60d3dc4030 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -199,15 +199,17 @@ class SyncRestServlet(RestServlet): """ Encode the joined rooms in a sync result - :param list[synapse.handlers.sync.JoinedSyncResult] rooms: list of sync - results for rooms this user is joined to - :param int time_now: current time - used as a baseline for age - calculations - :param int token_id: ID of the user's auth token - used for namespacing - of transaction IDs - - :return: the joined rooms list, in our response format - :rtype: dict[str, dict[str, object]] + Args: + rooms(list[synapse.handlers.sync.JoinedSyncResult]): list of sync + results for rooms this user is joined to + time_now(int): current time - used as a baseline for age + calculations + token_id(int): ID of the user's auth token - used for namespacing + of transaction IDs + + Returns: + dict[str, dict[str, object]]: the joined rooms list, in our + response format """ joined = {} for room in rooms: @@ -221,15 +223,17 @@ class SyncRestServlet(RestServlet): """ Encode the invited rooms in a sync result - :param list[synapse.handlers.sync.InvitedSyncResult] rooms: list of - sync results for rooms this user is joined to - :param int time_now: current time - used as a baseline for age - calculations - :param int token_id: ID of the user's auth token - used for namespacing + Args: + rooms(list[synapse.handlers.sync.InvitedSyncResult]): list of + sync results for rooms this user is joined to + time_now(int): current time - used as a baseline for age + calculations + token_id(int): ID of the user's auth token - used for namespacing of transaction IDs - :return: the invited rooms list, in our response format - :rtype: dict[str, dict[str, object]] + Returns: + dict[str, dict[str, object]]: the invited rooms list, in our + response format """ invited = {} for room in rooms: @@ -251,15 +255,17 @@ class SyncRestServlet(RestServlet): """ Encode the archived rooms in a sync result - :param list[synapse.handlers.sync.ArchivedSyncResult] rooms: list of - sync results for rooms this user is joined to - :param int time_now: current time - used as a baseline for age - calculations - :param int token_id: ID of the user's auth token - used for namespacing - of transaction IDs - - :return: the invited rooms list, in our response format - :rtype: dict[str, dict[str, object]] + Args: + rooms (list[synapse.handlers.sync.ArchivedSyncResult]): list of + sync results for rooms this user is joined to + time_now(int): current time - used as a baseline for age + calculations + token_id(int): ID of the user's auth token - used for namespacing + of transaction IDs + + Returns: + dict[str, dict[str, object]]: The invited rooms list, in our + response format """ joined = {} for room in rooms: @@ -272,17 +278,18 @@ class SyncRestServlet(RestServlet): @staticmethod def encode_room(room, time_now, token_id, joined=True): """ - :param JoinedSyncResult|ArchivedSyncResult room: sync result for a - single room - :param int time_now: current time - used as a baseline for age - calculations - :param int token_id: ID of the user's auth token - used for namespacing - of transaction IDs - :param joined: True if the user is joined to this room - will mean - we handle ephemeral events - - :return: the room, encoded in our response format - :rtype: dict[str, object] + Args: + room (JoinedSyncResult|ArchivedSyncResult): sync result for a + single room + time_now (int): current time - used as a baseline for age + calculations + token_id (int): ID of the user's auth token - used for namespacing + of transaction IDs + joined (bool): True if the user is joined to this room - will mean + we handle ephemeral events + + Returns: + dict[str, object]: the room, encoded in our response format """ def serialize(event): # TODO(mjark): Respect formatting requirements in the filter. diff --git a/synapse/state.py b/synapse/state.py index 41d32e664a..4a9e148de7 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -86,7 +86,8 @@ class StateHandler(object): If `event_type` is specified, then the method returns only the one event (or None) with that `event_type` and `state_key`. - :returns map from (type, state_key) to event + Returns: + map from (type, state_key) to event """ event_ids = yield self.store.get_latest_event_ids_in_room(room_id) @@ -176,10 +177,11 @@ class StateHandler(object): """ Given a list of event_ids this method fetches the state at each event, resolves conflicts between them and returns them. - :returns a Deferred tuple of (`state_group`, `state`, `prev_state`). - `state_group` is the name of a state group if one and only one is - involved. `state` is a map from (type, state_key) to event, and - `prev_state` is a list of event ids. + Returns: + a Deferred tuple of (`state_group`, `state`, `prev_state`). + `state_group` is the name of a state group if one and only one is + involved. `state` is a map from (type, state_key) to event, and + `prev_state` is a list of event ids. """ logger.debug("resolve_state_groups event_ids %s", event_ids) @@ -251,9 +253,10 @@ class StateHandler(object): def _resolve_events(self, state_sets, event_type=None, state_key=""): """ - :returns a tuple (new_state, prev_states). new_state is a map - from (type, state_key) to event. prev_states is a list of event_ids. - :rtype: (dict[(str, str), synapse.events.FrozenEvent], list[str]) + Returns + (dict[(str, str), synapse.events.FrozenEvent], list[str]): a tuple + (new_state, prev_states). new_state is a map from (type, state_key) + to event. prev_states is a list of event_ids. """ with Measure(self.clock, "state._resolve_events"): state = {} diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py index dc5830450a..3933b6e2c5 100644 --- a/synapse/storage/event_push_actions.py +++ b/synapse/storage/event_push_actions.py @@ -26,8 +26,9 @@ logger = logging.getLogger(__name__) class EventPushActionsStore(SQLBaseStore): def _set_push_actions_for_event_and_users_txn(self, txn, event, tuples): """ - :param event: the event set actions for - :param tuples: list of tuples of (user_id, actions) + Args: + event: the event set actions for + tuples: list of tuples of (user_id, actions) """ values = [] for uid, actions in tuples: diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index bd4eb88a92..d46a963bb8 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -458,12 +458,15 @@ class RegistrationStore(SQLBaseStore): """ Gets the 3pid's guest access token if exists, else saves access_token. - :param medium (str): Medium of the 3pid. Must be "email". - :param address (str): 3pid address. - :param access_token (str): The access token to persist if none is - already persisted. - :param inviter_user_id (str): User ID of the inviter. - :return (deferred str): Whichever access token is persisted at the end + Args: + medium (str): Medium of the 3pid. Must be "email". + address (str): 3pid address. + access_token (str): The access token to persist if none is + already persisted. + inviter_user_id (str): User ID of the inviter. + + Returns: + deferred str: Whichever access token is persisted at the end of this function call. """ def insert(txn): diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 7fc9a4f264..f84fd0e30a 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -249,11 +249,14 @@ class StateStore(SQLBaseStore): """ Get the state dict corresponding to a particular event - :param str event_id: event whose state should be returned - :param list[(str, str)]|None types: List of (type, state_key) tuples - which are used to filter the state fetched. May be None, which - matches any key - :return: a deferred dict from (type, state_key) -> state_event + Args: + event_id(str): event whose state should be returned + types(list[(str, str)]|None): List of (type, state_key) tuples + which are used to filter the state fetched. May be None, which + matches any key + + Returns: + A deferred dict from (type, state_key) -> state_event """ state_map = yield self.get_state_for_events([event_id], types) defer.returnValue(state_map[event_id]) -- cgit 1.5.1 From 87f2dec8d475f038beb138bc56e3ef76fcb83ec6 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Wed, 6 Apr 2016 13:08:05 +0100 Subject: Make the cache objects be per instance rather than being global --- synapse/storage/receipts.py | 4 ++-- synapse/storage/registration.py | 2 +- synapse/storage/state.py | 4 ++-- synapse/util/caches/descriptors.py | 45 ++++++++++++++++++++------------------ 4 files changed, 29 insertions(+), 26 deletions(-) (limited to 'synapse/storage/state.py') diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py index 4befebc8e2..7fdd84bbdc 100644 --- a/synapse/storage/receipts.py +++ b/synapse/storage/receipts.py @@ -160,8 +160,8 @@ class ReceiptsStore(SQLBaseStore): "content": content, }]) - @cachedList(cache=get_linearized_receipts_for_room.cache, list_name="room_ids", - num_args=3, inlineCallbacks=True) + @cachedList(cached_method_name="get_linearized_receipts_for_room", + list_name="room_ids", num_args=3, inlineCallbacks=True) def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None): if not room_ids: defer.returnValue({}) diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index d46a963bb8..1f71773aaa 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -319,7 +319,7 @@ class RegistrationStore(SQLBaseStore): defer.returnValue(res if res else False) - @cachedList(cache=is_guest.cache, list_name="user_ids", num_args=1, + @cachedList(cached_method_name="is_guest", list_name="user_ids", num_args=1, inlineCallbacks=True) def are_guests(self, user_ids): sql = "SELECT name, is_guest FROM users WHERE name IN (%s)" % ( diff --git a/synapse/storage/state.py b/synapse/storage/state.py index e9f9406014..c5d2a3a6df 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -273,8 +273,8 @@ class StateStore(SQLBaseStore): desc="_get_state_group_for_event", ) - @cachedList(cache=_get_state_group_for_event.cache, list_name="event_ids", - num_args=1, inlineCallbacks=True) + @cachedList(cached_method_name="_get_state_group_for_event", + list_name="event_ids", num_args=1, inlineCallbacks=True) def _get_state_group_for_events(self, event_ids): """Returns mapping event_id -> state_group """ diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 35544b19fd..758f5982b0 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -167,7 +167,8 @@ class CacheDescriptor(object): % (orig.__name__,) ) - self.cache = Cache( + def __get__(self, obj, objtype=None): + cache = Cache( name=self.orig.__name__, max_entries=self.max_entries, keylen=self.num_args, @@ -175,14 +176,12 @@ class CacheDescriptor(object): tree=self.tree, ) - def __get__(self, obj, objtype=None): - @functools.wraps(self.orig) def wrapped(*args, **kwargs): arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names) try: - cached_result_d = self.cache.get(cache_key) + cached_result_d = cache.get(cache_key) observer = cached_result_d.observe() if DEBUG_CACHES: @@ -204,7 +203,7 @@ class CacheDescriptor(object): # Get the sequence number of the cache before reading from the # database so that we can tell if the cache is invalidated # while the SELECT is executing (SYN-369) - sequence = self.cache.sequence + sequence = cache.sequence ret = defer.maybeDeferred( preserve_context_over_fn, @@ -213,20 +212,21 @@ class CacheDescriptor(object): ) def onErr(f): - self.cache.invalidate(cache_key) + cache.invalidate(cache_key) return f ret.addErrback(onErr) ret = ObservableDeferred(ret, consumeErrors=True) - self.cache.update(sequence, cache_key, ret) + cache.update(sequence, cache_key, ret) return preserve_context_over_deferred(ret.observe()) - wrapped.invalidate = self.cache.invalidate - wrapped.invalidate_all = self.cache.invalidate_all - wrapped.invalidate_many = self.cache.invalidate_many - wrapped.prefill = self.cache.prefill + wrapped.invalidate = cache.invalidate + wrapped.invalidate_all = cache.invalidate_all + wrapped.invalidate_many = cache.invalidate_many + wrapped.prefill = cache.prefill + wrapped.cache = cache obj.__dict__[self.orig.__name__] = wrapped @@ -240,11 +240,12 @@ class CacheListDescriptor(object): the list of missing keys to the wrapped fucntion. """ - def __init__(self, orig, cache, list_name, num_args=1, inlineCallbacks=False): + def __init__(self, orig, cached_method_name, list_name, num_args=1, + inlineCallbacks=False): """ Args: orig (function) - cache (Cache) + method_name (str); The name of the chached method. list_name (str): Name of the argument which is the bulk lookup list num_args (int) inlineCallbacks (bool): Whether orig is a generator that should @@ -263,7 +264,7 @@ class CacheListDescriptor(object): self.arg_names = inspect.getargspec(orig).args[1:num_args + 1] self.list_pos = self.arg_names.index(self.list_name) - self.cache = cache + self.cached_method_name = cached_method_name self.sentinel = object() @@ -277,11 +278,13 @@ class CacheListDescriptor(object): if self.list_name not in self.arg_names: raise Exception( "Couldn't see arguments %r for %r." - % (self.list_name, cache.name,) + % (self.list_name, cached_method_name,) ) def __get__(self, obj, objtype=None): + cache = getattr(obj, self.cached_method_name).cache + @functools.wraps(self.orig) def wrapped(*args, **kwargs): arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) @@ -297,14 +300,14 @@ class CacheListDescriptor(object): key[self.list_pos] = arg try: - res = self.cache.get(tuple(key)).observe() + res = cache.get(tuple(key)).observe() res.addCallback(lambda r, arg: (arg, r), arg) cached[arg] = res except KeyError: missing.append(arg) if missing: - sequence = self.cache.sequence + sequence = cache.sequence args_to_call = dict(arg_dict) args_to_call[self.list_name] = missing @@ -327,10 +330,10 @@ class CacheListDescriptor(object): key = list(keyargs) key[self.list_pos] = arg - self.cache.update(sequence, tuple(key), observer) + cache.update(sequence, tuple(key), observer) def invalidate(f, key): - self.cache.invalidate(key) + cache.invalidate(key) return f observer.addErrback(invalidate, tuple(key)) @@ -370,7 +373,7 @@ def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False, tree=False): ) -def cachedList(cache, list_name, num_args=1, inlineCallbacks=False): +def cachedList(cached_method_name, list_name, num_args=1, inlineCallbacks=False): """Creates a descriptor that wraps a function in a `CacheListDescriptor`. Used to do batch lookups for an already created cache. A single argument @@ -400,7 +403,7 @@ def cachedList(cache, list_name, num_args=1, inlineCallbacks=False): """ return lambda orig: CacheListDescriptor( orig, - cache=cache, + cached_method_name=cached_method_name, list_name=list_name, num_args=num_args, inlineCallbacks=inlineCallbacks, -- cgit 1.5.1 From 61c7edfd34abdb9eaa7c8d3dd3dbef95b60de5de Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 19 Apr 2016 17:22:03 +0100 Subject: Add cache to _get_state_groups_from_groups --- synapse/storage/state.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) (limited to 'synapse/storage/state.py') diff --git a/synapse/storage/state.py b/synapse/storage/state.py index c5d2a3a6df..5b743db67a 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -174,6 +174,12 @@ class StateStore(SQLBaseStore): return [r[0] for r in results] return self.runInteraction("get_current_state_for_key", f) + @cached(num_args=2, lru=True, max_entries=1000) + def _get_state_group_from_group(self, group, types): + raise NotImplementedError() + + @cachedList(cached_method_name="_get_state_group_from_group", + list_name="groups", num_args=2, inlineCallbacks=True) def _get_state_groups_from_groups(self, groups, types): """Returns dictionary state_group -> (dict of (type, state_key) -> event id) """ @@ -201,18 +207,23 @@ class StateStore(SQLBaseStore): txn.execute(sql, args) rows = self.cursor_to_dict(txn) - results = {} + results = {group: {} for group in groups} for row in rows: key = (row["type"], row["state_key"]) - results.setdefault(row["state_group"], {})[key] = row["event_id"] + results[row["state_group"]][key] = row["event_id"] return results + results = {} + chunks = [groups[i:i + 100] for i in xrange(0, len(groups), 100)] for chunk in chunks: - return self.runInteraction( + res = yield self.runInteraction( "_get_state_groups_from_groups", f, chunk ) + results.update(res) + + defer.returnValue(results) @defer.inlineCallbacks def get_state_for_events(self, event_ids, types): @@ -359,6 +370,8 @@ class StateStore(SQLBaseStore): a `state_key` of None matches all state_keys. If `types` is None then all events are returned. """ + if types: + types = frozenset(types) results = {} missing_groups = [] if types is not None: -- cgit 1.5.1