diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 5a9e7720d9..f257721ea3 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -20,7 +20,7 @@ from .appservice import (
from ._base import Cache
from .directory import DirectoryStore
from .events import EventsStore
-from .presence import PresenceStore
+from .presence import PresenceStore, UserPresenceState
from .profile import ProfileStore
from .registration import RegistrationStore
from .room import RoomStore
@@ -47,6 +47,7 @@ from .account_data import AccountDataStore
from util.id_generators import IdGenerator, StreamIdGenerator
+from synapse.api.constants import PresenceState
from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -110,16 +111,19 @@ class DataStore(RoomMemberStore, RoomStore,
self._account_data_id_gen = StreamIdGenerator(
db_conn, "account_data_max_stream_id", "stream_id"
)
+ self._presence_id_gen = StreamIdGenerator(
+ db_conn, "presence_stream", "stream_id"
+ )
- self._transaction_id_gen = IdGenerator("sent_transactions", "id", self)
- self._state_groups_id_gen = IdGenerator("state_groups", "id", self)
- self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self)
- self._refresh_tokens_id_gen = IdGenerator("refresh_tokens", "id", self)
- self._pushers_id_gen = IdGenerator("pushers", "id", self)
- self._push_rule_id_gen = IdGenerator("push_rules", "id", self)
- self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self)
+ self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id")
+ self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id")
+ self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
+ self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id")
+ self._pushers_id_gen = IdGenerator(db_conn, "pushers", "id")
+ self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
+ self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
- events_max = self._stream_id_gen.get_max_token(None)
+ events_max = self._stream_id_gen.get_max_token()
event_cache_prefill, min_event_val = self._get_cache_dict(
db_conn, "events",
entity_column="room_id",
@@ -135,13 +139,31 @@ class DataStore(RoomMemberStore, RoomStore,
"MembershipStreamChangeCache", events_max,
)
- account_max = self._account_data_id_gen.get_max_token(None)
+ account_max = self._account_data_id_gen.get_max_token()
self._account_data_stream_cache = StreamChangeCache(
"AccountDataAndTagsChangeCache", account_max,
)
+ self.__presence_on_startup = self._get_active_presence(db_conn)
+
+ presence_cache_prefill, min_presence_val = self._get_cache_dict(
+ db_conn, "presence_stream",
+ entity_column="user_id",
+ stream_column="stream_id",
+ max_value=self._presence_id_gen.get_max_token(),
+ )
+ self.presence_stream_cache = StreamChangeCache(
+ "PresenceStreamChangeCache", min_presence_val,
+ prefilled_cache=presence_cache_prefill
+ )
+
super(DataStore, self).__init__(hs)
+ def take_presence_startup_info(self):
+ active_on_startup = self.__presence_on_startup
+ self.__presence_on_startup = None
+ return active_on_startup
+
def _get_cache_dict(self, db_conn, table, entity_column, stream_column, max_value):
# Fetch a mapping of room_id -> max stream position for "recent" rooms.
# It doesn't really matter how many we get, the StreamChangeCache will
@@ -161,6 +183,7 @@ class DataStore(RoomMemberStore, RoomStore,
txn = db_conn.cursor()
txn.execute(sql, (int(max_value),))
rows = txn.fetchall()
+ txn.close()
cache = {
row[0]: int(row[1])
@@ -174,6 +197,28 @@ class DataStore(RoomMemberStore, RoomStore,
return cache, min_val
+ def _get_active_presence(self, db_conn):
+ """Fetch non-offline presence from the database so that we can register
+ the appropriate time outs.
+ """
+
+ sql = (
+ "SELECT user_id, state, last_active_ts, last_federation_update_ts,"
+ " last_user_sync_ts, status_msg, currently_active FROM presence_stream"
+ " WHERE state != ?"
+ )
+ sql = self.database_engine.convert_param_style(sql)
+
+ txn = db_conn.cursor()
+ txn.execute(sql, (PresenceState.OFFLINE,))
+ rows = self.cursor_to_dict(txn)
+ txn.close()
+
+ for row in rows:
+ row["currently_active"] = bool(row["currently_active"])
+
+ return [UserPresenceState(**row) for row in rows]
+
@defer.inlineCallbacks
def insert_client_ip(self, user, access_token, ip, user_agent):
now = int(self._clock.time_msec())
diff --git a/synapse/storage/account_data.py b/synapse/storage/account_data.py
index b8387fc500..faddefe219 100644
--- a/synapse/storage/account_data.py
+++ b/synapse/storage/account_data.py
@@ -83,8 +83,40 @@ class AccountDataStore(SQLBaseStore):
"get_account_data_for_room", get_account_data_for_room_txn
)
- def get_updated_account_data_for_user(self, user_id, stream_id, room_ids=None):
- """Get all the client account_data for a that's changed.
+ def get_all_updated_account_data(self, last_global_id, last_room_id,
+ current_id, limit):
+ """Get all the client account_data that has changed on the server
+ Args:
+ last_global_id(int): The position to fetch from for top level data
+ last_room_id(int): The position to fetch from for per room data
+ current_id(int): The position to fetch up to.
+ Returns:
+ A deferred pair of lists of tuples of stream_id int, user_id string,
+ room_id string, type string, and content string.
+ """
+ def get_updated_account_data_txn(txn):
+ sql = (
+ "SELECT stream_id, user_id, account_data_type, content"
+ " FROM account_data WHERE ? < stream_id AND stream_id <= ?"
+ " ORDER BY stream_id ASC LIMIT ?"
+ )
+ txn.execute(sql, (last_global_id, current_id, limit))
+ global_results = txn.fetchall()
+
+ sql = (
+ "SELECT stream_id, user_id, room_id, account_data_type, content"
+ " FROM room_account_data WHERE ? < stream_id AND stream_id <= ?"
+ " ORDER BY stream_id ASC LIMIT ?"
+ )
+ txn.execute(sql, (last_room_id, current_id, limit))
+ room_results = txn.fetchall()
+ return (global_results, room_results)
+ return self.runInteraction(
+ "get_all_updated_account_data_txn", get_updated_account_data_txn
+ )
+
+ def get_updated_account_data_for_user(self, user_id, stream_id):
+ """Get all the client account_data for a that's changed for a user
Args:
user_id(str): The user to get the account_data for.
@@ -163,12 +195,12 @@ class AccountDataStore(SQLBaseStore):
)
self._update_max_stream_id(txn, next_id)
- with (yield self._account_data_id_gen.get_next(self)) as next_id:
+ with self._account_data_id_gen.get_next() as next_id:
yield self.runInteraction(
"add_room_account_data", add_account_data_txn, next_id
)
- result = yield self._account_data_id_gen.get_max_token(self)
+ result = self._account_data_id_gen.get_max_token()
defer.returnValue(result)
@defer.inlineCallbacks
@@ -202,12 +234,12 @@ class AccountDataStore(SQLBaseStore):
)
self._update_max_stream_id(txn, next_id)
- with (yield self._account_data_id_gen.get_next(self)) as next_id:
+ with self._account_data_id_gen.get_next() as next_id:
yield self.runInteraction(
"add_user_account_data", add_account_data_txn, next_id
)
- result = yield self._account_data_id_gen.get_max_token(self)
+ result = self._account_data_id_gen.get_max_token()
defer.returnValue(result)
def _update_max_stream_id(self, txn, next_id):
diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
index ce2c794025..3489315e0d 100644
--- a/synapse/storage/event_federation.py
+++ b/synapse/storage/event_federation.py
@@ -114,10 +114,10 @@ class EventFederationStore(SQLBaseStore):
retcol="event_id",
)
- def get_latest_events_in_room(self, room_id):
+ def get_latest_event_ids_and_hashes_in_room(self, room_id):
return self.runInteraction(
- "get_latest_events_in_room",
- self._get_latest_events_in_room,
+ "get_latest_event_ids_and_hashes_in_room",
+ self._get_latest_event_ids_and_hashes_in_room,
room_id,
)
@@ -132,7 +132,7 @@ class EventFederationStore(SQLBaseStore):
desc="get_latest_event_ids_in_room",
)
- def _get_latest_events_in_room(self, txn, room_id):
+ def _get_latest_event_ids_and_hashes_in_room(self, txn, room_id):
sql = (
"SELECT e.event_id, e.depth FROM events as e "
"INNER JOIN event_forward_extremities as f "
diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py
index d77a817682..5820539a92 100644
--- a/synapse/storage/event_push_actions.py
+++ b/synapse/storage/event_push_actions.py
@@ -27,15 +27,14 @@ class EventPushActionsStore(SQLBaseStore):
def _set_push_actions_for_event_and_users_txn(self, txn, event, tuples):
"""
:param event: the event set actions for
- :param tuples: list of tuples of (user_id, profile_tag, actions)
+ :param tuples: list of tuples of (user_id, actions)
"""
values = []
- for uid, profile_tag, actions in tuples:
+ for uid, actions in tuples:
values.append({
'room_id': event.room_id,
'event_id': event.event_id,
'user_id': uid,
- 'profile_tag': profile_tag,
'actions': json.dumps(actions),
'stream_ordering': event.internal_metadata.stream_ordering,
'topological_ordering': event.depth,
@@ -43,7 +42,7 @@ class EventPushActionsStore(SQLBaseStore):
'highlight': 1 if _action_has_highlight(actions) else 0,
})
- for uid, _, __ in tuples:
+ for uid, __ in tuples:
txn.call_after(
self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
(event.room_id, uid)
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 3a5c6ee4b1..60936500d8 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -75,8 +75,8 @@ class EventsStore(SQLBaseStore):
yield stream_orderings
stream_ordering_manager = stream_ordering_manager()
else:
- stream_ordering_manager = yield self._stream_id_gen.get_next_mult(
- self, len(events_and_contexts)
+ stream_ordering_manager = self._stream_id_gen.get_next_mult(
+ len(events_and_contexts)
)
with stream_ordering_manager as stream_orderings:
@@ -109,7 +109,7 @@ class EventsStore(SQLBaseStore):
stream_ordering = self.min_stream_token
if stream_ordering is None:
- stream_ordering_manager = yield self._stream_id_gen.get_next(self)
+ stream_ordering_manager = self._stream_id_gen.get_next()
else:
@contextmanager
def stream_ordering_manager():
@@ -131,7 +131,7 @@ class EventsStore(SQLBaseStore):
except _RollbackButIsFineException:
pass
- max_persisted_id = yield self._stream_id_gen.get_max_token(self)
+ max_persisted_id = yield self._stream_id_gen.get_max_token()
defer.returnValue((stream_ordering, max_persisted_id))
@defer.inlineCallbacks
@@ -1064,3 +1064,48 @@ class EventsStore(SQLBaseStore):
yield self._end_background_update(self.EVENT_ORIGIN_SERVER_TS_NAME)
defer.returnValue(result)
+
+ def get_current_backfill_token(self):
+ """The current minimum token that backfilled events have reached"""
+
+ # TODO: Fix race with the persit_event txn by using one of the
+ # stream id managers
+ return -self.min_stream_token
+
+ def get_all_new_events(self, last_backfill_id, last_forward_id,
+ current_backfill_id, current_forward_id, limit):
+ """Get all the new events that have arrived at the server either as
+ new events or as backfilled events"""
+ def get_all_new_events_txn(txn):
+ sql = (
+ "SELECT e.stream_ordering, ej.internal_metadata, ej.json"
+ " FROM events as e"
+ " JOIN event_json as ej"
+ " ON e.event_id = ej.event_id AND e.room_id = ej.room_id"
+ " WHERE ? < e.stream_ordering AND e.stream_ordering <= ?"
+ " ORDER BY e.stream_ordering ASC"
+ " LIMIT ?"
+ )
+ if last_forward_id != current_forward_id:
+ txn.execute(sql, (last_forward_id, current_forward_id, limit))
+ new_forward_events = txn.fetchall()
+ else:
+ new_forward_events = []
+
+ sql = (
+ "SELECT -e.stream_ordering, ej.internal_metadata, ej.json"
+ " FROM events as e"
+ " JOIN event_json as ej"
+ " ON e.event_id = ej.event_id AND e.room_id = ej.room_id"
+ " WHERE ? > e.stream_ordering AND e.stream_ordering >= ?"
+ " ORDER BY e.stream_ordering DESC"
+ " LIMIT ?"
+ )
+ if last_backfill_id != current_backfill_id:
+ txn.execute(sql, (-last_backfill_id, -current_backfill_id, limit))
+ new_backfill_events = txn.fetchall()
+ else:
+ new_backfill_events = []
+
+ return (new_forward_events, new_backfill_events)
+ return self.runInteraction("get_all_new_events", get_all_new_events_txn)
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 850736c85e..0fd5d497ab 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts.
-SCHEMA_VERSION = 29
+SCHEMA_VERSION = 30
dir_path = os.path.abspath(os.path.dirname(__file__))
diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py
index ef525f34c5..4cec31e316 100644
--- a/synapse/storage/presence.py
+++ b/synapse/storage/presence.py
@@ -14,73 +14,148 @@
# limitations under the License.
from ._base import SQLBaseStore
-from synapse.util.caches.descriptors import cached, cachedList
+from synapse.api.constants import PresenceState
+from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
+from collections import namedtuple
from twisted.internet import defer
+class UserPresenceState(namedtuple("UserPresenceState",
+ ("user_id", "state", "last_active_ts",
+ "last_federation_update_ts", "last_user_sync_ts",
+ "status_msg", "currently_active"))):
+ """Represents the current presence state of the user.
+
+ user_id (str)
+ last_active (int): Time in msec that the user last interacted with server.
+ last_federation_update (int): Time in msec since either a) we sent a presence
+ update to other servers or b) we received a presence update, depending
+ on if is a local user or not.
+ last_user_sync (int): Time in msec that the user last *completed* a sync
+ (or event stream).
+ status_msg (str): User set status message.
+ """
+
+ def copy_and_replace(self, **kwargs):
+ return self._replace(**kwargs)
+
+ @classmethod
+ def default(cls, user_id):
+ """Returns a default presence state.
+ """
+ return cls(
+ user_id=user_id,
+ state=PresenceState.OFFLINE,
+ last_active_ts=0,
+ last_federation_update_ts=0,
+ last_user_sync_ts=0,
+ status_msg=None,
+ currently_active=False,
+ )
+
+
class PresenceStore(SQLBaseStore):
- def create_presence(self, user_localpart):
- res = self._simple_insert(
- table="presence",
- values={"user_id": user_localpart},
- desc="create_presence",
+ @defer.inlineCallbacks
+ def update_presence(self, presence_states):
+ stream_ordering_manager = self._presence_id_gen.get_next_mult(
+ len(presence_states)
)
- self.get_presence_state.invalidate((user_localpart,))
- return res
+ with stream_ordering_manager as stream_orderings:
+ yield self.runInteraction(
+ "update_presence",
+ self._update_presence_txn, stream_orderings, presence_states,
+ )
- def has_presence_state(self, user_localpart):
- return self._simple_select_one(
- table="presence",
- keyvalues={"user_id": user_localpart},
- retcols=["user_id"],
- allow_none=True,
- desc="has_presence_state",
+ defer.returnValue((stream_orderings[-1], self._presence_id_gen.get_max_token()))
+
+ def _update_presence_txn(self, txn, stream_orderings, presence_states):
+ for stream_id, state in zip(stream_orderings, presence_states):
+ txn.call_after(
+ self.presence_stream_cache.entity_has_changed,
+ state.user_id, stream_id,
+ )
+
+ # Actually insert new rows
+ self._simple_insert_many_txn(
+ txn,
+ table="presence_stream",
+ values=[
+ {
+ "stream_id": stream_id,
+ "user_id": state.user_id,
+ "state": state.state,
+ "last_active_ts": state.last_active_ts,
+ "last_federation_update_ts": state.last_federation_update_ts,
+ "last_user_sync_ts": state.last_user_sync_ts,
+ "status_msg": state.status_msg,
+ "currently_active": state.currently_active,
+ }
+ for state in presence_states
+ ],
)
- @cached(max_entries=2000)
- def get_presence_state(self, user_localpart):
- return self._simple_select_one(
- table="presence",
- keyvalues={"user_id": user_localpart},
- retcols=["state", "status_msg", "mtime"],
- desc="get_presence_state",
+ # Delete old rows to stop database from getting really big
+ sql = (
+ "DELETE FROM presence_stream WHERE"
+ " stream_id < ?"
+ " AND user_id IN (%s)"
)
- @cachedList(get_presence_state.cache, list_name="user_localparts",
- inlineCallbacks=True)
- def get_presence_states(self, user_localparts):
- rows = yield self._simple_select_many_batch(
- table="presence",
- column="user_id",
- iterable=user_localparts,
- retcols=("user_id", "state", "status_msg", "mtime",),
- desc="get_presence_states",
+ batches = (
+ presence_states[i:i + 50]
+ for i in xrange(0, len(presence_states), 50)
)
+ for states in batches:
+ args = [stream_id]
+ args.extend(s.user_id for s in states)
+ txn.execute(
+ sql % (",".join("?" for _ in states),),
+ args
+ )
+
+ def get_all_presence_updates(self, last_id, current_id):
+ def get_all_presence_updates_txn(txn):
+ sql = (
+ "SELECT stream_id, user_id, state, last_active_ts,"
+ " last_federation_update_ts, last_user_sync_ts, status_msg,"
+ " currently_active"
+ " FROM presence_stream"
+ " WHERE ? < stream_id AND stream_id <= ?"
+ )
+ txn.execute(sql, (last_id, current_id))
+ return txn.fetchall()
- defer.returnValue({
- row["user_id"]: {
- "state": row["state"],
- "status_msg": row["status_msg"],
- "mtime": row["mtime"],
- }
- for row in rows
- })
+ return self.runInteraction(
+ "get_all_presence_updates", get_all_presence_updates_txn
+ )
@defer.inlineCallbacks
- def set_presence_state(self, user_localpart, new_state):
- res = yield self._simple_update_one(
- table="presence",
- keyvalues={"user_id": user_localpart},
- updatevalues={"state": new_state["state"],
- "status_msg": new_state["status_msg"],
- "mtime": self._clock.time_msec()},
- desc="set_presence_state",
+ def get_presence_for_users(self, user_ids):
+ rows = yield self._simple_select_many_batch(
+ table="presence_stream",
+ column="user_id",
+ iterable=user_ids,
+ keyvalues={},
+ retcols=(
+ "user_id",
+ "state",
+ "last_active_ts",
+ "last_federation_update_ts",
+ "last_user_sync_ts",
+ "status_msg",
+ "currently_active",
+ ),
)
- self.get_presence_state.invalidate((user_localpart,))
- defer.returnValue(res)
+ for row in rows:
+ row["currently_active"] = bool(row["currently_active"])
+
+ defer.returnValue([UserPresenceState(**row) for row in rows])
+
+ def get_current_presence_token(self):
+ return self._presence_id_gen.get_max_token()
def allow_presence_visible(self, observed_localpart, observer_userid):
return self._simple_insert(
@@ -128,6 +203,7 @@ class PresenceStore(SQLBaseStore):
desc="set_presence_list_accepted",
)
self.get_presence_list_accepted.invalidate((observer_localpart,))
+ self.get_presence_list_observers_accepted.invalidate((observed_userid,))
defer.returnValue(result)
def get_presence_list(self, observer_localpart, accepted=None):
@@ -154,6 +230,19 @@ class PresenceStore(SQLBaseStore):
desc="get_presence_list_accepted",
)
+ @cachedInlineCallbacks()
+ def get_presence_list_observers_accepted(self, observed_userid):
+ user_localparts = yield self._simple_select_onecol(
+ table="presence_list",
+ keyvalues={"observed_user_id": observed_userid, "accepted": True},
+ retcol="user_id",
+ desc="get_presence_list_accepted",
+ )
+
+ defer.returnValue([
+ "@%s:%s" % (u, self.hs.hostname,) for u in user_localparts
+ ])
+
@defer.inlineCallbacks
def del_presence_list(self, observer_localpart, observed_userid):
yield self._simple_delete_one(
@@ -163,3 +252,4 @@ class PresenceStore(SQLBaseStore):
desc="del_presence_list",
)
self.get_presence_list_accepted.invalidate((observer_localpart,))
+ self.get_presence_list_observers_accepted.invalidate((observed_userid,))
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index f9a48171ba..56e69495b1 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -99,38 +99,36 @@ class PushRuleStore(SQLBaseStore):
results.setdefault(row['user_name'], {})[row['rule_id']] = row['enabled']
defer.returnValue(results)
- @defer.inlineCallbacks
- def add_push_rule(self, before, after, **kwargs):
- vals = kwargs
- if 'conditions' in vals:
- vals['conditions'] = json.dumps(vals['conditions'])
- if 'actions' in vals:
- vals['actions'] = json.dumps(vals['actions'])
-
- # we could check the rest of the keys are valid column names
- # but sqlite will do that anyway so I think it's just pointless.
- vals.pop("id", None)
+ def add_push_rule(
+ self, user_id, rule_id, priority_class, conditions, actions,
+ before=None, after=None
+ ):
+ conditions_json = json.dumps(conditions)
+ actions_json = json.dumps(actions)
if before or after:
- ret = yield self.runInteraction(
+ return self.runInteraction(
"_add_push_rule_relative_txn",
self._add_push_rule_relative_txn,
- before=before,
- after=after,
- **vals
+ user_id, rule_id, priority_class,
+ conditions_json, actions_json, before, after,
)
- defer.returnValue(ret)
else:
- ret = yield self.runInteraction(
+ return self.runInteraction(
"_add_push_rule_highest_priority_txn",
self._add_push_rule_highest_priority_txn,
- **vals
+ user_id, rule_id, priority_class,
+ conditions_json, actions_json,
)
- defer.returnValue(ret)
- def _add_push_rule_relative_txn(self, txn, user_id, **kwargs):
- after = kwargs.pop("after", None)
- before = kwargs.pop("before", None)
+ def _add_push_rule_relative_txn(
+ self, txn, user_id, rule_id, priority_class,
+ conditions_json, actions_json, before, after
+ ):
+ # Lock the table since otherwise we'll have annoying races between the
+ # SELECT here and the UPSERT below.
+ self.database_engine.lock_table(txn, "push_rules")
+
relative_to_rule = before or after
res = self._simple_select_one_txn(
@@ -149,69 +147,45 @@ class PushRuleStore(SQLBaseStore):
"before/after rule not found: %s" % (relative_to_rule,)
)
- priority_class = res["priority_class"]
+ base_priority_class = res["priority_class"]
base_rule_priority = res["priority"]
- if 'priority_class' in kwargs and kwargs['priority_class'] != priority_class:
+ if base_priority_class != priority_class:
raise InconsistentRuleException(
"Given priority class does not match class of relative rule"
)
- new_rule = kwargs
- new_rule.pop("before", None)
- new_rule.pop("after", None)
- new_rule['priority_class'] = priority_class
- new_rule['user_name'] = user_id
- new_rule['id'] = self._push_rule_id_gen.get_next_txn(txn)
-
- # check if the priority before/after is free
- new_rule_priority = base_rule_priority
- if after:
- new_rule_priority -= 1
+ if before:
+ # Higher priority rules are executed first, So adding a rule before
+ # a rule means giving it a higher priority than that rule.
+ new_rule_priority = base_rule_priority + 1
else:
- new_rule_priority += 1
-
- new_rule['priority'] = new_rule_priority
+ # We increment the priority of the existing rules to make space for
+ # the new rule. Therefore if we want this rule to appear after
+ # an existing rule we give it the priority of the existing rule,
+ # and then increment the priority of the existing rule.
+ new_rule_priority = base_rule_priority
sql = (
- "SELECT COUNT(*) FROM push_rules"
- " WHERE user_name = ? AND priority_class = ? AND priority = ?"
+ "UPDATE push_rules SET priority = priority + 1"
+ " WHERE user_name = ? AND priority_class = ? AND priority >= ?"
)
- txn.execute(sql, (user_id, priority_class, new_rule_priority))
- res = txn.fetchall()
- num_conflicting = res[0][0]
-
- # if there are conflicting rules, bump everything
- if num_conflicting:
- sql = "UPDATE push_rules SET priority = priority "
- if after:
- sql += "-1"
- else:
- sql += "+1"
- sql += " WHERE user_name = ? AND priority_class = ? AND priority "
- if after:
- sql += "<= ?"
- else:
- sql += ">= ?"
- txn.execute(sql, (user_id, priority_class, new_rule_priority))
-
- txn.call_after(
- self.get_push_rules_for_user.invalidate, (user_id,)
- )
+ txn.execute(sql, (user_id, priority_class, new_rule_priority))
- txn.call_after(
- self.get_push_rules_enabled_for_user.invalidate, (user_id,)
+ self._upsert_push_rule_txn(
+ txn, user_id, rule_id, priority_class, new_rule_priority,
+ conditions_json, actions_json,
)
- self._simple_insert_txn(
- txn,
- table="push_rules",
- values=new_rule,
- )
+ def _add_push_rule_highest_priority_txn(
+ self, txn, user_id, rule_id, priority_class,
+ conditions_json, actions_json
+ ):
+ # Lock the table since otherwise we'll have annoying races between the
+ # SELECT here and the UPSERT below.
+ self.database_engine.lock_table(txn, "push_rules")
- def _add_push_rule_highest_priority_txn(self, txn, user_id,
- priority_class, **kwargs):
# find the highest priority rule in that class
sql = (
"SELECT COUNT(*), MAX(priority) FROM push_rules"
@@ -225,12 +199,48 @@ class PushRuleStore(SQLBaseStore):
if how_many > 0:
new_prio = highest_prio + 1
- # and insert the new rule
- new_rule = kwargs
- new_rule['id'] = self._push_rule_id_gen.get_next_txn(txn)
- new_rule['user_name'] = user_id
- new_rule['priority_class'] = priority_class
- new_rule['priority'] = new_prio
+ self._upsert_push_rule_txn(
+ txn,
+ user_id, rule_id, priority_class, new_prio,
+ conditions_json, actions_json,
+ )
+
+ def _upsert_push_rule_txn(
+ self, txn, user_id, rule_id, priority_class,
+ priority, conditions_json, actions_json
+ ):
+ """Specialised version of _simple_upsert_txn that picks a push_rule_id
+ using the _push_rule_id_gen if it needs to insert the rule. It assumes
+ that the "push_rules" table is locked"""
+
+ sql = (
+ "UPDATE push_rules"
+ " SET priority_class = ?, priority = ?, conditions = ?, actions = ?"
+ " WHERE user_name = ? AND rule_id = ?"
+ )
+
+ txn.execute(sql, (
+ priority_class, priority, conditions_json, actions_json,
+ user_id, rule_id,
+ ))
+
+ if txn.rowcount == 0:
+ # We didn't update a row with the given rule_id so insert one
+ push_rule_id = self._push_rule_id_gen.get_next()
+
+ self._simple_insert_txn(
+ txn,
+ table="push_rules",
+ values={
+ "id": push_rule_id,
+ "user_name": user_id,
+ "rule_id": rule_id,
+ "priority_class": priority_class,
+ "priority": priority,
+ "conditions": conditions_json,
+ "actions": actions_json,
+ },
+ )
txn.call_after(
self.get_push_rules_for_user.invalidate, (user_id,)
@@ -239,12 +249,6 @@ class PushRuleStore(SQLBaseStore):
self.get_push_rules_enabled_for_user.invalidate, (user_id,)
)
- self._simple_insert_txn(
- txn,
- table="push_rules",
- values=new_rule,
- )
-
@defer.inlineCallbacks
def delete_push_rule(self, user_id, rule_id):
"""
@@ -275,7 +279,7 @@ class PushRuleStore(SQLBaseStore):
defer.returnValue(ret)
def _set_push_rule_enabled_txn(self, txn, user_id, rule_id, enabled):
- new_id = self._push_rules_enable_id_gen.get_next_txn(txn)
+ new_id = self._push_rules_enable_id_gen.get_next()
self._simple_upsert_txn(
txn,
"push_rules_enable",
@@ -290,6 +294,31 @@ class PushRuleStore(SQLBaseStore):
self.get_push_rules_enabled_for_user.invalidate, (user_id,)
)
+ def set_push_rule_actions(self, user_id, rule_id, actions, is_default_rule):
+ actions_json = json.dumps(actions)
+
+ def set_push_rule_actions_txn(txn):
+ if is_default_rule:
+ # Add a dummy rule to the rules table with the user specified
+ # actions.
+ priority_class = -1
+ priority = 1
+ self._upsert_push_rule_txn(
+ txn, user_id, rule_id, priority_class, priority,
+ "[]", actions_json
+ )
+ else:
+ self._simple_update_one_txn(
+ txn,
+ "push_rules",
+ {'user_name': user_id, 'rule_id': rule_id},
+ {'actions': actions_json},
+ )
+
+ return self.runInteraction(
+ "set_push_rule_actions", set_push_rule_actions_txn,
+ )
+
class RuleNotFoundException(Exception):
pass
diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py
index 8ec706178a..7693ab9082 100644
--- a/synapse/storage/pusher.py
+++ b/synapse/storage/pusher.py
@@ -80,11 +80,11 @@ class PusherStore(SQLBaseStore):
defer.returnValue(rows)
@defer.inlineCallbacks
- def add_pusher(self, user_id, access_token, profile_tag, kind, app_id,
+ def add_pusher(self, user_id, access_token, kind, app_id,
app_display_name, device_display_name,
- pushkey, pushkey_ts, lang, data):
+ pushkey, pushkey_ts, lang, data, profile_tag=""):
try:
- next_id = yield self._pushers_id_gen.get_next()
+ next_id = self._pushers_id_gen.get_next()
yield self._simple_upsert(
"pushers",
dict(
@@ -95,12 +95,12 @@ class PusherStore(SQLBaseStore):
dict(
access_token=access_token,
kind=kind,
- profile_tag=profile_tag,
app_display_name=app_display_name,
device_display_name=device_display_name,
ts=pushkey_ts,
lang=lang,
data=encode_canonical_json(data),
+ profile_tag=profile_tag,
),
insertion_values=dict(
id=next_id,
diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py
index 4202a6b3dc..dbc074d6b5 100644
--- a/synapse/storage/receipts.py
+++ b/synapse/storage/receipts.py
@@ -31,7 +31,7 @@ class ReceiptsStore(SQLBaseStore):
super(ReceiptsStore, self).__init__(hs)
self._receipts_stream_cache = StreamChangeCache(
- "ReceiptsRoomChangeCache", self._receipts_id_gen.get_max_token(None)
+ "ReceiptsRoomChangeCache", self._receipts_id_gen.get_max_token()
)
@cached(num_args=2)
@@ -222,7 +222,7 @@ class ReceiptsStore(SQLBaseStore):
defer.returnValue(results)
def get_max_receipt_stream_id(self):
- return self._receipts_id_gen.get_max_token(self)
+ return self._receipts_id_gen.get_max_token()
def insert_linearized_receipt_txn(self, txn, room_id, receipt_type,
user_id, event_id, data, stream_id):
@@ -330,7 +330,7 @@ class ReceiptsStore(SQLBaseStore):
"insert_receipt_conv", graph_to_linear
)
- stream_id_manager = yield self._receipts_id_gen.get_next(self)
+ stream_id_manager = self._receipts_id_gen.get_next()
with stream_id_manager as stream_id:
have_persisted = yield self.runInteraction(
"insert_linearized_receipt",
@@ -347,7 +347,7 @@ class ReceiptsStore(SQLBaseStore):
room_id, receipt_type, user_id, event_ids, data
)
- max_persisted_id = yield self._stream_id_gen.get_max_token(self)
+ max_persisted_id = self._stream_id_gen.get_max_token()
defer.returnValue((stream_id, max_persisted_id))
@@ -390,3 +390,19 @@ class ReceiptsStore(SQLBaseStore):
"data": json.dumps(data),
}
)
+
+ def get_all_updated_receipts(self, last_id, current_id, limit):
+ def get_all_updated_receipts_txn(txn):
+ sql = (
+ "SELECT stream_id, room_id, receipt_type, user_id, event_id, data"
+ " FROM receipts_linearized"
+ " WHERE ? < stream_id AND stream_id <= ?"
+ " ORDER BY stream_id ASC"
+ " LIMIT ?"
+ )
+ txn.execute(sql, (last_id, current_id, limit))
+
+ return txn.fetchall()
+ return self.runInteraction(
+ "get_all_updated_receipts", get_all_updated_receipts_txn
+ )
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 967c732bda..ad1157f979 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -40,7 +40,7 @@ class RegistrationStore(SQLBaseStore):
Raises:
StoreError if there was a problem adding this.
"""
- next_id = yield self._access_tokens_id_gen.get_next()
+ next_id = self._access_tokens_id_gen.get_next()
yield self._simple_insert(
"access_tokens",
@@ -62,7 +62,7 @@ class RegistrationStore(SQLBaseStore):
Raises:
StoreError if there was a problem adding this.
"""
- next_id = yield self._refresh_tokens_id_gen.get_next()
+ next_id = self._refresh_tokens_id_gen.get_next()
yield self._simple_insert(
"refresh_tokens",
@@ -99,7 +99,7 @@ class RegistrationStore(SQLBaseStore):
def _register(self, txn, user_id, token, password_hash, was_guest, make_guest):
now = int(self.clock.time())
- next_id = self._access_tokens_id_gen.get_next_txn(txn)
+ next_id = self._access_tokens_id_gen.get_next()
try:
if was_guest:
@@ -387,3 +387,47 @@ class RegistrationStore(SQLBaseStore):
"find_next_generated_user_id",
_find_next_generated_user_id
)))
+
+ @defer.inlineCallbacks
+ def get_3pid_guest_access_token(self, medium, address):
+ ret = yield self._simple_select_one(
+ "threepid_guest_access_tokens",
+ {
+ "medium": medium,
+ "address": address
+ },
+ ["guest_access_token"], True, 'get_3pid_guest_access_token'
+ )
+ if ret:
+ defer.returnValue(ret["guest_access_token"])
+ defer.returnValue(None)
+
+ @defer.inlineCallbacks
+ def save_or_get_3pid_guest_access_token(
+ self, medium, address, access_token, inviter_user_id
+ ):
+ """
+ Gets the 3pid's guest access token if exists, else saves access_token.
+
+ :param medium (str): Medium of the 3pid. Must be "email".
+ :param address (str): 3pid address.
+ :param access_token (str): The access token to persist if none is
+ already persisted.
+ :param inviter_user_id (str): User ID of the inviter.
+ :return (deferred str): Whichever access token is persisted at the end
+ of this function call.
+ """
+ def insert(txn):
+ txn.execute(
+ "INSERT INTO threepid_guest_access_tokens "
+ "(medium, address, guest_access_token, first_inviter) "
+ "VALUES (?, ?, ?, ?)",
+ (medium, address, access_token, inviter_user_id)
+ )
+
+ try:
+ yield self.runInteraction("save_3pid_guest_access_token", insert)
+ defer.returnValue(access_token)
+ except self.database_engine.module.IntegrityError:
+ ret = yield self.get_3pid_guest_access_token(medium, address)
+ defer.returnValue(ret)
diff --git a/synapse/storage/schema/delta/30/presence_stream.sql b/synapse/storage/schema/delta/30/presence_stream.sql
new file mode 100644
index 0000000000..606bbb037d
--- /dev/null
+++ b/synapse/storage/schema/delta/30/presence_stream.sql
@@ -0,0 +1,30 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+ CREATE TABLE presence_stream(
+ stream_id BIGINT,
+ user_id TEXT,
+ state TEXT,
+ last_active_ts BIGINT,
+ last_federation_update_ts BIGINT,
+ last_user_sync_ts BIGINT,
+ status_msg TEXT,
+ currently_active BOOLEAN
+ );
+
+ CREATE INDEX presence_stream_id ON presence_stream(stream_id, user_id);
+ CREATE INDEX presence_stream_user_id ON presence_stream(user_id);
+ CREATE INDEX presence_stream_state ON presence_stream(state);
diff --git a/synapse/storage/schema/delta/30/threepid_guest_access_tokens.sql b/synapse/storage/schema/delta/30/threepid_guest_access_tokens.sql
new file mode 100644
index 0000000000..0dd2f1360c
--- /dev/null
+++ b/synapse/storage/schema/delta/30/threepid_guest_access_tokens.sql
@@ -0,0 +1,24 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Stores guest account access tokens generated for unbound 3pids.
+CREATE TABLE threepid_guest_access_tokens(
+ medium TEXT, -- The medium of the 3pid. Must be "email".
+ address TEXT, -- The 3pid address.
+ guest_access_token TEXT, -- The access token for a guest user for this 3pid.
+ first_inviter TEXT -- User ID of the first user to invite this 3pid to a room.
+);
+
+CREATE UNIQUE INDEX threepid_guest_access_tokens_index ON threepid_guest_access_tokens(medium, address);
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 372b540002..8ed8a21b0a 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -83,7 +83,7 @@ class StateStore(SQLBaseStore):
if event.is_state():
state_events[(event.type, event.state_key)] = event
- state_group = self._state_groups_id_gen.get_next_txn(txn)
+ state_group = self._state_groups_id_gen.get_next()
self._simple_insert_txn(
txn,
table="state_groups",
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index c236dafafb..8908d5b5da 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -531,7 +531,7 @@ class StreamStore(SQLBaseStore):
@defer.inlineCallbacks
def get_room_events_max_id(self, direction='f'):
- token = yield self._stream_id_gen.get_max_token(self)
+ token = yield self._stream_id_gen.get_max_token()
if direction != 'b':
defer.returnValue("s%d" % (token,))
else:
diff --git a/synapse/storage/tags.py b/synapse/storage/tags.py
index e1a9c0c261..a0e6b42b30 100644
--- a/synapse/storage/tags.py
+++ b/synapse/storage/tags.py
@@ -30,7 +30,7 @@ class TagsStore(SQLBaseStore):
Returns:
A deferred int.
"""
- return self._account_data_id_gen.get_max_token(self)
+ return self._account_data_id_gen.get_max_token()
@cached()
def get_tags_for_user(self, user_id):
@@ -59,6 +59,59 @@ class TagsStore(SQLBaseStore):
return deferred
@defer.inlineCallbacks
+ def get_all_updated_tags(self, last_id, current_id, limit):
+ """Get all the client tags that have changed on the server
+ Args:
+ last_id(int): The position to fetch from.
+ current_id(int): The position to fetch up to.
+ Returns:
+ A deferred list of tuples of stream_id int, user_id string,
+ room_id string, tag string and content string.
+ """
+ def get_all_updated_tags_txn(txn):
+ sql = (
+ "SELECT stream_id, user_id, room_id"
+ " FROM room_tags_revisions as r"
+ " WHERE ? < stream_id AND stream_id <= ?"
+ " ORDER BY stream_id ASC LIMIT ?"
+ )
+ txn.execute(sql, (last_id, current_id, limit))
+ return txn.fetchall()
+
+ tag_ids = yield self.runInteraction(
+ "get_all_updated_tags", get_all_updated_tags_txn
+ )
+
+ def get_tag_content(txn, tag_ids):
+ sql = (
+ "SELECT tag, content"
+ " FROM room_tags"
+ " WHERE user_id=? AND room_id=?"
+ )
+ results = []
+ for stream_id, user_id, room_id in tag_ids:
+ txn.execute(sql, (user_id, room_id))
+ tags = []
+ for tag, content in txn.fetchall():
+ tags.append(json.dumps(tag) + ":" + content)
+ tag_json = "{" + ",".join(tags) + "}"
+ results.append((stream_id, user_id, room_id, tag_json))
+
+ return results
+
+ batch_size = 50
+ results = []
+ for i in xrange(0, len(tag_ids), batch_size):
+ tags = yield self.runInteraction(
+ "get_all_updated_tag_content",
+ get_tag_content,
+ tag_ids[i:i + batch_size],
+ )
+ results.extend(tags)
+
+ defer.returnValue(results)
+
+ @defer.inlineCallbacks
def get_updated_tags(self, user_id, stream_id):
"""Get all the tags for the rooms where the tags have changed since the
given version
@@ -142,12 +195,12 @@ class TagsStore(SQLBaseStore):
)
self._update_revision_txn(txn, user_id, room_id, next_id)
- with (yield self._account_data_id_gen.get_next(self)) as next_id:
+ with self._account_data_id_gen.get_next() as next_id:
yield self.runInteraction("add_tag", add_tag_txn, next_id)
self.get_tags_for_user.invalidate((user_id,))
- result = yield self._account_data_id_gen.get_max_token(self)
+ result = self._account_data_id_gen.get_max_token()
defer.returnValue(result)
@defer.inlineCallbacks
@@ -164,12 +217,12 @@ class TagsStore(SQLBaseStore):
txn.execute(sql, (user_id, room_id, tag))
self._update_revision_txn(txn, user_id, room_id, next_id)
- with (yield self._account_data_id_gen.get_next(self)) as next_id:
+ with self._account_data_id_gen.get_next() as next_id:
yield self.runInteraction("remove_tag", remove_tag_txn, next_id)
self.get_tags_for_user.invalidate((user_id,))
- result = yield self._account_data_id_gen.get_max_token(self)
+ result = self._account_data_id_gen.get_max_token()
defer.returnValue(result)
def _update_revision_txn(self, txn, user_id, room_id, next_id):
diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py
index 4475c451c1..d338dfcf0a 100644
--- a/synapse/storage/transactions.py
+++ b/synapse/storage/transactions.py
@@ -117,7 +117,7 @@ class TransactionStore(SQLBaseStore):
def _prep_send_transaction(self, txn, transaction_id, destination,
origin_server_ts):
- next_id = self._transaction_id_gen.get_next_txn(txn)
+ next_id = self._transaction_id_gen.get_next()
# First we find out what the prev_txns should be.
# Since we know that we are only sending one transaction at a time,
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 5c522f4ab9..efe3f68e6e 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -13,51 +13,30 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
-
from collections import deque
import contextlib
import threading
class IdGenerator(object):
- def __init__(self, table, column, store):
+ def __init__(self, db_conn, table, column):
self.table = table
self.column = column
- self.store = store
self._lock = threading.Lock()
- self._next_id = None
+ cur = db_conn.cursor()
+ self._next_id = self._load_next_id(cur)
+ cur.close()
- @defer.inlineCallbacks
- def get_next(self):
- if self._next_id is None:
- yield self.store.runInteraction(
- "IdGenerator_%s" % (self.table,),
- self.get_next_txn,
- )
+ def _load_next_id(self, txn):
+ txn.execute("SELECT MAX(%s) FROM %s" % (self.column, self.table,))
+ val, = txn.fetchone()
+ return val + 1 if val else 1
+ def get_next(self):
with self._lock:
i = self._next_id
self._next_id += 1
- defer.returnValue(i)
-
- def get_next_txn(self, txn):
- with self._lock:
- if self._next_id:
- i = self._next_id
- self._next_id += 1
- return i
- else:
- txn.execute(
- "SELECT MAX(%s) FROM %s" % (self.column, self.table,)
- )
-
- val, = txn.fetchone()
- cur = val or 0
- cur += 1
- self._next_id = cur + 1
-
- return cur
+ return i
class StreamIdGenerator(object):
@@ -69,7 +48,7 @@ class StreamIdGenerator(object):
persistence of events can complete out of order.
Usage:
- with stream_id_gen.get_next_txn(txn) as stream_id:
+ with stream_id_gen.get_next() as stream_id:
# ... persist event ...
"""
def __init__(self, db_conn, table, column):
@@ -79,15 +58,21 @@ class StreamIdGenerator(object):
self._lock = threading.Lock()
cur = db_conn.cursor()
- self._current_max = self._get_or_compute_current_max(cur)
+ self._current_max = self._load_current_max(cur)
cur.close()
self._unfinished_ids = deque()
- def get_next(self, store):
+ def _load_current_max(self, txn):
+ txn.execute("SELECT MAX(%s) FROM %s" % (self.column, self.table))
+ rows = txn.fetchall()
+ val, = rows[0]
+ return int(val) if val else 1
+
+ def get_next(self):
"""
Usage:
- with yield stream_id_gen.get_next as stream_id:
+ with stream_id_gen.get_next() as stream_id:
# ... persist event ...
"""
with self._lock:
@@ -106,10 +91,10 @@ class StreamIdGenerator(object):
return manager()
- def get_next_mult(self, store, n):
+ def get_next_mult(self, n):
"""
Usage:
- with yield stream_id_gen.get_next(store, n) as stream_ids:
+ with stream_id_gen.get_next(n) as stream_ids:
# ... persist events ...
"""
with self._lock:
@@ -130,7 +115,7 @@ class StreamIdGenerator(object):
return manager()
- def get_max_token(self, store):
+ def get_max_token(self):
"""Returns the maximum stream id such that all stream ids less than or
equal to it have been successfully persisted.
"""
@@ -139,13 +124,3 @@ class StreamIdGenerator(object):
return self._unfinished_ids[0] - 1
return self._current_max
-
- def _get_or_compute_current_max(self, txn):
- with self._lock:
- txn.execute("SELECT MAX(%s) FROM %s" % (self.column, self.table))
- rows = txn.fetchall()
- val, = rows[0]
-
- self._current_max = int(val) if val else 1
-
- return self._current_max
|