diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 7efc5bfeef..6c32773f25 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -36,6 +36,7 @@ from .push_rule import PushRuleStore
from .media_repository import MediaRepositoryStore
from .rejections import RejectionsStore
from .event_push_actions import EventPushActionsStore
+from .deviceinbox import DeviceInboxStore
from .state import StateStore
from .signatures import SignatureStore
@@ -84,6 +85,7 @@ class DataStore(RoomMemberStore, RoomStore,
OpenIdStore,
ClientIpStore,
DeviceStore,
+ DeviceInboxStore,
):
def __init__(self, db_conn, hs):
@@ -108,9 +110,12 @@ class DataStore(RoomMemberStore, RoomStore,
self._presence_id_gen = StreamIdGenerator(
db_conn, "presence_stream", "stream_id"
)
+ self._device_inbox_id_gen = StreamIdGenerator(
+ db_conn, "device_inbox", "stream_id"
+ )
self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id")
- self._state_groups_id_gen = StreamIdGenerator(db_conn, "state_groups", "id")
+ self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id")
self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id")
self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py
new file mode 100644
index 0000000000..68116b0394
--- /dev/null
+++ b/synapse/storage/deviceinbox.py
@@ -0,0 +1,184 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import ujson
+
+from twisted.internet import defer
+
+from ._base import SQLBaseStore
+
+
+logger = logging.getLogger(__name__)
+
+
+class DeviceInboxStore(SQLBaseStore):
+
+ @defer.inlineCallbacks
+ def add_messages_to_device_inbox(self, messages_by_user_then_device):
+ """
+ Args:
+ messages_by_user_and_device(dict):
+ Dictionary of user_id to device_id to message.
+ Returns:
+ A deferred stream_id that resolves when the messages have been
+ inserted.
+ """
+
+ def select_devices_txn(txn, user_id, devices):
+ if not devices:
+ return []
+ sql = (
+ "SELECT user_id, device_id FROM devices"
+ " WHERE user_id = ? AND device_id IN ("
+ + ",".join("?" * len(devices))
+ + ")"
+ )
+ # TODO: Maybe this needs to be done in batches if there are
+ # too many local devices for a given user.
+ args = [user_id] + devices
+ txn.execute(sql, args)
+ return [tuple(row) for row in txn.fetchall()]
+
+ def add_messages_to_device_inbox_txn(txn, stream_id):
+ local_users_and_devices = set()
+ for user_id, messages_by_device in messages_by_user_then_device.items():
+ local_users_and_devices.update(
+ select_devices_txn(txn, user_id, messages_by_device.keys())
+ )
+
+ sql = (
+ "INSERT INTO device_inbox"
+ " (user_id, device_id, stream_id, message_json)"
+ " VALUES (?,?,?,?)"
+ )
+ rows = []
+ for user_id, messages_by_device in messages_by_user_then_device.items():
+ for device_id, message in messages_by_device.items():
+ message_json = ujson.dumps(message)
+ # Only insert into the local inbox if the device exists on
+ # this server
+ if (user_id, device_id) in local_users_and_devices:
+ rows.append((user_id, device_id, stream_id, message_json))
+
+ txn.executemany(sql, rows)
+
+ with self._device_inbox_id_gen.get_next() as stream_id:
+ yield self.runInteraction(
+ "add_messages_to_device_inbox",
+ add_messages_to_device_inbox_txn,
+ stream_id
+ )
+
+ defer.returnValue(self._device_inbox_id_gen.get_current_token())
+
+ def get_new_messages_for_device(
+ self, user_id, device_id, last_stream_id, current_stream_id, limit=100
+ ):
+ """
+ Args:
+ user_id(str): The recipient user_id.
+ device_id(str): The recipient device_id.
+ current_stream_id(int): The current position of the to device
+ message stream.
+ Returns:
+ Deferred ([dict], int): List of messages for the device and where
+ in the stream the messages got to.
+ """
+ def get_new_messages_for_device_txn(txn):
+ sql = (
+ "SELECT stream_id, message_json FROM device_inbox"
+ " WHERE user_id = ? AND device_id = ?"
+ " AND ? < stream_id AND stream_id <= ?"
+ " ORDER BY stream_id ASC"
+ " LIMIT ?"
+ )
+ txn.execute(sql, (
+ user_id, device_id, last_stream_id, current_stream_id, limit
+ ))
+ messages = []
+ for row in txn.fetchall():
+ stream_pos = row[0]
+ messages.append(ujson.loads(row[1]))
+ if len(messages) < limit:
+ stream_pos = current_stream_id
+ return (messages, stream_pos)
+
+ return self.runInteraction(
+ "get_new_messages_for_device", get_new_messages_for_device_txn,
+ )
+
+ def delete_messages_for_device(self, user_id, device_id, up_to_stream_id):
+ """
+ Args:
+ user_id(str): The recipient user_id.
+ device_id(str): The recipient device_id.
+ up_to_stream_id(int): Where to delete messages up to.
+ Returns:
+ A deferred that resolves when the messages have been deleted.
+ """
+ def delete_messages_for_device_txn(txn):
+ sql = (
+ "DELETE FROM device_inbox"
+ " WHERE user_id = ? AND device_id = ?"
+ " AND stream_id <= ?"
+ )
+ txn.execute(sql, (user_id, device_id, up_to_stream_id))
+
+ return self.runInteraction(
+ "delete_messages_for_device", delete_messages_for_device_txn
+ )
+
+ def get_all_new_device_messages(self, last_pos, current_pos, limit):
+ """
+ Args:
+ last_pos(int):
+ current_pos(int):
+ limit(int):
+ Returns:
+ A deferred list of rows from the device inbox
+ """
+ if last_pos == current_pos:
+ return defer.succeed([])
+
+ def get_all_new_device_messages_txn(txn):
+ sql = (
+ "SELECT stream_id FROM device_inbox"
+ " WHERE ? < stream_id AND stream_id <= ?"
+ " GROUP BY stream_id"
+ " ORDER BY stream_id ASC"
+ " LIMIT ?"
+ )
+ txn.execute(sql, (last_pos, current_pos, limit))
+ stream_ids = txn.fetchall()
+ if not stream_ids:
+ return []
+ max_stream_id_in_limit = stream_ids[-1]
+
+ sql = (
+ "SELECT stream_id, user_id, device_id, message_json"
+ " FROM device_inbox"
+ " WHERE ? < stream_id AND stream_id <= ?"
+ " ORDER BY stream_id ASC"
+ )
+ txn.execute(sql, (last_pos, max_stream_id_in_limit))
+ return txn.fetchall()
+
+ return self.runInteraction(
+ "get_all_new_device_messages", get_all_new_device_messages_txn
+ )
+
+ def get_to_device_stream_token(self):
+ return self._device_inbox_id_gen.get_current_token()
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 57e5005285..1a7d4c5199 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -271,39 +271,28 @@ class EventsStore(SQLBaseStore):
len(events_and_contexts)
)
- state_group_id_manager = self._state_groups_id_gen.get_next_mult(
- len(events_and_contexts)
- )
with stream_ordering_manager as stream_orderings:
- with state_group_id_manager as state_group_ids:
- for (event, context), stream, state_group_id in zip(
- events_and_contexts, stream_orderings, state_group_ids
- ):
- event.internal_metadata.stream_ordering = stream
- # Assign a state group_id in case a new id is needed for
- # this context. In theory we only need to assign this
- # for contexts that have current_state and aren't outliers
- # but that make the code more complicated. Assigning an ID
- # per event only causes the state_group_ids to grow as fast
- # as the stream_ordering so in practise shouldn't be a problem.
- context.new_state_group_id = state_group_id
-
- chunks = [
- events_and_contexts[x:x + 100]
- for x in xrange(0, len(events_and_contexts), 100)
- ]
+ for (event, context), stream, in zip(
+ events_and_contexts, stream_orderings
+ ):
+ event.internal_metadata.stream_ordering = stream
- for chunk in chunks:
- # We can't easily parallelize these since different chunks
- # might contain the same event. :(
- yield self.runInteraction(
- "persist_events",
- self._persist_events_txn,
- events_and_contexts=chunk,
- backfilled=backfilled,
- delete_existing=delete_existing,
- )
- persist_event_counter.inc_by(len(chunk))
+ chunks = [
+ events_and_contexts[x:x + 100]
+ for x in xrange(0, len(events_and_contexts), 100)
+ ]
+
+ for chunk in chunks:
+ # We can't easily parallelize these since different chunks
+ # might contain the same event. :(
+ yield self.runInteraction(
+ "persist_events",
+ self._persist_events_txn,
+ events_and_contexts=chunk,
+ backfilled=backfilled,
+ delete_existing=delete_existing,
+ )
+ persist_event_counter.inc_by(len(chunk))
@_retry_on_integrity_error
@defer.inlineCallbacks
@@ -312,19 +301,17 @@ class EventsStore(SQLBaseStore):
delete_existing=False):
try:
with self._stream_id_gen.get_next() as stream_ordering:
- with self._state_groups_id_gen.get_next() as state_group_id:
- event.internal_metadata.stream_ordering = stream_ordering
- context.new_state_group_id = state_group_id
- yield self.runInteraction(
- "persist_event",
- self._persist_event_txn,
- event=event,
- context=context,
- current_state=current_state,
- backfilled=backfilled,
- delete_existing=delete_existing,
- )
- persist_event_counter.inc()
+ event.internal_metadata.stream_ordering = stream_ordering
+ yield self.runInteraction(
+ "persist_event",
+ self._persist_event_txn,
+ event=event,
+ context=context,
+ current_state=current_state,
+ backfilled=backfilled,
+ delete_existing=delete_existing,
+ )
+ persist_event_counter.inc()
except _RollbackButIsFineException:
pass
@@ -393,7 +380,6 @@ class EventsStore(SQLBaseStore):
txn.call_after(self._get_current_state_for_key.invalidate_all)
txn.call_after(self.get_rooms_for_user.invalidate_all)
txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
- txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,))
# Add an entry to the current_state_resets table to record the point
# where we clobbered the current state
@@ -529,7 +515,7 @@ class EventsStore(SQLBaseStore):
# Add an entry to the ex_outlier_stream table to replicate the
# change in outlier status to our workers.
stream_order = event.internal_metadata.stream_ordering
- state_group_id = context.state_group or context.new_state_group_id
+ state_group_id = context.state_group
self._simple_insert_txn(
txn,
table="ex_outlier_stream",
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index 78334a98cf..49721656b6 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -16,7 +16,6 @@
from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
from synapse.push.baserules import list_with_base_rules
-from synapse.api.constants import EventTypes, Membership
from twisted.internet import defer
import logging
@@ -124,7 +123,8 @@ class PushRuleStore(SQLBaseStore):
defer.returnValue(results)
- def bulk_get_push_rules_for_room(self, room_id, state_group, current_state):
+ def bulk_get_push_rules_for_room(self, event, context):
+ state_group = context.state_group
if not state_group:
# If state_group is None it means it has yet to be assigned a
# state group, i.e. we need to make sure that calls with a state_group
@@ -132,11 +132,13 @@ class PushRuleStore(SQLBaseStore):
# To do this we set the state_group to a new object as object() != object()
state_group = object()
- return self._bulk_get_push_rules_for_room(room_id, state_group, current_state)
+ return self._bulk_get_push_rules_for_room(
+ event.room_id, state_group, context.current_state_ids, event=event
+ )
@cachedInlineCallbacks(num_args=2, cache_context=True)
- def _bulk_get_push_rules_for_room(self, room_id, state_group, current_state,
- cache_context):
+ def _bulk_get_push_rules_for_room(self, room_id, state_group, current_state_ids,
+ cache_context, event=None):
# We don't use `state_group`, its there so that we can cache based
# on it. However, its important that its never None, since two current_state's
# with a state_group of None are likely to be different.
@@ -147,12 +149,15 @@ class PushRuleStore(SQLBaseStore):
# their unread countss are correct in the event stream, but to avoid
# generating them for bot / AS users etc, we only do so for people who've
# sent a read receipt into the room.
- local_users_in_room = set(
- e.state_key for e in current_state.values()
- if e.type == EventTypes.Member and e.membership == Membership.JOIN
- and self.hs.is_mine_id(e.state_key)
+
+ users_in_room = yield self._get_joined_users_from_context(
+ room_id, state_group, current_state_ids,
+ on_invalidate=cache_context.invalidate,
+ event=event,
)
+ local_users_in_room = set(u for u in users_in_room if self.hs.is_mine_id(u))
+
# users in the room who have pushers need to get push rules run because
# that's how their pushers work
if_users_with_pushers = yield self.get_if_users_have_pushers(
diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py
index ccc3811e84..9747a04a9a 100644
--- a/synapse/storage/receipts.py
+++ b/synapse/storage/receipts.py
@@ -145,7 +145,7 @@ class ReceiptsStore(SQLBaseStore):
defer.returnValue([ev for res in results.values() for ev in res])
- @cachedInlineCallbacks(num_args=3, max_entries=5000, tree=True)
+ @cachedInlineCallbacks(num_args=3, tree=True)
def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
"""Get receipts for a single room for sending to clients.
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index a422ddf633..6ab10db328 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -20,7 +20,7 @@ from collections import namedtuple
from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
-from synapse.api.constants import Membership
+from synapse.api.constants import Membership, EventTypes
from synapse.types import get_domain_from_id
import logging
@@ -56,7 +56,6 @@ class RoomMemberStore(SQLBaseStore):
for event in events:
txn.call_after(self.get_rooms_for_user.invalidate, (event.state_key,))
- txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,))
txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
txn.call_after(
self._membership_stream_cache.entity_has_changed,
@@ -238,11 +237,6 @@ class RoomMemberStore(SQLBaseStore):
return results
- @cachedInlineCallbacks(max_entries=5000)
- def get_joined_hosts_for_room(self, room_id):
- user_ids = yield self.get_users_in_room(room_id)
- defer.returnValue(set(get_domain_from_id(uid) for uid in user_ids))
-
def _get_members_rows_txn(self, txn, room_id, membership=None, user_id=None):
where_clause = "c.room_id = ?"
where_values = [room_id]
@@ -325,7 +319,8 @@ class RoomMemberStore(SQLBaseStore):
@cachedInlineCallbacks(num_args=3)
def was_forgotten_at(self, user_id, room_id, event_id):
- """Returns whether user_id has elected to discard history for room_id at event_id.
+ """Returns whether user_id has elected to discard history for room_id at
+ event_id.
event_id must be a membership event."""
def f(txn):
@@ -358,3 +353,98 @@ class RoomMemberStore(SQLBaseStore):
},
desc="who_forgot"
)
+
+ def get_joined_users_from_context(self, event, context):
+ state_group = context.state_group
+ if not state_group:
+ # If state_group is None it means it has yet to be assigned a
+ # state group, i.e. we need to make sure that calls with a state_group
+ # of None don't hit previous cached calls with a None state_group.
+ # To do this we set the state_group to a new object as object() != object()
+ state_group = object()
+
+ return self._get_joined_users_from_context(
+ event.room_id, state_group, context.current_state_ids, event=event,
+ )
+
+ def get_joined_users_from_state(self, room_id, state_group, state_ids):
+ if not state_group:
+ # If state_group is None it means it has yet to be assigned a
+ # state group, i.e. we need to make sure that calls with a state_group
+ # of None don't hit previous cached calls with a None state_group.
+ # To do this we set the state_group to a new object as object() != object()
+ state_group = object()
+
+ return self._get_joined_users_from_context(
+ room_id, state_group, state_ids,
+ )
+
+ @cachedInlineCallbacks(num_args=2, cache_context=True)
+ def _get_joined_users_from_context(self, room_id, state_group, current_state_ids,
+ cache_context, event=None):
+ # We don't use `state_group`, its there so that we can cache based
+ # on it. However, its important that its never None, since two current_state's
+ # with a state_group of None are likely to be different.
+ # See bulk_get_push_rules_for_room for how we work around this.
+ assert state_group is not None
+
+ member_event_ids = [
+ e_id
+ for key, e_id in current_state_ids.iteritems()
+ if key[0] == EventTypes.Member
+ ]
+
+ rows = yield self._simple_select_many_batch(
+ table="room_memberships",
+ column="event_id",
+ iterable=member_event_ids,
+ retcols=['user_id'],
+ keyvalues={
+ "membership": Membership.JOIN,
+ },
+ batch_size=1000,
+ desc="_get_joined_users_from_context",
+ )
+
+ users_in_room = set(row["user_id"] for row in rows)
+ if event is not None and event.type == EventTypes.Member:
+ if event.membership == Membership.JOIN:
+ if event.event_id in member_event_ids:
+ users_in_room.add(event.state_key)
+
+ defer.returnValue(users_in_room)
+
+ def is_host_joined(self, room_id, host, state_group, state_ids):
+ if not state_group:
+ # If state_group is None it means it has yet to be assigned a
+ # state group, i.e. we need to make sure that calls with a state_group
+ # of None don't hit previous cached calls with a None state_group.
+ # To do this we set the state_group to a new object as object() != object()
+ state_group = object()
+
+ return self._is_host_joined(
+ room_id, host, state_group, state_ids
+ )
+
+ @cachedInlineCallbacks(num_args=3)
+ def _is_host_joined(self, room_id, host, state_group, current_state_ids):
+ # We don't use `state_group`, its there so that we can cache based
+ # on it. However, its important that its never None, since two current_state's
+ # with a state_group of None are likely to be different.
+ # See bulk_get_push_rules_for_room for how we work around this.
+ assert state_group is not None
+
+ for (etype, state_key), event_id in current_state_ids.items():
+ if etype == EventTypes.Member:
+ try:
+ if get_domain_from_id(state_key) != host:
+ continue
+ except:
+ logger.warn("state_key not user_id: %s", state_key)
+ continue
+
+ event = yield self.get_event(event_id, allow_none=True)
+ if event and event.content["membership"] == Membership.JOIN:
+ defer.returnValue(True)
+
+ defer.returnValue(False)
diff --git a/synapse/storage/schema/delta/34/device_inbox.sql b/synapse/storage/schema/delta/34/device_inbox.sql
new file mode 100644
index 0000000000..e68844c74a
--- /dev/null
+++ b/synapse/storage/schema/delta/34/device_inbox.sql
@@ -0,0 +1,24 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE device_inbox (
+ user_id TEXT NOT NULL,
+ device_id TEXT NOT NULL,
+ stream_id BIGINT NOT NULL,
+ message_json TEXT NOT NULL -- {"type":, "sender":, "content",}
+);
+
+CREATE INDEX device_inbox_user_stream_id ON device_inbox(user_id, device_id, stream_id);
+CREATE INDEX device_inbox_stream_id ON device_inbox(stream_id);
diff --git a/synapse/storage/schema/delta/34/sent_txn_purge.py b/synapse/storage/schema/delta/34/sent_txn_purge.py
new file mode 100644
index 0000000000..81948e3431
--- /dev/null
+++ b/synapse/storage/schema/delta/34/sent_txn_purge.py
@@ -0,0 +1,32 @@
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.storage.engines import PostgresEngine
+
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+def run_create(cur, database_engine, *args, **kwargs):
+ if isinstance(database_engine, PostgresEngine):
+ cur.execute("TRUNCATE sent_transactions")
+ else:
+ cur.execute("DELETE FROM sent_transactions")
+
+ cur.execute("CREATE INDEX sent_transactions_ts ON sent_transactions(ts)")
+
+
+def run_upgrade(cur, database_engine, *args, **kwargs):
+ pass
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 0e8fa93e1f..ec551b0b4f 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -44,11 +44,7 @@ class StateStore(SQLBaseStore):
"""
@defer.inlineCallbacks
- def get_state_groups(self, room_id, event_ids):
- """ Get the state groups for the given list of event_ids
-
- The return value is a dict mapping group names to lists of events.
- """
+ def get_state_groups_ids(self, room_id, event_ids):
if not event_ids:
defer.returnValue({})
@@ -59,36 +55,64 @@ class StateStore(SQLBaseStore):
groups = set(event_to_groups.values())
group_to_state = yield self._get_state_for_groups(groups)
+ defer.returnValue(group_to_state)
+
+ @defer.inlineCallbacks
+ def get_state_groups(self, room_id, event_ids):
+ """ Get the state groups for the given list of event_ids
+
+ The return value is a dict mapping group names to lists of events.
+ """
+ if not event_ids:
+ defer.returnValue({})
+
+ group_to_ids = yield self.get_state_groups_ids(room_id, event_ids)
+
+ state_event_map = yield self.get_events(
+ [
+ ev_id for group_ids in group_to_ids.values()
+ for ev_id in group_ids.values()
+ ],
+ get_prev_content=False
+ )
+
defer.returnValue({
- group: state_map.values()
- for group, state_map in group_to_state.items()
+ group: [
+ state_event_map[v] for v in event_id_map.values() if v in state_event_map
+ ]
+ for group, event_id_map in group_to_ids.items()
})
+ def _have_persisted_state_group_txn(self, txn, state_group):
+ txn.execute(
+ "SELECT count(*) FROM state_groups WHERE id = ?",
+ (state_group,)
+ )
+ row = txn.fetchone()
+ return row and row[0]
+
def _store_mult_state_groups_txn(self, txn, events_and_contexts):
state_groups = {}
for event, context in events_and_contexts:
if event.internal_metadata.is_outlier():
continue
- if context.current_state is None:
- continue
-
- if context.state_group is not None:
- state_groups[event.event_id] = context.state_group
+ if context.current_state_ids is None:
continue
- state_events = dict(context.current_state)
+ state_groups[event.event_id] = context.state_group
- if event.is_state():
- state_events[(event.type, event.state_key)] = event
+ if self._have_persisted_state_group_txn(txn, context.state_group):
+ logger.info("Already persisted state_group: %r", context.state_group)
+ continue
- state_group = context.new_state_group_id
+ state_event_ids = dict(context.current_state_ids)
self._simple_insert_txn(
txn,
table="state_groups",
values={
- "id": state_group,
+ "id": context.state_group,
"room_id": event.room_id,
"event_id": event.event_id,
},
@@ -99,16 +123,15 @@ class StateStore(SQLBaseStore):
table="state_groups_state",
values=[
{
- "state_group": state_group,
- "room_id": state.room_id,
- "type": state.type,
- "state_key": state.state_key,
- "event_id": state.event_id,
+ "state_group": context.state_group,
+ "room_id": event.room_id,
+ "type": key[0],
+ "state_key": key[1],
+ "event_id": state_id,
}
- for state in state_events.values()
+ for key, state_id in state_event_ids.items()
],
)
- state_groups[event.event_id] = state_group
self._simple_insert_many_txn(
txn,
@@ -248,6 +271,31 @@ class StateStore(SQLBaseStore):
groups = set(event_to_groups.values())
group_to_state = yield self._get_state_for_groups(groups, types)
+ state_event_map = yield self.get_events(
+ [ev_id for sd in group_to_state.values() for ev_id in sd.values()],
+ get_prev_content=False
+ )
+
+ event_to_state = {
+ event_id: {
+ k: state_event_map[v]
+ for k, v in group_to_state[group].items()
+ if v in state_event_map
+ }
+ for event_id, group in event_to_groups.items()
+ }
+
+ defer.returnValue({event: event_to_state[event] for event in event_ids})
+
+ @defer.inlineCallbacks
+ def get_state_ids_for_events(self, event_ids, types):
+ event_to_groups = yield self._get_state_group_for_events(
+ event_ids,
+ )
+
+ groups = set(event_to_groups.values())
+ group_to_state = yield self._get_state_for_groups(groups, types)
+
event_to_state = {
event_id: group_to_state[group]
for event_id, group in event_to_groups.items()
@@ -272,6 +320,23 @@ class StateStore(SQLBaseStore):
state_map = yield self.get_state_for_events([event_id], types)
defer.returnValue(state_map[event_id])
+ @defer.inlineCallbacks
+ def get_state_ids_for_event(self, event_id, types=None):
+ """
+ Get the state dict corresponding to a particular event
+
+ Args:
+ event_id(str): event whose state should be returned
+ types(list[(str, str)]|None): List of (type, state_key) tuples
+ which are used to filter the state fetched. May be None, which
+ matches any key
+
+ Returns:
+ A deferred dict from (type, state_key) -> state_event
+ """
+ state_map = yield self.get_state_ids_for_events([event_id], types)
+ defer.returnValue(state_map[event_id])
+
@cached(num_args=2, max_entries=10000)
def _get_state_group_for_event(self, room_id, event_id):
return self._simple_select_one_onecol(
@@ -428,20 +493,13 @@ class StateStore(SQLBaseStore):
full=(types is None),
)
- state_events = yield self._get_events(
- [ev_id for sd in results.values() for ev_id in sd.values()],
- get_prev_content=False
- )
-
- state_events = {e.event_id: e for e in state_events}
-
# Remove all the entries with None values. The None values were just
# used for bookkeeping in the cache.
for group, state_dict in results.items():
results[group] = {
- key: state_events[event_id]
+ key: event_id
for key, event_id in state_dict.items()
- if event_id and event_id in state_events
+ if event_id
}
defer.returnValue(results)
@@ -473,5 +531,5 @@ class StateStore(SQLBaseStore):
"get_all_new_state_groups", get_all_new_state_groups_txn
)
- def get_state_stream_token(self):
- return self._state_groups_id_gen.get_current_token()
+ def get_next_state_group(self):
+ return self._state_groups_id_gen.get_next()
diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py
index 58d4de4f1d..5055c04b24 100644
--- a/synapse/storage/transactions.py
+++ b/synapse/storage/transactions.py
@@ -245,7 +245,7 @@ class TransactionStore(SQLBaseStore):
return self.cursor_to_dict(txn)
- @cached()
+ @cached(max_entries=10000)
def get_destination_retry_timings(self, destination):
"""Gets the current retry timings (if any) for a given destination.
@@ -387,8 +387,10 @@ class TransactionStore(SQLBaseStore):
def _cleanup_transactions(self):
now = self._clock.time_msec()
month_ago = now - 30 * 24 * 60 * 60 * 1000
+ six_hours_ago = now - 6 * 60 * 60 * 1000
def _cleanup_transactions_txn(txn):
txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,))
+ txn.execute("DELETE FROM sent_transactions WHERE ts < ?", (six_hours_ago,))
return self.runInteraction("_persist_in_mem_txns", _cleanup_transactions_txn)
|