diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 349f96e24b..5e72985cda 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -49,6 +49,7 @@ from .tags import TagsStore
from .account_data import AccountDataStore
from .openid import OpenIdStore
from .client_ips import ClientIpStore
+from .user_directory import UserDirectoryStore
from .util.id_generators import IdGenerator, StreamIdGenerator, ChainedIdGenerator
from .engines import PostgresEngine
@@ -86,6 +87,7 @@ class DataStore(RoomMemberStore, RoomStore,
ClientIpStore,
DeviceStore,
DeviceInboxStore,
+ UserDirectoryStore,
):
def __init__(self, db_conn, hs):
@@ -221,6 +223,18 @@ class DataStore(RoomMemberStore, RoomStore,
"DeviceListFederationStreamChangeCache", device_list_max,
)
+ curr_state_delta_prefill, min_curr_state_delta_id = self._get_cache_dict(
+ db_conn, "current_state_delta_stream",
+ entity_column="room_id",
+ stream_column="stream_id",
+ max_value=events_max, # As we share the stream id with events token
+ limit=1000,
+ )
+ self._curr_state_delta_stream_cache = StreamChangeCache(
+ "_curr_state_delta_stream_cache", min_curr_state_delta_id,
+ prefilled_cache=curr_state_delta_prefill,
+ )
+
cur = LoggingTransaction(
db_conn.cursor(),
name="_find_stream_orderings_for_times_txn",
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index f214b9d4c4..51730a88bf 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -438,6 +438,11 @@ class SQLBaseStore(object):
txn.execute(sql, vals)
+ def _simple_insert_many(self, table, values, desc):
+ return self.runInteraction(
+ desc, self._simple_insert_many_txn, table, values
+ )
+
@staticmethod
def _simple_insert_many_txn(txn, table, values):
if not values:
diff --git a/synapse/storage/appservice.py b/synapse/storage/appservice.py
index 514570561f..532df736a5 100644
--- a/synapse/storage/appservice.py
+++ b/synapse/storage/appservice.py
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+import re
import simplejson as json
from twisted.internet import defer
@@ -36,16 +37,31 @@ class ApplicationServiceStore(SQLBaseStore):
hs.config.app_service_config_files
)
+ # We precompie a regex constructed from all the regexes that the AS's
+ # have registered for exclusive users.
+ exclusive_user_regexes = [
+ regex.pattern
+ for service in self.services_cache
+ for regex in service.get_exlusive_user_regexes()
+ ]
+ if exclusive_user_regexes:
+ exclusive_user_regex = "|".join("(" + r + ")" for r in exclusive_user_regexes)
+ self.exclusive_user_regex = re.compile(exclusive_user_regex)
+ else:
+ # We handle this case specially otherwise the constructed regex
+ # will always match
+ self.exclusive_user_regex = None
+
def get_app_services(self):
return self.services_cache
def get_if_app_services_interested_in_user(self, user_id):
- """Check if the user is one associated with an app service
+ """Check if the user is one associated with an app service (exclusively)
"""
- for service in self.services_cache:
- if service.is_interested_in_user(user_id):
- return True
- return False
+ if self.exclusive_user_regex:
+ return bool(self.exclusive_user_regex.match(user_id))
+ else:
+ return False
def get_app_service_by_user_id(self, user_id):
"""Retrieve an application service from their user ID.
diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py
index 747d2df622..014ab635b7 100644
--- a/synapse/storage/client_ips.py
+++ b/synapse/storage/client_ips.py
@@ -20,6 +20,8 @@ from twisted.internet import defer
from ._base import Cache
from . import background_updates
+import os
+
logger = logging.getLogger(__name__)
# Number of msec of granularity to store the user IP 'last seen' time. Smaller
@@ -28,12 +30,15 @@ logger = logging.getLogger(__name__)
LAST_SEEN_GRANULARITY = 120 * 1000
+CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1))
+
+
class ClientIpStore(background_updates.BackgroundUpdateStore):
def __init__(self, hs):
self.client_ip_last_seen = Cache(
name="client_ip_last_seen",
keylen=4,
- max_entries=5000,
+ max_entries=50000 * CACHE_SIZE_FACTOR,
)
super(ClientIpStore, self).__init__(hs)
diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py
index d9936c88bb..bb27fd1f70 100644
--- a/synapse/storage/devices.py
+++ b/synapse/storage/devices.py
@@ -368,7 +368,7 @@ class DeviceStore(SQLBaseStore):
prev_sent_id_sql = """
SELECT coalesce(max(stream_id), 0) as stream_id
- FROM device_lists_outbound_pokes
+ FROM device_lists_outbound_last_success
WHERE destination = ? AND user_id = ? AND stream_id <= ?
"""
@@ -510,32 +510,43 @@ class DeviceStore(SQLBaseStore):
)
def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id):
- # First we DELETE all rows such that only the latest row for each
- # (destination, user_id is left. We do this by selecting first and
- # deleting.
+ # We update the device_lists_outbound_last_success with the successfully
+ # poked users. We do the join to see which users need to be inserted and
+ # which updated.
sql = """
- SELECT user_id, coalesce(max(stream_id), 0) FROM device_lists_outbound_pokes
- WHERE destination = ? AND stream_id <= ?
+ SELECT user_id, coalesce(max(o.stream_id), 0), (max(s.stream_id) IS NOT NULL)
+ FROM device_lists_outbound_pokes as o
+ LEFT JOIN device_lists_outbound_last_success as s
+ USING (destination, user_id)
+ WHERE destination = ? AND o.stream_id <= ?
GROUP BY user_id
- HAVING count(*) > 1
"""
txn.execute(sql, (destination, stream_id,))
rows = txn.fetchall()
sql = """
- DELETE FROM device_lists_outbound_pokes
- WHERE destination = ? AND user_id = ? AND stream_id < ?
+ UPDATE device_lists_outbound_last_success
+ SET stream_id = ?
+ WHERE destination = ? AND user_id = ?
"""
txn.executemany(
- sql, ((destination, row[0], row[1],) for row in rows)
+ sql, ((row[1], destination, row[0],) for row in rows if row[2])
)
- # Mark everything that is left as sent
sql = """
- UPDATE device_lists_outbound_pokes SET sent = ?
+ INSERT INTO device_lists_outbound_last_success
+ (destination, user_id, stream_id) VALUES (?, ?, ?)
+ """
+ txn.executemany(
+ sql, ((destination, row[0], row[1],) for row in rows if not row[2])
+ )
+
+ # Delete all sent outbound pokes
+ sql = """
+ DELETE FROM device_lists_outbound_pokes
WHERE destination = ? AND stream_id <= ?
"""
- txn.execute(sql, (True, destination, stream_id,))
+ txn.execute(sql, (destination, stream_id,))
@defer.inlineCallbacks
def get_user_whose_devices_changed(self, from_key):
@@ -670,6 +681,14 @@ class DeviceStore(SQLBaseStore):
)
)
+ # Since we've deleted unsent deltas, we need to remove the entry
+ # of last successful sent so that the prev_ids are correctly set.
+ sql = """
+ DELETE FROM device_lists_outbound_last_success
+ WHERE destination = ? AND user_id = ?
+ """
+ txn.executemany(sql, ((row[0], row[1]) for row in rows))
+
logger.info("Pruned %d device list outbound pokes", txn.rowcount)
return self.runInteraction(
diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py
index e00f31da2b..2cebb203c6 100644
--- a/synapse/storage/end_to_end_keys.py
+++ b/synapse/storage/end_to_end_keys.py
@@ -185,8 +185,8 @@ class EndToEndKeyStore(SQLBaseStore):
for algorithm, key_id, json_bytes in new_keys
],
)
- txn.call_after(
- self.count_e2e_one_time_keys.invalidate, (user_id, device_id,)
+ self._invalidate_cache_and_stream(
+ txn, self.count_e2e_one_time_keys, (user_id, device_id,)
)
yield self.runInteraction(
"add_e2e_one_time_keys_insert", _add_e2e_one_time_keys
@@ -237,24 +237,29 @@ class EndToEndKeyStore(SQLBaseStore):
)
for user_id, device_id, algorithm, key_id in delete:
txn.execute(sql, (user_id, device_id, algorithm, key_id))
- txn.call_after(
- self.count_e2e_one_time_keys.invalidate, (user_id, device_id,)
+ self._invalidate_cache_and_stream(
+ txn, self.count_e2e_one_time_keys, (user_id, device_id,)
)
return result
return self.runInteraction(
"claim_e2e_one_time_keys", _claim_e2e_one_time_keys
)
- @defer.inlineCallbacks
def delete_e2e_keys_by_device(self, user_id, device_id):
- yield self._simple_delete(
- table="e2e_device_keys_json",
- keyvalues={"user_id": user_id, "device_id": device_id},
- desc="delete_e2e_device_keys_by_device"
- )
- yield self._simple_delete(
- table="e2e_one_time_keys_json",
- keyvalues={"user_id": user_id, "device_id": device_id},
- desc="delete_e2e_one_time_keys_by_device"
+ def delete_e2e_keys_by_device_txn(txn):
+ self._simple_delete_txn(
+ txn,
+ table="e2e_device_keys_json",
+ keyvalues={"user_id": user_id, "device_id": device_id},
+ )
+ self._simple_delete_txn(
+ txn,
+ table="e2e_one_time_keys_json",
+ keyvalues={"user_id": user_id, "device_id": device_id},
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.count_e2e_one_time_keys, (user_id, device_id,)
+ )
+ return self.runInteraction(
+ "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
)
- self.count_e2e_one_time_keys.invalidate((user_id, device_id,))
diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
index 519059c306..e8133de2fa 100644
--- a/synapse/storage/event_federation.py
+++ b/synapse/storage/event_federation.py
@@ -37,25 +37,55 @@ class EventFederationStore(SQLBaseStore):
and backfilling from another server respectively.
"""
+ EVENT_AUTH_STATE_ONLY = "event_auth_state_only"
+
def __init__(self, hs):
super(EventFederationStore, self).__init__(hs)
+ self.register_background_update_handler(
+ self.EVENT_AUTH_STATE_ONLY,
+ self._background_delete_non_state_event_auth,
+ )
+
hs.get_clock().looping_call(
self._delete_old_forward_extrem_cache, 60 * 60 * 1000
)
- def get_auth_chain(self, event_ids):
- return self.get_auth_chain_ids(event_ids).addCallback(self._get_events)
+ def get_auth_chain(self, event_ids, include_given=False):
+ """Get auth events for given event_ids. The events *must* be state events.
+
+ Args:
+ event_ids (list): state events
+ include_given (bool): include the given events in result
+
+ Returns:
+ list of events
+ """
+ return self.get_auth_chain_ids(
+ event_ids, include_given=include_given,
+ ).addCallback(self._get_events)
+
+ def get_auth_chain_ids(self, event_ids, include_given=False):
+ """Get auth events for given event_ids. The events *must* be state events.
+
+ Args:
+ event_ids (list): state events
+ include_given (bool): include the given events in result
- def get_auth_chain_ids(self, event_ids):
+ Returns:
+ list of event_ids
+ """
return self.runInteraction(
"get_auth_chain_ids",
self._get_auth_chain_ids_txn,
- event_ids
+ event_ids, include_given
)
- def _get_auth_chain_ids_txn(self, txn, event_ids):
- results = set()
+ def _get_auth_chain_ids_txn(self, txn, event_ids, include_given):
+ if include_given:
+ results = set(event_ids)
+ else:
+ results = set()
base_sql = (
"SELECT auth_id FROM event_auth WHERE event_id IN (%s)"
@@ -504,3 +534,52 @@ class EventFederationStore(SQLBaseStore):
txn.execute(query, (room_id,))
txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,))
+
+ @defer.inlineCallbacks
+ def _background_delete_non_state_event_auth(self, progress, batch_size):
+ def delete_event_auth(txn):
+ target_min_stream_id = progress.get("target_min_stream_id_inclusive")
+ max_stream_id = progress.get("max_stream_id_exclusive")
+
+ if not target_min_stream_id or not max_stream_id:
+ txn.execute("SELECT COALESCE(MIN(stream_ordering), 0) FROM events")
+ rows = txn.fetchall()
+ target_min_stream_id = rows[0][0]
+
+ txn.execute("SELECT COALESCE(MAX(stream_ordering), 0) FROM events")
+ rows = txn.fetchall()
+ max_stream_id = rows[0][0]
+
+ min_stream_id = max_stream_id - batch_size
+
+ sql = """
+ DELETE FROM event_auth
+ WHERE event_id IN (
+ SELECT event_id FROM events
+ LEFT JOIN state_events USING (room_id, event_id)
+ WHERE ? <= stream_ordering AND stream_ordering < ?
+ AND state_key IS null
+ )
+ """
+
+ txn.execute(sql, (min_stream_id, max_stream_id,))
+
+ new_progress = {
+ "target_min_stream_id_inclusive": target_min_stream_id,
+ "max_stream_id_exclusive": min_stream_id,
+ }
+
+ self._background_update_progress_txn(
+ txn, self.EVENT_AUTH_STATE_ONLY, new_progress
+ )
+
+ return min_stream_id >= target_min_stream_id
+
+ result = yield self.runInteraction(
+ self.EVENT_AUTH_STATE_ONLY, delete_event_auth
+ )
+
+ if not result:
+ yield self._end_background_update(self.EVENT_AUTH_STATE_ONLY)
+
+ defer.returnValue(batch_size)
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 73283eb4c7..c80d181fc7 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -648,9 +648,10 @@ class EventsStore(SQLBaseStore):
list of the event ids which are the forward extremities.
"""
- self._update_current_state_txn(txn, current_state_for_room)
-
max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering
+
+ self._update_current_state_txn(txn, current_state_for_room, max_stream_order)
+
self._update_forward_extremities_txn(
txn,
new_forward_extremities=new_forward_extremeties,
@@ -713,7 +714,7 @@ class EventsStore(SQLBaseStore):
backfilled=backfilled,
)
- def _update_current_state_txn(self, txn, state_delta_by_room):
+ def _update_current_state_txn(self, txn, state_delta_by_room, max_stream_order):
for room_id, current_state_tuple in state_delta_by_room.iteritems():
to_delete, to_insert, _ = current_state_tuple
txn.executemany(
@@ -735,6 +736,29 @@ class EventsStore(SQLBaseStore):
],
)
+ state_deltas = {key: None for key in to_delete}
+ state_deltas.update(to_insert)
+
+ self._simple_insert_many_txn(
+ txn,
+ table="current_state_delta_stream",
+ values=[
+ {
+ "stream_id": max_stream_order,
+ "room_id": room_id,
+ "type": key[0],
+ "state_key": key[1],
+ "event_id": ev_id,
+ "prev_event_id": to_delete.get(key, None),
+ }
+ for key, ev_id in state_deltas.iteritems()
+ ]
+ )
+
+ self._curr_state_delta_stream_cache.entity_has_changed(
+ room_id, max_stream_order,
+ )
+
# Invalidate the various caches
# Figure out the changes of membership to invalidate the
@@ -743,11 +767,7 @@ class EventsStore(SQLBaseStore):
# and which we have added, then we invlidate the caches for all
# those users.
members_changed = set(
- state_key for ev_type, state_key in to_delete.iterkeys()
- if ev_type == EventTypes.Member
- )
- members_changed.update(
- state_key for ev_type, state_key in to_insert.iterkeys()
+ state_key for ev_type, state_key in state_deltas
if ev_type == EventTypes.Member
)
@@ -1120,6 +1140,7 @@ class EventsStore(SQLBaseStore):
}
for event, _ in events_and_contexts
for auth_id, _ in event.auth_events
+ if event.is_state()
],
)
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 6e623843d5..eaba699e29 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts.
-SCHEMA_VERSION = 41
+SCHEMA_VERSION = 42
dir_path = os.path.abspath(os.path.dirname(__file__))
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index 0a819d32c5..8758b1c0c7 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -49,7 +49,7 @@ def _load_rules(rawrules, enabled_map):
class PushRuleStore(SQLBaseStore):
- @cachedInlineCallbacks()
+ @cachedInlineCallbacks(max_entries=5000)
def get_push_rules_for_user(self, user_id):
rows = yield self._simple_select_list(
table="push_rules",
@@ -73,7 +73,7 @@ class PushRuleStore(SQLBaseStore):
defer.returnValue(rules)
- @cachedInlineCallbacks()
+ @cachedInlineCallbacks(max_entries=5000)
def get_push_rules_enabled_for_user(self, user_id):
results = yield self._simple_select_list(
table="push_rules_enable",
diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py
index efb90c3c91..f42b8014c7 100644
--- a/synapse/storage/receipts.py
+++ b/synapse/storage/receipts.py
@@ -45,7 +45,9 @@ class ReceiptsStore(SQLBaseStore):
return
# Returns an ObservableDeferred
- res = self.get_users_with_read_receipts_in_room.cache.get((room_id,), None)
+ res = self.get_users_with_read_receipts_in_room.cache.get(
+ room_id, None, update_metrics=False,
+ )
if res:
if isinstance(res, defer.Deferred) and res.called:
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 0829ae5bee..8656455f6e 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -18,6 +18,7 @@ from twisted.internet import defer
from collections import namedtuple
from ._base import SQLBaseStore
+from synapse.util.async import Linearizer
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
from synapse.util.stringutils import to_ascii
@@ -392,7 +393,8 @@ class RoomMemberStore(SQLBaseStore):
context=context,
)
- def get_joined_users_from_state(self, room_id, state_group, state_ids):
+ def get_joined_users_from_state(self, room_id, state_entry):
+ state_group = state_entry.state_group
if not state_group:
# If state_group is None it means it has yet to be assigned a
# state group, i.e. we need to make sure that calls with a state_group
@@ -401,7 +403,7 @@ class RoomMemberStore(SQLBaseStore):
state_group = object()
return self._get_joined_users_from_context(
- room_id, state_group, state_ids,
+ room_id, state_group, state_entry.state, context=state_entry,
)
@cachedInlineCallbacks(num_args=2, cache_context=True, iterable=True,
@@ -534,7 +536,8 @@ class RoomMemberStore(SQLBaseStore):
defer.returnValue(False)
- def get_joined_hosts(self, room_id, state_group, state_ids):
+ def get_joined_hosts(self, room_id, state_entry):
+ state_group = state_entry.state_group
if not state_group:
# If state_group is None it means it has yet to be assigned a
# state group, i.e. we need to make sure that calls with a state_group
@@ -543,33 +546,20 @@ class RoomMemberStore(SQLBaseStore):
state_group = object()
return self._get_joined_hosts(
- room_id, state_group, state_ids
+ room_id, state_group, state_entry.state, state_entry=state_entry,
)
@cachedInlineCallbacks(num_args=2, max_entries=10000, iterable=True)
- def _get_joined_hosts(self, room_id, state_group, current_state_ids):
+ # @defer.inlineCallbacks
+ def _get_joined_hosts(self, room_id, state_group, current_state_ids, state_entry):
# We don't use `state_group`, its there so that we can cache based
# on it. However, its important that its never None, since two current_state's
# with a state_group of None are likely to be different.
# See bulk_get_push_rules_for_room for how we work around this.
assert state_group is not None
- joined_hosts = set()
- for etype, state_key in current_state_ids:
- if etype == EventTypes.Member:
- try:
- host = get_domain_from_id(state_key)
- except:
- logger.warn("state_key not user_id: %s", state_key)
- continue
-
- if host in joined_hosts:
- continue
-
- event_id = current_state_ids[(etype, state_key)]
- event = yield self.get_event(event_id, allow_none=True)
- if event and event.content["membership"] == Membership.JOIN:
- joined_hosts.add(intern_string(host))
+ cache = self._get_joined_hosts_cache(room_id)
+ joined_hosts = yield cache.get_destinations(state_entry)
defer.returnValue(joined_hosts)
@@ -647,3 +637,75 @@ class RoomMemberStore(SQLBaseStore):
yield self._end_background_update(_MEMBERSHIP_PROFILE_UPDATE_NAME)
defer.returnValue(result)
+
+ @cached(max_entries=10000, iterable=True)
+ def _get_joined_hosts_cache(self, room_id):
+ return _JoinedHostsCache(self, room_id)
+
+
+class _JoinedHostsCache(object):
+ """Cache for joined hosts in a room that is optimised to handle updates
+ via state deltas.
+ """
+
+ def __init__(self, store, room_id):
+ self.store = store
+ self.room_id = room_id
+
+ self.hosts_to_joined_users = {}
+
+ self.state_group = object()
+
+ self.linearizer = Linearizer("_JoinedHostsCache")
+
+ self._len = 0
+
+ @defer.inlineCallbacks
+ def get_destinations(self, state_entry):
+ """Get set of destinations for a state entry
+
+ Args:
+ state_entry(synapse.state._StateCacheEntry)
+ """
+ if state_entry.state_group == self.state_group:
+ defer.returnValue(frozenset(self.hosts_to_joined_users))
+
+ with (yield self.linearizer.queue(())):
+ if state_entry.state_group == self.state_group:
+ pass
+ elif state_entry.prev_group == self.state_group:
+ for (typ, state_key), event_id in state_entry.delta_ids.iteritems():
+ if typ != EventTypes.Member:
+ continue
+
+ host = intern_string(get_domain_from_id(state_key))
+ user_id = state_key
+ known_joins = self.hosts_to_joined_users.setdefault(host, set())
+
+ event = yield self.store.get_event(event_id)
+ if event.membership == Membership.JOIN:
+ known_joins.add(user_id)
+ else:
+ known_joins.discard(user_id)
+
+ if not known_joins:
+ self.hosts_to_joined_users.pop(host, None)
+ else:
+ joined_users = yield self.store.get_joined_users_from_state(
+ self.room_id, state_entry,
+ )
+
+ self.hosts_to_joined_users = {}
+ for user_id in joined_users:
+ host = intern_string(get_domain_from_id(user_id))
+ self.hosts_to_joined_users.setdefault(host, set()).add(user_id)
+
+ if state_entry.state_group:
+ self.state_group = state_entry.state_group
+ else:
+ self.state_group = object()
+ self._len = sum(len(v) for v in self.hosts_to_joined_users.itervalues())
+ defer.returnValue(frozenset(self.hosts_to_joined_users))
+
+ def __len__(self):
+ return self._len
diff --git a/synapse/storage/schema/delta/42/current_state_delta.sql b/synapse/storage/schema/delta/42/current_state_delta.sql
new file mode 100644
index 0000000000..d28851aff8
--- /dev/null
+++ b/synapse/storage/schema/delta/42/current_state_delta.sql
@@ -0,0 +1,26 @@
+/* Copyright 2017 Vector Creations Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+CREATE TABLE current_state_delta_stream (
+ stream_id BIGINT NOT NULL,
+ room_id TEXT NOT NULL,
+ type TEXT NOT NULL,
+ state_key TEXT NOT NULL,
+ event_id TEXT, -- Is null if the key was removed
+ prev_event_id TEXT -- Is null if the key was added
+);
+
+CREATE INDEX current_state_delta_stream_idx ON current_state_delta_stream(stream_id);
diff --git a/synapse/storage/schema/delta/42/device_list_last_id.sql b/synapse/storage/schema/delta/42/device_list_last_id.sql
new file mode 100644
index 0000000000..9ab8c14fa3
--- /dev/null
+++ b/synapse/storage/schema/delta/42/device_list_last_id.sql
@@ -0,0 +1,33 @@
+/* Copyright 2017 Vector Creations Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+-- Table of last stream_id that we sent to destination for user_id. This is
+-- used to fill out the `prev_id` fields of outbound device list updates.
+CREATE TABLE device_lists_outbound_last_success (
+ destination TEXT NOT NULL,
+ user_id TEXT NOT NULL,
+ stream_id BIGINT NOT NULL
+);
+
+INSERT INTO device_lists_outbound_last_success
+ SELECT destination, user_id, coalesce(max(stream_id), 0) as stream_id
+ FROM device_lists_outbound_pokes
+ WHERE sent = (1 = 1) -- sqlite doesn't have inbuilt boolean values
+ GROUP BY destination, user_id;
+
+CREATE INDEX device_lists_outbound_last_success_idx ON device_lists_outbound_last_success(
+ destination, user_id, stream_id
+);
diff --git a/synapse/storage/schema/delta/42/event_auth_state_only.sql b/synapse/storage/schema/delta/42/event_auth_state_only.sql
new file mode 100644
index 0000000000..b8821ac759
--- /dev/null
+++ b/synapse/storage/schema/delta/42/event_auth_state_only.sql
@@ -0,0 +1,17 @@
+/* Copyright 2017 Vector Creations Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+INSERT INTO background_updates (update_name, progress_json) VALUES
+ ('event_auth_state_only', '{}');
diff --git a/synapse/storage/schema/delta/42/user_dir.py b/synapse/storage/schema/delta/42/user_dir.py
new file mode 100644
index 0000000000..ea6a18196d
--- /dev/null
+++ b/synapse/storage/schema/delta/42/user_dir.py
@@ -0,0 +1,84 @@
+# Copyright 2017 Vector Creations Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from synapse.storage.prepare_database import get_statements
+from synapse.storage.engines import PostgresEngine, Sqlite3Engine
+
+logger = logging.getLogger(__name__)
+
+
+BOTH_TABLES = """
+CREATE TABLE user_directory_stream_pos (
+ Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row.
+ stream_id BIGINT,
+ CHECK (Lock='X')
+);
+
+INSERT INTO user_directory_stream_pos (stream_id) VALUES (null);
+
+CREATE TABLE user_directory (
+ user_id TEXT NOT NULL,
+ room_id TEXT NOT NULL, -- A room_id that we know the user is joined to
+ display_name TEXT,
+ avatar_url TEXT
+);
+
+CREATE INDEX user_directory_room_idx ON user_directory(room_id);
+CREATE UNIQUE INDEX user_directory_user_idx ON user_directory(user_id);
+
+CREATE TABLE users_in_pubic_room (
+ user_id TEXT NOT NULL,
+ room_id TEXT NOT NULL -- A room_id that we know is public
+);
+
+CREATE INDEX users_in_pubic_room_room_idx ON users_in_pubic_room(room_id);
+CREATE UNIQUE INDEX users_in_pubic_room_user_idx ON users_in_pubic_room(user_id);
+"""
+
+
+POSTGRES_TABLE = """
+CREATE TABLE user_directory_search (
+ user_id TEXT NOT NULL,
+ vector tsvector
+);
+
+CREATE INDEX user_directory_search_fts_idx ON user_directory_search USING gin(vector);
+CREATE UNIQUE INDEX user_directory_search_user_idx ON user_directory_search(user_id);
+"""
+
+
+SQLITE_TABLE = """
+CREATE VIRTUAL TABLE user_directory_search
+ USING fts4 ( user_id, value );
+"""
+
+
+def run_create(cur, database_engine, *args, **kwargs):
+ for statement in get_statements(BOTH_TABLES.splitlines()):
+ cur.execute(statement)
+
+ if isinstance(database_engine, PostgresEngine):
+ for statement in get_statements(POSTGRES_TABLE.splitlines()):
+ cur.execute(statement)
+ elif isinstance(database_engine, Sqlite3Engine):
+ for statement in get_statements(SQLITE_TABLE.splitlines()):
+ cur.execute(statement)
+ else:
+ raise Exception("Unrecognized database engine")
+
+
+def run_upgrade(*args, **kwargs):
+ pass
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 85acf2ad1e..c3eecbe824 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -98,6 +98,45 @@ class StateStore(SQLBaseStore):
_get_current_state_ids_txn,
)
+ def get_state_group_delta(self, state_group):
+ """Given a state group try to return a previous group and a delta between
+ the old and the new.
+
+ Returns:
+ (prev_group, delta_ids), where both may be None.
+ """
+ def _get_state_group_delta_txn(txn):
+ prev_group = self._simple_select_one_onecol_txn(
+ txn,
+ table="state_group_edges",
+ keyvalues={
+ "state_group": state_group,
+ },
+ retcol="prev_state_group",
+ allow_none=True,
+ )
+
+ if not prev_group:
+ return None, None
+
+ delta_ids = self._simple_select_list_txn(
+ txn,
+ table="state_groups_state",
+ keyvalues={
+ "state_group": state_group,
+ },
+ retcols=("type", "state_key", "event_id",)
+ )
+
+ return prev_group, {
+ (row["type"], row["state_key"]): row["event_id"]
+ for row in delta_ids
+ }
+ return self.runInteraction(
+ "get_state_group_delta",
+ _get_state_group_delta_txn,
+ )
+
@defer.inlineCallbacks
def get_state_groups_ids(self, room_id, event_ids):
if not event_ids:
@@ -563,20 +602,22 @@ class StateStore(SQLBaseStore):
where a `state_key` of `None` matches all state_keys for the
`type`.
"""
- is_all, state_dict_ids = self._state_group_cache.get(group)
+ is_all, known_absent, state_dict_ids = self._state_group_cache.get(group)
type_to_key = {}
missing_types = set()
+
for typ, state_key in types:
+ key = (typ, state_key)
if state_key is None:
type_to_key[typ] = None
- missing_types.add((typ, state_key))
+ missing_types.add(key)
else:
if type_to_key.get(typ, object()) is not None:
type_to_key.setdefault(typ, set()).add(state_key)
- if (typ, state_key) not in state_dict_ids:
- missing_types.add((typ, state_key))
+ if key not in state_dict_ids and key not in known_absent:
+ missing_types.add(key)
sentinel = object()
@@ -590,7 +631,7 @@ class StateStore(SQLBaseStore):
return True
return False
- got_all = not (missing_types or types is None)
+ got_all = is_all or not missing_types
return {
k: v for k, v in state_dict_ids.iteritems()
@@ -607,7 +648,7 @@ class StateStore(SQLBaseStore):
Args:
group: The state group to lookup
"""
- is_all, state_dict_ids = self._state_group_cache.get(group)
+ is_all, _, state_dict_ids = self._state_group_cache.get(group)
return state_dict_ids, is_all
@@ -624,7 +665,7 @@ class StateStore(SQLBaseStore):
missing_groups = []
if types is not None:
for group in set(groups):
- state_dict_ids, missing_types, got_all = self._get_some_state_from_cache(
+ state_dict_ids, _, got_all = self._get_some_state_from_cache(
group, types
)
results[group] = state_dict_ids
@@ -653,19 +694,7 @@ class StateStore(SQLBaseStore):
# Now we want to update the cache with all the things we fetched
# from the database.
for group, group_state_dict in group_to_state_dict.iteritems():
- if types:
- # We delibrately put key -> None mappings into the cache to
- # cache absence of the key, on the assumption that if we've
- # explicitly asked for some types then we will probably ask
- # for them again.
- state_dict = {
- (intern_string(etype), intern_string(state_key)): None
- for (etype, state_key) in types
- }
- state_dict.update(results[group])
- results[group] = state_dict
- else:
- state_dict = results[group]
+ state_dict = results[group]
state_dict.update(
((intern_string(k[0]), intern_string(k[1])), to_ascii(v))
@@ -677,17 +706,9 @@ class StateStore(SQLBaseStore):
key=group,
value=state_dict,
full=(types is None),
+ known_absent=types,
)
- # Remove all the entries with None values. The None values were just
- # used for bookkeeping in the cache.
- for group, state_dict in results.iteritems():
- results[group] = {
- key: event_id
- for key, event_id in state_dict.iteritems()
- if event_id
- }
-
defer.returnValue(results)
def get_next_state_group(self):
diff --git a/synapse/storage/user_directory.py b/synapse/storage/user_directory.py
new file mode 100644
index 0000000000..6a4bf63f0d
--- /dev/null
+++ b/synapse/storage/user_directory.py
@@ -0,0 +1,461 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 Vector Creations Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+from ._base import SQLBaseStore
+from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
+from synapse.api.constants import EventTypes, JoinRules
+from synapse.storage.engines import PostgresEngine, Sqlite3Engine
+from synapse.types import get_domain_from_id, get_localpart_from_id
+
+import re
+
+
+class UserDirectoryStore(SQLBaseStore):
+
+ @cachedInlineCallbacks(cache_context=True)
+ def is_room_world_readable_or_publicly_joinable(self, room_id, cache_context):
+ """Check if the room is either world_readable or publically joinable
+ """
+ current_state_ids = yield self.get_current_state_ids(
+ room_id, on_invalidate=cache_context.invalidate
+ )
+
+ join_rules_id = current_state_ids.get((EventTypes.JoinRules, ""))
+ if join_rules_id:
+ join_rule_ev = yield self.get_event(join_rules_id, allow_none=True)
+ if join_rule_ev:
+ if join_rule_ev.content.get("join_rule") == JoinRules.PUBLIC:
+ defer.returnValue(True)
+
+ hist_vis_id = current_state_ids.get((EventTypes.RoomHistoryVisibility, ""))
+ if hist_vis_id:
+ hist_vis_ev = yield self.get_event(hist_vis_id, allow_none=True)
+ if hist_vis_ev:
+ if hist_vis_ev.content.get("history_visibility") == "world_readable":
+ defer.returnValue(True)
+
+ defer.returnValue(False)
+
+ @defer.inlineCallbacks
+ def add_users_to_public_room(self, room_id, user_ids):
+ """Add user to the list of users in public rooms
+
+ Args:
+ room_id (str): A room_id that all users are in that is world_readable
+ or publically joinable
+ user_ids (list(str)): Users to add
+ """
+ yield self._simple_insert_many(
+ table="users_in_pubic_room",
+ values=[
+ {
+ "user_id": user_id,
+ "room_id": room_id,
+ }
+ for user_id in user_ids
+ ],
+ desc="add_users_to_public_room"
+ )
+ for user_id in user_ids:
+ self.get_user_in_public_room.invalidate((user_id,))
+
+ def add_profiles_to_user_dir(self, room_id, users_with_profile):
+ """Add profiles to the user directory
+
+ Args:
+ room_id (str): A room_id that all users are joined to
+ users_with_profile (dict): Users to add to directory in the form of
+ mapping of user_id -> ProfileInfo
+ """
+ if isinstance(self.database_engine, PostgresEngine):
+ # We weight the loclpart most highly, then display name and finally
+ # server name
+ sql = """
+ INSERT INTO user_directory_search(user_id, vector)
+ VALUES (?,
+ setweight(to_tsvector('english', ?), 'A')
+ || setweight(to_tsvector('english', ?), 'D')
+ || setweight(to_tsvector('english', COALESCE(?, '')), 'B')
+ )
+ """
+ args = (
+ (
+ user_id, get_localpart_from_id(user_id), get_domain_from_id(user_id),
+ profile.display_name,
+ )
+ for user_id, profile in users_with_profile.iteritems()
+ )
+ elif isinstance(self.database_engine, Sqlite3Engine):
+ sql = """
+ INSERT INTO user_directory_search(user_id, value)
+ VALUES (?,?)
+ """
+ args = (
+ (
+ user_id,
+ "%s %s" % (user_id, p.display_name,) if p.display_name else user_id
+ )
+ for user_id, p in users_with_profile.iteritems()
+ )
+ else:
+ # This should be unreachable.
+ raise Exception("Unrecognized database engine")
+
+ def _add_profiles_to_user_dir_txn(txn):
+ txn.executemany(sql, args)
+ self._simple_insert_many_txn(
+ txn,
+ table="user_directory",
+ values=[
+ {
+ "user_id": user_id,
+ "room_id": room_id,
+ "display_name": profile.display_name,
+ "avatar_url": profile.avatar_url,
+ }
+ for user_id, profile in users_with_profile.iteritems()
+ ]
+ )
+ for user_id in users_with_profile:
+ txn.call_after(
+ self.get_user_in_directory.invalidate, (user_id,)
+ )
+
+ return self.runInteraction(
+ "add_profiles_to_user_dir", _add_profiles_to_user_dir_txn
+ )
+
+ @defer.inlineCallbacks
+ def update_user_in_user_dir(self, user_id, room_id):
+ yield self._simple_update_one(
+ table="user_directory",
+ keyvalues={"user_id": user_id},
+ updatevalues={"room_id": room_id},
+ desc="update_user_in_user_dir",
+ )
+ self.get_user_in_directory.invalidate((user_id,))
+
+ def update_profile_in_user_dir(self, user_id, display_name, avatar_url):
+ def _update_profile_in_user_dir_txn(txn):
+ self._simple_update_one_txn(
+ txn,
+ table="user_directory",
+ keyvalues={"user_id": user_id},
+ updatevalues={"display_name": display_name, "avatar_url": avatar_url},
+ )
+
+ if isinstance(self.database_engine, PostgresEngine):
+ # We weight the loclpart most highly, then display name and finally
+ # server name
+ sql = """
+ UPDATE user_directory_search
+ SET vector = setweight(to_tsvector('english', ?), 'A')
+ || setweight(to_tsvector('english', ?), 'D')
+ || setweight(to_tsvector('english', COALESCE(?, '')), 'B')
+ WHERE user_id = ?
+ """
+ args = (
+ get_localpart_from_id(user_id), get_domain_from_id(user_id),
+ display_name,
+ user_id,
+ )
+ elif isinstance(self.database_engine, Sqlite3Engine):
+ sql = """
+ UPDATE user_directory_search
+ set value = ?
+ WHERE user_id = ?
+ """
+ args = (
+ "%s %s" % (user_id, display_name,) if display_name else user_id,
+ user_id,
+ )
+ else:
+ # This should be unreachable.
+ raise Exception("Unrecognized database engine")
+
+ txn.execute(sql, args)
+
+ txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
+
+ return self.runInteraction(
+ "update_profile_in_user_dir", _update_profile_in_user_dir_txn
+ )
+
+ @defer.inlineCallbacks
+ def update_user_in_public_user_list(self, user_id, room_id):
+ yield self._simple_update_one(
+ table="users_in_pubic_room",
+ keyvalues={"user_id": user_id},
+ updatevalues={"room_id": room_id},
+ desc="update_user_in_public_user_list",
+ )
+ self.get_user_in_public_room.invalidate((user_id,))
+
+ def remove_from_user_dir(self, user_id):
+ def _remove_from_user_dir_txn(txn):
+ self._simple_delete_txn(
+ txn,
+ table="user_directory",
+ keyvalues={"user_id": user_id},
+ )
+ self._simple_delete_txn(
+ txn,
+ table="user_directory_search",
+ keyvalues={"user_id": user_id},
+ )
+ self._simple_delete_txn(
+ txn,
+ table="users_in_pubic_room",
+ keyvalues={"user_id": user_id},
+ )
+ txn.call_after(
+ self.get_user_in_directory.invalidate, (user_id,)
+ )
+ txn.call_after(
+ self.get_user_in_public_room.invalidate, (user_id,)
+ )
+ return self.runInteraction(
+ "remove_from_user_dir", _remove_from_user_dir_txn,
+ )
+
+ @defer.inlineCallbacks
+ def remove_from_user_in_public_room(self, user_id):
+ yield self._simple_delete(
+ table="users_in_pubic_room",
+ keyvalues={"user_id": user_id},
+ desc="remove_from_user_in_public_room",
+ )
+ self.get_user_in_public_room.invalidate((user_id,))
+
+ def get_users_in_public_due_to_room(self, room_id):
+ """Get all user_ids that are in the room directory becuase they're
+ in the given room_id
+ """
+ return self._simple_select_onecol(
+ table="users_in_pubic_room",
+ keyvalues={"room_id": room_id},
+ retcol="user_id",
+ desc="get_users_in_public_due_to_room",
+ )
+
+ def get_users_in_dir_due_to_room(self, room_id):
+ """Get all user_ids that are in the room directory becuase they're
+ in the given room_id
+ """
+ return self._simple_select_onecol(
+ table="user_directory",
+ keyvalues={"room_id": room_id},
+ retcol="user_id",
+ desc="get_users_in_dir_due_to_room",
+ )
+
+ def get_all_rooms(self):
+ """Get all room_ids we've ever known about
+ """
+ return self._simple_select_onecol(
+ table="current_state_events",
+ keyvalues={},
+ retcol="DISTINCT room_id",
+ desc="get_all_rooms",
+ )
+
+ def delete_all_from_user_dir(self):
+ """Delete the entire user directory
+ """
+ def _delete_all_from_user_dir_txn(txn):
+ txn.execute("DELETE FROM user_directory")
+ txn.execute("DELETE FROM user_directory_search")
+ txn.execute("DELETE FROM users_in_pubic_room")
+ txn.call_after(self.get_user_in_directory.invalidate_all)
+ txn.call_after(self.get_user_in_public_room.invalidate_all)
+ return self.runInteraction(
+ "delete_all_from_user_dir", _delete_all_from_user_dir_txn
+ )
+
+ @cached()
+ def get_user_in_directory(self, user_id):
+ return self._simple_select_one(
+ table="user_directory",
+ keyvalues={"user_id": user_id},
+ retcols=("room_id", "display_name", "avatar_url",),
+ allow_none=True,
+ desc="get_user_in_directory",
+ )
+
+ @cached()
+ def get_user_in_public_room(self, user_id):
+ return self._simple_select_one(
+ table="users_in_pubic_room",
+ keyvalues={"user_id": user_id},
+ retcols=("room_id",),
+ allow_none=True,
+ desc="get_user_in_public_room",
+ )
+
+ def get_user_directory_stream_pos(self):
+ return self._simple_select_one_onecol(
+ table="user_directory_stream_pos",
+ keyvalues={},
+ retcol="stream_id",
+ desc="get_user_directory_stream_pos",
+ )
+
+ def update_user_directory_stream_pos(self, stream_id):
+ return self._simple_update_one(
+ table="user_directory_stream_pos",
+ keyvalues={},
+ updatevalues={"stream_id": stream_id},
+ desc="update_user_directory_stream_pos",
+ )
+
+ def get_current_state_deltas(self, prev_stream_id):
+ prev_stream_id = int(prev_stream_id)
+ if not self._curr_state_delta_stream_cache.has_any_entity_changed(prev_stream_id):
+ return []
+
+ def get_current_state_deltas_txn(txn):
+ # First we calculate the max stream id that will give us less than
+ # N results.
+ # We arbitarily limit to 100 stream_id entries to ensure we don't
+ # select toooo many.
+ sql = """
+ SELECT stream_id, count(*)
+ FROM current_state_delta_stream
+ WHERE stream_id > ?
+ GROUP BY stream_id
+ ORDER BY stream_id ASC
+ LIMIT 100
+ """
+ txn.execute(sql, (prev_stream_id,))
+
+ total = 0
+ max_stream_id = prev_stream_id
+ for max_stream_id, count in txn:
+ total += count
+ if total > 100:
+ # We arbitarily limit to 100 entries to ensure we don't
+ # select toooo many.
+ break
+
+ # Now actually get the deltas
+ sql = """
+ SELECT stream_id, room_id, type, state_key, event_id, prev_event_id
+ FROM current_state_delta_stream
+ WHERE ? < stream_id AND stream_id <= ?
+ ORDER BY stream_id ASC
+ """
+ txn.execute(sql, (prev_stream_id, max_stream_id,))
+ return self.cursor_to_dict(txn)
+
+ return self.runInteraction(
+ "get_current_state_deltas", get_current_state_deltas_txn
+ )
+
+ def get_max_stream_id_in_current_state_deltas(self):
+ return self._simple_select_one_onecol(
+ table="current_state_delta_stream",
+ keyvalues={},
+ retcol="COALESCE(MAX(stream_id), -1)",
+ desc="get_max_stream_id_in_current_state_deltas",
+ )
+
+ @defer.inlineCallbacks
+ def search_user_dir(self, search_term, limit):
+ """Searches for users in directory
+
+ Returns:
+ dict of the form::
+
+ {
+ "limited": <bool>, # whether there were more results or not
+ "results": [ # Ordered by best match first
+ {
+ "user_id": <user_id>,
+ "display_name": <display_name>,
+ "avatar_url": <avatar_url>
+ }
+ ]
+ }
+ """
+
+ search_query = _parse_query(self.database_engine, search_term)
+
+ if isinstance(self.database_engine, PostgresEngine):
+ # We order by rank and then if they have profile info
+ sql = """
+ SELECT user_id, display_name, avatar_url
+ FROM user_directory_search
+ INNER JOIN user_directory USING (user_id)
+ INNER JOIN users_in_pubic_room USING (user_id)
+ WHERE vector @@ to_tsquery('english', ?)
+ ORDER BY
+ ts_rank_cd(vector, to_tsquery('english', ?), 1) DESC,
+ display_name IS NULL,
+ avatar_url IS NULL
+ LIMIT ?
+ """
+ args = (search_query, search_query, limit + 1,)
+ elif isinstance(self.database_engine, Sqlite3Engine):
+ sql = """
+ SELECT user_id, display_name, avatar_url
+ FROM user_directory_search
+ INNER JOIN user_directory USING (user_id)
+ INNER JOIN users_in_pubic_room USING (user_id)
+ WHERE value MATCH ?
+ ORDER BY
+ rank(matchinfo(user_directory_search)) DESC,
+ display_name IS NULL,
+ avatar_url IS NULL
+ LIMIT ?
+ """
+ args = (search_query, limit + 1)
+ else:
+ # This should be unreachable.
+ raise Exception("Unrecognized database engine")
+
+ results = yield self._execute(
+ "search_user_dir", self.cursor_to_dict, sql, *args
+ )
+
+ limited = len(results) > limit
+
+ defer.returnValue({
+ "limited": limited,
+ "results": results,
+ })
+
+
+def _parse_query(database_engine, search_term):
+ """Takes a plain unicode string from the user and converts it into a form
+ that can be passed to database.
+ We use this so that we can add prefix matching, which isn't something
+ that is supported by default.
+
+ We specifically add both a prefix and non prefix matching term so that
+ exact matches get ranked higher.
+ """
+
+ # Pull out the individual words, discarding any non-word characters.
+ results = re.findall(r"([\w\-]+)", search_term, re.UNICODE)
+
+ if isinstance(database_engine, PostgresEngine):
+ return " & ".join("(%s:* | %s)" % (result, result,) for result in results)
+ elif isinstance(database_engine, Sqlite3Engine):
+ return " & ".join("(%s* | %s)" % (result, result,) for result in results)
+ else:
+ # This should be unreachable.
+ raise Exception("Unrecognized database engine")
|