diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 2970df138b..b92472df33 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -49,6 +49,7 @@ from .tags import TagsStore
from .account_data import AccountDataStore
from .openid import OpenIdStore
from .client_ips import ClientIpStore
+from .user_directory import UserDirectoryStore
from .util.id_generators import IdGenerator, StreamIdGenerator, ChainedIdGenerator
from .engines import PostgresEngine
@@ -86,6 +87,7 @@ class DataStore(RoomMemberStore, RoomStore,
ClientIpStore,
DeviceStore,
DeviceInboxStore,
+ UserDirectoryStore,
):
def __init__(self, db_conn, hs):
@@ -221,11 +223,24 @@ class DataStore(RoomMemberStore, RoomStore,
"DeviceListFederationStreamChangeCache", device_list_max,
)
+ curr_state_delta_prefill, min_curr_state_delta_id = self._get_cache_dict(
+ db_conn, "current_state_delta_stream",
+ entity_column="room_id",
+ stream_column="stream_id",
+ max_value=events_max, # As we share the stream id with events token
+ limit=1000,
+ )
+ self._curr_state_delta_stream_cache = StreamChangeCache(
+ "_curr_state_delta_stream_cache", min_curr_state_delta_id,
+ prefilled_cache=curr_state_delta_prefill,
+ )
+
cur = LoggingTransaction(
db_conn.cursor(),
name="_find_stream_orderings_for_times_txn",
database_engine=self.database_engine,
- after_callbacks=[]
+ after_callbacks=[],
+ final_callbacks=[],
)
self._find_stream_orderings_for_times_txn(cur)
cur.close()
@@ -289,16 +304,6 @@ class DataStore(RoomMemberStore, RoomStore,
ret = yield self.runInteraction("count_users", _count_users)
defer.returnValue(ret)
- def get_user_ip_and_agents(self, user):
- return self._simple_select_list(
- table="user_ips",
- keyvalues={"user_id": user.to_string()},
- retcols=[
- "access_token", "ip", "user_agent", "last_seen"
- ],
- desc="get_user_ip_and_agents",
- )
-
def get_users(self):
"""Function to reterive a list of users in users table.
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 58b73af7d2..6f54036d67 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -16,6 +16,7 @@ import logging
from synapse.api.errors import StoreError
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
+from synapse.util.caches import CACHE_SIZE_FACTOR
from synapse.util.caches.dictionary_cache import DictionaryCache
from synapse.util.caches.descriptors import Cache
from synapse.storage.engines import PostgresEngine
@@ -27,10 +28,6 @@ from twisted.internet import defer
import sys
import time
import threading
-import os
-
-
-CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1))
logger = logging.getLogger(__name__)
@@ -52,13 +49,17 @@ class LoggingTransaction(object):
"""An object that almost-transparently proxies for the 'txn' object
passed to the constructor. Adds logging and metrics to the .execute()
method."""
- __slots__ = ["txn", "name", "database_engine", "after_callbacks"]
+ __slots__ = [
+ "txn", "name", "database_engine", "after_callbacks", "final_callbacks",
+ ]
- def __init__(self, txn, name, database_engine, after_callbacks):
+ def __init__(self, txn, name, database_engine, after_callbacks,
+ final_callbacks):
object.__setattr__(self, "txn", txn)
object.__setattr__(self, "name", name)
object.__setattr__(self, "database_engine", database_engine)
object.__setattr__(self, "after_callbacks", after_callbacks)
+ object.__setattr__(self, "final_callbacks", final_callbacks)
def call_after(self, callback, *args, **kwargs):
"""Call the given callback on the main twisted thread after the
@@ -67,6 +68,9 @@ class LoggingTransaction(object):
"""
self.after_callbacks.append((callback, args, kwargs))
+ def call_finally(self, callback, *args, **kwargs):
+ self.final_callbacks.append((callback, args, kwargs))
+
def __getattr__(self, name):
return getattr(self.txn, name)
@@ -217,8 +221,8 @@ class SQLBaseStore(object):
self._clock.looping_call(loop, 10000)
- def _new_transaction(self, conn, desc, after_callbacks, logging_context,
- func, *args, **kwargs):
+ def _new_transaction(self, conn, desc, after_callbacks, final_callbacks,
+ logging_context, func, *args, **kwargs):
start = time.time() * 1000
txn_id = self._TXN_ID
@@ -237,7 +241,8 @@ class SQLBaseStore(object):
try:
txn = conn.cursor()
txn = LoggingTransaction(
- txn, name, self.database_engine, after_callbacks
+ txn, name, self.database_engine, after_callbacks,
+ final_callbacks,
)
r = func(txn, *args, **kwargs)
conn.commit()
@@ -298,6 +303,7 @@ class SQLBaseStore(object):
start_time = time.time() * 1000
after_callbacks = []
+ final_callbacks = []
def inner_func(conn, *args, **kwargs):
with LoggingContext("runInteraction") as context:
@@ -309,7 +315,7 @@ class SQLBaseStore(object):
current_context.copy_to(context)
return self._new_transaction(
- conn, desc, after_callbacks, current_context,
+ conn, desc, after_callbacks, final_callbacks, current_context,
func, *args, **kwargs
)
@@ -318,9 +324,13 @@ class SQLBaseStore(object):
result = yield self._db_pool.runWithConnection(
inner_func, *args, **kwargs
)
- finally:
+
for after_callback, after_args, after_kwargs in after_callbacks:
after_callback(*after_args, **after_kwargs)
+ finally:
+ for after_callback, after_args, after_kwargs in final_callbacks:
+ after_callback(*after_args, **after_kwargs)
+
defer.returnValue(result)
@defer.inlineCallbacks
@@ -425,6 +435,11 @@ class SQLBaseStore(object):
txn.execute(sql, vals)
+ def _simple_insert_many(self, table, values, desc):
+ return self.runInteraction(
+ desc, self._simple_insert_many_txn, table, values
+ )
+
@staticmethod
def _simple_insert_many_txn(txn, table, values):
if not values:
@@ -936,7 +951,7 @@ class SQLBaseStore(object):
# __exit__ called after the transaction finishes.
ctx = self._cache_id_gen.get_next()
stream_id = ctx.__enter__()
- txn.call_after(ctx.__exit__, None, None, None)
+ txn.call_finally(ctx.__exit__, None, None, None)
txn.call_after(self.hs.get_notifier().on_new_replication_data)
self._simple_insert_txn(
diff --git a/synapse/storage/appservice.py b/synapse/storage/appservice.py
index 514570561f..c63935cb07 100644
--- a/synapse/storage/appservice.py
+++ b/synapse/storage/appservice.py
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+import re
import simplejson as json
from twisted.internet import defer
@@ -26,6 +27,25 @@ from ._base import SQLBaseStore
logger = logging.getLogger(__name__)
+def _make_exclusive_regex(services_cache):
+ # We precompie a regex constructed from all the regexes that the AS's
+ # have registered for exclusive users.
+ exclusive_user_regexes = [
+ regex.pattern
+ for service in services_cache
+ for regex in service.get_exlusive_user_regexes()
+ ]
+ if exclusive_user_regexes:
+ exclusive_user_regex = "|".join("(" + r + ")" for r in exclusive_user_regexes)
+ exclusive_user_regex = re.compile(exclusive_user_regex)
+ else:
+ # We handle this case specially otherwise the constructed regex
+ # will always match
+ exclusive_user_regex = None
+
+ return exclusive_user_regex
+
+
class ApplicationServiceStore(SQLBaseStore):
def __init__(self, hs):
@@ -35,17 +55,18 @@ class ApplicationServiceStore(SQLBaseStore):
hs.hostname,
hs.config.app_service_config_files
)
+ self.exclusive_user_regex = _make_exclusive_regex(self.services_cache)
def get_app_services(self):
return self.services_cache
def get_if_app_services_interested_in_user(self, user_id):
- """Check if the user is one associated with an app service
+ """Check if the user is one associated with an app service (exclusively)
"""
- for service in self.services_cache:
- if service.is_interested_in_user(user_id):
- return True
- return False
+ if self.exclusive_user_regex:
+ return bool(self.exclusive_user_regex.match(user_id))
+ else:
+ return False
def get_app_service_by_user_id(self, user_id):
"""Retrieve an application service from their user ID.
diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py
index 747d2df622..fc468ea185 100644
--- a/synapse/storage/client_ips.py
+++ b/synapse/storage/client_ips.py
@@ -15,11 +15,14 @@
import logging
-from twisted.internet import defer
+from twisted.internet import defer, reactor
from ._base import Cache
from . import background_updates
+from synapse.util.caches import CACHE_SIZE_FACTOR
+
+
logger = logging.getLogger(__name__)
# Number of msec of granularity to store the user IP 'last seen' time. Smaller
@@ -33,7 +36,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
self.client_ip_last_seen = Cache(
name="client_ip_last_seen",
keylen=4,
- max_entries=5000,
+ max_entries=50000 * CACHE_SIZE_FACTOR,
)
super(ClientIpStore, self).__init__(hs)
@@ -45,7 +48,14 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
columns=["user_id", "device_id", "last_seen"],
)
- @defer.inlineCallbacks
+ # (user_id, access_token, ip) -> (user_agent, device_id, last_seen)
+ self._batch_row_update = {}
+
+ self._client_ip_looper = self._clock.looping_call(
+ self._update_client_ips_batch, 5 * 1000
+ )
+ reactor.addSystemEventTrigger("before", "shutdown", self._update_client_ips_batch)
+
def insert_client_ip(self, user, access_token, ip, user_agent, device_id):
now = int(self._clock.time_msec())
key = (user.to_string(), access_token, ip)
@@ -57,34 +67,48 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
# Rate-limited inserts
if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
- defer.returnValue(None)
+ return
self.client_ip_last_seen.prefill(key, now)
- # It's safe not to lock here: a) no unique constraint,
- # b) LAST_SEEN_GRANULARITY makes concurrent updates incredibly unlikely
- yield self._simple_upsert(
- "user_ips",
- keyvalues={
- "user_id": user.to_string(),
- "access_token": access_token,
- "ip": ip,
- "user_agent": user_agent,
- "device_id": device_id,
- },
- values={
- "last_seen": now,
- },
- desc="insert_client_ip",
- lock=False,
+ self._batch_row_update[key] = (user_agent, device_id, now)
+
+ def _update_client_ips_batch(self):
+ to_update = self._batch_row_update
+ self._batch_row_update = {}
+ return self.runInteraction(
+ "_update_client_ips_batch", self._update_client_ips_batch_txn, to_update
)
+ def _update_client_ips_batch_txn(self, txn, to_update):
+ self.database_engine.lock_table(txn, "user_ips")
+
+ for entry in to_update.iteritems():
+ (user_id, access_token, ip), (user_agent, device_id, last_seen) = entry
+
+ self._simple_upsert_txn(
+ txn,
+ table="user_ips",
+ keyvalues={
+ "user_id": user_id,
+ "access_token": access_token,
+ "ip": ip,
+ "user_agent": user_agent,
+ "device_id": device_id,
+ },
+ values={
+ "last_seen": last_seen,
+ },
+ lock=False,
+ )
+
@defer.inlineCallbacks
- def get_last_client_ip_by_device(self, devices):
+ def get_last_client_ip_by_device(self, user_id, device_id):
"""For each device_id listed, give the user_ip it was last seen on
Args:
- devices (iterable[(str, str)]): list of (user_id, device_id) pairs
+ user_id (str)
+ device_id (str): If None fetches all devices for the user
Returns:
defer.Deferred: resolves to a dict, where the keys
@@ -95,6 +119,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
res = yield self.runInteraction(
"get_last_client_ip_by_device",
self._get_last_client_ip_by_device_txn,
+ user_id, device_id,
retcols=(
"user_id",
"access_token",
@@ -103,23 +128,34 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
"device_id",
"last_seen",
),
- devices=devices
)
ret = {(d["user_id"], d["device_id"]): d for d in res}
+ for key in self._batch_row_update:
+ uid, access_token, ip = key
+ if uid == user_id:
+ user_agent, did, last_seen = self._batch_row_update[key]
+ if not device_id or did == device_id:
+ ret[(user_id, device_id)] = {
+ "user_id": user_id,
+ "access_token": access_token,
+ "ip": ip,
+ "user_agent": user_agent,
+ "device_id": did,
+ "last_seen": last_seen,
+ }
defer.returnValue(ret)
@classmethod
- def _get_last_client_ip_by_device_txn(cls, txn, devices, retcols):
+ def _get_last_client_ip_by_device_txn(cls, txn, user_id, device_id, retcols):
where_clauses = []
bindings = []
- for (user_id, device_id) in devices:
- if device_id is None:
- where_clauses.append("(user_id = ? AND device_id IS NULL)")
- bindings.extend((user_id, ))
- else:
- where_clauses.append("(user_id = ? AND device_id = ?)")
- bindings.extend((user_id, device_id))
+ if device_id is None:
+ where_clauses.append("user_id = ?")
+ bindings.extend((user_id, ))
+ else:
+ where_clauses.append("(user_id = ? AND device_id = ?)")
+ bindings.extend((user_id, device_id))
if not where_clauses:
return []
@@ -147,3 +183,37 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
txn.execute(sql, bindings)
return cls.cursor_to_dict(txn)
+
+ @defer.inlineCallbacks
+ def get_user_ip_and_agents(self, user):
+ user_id = user.to_string()
+ results = {}
+
+ for key in self._batch_row_update:
+ uid, access_token, ip = key
+ if uid == user_id:
+ user_agent, _, last_seen = self._batch_row_update[key]
+ results[(access_token, ip)] = (user_agent, last_seen)
+
+ rows = yield self._simple_select_list(
+ table="user_ips",
+ keyvalues={"user_id": user_id},
+ retcols=[
+ "access_token", "ip", "user_agent", "last_seen"
+ ],
+ desc="get_user_ip_and_agents",
+ )
+
+ results.update(
+ ((row["access_token"], row["ip"]), (row["user_agent"], row["last_seen"]))
+ for row in rows
+ )
+ defer.returnValue(list(
+ {
+ "access_token": access_token,
+ "ip": ip,
+ "user_agent": user_agent,
+ "last_seen": last_seen,
+ }
+ for (access_token, ip), (user_agent, last_seen) in results.iteritems()
+ ))
diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py
index d9936c88bb..bb27fd1f70 100644
--- a/synapse/storage/devices.py
+++ b/synapse/storage/devices.py
@@ -368,7 +368,7 @@ class DeviceStore(SQLBaseStore):
prev_sent_id_sql = """
SELECT coalesce(max(stream_id), 0) as stream_id
- FROM device_lists_outbound_pokes
+ FROM device_lists_outbound_last_success
WHERE destination = ? AND user_id = ? AND stream_id <= ?
"""
@@ -510,32 +510,43 @@ class DeviceStore(SQLBaseStore):
)
def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id):
- # First we DELETE all rows such that only the latest row for each
- # (destination, user_id is left. We do this by selecting first and
- # deleting.
+ # We update the device_lists_outbound_last_success with the successfully
+ # poked users. We do the join to see which users need to be inserted and
+ # which updated.
sql = """
- SELECT user_id, coalesce(max(stream_id), 0) FROM device_lists_outbound_pokes
- WHERE destination = ? AND stream_id <= ?
+ SELECT user_id, coalesce(max(o.stream_id), 0), (max(s.stream_id) IS NOT NULL)
+ FROM device_lists_outbound_pokes as o
+ LEFT JOIN device_lists_outbound_last_success as s
+ USING (destination, user_id)
+ WHERE destination = ? AND o.stream_id <= ?
GROUP BY user_id
- HAVING count(*) > 1
"""
txn.execute(sql, (destination, stream_id,))
rows = txn.fetchall()
sql = """
- DELETE FROM device_lists_outbound_pokes
- WHERE destination = ? AND user_id = ? AND stream_id < ?
+ UPDATE device_lists_outbound_last_success
+ SET stream_id = ?
+ WHERE destination = ? AND user_id = ?
"""
txn.executemany(
- sql, ((destination, row[0], row[1],) for row in rows)
+ sql, ((row[1], destination, row[0],) for row in rows if row[2])
)
- # Mark everything that is left as sent
sql = """
- UPDATE device_lists_outbound_pokes SET sent = ?
+ INSERT INTO device_lists_outbound_last_success
+ (destination, user_id, stream_id) VALUES (?, ?, ?)
+ """
+ txn.executemany(
+ sql, ((destination, row[0], row[1],) for row in rows if not row[2])
+ )
+
+ # Delete all sent outbound pokes
+ sql = """
+ DELETE FROM device_lists_outbound_pokes
WHERE destination = ? AND stream_id <= ?
"""
- txn.execute(sql, (True, destination, stream_id,))
+ txn.execute(sql, (destination, stream_id,))
@defer.inlineCallbacks
def get_user_whose_devices_changed(self, from_key):
@@ -670,6 +681,14 @@ class DeviceStore(SQLBaseStore):
)
)
+ # Since we've deleted unsent deltas, we need to remove the entry
+ # of last successful sent so that the prev_ids are correctly set.
+ sql = """
+ DELETE FROM device_lists_outbound_last_success
+ WHERE destination = ? AND user_id = ?
+ """
+ txn.executemany(sql, ((row[0], row[1]) for row in rows))
+
logger.info("Pruned %d device list outbound pokes", txn.rowcount)
return self.runInteraction(
diff --git a/synapse/storage/directory.py b/synapse/storage/directory.py
index 9caaf81f2c..79e7c540ad 100644
--- a/synapse/storage/directory.py
+++ b/synapse/storage/directory.py
@@ -170,3 +170,17 @@ class DirectoryStore(SQLBaseStore):
"room_alias",
desc="get_aliases_for_room",
)
+
+ def update_aliases_for_room(self, old_room_id, new_room_id, creator):
+ def _update_aliases_for_room_txn(txn):
+ sql = "UPDATE room_aliases SET room_id = ?, creator = ? WHERE room_id = ?"
+ txn.execute(sql, (new_room_id, creator, old_room_id,))
+ self._invalidate_cache_and_stream(
+ txn, self.get_aliases_for_room, (old_room_id,)
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.get_aliases_for_room, (new_room_id,)
+ )
+ return self.runInteraction(
+ "_update_aliases_for_room_txn", _update_aliases_for_room_txn
+ )
diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py
index e00f31da2b..2cebb203c6 100644
--- a/synapse/storage/end_to_end_keys.py
+++ b/synapse/storage/end_to_end_keys.py
@@ -185,8 +185,8 @@ class EndToEndKeyStore(SQLBaseStore):
for algorithm, key_id, json_bytes in new_keys
],
)
- txn.call_after(
- self.count_e2e_one_time_keys.invalidate, (user_id, device_id,)
+ self._invalidate_cache_and_stream(
+ txn, self.count_e2e_one_time_keys, (user_id, device_id,)
)
yield self.runInteraction(
"add_e2e_one_time_keys_insert", _add_e2e_one_time_keys
@@ -237,24 +237,29 @@ class EndToEndKeyStore(SQLBaseStore):
)
for user_id, device_id, algorithm, key_id in delete:
txn.execute(sql, (user_id, device_id, algorithm, key_id))
- txn.call_after(
- self.count_e2e_one_time_keys.invalidate, (user_id, device_id,)
+ self._invalidate_cache_and_stream(
+ txn, self.count_e2e_one_time_keys, (user_id, device_id,)
)
return result
return self.runInteraction(
"claim_e2e_one_time_keys", _claim_e2e_one_time_keys
)
- @defer.inlineCallbacks
def delete_e2e_keys_by_device(self, user_id, device_id):
- yield self._simple_delete(
- table="e2e_device_keys_json",
- keyvalues={"user_id": user_id, "device_id": device_id},
- desc="delete_e2e_device_keys_by_device"
- )
- yield self._simple_delete(
- table="e2e_one_time_keys_json",
- keyvalues={"user_id": user_id, "device_id": device_id},
- desc="delete_e2e_one_time_keys_by_device"
+ def delete_e2e_keys_by_device_txn(txn):
+ self._simple_delete_txn(
+ txn,
+ table="e2e_device_keys_json",
+ keyvalues={"user_id": user_id, "device_id": device_id},
+ )
+ self._simple_delete_txn(
+ txn,
+ table="e2e_one_time_keys_json",
+ keyvalues={"user_id": user_id, "device_id": device_id},
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.count_e2e_one_time_keys, (user_id, device_id,)
+ )
+ return self.runInteraction(
+ "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
)
- self.count_e2e_one_time_keys.invalidate((user_id, device_id,))
diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
index 519059c306..e8133de2fa 100644
--- a/synapse/storage/event_federation.py
+++ b/synapse/storage/event_federation.py
@@ -37,25 +37,55 @@ class EventFederationStore(SQLBaseStore):
and backfilling from another server respectively.
"""
+ EVENT_AUTH_STATE_ONLY = "event_auth_state_only"
+
def __init__(self, hs):
super(EventFederationStore, self).__init__(hs)
+ self.register_background_update_handler(
+ self.EVENT_AUTH_STATE_ONLY,
+ self._background_delete_non_state_event_auth,
+ )
+
hs.get_clock().looping_call(
self._delete_old_forward_extrem_cache, 60 * 60 * 1000
)
- def get_auth_chain(self, event_ids):
- return self.get_auth_chain_ids(event_ids).addCallback(self._get_events)
+ def get_auth_chain(self, event_ids, include_given=False):
+ """Get auth events for given event_ids. The events *must* be state events.
+
+ Args:
+ event_ids (list): state events
+ include_given (bool): include the given events in result
+
+ Returns:
+ list of events
+ """
+ return self.get_auth_chain_ids(
+ event_ids, include_given=include_given,
+ ).addCallback(self._get_events)
+
+ def get_auth_chain_ids(self, event_ids, include_given=False):
+ """Get auth events for given event_ids. The events *must* be state events.
+
+ Args:
+ event_ids (list): state events
+ include_given (bool): include the given events in result
- def get_auth_chain_ids(self, event_ids):
+ Returns:
+ list of event_ids
+ """
return self.runInteraction(
"get_auth_chain_ids",
self._get_auth_chain_ids_txn,
- event_ids
+ event_ids, include_given
)
- def _get_auth_chain_ids_txn(self, txn, event_ids):
- results = set()
+ def _get_auth_chain_ids_txn(self, txn, event_ids, include_given):
+ if include_given:
+ results = set(event_ids)
+ else:
+ results = set()
base_sql = (
"SELECT auth_id FROM event_auth WHERE event_id IN (%s)"
@@ -504,3 +534,52 @@ class EventFederationStore(SQLBaseStore):
txn.execute(query, (room_id,))
txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,))
+
+ @defer.inlineCallbacks
+ def _background_delete_non_state_event_auth(self, progress, batch_size):
+ def delete_event_auth(txn):
+ target_min_stream_id = progress.get("target_min_stream_id_inclusive")
+ max_stream_id = progress.get("max_stream_id_exclusive")
+
+ if not target_min_stream_id or not max_stream_id:
+ txn.execute("SELECT COALESCE(MIN(stream_ordering), 0) FROM events")
+ rows = txn.fetchall()
+ target_min_stream_id = rows[0][0]
+
+ txn.execute("SELECT COALESCE(MAX(stream_ordering), 0) FROM events")
+ rows = txn.fetchall()
+ max_stream_id = rows[0][0]
+
+ min_stream_id = max_stream_id - batch_size
+
+ sql = """
+ DELETE FROM event_auth
+ WHERE event_id IN (
+ SELECT event_id FROM events
+ LEFT JOIN state_events USING (room_id, event_id)
+ WHERE ? <= stream_ordering AND stream_ordering < ?
+ AND state_key IS null
+ )
+ """
+
+ txn.execute(sql, (min_stream_id, max_stream_id,))
+
+ new_progress = {
+ "target_min_stream_id_inclusive": target_min_stream_id,
+ "max_stream_id_exclusive": min_stream_id,
+ }
+
+ self._background_update_progress_txn(
+ txn, self.EVENT_AUTH_STATE_ONLY, new_progress
+ )
+
+ return min_stream_id >= target_min_stream_id
+
+ result = yield self.runInteraction(
+ self.EVENT_AUTH_STATE_ONLY, delete_event_auth
+ )
+
+ if not result:
+ yield self._end_background_update(self.EVENT_AUTH_STATE_ONLY)
+
+ defer.returnValue(batch_size)
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index f29d71589d..7002b3752e 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -403,6 +403,11 @@ class EventsStore(SQLBaseStore):
(room_id, ), new_state
)
+ for room_id, latest_event_ids in new_forward_extremeties.iteritems():
+ self.get_latest_event_ids_in_room.prefill(
+ (room_id,), list(latest_event_ids)
+ )
+
@defer.inlineCallbacks
def _calculate_new_extremeties(self, room_id, event_contexts, latest_event_ids):
"""Calculates the new forward extremeties for a room given events to
@@ -647,9 +652,10 @@ class EventsStore(SQLBaseStore):
list of the event ids which are the forward extremities.
"""
- self._update_current_state_txn(txn, current_state_for_room)
-
max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering
+
+ self._update_current_state_txn(txn, current_state_for_room, max_stream_order)
+
self._update_forward_extremities_txn(
txn,
new_forward_extremities=new_forward_extremeties,
@@ -712,7 +718,7 @@ class EventsStore(SQLBaseStore):
backfilled=backfilled,
)
- def _update_current_state_txn(self, txn, state_delta_by_room):
+ def _update_current_state_txn(self, txn, state_delta_by_room, max_stream_order):
for room_id, current_state_tuple in state_delta_by_room.iteritems():
to_delete, to_insert, _ = current_state_tuple
txn.executemany(
@@ -734,6 +740,29 @@ class EventsStore(SQLBaseStore):
],
)
+ state_deltas = {key: None for key in to_delete}
+ state_deltas.update(to_insert)
+
+ self._simple_insert_many_txn(
+ txn,
+ table="current_state_delta_stream",
+ values=[
+ {
+ "stream_id": max_stream_order,
+ "room_id": room_id,
+ "type": key[0],
+ "state_key": key[1],
+ "event_id": ev_id,
+ "prev_event_id": to_delete.get(key, None),
+ }
+ for key, ev_id in state_deltas.iteritems()
+ ]
+ )
+
+ self._curr_state_delta_stream_cache.entity_has_changed(
+ room_id, max_stream_order,
+ )
+
# Invalidate the various caches
# Figure out the changes of membership to invalidate the
@@ -742,11 +771,7 @@ class EventsStore(SQLBaseStore):
# and which we have added, then we invlidate the caches for all
# those users.
members_changed = set(
- state_key for ev_type, state_key in to_delete.iterkeys()
- if ev_type == EventTypes.Member
- )
- members_changed.update(
- state_key for ev_type, state_key in to_insert.iterkeys()
+ state_key for ev_type, state_key in state_deltas
if ev_type == EventTypes.Member
)
@@ -755,6 +780,11 @@ class EventsStore(SQLBaseStore):
txn, self.get_rooms_for_user, (member,)
)
+ for host in set(get_domain_from_id(u) for u in members_changed):
+ self._invalidate_cache_and_stream(
+ txn, self.is_host_joined, (room_id, host)
+ )
+
self._invalidate_cache_and_stream(
txn, self.get_users_in_room, (room_id,)
)
@@ -1119,6 +1149,7 @@ class EventsStore(SQLBaseStore):
}
for event, _ in events_and_contexts
for auth_id, _ in event.auth_events
+ if event.is_state()
],
)
@@ -1418,7 +1449,7 @@ class EventsStore(SQLBaseStore):
]
rows = self._new_transaction(
- conn, "do_fetch", [], None, self._fetch_event_rows, event_ids
+ conn, "do_fetch", [], [], None, self._fetch_event_rows, event_ids
)
row_dict = {
@@ -2243,6 +2274,24 @@ class EventsStore(SQLBaseStore):
defer.returnValue((int(res["topological_ordering"]), int(res["stream_ordering"])))
+ def get_max_current_state_delta_stream_id(self):
+ return self._stream_id_gen.get_current_token()
+
+ def get_all_updated_current_state_deltas(self, from_token, to_token, limit):
+ def get_all_updated_current_state_deltas_txn(txn):
+ sql = """
+ SELECT stream_id, room_id, type, state_key, event_id
+ FROM current_state_delta_stream
+ WHERE ? < stream_id AND stream_id <= ?
+ ORDER BY stream_id ASC LIMIT ?
+ """
+ txn.execute(sql, (from_token, to_token, limit))
+ return txn.fetchall()
+ return self.runInteraction(
+ "get_all_updated_current_state_deltas",
+ get_all_updated_current_state_deltas_txn,
+ )
+
AllNewEventsResult = namedtuple("AllNewEventsResult", [
"new_forward_events", "new_backfill_events",
diff --git a/synapse/storage/filtering.py b/synapse/storage/filtering.py
index a2ccc66ea7..78b1e30945 100644
--- a/synapse/storage/filtering.py
+++ b/synapse/storage/filtering.py
@@ -19,6 +19,7 @@ from ._base import SQLBaseStore
from synapse.api.errors import SynapseError, Codes
from synapse.util.caches.descriptors import cachedInlineCallbacks
+from canonicaljson import encode_canonical_json
import simplejson as json
@@ -46,12 +47,21 @@ class FilteringStore(SQLBaseStore):
defer.returnValue(json.loads(str(def_json).decode("utf-8")))
def add_user_filter(self, user_localpart, user_filter):
- def_json = json.dumps(user_filter).encode("utf-8")
+ def_json = encode_canonical_json(user_filter)
# Need an atomic transaction to SELECT the maximal ID so far then
# INSERT a new one
def _do_txn(txn):
sql = (
+ "SELECT filter_id FROM user_filters "
+ "WHERE user_id = ? AND filter_json = ?"
+ )
+ txn.execute(sql, (user_localpart, def_json))
+ filter_id_response = txn.fetchone()
+ if filter_id_response is not None:
+ return filter_id_response[0]
+
+ sql = (
"SELECT MAX(filter_id) FROM user_filters "
"WHERE user_id = ?"
)
diff --git a/synapse/storage/media_repository.py b/synapse/storage/media_repository.py
index 4c0f82353d..82bb61b811 100644
--- a/synapse/storage/media_repository.py
+++ b/synapse/storage/media_repository.py
@@ -30,13 +30,16 @@ class MediaRepositoryStore(SQLBaseStore):
return self._simple_select_one(
"local_media_repository",
{"media_id": media_id},
- ("media_type", "media_length", "upload_name", "created_ts"),
+ (
+ "media_type", "media_length", "upload_name", "created_ts",
+ "quarantined_by", "url_cache",
+ ),
allow_none=True,
desc="get_local_media",
)
def store_local_media(self, media_id, media_type, time_now_ms, upload_name,
- media_length, user_id):
+ media_length, user_id, url_cache=None):
return self._simple_insert(
"local_media_repository",
{
@@ -46,6 +49,7 @@ class MediaRepositoryStore(SQLBaseStore):
"upload_name": upload_name,
"media_length": media_length,
"user_id": user_id.to_string(),
+ "url_cache": url_cache,
},
desc="store_local_media",
)
@@ -138,7 +142,7 @@ class MediaRepositoryStore(SQLBaseStore):
{"media_origin": origin, "media_id": media_id},
(
"media_type", "media_length", "upload_name", "created_ts",
- "filesystem_id",
+ "filesystem_id", "quarantined_by",
),
allow_none=True,
desc="get_cached_remote_media",
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 6e623843d5..72b670b83b 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts.
-SCHEMA_VERSION = 41
+SCHEMA_VERSION = 43
dir_path = os.path.abspath(os.path.dirname(__file__))
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index 0a819d32c5..8758b1c0c7 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -49,7 +49,7 @@ def _load_rules(rawrules, enabled_map):
class PushRuleStore(SQLBaseStore):
- @cachedInlineCallbacks()
+ @cachedInlineCallbacks(max_entries=5000)
def get_push_rules_for_user(self, user_id):
rows = yield self._simple_select_list(
table="push_rules",
@@ -73,7 +73,7 @@ class PushRuleStore(SQLBaseStore):
defer.returnValue(rules)
- @cachedInlineCallbacks()
+ @cachedInlineCallbacks(max_entries=5000)
def get_push_rules_enabled_for_user(self, user_id):
results = yield self._simple_select_list(
table="push_rules_enable",
diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py
index efb90c3c91..f42b8014c7 100644
--- a/synapse/storage/receipts.py
+++ b/synapse/storage/receipts.py
@@ -45,7 +45,9 @@ class ReceiptsStore(SQLBaseStore):
return
# Returns an ObservableDeferred
- res = self.get_users_with_read_receipts_in_room.cache.get((room_id,), None)
+ res = self.get_users_with_read_receipts_in_room.cache.get(
+ room_id, None, update_metrics=False,
+ )
if res:
if isinstance(res, defer.Deferred) and res.called:
diff --git a/synapse/storage/room.py b/synapse/storage/room.py
index 5d543652bb..23688430b7 100644
--- a/synapse/storage/room.py
+++ b/synapse/storage/room.py
@@ -24,6 +24,7 @@ from .engines import PostgresEngine, Sqlite3Engine
import collections
import logging
import ujson as json
+import re
logger = logging.getLogger(__name__)
@@ -507,3 +508,98 @@ class RoomStore(SQLBaseStore):
))
else:
defer.returnValue(None)
+
+ @cached(max_entries=10000)
+ def is_room_blocked(self, room_id):
+ return self._simple_select_one_onecol(
+ table="blocked_rooms",
+ keyvalues={
+ "room_id": room_id,
+ },
+ retcol="1",
+ allow_none=True,
+ desc="is_room_blocked",
+ )
+
+ @defer.inlineCallbacks
+ def block_room(self, room_id, user_id):
+ yield self._simple_insert(
+ table="blocked_rooms",
+ values={
+ "room_id": room_id,
+ "user_id": user_id,
+ },
+ desc="block_room",
+ )
+ self.is_room_blocked.invalidate((room_id,))
+
+ def quarantine_media_ids_in_room(self, room_id, quarantined_by):
+ """For a room loops through all events with media and quarantines
+ the associated media
+ """
+ def _get_media_ids_in_room(txn):
+ mxc_re = re.compile("^mxc://([^/]+)/([^/#?]+)")
+
+ next_token = self.get_current_events_token() + 1
+
+ total_media_quarantined = 0
+
+ while next_token:
+ sql = """
+ SELECT stream_ordering, content FROM events
+ WHERE room_id = ?
+ AND stream_ordering < ?
+ AND contains_url = ? AND outlier = ?
+ ORDER BY stream_ordering DESC
+ LIMIT ?
+ """
+ txn.execute(sql, (room_id, next_token, True, False, 100))
+
+ next_token = None
+ local_media_mxcs = []
+ remote_media_mxcs = []
+ for stream_ordering, content_json in txn:
+ next_token = stream_ordering
+ content = json.loads(content_json)
+
+ content_url = content.get("url")
+ thumbnail_url = content.get("info", {}).get("thumbnail_url")
+
+ for url in (content_url, thumbnail_url):
+ if not url:
+ continue
+ matches = mxc_re.match(url)
+ if matches:
+ hostname = matches.group(1)
+ media_id = matches.group(2)
+ if hostname == self.hostname:
+ local_media_mxcs.append(media_id)
+ else:
+ remote_media_mxcs.append((hostname, media_id))
+
+ # Now update all the tables to set the quarantined_by flag
+
+ txn.executemany("""
+ UPDATE local_media_repository
+ SET quarantined_by = ?
+ WHERE media_id = ?
+ """, ((quarantined_by, media_id) for media_id in local_media_mxcs))
+
+ txn.executemany(
+ """
+ UPDATE remote_media_cache
+ SET quarantined_by = ?
+ WHERE media_origin AND media_id = ?
+ """,
+ (
+ (quarantined_by, origin, media_id)
+ for origin, media_id in remote_media_mxcs
+ )
+ )
+
+ total_media_quarantined += len(local_media_mxcs)
+ total_media_quarantined += len(remote_media_mxcs)
+
+ return total_media_quarantined
+
+ return self.runInteraction("get_media_ids_in_room", _get_media_ids_in_room)
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 0829ae5bee..457ca288d0 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -18,6 +18,7 @@ from twisted.internet import defer
from collections import namedtuple
from ._base import SQLBaseStore
+from synapse.util.async import Linearizer
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
from synapse.util.stringutils import to_ascii
@@ -392,7 +393,8 @@ class RoomMemberStore(SQLBaseStore):
context=context,
)
- def get_joined_users_from_state(self, room_id, state_group, state_ids):
+ def get_joined_users_from_state(self, room_id, state_entry):
+ state_group = state_entry.state_group
if not state_group:
# If state_group is None it means it has yet to be assigned a
# state group, i.e. we need to make sure that calls with a state_group
@@ -401,7 +403,7 @@ class RoomMemberStore(SQLBaseStore):
state_group = object()
return self._get_joined_users_from_context(
- room_id, state_group, state_ids,
+ room_id, state_group, state_entry.state, context=state_entry,
)
@cachedInlineCallbacks(num_args=2, cache_context=True, iterable=True,
@@ -499,42 +501,40 @@ class RoomMemberStore(SQLBaseStore):
defer.returnValue(users_in_room)
- def is_host_joined(self, room_id, host, state_group, state_ids):
- if not state_group:
- # If state_group is None it means it has yet to be assigned a
- # state group, i.e. we need to make sure that calls with a state_group
- # of None don't hit previous cached calls with a None state_group.
- # To do this we set the state_group to a new object as object() != object()
- state_group = object()
+ @cachedInlineCallbacks(max_entries=10000)
+ def is_host_joined(self, room_id, host):
+ if '%' in host or '_' in host:
+ raise Exception("Invalid host name")
- return self._is_host_joined(
- room_id, host, state_group, state_ids
- )
+ sql = """
+ SELECT state_key FROM current_state_events AS c
+ INNER JOIN room_memberships USING (event_id)
+ WHERE membership = 'join'
+ AND type = 'm.room.member'
+ AND c.room_id = ?
+ AND state_key LIKE ?
+ LIMIT 1
+ """
- @cachedInlineCallbacks(num_args=3)
- def _is_host_joined(self, room_id, host, state_group, current_state_ids):
- # We don't use `state_group`, its there so that we can cache based
- # on it. However, its important that its never None, since two current_state's
- # with a state_group of None are likely to be different.
- # See bulk_get_push_rules_for_room for how we work around this.
- assert state_group is not None
+ # We do need to be careful to ensure that host doesn't have any wild cards
+ # in it, but we checked above for known ones and we'll check below that
+ # the returned user actually has the correct domain.
+ like_clause = "%:" + host
- for (etype, state_key), event_id in current_state_ids.items():
- if etype == EventTypes.Member:
- try:
- if get_domain_from_id(state_key) != host:
- continue
- except:
- logger.warn("state_key not user_id: %s", state_key)
- continue
+ rows = yield self._execute("is_host_joined", None, sql, room_id, like_clause)
- event = yield self.get_event(event_id, allow_none=True)
- if event and event.content["membership"] == Membership.JOIN:
- defer.returnValue(True)
+ if not rows:
+ defer.returnValue(False)
- defer.returnValue(False)
+ user_id = rows[0][0]
+ if get_domain_from_id(user_id) != host:
+ # This can only happen if the host name has something funky in it
+ raise Exception("Invalid host name")
- def get_joined_hosts(self, room_id, state_group, state_ids):
+ defer.returnValue(True)
+
+ def get_joined_hosts(self, room_id, state_entry):
+ state_group = state_entry.state_group
if not state_group:
# If state_group is None it means it has yet to be assigned a
# state group, i.e. we need to make sure that calls with a state_group
@@ -543,33 +543,20 @@ class RoomMemberStore(SQLBaseStore):
state_group = object()
return self._get_joined_hosts(
- room_id, state_group, state_ids
+ room_id, state_group, state_entry.state, state_entry=state_entry,
)
@cachedInlineCallbacks(num_args=2, max_entries=10000, iterable=True)
- def _get_joined_hosts(self, room_id, state_group, current_state_ids):
+ # @defer.inlineCallbacks
+ def _get_joined_hosts(self, room_id, state_group, current_state_ids, state_entry):
# We don't use `state_group`, its there so that we can cache based
# on it. However, its important that its never None, since two current_state's
# with a state_group of None are likely to be different.
# See bulk_get_push_rules_for_room for how we work around this.
assert state_group is not None
- joined_hosts = set()
- for etype, state_key in current_state_ids:
- if etype == EventTypes.Member:
- try:
- host = get_domain_from_id(state_key)
- except:
- logger.warn("state_key not user_id: %s", state_key)
- continue
-
- if host in joined_hosts:
- continue
-
- event_id = current_state_ids[(etype, state_key)]
- event = yield self.get_event(event_id, allow_none=True)
- if event and event.content["membership"] == Membership.JOIN:
- joined_hosts.add(intern_string(host))
+ cache = self._get_joined_hosts_cache(room_id)
+ joined_hosts = yield cache.get_destinations(state_entry)
defer.returnValue(joined_hosts)
@@ -647,3 +634,75 @@ class RoomMemberStore(SQLBaseStore):
yield self._end_background_update(_MEMBERSHIP_PROFILE_UPDATE_NAME)
defer.returnValue(result)
+
+ @cached(max_entries=10000, iterable=True)
+ def _get_joined_hosts_cache(self, room_id):
+ return _JoinedHostsCache(self, room_id)
+
+
+class _JoinedHostsCache(object):
+ """Cache for joined hosts in a room that is optimised to handle updates
+ via state deltas.
+ """
+
+ def __init__(self, store, room_id):
+ self.store = store
+ self.room_id = room_id
+
+ self.hosts_to_joined_users = {}
+
+ self.state_group = object()
+
+ self.linearizer = Linearizer("_JoinedHostsCache")
+
+ self._len = 0
+
+ @defer.inlineCallbacks
+ def get_destinations(self, state_entry):
+ """Get set of destinations for a state entry
+
+ Args:
+ state_entry(synapse.state._StateCacheEntry)
+ """
+ if state_entry.state_group == self.state_group:
+ defer.returnValue(frozenset(self.hosts_to_joined_users))
+
+ with (yield self.linearizer.queue(())):
+ if state_entry.state_group == self.state_group:
+ pass
+ elif state_entry.prev_group == self.state_group:
+ for (typ, state_key), event_id in state_entry.delta_ids.iteritems():
+ if typ != EventTypes.Member:
+ continue
+
+ host = intern_string(get_domain_from_id(state_key))
+ user_id = state_key
+ known_joins = self.hosts_to_joined_users.setdefault(host, set())
+
+ event = yield self.store.get_event(event_id)
+ if event.membership == Membership.JOIN:
+ known_joins.add(user_id)
+ else:
+ known_joins.discard(user_id)
+
+ if not known_joins:
+ self.hosts_to_joined_users.pop(host, None)
+ else:
+ joined_users = yield self.store.get_joined_users_from_state(
+ self.room_id, state_entry,
+ )
+
+ self.hosts_to_joined_users = {}
+ for user_id in joined_users:
+ host = intern_string(get_domain_from_id(user_id))
+ self.hosts_to_joined_users.setdefault(host, set()).add(user_id)
+
+ if state_entry.state_group:
+ self.state_group = state_entry.state_group
+ else:
+ self.state_group = object()
+ self._len = sum(len(v) for v in self.hosts_to_joined_users.itervalues())
+ defer.returnValue(frozenset(self.hosts_to_joined_users))
+
+ def __len__(self):
+ return self._len
diff --git a/synapse/storage/schema/delta/42/current_state_delta.sql b/synapse/storage/schema/delta/42/current_state_delta.sql
new file mode 100644
index 0000000000..d28851aff8
--- /dev/null
+++ b/synapse/storage/schema/delta/42/current_state_delta.sql
@@ -0,0 +1,26 @@
+/* Copyright 2017 Vector Creations Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+CREATE TABLE current_state_delta_stream (
+ stream_id BIGINT NOT NULL,
+ room_id TEXT NOT NULL,
+ type TEXT NOT NULL,
+ state_key TEXT NOT NULL,
+ event_id TEXT, -- Is null if the key was removed
+ prev_event_id TEXT -- Is null if the key was added
+);
+
+CREATE INDEX current_state_delta_stream_idx ON current_state_delta_stream(stream_id);
diff --git a/synapse/storage/schema/delta/42/device_list_last_id.sql b/synapse/storage/schema/delta/42/device_list_last_id.sql
new file mode 100644
index 0000000000..9ab8c14fa3
--- /dev/null
+++ b/synapse/storage/schema/delta/42/device_list_last_id.sql
@@ -0,0 +1,33 @@
+/* Copyright 2017 Vector Creations Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+-- Table of last stream_id that we sent to destination for user_id. This is
+-- used to fill out the `prev_id` fields of outbound device list updates.
+CREATE TABLE device_lists_outbound_last_success (
+ destination TEXT NOT NULL,
+ user_id TEXT NOT NULL,
+ stream_id BIGINT NOT NULL
+);
+
+INSERT INTO device_lists_outbound_last_success
+ SELECT destination, user_id, coalesce(max(stream_id), 0) as stream_id
+ FROM device_lists_outbound_pokes
+ WHERE sent = (1 = 1) -- sqlite doesn't have inbuilt boolean values
+ GROUP BY destination, user_id;
+
+CREATE INDEX device_lists_outbound_last_success_idx ON device_lists_outbound_last_success(
+ destination, user_id, stream_id
+);
diff --git a/synapse/storage/schema/delta/42/event_auth_state_only.sql b/synapse/storage/schema/delta/42/event_auth_state_only.sql
new file mode 100644
index 0000000000..b8821ac759
--- /dev/null
+++ b/synapse/storage/schema/delta/42/event_auth_state_only.sql
@@ -0,0 +1,17 @@
+/* Copyright 2017 Vector Creations Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+INSERT INTO background_updates (update_name, progress_json) VALUES
+ ('event_auth_state_only', '{}');
diff --git a/synapse/storage/schema/delta/42/user_dir.py b/synapse/storage/schema/delta/42/user_dir.py
new file mode 100644
index 0000000000..ea6a18196d
--- /dev/null
+++ b/synapse/storage/schema/delta/42/user_dir.py
@@ -0,0 +1,84 @@
+# Copyright 2017 Vector Creations Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from synapse.storage.prepare_database import get_statements
+from synapse.storage.engines import PostgresEngine, Sqlite3Engine
+
+logger = logging.getLogger(__name__)
+
+
+BOTH_TABLES = """
+CREATE TABLE user_directory_stream_pos (
+ Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row.
+ stream_id BIGINT,
+ CHECK (Lock='X')
+);
+
+INSERT INTO user_directory_stream_pos (stream_id) VALUES (null);
+
+CREATE TABLE user_directory (
+ user_id TEXT NOT NULL,
+ room_id TEXT NOT NULL, -- A room_id that we know the user is joined to
+ display_name TEXT,
+ avatar_url TEXT
+);
+
+CREATE INDEX user_directory_room_idx ON user_directory(room_id);
+CREATE UNIQUE INDEX user_directory_user_idx ON user_directory(user_id);
+
+CREATE TABLE users_in_pubic_room (
+ user_id TEXT NOT NULL,
+ room_id TEXT NOT NULL -- A room_id that we know is public
+);
+
+CREATE INDEX users_in_pubic_room_room_idx ON users_in_pubic_room(room_id);
+CREATE UNIQUE INDEX users_in_pubic_room_user_idx ON users_in_pubic_room(user_id);
+"""
+
+
+POSTGRES_TABLE = """
+CREATE TABLE user_directory_search (
+ user_id TEXT NOT NULL,
+ vector tsvector
+);
+
+CREATE INDEX user_directory_search_fts_idx ON user_directory_search USING gin(vector);
+CREATE UNIQUE INDEX user_directory_search_user_idx ON user_directory_search(user_id);
+"""
+
+
+SQLITE_TABLE = """
+CREATE VIRTUAL TABLE user_directory_search
+ USING fts4 ( user_id, value );
+"""
+
+
+def run_create(cur, database_engine, *args, **kwargs):
+ for statement in get_statements(BOTH_TABLES.splitlines()):
+ cur.execute(statement)
+
+ if isinstance(database_engine, PostgresEngine):
+ for statement in get_statements(POSTGRES_TABLE.splitlines()):
+ cur.execute(statement)
+ elif isinstance(database_engine, Sqlite3Engine):
+ for statement in get_statements(SQLITE_TABLE.splitlines()):
+ cur.execute(statement)
+ else:
+ raise Exception("Unrecognized database engine")
+
+
+def run_upgrade(*args, **kwargs):
+ pass
diff --git a/synapse/storage/schema/delta/43/blocked_rooms.sql b/synapse/storage/schema/delta/43/blocked_rooms.sql
new file mode 100644
index 0000000000..0e3cd143ff
--- /dev/null
+++ b/synapse/storage/schema/delta/43/blocked_rooms.sql
@@ -0,0 +1,21 @@
+/* Copyright 2017 Vector Creations Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE blocked_rooms (
+ room_id TEXT NOT NULL,
+ user_id TEXT NOT NULL -- Admin who blocked the room
+);
+
+CREATE UNIQUE INDEX blocked_rooms_idx ON blocked_rooms(room_id);
diff --git a/synapse/storage/schema/delta/43/quarantine_media.sql b/synapse/storage/schema/delta/43/quarantine_media.sql
new file mode 100644
index 0000000000..630907ec4f
--- /dev/null
+++ b/synapse/storage/schema/delta/43/quarantine_media.sql
@@ -0,0 +1,17 @@
+/* Copyright 2017 Vector Creations Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ALTER TABLE local_media_repository ADD COLUMN quarantined_by TEXT;
+ALTER TABLE remote_media_cache ADD COLUMN quarantined_by TEXT;
diff --git a/synapse/storage/schema/delta/43/url_cache.sql b/synapse/storage/schema/delta/43/url_cache.sql
new file mode 100644
index 0000000000..45ebe020da
--- /dev/null
+++ b/synapse/storage/schema/delta/43/url_cache.sql
@@ -0,0 +1,16 @@
+/* Copyright 2017 Vector Creations Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ALTER TABLE local_media_repository ADD COLUMN url_cache TEXT;
diff --git a/synapse/storage/schema/delta/43/user_share.sql b/synapse/storage/schema/delta/43/user_share.sql
new file mode 100644
index 0000000000..4501d90cbb
--- /dev/null
+++ b/synapse/storage/schema/delta/43/user_share.sql
@@ -0,0 +1,33 @@
+/* Copyright 2017 Vector Creations Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Table keeping track of who shares a room with who. We only keep track
+-- of this for local users, so `user_id` is local users only (but we do keep track
+-- of which remote users share a room)
+CREATE TABLE users_who_share_rooms (
+ user_id TEXT NOT NULL,
+ other_user_id TEXT NOT NULL,
+ room_id TEXT NOT NULL,
+ share_private BOOLEAN NOT NULL -- is the shared room private? i.e. they share a private room
+);
+
+
+CREATE UNIQUE INDEX users_who_share_rooms_u_idx ON users_who_share_rooms(user_id, other_user_id);
+CREATE INDEX users_who_share_rooms_r_idx ON users_who_share_rooms(room_id);
+CREATE INDEX users_who_share_rooms_o_idx ON users_who_share_rooms(other_user_id);
+
+
+-- Make sure that we popualte the table initially
+UPDATE user_directory_stream_pos SET stream_id = NULL;
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 85acf2ad1e..5673e4aa96 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -20,6 +20,7 @@ from synapse.util.stringutils import to_ascii
from synapse.storage.engines import PostgresEngine
from twisted.internet import defer
+from collections import namedtuple
import logging
@@ -29,6 +30,16 @@ logger = logging.getLogger(__name__)
MAX_STATE_DELTA_HOPS = 100
+class _GetStateGroupDelta(namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids"))):
+ """Return type of get_state_group_delta that implements __len__, which lets
+ us use the itrable flag when caching
+ """
+ __slots__ = []
+
+ def __len__(self):
+ return len(self.delta_ids) if self.delta_ids else 0
+
+
class StateStore(SQLBaseStore):
""" Keeps track of the state at a given event.
@@ -98,6 +109,46 @@ class StateStore(SQLBaseStore):
_get_current_state_ids_txn,
)
+ @cached(max_entries=10000, iterable=True)
+ def get_state_group_delta(self, state_group):
+ """Given a state group try to return a previous group and a delta between
+ the old and the new.
+
+ Returns:
+ (prev_group, delta_ids), where both may be None.
+ """
+ def _get_state_group_delta_txn(txn):
+ prev_group = self._simple_select_one_onecol_txn(
+ txn,
+ table="state_group_edges",
+ keyvalues={
+ "state_group": state_group,
+ },
+ retcol="prev_state_group",
+ allow_none=True,
+ )
+
+ if not prev_group:
+ return _GetStateGroupDelta(None, None)
+
+ delta_ids = self._simple_select_list_txn(
+ txn,
+ table="state_groups_state",
+ keyvalues={
+ "state_group": state_group,
+ },
+ retcols=("type", "state_key", "event_id",)
+ )
+
+ return _GetStateGroupDelta(prev_group, {
+ (row["type"], row["state_key"]): row["event_id"]
+ for row in delta_ids
+ })
+ return self.runInteraction(
+ "get_state_group_delta",
+ _get_state_group_delta_txn,
+ )
+
@defer.inlineCallbacks
def get_state_groups_ids(self, room_id, event_ids):
if not event_ids:
@@ -184,6 +235,19 @@ class StateStore(SQLBaseStore):
# We persist as a delta if we can, while also ensuring the chain
# of deltas isn't tooo long, as otherwise read performance degrades.
if context.prev_group:
+ is_in_db = self._simple_select_one_onecol_txn(
+ txn,
+ table="state_groups",
+ keyvalues={"id": context.prev_group},
+ retcol="id",
+ allow_none=True,
+ )
+ if not is_in_db:
+ raise Exception(
+ "Trying to persist state with unpersisted prev_group: %r"
+ % (context.prev_group,)
+ )
+
potential_hops = self._count_state_group_hops_txn(
txn, context.prev_group
)
@@ -251,6 +315,12 @@ class StateStore(SQLBaseStore):
],
)
+ for event_id, state_group_id in state_groups.iteritems():
+ txn.call_after(
+ self._get_state_group_for_event.prefill,
+ (event_id,), state_group_id
+ )
+
def _count_state_group_hops_txn(self, txn, state_group):
"""Given a state group, count how many hops there are in the tree.
@@ -520,8 +590,8 @@ class StateStore(SQLBaseStore):
state_map = yield self.get_state_ids_for_events([event_id], types)
defer.returnValue(state_map[event_id])
- @cached(num_args=2, max_entries=50000)
- def _get_state_group_for_event(self, room_id, event_id):
+ @cached(max_entries=50000)
+ def _get_state_group_for_event(self, event_id):
return self._simple_select_one_onecol(
table="event_to_state_groups",
keyvalues={
@@ -563,20 +633,22 @@ class StateStore(SQLBaseStore):
where a `state_key` of `None` matches all state_keys for the
`type`.
"""
- is_all, state_dict_ids = self._state_group_cache.get(group)
+ is_all, known_absent, state_dict_ids = self._state_group_cache.get(group)
type_to_key = {}
missing_types = set()
+
for typ, state_key in types:
+ key = (typ, state_key)
if state_key is None:
type_to_key[typ] = None
- missing_types.add((typ, state_key))
+ missing_types.add(key)
else:
if type_to_key.get(typ, object()) is not None:
type_to_key.setdefault(typ, set()).add(state_key)
- if (typ, state_key) not in state_dict_ids:
- missing_types.add((typ, state_key))
+ if key not in state_dict_ids and key not in known_absent:
+ missing_types.add(key)
sentinel = object()
@@ -590,7 +662,7 @@ class StateStore(SQLBaseStore):
return True
return False
- got_all = not (missing_types or types is None)
+ got_all = is_all or not missing_types
return {
k: v for k, v in state_dict_ids.iteritems()
@@ -607,7 +679,7 @@ class StateStore(SQLBaseStore):
Args:
group: The state group to lookup
"""
- is_all, state_dict_ids = self._state_group_cache.get(group)
+ is_all, _, state_dict_ids = self._state_group_cache.get(group)
return state_dict_ids, is_all
@@ -624,7 +696,7 @@ class StateStore(SQLBaseStore):
missing_groups = []
if types is not None:
for group in set(groups):
- state_dict_ids, missing_types, got_all = self._get_some_state_from_cache(
+ state_dict_ids, _, got_all = self._get_some_state_from_cache(
group, types
)
results[group] = state_dict_ids
@@ -653,19 +725,7 @@ class StateStore(SQLBaseStore):
# Now we want to update the cache with all the things we fetched
# from the database.
for group, group_state_dict in group_to_state_dict.iteritems():
- if types:
- # We delibrately put key -> None mappings into the cache to
- # cache absence of the key, on the assumption that if we've
- # explicitly asked for some types then we will probably ask
- # for them again.
- state_dict = {
- (intern_string(etype), intern_string(state_key)): None
- for (etype, state_key) in types
- }
- state_dict.update(results[group])
- results[group] = state_dict
- else:
- state_dict = results[group]
+ state_dict = results[group]
state_dict.update(
((intern_string(k[0]), intern_string(k[1])), to_ascii(v))
@@ -677,17 +737,9 @@ class StateStore(SQLBaseStore):
key=group,
value=state_dict,
full=(types is None),
+ known_absent=types,
)
- # Remove all the entries with None values. The None values were just
- # used for bookkeeping in the cache.
- for group, state_dict in results.iteritems():
- results[group] = {
- key: event_id
- for key, event_id in state_dict.iteritems()
- if event_id
- }
-
defer.returnValue(results)
def get_next_state_group(self):
diff --git a/synapse/storage/user_directory.py b/synapse/storage/user_directory.py
new file mode 100644
index 0000000000..2a4db3f03c
--- /dev/null
+++ b/synapse/storage/user_directory.py
@@ -0,0 +1,743 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 Vector Creations Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+from ._base import SQLBaseStore
+
+from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
+from synapse.api.constants import EventTypes, JoinRules
+from synapse.storage.engines import PostgresEngine, Sqlite3Engine
+from synapse.types import get_domain_from_id, get_localpart_from_id
+
+import re
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class UserDirectoryStore(SQLBaseStore):
+ @cachedInlineCallbacks(cache_context=True)
+ def is_room_world_readable_or_publicly_joinable(self, room_id, cache_context):
+ """Check if the room is either world_readable or publically joinable
+ """
+ current_state_ids = yield self.get_current_state_ids(
+ room_id, on_invalidate=cache_context.invalidate
+ )
+
+ join_rules_id = current_state_ids.get((EventTypes.JoinRules, ""))
+ if join_rules_id:
+ join_rule_ev = yield self.get_event(join_rules_id, allow_none=True)
+ if join_rule_ev:
+ if join_rule_ev.content.get("join_rule") == JoinRules.PUBLIC:
+ defer.returnValue(True)
+
+ hist_vis_id = current_state_ids.get((EventTypes.RoomHistoryVisibility, ""))
+ if hist_vis_id:
+ hist_vis_ev = yield self.get_event(hist_vis_id, allow_none=True)
+ if hist_vis_ev:
+ if hist_vis_ev.content.get("history_visibility") == "world_readable":
+ defer.returnValue(True)
+
+ defer.returnValue(False)
+
+ @defer.inlineCallbacks
+ def add_users_to_public_room(self, room_id, user_ids):
+ """Add user to the list of users in public rooms
+
+ Args:
+ room_id (str): A room_id that all users are in that is world_readable
+ or publically joinable
+ user_ids (list(str)): Users to add
+ """
+ yield self._simple_insert_many(
+ table="users_in_pubic_room",
+ values=[
+ {
+ "user_id": user_id,
+ "room_id": room_id,
+ }
+ for user_id in user_ids
+ ],
+ desc="add_users_to_public_room"
+ )
+ for user_id in user_ids:
+ self.get_user_in_public_room.invalidate((user_id,))
+
+ def add_profiles_to_user_dir(self, room_id, users_with_profile):
+ """Add profiles to the user directory
+
+ Args:
+ room_id (str): A room_id that all users are joined to
+ users_with_profile (dict): Users to add to directory in the form of
+ mapping of user_id -> ProfileInfo
+ """
+ if isinstance(self.database_engine, PostgresEngine):
+ # We weight the loclpart most highly, then display name and finally
+ # server name
+ sql = """
+ INSERT INTO user_directory_search(user_id, vector)
+ VALUES (?,
+ setweight(to_tsvector('english', ?), 'A')
+ || setweight(to_tsvector('english', ?), 'D')
+ || setweight(to_tsvector('english', COALESCE(?, '')), 'B')
+ )
+ """
+ args = (
+ (
+ user_id, get_localpart_from_id(user_id), get_domain_from_id(user_id),
+ profile.display_name,
+ )
+ for user_id, profile in users_with_profile.iteritems()
+ )
+ elif isinstance(self.database_engine, Sqlite3Engine):
+ sql = """
+ INSERT INTO user_directory_search(user_id, value)
+ VALUES (?,?)
+ """
+ args = (
+ (
+ user_id,
+ "%s %s" % (user_id, p.display_name,) if p.display_name else user_id
+ )
+ for user_id, p in users_with_profile.iteritems()
+ )
+ else:
+ # This should be unreachable.
+ raise Exception("Unrecognized database engine")
+
+ def _add_profiles_to_user_dir_txn(txn):
+ txn.executemany(sql, args)
+ self._simple_insert_many_txn(
+ txn,
+ table="user_directory",
+ values=[
+ {
+ "user_id": user_id,
+ "room_id": room_id,
+ "display_name": profile.display_name,
+ "avatar_url": profile.avatar_url,
+ }
+ for user_id, profile in users_with_profile.iteritems()
+ ]
+ )
+ for user_id in users_with_profile:
+ txn.call_after(
+ self.get_user_in_directory.invalidate, (user_id,)
+ )
+
+ return self.runInteraction(
+ "add_profiles_to_user_dir", _add_profiles_to_user_dir_txn
+ )
+
+ @defer.inlineCallbacks
+ def update_user_in_user_dir(self, user_id, room_id):
+ yield self._simple_update_one(
+ table="user_directory",
+ keyvalues={"user_id": user_id},
+ updatevalues={"room_id": room_id},
+ desc="update_user_in_user_dir",
+ )
+ self.get_user_in_directory.invalidate((user_id,))
+
+ def update_profile_in_user_dir(self, user_id, display_name, avatar_url, room_id):
+ def _update_profile_in_user_dir_txn(txn):
+ new_entry = self._simple_upsert_txn(
+ txn,
+ table="user_directory",
+ keyvalues={"user_id": user_id},
+ insertion_values={"room_id": room_id},
+ values={"display_name": display_name, "avatar_url": avatar_url},
+ lock=False, # We're only inserter
+ )
+
+ if isinstance(self.database_engine, PostgresEngine):
+ # We weight the loclpart most highly, then display name and finally
+ # server name
+ if new_entry:
+ sql = """
+ INSERT INTO user_directory_search(user_id, vector)
+ VALUES (?,
+ setweight(to_tsvector('english', ?), 'A')
+ || setweight(to_tsvector('english', ?), 'D')
+ || setweight(to_tsvector('english', COALESCE(?, '')), 'B')
+ )
+ """
+ txn.execute(
+ sql,
+ (
+ user_id, get_localpart_from_id(user_id),
+ get_domain_from_id(user_id), display_name,
+ )
+ )
+ else:
+ sql = """
+ UPDATE user_directory_search
+ SET vector = setweight(to_tsvector('english', ?), 'A')
+ || setweight(to_tsvector('english', ?), 'D')
+ || setweight(to_tsvector('english', COALESCE(?, '')), 'B')
+ WHERE user_id = ?
+ """
+ txn.execute(
+ sql,
+ (
+ get_localpart_from_id(user_id), get_domain_from_id(user_id),
+ display_name, user_id,
+ )
+ )
+ elif isinstance(self.database_engine, Sqlite3Engine):
+ value = "%s %s" % (user_id, display_name,) if display_name else user_id
+ self._simple_upsert_txn(
+ txn,
+ table="user_directory_search",
+ keyvalues={"user_id": user_id},
+ values={"value": value},
+ lock=False, # We're only inserter
+ )
+ else:
+ # This should be unreachable.
+ raise Exception("Unrecognized database engine")
+
+ txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
+
+ return self.runInteraction(
+ "update_profile_in_user_dir", _update_profile_in_user_dir_txn
+ )
+
+ @defer.inlineCallbacks
+ def update_user_in_public_user_list(self, user_id, room_id):
+ yield self._simple_update_one(
+ table="users_in_pubic_room",
+ keyvalues={"user_id": user_id},
+ updatevalues={"room_id": room_id},
+ desc="update_user_in_public_user_list",
+ )
+ self.get_user_in_public_room.invalidate((user_id,))
+
+ def remove_from_user_dir(self, user_id):
+ def _remove_from_user_dir_txn(txn):
+ self._simple_delete_txn(
+ txn,
+ table="user_directory",
+ keyvalues={"user_id": user_id},
+ )
+ self._simple_delete_txn(
+ txn,
+ table="user_directory_search",
+ keyvalues={"user_id": user_id},
+ )
+ self._simple_delete_txn(
+ txn,
+ table="users_in_pubic_room",
+ keyvalues={"user_id": user_id},
+ )
+ txn.call_after(
+ self.get_user_in_directory.invalidate, (user_id,)
+ )
+ txn.call_after(
+ self.get_user_in_public_room.invalidate, (user_id,)
+ )
+ return self.runInteraction(
+ "remove_from_user_dir", _remove_from_user_dir_txn,
+ )
+
+ @defer.inlineCallbacks
+ def remove_from_user_in_public_room(self, user_id):
+ yield self._simple_delete(
+ table="users_in_pubic_room",
+ keyvalues={"user_id": user_id},
+ desc="remove_from_user_in_public_room",
+ )
+ self.get_user_in_public_room.invalidate((user_id,))
+
+ def get_users_in_public_due_to_room(self, room_id):
+ """Get all user_ids that are in the room directory becuase they're
+ in the given room_id
+ """
+ return self._simple_select_onecol(
+ table="users_in_pubic_room",
+ keyvalues={"room_id": room_id},
+ retcol="user_id",
+ desc="get_users_in_public_due_to_room",
+ )
+
+ @defer.inlineCallbacks
+ def get_users_in_dir_due_to_room(self, room_id):
+ """Get all user_ids that are in the room directory becuase they're
+ in the given room_id
+ """
+ user_ids_dir = yield self._simple_select_onecol(
+ table="user_directory",
+ keyvalues={"room_id": room_id},
+ retcol="user_id",
+ desc="get_users_in_dir_due_to_room",
+ )
+
+ user_ids_pub = yield self._simple_select_onecol(
+ table="users_in_pubic_room",
+ keyvalues={"room_id": room_id},
+ retcol="user_id",
+ desc="get_users_in_dir_due_to_room",
+ )
+
+ user_ids_share = yield self._simple_select_onecol(
+ table="users_who_share_rooms",
+ keyvalues={"room_id": room_id},
+ retcol="user_id",
+ desc="get_users_in_dir_due_to_room",
+ )
+
+ user_ids = set(user_ids_dir)
+ user_ids.update(user_ids_pub)
+ user_ids.update(user_ids_share)
+
+ defer.returnValue(user_ids)
+
+ @defer.inlineCallbacks
+ def get_all_rooms(self):
+ """Get all room_ids we've ever known about, in ascending order of "size"
+ """
+ sql = """
+ SELECT room_id FROM current_state_events
+ GROUP BY room_id
+ ORDER BY count(*) ASC
+ """
+ rows = yield self._execute("get_all_rooms", None, sql)
+ defer.returnValue([room_id for room_id, in rows])
+
+ def add_users_who_share_room(self, room_id, share_private, user_id_tuples):
+ """Insert entries into the users_who_share_rooms table. The first
+ user should be a local user.
+
+ Args:
+ room_id (str)
+ share_private (bool): Is the room private
+ user_id_tuples([(str, str)]): iterable of 2-tuple of user IDs.
+ """
+ def _add_users_who_share_room_txn(txn):
+ self._simple_insert_many_txn(
+ txn,
+ table="users_who_share_rooms",
+ values=[
+ {
+ "user_id": user_id,
+ "other_user_id": other_user_id,
+ "room_id": room_id,
+ "share_private": share_private,
+ }
+ for user_id, other_user_id in user_id_tuples
+ ],
+ )
+ for user_id, other_user_id in user_id_tuples:
+ txn.call_after(
+ self.get_users_who_share_room_from_dir.invalidate,
+ (user_id,),
+ )
+ txn.call_after(
+ self.get_if_users_share_a_room.invalidate,
+ (user_id, other_user_id),
+ )
+ return self.runInteraction(
+ "add_users_who_share_room", _add_users_who_share_room_txn
+ )
+
+ def update_users_who_share_room(self, room_id, share_private, user_id_sets):
+ """Updates entries in the users_who_share_rooms table. The first
+ user should be a local user.
+
+ Args:
+ room_id (str)
+ share_private (bool): Is the room private
+ user_id_tuples([(str, str)]): iterable of 2-tuple of user IDs.
+ """
+ def _update_users_who_share_room_txn(txn):
+ sql = """
+ UPDATE users_who_share_rooms
+ SET room_id = ?, share_private = ?
+ WHERE user_id = ? AND other_user_id = ?
+ """
+ txn.executemany(
+ sql,
+ (
+ (room_id, share_private, uid, oid)
+ for uid, oid in user_id_sets
+ )
+ )
+ for user_id, other_user_id in user_id_sets:
+ txn.call_after(
+ self.get_users_who_share_room_from_dir.invalidate,
+ (user_id,),
+ )
+ txn.call_after(
+ self.get_if_users_share_a_room.invalidate,
+ (user_id, other_user_id),
+ )
+ return self.runInteraction(
+ "update_users_who_share_room", _update_users_who_share_room_txn
+ )
+
+ def remove_user_who_share_room(self, user_id, other_user_id):
+ """Deletes entries in the users_who_share_rooms table. The first
+ user should be a local user.
+
+ Args:
+ room_id (str)
+ share_private (bool): Is the room private
+ user_id_tuples([(str, str)]): iterable of 2-tuple of user IDs.
+ """
+ def _remove_user_who_share_room_txn(txn):
+ self._simple_delete_txn(
+ txn,
+ table="users_who_share_rooms",
+ keyvalues={
+ "user_id": user_id,
+ "other_user_id": other_user_id,
+ },
+ )
+ txn.call_after(
+ self.get_users_who_share_room_from_dir.invalidate,
+ (user_id,),
+ )
+ txn.call_after(
+ self.get_if_users_share_a_room.invalidate,
+ (user_id, other_user_id),
+ )
+
+ return self.runInteraction(
+ "remove_user_who_share_room", _remove_user_who_share_room_txn
+ )
+
+ @cached(max_entries=500000)
+ def get_if_users_share_a_room(self, user_id, other_user_id):
+ """Gets if users share a room.
+
+ Args:
+ user_id (str): Must be a local user_id
+ other_user_id (str)
+
+ Returns:
+ bool|None: None if they don't share a room, otherwise whether they
+ share a private room or not.
+ """
+ return self._simple_select_one_onecol(
+ table="users_who_share_rooms",
+ keyvalues={
+ "user_id": user_id,
+ "other_user_id": other_user_id,
+ },
+ retcol="share_private",
+ allow_none=True,
+ desc="get_if_users_share_a_room",
+ )
+
+ @cachedInlineCallbacks(max_entries=500000, iterable=True)
+ def get_users_who_share_room_from_dir(self, user_id):
+ """Returns the set of users who share a room with `user_id`
+
+ Args:
+ user_id(str): Must be a local user
+
+ Returns:
+ dict: user_id -> share_private mapping
+ """
+ rows = yield self._simple_select_list(
+ table="users_who_share_rooms",
+ keyvalues={
+ "user_id": user_id,
+ },
+ retcols=("other_user_id", "share_private",),
+ desc="get_users_who_share_room_with_user",
+ )
+
+ defer.returnValue({
+ row["other_user_id"]: row["share_private"]
+ for row in rows
+ })
+
+ def get_users_in_share_dir_with_room_id(self, user_id, room_id):
+ """Get all user tuples that are in the users_who_share_rooms due to the
+ given room_id.
+
+ Returns:
+ [(user_id, other_user_id)]: where one of the two will match the given
+ user_id.
+ """
+ sql = """
+ SELECT user_id, other_user_id FROM users_who_share_rooms
+ WHERE room_id = ? AND (user_id = ? OR other_user_id = ?)
+ """
+ return self._execute(
+ "get_users_in_share_dir_with_room_id", None, sql, room_id, user_id, user_id
+ )
+
+ @defer.inlineCallbacks
+ def get_rooms_in_common_for_users(self, user_id, other_user_id):
+ """Given two user_ids find out the list of rooms they share.
+ """
+ sql = """
+ SELECT room_id FROM (
+ SELECT c.room_id FROM current_state_events AS c
+ INNER JOIN room_memberships USING (event_id)
+ WHERE type = 'm.room.member'
+ AND membership = 'join'
+ AND state_key = ?
+ ) AS f1 INNER JOIN (
+ SELECT c.room_id FROM current_state_events AS c
+ INNER JOIN room_memberships USING (event_id)
+ WHERE type = 'm.room.member'
+ AND membership = 'join'
+ AND state_key = ?
+ ) f2 USING (room_id)
+ """
+
+ rows = yield self._execute(
+ "get_rooms_in_common_for_users", None, sql, user_id, other_user_id
+ )
+
+ defer.returnValue([room_id for room_id, in rows])
+
+ def delete_all_from_user_dir(self):
+ """Delete the entire user directory
+ """
+ def _delete_all_from_user_dir_txn(txn):
+ txn.execute("DELETE FROM user_directory")
+ txn.execute("DELETE FROM user_directory_search")
+ txn.execute("DELETE FROM users_in_pubic_room")
+ txn.execute("DELETE FROM users_who_share_rooms")
+ txn.call_after(self.get_user_in_directory.invalidate_all)
+ txn.call_after(self.get_user_in_public_room.invalidate_all)
+ txn.call_after(self.get_users_who_share_room_from_dir.invalidate_all)
+ txn.call_after(self.get_if_users_share_a_room.invalidate_all)
+ return self.runInteraction(
+ "delete_all_from_user_dir", _delete_all_from_user_dir_txn
+ )
+
+ @cached()
+ def get_user_in_directory(self, user_id):
+ return self._simple_select_one(
+ table="user_directory",
+ keyvalues={"user_id": user_id},
+ retcols=("room_id", "display_name", "avatar_url",),
+ allow_none=True,
+ desc="get_user_in_directory",
+ )
+
+ @cached()
+ def get_user_in_public_room(self, user_id):
+ return self._simple_select_one(
+ table="users_in_pubic_room",
+ keyvalues={"user_id": user_id},
+ retcols=("room_id",),
+ allow_none=True,
+ desc="get_user_in_public_room",
+ )
+
+ def get_user_directory_stream_pos(self):
+ return self._simple_select_one_onecol(
+ table="user_directory_stream_pos",
+ keyvalues={},
+ retcol="stream_id",
+ desc="get_user_directory_stream_pos",
+ )
+
+ def update_user_directory_stream_pos(self, stream_id):
+ return self._simple_update_one(
+ table="user_directory_stream_pos",
+ keyvalues={},
+ updatevalues={"stream_id": stream_id},
+ desc="update_user_directory_stream_pos",
+ )
+
+ def get_current_state_deltas(self, prev_stream_id):
+ prev_stream_id = int(prev_stream_id)
+ if not self._curr_state_delta_stream_cache.has_any_entity_changed(prev_stream_id):
+ return []
+
+ def get_current_state_deltas_txn(txn):
+ # First we calculate the max stream id that will give us less than
+ # N results.
+ # We arbitarily limit to 100 stream_id entries to ensure we don't
+ # select toooo many.
+ sql = """
+ SELECT stream_id, count(*)
+ FROM current_state_delta_stream
+ WHERE stream_id > ?
+ GROUP BY stream_id
+ ORDER BY stream_id ASC
+ LIMIT 100
+ """
+ txn.execute(sql, (prev_stream_id,))
+
+ total = 0
+ max_stream_id = prev_stream_id
+ for max_stream_id, count in txn:
+ total += count
+ if total > 100:
+ # We arbitarily limit to 100 entries to ensure we don't
+ # select toooo many.
+ break
+
+ # Now actually get the deltas
+ sql = """
+ SELECT stream_id, room_id, type, state_key, event_id, prev_event_id
+ FROM current_state_delta_stream
+ WHERE ? < stream_id AND stream_id <= ?
+ ORDER BY stream_id ASC
+ """
+ txn.execute(sql, (prev_stream_id, max_stream_id,))
+ return self.cursor_to_dict(txn)
+
+ return self.runInteraction(
+ "get_current_state_deltas", get_current_state_deltas_txn
+ )
+
+ def get_max_stream_id_in_current_state_deltas(self):
+ return self._simple_select_one_onecol(
+ table="current_state_delta_stream",
+ keyvalues={},
+ retcol="COALESCE(MAX(stream_id), -1)",
+ desc="get_max_stream_id_in_current_state_deltas",
+ )
+
+ @defer.inlineCallbacks
+ def search_user_dir(self, user_id, search_term, limit):
+ """Searches for users in directory
+
+ Returns:
+ dict of the form::
+
+ {
+ "limited": <bool>, # whether there were more results or not
+ "results": [ # Ordered by best match first
+ {
+ "user_id": <user_id>,
+ "display_name": <display_name>,
+ "avatar_url": <avatar_url>
+ }
+ ]
+ }
+ """
+ if isinstance(self.database_engine, PostgresEngine):
+ full_query, exact_query, prefix_query = _parse_query_postgres(search_term)
+
+ # We order by rank and then if they have profile info
+ # The ranking algorithm is hand tweaked for "best" results. Broadly
+ # the idea is we give a higher weight to exact matches.
+ # The array of numbers are the weights for the various part of the
+ # search: (domain, _, display name, localpart)
+ sql = """
+ SELECT d.user_id, display_name, avatar_url
+ FROM user_directory_search
+ INNER JOIN user_directory AS d USING (user_id)
+ LEFT JOIN users_in_pubic_room AS p USING (user_id)
+ LEFT JOIN (
+ SELECT other_user_id AS user_id FROM users_who_share_rooms
+ WHERE user_id = ? AND share_private
+ ) AS s USING (user_id)
+ WHERE
+ (s.user_id IS NOT NULL OR p.user_id IS NOT NULL)
+ AND vector @@ to_tsquery('english', ?)
+ ORDER BY
+ (CASE WHEN s.user_id IS NOT NULL THEN 4.0 ELSE 1.0 END)
+ * (CASE WHEN display_name IS NOT NULL THEN 1.2 ELSE 1.0 END)
+ * (CASE WHEN avatar_url IS NOT NULL THEN 1.2 ELSE 1.0 END)
+ * (
+ 3 * ts_rank_cd(
+ '{0.1, 0.1, 0.9, 1.0}',
+ vector,
+ to_tsquery('english', ?),
+ 8
+ )
+ + ts_rank_cd(
+ '{0.1, 0.1, 0.9, 1.0}',
+ vector,
+ to_tsquery('english', ?),
+ 8
+ )
+ )
+ DESC,
+ display_name IS NULL,
+ avatar_url IS NULL
+ LIMIT ?
+ """
+ args = (user_id, full_query, exact_query, prefix_query, limit + 1,)
+ elif isinstance(self.database_engine, Sqlite3Engine):
+ search_query = _parse_query_sqlite(search_term)
+
+ sql = """
+ SELECT d.user_id, display_name, avatar_url
+ FROM user_directory_search
+ INNER JOIN user_directory AS d USING (user_id)
+ LEFT JOIN users_in_pubic_room AS p USING (user_id)
+ LEFT JOIN (
+ SELECT other_user_id AS user_id FROM users_who_share_rooms
+ WHERE user_id = ? AND share_private
+ ) AS s USING (user_id)
+ WHERE
+ (s.user_id IS NOT NULL OR p.user_id IS NOT NULL)
+ AND value MATCH ?
+ ORDER BY
+ rank(matchinfo(user_directory_search)) DESC,
+ display_name IS NULL,
+ avatar_url IS NULL
+ LIMIT ?
+ """
+ args = (user_id, search_query, limit + 1)
+ else:
+ # This should be unreachable.
+ raise Exception("Unrecognized database engine")
+
+ results = yield self._execute(
+ "search_user_dir", self.cursor_to_dict, sql, *args
+ )
+
+ limited = len(results) > limit
+
+ defer.returnValue({
+ "limited": limited,
+ "results": results,
+ })
+
+
+def _parse_query_sqlite(search_term):
+ """Takes a plain unicode string from the user and converts it into a form
+ that can be passed to database.
+ We use this so that we can add prefix matching, which isn't something
+ that is supported by default.
+
+ We specifically add both a prefix and non prefix matching term so that
+ exact matches get ranked higher.
+ """
+
+ # Pull out the individual words, discarding any non-word characters.
+ results = re.findall(r"([\w\-]+)", search_term, re.UNICODE)
+ return " & ".join("(%s* | %s)" % (result, result,) for result in results)
+
+
+def _parse_query_postgres(search_term):
+ """Takes a plain unicode string from the user and converts it into a form
+ that can be passed to database.
+ We use this so that we can add prefix matching, which isn't something
+ that is supported by default.
+ """
+
+ # Pull out the individual words, discarding any non-word characters.
+ results = re.findall(r"([\w\-]+)", search_term, re.UNICODE)
+
+ both = " & ".join("(%s:* | %s)" % (result, result,) for result in results)
+ exact = " & ".join("%s" % (result,) for result in results)
+ prefix = " & ".join("%s:*" % (result,) for result in results)
+
+ return both, exact, prefix
|