diff options
Diffstat (limited to 'synapse/storage')
-rw-r--r-- | synapse/storage/__init__.py | 7 | ||||
-rw-r--r-- | synapse/storage/deviceinbox.py | 184 | ||||
-rw-r--r-- | synapse/storage/events.py | 78 | ||||
-rw-r--r-- | synapse/storage/push_rule.py | 23 | ||||
-rw-r--r-- | synapse/storage/receipts.py | 2 | ||||
-rw-r--r-- | synapse/storage/roommember.py | 106 | ||||
-rw-r--r-- | synapse/storage/schema/delta/34/device_inbox.sql | 24 | ||||
-rw-r--r-- | synapse/storage/schema/delta/34/sent_txn_purge.py | 32 | ||||
-rw-r--r-- | synapse/storage/state.py | 128 | ||||
-rw-r--r-- | synapse/storage/transactions.py | 4 |
10 files changed, 487 insertions, 101 deletions
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 7efc5bfeef..6c32773f25 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -36,6 +36,7 @@ from .push_rule import PushRuleStore from .media_repository import MediaRepositoryStore from .rejections import RejectionsStore from .event_push_actions import EventPushActionsStore +from .deviceinbox import DeviceInboxStore from .state import StateStore from .signatures import SignatureStore @@ -84,6 +85,7 @@ class DataStore(RoomMemberStore, RoomStore, OpenIdStore, ClientIpStore, DeviceStore, + DeviceInboxStore, ): def __init__(self, db_conn, hs): @@ -108,9 +110,12 @@ class DataStore(RoomMemberStore, RoomStore, self._presence_id_gen = StreamIdGenerator( db_conn, "presence_stream", "stream_id" ) + self._device_inbox_id_gen = StreamIdGenerator( + db_conn, "device_inbox", "stream_id" + ) self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id") - self._state_groups_id_gen = StreamIdGenerator(db_conn, "state_groups", "id") + self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id") self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id") self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id") self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id") diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py new file mode 100644 index 0000000000..68116b0394 --- /dev/null +++ b/synapse/storage/deviceinbox.py @@ -0,0 +1,184 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import ujson + +from twisted.internet import defer + +from ._base import SQLBaseStore + + +logger = logging.getLogger(__name__) + + +class DeviceInboxStore(SQLBaseStore): + + @defer.inlineCallbacks + def add_messages_to_device_inbox(self, messages_by_user_then_device): + """ + Args: + messages_by_user_and_device(dict): + Dictionary of user_id to device_id to message. + Returns: + A deferred stream_id that resolves when the messages have been + inserted. + """ + + def select_devices_txn(txn, user_id, devices): + if not devices: + return [] + sql = ( + "SELECT user_id, device_id FROM devices" + " WHERE user_id = ? AND device_id IN (" + + ",".join("?" * len(devices)) + + ")" + ) + # TODO: Maybe this needs to be done in batches if there are + # too many local devices for a given user. + args = [user_id] + devices + txn.execute(sql, args) + return [tuple(row) for row in txn.fetchall()] + + def add_messages_to_device_inbox_txn(txn, stream_id): + local_users_and_devices = set() + for user_id, messages_by_device in messages_by_user_then_device.items(): + local_users_and_devices.update( + select_devices_txn(txn, user_id, messages_by_device.keys()) + ) + + sql = ( + "INSERT INTO device_inbox" + " (user_id, device_id, stream_id, message_json)" + " VALUES (?,?,?,?)" + ) + rows = [] + for user_id, messages_by_device in messages_by_user_then_device.items(): + for device_id, message in messages_by_device.items(): + message_json = ujson.dumps(message) + # Only insert into the local inbox if the device exists on + # this server + if (user_id, device_id) in local_users_and_devices: + rows.append((user_id, device_id, stream_id, message_json)) + + txn.executemany(sql, rows) + + with self._device_inbox_id_gen.get_next() as stream_id: + yield self.runInteraction( + "add_messages_to_device_inbox", + add_messages_to_device_inbox_txn, + stream_id + ) + + defer.returnValue(self._device_inbox_id_gen.get_current_token()) + + def get_new_messages_for_device( + self, user_id, device_id, last_stream_id, current_stream_id, limit=100 + ): + """ + Args: + user_id(str): The recipient user_id. + device_id(str): The recipient device_id. + current_stream_id(int): The current position of the to device + message stream. + Returns: + Deferred ([dict], int): List of messages for the device and where + in the stream the messages got to. + """ + def get_new_messages_for_device_txn(txn): + sql = ( + "SELECT stream_id, message_json FROM device_inbox" + " WHERE user_id = ? AND device_id = ?" + " AND ? < stream_id AND stream_id <= ?" + " ORDER BY stream_id ASC" + " LIMIT ?" + ) + txn.execute(sql, ( + user_id, device_id, last_stream_id, current_stream_id, limit + )) + messages = [] + for row in txn.fetchall(): + stream_pos = row[0] + messages.append(ujson.loads(row[1])) + if len(messages) < limit: + stream_pos = current_stream_id + return (messages, stream_pos) + + return self.runInteraction( + "get_new_messages_for_device", get_new_messages_for_device_txn, + ) + + def delete_messages_for_device(self, user_id, device_id, up_to_stream_id): + """ + Args: + user_id(str): The recipient user_id. + device_id(str): The recipient device_id. + up_to_stream_id(int): Where to delete messages up to. + Returns: + A deferred that resolves when the messages have been deleted. + """ + def delete_messages_for_device_txn(txn): + sql = ( + "DELETE FROM device_inbox" + " WHERE user_id = ? AND device_id = ?" + " AND stream_id <= ?" + ) + txn.execute(sql, (user_id, device_id, up_to_stream_id)) + + return self.runInteraction( + "delete_messages_for_device", delete_messages_for_device_txn + ) + + def get_all_new_device_messages(self, last_pos, current_pos, limit): + """ + Args: + last_pos(int): + current_pos(int): + limit(int): + Returns: + A deferred list of rows from the device inbox + """ + if last_pos == current_pos: + return defer.succeed([]) + + def get_all_new_device_messages_txn(txn): + sql = ( + "SELECT stream_id FROM device_inbox" + " WHERE ? < stream_id AND stream_id <= ?" + " GROUP BY stream_id" + " ORDER BY stream_id ASC" + " LIMIT ?" + ) + txn.execute(sql, (last_pos, current_pos, limit)) + stream_ids = txn.fetchall() + if not stream_ids: + return [] + max_stream_id_in_limit = stream_ids[-1] + + sql = ( + "SELECT stream_id, user_id, device_id, message_json" + " FROM device_inbox" + " WHERE ? < stream_id AND stream_id <= ?" + " ORDER BY stream_id ASC" + ) + txn.execute(sql, (last_pos, max_stream_id_in_limit)) + return txn.fetchall() + + return self.runInteraction( + "get_all_new_device_messages", get_all_new_device_messages_txn + ) + + def get_to_device_stream_token(self): + return self._device_inbox_id_gen.get_current_token() diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 57e5005285..1a7d4c5199 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -271,39 +271,28 @@ class EventsStore(SQLBaseStore): len(events_and_contexts) ) - state_group_id_manager = self._state_groups_id_gen.get_next_mult( - len(events_and_contexts) - ) with stream_ordering_manager as stream_orderings: - with state_group_id_manager as state_group_ids: - for (event, context), stream, state_group_id in zip( - events_and_contexts, stream_orderings, state_group_ids - ): - event.internal_metadata.stream_ordering = stream - # Assign a state group_id in case a new id is needed for - # this context. In theory we only need to assign this - # for contexts that have current_state and aren't outliers - # but that make the code more complicated. Assigning an ID - # per event only causes the state_group_ids to grow as fast - # as the stream_ordering so in practise shouldn't be a problem. - context.new_state_group_id = state_group_id - - chunks = [ - events_and_contexts[x:x + 100] - for x in xrange(0, len(events_and_contexts), 100) - ] + for (event, context), stream, in zip( + events_and_contexts, stream_orderings + ): + event.internal_metadata.stream_ordering = stream - for chunk in chunks: - # We can't easily parallelize these since different chunks - # might contain the same event. :( - yield self.runInteraction( - "persist_events", - self._persist_events_txn, - events_and_contexts=chunk, - backfilled=backfilled, - delete_existing=delete_existing, - ) - persist_event_counter.inc_by(len(chunk)) + chunks = [ + events_and_contexts[x:x + 100] + for x in xrange(0, len(events_and_contexts), 100) + ] + + for chunk in chunks: + # We can't easily parallelize these since different chunks + # might contain the same event. :( + yield self.runInteraction( + "persist_events", + self._persist_events_txn, + events_and_contexts=chunk, + backfilled=backfilled, + delete_existing=delete_existing, + ) + persist_event_counter.inc_by(len(chunk)) @_retry_on_integrity_error @defer.inlineCallbacks @@ -312,19 +301,17 @@ class EventsStore(SQLBaseStore): delete_existing=False): try: with self._stream_id_gen.get_next() as stream_ordering: - with self._state_groups_id_gen.get_next() as state_group_id: - event.internal_metadata.stream_ordering = stream_ordering - context.new_state_group_id = state_group_id - yield self.runInteraction( - "persist_event", - self._persist_event_txn, - event=event, - context=context, - current_state=current_state, - backfilled=backfilled, - delete_existing=delete_existing, - ) - persist_event_counter.inc() + event.internal_metadata.stream_ordering = stream_ordering + yield self.runInteraction( + "persist_event", + self._persist_event_txn, + event=event, + context=context, + current_state=current_state, + backfilled=backfilled, + delete_existing=delete_existing, + ) + persist_event_counter.inc() except _RollbackButIsFineException: pass @@ -393,7 +380,6 @@ class EventsStore(SQLBaseStore): txn.call_after(self._get_current_state_for_key.invalidate_all) txn.call_after(self.get_rooms_for_user.invalidate_all) txn.call_after(self.get_users_in_room.invalidate, (event.room_id,)) - txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,)) # Add an entry to the current_state_resets table to record the point # where we clobbered the current state @@ -529,7 +515,7 @@ class EventsStore(SQLBaseStore): # Add an entry to the ex_outlier_stream table to replicate the # change in outlier status to our workers. stream_order = event.internal_metadata.stream_ordering - state_group_id = context.state_group or context.new_state_group_id + state_group_id = context.state_group self._simple_insert_txn( txn, table="ex_outlier_stream", diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index 78334a98cf..49721656b6 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -16,7 +16,6 @@ from ._base import SQLBaseStore from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList from synapse.push.baserules import list_with_base_rules -from synapse.api.constants import EventTypes, Membership from twisted.internet import defer import logging @@ -124,7 +123,8 @@ class PushRuleStore(SQLBaseStore): defer.returnValue(results) - def bulk_get_push_rules_for_room(self, room_id, state_group, current_state): + def bulk_get_push_rules_for_room(self, event, context): + state_group = context.state_group if not state_group: # If state_group is None it means it has yet to be assigned a # state group, i.e. we need to make sure that calls with a state_group @@ -132,11 +132,13 @@ class PushRuleStore(SQLBaseStore): # To do this we set the state_group to a new object as object() != object() state_group = object() - return self._bulk_get_push_rules_for_room(room_id, state_group, current_state) + return self._bulk_get_push_rules_for_room( + event.room_id, state_group, context.current_state_ids, event=event + ) @cachedInlineCallbacks(num_args=2, cache_context=True) - def _bulk_get_push_rules_for_room(self, room_id, state_group, current_state, - cache_context): + def _bulk_get_push_rules_for_room(self, room_id, state_group, current_state_ids, + cache_context, event=None): # We don't use `state_group`, its there so that we can cache based # on it. However, its important that its never None, since two current_state's # with a state_group of None are likely to be different. @@ -147,12 +149,15 @@ class PushRuleStore(SQLBaseStore): # their unread countss are correct in the event stream, but to avoid # generating them for bot / AS users etc, we only do so for people who've # sent a read receipt into the room. - local_users_in_room = set( - e.state_key for e in current_state.values() - if e.type == EventTypes.Member and e.membership == Membership.JOIN - and self.hs.is_mine_id(e.state_key) + + users_in_room = yield self._get_joined_users_from_context( + room_id, state_group, current_state_ids, + on_invalidate=cache_context.invalidate, + event=event, ) + local_users_in_room = set(u for u in users_in_room if self.hs.is_mine_id(u)) + # users in the room who have pushers need to get push rules run because # that's how their pushers work if_users_with_pushers = yield self.get_if_users_have_pushers( diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py index ccc3811e84..9747a04a9a 100644 --- a/synapse/storage/receipts.py +++ b/synapse/storage/receipts.py @@ -145,7 +145,7 @@ class ReceiptsStore(SQLBaseStore): defer.returnValue([ev for res in results.values() for ev in res]) - @cachedInlineCallbacks(num_args=3, max_entries=5000, tree=True) + @cachedInlineCallbacks(num_args=3, tree=True) def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None): """Get receipts for a single room for sending to clients. diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index a422ddf633..6ab10db328 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -20,7 +20,7 @@ from collections import namedtuple from ._base import SQLBaseStore from synapse.util.caches.descriptors import cached, cachedInlineCallbacks -from synapse.api.constants import Membership +from synapse.api.constants import Membership, EventTypes from synapse.types import get_domain_from_id import logging @@ -56,7 +56,6 @@ class RoomMemberStore(SQLBaseStore): for event in events: txn.call_after(self.get_rooms_for_user.invalidate, (event.state_key,)) - txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,)) txn.call_after(self.get_users_in_room.invalidate, (event.room_id,)) txn.call_after( self._membership_stream_cache.entity_has_changed, @@ -238,11 +237,6 @@ class RoomMemberStore(SQLBaseStore): return results - @cachedInlineCallbacks(max_entries=5000) - def get_joined_hosts_for_room(self, room_id): - user_ids = yield self.get_users_in_room(room_id) - defer.returnValue(set(get_domain_from_id(uid) for uid in user_ids)) - def _get_members_rows_txn(self, txn, room_id, membership=None, user_id=None): where_clause = "c.room_id = ?" where_values = [room_id] @@ -325,7 +319,8 @@ class RoomMemberStore(SQLBaseStore): @cachedInlineCallbacks(num_args=3) def was_forgotten_at(self, user_id, room_id, event_id): - """Returns whether user_id has elected to discard history for room_id at event_id. + """Returns whether user_id has elected to discard history for room_id at + event_id. event_id must be a membership event.""" def f(txn): @@ -358,3 +353,98 @@ class RoomMemberStore(SQLBaseStore): }, desc="who_forgot" ) + + def get_joined_users_from_context(self, event, context): + state_group = context.state_group + if not state_group: + # If state_group is None it means it has yet to be assigned a + # state group, i.e. we need to make sure that calls with a state_group + # of None don't hit previous cached calls with a None state_group. + # To do this we set the state_group to a new object as object() != object() + state_group = object() + + return self._get_joined_users_from_context( + event.room_id, state_group, context.current_state_ids, event=event, + ) + + def get_joined_users_from_state(self, room_id, state_group, state_ids): + if not state_group: + # If state_group is None it means it has yet to be assigned a + # state group, i.e. we need to make sure that calls with a state_group + # of None don't hit previous cached calls with a None state_group. + # To do this we set the state_group to a new object as object() != object() + state_group = object() + + return self._get_joined_users_from_context( + room_id, state_group, state_ids, + ) + + @cachedInlineCallbacks(num_args=2, cache_context=True) + def _get_joined_users_from_context(self, room_id, state_group, current_state_ids, + cache_context, event=None): + # We don't use `state_group`, its there so that we can cache based + # on it. However, its important that its never None, since two current_state's + # with a state_group of None are likely to be different. + # See bulk_get_push_rules_for_room for how we work around this. + assert state_group is not None + + member_event_ids = [ + e_id + for key, e_id in current_state_ids.iteritems() + if key[0] == EventTypes.Member + ] + + rows = yield self._simple_select_many_batch( + table="room_memberships", + column="event_id", + iterable=member_event_ids, + retcols=['user_id'], + keyvalues={ + "membership": Membership.JOIN, + }, + batch_size=1000, + desc="_get_joined_users_from_context", + ) + + users_in_room = set(row["user_id"] for row in rows) + if event is not None and event.type == EventTypes.Member: + if event.membership == Membership.JOIN: + if event.event_id in member_event_ids: + users_in_room.add(event.state_key) + + defer.returnValue(users_in_room) + + def is_host_joined(self, room_id, host, state_group, state_ids): + if not state_group: + # If state_group is None it means it has yet to be assigned a + # state group, i.e. we need to make sure that calls with a state_group + # of None don't hit previous cached calls with a None state_group. + # To do this we set the state_group to a new object as object() != object() + state_group = object() + + return self._is_host_joined( + room_id, host, state_group, state_ids + ) + + @cachedInlineCallbacks(num_args=3) + def _is_host_joined(self, room_id, host, state_group, current_state_ids): + # We don't use `state_group`, its there so that we can cache based + # on it. However, its important that its never None, since two current_state's + # with a state_group of None are likely to be different. + # See bulk_get_push_rules_for_room for how we work around this. + assert state_group is not None + + for (etype, state_key), event_id in current_state_ids.items(): + if etype == EventTypes.Member: + try: + if get_domain_from_id(state_key) != host: + continue + except: + logger.warn("state_key not user_id: %s", state_key) + continue + + event = yield self.get_event(event_id, allow_none=True) + if event and event.content["membership"] == Membership.JOIN: + defer.returnValue(True) + + defer.returnValue(False) diff --git a/synapse/storage/schema/delta/34/device_inbox.sql b/synapse/storage/schema/delta/34/device_inbox.sql new file mode 100644 index 0000000000..e68844c74a --- /dev/null +++ b/synapse/storage/schema/delta/34/device_inbox.sql @@ -0,0 +1,24 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE device_inbox ( + user_id TEXT NOT NULL, + device_id TEXT NOT NULL, + stream_id BIGINT NOT NULL, + message_json TEXT NOT NULL -- {"type":, "sender":, "content",} +); + +CREATE INDEX device_inbox_user_stream_id ON device_inbox(user_id, device_id, stream_id); +CREATE INDEX device_inbox_stream_id ON device_inbox(stream_id); diff --git a/synapse/storage/schema/delta/34/sent_txn_purge.py b/synapse/storage/schema/delta/34/sent_txn_purge.py new file mode 100644 index 0000000000..81948e3431 --- /dev/null +++ b/synapse/storage/schema/delta/34/sent_txn_purge.py @@ -0,0 +1,32 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.storage.engines import PostgresEngine + +import logging + +logger = logging.getLogger(__name__) + + +def run_create(cur, database_engine, *args, **kwargs): + if isinstance(database_engine, PostgresEngine): + cur.execute("TRUNCATE sent_transactions") + else: + cur.execute("DELETE FROM sent_transactions") + + cur.execute("CREATE INDEX sent_transactions_ts ON sent_transactions(ts)") + + +def run_upgrade(cur, database_engine, *args, **kwargs): + pass diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 0e8fa93e1f..ec551b0b4f 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -44,11 +44,7 @@ class StateStore(SQLBaseStore): """ @defer.inlineCallbacks - def get_state_groups(self, room_id, event_ids): - """ Get the state groups for the given list of event_ids - - The return value is a dict mapping group names to lists of events. - """ + def get_state_groups_ids(self, room_id, event_ids): if not event_ids: defer.returnValue({}) @@ -59,36 +55,64 @@ class StateStore(SQLBaseStore): groups = set(event_to_groups.values()) group_to_state = yield self._get_state_for_groups(groups) + defer.returnValue(group_to_state) + + @defer.inlineCallbacks + def get_state_groups(self, room_id, event_ids): + """ Get the state groups for the given list of event_ids + + The return value is a dict mapping group names to lists of events. + """ + if not event_ids: + defer.returnValue({}) + + group_to_ids = yield self.get_state_groups_ids(room_id, event_ids) + + state_event_map = yield self.get_events( + [ + ev_id for group_ids in group_to_ids.values() + for ev_id in group_ids.values() + ], + get_prev_content=False + ) + defer.returnValue({ - group: state_map.values() - for group, state_map in group_to_state.items() + group: [ + state_event_map[v] for v in event_id_map.values() if v in state_event_map + ] + for group, event_id_map in group_to_ids.items() }) + def _have_persisted_state_group_txn(self, txn, state_group): + txn.execute( + "SELECT count(*) FROM state_groups WHERE id = ?", + (state_group,) + ) + row = txn.fetchone() + return row and row[0] + def _store_mult_state_groups_txn(self, txn, events_and_contexts): state_groups = {} for event, context in events_and_contexts: if event.internal_metadata.is_outlier(): continue - if context.current_state is None: - continue - - if context.state_group is not None: - state_groups[event.event_id] = context.state_group + if context.current_state_ids is None: continue - state_events = dict(context.current_state) + state_groups[event.event_id] = context.state_group - if event.is_state(): - state_events[(event.type, event.state_key)] = event + if self._have_persisted_state_group_txn(txn, context.state_group): + logger.info("Already persisted state_group: %r", context.state_group) + continue - state_group = context.new_state_group_id + state_event_ids = dict(context.current_state_ids) self._simple_insert_txn( txn, table="state_groups", values={ - "id": state_group, + "id": context.state_group, "room_id": event.room_id, "event_id": event.event_id, }, @@ -99,16 +123,15 @@ class StateStore(SQLBaseStore): table="state_groups_state", values=[ { - "state_group": state_group, - "room_id": state.room_id, - "type": state.type, - "state_key": state.state_key, - "event_id": state.event_id, + "state_group": context.state_group, + "room_id": event.room_id, + "type": key[0], + "state_key": key[1], + "event_id": state_id, } - for state in state_events.values() + for key, state_id in state_event_ids.items() ], ) - state_groups[event.event_id] = state_group self._simple_insert_many_txn( txn, @@ -248,6 +271,31 @@ class StateStore(SQLBaseStore): groups = set(event_to_groups.values()) group_to_state = yield self._get_state_for_groups(groups, types) + state_event_map = yield self.get_events( + [ev_id for sd in group_to_state.values() for ev_id in sd.values()], + get_prev_content=False + ) + + event_to_state = { + event_id: { + k: state_event_map[v] + for k, v in group_to_state[group].items() + if v in state_event_map + } + for event_id, group in event_to_groups.items() + } + + defer.returnValue({event: event_to_state[event] for event in event_ids}) + + @defer.inlineCallbacks + def get_state_ids_for_events(self, event_ids, types): + event_to_groups = yield self._get_state_group_for_events( + event_ids, + ) + + groups = set(event_to_groups.values()) + group_to_state = yield self._get_state_for_groups(groups, types) + event_to_state = { event_id: group_to_state[group] for event_id, group in event_to_groups.items() @@ -272,6 +320,23 @@ class StateStore(SQLBaseStore): state_map = yield self.get_state_for_events([event_id], types) defer.returnValue(state_map[event_id]) + @defer.inlineCallbacks + def get_state_ids_for_event(self, event_id, types=None): + """ + Get the state dict corresponding to a particular event + + Args: + event_id(str): event whose state should be returned + types(list[(str, str)]|None): List of (type, state_key) tuples + which are used to filter the state fetched. May be None, which + matches any key + + Returns: + A deferred dict from (type, state_key) -> state_event + """ + state_map = yield self.get_state_ids_for_events([event_id], types) + defer.returnValue(state_map[event_id]) + @cached(num_args=2, max_entries=10000) def _get_state_group_for_event(self, room_id, event_id): return self._simple_select_one_onecol( @@ -428,20 +493,13 @@ class StateStore(SQLBaseStore): full=(types is None), ) - state_events = yield self._get_events( - [ev_id for sd in results.values() for ev_id in sd.values()], - get_prev_content=False - ) - - state_events = {e.event_id: e for e in state_events} - # Remove all the entries with None values. The None values were just # used for bookkeeping in the cache. for group, state_dict in results.items(): results[group] = { - key: state_events[event_id] + key: event_id for key, event_id in state_dict.items() - if event_id and event_id in state_events + if event_id } defer.returnValue(results) @@ -473,5 +531,5 @@ class StateStore(SQLBaseStore): "get_all_new_state_groups", get_all_new_state_groups_txn ) - def get_state_stream_token(self): - return self._state_groups_id_gen.get_current_token() + def get_next_state_group(self): + return self._state_groups_id_gen.get_next() diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py index 58d4de4f1d..5055c04b24 100644 --- a/synapse/storage/transactions.py +++ b/synapse/storage/transactions.py @@ -245,7 +245,7 @@ class TransactionStore(SQLBaseStore): return self.cursor_to_dict(txn) - @cached() + @cached(max_entries=10000) def get_destination_retry_timings(self, destination): """Gets the current retry timings (if any) for a given destination. @@ -387,8 +387,10 @@ class TransactionStore(SQLBaseStore): def _cleanup_transactions(self): now = self._clock.time_msec() month_ago = now - 30 * 24 * 60 * 60 * 1000 + six_hours_ago = now - 6 * 60 * 60 * 1000 def _cleanup_transactions_txn(txn): txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,)) + txn.execute("DELETE FROM sent_transactions WHERE ts < ?", (six_hours_ago,)) return self.runInteraction("_persist_in_mem_txns", _cleanup_transactions_txn) |