diff --git a/synapse/storage/databases/__init__.py b/synapse/storage/databases/__init__.py
index 4406e58273..0ac854aee2 100644
--- a/synapse/storage/databases/__init__.py
+++ b/synapse/storage/databases/__init__.py
@@ -87,12 +87,21 @@ class Databases(object):
logger.info("Database %r prepared", db_name)
+ # Closing the context manager doesn't close the connection.
+ # psycopg will close the connection when the object gets GCed, but *only*
+ # if the PID is the same as when the connection was opened [1], and
+ # it may not be if we fork in the meantime.
+ #
+ # [1]: https://github.com/psycopg/psycopg2/blob/2_8_5/psycopg/connection_type.c#L1378
+
+ db_conn.close()
+
# Sanity check that we have actually configured all the required stores.
if not main:
raise Exception("No 'main' data store configured")
if not state:
- raise Exception("No 'main' data store configured")
+ raise Exception("No 'state' data store configured")
# We use local variables here to ensure that the databases do not have
# optional types.
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 17fa470919..0934ae276c 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -498,7 +498,7 @@ class DataStore(
)
def get_users_paginate(
- self, start, limit, name=None, guests=True, deactivated=False
+ self, start, limit, user_id=None, name=None, guests=True, deactivated=False
):
"""Function to retrieve a paginated list of users from
users list. This will return a json list of users and the
@@ -507,7 +507,8 @@ class DataStore(
Args:
start (int): start number to begin the query from
limit (int): number of rows to retrieve
- name (string): filter for user names
+ user_id (string): search for user_id. ignored if name is not None
+ name (string): search for local part of user_id or display name
guests (bool): whether to in include guest users
deactivated (bool): whether to include deactivated users
Returns:
@@ -516,11 +517,14 @@ class DataStore(
def get_users_paginate_txn(txn):
filters = []
- args = []
+ args = [self.hs.config.server_name]
if name:
+ filters.append("(name LIKE ? OR displayname LIKE ?)")
+ args.extend(["@%" + name + "%:%", "%" + name + "%"])
+ elif user_id:
filters.append("name LIKE ?")
- args.append("%" + name + "%")
+ args.extend(["%" + user_id + "%"])
if not guests:
filters.append("is_guest = 0")
@@ -530,20 +534,23 @@ class DataStore(
where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else ""
- sql = "SELECT COUNT(*) as total_users FROM users %s" % (where_clause)
- txn.execute(sql, args)
- count = txn.fetchone()[0]
-
- args = [self.hs.config.server_name] + args + [limit, start]
- sql = """
- SELECT name, user_type, is_guest, admin, deactivated, displayname, avatar_url
+ sql_base = """
FROM users as u
LEFT JOIN profiles AS p ON u.name = '@' || p.user_id || ':' || ?
{}
- ORDER BY u.name LIMIT ? OFFSET ?
""".format(
where_clause
)
+ sql = "SELECT COUNT(*) as total_users " + sql_base
+ txn.execute(sql, args)
+ count = txn.fetchone()[0]
+
+ sql = (
+ "SELECT name, user_type, is_guest, admin, deactivated, displayname, avatar_url "
+ + sql_base
+ + " ORDER BY u.name LIMIT ? OFFSET ?"
+ )
+ args += [limit, start]
txn.execute(sql, args)
users = self.db_pool.cursor_to_dict(txn)
return users, count
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 82aac2bbf3..04042a2c98 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -336,7 +336,7 @@ class AccountDataStore(AccountDataWorkerStore):
"""
content_json = json_encoder.encode(content)
- with self._account_data_id_gen.get_next() as next_id:
+ with await self._account_data_id_gen.get_next() as next_id:
# no need to lock here as room_account_data has a unique constraint
# on (user_id, room_id, account_data_type) so simple_upsert will
# retry if there is a conflict.
@@ -384,7 +384,7 @@ class AccountDataStore(AccountDataWorkerStore):
"""
content_json = json_encoder.encode(content)
- with self._account_data_id_gen.get_next() as next_id:
+ with await self._account_data_id_gen.get_next() as next_id:
# no need to lock here as account_data has a unique constraint on
# (user_id, account_data_type) so simple_upsert will retry if
# there is a conflict.
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 5cf1a88399..77723f7d4d 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -16,13 +16,12 @@
import logging
import re
-from canonicaljson import json
-
from synapse.appservice import AppServiceTransaction
from synapse.config.appservice import load_appservices
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.events_worker import EventsWorkerStore
+from synapse.util import json_encoder
logger = logging.getLogger(__name__)
@@ -169,7 +168,7 @@ class ApplicationServiceTransactionWorkerStore(
service(ApplicationService): The service whose state to set.
state(ApplicationServiceState): The connectivity state to apply.
Returns:
- A Deferred which resolves when the state was set successfully.
+ An Awaitable which resolves when the state was set successfully.
"""
return self.db_pool.simple_upsert(
"application_services_state", {"as_id": service.id}, {"state": state}
@@ -204,7 +203,7 @@ class ApplicationServiceTransactionWorkerStore(
new_txn_id = max(highest_txn_id, last_txn_id) + 1
# Insert new txn into txn table
- event_ids = json.dumps([e.event_id for e in events])
+ event_ids = json_encoder.encode([e.event_id for e in events])
txn.execute(
"INSERT INTO application_services_txns(as_id, txn_id, event_ids) "
"VALUES(?,?,?)",
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 10de446065..1e7637a6f5 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -299,8 +299,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
},
)
- def get_cache_stream_token(self, instance_name):
+ def get_cache_stream_token_for_writer(self, instance_name: str) -> int:
if self._cache_id_gen:
- return self._cache_id_gen.get_current_token(instance_name)
+ return self._cache_id_gen.get_current_token_for_writer(instance_name)
else:
return 0
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 1f6e995c4f..bb85637a95 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -362,7 +362,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
rows.append((destination, stream_id, now_ms, edu_json))
txn.executemany(sql, rows)
- with self._device_inbox_id_gen.get_next() as stream_id:
+ with await self._device_inbox_id_gen.get_next() as stream_id:
now_ms = self.clock.time_msec()
await self.db_pool.runInteraction(
"add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id
@@ -411,7 +411,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
txn, stream_id, local_messages_by_user_then_device
)
- with self._device_inbox_id_gen.get_next() as stream_id:
+ with await self._device_inbox_id_gen.get_next() as stream_id:
now_ms = self.clock.time_msec()
await self.db_pool.runInteraction(
"add_messages_from_remote_to_device_inbox",
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 2b33060480..03b45dbc4d 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -380,7 +380,7 @@ class DeviceWorkerStore(SQLBaseStore):
THe new stream ID.
"""
- with self._device_list_id_gen.get_next() as stream_id:
+ with await self._device_list_id_gen.get_next() as stream_id:
await self.db_pool.runInteraction(
"add_user_sig_change_to_streams",
self._add_user_signature_change_txn,
@@ -671,10 +671,9 @@ class DeviceWorkerStore(SQLBaseStore):
@cachedList(
cached_method_name="get_device_list_last_stream_id_for_remote",
list_name="user_ids",
- inlineCallbacks=True,
)
- def get_device_list_last_stream_id_for_remotes(self, user_ids: str):
- rows = yield self.db_pool.simple_select_many_batch(
+ async def get_device_list_last_stream_id_for_remotes(self, user_ids: str):
+ rows = await self.db_pool.simple_select_many_batch(
table="device_lists_remote_extremeties",
column="user_id",
iterable=user_ids,
@@ -1147,7 +1146,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
if not device_ids:
return
- with self._device_list_id_gen.get_next_mult(len(device_ids)) as stream_ids:
+ with await self._device_list_id_gen.get_next_mult(
+ len(device_ids)
+ ) as stream_ids:
await self.db_pool.runInteraction(
"add_device_change_to_stream",
self._add_device_change_to_stream_txn,
@@ -1160,7 +1161,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
return stream_ids[-1]
context = get_active_span_text_map()
- with self._device_list_id_gen.get_next_mult(
+ with await self._device_list_id_gen.get_next_mult(
len(hosts) * len(device_ids)
) as stream_ids:
await self.db_pool.runInteraction(
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index f93e0d320d..385868bdab 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -648,7 +648,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
"delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
)
- def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key):
+ def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key, stream_id):
"""Set a user's cross-signing key.
Args:
@@ -658,6 +658,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
for a master key, 'self_signing' for a self-signing key, or
'user_signing' for a user-signing key
key (dict): the key data
+ stream_id (int)
"""
# the 'key' dict will look something like:
# {
@@ -695,23 +696,22 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
)
# and finally, store the key itself
- with self._cross_signing_id_gen.get_next() as stream_id:
- self.db_pool.simple_insert_txn(
- txn,
- "e2e_cross_signing_keys",
- values={
- "user_id": user_id,
- "keytype": key_type,
- "keydata": json_encoder.encode(key),
- "stream_id": stream_id,
- },
- )
+ self.db_pool.simple_insert_txn(
+ txn,
+ "e2e_cross_signing_keys",
+ values={
+ "user_id": user_id,
+ "keytype": key_type,
+ "keydata": json_encoder.encode(key),
+ "stream_id": stream_id,
+ },
+ )
self._invalidate_cache_and_stream(
txn, self._get_bare_e2e_cross_signing_keys, (user_id,)
)
- def set_e2e_cross_signing_key(self, user_id, key_type, key):
+ async def set_e2e_cross_signing_key(self, user_id, key_type, key):
"""Set a user's cross-signing key.
Args:
@@ -719,13 +719,16 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
key_type (str): the type of cross-signing key to set
key (dict): the key data
"""
- return self.db_pool.runInteraction(
- "add_e2e_cross_signing_key",
- self._set_e2e_cross_signing_key_txn,
- user_id,
- key_type,
- key,
- )
+
+ with await self._cross_signing_id_gen.get_next() as stream_id:
+ return await self.db_pool.runInteraction(
+ "add_e2e_cross_signing_key",
+ self._set_e2e_cross_signing_key_txn,
+ user_id,
+ key_type,
+ key,
+ stream_id,
+ )
def store_e2e_cross_signing_signatures(self, user_id, signatures):
"""Stores cross-signing signatures.
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 484875f989..e6a97b018c 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -15,14 +15,16 @@
import itertools
import logging
from queue import Empty, PriorityQueue
-from typing import Dict, Iterable, List, Optional, Set, Tuple
+from typing import Dict, Iterable, List, Set, Tuple
from synapse.api.errors import StoreError
+from synapse.events import EventBase
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.signatures import SignatureWorkerStore
+from synapse.types import Collection
from synapse.util.caches.descriptors import cached
from synapse.util.iterutils import batch_iter
@@ -30,57 +32,51 @@ logger = logging.getLogger(__name__)
class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore):
- def get_auth_chain(self, event_ids, include_given=False):
+ async def get_auth_chain(
+ self, event_ids: Collection[str], include_given: bool = False
+ ) -> List[EventBase]:
"""Get auth events for given event_ids. The events *must* be state events.
Args:
- event_ids (list): state events
- include_given (bool): include the given events in result
+ event_ids: state events
+ include_given: include the given events in result
Returns:
list of events
"""
- return self.get_auth_chain_ids(
+ event_ids = await self.get_auth_chain_ids(
event_ids, include_given=include_given
- ).addCallback(self.get_events_as_list)
-
- def get_auth_chain_ids(
- self,
- event_ids: List[str],
- include_given: bool = False,
- ignore_events: Optional[Set[str]] = None,
- ):
+ )
+ return await self.get_events_as_list(event_ids)
+
+ async def get_auth_chain_ids(
+ self, event_ids: Collection[str], include_given: bool = False,
+ ) -> List[str]:
"""Get auth events for given event_ids. The events *must* be state events.
Args:
event_ids: state events
include_given: include the given events in result
- ignore_events: Set of events to exclude from the returned auth
- chain. This is useful if the caller will just discard the
- given events anyway, and saves us from figuring out their auth
- chains if not required.
Returns:
list of event_ids
"""
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_auth_chain_ids",
self._get_auth_chain_ids_txn,
event_ids,
include_given,
- ignore_events,
)
- def _get_auth_chain_ids_txn(self, txn, event_ids, include_given, ignore_events):
- if ignore_events is None:
- ignore_events = set()
-
+ def _get_auth_chain_ids_txn(
+ self, txn: LoggingTransaction, event_ids: Collection[str], include_given: bool
+ ) -> List[str]:
if include_given:
results = set(event_ids)
else:
results = set()
- base_sql = "SELECT auth_id FROM event_auth WHERE "
+ base_sql = "SELECT DISTINCT auth_id FROM event_auth WHERE "
front = set(event_ids)
while front:
@@ -92,7 +88,6 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
txn.execute(base_sql + clause, args)
new_front.update(r[0] for r in txn)
- new_front -= ignore_events
new_front -= results
front = new_front
@@ -257,11 +252,6 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
# Return all events where not all sets can reach them.
return {eid for eid, n in event_to_missing_sets.items() if n}
- def get_oldest_events_in_room(self, room_id):
- return self.db_pool.runInteraction(
- "get_oldest_events_in_room", self._get_oldest_events_in_room_txn, room_id
- )
-
def get_oldest_events_with_depth_in_room(self, room_id):
return self.db_pool.runInteraction(
"get_oldest_events_with_depth_in_room",
@@ -303,14 +293,6 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
else:
return max(row["depth"] for row in rows)
- def _get_oldest_events_in_room_txn(self, txn, room_id):
- return self.db_pool.simple_select_onecol_txn(
- txn,
- table="event_backward_extremities",
- keyvalues={"room_id": room_id},
- retcol="event_id",
- )
-
def get_prev_events_for_room(self, room_id: str):
"""
Gets a subset of the current forward extremities in the given room.
@@ -472,7 +454,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
"get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn
)
- def get_backfill_events(self, room_id, event_list, limit):
+ async def get_backfill_events(self, room_id, event_list, limit):
"""Get a list of Events for a given topic that occurred before (and
including) the events in event_list. Return a list of max size `limit`
@@ -482,17 +464,15 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
event_list (list)
limit (int)
"""
- return (
- self.db_pool.runInteraction(
- "get_backfill_events",
- self._get_backfill_events,
- room_id,
- event_list,
- limit,
- )
- .addCallback(self.get_events_as_list)
- .addCallback(lambda l: sorted(l, key=lambda e: -e.depth))
+ event_ids = await self.db_pool.runInteraction(
+ "get_backfill_events",
+ self._get_backfill_events,
+ room_id,
+ event_list,
+ limit,
)
+ events = await self.get_events_as_list(event_ids)
+ return sorted(events, key=lambda e: -e.depth)
def _get_backfill_events(self, txn, room_id, event_list, limit):
logger.debug("_get_backfill_events: %s, %r, %s", room_id, event_list, limit)
@@ -553,8 +533,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
latest_events,
limit,
)
- events = await self.get_events_as_list(ids)
- return events
+ return await self.get_events_as_list(ids)
def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit):
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 7c246d3e4c..e8834b2162 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -21,7 +21,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool
from synapse.util import json_encoder
-from synapse.util.caches.descriptors import cachedInlineCallbacks
+from synapse.util.caches.descriptors import cached
logger = logging.getLogger(__name__)
@@ -86,18 +86,17 @@ class EventPushActionsWorkerStore(SQLBaseStore):
self._rotate_delay = 3
self._rotate_count = 10000
- @cachedInlineCallbacks(num_args=3, tree=True, max_entries=5000)
- def get_unread_event_push_actions_by_room_for_user(
+ @cached(num_args=3, tree=True, max_entries=5000)
+ async def get_unread_event_push_actions_by_room_for_user(
self, room_id, user_id, last_read_event_id
):
- ret = yield self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_unread_event_push_actions_by_room",
self._get_unread_counts_by_receipt_txn,
room_id,
user_id,
last_read_event_id,
)
- return ret
def _get_unread_counts_by_receipt_txn(
self, txn, room_id, user_id, last_read_event_id
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 1a68bf32cb..6313b41eef 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -17,13 +17,11 @@
import itertools
import logging
from collections import OrderedDict, namedtuple
-from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple
+from typing import TYPE_CHECKING, Dict, Iterable, List, Set, Tuple
import attr
from prometheus_client import Counter
-from twisted.internet import defer
-
import synapse.metrics
from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
from synapse.api.room_versions import RoomVersions
@@ -113,15 +111,14 @@ class PersistEventsStore:
hs.config.worker.writers.events == hs.get_instance_name()
), "Can only instantiate EventsStore on master"
- @defer.inlineCallbacks
- def _persist_events_and_state_updates(
+ async def _persist_events_and_state_updates(
self,
events_and_contexts: List[Tuple[EventBase, EventContext]],
current_state_for_room: Dict[str, StateMap[str]],
state_delta_for_room: Dict[str, DeltaState],
new_forward_extremeties: Dict[str, List[str]],
backfilled: bool = False,
- ):
+ ) -> None:
"""Persist a set of events alongside updates to the current state and
forward extremities tables.
@@ -136,7 +133,7 @@ class PersistEventsStore:
backfilled
Returns:
- Deferred: resolves when the events have been persisted
+ Resolves when the events have been persisted
"""
# We want to calculate the stream orderings as late as possible, as
@@ -156,11 +153,11 @@ class PersistEventsStore:
# Note: Multiple instances of this function cannot be in flight at
# the same time for the same room.
if backfilled:
- stream_ordering_manager = self._backfill_id_gen.get_next_mult(
+ stream_ordering_manager = await self._backfill_id_gen.get_next_mult(
len(events_and_contexts)
)
else:
- stream_ordering_manager = self._stream_id_gen.get_next_mult(
+ stream_ordering_manager = await self._stream_id_gen.get_next_mult(
len(events_and_contexts)
)
@@ -168,7 +165,7 @@ class PersistEventsStore:
for (event, context), stream in zip(events_and_contexts, stream_orderings):
event.internal_metadata.stream_ordering = stream
- yield self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"persist_events",
self._persist_events_txn,
events_and_contexts=events_and_contexts,
@@ -206,16 +203,15 @@ class PersistEventsStore:
(room_id,), list(latest_event_ids)
)
- @defer.inlineCallbacks
- def _get_events_which_are_prevs(self, event_ids):
+ async def _get_events_which_are_prevs(self, event_ids: Iterable[str]) -> List[str]:
"""Filter the supplied list of event_ids to get those which are prev_events of
existing (non-outlier/rejected) events.
Args:
- event_ids (Iterable[str]): event ids to filter
+ event_ids: event ids to filter
Returns:
- Deferred[List[str]]: filtered event ids
+ Filtered event ids
"""
results = []
@@ -240,14 +236,13 @@ class PersistEventsStore:
results.extend(r[0] for r in txn if not db_to_json(r[1]).get("soft_failed"))
for chunk in batch_iter(event_ids, 100):
- yield self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"_get_events_which_are_prevs", _get_events_which_are_prevs_txn, chunk
)
return results
- @defer.inlineCallbacks
- def _get_prevs_before_rejected(self, event_ids):
+ async def _get_prevs_before_rejected(self, event_ids: Iterable[str]) -> Set[str]:
"""Get soft-failed ancestors to remove from the extremities.
Given a set of events, find all those that have been soft-failed or
@@ -259,11 +254,11 @@ class PersistEventsStore:
are separated by soft failed events.
Args:
- event_ids (Iterable[str]): Events to find prev events for. Note
- that these must have already been persisted.
+ event_ids: Events to find prev events for. Note that these must have
+ already been persisted.
Returns:
- Deferred[set[str]]
+ The previous events.
"""
# The set of event_ids to return. This includes all soft-failed events
@@ -304,7 +299,7 @@ class PersistEventsStore:
existing_prevs.add(prev_event_id)
for chunk in batch_iter(event_ids, 100):
- yield self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"_get_prevs_before_rejected", _get_prevs_before_rejected_txn, chunk
)
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index 35a0e09e3c..e53c6373a8 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -15,8 +15,6 @@
import logging
-from twisted.internet import defer
-
from synapse.api.constants import EventContentFields
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool
@@ -94,8 +92,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
where_clause="NOT have_censored",
)
- @defer.inlineCallbacks
- def _background_reindex_fields_sender(self, progress, batch_size):
+ async def _background_reindex_fields_sender(self, progress, batch_size):
target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"]
rows_inserted = progress.get("rows_inserted", 0)
@@ -155,19 +152,18 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
return len(rows)
- result = yield self.db_pool.runInteraction(
+ result = await self.db_pool.runInteraction(
self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, reindex_txn
)
if not result:
- yield self.db_pool.updates._end_background_update(
+ await self.db_pool.updates._end_background_update(
self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME
)
return result
- @defer.inlineCallbacks
- def _background_reindex_origin_server_ts(self, progress, batch_size):
+ async def _background_reindex_origin_server_ts(self, progress, batch_size):
target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"]
rows_inserted = progress.get("rows_inserted", 0)
@@ -234,19 +230,18 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
return len(rows_to_update)
- result = yield self.db_pool.runInteraction(
+ result = await self.db_pool.runInteraction(
self.EVENT_ORIGIN_SERVER_TS_NAME, reindex_search_txn
)
if not result:
- yield self.db_pool.updates._end_background_update(
+ await self.db_pool.updates._end_background_update(
self.EVENT_ORIGIN_SERVER_TS_NAME
)
return result
- @defer.inlineCallbacks
- def _cleanup_extremities_bg_update(self, progress, batch_size):
+ async def _cleanup_extremities_bg_update(self, progress, batch_size):
"""Background update to clean out extremities that should have been
deleted previously.
@@ -414,26 +409,25 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
return len(original_set)
- num_handled = yield self.db_pool.runInteraction(
+ num_handled = await self.db_pool.runInteraction(
"_cleanup_extremities_bg_update", _cleanup_extremities_bg_update_txn
)
if not num_handled:
- yield self.db_pool.updates._end_background_update(
+ await self.db_pool.updates._end_background_update(
self.DELETE_SOFT_FAILED_EXTREMITIES
)
def _drop_table_txn(txn):
txn.execute("DROP TABLE _extremities_to_check")
- yield self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"_cleanup_extremities_bg_update_drop_table", _drop_table_txn
)
return num_handled
- @defer.inlineCallbacks
- def _redactions_received_ts(self, progress, batch_size):
+ async def _redactions_received_ts(self, progress, batch_size):
"""Handles filling out the `received_ts` column in redactions.
"""
last_event_id = progress.get("last_event_id", "")
@@ -480,17 +474,16 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
return len(rows)
- count = yield self.db_pool.runInteraction(
+ count = await self.db_pool.runInteraction(
"_redactions_received_ts", _redactions_received_ts_txn
)
if not count:
- yield self.db_pool.updates._end_background_update("redactions_received_ts")
+ await self.db_pool.updates._end_background_update("redactions_received_ts")
return count
- @defer.inlineCallbacks
- def _event_fix_redactions_bytes(self, progress, batch_size):
+ async def _event_fix_redactions_bytes(self, progress, batch_size):
"""Undoes hex encoded censored redacted event JSON.
"""
@@ -511,16 +504,15 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
txn.execute("DROP INDEX redactions_censored_redacts")
- yield self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"_event_fix_redactions_bytes", _event_fix_redactions_bytes_txn
)
- yield self.db_pool.updates._end_background_update("event_fix_redactions_bytes")
+ await self.db_pool.updates._end_background_update("event_fix_redactions_bytes")
return 1
- @defer.inlineCallbacks
- def _event_store_labels(self, progress, batch_size):
+ async def _event_store_labels(self, progress, batch_size):
"""Background update handler which will store labels for existing events."""
last_event_id = progress.get("last_event_id", "")
@@ -575,11 +567,11 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
return nbrows
- num_rows = yield self.db_pool.runInteraction(
+ num_rows = await self.db_pool.runInteraction(
desc="event_store_labels", func=_event_store_labels_txn
)
if not num_rows:
- yield self.db_pool.updates._end_background_update("event_store_labels")
+ await self.db_pool.updates._end_background_update("event_store_labels")
return num_rows
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 755b7a2a85..e1241a724b 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -19,9 +19,10 @@ import itertools
import logging
import threading
from collections import namedtuple
-from typing import List, Optional, Tuple
+from typing import Dict, Iterable, List, Optional, Tuple, overload
from constantly import NamedConstant, Names
+from typing_extensions import Literal
from twisted.internet import defer
@@ -32,7 +33,7 @@ from synapse.api.room_versions import (
EventFormatVersions,
RoomVersions,
)
-from synapse.events import make_event_from_dict
+from synapse.events import EventBase, make_event_from_dict
from synapse.events.utils import prune_event
from synapse.logging.context import PreserveLoggingContext, current_context
from synapse.metrics.background_process_metrics import run_as_background_process
@@ -42,8 +43,8 @@ from synapse.replication.tcp.streams.events import EventsStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool
from synapse.storage.util.id_generators import StreamIdGenerator
-from synapse.types import get_domain_from_id
-from synapse.util.caches.descriptors import Cache, cached, cachedInlineCallbacks
+from synapse.types import Collection, get_domain_from_id
+from synapse.util.caches.descriptors import Cache, cached
from synapse.util.iterutils import batch_iter
from synapse.util.metrics import Measure
@@ -137,44 +138,33 @@ class EventsWorkerStore(SQLBaseStore):
desc="get_received_ts",
)
- def get_received_ts_by_stream_pos(self, stream_ordering):
- """Given a stream ordering get an approximate timestamp of when it
- happened.
-
- This is done by simply taking the received ts of the first event that
- has a stream ordering greater than or equal to the given stream pos.
- If none exists returns the current time, on the assumption that it must
- have happened recently.
-
- Args:
- stream_ordering (int)
-
- Returns:
- Deferred[int]
- """
-
- def _get_approximate_received_ts_txn(txn):
- sql = """
- SELECT received_ts FROM events
- WHERE stream_ordering >= ?
- LIMIT 1
- """
-
- txn.execute(sql, (stream_ordering,))
- row = txn.fetchone()
- if row and row[0]:
- ts = row[0]
- else:
- ts = self.clock.time_msec()
-
- return ts
+ # Inform mypy that if allow_none is False (the default) then get_event
+ # always returns an EventBase.
+ @overload
+ async def get_event(
+ self,
+ event_id: str,
+ redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
+ get_prev_content: bool = False,
+ allow_rejected: bool = False,
+ allow_none: Literal[False] = False,
+ check_room_id: Optional[str] = None,
+ ) -> EventBase:
+ ...
- return self.db_pool.runInteraction(
- "get_approximate_received_ts", _get_approximate_received_ts_txn
- )
+ @overload
+ async def get_event(
+ self,
+ event_id: str,
+ redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
+ get_prev_content: bool = False,
+ allow_rejected: bool = False,
+ allow_none: Literal[True] = False,
+ check_room_id: Optional[str] = None,
+ ) -> Optional[EventBase]:
+ ...
- @defer.inlineCallbacks
- def get_event(
+ async def get_event(
self,
event_id: str,
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
@@ -182,7 +172,7 @@ class EventsWorkerStore(SQLBaseStore):
allow_rejected: bool = False,
allow_none: bool = False,
check_room_id: Optional[str] = None,
- ):
+ ) -> Optional[EventBase]:
"""Get an event from the database by event_id.
Args:
@@ -207,12 +197,12 @@ class EventsWorkerStore(SQLBaseStore):
If there is a mismatch, behave as per allow_none.
Returns:
- Deferred[EventBase|None]
+ The event, or None if the event was not found.
"""
if not isinstance(event_id, str):
raise TypeError("Invalid event event_id %r" % (event_id,))
- events = yield self.get_events_as_list(
+ events = await self.get_events_as_list(
[event_id],
redact_behaviour=redact_behaviour,
get_prev_content=get_prev_content,
@@ -230,14 +220,13 @@ class EventsWorkerStore(SQLBaseStore):
return event
- @defer.inlineCallbacks
- def get_events(
+ async def get_events(
self,
- event_ids: List[str],
+ event_ids: Iterable[str],
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
get_prev_content: bool = False,
allow_rejected: bool = False,
- ):
+ ) -> Dict[str, EventBase]:
"""Get events from the database
Args:
@@ -256,9 +245,9 @@ class EventsWorkerStore(SQLBaseStore):
omits rejeted events from the response.
Returns:
- Deferred : Dict from event_id to event.
+ A mapping from event_id to event.
"""
- events = yield self.get_events_as_list(
+ events = await self.get_events_as_list(
event_ids,
redact_behaviour=redact_behaviour,
get_prev_content=get_prev_content,
@@ -267,14 +256,13 @@ class EventsWorkerStore(SQLBaseStore):
return {e.event_id: e for e in events}
- @defer.inlineCallbacks
- def get_events_as_list(
+ async def get_events_as_list(
self,
- event_ids: List[str],
+ event_ids: Collection[str],
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
get_prev_content: bool = False,
allow_rejected: bool = False,
- ):
+ ) -> List[EventBase]:
"""Get events from the database and return in a list in the same order
as given by `event_ids` arg.
@@ -295,8 +283,8 @@ class EventsWorkerStore(SQLBaseStore):
omits rejected events from the response.
Returns:
- Deferred[list[EventBase]]: List of events fetched from the database. The
- events are in the same order as `event_ids` arg.
+ List of events fetched from the database. The events are in the same
+ order as `event_ids` arg.
Note that the returned list may be smaller than the list of event
IDs if not all events could be fetched.
@@ -306,7 +294,7 @@ class EventsWorkerStore(SQLBaseStore):
return []
# there may be duplicates so we cast the list to a set
- event_entry_map = yield self._get_events_from_cache_or_db(
+ event_entry_map = await self._get_events_from_cache_or_db(
set(event_ids), allow_rejected=allow_rejected
)
@@ -341,7 +329,7 @@ class EventsWorkerStore(SQLBaseStore):
continue
redacted_event_id = entry.event.redacts
- event_map = yield self._get_events_from_cache_or_db([redacted_event_id])
+ event_map = await self._get_events_from_cache_or_db([redacted_event_id])
original_event_entry = event_map.get(redacted_event_id)
if not original_event_entry:
# we don't have the redacted event (or it was rejected).
@@ -407,7 +395,7 @@ class EventsWorkerStore(SQLBaseStore):
if get_prev_content:
if "replaces_state" in event.unsigned:
- prev = yield self.get_event(
+ prev = await self.get_event(
event.unsigned["replaces_state"],
get_prev_content=False,
allow_none=True,
@@ -419,8 +407,7 @@ class EventsWorkerStore(SQLBaseStore):
return events
- @defer.inlineCallbacks
- def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False):
+ async def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False):
"""Fetch a bunch of events from the cache or the database.
If events are pulled from the database, they will be cached for future lookups.
@@ -435,7 +422,7 @@ class EventsWorkerStore(SQLBaseStore):
rejected events are omitted from the response.
Returns:
- Deferred[Dict[str, _EventCacheEntry]]:
+ Dict[str, _EventCacheEntry]:
map from event id to result
"""
event_entry_map = self._get_events_from_cache(
@@ -453,7 +440,7 @@ class EventsWorkerStore(SQLBaseStore):
# the events have been redacted, and if so pulling the redaction event out
# of the database to check it.
#
- missing_events = yield self._get_events_from_db(
+ missing_events = await self._get_events_from_db(
missing_events_ids, allow_rejected=allow_rejected
)
@@ -561,8 +548,7 @@ class EventsWorkerStore(SQLBaseStore):
with PreserveLoggingContext():
self.hs.get_reactor().callFromThread(fire, event_list, e)
- @defer.inlineCallbacks
- def _get_events_from_db(self, event_ids, allow_rejected=False):
+ async def _get_events_from_db(self, event_ids, allow_rejected=False):
"""Fetch a bunch of events from the database.
Returned events will be added to the cache for future lookups.
@@ -576,7 +562,7 @@ class EventsWorkerStore(SQLBaseStore):
rejected events are omitted from the response.
Returns:
- Deferred[Dict[str, _EventCacheEntry]]:
+ Dict[str, _EventCacheEntry]:
map from event id to result. May return extra events which
weren't asked for.
"""
@@ -584,7 +570,7 @@ class EventsWorkerStore(SQLBaseStore):
events_to_fetch = event_ids
while events_to_fetch:
- row_map = yield self._enqueue_events(events_to_fetch)
+ row_map = await self._enqueue_events(events_to_fetch)
# we need to recursively fetch any redactions of those events
redaction_ids = set()
@@ -610,8 +596,20 @@ class EventsWorkerStore(SQLBaseStore):
if not allow_rejected and rejected_reason:
continue
- d = db_to_json(row["json"])
- internal_metadata = db_to_json(row["internal_metadata"])
+ # If the event or metadata cannot be parsed, log the error and act
+ # as if the event is unknown.
+ try:
+ d = db_to_json(row["json"])
+ except ValueError:
+ logger.error("Unable to parse json from event: %s", event_id)
+ continue
+ try:
+ internal_metadata = db_to_json(row["internal_metadata"])
+ except ValueError:
+ logger.error(
+ "Unable to parse internal_metadata from event: %s", event_id
+ )
+ continue
format_version = row["format_version"]
if format_version is None:
@@ -622,19 +620,38 @@ class EventsWorkerStore(SQLBaseStore):
room_version_id = row["room_version_id"]
if not room_version_id:
- # this should only happen for out-of-band membership events
- if not internal_metadata.get("out_of_band_membership"):
- logger.warning(
- "Room %s for event %s is unknown", d["room_id"], event_id
+ # this should only happen for out-of-band membership events which
+ # arrived before #6983 landed. For all other events, we should have
+ # an entry in the 'rooms' table.
+ #
+ # However, the 'out_of_band_membership' flag is unreliable for older
+ # invites, so just accept it for all membership events.
+ #
+ if d["type"] != EventTypes.Member:
+ raise Exception(
+ "Room %s for event %s is unknown" % (d["room_id"], event_id)
)
- continue
- # take a wild stab at the room version based on the event format
+ # so, assuming this is an out-of-band-invite that arrived before #6983
+ # landed, we know that the room version must be v5 or earlier (because
+ # v6 hadn't been invented at that point, so invites from such rooms
+ # would have been rejected.)
+ #
+ # The main reason we need to know the room version here (other than
+ # choosing the right python Event class) is in case the event later has
+ # to be redacted - and all the room versions up to v5 used the same
+ # redaction algorithm.
+ #
+ # So, the following approximations should be adequate.
+
if format_version == EventFormatVersions.V1:
+ # if it's event format v1 then it must be room v1 or v2
room_version = RoomVersions.V1
elif format_version == EventFormatVersions.V2:
+ # if it's event format v2 then it must be room v3
room_version = RoomVersions.V3
else:
+ # if it's event format v3 then it must be room v4 or v5
room_version = RoomVersions.V5
else:
room_version = KNOWN_ROOM_VERSIONS.get(room_version_id)
@@ -686,8 +703,7 @@ class EventsWorkerStore(SQLBaseStore):
return result_map
- @defer.inlineCallbacks
- def _enqueue_events(self, events):
+ async def _enqueue_events(self, events):
"""Fetches events from the database using the _event_fetch_list. This
allows batch and bulk fetching of events - it allows us to fetch events
without having to create a new transaction for each request for events.
@@ -696,7 +712,7 @@ class EventsWorkerStore(SQLBaseStore):
events (Iterable[str]): events to be fetched.
Returns:
- Deferred[Dict[str, Dict]]: map from event id to row data from the database.
+ Dict[str, Dict]: map from event id to row data from the database.
May contain events that weren't requested.
"""
@@ -719,7 +735,7 @@ class EventsWorkerStore(SQLBaseStore):
logger.debug("Loading %d events: %s", len(events), events)
with PreserveLoggingContext():
- row_map = yield events_d
+ row_map = await events_d
logger.debug("Loaded %d events (%d rows)", len(events), len(row_map))
return row_map
@@ -878,12 +894,11 @@ class EventsWorkerStore(SQLBaseStore):
# no valid redaction found for this event
return None
- @defer.inlineCallbacks
- def have_events_in_timeline(self, event_ids):
+ async def have_events_in_timeline(self, event_ids):
"""Given a list of event ids, check if we have already processed and
stored them as non outliers.
"""
- rows = yield self.db_pool.simple_select_many_batch(
+ rows = await self.db_pool.simple_select_many_batch(
table="events",
retcols=("event_id",),
column="event_id",
@@ -894,15 +909,14 @@ class EventsWorkerStore(SQLBaseStore):
return {r["event_id"] for r in rows}
- @defer.inlineCallbacks
- def have_seen_events(self, event_ids):
+ async def have_seen_events(self, event_ids):
"""Given a list of event ids, check if we have already processed them.
Args:
event_ids (iterable[str]):
Returns:
- Deferred[set[str]]: The events we have already seen.
+ set[str]: The events we have already seen.
"""
results = set()
@@ -918,41 +932,11 @@ class EventsWorkerStore(SQLBaseStore):
# break the input up into chunks of 100
input_iterator = iter(event_ids)
for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), []):
- yield self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"have_seen_events", have_seen_events_txn, chunk
)
return results
- def _get_total_state_event_counts_txn(self, txn, room_id):
- """
- See get_total_state_event_counts.
- """
- # We join against the events table as that has an index on room_id
- sql = """
- SELECT COUNT(*) FROM state_events
- INNER JOIN events USING (room_id, event_id)
- WHERE room_id=?
- """
- txn.execute(sql, (room_id,))
- row = txn.fetchone()
- return row[0] if row else 0
-
- def get_total_state_event_counts(self, room_id):
- """
- Gets the total number of state events in a room.
-
- Args:
- room_id (str)
-
- Returns:
- Deferred[int]
- """
- return self.db_pool.runInteraction(
- "get_total_state_event_counts",
- self._get_total_state_event_counts_txn,
- room_id,
- )
-
def _get_current_state_event_counts_txn(self, txn, room_id):
"""
See get_current_state_event_counts.
@@ -978,8 +962,7 @@ class EventsWorkerStore(SQLBaseStore):
room_id,
)
- @defer.inlineCallbacks
- def get_room_complexity(self, room_id):
+ async def get_room_complexity(self, room_id):
"""
Get a rough approximation of the complexity of the room. This is used by
remote servers to decide whether they wish to join the room or not.
@@ -990,9 +973,9 @@ class EventsWorkerStore(SQLBaseStore):
room_id (str)
Returns:
- Deferred[dict[str:int]] of complexity version to complexity.
+ dict[str:int] of complexity version to complexity.
"""
- state_events = yield self.get_current_state_event_counts(room_id)
+ state_events = await self.get_current_state_event_counts(room_id)
# Call this one "v1", so we can introduce new ones as we want to develop
# it.
@@ -1222,97 +1205,6 @@ class EventsWorkerStore(SQLBaseStore):
return rows, to_token, True
- @cached(num_args=5, max_entries=10)
- def get_all_new_events(
- self,
- last_backfill_id,
- last_forward_id,
- current_backfill_id,
- current_forward_id,
- limit,
- ):
- """Get all the new events that have arrived at the server either as
- new events or as backfilled events"""
- have_backfill_events = last_backfill_id != current_backfill_id
- have_forward_events = last_forward_id != current_forward_id
-
- if not have_backfill_events and not have_forward_events:
- return defer.succeed(AllNewEventsResult([], [], [], [], []))
-
- def get_all_new_events_txn(txn):
- sql = (
- "SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
- " state_key, redacts"
- " FROM events AS e"
- " LEFT JOIN redactions USING (event_id)"
- " LEFT JOIN state_events USING (event_id)"
- " WHERE ? < stream_ordering AND stream_ordering <= ?"
- " ORDER BY stream_ordering ASC"
- " LIMIT ?"
- )
- if have_forward_events:
- txn.execute(sql, (last_forward_id, current_forward_id, limit))
- new_forward_events = txn.fetchall()
-
- if len(new_forward_events) == limit:
- upper_bound = new_forward_events[-1][0]
- else:
- upper_bound = current_forward_id
-
- sql = (
- "SELECT event_stream_ordering, event_id, state_group"
- " FROM ex_outlier_stream"
- " WHERE ? > event_stream_ordering"
- " AND event_stream_ordering >= ?"
- " ORDER BY event_stream_ordering DESC"
- )
- txn.execute(sql, (last_forward_id, upper_bound))
- forward_ex_outliers = txn.fetchall()
- else:
- new_forward_events = []
- forward_ex_outliers = []
-
- sql = (
- "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
- " state_key, redacts"
- " FROM events AS e"
- " LEFT JOIN redactions USING (event_id)"
- " LEFT JOIN state_events USING (event_id)"
- " WHERE ? > stream_ordering AND stream_ordering >= ?"
- " ORDER BY stream_ordering DESC"
- " LIMIT ?"
- )
- if have_backfill_events:
- txn.execute(sql, (-last_backfill_id, -current_backfill_id, limit))
- new_backfill_events = txn.fetchall()
-
- if len(new_backfill_events) == limit:
- upper_bound = new_backfill_events[-1][0]
- else:
- upper_bound = current_backfill_id
-
- sql = (
- "SELECT -event_stream_ordering, event_id, state_group"
- " FROM ex_outlier_stream"
- " WHERE ? > event_stream_ordering"
- " AND event_stream_ordering >= ?"
- " ORDER BY event_stream_ordering DESC"
- )
- txn.execute(sql, (-last_backfill_id, -upper_bound))
- backward_ex_outliers = txn.fetchall()
- else:
- new_backfill_events = []
- backward_ex_outliers = []
-
- return AllNewEventsResult(
- new_forward_events,
- new_backfill_events,
- forward_ex_outliers,
- backward_ex_outliers,
- )
-
- return self.db_pool.runInteraction("get_all_new_events", get_all_new_events_txn)
-
async def is_event_after(self, event_id1, event_id2):
"""Returns True if event_id1 is after event_id2 in the stream
"""
@@ -1320,9 +1212,9 @@ class EventsWorkerStore(SQLBaseStore):
to_2, so_2 = await self.get_event_ordering(event_id2)
return (to_1, so_1) > (to_2, so_2)
- @cachedInlineCallbacks(max_entries=5000)
- def get_event_ordering(self, event_id):
- res = yield self.db_pool.simple_select_one(
+ @cached(max_entries=5000)
+ async def get_event_ordering(self, event_id):
+ res = await self.db_pool.simple_select_one(
table="events",
retcols=["topological_ordering", "stream_ordering"],
keyvalues={"event_id": event_id},
@@ -1357,14 +1249,3 @@ class EventsWorkerStore(SQLBaseStore):
return self.db_pool.runInteraction(
desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
)
-
-
-AllNewEventsResult = namedtuple(
- "AllNewEventsResult",
- [
- "new_forward_events",
- "new_backfill_events",
- "forward_ex_outliers",
- "backward_ex_outliers",
- ],
-)
diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py
index 380db3a3f3..a488e0924b 100644
--- a/synapse/storage/databases/main/group_server.py
+++ b/synapse/storage/databases/main/group_server.py
@@ -341,14 +341,15 @@ class GroupServerWorkerStore(SQLBaseStore):
"get_users_for_summary_by_role", _get_users_for_summary_txn
)
- def is_user_in_group(self, user_id, group_id):
- return self.db_pool.simple_select_one_onecol(
+ async def is_user_in_group(self, user_id: str, group_id: str) -> bool:
+ result = await self.db_pool.simple_select_one_onecol(
table="group_users",
keyvalues={"group_id": group_id, "user_id": user_id},
retcol="user_id",
allow_none=True,
desc="is_user_in_group",
- ).addCallback(lambda r: bool(r))
+ )
+ return bool(result)
def is_user_admin_in_group(self, group_id, user_id):
return self.db_pool.simple_select_one_onecol(
@@ -1181,7 +1182,7 @@ class GroupServerStore(GroupServerWorkerStore):
return next_id
- with self._group_updates_id_gen.get_next() as next_id:
+ with await self._group_updates_id_gen.get_next() as next_id:
res = await self.db_pool.runInteraction(
"register_user_group_membership",
_register_user_group_membership_txn,
diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py
index 384e9c5eb0..fadcad51e7 100644
--- a/synapse/storage/databases/main/keys.py
+++ b/synapse/storage/databases/main/keys.py
@@ -16,6 +16,7 @@
import itertools
import logging
+from typing import Iterable, Tuple
from signedjson.key import decode_verify_key_bytes
@@ -88,12 +89,17 @@ class KeyStore(SQLBaseStore):
return self.db_pool.runInteraction("get_server_verify_keys", _txn)
- def store_server_verify_keys(self, from_server, ts_added_ms, verify_keys):
+ async def store_server_verify_keys(
+ self,
+ from_server: str,
+ ts_added_ms: int,
+ verify_keys: Iterable[Tuple[str, str, FetchKeyResult]],
+ ) -> None:
"""Stores NACL verification keys for remote servers.
Args:
- from_server (str): Where the verification keys were looked up
- ts_added_ms (int): The time to record that the key was added
- verify_keys (iterable[tuple[str, str, FetchKeyResult]]):
+ from_server: Where the verification keys were looked up
+ ts_added_ms: The time to record that the key was added
+ verify_keys:
keys to be stored. Each entry is a triplet of
(server_name, key_id, key).
"""
@@ -115,13 +121,7 @@ class KeyStore(SQLBaseStore):
# param, which is itself the 2-tuple (server_name, key_id).
invalidations.append((server_name, key_id))
- def _invalidate(res):
- f = self._get_server_verify_key.invalidate
- for i in invalidations:
- f((i,))
- return res
-
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"store_server_verify_keys",
self.db_pool.simple_upsert_many_txn,
table="server_signature_keys",
@@ -134,7 +134,11 @@ class KeyStore(SQLBaseStore):
"verify_key",
),
value_values=value_values,
- ).addCallback(_invalidate)
+ )
+
+ invalidate = self._get_server_verify_key.invalidate
+ for i in invalidations:
+ invalidate((i,))
def store_server_keys_json(
self, server_name, key_id, from_server, ts_now_ms, ts_expires_ms, key_json_bytes
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index 59ba12820a..c9f655dfb7 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -15,15 +15,15 @@
from typing import List, Tuple
+from synapse.api.presence import UserPresenceState
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
-from synapse.storage.presence import UserPresenceState
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.iterutils import batch_iter
class PresenceStore(SQLBaseStore):
async def update_presence(self, presence_states):
- stream_ordering_manager = self._presence_id_gen.get_next_mult(
+ stream_ordering_manager = await self._presence_id_gen.get_next_mult(
len(presence_states)
)
@@ -130,13 +130,10 @@ class PresenceStore(SQLBaseStore):
raise NotImplementedError()
@cachedList(
- cached_method_name="_get_presence_for_user",
- list_name="user_ids",
- num_args=1,
- inlineCallbacks=True,
+ cached_method_name="_get_presence_for_user", list_name="user_ids", num_args=1,
)
- def get_presence_for_users(self, user_ids):
- rows = yield self.db_pool.simple_select_many_batch(
+ async def get_presence_for_users(self, user_ids):
+ rows = await self.db_pool.simple_select_many_batch(
table="presence_stream",
column="user_id",
iterable=user_ids,
@@ -160,24 +157,3 @@ class PresenceStore(SQLBaseStore):
def get_current_presence_token(self):
return self._presence_id_gen.get_current_token()
-
- def allow_presence_visible(self, observed_localpart, observer_userid):
- return self.db_pool.simple_insert(
- table="presence_allow_inbound",
- values={
- "observed_user_id": observed_localpart,
- "observer_user_id": observer_userid,
- },
- desc="allow_presence_visible",
- or_ignore=True,
- )
-
- def disallow_presence_visible(self, observed_localpart, observer_userid):
- return self.db_pool.simple_delete_one(
- table="presence_allow_inbound",
- keyvalues={
- "observed_user_id": observed_localpart,
- "observer_user_id": observer_userid,
- },
- desc="disallow_presence_visible",
- )
diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py
index b8261357d4..086cfbeed4 100644
--- a/synapse/storage/databases/main/profile.py
+++ b/synapse/storage/databases/main/profile.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,9 +14,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import List, Tuple
+
from synapse.api.errors import StoreError
from synapse.storage._base import SQLBaseStore
from synapse.storage.databases.main.roommember import ProfileInfo
+from synapse.types import UserID
+from synapse.util.caches.descriptors import cached
+
+BATCH_SIZE = 100
class ProfileWorkerStore(SQLBaseStore):
@@ -38,6 +45,7 @@ class ProfileWorkerStore(SQLBaseStore):
avatar_url=profile["avatar_url"], display_name=profile["displayname"]
)
+ @cached(max_entries=5000)
def get_profile_displayname(self, user_localpart):
return self.db_pool.simple_select_one_onecol(
table="profiles",
@@ -46,6 +54,7 @@ class ProfileWorkerStore(SQLBaseStore):
desc="get_profile_displayname",
)
+ @cached(max_entries=5000)
def get_profile_avatar_url(self, user_localpart):
return self.db_pool.simple_select_one_onecol(
table="profiles",
@@ -54,6 +63,56 @@ class ProfileWorkerStore(SQLBaseStore):
desc="get_profile_avatar_url",
)
+ def get_latest_profile_replication_batch_number(self):
+ def f(txn):
+ txn.execute("SELECT MAX(batch) as maxbatch FROM profiles")
+ rows = self.db_pool.cursor_to_dict(txn)
+ return rows[0]["maxbatch"]
+
+ return self.db_pool.runInteraction(
+ "get_latest_profile_replication_batch_number", f
+ )
+
+ def get_profile_batch(self, batchnum):
+ return self.db_pool.simple_select_list(
+ table="profiles",
+ keyvalues={"batch": batchnum},
+ retcols=("user_id", "displayname", "avatar_url", "active"),
+ desc="get_profile_batch",
+ )
+
+ def assign_profile_batch(self):
+ def f(txn):
+ sql = (
+ "UPDATE profiles SET batch = "
+ "(SELECT COALESCE(MAX(batch), -1) + 1 FROM profiles) "
+ "WHERE user_id in ("
+ " SELECT user_id FROM profiles WHERE batch is NULL limit ?"
+ ")"
+ )
+ txn.execute(sql, (BATCH_SIZE,))
+ return txn.rowcount
+
+ return self.db_pool.runInteraction("assign_profile_batch", f)
+
+ def get_replication_hosts(self):
+ def f(txn):
+ txn.execute(
+ "SELECT host, last_synced_batch FROM profile_replication_status"
+ )
+ rows = self.db_pool.cursor_to_dict(txn)
+ return {r["host"]: r["last_synced_batch"] for r in rows}
+
+ return self.db_pool.runInteraction("get_replication_hosts", f)
+
+ def update_replication_batch_for_host(self, host, last_synced_batch):
+ return self.db_pool.simple_upsert(
+ table="profile_replication_status",
+ keyvalues={"host": host},
+ values={"last_synced_batch": last_synced_batch},
+ desc="update_replication_batch_for_host",
+ )
+
def get_from_remote_profile_cache(self, user_id):
return self.db_pool.simple_select_one(
table="remote_profile_cache",
@@ -68,24 +127,83 @@ class ProfileWorkerStore(SQLBaseStore):
table="profiles", values={"user_id": user_localpart}, desc="create_profile"
)
- def set_profile_displayname(self, user_localpart, new_displayname):
- return self.db_pool.simple_update_one(
+ def set_profile_displayname(self, user_localpart, new_displayname, batchnum):
+ # Invalidate the read cache for this user
+ self.get_profile_displayname.invalidate((user_localpart,))
+
+ return self.db_pool.simple_upsert(
table="profiles",
keyvalues={"user_id": user_localpart},
- updatevalues={"displayname": new_displayname},
+ values={"displayname": new_displayname, "batch": batchnum},
desc="set_profile_displayname",
+ lock=False, # we can do this because user_id has a unique index
)
- def set_profile_avatar_url(self, user_localpart, new_avatar_url):
- return self.db_pool.simple_update_one(
+ def set_profile_avatar_url(self, user_localpart, new_avatar_url, batchnum):
+ # Invalidate the read cache for this user
+ self.get_profile_avatar_url.invalidate((user_localpart,))
+
+ return self.db_pool.simple_upsert(
table="profiles",
keyvalues={"user_id": user_localpart},
- updatevalues={"avatar_url": new_avatar_url},
+ values={"avatar_url": new_avatar_url, "batch": batchnum},
desc="set_profile_avatar_url",
+ lock=False, # we can do this because user_id has a unique index
+ )
+
+ def set_profiles_active(
+ self, users: List[UserID], active: bool, hide: bool, batchnum: int,
+ ):
+ """Given a set of users, set active and hidden flags on them.
+
+ Args:
+ users: A list of UserIDs
+ active: Whether to set the users to active or inactive
+ hide: Whether to hide the users (withold from replication). If
+ False and active is False, users will have their profiles
+ erased
+ batchnum: The batch number, used for profile replication
+
+ Returns:
+ Deferred
+ """
+ # Convert list of localparts to list of tuples containing localparts
+ user_localparts = [(user.localpart,) for user in users]
+
+ # Generate list of value tuples for each user
+ value_names = ("active", "batch")
+ values = [(int(active), batchnum) for _ in user_localparts] # type: List[Tuple]
+
+ if not active and not hide:
+ # we are deactivating for real (not in hide mode)
+ # so clear the profile information
+ value_names += ("avatar_url", "displayname")
+ values = [v + (None, None) for v in values]
+
+ return self.db_pool.runInteraction(
+ "set_profiles_active",
+ self.db_pool.simple_upsert_many_txn,
+ table="profiles",
+ key_names=("user_id",),
+ key_values=user_localparts,
+ value_names=value_names,
+ value_values=values,
)
class ProfileStore(ProfileWorkerStore):
+ def __init__(self, database, db_conn, hs):
+
+ super(ProfileStore, self).__init__(database, db_conn, hs)
+
+ self.db_pool.updates.register_background_index_update(
+ "profile_replication_status_host_index",
+ index_name="profile_replication_status_idx",
+ table="profile_replication_status",
+ columns=["host"],
+ unique=True,
+ )
+
def add_remote_profile_cache(self, user_id, displayname, avatar_url):
"""Ensure we are caching the remote user's profiles.
@@ -104,7 +222,7 @@ class ProfileStore(ProfileWorkerStore):
)
def update_remote_profile_cache(self, user_id, displayname, avatar_url):
- return self.db_pool.simple_update(
+ return self.db_pool.simple_upsert(
table="remote_profile_cache",
keyvalues={"user_id": user_id},
updatevalues={
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 6562db5c2b..2fb5b02d7d 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -30,9 +30,9 @@ from synapse.storage.databases.main.pusher import PusherWorkerStore
from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
-from synapse.storage.util.id_generators import ChainedIdGenerator
+from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.util import json_encoder
-from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
+from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
logger = logging.getLogger(__name__)
@@ -82,9 +82,9 @@ class PushRulesWorkerStore(
super(PushRulesWorkerStore, self).__init__(database, db_conn, hs)
if hs.config.worker.worker_app is None:
- self._push_rules_stream_id_gen = ChainedIdGenerator(
- self._stream_id_gen, db_conn, "push_rules_stream", "stream_id"
- ) # type: Union[ChainedIdGenerator, SlavedIdTracker]
+ self._push_rules_stream_id_gen = StreamIdGenerator(
+ db_conn, "push_rules_stream", "stream_id"
+ ) # type: Union[StreamIdGenerator, SlavedIdTracker]
else:
self._push_rules_stream_id_gen = SlavedIdTracker(
db_conn, "push_rules_stream", "stream_id"
@@ -115,9 +115,9 @@ class PushRulesWorkerStore(
"""
raise NotImplementedError()
- @cachedInlineCallbacks(max_entries=5000)
- def get_push_rules_for_user(self, user_id):
- rows = yield self.db_pool.simple_select_list(
+ @cached(max_entries=5000)
+ async def get_push_rules_for_user(self, user_id):
+ rows = await self.db_pool.simple_select_list(
table="push_rules",
keyvalues={"user_name": user_id},
retcols=(
@@ -133,17 +133,15 @@ class PushRulesWorkerStore(
rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"])))
- enabled_map = yield self.get_push_rules_enabled_for_user(user_id)
+ enabled_map = await self.get_push_rules_enabled_for_user(user_id)
use_new_defaults = user_id in self._users_new_default_push_rules
- rules = _load_rules(rows, enabled_map, use_new_defaults)
-
- return rules
+ return _load_rules(rows, enabled_map, use_new_defaults)
- @cachedInlineCallbacks(max_entries=5000)
- def get_push_rules_enabled_for_user(self, user_id):
- results = yield self.db_pool.simple_select_list(
+ @cached(max_entries=5000)
+ async def get_push_rules_enabled_for_user(self, user_id):
+ results = await self.db_pool.simple_select_list(
table="push_rules_enable",
keyvalues={"user_name": user_id},
retcols=("user_name", "rule_id", "enabled"),
@@ -170,18 +168,15 @@ class PushRulesWorkerStore(
)
@cachedList(
- cached_method_name="get_push_rules_for_user",
- list_name="user_ids",
- num_args=1,
- inlineCallbacks=True,
+ cached_method_name="get_push_rules_for_user", list_name="user_ids", num_args=1,
)
- def bulk_get_push_rules(self, user_ids):
+ async def bulk_get_push_rules(self, user_ids):
if not user_ids:
return {}
results = {user_id: [] for user_id in user_ids}
- rows = yield self.db_pool.simple_select_many_batch(
+ rows = await self.db_pool.simple_select_many_batch(
table="push_rules",
column="user_name",
iterable=user_ids,
@@ -194,7 +189,7 @@ class PushRulesWorkerStore(
for row in rows:
results.setdefault(row["user_name"], []).append(row)
- enabled_map_by_user = yield self.bulk_get_push_rules_enabled(user_ids)
+ enabled_map_by_user = await self.bulk_get_push_rules_enabled(user_ids)
for user_id, rules in results.items():
use_new_defaults = user_id in self._users_new_default_push_rules
@@ -205,14 +200,15 @@ class PushRulesWorkerStore(
return results
- @defer.inlineCallbacks
- def copy_push_rule_from_room_to_room(self, new_room_id, user_id, rule):
+ async def copy_push_rule_from_room_to_room(
+ self, new_room_id: str, user_id: str, rule: dict
+ ) -> None:
"""Copy a single push rule from one room to another for a specific user.
Args:
- new_room_id (str): ID of the new room.
- user_id (str): ID of user the push rule belongs to.
- rule (Dict): A push rule.
+ new_room_id: ID of the new room.
+ user_id : ID of user the push rule belongs to.
+ rule: A push rule.
"""
# Create new rule id
rule_id_scope = "/".join(rule["rule_id"].split("/")[:-1])
@@ -224,7 +220,7 @@ class PushRulesWorkerStore(
condition["pattern"] = new_room_id
# Add the rule for the new room
- yield self.add_push_rule(
+ await self.add_push_rule(
user_id=user_id,
rule_id=new_rule_id,
priority_class=rule["priority_class"],
@@ -232,20 +228,19 @@ class PushRulesWorkerStore(
actions=rule["actions"],
)
- @defer.inlineCallbacks
- def copy_push_rules_from_room_to_room_for_user(
- self, old_room_id, new_room_id, user_id
- ):
+ async def copy_push_rules_from_room_to_room_for_user(
+ self, old_room_id: str, new_room_id: str, user_id: str
+ ) -> None:
"""Copy all of the push rules from one room to another for a specific
user.
Args:
- old_room_id (str): ID of the old room.
- new_room_id (str): ID of the new room.
- user_id (str): ID of user to copy push rules for.
+ old_room_id: ID of the old room.
+ new_room_id: ID of the new room.
+ user_id: ID of user to copy push rules for.
"""
# Retrieve push rules for this user
- user_push_rules = yield self.get_push_rules_for_user(user_id)
+ user_push_rules = await self.get_push_rules_for_user(user_id)
# Get rules relating to the old room and copy them to the new room
for rule in user_push_rules:
@@ -254,21 +249,20 @@ class PushRulesWorkerStore(
(c.get("key") == "room_id" and c.get("pattern") == old_room_id)
for c in conditions
):
- yield self.copy_push_rule_from_room_to_room(new_room_id, user_id, rule)
+ await self.copy_push_rule_from_room_to_room(new_room_id, user_id, rule)
@cachedList(
cached_method_name="get_push_rules_enabled_for_user",
list_name="user_ids",
num_args=1,
- inlineCallbacks=True,
)
- def bulk_get_push_rules_enabled(self, user_ids):
+ async def bulk_get_push_rules_enabled(self, user_ids):
if not user_ids:
return {}
results = {user_id: {} for user_id in user_ids}
- rows = yield self.db_pool.simple_select_many_batch(
+ rows = await self.db_pool.simple_select_many_batch(
table="push_rules_enable",
column="user_name",
iterable=user_ids,
@@ -332,8 +326,7 @@ class PushRulesWorkerStore(
class PushRuleStore(PushRulesWorkerStore):
- @defer.inlineCallbacks
- def add_push_rule(
+ async def add_push_rule(
self,
user_id,
rule_id,
@@ -342,13 +335,14 @@ class PushRuleStore(PushRulesWorkerStore):
actions,
before=None,
after=None,
- ):
+ ) -> None:
conditions_json = json_encoder.encode(conditions)
actions_json = json_encoder.encode(actions)
- with self._push_rules_stream_id_gen.get_next() as ids:
- stream_id, event_stream_ordering = ids
+ with await self._push_rules_stream_id_gen.get_next() as stream_id:
+ event_stream_ordering = self._stream_id_gen.get_current_token()
+
if before or after:
- yield self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"_add_push_rule_relative_txn",
self._add_push_rule_relative_txn,
stream_id,
@@ -362,7 +356,7 @@ class PushRuleStore(PushRulesWorkerStore):
after,
)
else:
- yield self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"_add_push_rule_highest_priority_txn",
self._add_push_rule_highest_priority_txn,
stream_id,
@@ -546,16 +540,15 @@ class PushRuleStore(PushRulesWorkerStore):
},
)
- @defer.inlineCallbacks
- def delete_push_rule(self, user_id, rule_id):
+ async def delete_push_rule(self, user_id: str, rule_id: str) -> None:
"""
Delete a push rule. Args specify the row to be deleted and can be
any of the columns in the push_rule table, but below are the
standard ones
Args:
- user_id (str): The matrix ID of the push rule owner
- rule_id (str): The rule_id of the rule to be deleted
+ user_id: The matrix ID of the push rule owner
+ rule_id: The rule_id of the rule to be deleted
"""
def delete_push_rule_txn(txn, stream_id, event_stream_ordering):
@@ -567,20 +560,21 @@ class PushRuleStore(PushRulesWorkerStore):
txn, stream_id, event_stream_ordering, user_id, rule_id, op="DELETE"
)
- with self._push_rules_stream_id_gen.get_next() as ids:
- stream_id, event_stream_ordering = ids
- yield self.db_pool.runInteraction(
+ with await self._push_rules_stream_id_gen.get_next() as stream_id:
+ event_stream_ordering = self._stream_id_gen.get_current_token()
+
+ await self.db_pool.runInteraction(
"delete_push_rule",
delete_push_rule_txn,
stream_id,
event_stream_ordering,
)
- @defer.inlineCallbacks
- def set_push_rule_enabled(self, user_id, rule_id, enabled):
- with self._push_rules_stream_id_gen.get_next() as ids:
- stream_id, event_stream_ordering = ids
- yield self.db_pool.runInteraction(
+ async def set_push_rule_enabled(self, user_id, rule_id, enabled) -> None:
+ with await self._push_rules_stream_id_gen.get_next() as stream_id:
+ event_stream_ordering = self._stream_id_gen.get_current_token()
+
+ await self.db_pool.runInteraction(
"_set_push_rule_enabled_txn",
self._set_push_rule_enabled_txn,
stream_id,
@@ -611,8 +605,9 @@ class PushRuleStore(PushRulesWorkerStore):
op="ENABLE" if enabled else "DISABLE",
)
- @defer.inlineCallbacks
- def set_push_rule_actions(self, user_id, rule_id, actions, is_default_rule):
+ async def set_push_rule_actions(
+ self, user_id, rule_id, actions, is_default_rule
+ ) -> None:
actions_json = json_encoder.encode(actions)
def set_push_rule_actions_txn(txn, stream_id, event_stream_ordering):
@@ -651,9 +646,10 @@ class PushRuleStore(PushRulesWorkerStore):
data={"actions": actions_json},
)
- with self._push_rules_stream_id_gen.get_next() as ids:
- stream_id, event_stream_ordering = ids
- yield self.db_pool.runInteraction(
+ with await self._push_rules_stream_id_gen.get_next() as stream_id:
+ event_stream_ordering = self._stream_id_gen.get_current_token()
+
+ await self.db_pool.runInteraction(
"set_push_rule_actions",
set_push_rule_actions_txn,
stream_id,
@@ -681,11 +677,5 @@ class PushRuleStore(PushRulesWorkerStore):
self.push_rules_stream_cache.entity_has_changed, user_id, stream_id
)
- def get_push_rules_stream_token(self):
- """Get the position of the push rules stream.
- Returns a pair of a stream id for the push_rules stream and the
- room stream ordering it corresponds to."""
- return self._push_rules_stream_id_gen.get_current_token()
-
def get_max_push_rules_stream_id(self):
- return self.get_push_rules_stream_token()[0]
+ return self._push_rules_stream_id_gen.get_current_token()
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index b5200fbe79..c388468273 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -19,10 +19,8 @@ from typing import Iterable, Iterator, List, Tuple
from canonicaljson import encode_canonical_json
-from twisted.internet import defer
-
from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
+from synapse.util.caches.descriptors import cached, cachedList
logger = logging.getLogger(__name__)
@@ -34,23 +32,22 @@ class PusherWorkerStore(SQLBaseStore):
Drops any rows whose data cannot be decoded
"""
for r in rows:
- dataJson = r["data"]
+ data_json = r["data"]
try:
- r["data"] = db_to_json(dataJson)
+ r["data"] = db_to_json(data_json)
except Exception as e:
logger.warning(
"Invalid JSON in data for pusher %d: %s, %s",
r["id"],
- dataJson,
+ data_json,
e.args[0],
)
continue
yield r
- @defer.inlineCallbacks
- def user_has_pusher(self, user_id):
- ret = yield self.db_pool.simple_select_one_onecol(
+ async def user_has_pusher(self, user_id):
+ ret = await self.db_pool.simple_select_one_onecol(
"pushers", {"user_name": user_id}, "id", allow_none=True
)
return ret is not None
@@ -61,9 +58,8 @@ class PusherWorkerStore(SQLBaseStore):
def get_pushers_by_user_id(self, user_id):
return self.get_pushers_by({"user_name": user_id})
- @defer.inlineCallbacks
- def get_pushers_by(self, keyvalues):
- ret = yield self.db_pool.simple_select_list(
+ async def get_pushers_by(self, keyvalues):
+ ret = await self.db_pool.simple_select_list(
"pushers",
keyvalues,
[
@@ -87,16 +83,14 @@ class PusherWorkerStore(SQLBaseStore):
)
return self._decode_pushers_rows(ret)
- @defer.inlineCallbacks
- def get_all_pushers(self):
+ async def get_all_pushers(self):
def get_pushers(txn):
txn.execute("SELECT * FROM pushers")
rows = self.db_pool.cursor_to_dict(txn)
return self._decode_pushers_rows(rows)
- rows = yield self.db_pool.runInteraction("get_all_pushers", get_pushers)
- return rows
+ return await self.db_pool.runInteraction("get_all_pushers", get_pushers)
async def get_all_updated_pushers_rows(
self, instance_name: str, last_id: int, current_id: int, limit: int
@@ -164,19 +158,16 @@ class PusherWorkerStore(SQLBaseStore):
"get_all_updated_pushers_rows", get_all_updated_pushers_rows_txn
)
- @cachedInlineCallbacks(num_args=1, max_entries=15000)
- def get_if_user_has_pusher(self, user_id):
+ @cached(num_args=1, max_entries=15000)
+ async def get_if_user_has_pusher(self, user_id):
# This only exists for the cachedList decorator
raise NotImplementedError()
@cachedList(
- cached_method_name="get_if_user_has_pusher",
- list_name="user_ids",
- num_args=1,
- inlineCallbacks=True,
+ cached_method_name="get_if_user_has_pusher", list_name="user_ids", num_args=1,
)
- def get_if_users_have_pushers(self, user_ids):
- rows = yield self.db_pool.simple_select_many_batch(
+ async def get_if_users_have_pushers(self, user_ids):
+ rows = await self.db_pool.simple_select_many_batch(
table="pushers",
column="user_name",
iterable=user_ids,
@@ -189,34 +180,38 @@ class PusherWorkerStore(SQLBaseStore):
return result
- @defer.inlineCallbacks
- def update_pusher_last_stream_ordering(
+ async def update_pusher_last_stream_ordering(
self, app_id, pushkey, user_id, last_stream_ordering
- ):
- yield self.db_pool.simple_update_one(
+ ) -> None:
+ await self.db_pool.simple_update_one(
"pushers",
{"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
{"last_stream_ordering": last_stream_ordering},
desc="update_pusher_last_stream_ordering",
)
- @defer.inlineCallbacks
- def update_pusher_last_stream_ordering_and_success(
- self, app_id, pushkey, user_id, last_stream_ordering, last_success
- ):
+ async def update_pusher_last_stream_ordering_and_success(
+ self,
+ app_id: str,
+ pushkey: str,
+ user_id: str,
+ last_stream_ordering: int,
+ last_success: int,
+ ) -> bool:
"""Update the last stream ordering position we've processed up to for
the given pusher.
Args:
- app_id (str)
- pushkey (str)
- last_stream_ordering (int)
- last_success (int)
+ app_id
+ pushkey
+ user_id
+ last_stream_ordering
+ last_success
Returns:
- Deferred[bool]: True if the pusher still exists; False if it has been deleted.
+ True if the pusher still exists; False if it has been deleted.
"""
- updated = yield self.db_pool.simple_update(
+ updated = await self.db_pool.simple_update(
table="pushers",
keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
updatevalues={
@@ -228,18 +223,18 @@ class PusherWorkerStore(SQLBaseStore):
return bool(updated)
- @defer.inlineCallbacks
- def update_pusher_failing_since(self, app_id, pushkey, user_id, failing_since):
- yield self.db_pool.simple_update(
+ async def update_pusher_failing_since(
+ self, app_id, pushkey, user_id, failing_since
+ ) -> None:
+ await self.db_pool.simple_update(
table="pushers",
keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
updatevalues={"failing_since": failing_since},
desc="update_pusher_failing_since",
)
- @defer.inlineCallbacks
- def get_throttle_params_by_room(self, pusher_id):
- res = yield self.db_pool.simple_select_list(
+ async def get_throttle_params_by_room(self, pusher_id):
+ res = await self.db_pool.simple_select_list(
"pusher_throttle",
{"pusher": pusher_id},
["room_id", "last_sent_ts", "throttle_ms"],
@@ -255,11 +250,10 @@ class PusherWorkerStore(SQLBaseStore):
return params_by_room
- @defer.inlineCallbacks
- def set_throttle_params(self, pusher_id, room_id, params):
+ async def set_throttle_params(self, pusher_id, room_id, params) -> None:
# no need to lock because `pusher_throttle` has a primary key on
# (pusher, room_id) so simple_upsert will retry
- yield self.db_pool.simple_upsert(
+ await self.db_pool.simple_upsert(
"pusher_throttle",
{"pusher": pusher_id, "room_id": room_id},
params,
@@ -272,8 +266,7 @@ class PusherStore(PusherWorkerStore):
def get_pushers_stream_token(self):
return self._pushers_id_gen.get_current_token()
- @defer.inlineCallbacks
- def add_pusher(
+ async def add_pusher(
self,
user_id,
access_token,
@@ -287,11 +280,11 @@ class PusherStore(PusherWorkerStore):
data,
last_stream_ordering,
profile_tag="",
- ):
- with self._pushers_id_gen.get_next() as stream_id:
+ ) -> None:
+ with await self._pushers_id_gen.get_next() as stream_id:
# no need to lock because `pushers` has a unique key on
# (app_id, pushkey, user_name) so simple_upsert will retry
- yield self.db_pool.simple_upsert(
+ await self.db_pool.simple_upsert(
table="pushers",
keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
values={
@@ -316,15 +309,16 @@ class PusherStore(PusherWorkerStore):
if user_has_pusher is not True:
# invalidate, since we the user might not have had a pusher before
- yield self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"add_pusher",
self._invalidate_cache_and_stream,
self.get_if_user_has_pusher,
(user_id,),
)
- @defer.inlineCallbacks
- def delete_pusher_by_app_id_pushkey_user_id(self, app_id, pushkey, user_id):
+ async def delete_pusher_by_app_id_pushkey_user_id(
+ self, app_id, pushkey, user_id
+ ) -> None:
def delete_pusher_txn(txn, stream_id):
self._invalidate_cache_and_stream(
txn, self.get_if_user_has_pusher, (user_id,)
@@ -350,7 +344,7 @@ class PusherStore(PusherWorkerStore):
},
)
- with self._pushers_id_gen.get_next() as stream_id:
- yield self.db_pool.runInteraction(
+ with await self._pushers_id_gen.get_next() as stream_id:
+ await self.db_pool.runInteraction(
"delete_pusher", delete_pusher_txn, stream_id
)
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 1920a8a152..6821476ee0 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -16,7 +16,7 @@
import abc
import logging
-from typing import List, Tuple
+from typing import List, Optional, Tuple
from twisted.internet import defer
@@ -25,7 +25,7 @@ from synapse.storage.database import DatabasePool
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.util import json_encoder
from synapse.util.async_helpers import ObservableDeferred
-from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
+from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
logger = logging.getLogger(__name__)
@@ -56,9 +56,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
"""
raise NotImplementedError()
- @cachedInlineCallbacks()
- def get_users_with_read_receipts_in_room(self, room_id):
- receipts = yield self.get_receipts_for_room(room_id, "m.read")
+ @cached()
+ async def get_users_with_read_receipts_in_room(self, room_id):
+ receipts = await self.get_receipts_for_room(room_id, "m.read")
return {r["user_id"] for r in receipts}
@cached(num_args=2)
@@ -84,9 +84,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
allow_none=True,
)
- @cachedInlineCallbacks(num_args=2)
- def get_receipts_for_user(self, user_id, receipt_type):
- rows = yield self.db_pool.simple_select_list(
+ @cached(num_args=2)
+ async def get_receipts_for_user(self, user_id, receipt_type):
+ rows = await self.db_pool.simple_select_list(
table="receipts_linearized",
keyvalues={"user_id": user_id, "receipt_type": receipt_type},
retcols=("room_id", "event_id"),
@@ -95,8 +95,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
return {row["room_id"]: row["event_id"] for row in rows}
- @defer.inlineCallbacks
- def get_receipts_for_user_with_orderings(self, user_id, receipt_type):
+ async def get_receipts_for_user_with_orderings(self, user_id, receipt_type):
def f(txn):
sql = (
"SELECT rl.room_id, rl.event_id,"
@@ -110,7 +109,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
txn.execute(sql, (user_id,))
return txn.fetchall()
- rows = yield self.db_pool.runInteraction(
+ rows = await self.db_pool.runInteraction(
"get_receipts_for_user_with_orderings", f
)
return {
@@ -122,56 +121,61 @@ class ReceiptsWorkerStore(SQLBaseStore):
for row in rows
}
- @defer.inlineCallbacks
- def get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
+ async def get_linearized_receipts_for_rooms(
+ self, room_ids: List[str], to_key: int, from_key: Optional[int] = None
+ ) -> List[dict]:
"""Get receipts for multiple rooms for sending to clients.
Args:
- room_ids (list): List of room_ids.
- to_key (int): Max stream id to fetch receipts upto.
- from_key (int): Min stream id to fetch receipts from. None fetches
+ room_id: List of room_ids.
+ to_key: Max stream id to fetch receipts upto.
+ from_key: Min stream id to fetch receipts from. None fetches
from the start.
Returns:
- list: A list of receipts.
+ A list of receipts.
"""
room_ids = set(room_ids)
if from_key is not None:
# Only ask the database about rooms where there have been new
# receipts added since `from_key`
- room_ids = yield self._receipts_stream_cache.get_entities_changed(
+ room_ids = self._receipts_stream_cache.get_entities_changed(
room_ids, from_key
)
- results = yield self._get_linearized_receipts_for_rooms(
+ results = await self._get_linearized_receipts_for_rooms(
room_ids, to_key, from_key=from_key
)
return [ev for res in results.values() for ev in res]
- def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
+ async def get_linearized_receipts_for_room(
+ self, room_id: str, to_key: int, from_key: Optional[int] = None
+ ) -> List[dict]:
"""Get receipts for a single room for sending to clients.
Args:
- room_ids (str): The room id.
- to_key (int): Max stream id to fetch receipts upto.
- from_key (int): Min stream id to fetch receipts from. None fetches
+ room_ids: The room id.
+ to_key: Max stream id to fetch receipts upto.
+ from_key: Min stream id to fetch receipts from. None fetches
from the start.
Returns:
- Deferred[list]: A list of receipts.
+ A list of receipts.
"""
if from_key is not None:
# Check the cache first to see if any new receipts have been added
# since`from_key`. If not we can no-op.
if not self._receipts_stream_cache.has_entity_changed(room_id, from_key):
- defer.succeed([])
+ return []
- return self._get_linearized_receipts_for_room(room_id, to_key, from_key)
+ return await self._get_linearized_receipts_for_room(room_id, to_key, from_key)
- @cachedInlineCallbacks(num_args=3, tree=True)
- def _get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
+ @cached(num_args=3, tree=True)
+ async def _get_linearized_receipts_for_room(
+ self, room_id: str, to_key: int, from_key: Optional[int] = None
+ ) -> List[dict]:
"""See get_linearized_receipts_for_room
"""
@@ -195,7 +199,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
return rows
- rows = yield self.db_pool.runInteraction("get_linearized_receipts_for_room", f)
+ rows = await self.db_pool.runInteraction("get_linearized_receipts_for_room", f)
if not rows:
return []
@@ -212,9 +216,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
cached_method_name="_get_linearized_receipts_for_room",
list_name="room_ids",
num_args=3,
- inlineCallbacks=True,
)
- def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
+ async def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
if not room_ids:
return {}
@@ -243,7 +246,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
return self.db_pool.cursor_to_dict(txn)
- txn_results = yield self.db_pool.runInteraction(
+ txn_results = await self.db_pool.runInteraction(
"_get_linearized_receipts_for_rooms", f
)
@@ -346,7 +349,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
)
def _invalidate_get_users_with_receipts_in_room(
- self, room_id, receipt_type, user_id
+ self, room_id: str, receipt_type: str, user_id: str
):
if receipt_type != "m.read":
return
@@ -472,15 +475,21 @@ class ReceiptsStore(ReceiptsWorkerStore):
return rx_ts
- @defer.inlineCallbacks
- def insert_receipt(self, room_id, receipt_type, user_id, event_ids, data):
+ async def insert_receipt(
+ self,
+ room_id: str,
+ receipt_type: str,
+ user_id: str,
+ event_ids: List[str],
+ data: dict,
+ ) -> Optional[Tuple[int, int]]:
"""Insert a receipt, either from local client or remote server.
Automatically does conversion between linearized and graph
representations.
"""
if not event_ids:
- return
+ return None
if len(event_ids) == 1:
linearized_event_id = event_ids[0]
@@ -507,13 +516,12 @@ class ReceiptsStore(ReceiptsWorkerStore):
else:
raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,))
- linearized_event_id = yield self.db_pool.runInteraction(
+ linearized_event_id = await self.db_pool.runInteraction(
"insert_receipt_conv", graph_to_linear
)
- stream_id_manager = self._receipts_id_gen.get_next()
- with stream_id_manager as stream_id:
- event_ts = yield self.db_pool.runInteraction(
+ with await self._receipts_id_gen.get_next() as stream_id:
+ event_ts = await self.db_pool.runInteraction(
"insert_linearized_receipt",
self.insert_linearized_receipt_txn,
room_id,
@@ -535,7 +543,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
now - event_ts,
)
- yield self.insert_graph_receipt(room_id, receipt_type, user_id, event_ids, data)
+ await self.insert_graph_receipt(room_id, receipt_type, user_id, event_ids, data)
max_persisted_id = self._receipts_id_gen.get_current_token()
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 402ae25571..336b578e23 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -17,9 +17,7 @@
import logging
import re
-from typing import Dict, List, Optional
-
-from twisted.internet.defer import Deferred
+from typing import Awaitable, Dict, List, Optional
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
@@ -158,6 +156,37 @@ class RegistrationWorkerStore(SQLBaseStore):
"set_account_validity_for_user", set_account_validity_for_user_txn
)
+ async def get_expired_users(self):
+ """Get UserIDs of all expired users.
+
+ Users who are not active, or do not have profile information, are
+ excluded from the results.
+
+ Returns:
+ Deferred[List[UserID]]: List of expired user IDs
+ """
+
+ def get_expired_users_txn(txn, now_ms):
+ # We need to use pattern matching as profiles.user_id is confusingly just the
+ # user's localpart, whereas account_validity.user_id is a full user ID
+ sql = """
+ SELECT av.user_id from account_validity AS av
+ LEFT JOIN profiles as p
+ ON av.user_id LIKE '%%' || p.user_id || ':%%'
+ WHERE expiration_ts_ms <= ?
+ AND p.active = 1
+ """
+ txn.execute(sql, (now_ms,))
+ rows = txn.fetchall()
+
+ return [UserID.from_string(row[0]) for row in rows]
+
+ res = await self.db_pool.runInteraction(
+ "get_expired_users", get_expired_users_txn, self.clock.time_msec()
+ )
+
+ return res
+
async def set_renewal_token_for_user(
self, user_id: str, renewal_token: str
) -> None:
@@ -264,6 +293,54 @@ class RegistrationWorkerStore(SQLBaseStore):
desc="delete_account_validity_for_user",
)
+ async def get_info_for_users(
+ self, user_ids: List[str],
+ ):
+ """Return the user info for a given set of users
+
+ Args:
+ user_ids: A list of users to return information about
+
+ Returns:
+ Deferred[Dict[str, bool]]: A dictionary mapping each user ID to
+ a dict with the following keys:
+ * expired - whether this is an expired user
+ * deactivated - whether this is a deactivated user
+ """
+ # Get information of all our local users
+ def _get_info_for_users_txn(txn):
+ rows = []
+
+ for user_id in user_ids:
+ sql = """
+ SELECT u.name, u.deactivated, av.expiration_ts_ms
+ FROM users as u
+ LEFT JOIN account_validity as av
+ ON av.user_id = u.name
+ WHERE u.name = ?
+ """
+
+ txn.execute(sql, (user_id,))
+ row = txn.fetchone()
+ if row:
+ rows.append(row)
+
+ return rows
+
+ info_rows = await self.db_pool.runInteraction(
+ "get_info_for_users", _get_info_for_users_txn
+ )
+
+ return {
+ user_id: {
+ "expired": (
+ expiration is not None and self.clock.time_msec() >= expiration
+ ),
+ "deactivated": deactivated == 1,
+ }
+ for user_id, deactivated, expiration in info_rows
+ }
+
async def is_server_admin(self, user: UserID) -> bool:
"""Determines if a user is an admin of this homeserver.
@@ -304,7 +381,7 @@ class RegistrationWorkerStore(SQLBaseStore):
def _query_for_auth(self, txn, token):
sql = (
- "SELECT users.name, users.is_guest, access_tokens.id as token_id,"
+ "SELECT users.name, users.is_guest, users.shadow_banned, access_tokens.id as token_id,"
" access_tokens.device_id, access_tokens.valid_until_ms"
" FROM users"
" INNER JOIN access_tokens on users.name = access_tokens.user_id"
@@ -563,7 +640,7 @@ class RegistrationWorkerStore(SQLBaseStore):
id_server (str)
Returns:
- Deferred
+ Awaitable
"""
# We need to use an upsert, in case they user had already bound the
# threepid
@@ -891,6 +968,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
super(RegistrationStore, self).__init__(database, db_conn, hs)
self._account_validity = hs.config.account_validity
+ self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors
if self._account_validity.enabled:
self._clock.call_later(
@@ -952,6 +1030,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
create_profile_with_displayname=None,
admin=False,
user_type=None,
+ shadow_banned=False,
):
"""Attempts to register an account.
@@ -968,6 +1047,8 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
admin (boolean): is an admin user?
user_type (str|None): type of user. One of the values from
api.constants.UserTypes, or None for a normal user.
+ shadow_banned (bool): Whether the user is shadow-banned,
+ i.e. they may be told their requests succeeded but we ignore them.
Raises:
StoreError if the user_id could not be registered.
@@ -986,6 +1067,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
create_profile_with_displayname,
admin,
user_type,
+ shadow_banned,
)
def _register_user(
@@ -999,6 +1081,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
create_profile_with_displayname,
admin,
user_type,
+ shadow_banned,
):
user_id_obj = UserID.from_string(user_id)
@@ -1028,6 +1111,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
"appservice_id": appservice_id,
"admin": 1 if admin else 0,
"user_type": user_type,
+ "shadow_banned": shadow_banned,
},
)
else:
@@ -1042,6 +1126,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
"appservice_id": appservice_id,
"admin": 1 if admin else 0,
"user_type": user_type,
+ "shadow_banned": shadow_banned,
},
)
@@ -1077,7 +1162,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
def record_user_external_id(
self, auth_provider: str, external_id: str, user_id: str
- ) -> Deferred:
+ ) -> Awaitable:
"""Record a mapping from an external user id to a mxid
Args:
@@ -1297,15 +1382,22 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
)
if not row:
- raise ThreepidValidationError(400, "Unknown session_id")
+ if self._ignore_unknown_session_error:
+ # If we need to inhibit the error caused by an incorrect session ID,
+ # use None as placeholder values for the client secret and the
+ # validation timestamp.
+ # It shouldn't be an issue because they're both only checked after
+ # the token check, which should fail. And if it doesn't for some
+ # reason, the next check is on the client secret, which is NOT NULL,
+ # so we don't have to worry about the client secret matching by
+ # accident.
+ row = {"client_secret": None, "validated_at": None}
+ else:
+ raise ThreepidValidationError(400, "Unknown session_id")
+
retrieved_client_secret = row["client_secret"]
validated_at = row["validated_at"]
- if retrieved_client_secret != client_secret:
- raise ThreepidValidationError(
- 400, "This client_secret does not match the provided session_id"
- )
-
row = self.db_pool.simple_select_one_txn(
txn,
table="threepid_validation_token",
@@ -1321,6 +1413,11 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
expires = row["expires"]
next_link = row["next_link"]
+ if retrieved_client_secret != client_secret:
+ raise ThreepidValidationError(
+ 400, "This client_secret does not match the provided session_id"
+ )
+
# If the session is already validated, no need to revalidate
if validated_at:
return next_link
@@ -1345,43 +1442,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
"validate_threepid_session_txn", validate_threepid_session_txn
)
- def upsert_threepid_validation_session(
- self,
- medium,
- address,
- client_secret,
- send_attempt,
- session_id,
- validated_at=None,
- ):
- """Upsert a threepid validation session
- Args:
- medium (str): The medium of the 3PID
- address (str): The address of the 3PID
- client_secret (str): A unique string provided by the client to
- help identify this validation attempt
- send_attempt (int): The latest send_attempt on this session
- session_id (str): The id of this validation session
- validated_at (int|None): The unix timestamp in milliseconds of
- when the session was marked as valid
- """
- insertion_values = {
- "medium": medium,
- "address": address,
- "client_secret": client_secret,
- }
-
- if validated_at:
- insertion_values["validated_at"] = validated_at
-
- return self.db_pool.simple_upsert(
- table="threepid_validation_session",
- keyvalues={"session_id": session_id},
- values={"last_send_attempt": send_attempt},
- insertion_values=insertion_values,
- desc="upsert_threepid_validation_session",
- )
-
def start_or_continue_validation_session(
self,
medium,
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index f4008e6221..99a8a9fab0 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -21,8 +21,6 @@ from abc import abstractmethod
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple
-from canonicaljson import json
-
from synapse.api.constants import EventTypes
from synapse.api.errors import StoreError
from synapse.api.room_versions import RoomVersion, RoomVersions
@@ -30,15 +28,12 @@ from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.databases.main.search import SearchStore
from synapse.types import ThirdPartyInstanceID
+from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
logger = logging.getLogger(__name__)
-OpsLevel = collections.namedtuple(
- "OpsLevel", ("ban_level", "kick_level", "redact_level")
-)
-
RatelimitOverride = collections.namedtuple(
"RatelimitOverride", ("messages_per_second", "burst_count")
)
@@ -344,6 +339,23 @@ class RoomWorkerStore(SQLBaseStore):
desc="is_room_blocked",
)
+ async def is_room_published(self, room_id: str) -> bool:
+ """Check whether a room has been published in the local public room
+ directory.
+
+ Args:
+ room_id
+ Returns:
+ Whether the room is currently published in the room directory
+ """
+ # Get room information
+ room_info = await self.get_room(room_id)
+ if not room_info:
+ return False
+
+ # Check the is_public value
+ return room_info.get("is_public", False)
+
async def get_rooms_paginate(
self,
start: int,
@@ -552,6 +564,11 @@ class RoomWorkerStore(SQLBaseStore):
Returns:
dict[int, int]: "min_lifetime" and "max_lifetime" for this room.
"""
+ # If the room retention feature is disabled, return a policy with no minimum nor
+ # maximum, in order not to filter out events we should filter out when sending to
+ # the client.
+ if not self.config.retention_enabled:
+ return {"min_lifetime": None, "max_lifetime": None}
def get_retention_policy_for_room_txn(txn):
txn.execute(
@@ -1134,7 +1151,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
},
)
- with self._public_room_id_gen.get_next() as next_id:
+ with await self._public_room_id_gen.get_next() as next_id:
await self.db_pool.runInteraction(
"store_room_txn", store_room_txn, next_id
)
@@ -1201,7 +1218,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
},
)
- with self._public_room_id_gen.get_next() as next_id:
+ with await self._public_room_id_gen.get_next() as next_id:
await self.db_pool.runInteraction(
"set_room_is_public", set_room_is_public_txn, next_id
)
@@ -1281,7 +1298,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
},
)
- with self._public_room_id_gen.get_next() as next_id:
+ with await self._public_room_id_gen.get_next() as next_id:
await self.db_pool.runInteraction(
"set_room_is_public_appservice",
set_room_is_public_appservice_txn,
@@ -1314,7 +1331,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
"event_id": event_id,
"user_id": user_id,
"reason": reason,
- "content": json.dumps(content),
+ "content": json_encoder.encode(content),
},
desc="add_event_report",
)
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index b2fcfc9bfe..161edbeccb 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -17,8 +17,6 @@
import logging
from typing import TYPE_CHECKING, Awaitable, Iterable, List, Optional, Set
-from twisted.internet import defer
-
from synapse.api.constants import EventTypes, Membership
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
@@ -92,8 +90,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
lambda: self._known_servers_count,
)
- @defer.inlineCallbacks
- def _count_known_servers(self):
+ async def _count_known_servers(self):
"""
Count the servers that this server knows about.
@@ -121,7 +118,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
txn.execute(query)
return list(txn)[0][0]
- count = yield self.db_pool.runInteraction("get_known_servers", _transact)
+ count = await self.db_pool.runInteraction("get_known_servers", _transact)
# We always know about ourselves, even if we have nothing in
# room_memberships (for example, the server is new).
@@ -589,11 +586,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
raise NotImplementedError()
@cachedList(
- cached_method_name="_get_joined_profile_from_event_id",
- list_name="event_ids",
- inlineCallbacks=True,
+ cached_method_name="_get_joined_profile_from_event_id", list_name="event_ids",
)
- def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]):
+ async def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]):
"""For given set of member event_ids check if they point to a join
event and if so return the associated user and profile info.
@@ -601,11 +596,11 @@ class RoomMemberWorkerStore(EventsWorkerStore):
event_ids: The member event IDs to lookup
Returns:
- Deferred[dict[str, Tuple[str, ProfileInfo]|None]]: Map from event ID
+ dict[str, Tuple[str, ProfileInfo]|None]: Map from event ID
to `user_id` and ProfileInfo (or None if not join event).
"""
- rows = yield self.db_pool.simple_select_many_batch(
+ rows = await self.db_pool.simple_select_many_batch(
table="room_memberships",
column="event_id",
iterable=event_ids,
@@ -772,13 +767,13 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return set(room_ids)
- def get_membership_from_event_ids(
+ async def get_membership_from_event_ids(
self, member_event_ids: Iterable[str]
) -> List[dict]:
"""Get user_id and membership of a set of event IDs.
"""
- return self.db_pool.simple_select_many_batch(
+ return await self.db_pool.simple_select_many_batch(
table="room_memberships",
column="event_id",
iterable=member_event_ids,
diff --git a/synapse/storage/databases/main/schema/delta/48/profiles_batch.sql b/synapse/storage/databases/main/schema/delta/48/profiles_batch.sql
new file mode 100644
index 0000000000..e744c02fe8
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/48/profiles_batch.sql
@@ -0,0 +1,36 @@
+/* Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+ * Add a batch number to track changes to profiles and the
+ * order they're made in so we can replicate user profiles
+ * to other hosts as they change
+ */
+ALTER TABLE profiles ADD COLUMN batch BIGINT DEFAULT NULL;
+
+/*
+ * Index on the batch number so we can get profiles
+ * by their batch
+ */
+CREATE INDEX profiles_batch_idx ON profiles(batch);
+
+/*
+ * A table to track what batch of user profiles has been
+ * synced to what profile replication target.
+ */
+CREATE TABLE profile_replication_status (
+ host TEXT NOT NULL,
+ last_synced_batch BIGINT NOT NULL
+);
diff --git a/synapse/storage/databases/main/schema/delta/50/profiles_deactivated_users.sql b/synapse/storage/databases/main/schema/delta/50/profiles_deactivated_users.sql
new file mode 100644
index 0000000000..96051ac179
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/50/profiles_deactivated_users.sql
@@ -0,0 +1,23 @@
+/* Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+ * A flag saying whether the user owning the profile has been deactivated
+ * This really belongs on the users table, not here, but the users table
+ * stores users by their full user_id and profiles stores them by localpart,
+ * so we can't easily join between the two tables. Plus, the batch number
+ * realy ought to represent data in this table that has changed.
+ */
+ALTER TABLE profiles ADD COLUMN active SMALLINT DEFAULT 1 NOT NULL;
\ No newline at end of file
diff --git a/synapse/storage/databases/main/schema/delta/55/profile_replication_status_index.sql b/synapse/storage/databases/main/schema/delta/55/profile_replication_status_index.sql
new file mode 100644
index 0000000000..7542ab8cbd
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/55/profile_replication_status_index.sql
@@ -0,0 +1,16 @@
+/* Copyright 2019 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE UNIQUE INDEX profile_replication_status_idx ON profile_replication_status(host);
\ No newline at end of file
diff --git a/synapse/storage/databases/main/schema/delta/58/07persist_ui_auth_ips.sql b/synapse/storage/databases/main/schema/delta/58/07persist_ui_auth_ips.sql
new file mode 100644
index 0000000000..4cc96a5341
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/07persist_ui_auth_ips.sql
@@ -0,0 +1,25 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- A table of the IP address and user-agent used to complete each step of a
+-- user-interactive authentication session.
+CREATE TABLE IF NOT EXISTS ui_auth_sessions_ips(
+ session_id TEXT NOT NULL,
+ ip TEXT NOT NULL,
+ user_agent TEXT NOT NULL,
+ UNIQUE (session_id, ip, user_agent),
+ FOREIGN KEY (session_id)
+ REFERENCES ui_auth_sessions (session_id)
+);
diff --git a/synapse/storage/databases/main/schema/delta/58/09shadow_ban.sql b/synapse/storage/databases/main/schema/delta/58/09shadow_ban.sql
new file mode 100644
index 0000000000..260b009b48
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/09shadow_ban.sql
@@ -0,0 +1,18 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- A shadow-banned user may be told that their requests succeeded when they were
+-- actually ignored.
+ALTER TABLE users ADD COLUMN shadow_banned BOOLEAN;
diff --git a/synapse/storage/databases/main/schema/delta/58/13remove_presence_allow_inbound.sql b/synapse/storage/databases/main/schema/delta/58/13remove_presence_allow_inbound.sql
new file mode 100644
index 0000000000..15421b99ac
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/13remove_presence_allow_inbound.sql
@@ -0,0 +1,17 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- This table is no longer used.
+DROP TABLE IF EXISTS presence_allow_inbound;
diff --git a/synapse/storage/databases/main/schema/full_schemas/54/full.sql.postgres b/synapse/storage/databases/main/schema/full_schemas/54/full.sql.postgres
index 889a9a0ce4..20c5af2eb7 100644
--- a/synapse/storage/databases/main/schema/full_schemas/54/full.sql.postgres
+++ b/synapse/storage/databases/main/schema/full_schemas/54/full.sql.postgres
@@ -658,10 +658,19 @@ CREATE TABLE presence_stream (
+CREATE TABLE profile_replication_status (
+ host text NOT NULL,
+ last_synced_batch bigint NOT NULL
+);
+
+
+
CREATE TABLE profiles (
user_id text NOT NULL,
displayname text,
- avatar_url text
+ avatar_url text,
+ batch bigint,
+ active smallint DEFAULT 1 NOT NULL
);
@@ -1788,6 +1797,10 @@ CREATE INDEX presence_stream_user_id ON presence_stream USING btree (user_id);
+CREATE INDEX profiles_batch_idx ON profiles USING btree (batch);
+
+
+
CREATE INDEX public_room_index ON rooms USING btree (is_public);
diff --git a/synapse/storage/databases/main/schema/full_schemas/54/full.sql.sqlite b/synapse/storage/databases/main/schema/full_schemas/54/full.sql.sqlite
index a0411ede7e..e28ec3fa45 100644
--- a/synapse/storage/databases/main/schema/full_schemas/54/full.sql.sqlite
+++ b/synapse/storage/databases/main/schema/full_schemas/54/full.sql.sqlite
@@ -6,7 +6,7 @@ CREATE TABLE presence_allow_inbound( observed_user_id TEXT NOT NULL, observer_us
CREATE TABLE users( name TEXT, password_hash TEXT, creation_ts BIGINT, admin SMALLINT DEFAULT 0 NOT NULL, upgrade_ts BIGINT, is_guest SMALLINT DEFAULT 0 NOT NULL, appservice_id TEXT, consent_version TEXT, consent_server_notice_sent TEXT, user_type TEXT DEFAULT NULL, UNIQUE(name) );
CREATE TABLE access_tokens( id BIGINT PRIMARY KEY, user_id TEXT NOT NULL, device_id TEXT, token TEXT NOT NULL, last_used BIGINT, UNIQUE(token) );
CREATE TABLE user_ips ( user_id TEXT NOT NULL, access_token TEXT NOT NULL, device_id TEXT, ip TEXT NOT NULL, user_agent TEXT NOT NULL, last_seen BIGINT NOT NULL );
-CREATE TABLE profiles( user_id TEXT NOT NULL, displayname TEXT, avatar_url TEXT, UNIQUE(user_id) );
+CREATE TABLE profiles( user_id TEXT NOT NULL, displayname TEXT, avatar_url TEXT, batch BIGINT DEFAULT NULL, active SMALLINT DEFAULT 1 NOT NULL, UNIQUE(user_id) );
CREATE TABLE received_transactions( transaction_id TEXT, origin TEXT, ts BIGINT, response_code INTEGER, response_json bytea, has_been_referenced smallint default 0, UNIQUE (transaction_id, origin) );
CREATE TABLE destinations( destination TEXT PRIMARY KEY, retry_last_ts BIGINT, retry_interval INTEGER );
CREATE TABLE events( stream_ordering INTEGER PRIMARY KEY, topological_ordering BIGINT NOT NULL, event_id TEXT NOT NULL, type TEXT NOT NULL, room_id TEXT NOT NULL, content TEXT, unrecognized_keys TEXT, processed BOOL NOT NULL, outlier BOOL NOT NULL, depth BIGINT DEFAULT 0 NOT NULL, origin_server_ts BIGINT, received_ts BIGINT, sender TEXT, contains_url BOOLEAN, UNIQUE (event_id) );
@@ -202,6 +202,8 @@ CREATE INDEX group_users_u_idx ON group_users(user_id);
CREATE INDEX group_invites_u_idx ON group_invites(user_id);
CREATE UNIQUE INDEX group_rooms_g_idx ON group_rooms(group_id, room_id);
CREATE INDEX group_rooms_r_idx ON group_rooms(room_id);
+CREATE INDEX profiles_batch_idx ON profiles(batch);
+CREATE TABLE profile_replication_status ( host TEXT NOT NULL, last_synced_batch BIGINT NOT NULL );
CREATE TABLE user_daily_visits ( user_id TEXT NOT NULL, device_id TEXT, timestamp BIGINT NOT NULL );
CREATE INDEX user_daily_visits_uts_idx ON user_daily_visits(user_id, timestamp);
CREATE INDEX user_daily_visits_ts_idx ON user_daily_visits(timestamp);
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index 96e0378e50..991233a9bc 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -273,12 +273,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
cached_method_name="_get_state_group_for_event",
list_name="event_ids",
num_args=1,
- inlineCallbacks=True,
)
- def _get_state_group_for_events(self, event_ids):
+ async def _get_state_group_for_events(self, event_ids):
"""Returns mapping event_id -> state_group
"""
- rows = yield self.db_pool.simple_select_many_batch(
+ rows = await self.db_pool.simple_select_many_batch(
table="event_to_state_groups",
column="event_id",
iterable=event_ids,
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index aaf225894e..497f607703 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -39,15 +39,17 @@ what sort order was used:
import abc
import logging
from collections import namedtuple
-from typing import Optional
+from typing import Dict, Iterable, List, Optional, Tuple
from twisted.internet import defer
+from synapse.api.filtering import Filter
+from synapse.events import EventBase
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool, make_in_list_sql_clause
from synapse.storage.databases.main.events_worker import EventsWorkerStore
-from synapse.storage.engines import PostgresEngine
+from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
from synapse.types import RoomStreamToken
from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -68,8 +70,12 @@ _EventDictReturn = namedtuple(
def generate_pagination_where_clause(
- direction, column_names, from_token, to_token, engine
-):
+ direction: str,
+ column_names: Tuple[str, str],
+ from_token: Optional[Tuple[int, int]],
+ to_token: Optional[Tuple[int, int]],
+ engine: BaseDatabaseEngine,
+) -> str:
"""Creates an SQL expression to bound the columns by the pagination
tokens.
@@ -90,21 +96,19 @@ def generate_pagination_where_clause(
token, but include those that match the to token.
Args:
- direction (str): Whether we're paginating backwards("b") or
- forwards ("f").
- column_names (tuple[str, str]): The column names to bound. Must *not*
- be user defined as these get inserted directly into the SQL
- statement without escapes.
- from_token (tuple[int, int]|None): The start point for the pagination.
- This is an exclusive minimum bound if direction is "f", and an
- inclusive maximum bound if direction is "b".
- to_token (tuple[int, int]|None): The endpoint point for the pagination.
- This is an inclusive maximum bound if direction is "f", and an
- exclusive minimum bound if direction is "b".
+ direction: Whether we're paginating backwards("b") or forwards ("f").
+ column_names: The column names to bound. Must *not* be user defined as
+ these get inserted directly into the SQL statement without escapes.
+ from_token: The start point for the pagination. This is an exclusive
+ minimum bound if direction is "f", and an inclusive maximum bound if
+ direction is "b".
+ to_token: The endpoint point for the pagination. This is an inclusive
+ maximum bound if direction is "f", and an exclusive minimum bound if
+ direction is "b".
engine: The database engine to generate the clauses for
Returns:
- str: The sql expression
+ The sql expression
"""
assert direction in ("b", "f")
@@ -132,7 +136,12 @@ def generate_pagination_where_clause(
return " AND ".join(where_clause)
-def _make_generic_sql_bound(bound, column_names, values, engine):
+def _make_generic_sql_bound(
+ bound: str,
+ column_names: Tuple[str, str],
+ values: Tuple[Optional[int], int],
+ engine: BaseDatabaseEngine,
+) -> str:
"""Create an SQL expression that bounds the given column names by the
values, e.g. create the equivalent of `(1, 2) < (col1, col2)`.
@@ -142,18 +151,18 @@ def _make_generic_sql_bound(bound, column_names, values, engine):
out manually.
Args:
- bound (str): The comparison operator to use. One of ">", "<", ">=",
+ bound: The comparison operator to use. One of ">", "<", ">=",
"<=", where the values are on the left and columns on the right.
- names (tuple[str, str]): The column names. Must *not* be user defined
+ names: The column names. Must *not* be user defined
as these get inserted directly into the SQL statement without
escapes.
- values (tuple[int|None, int]): The values to bound the columns by. If
+ values: The values to bound the columns by. If
the first value is None then only creates a bound on the second
column.
engine: The database engine to generate the SQL for
Returns:
- str
+ The SQL statement
"""
assert bound in (">", "<", ">=", "<=")
@@ -193,7 +202,7 @@ def _make_generic_sql_bound(bound, column_names, values, engine):
)
-def filter_to_clause(event_filter):
+def filter_to_clause(event_filter: Filter) -> Tuple[str, List[str]]:
# NB: This may create SQL clauses that don't optimise well (and we don't
# have indices on all possible clauses). E.g. it may create
# "room_id == X AND room_id != X", which postgres doesn't optimise.
@@ -291,34 +300,35 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
def get_room_min_stream_ordering(self):
raise NotImplementedError()
- @defer.inlineCallbacks
- def get_room_events_stream_for_rooms(
- self, room_ids, from_key, to_key, limit=0, order="DESC"
- ):
+ async def get_room_events_stream_for_rooms(
+ self,
+ room_ids: Iterable[str],
+ from_key: str,
+ to_key: str,
+ limit: int = 0,
+ order: str = "DESC",
+ ) -> Dict[str, Tuple[List[EventBase], str]]:
"""Get new room events in stream ordering since `from_key`.
Args:
- room_id (str)
- from_key (str): Token from which no events are returned before
- to_key (str): Token from which no events are returned after. (This
+ room_ids
+ from_key: Token from which no events are returned before
+ to_key: Token from which no events are returned after. (This
is typically the current stream token)
- limit (int): Maximum number of events to return
- order (str): Either "DESC" or "ASC". Determines which events are
+ limit: Maximum number of events to return
+ order: Either "DESC" or "ASC". Determines which events are
returned when the result is limited. If "DESC" then the most
recent `limit` events are returned, otherwise returns the
oldest `limit` events.
Returns:
- Deferred[dict[str,tuple[list[FrozenEvent], str]]]
- A map from room id to a tuple containing:
- - list of recent events in the room
- - stream ordering key for the start of the chunk of events returned.
+ A map from room id to a tuple containing:
+ - list of recent events in the room
+ - stream ordering key for the start of the chunk of events returned.
"""
from_id = RoomStreamToken.parse_stream_token(from_key).stream
- room_ids = yield self._events_stream_cache.get_entities_changed(
- room_ids, from_id
- )
+ room_ids = self._events_stream_cache.get_entities_changed(room_ids, from_id)
if not room_ids:
return {}
@@ -326,7 +336,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
results = {}
room_ids = list(room_ids)
for rm_ids in (room_ids[i : i + 20] for i in range(0, len(room_ids), 20)):
- res = yield make_deferred_yieldable(
+ res = await make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(
@@ -361,28 +371,30 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
if self._events_stream_cache.has_entity_changed(room_id, from_key)
}
- @defer.inlineCallbacks
- def get_room_events_stream_for_room(
- self, room_id, from_key, to_key, limit=0, order="DESC"
- ):
-
+ async def get_room_events_stream_for_room(
+ self,
+ room_id: str,
+ from_key: str,
+ to_key: str,
+ limit: int = 0,
+ order: str = "DESC",
+ ) -> Tuple[List[EventBase], str]:
"""Get new room events in stream ordering since `from_key`.
Args:
- room_id (str)
- from_key (str): Token from which no events are returned before
- to_key (str): Token from which no events are returned after. (This
+ room_id
+ from_key: Token from which no events are returned before
+ to_key: Token from which no events are returned after. (This
is typically the current stream token)
- limit (int): Maximum number of events to return
- order (str): Either "DESC" or "ASC". Determines which events are
+ limit: Maximum number of events to return
+ order: Either "DESC" or "ASC". Determines which events are
returned when the result is limited. If "DESC" then the most
recent `limit` events are returned, otherwise returns the
oldest `limit` events.
Returns:
- Deferred[tuple[list[FrozenEvent], str]]: Returns the list of
- events (in ascending order) and the token from the start of
- the chunk of events returned.
+ The list of events (in ascending order) and the token from the start
+ of the chunk of events returned.
"""
if from_key == to_key:
return [], from_key
@@ -390,9 +402,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
from_id = RoomStreamToken.parse_stream_token(from_key).stream
to_id = RoomStreamToken.parse_stream_token(to_key).stream
- has_changed = yield self._events_stream_cache.has_entity_changed(
- room_id, from_id
- )
+ has_changed = self._events_stream_cache.has_entity_changed(room_id, from_id)
if not has_changed:
return [], from_key
@@ -410,9 +420,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
rows = [_EventDictReturn(row[0], None, row[1]) for row in txn]
return rows
- rows = yield self.db_pool.runInteraction("get_room_events_stream_for_room", f)
+ rows = await self.db_pool.runInteraction("get_room_events_stream_for_room", f)
- ret = yield self.get_events_as_list(
+ ret = await self.get_events_as_list(
[r.event_id for r in rows], get_prev_content=True
)
@@ -430,8 +440,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return ret, key
- @defer.inlineCallbacks
- def get_membership_changes_for_user(self, user_id, from_key, to_key):
+ async def get_membership_changes_for_user(self, user_id, from_key, to_key):
from_id = RoomStreamToken.parse_stream_token(from_key).stream
to_id = RoomStreamToken.parse_stream_token(to_key).stream
@@ -460,9 +469,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return rows
- rows = yield self.db_pool.runInteraction("get_membership_changes_for_user", f)
+ rows = await self.db_pool.runInteraction("get_membership_changes_for_user", f)
- ret = yield self.get_events_as_list(
+ ret = await self.get_events_as_list(
[r.event_id for r in rows], get_prev_content=True
)
@@ -470,27 +479,26 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return ret
- @defer.inlineCallbacks
- def get_recent_events_for_room(self, room_id, limit, end_token):
+ async def get_recent_events_for_room(
+ self, room_id: str, limit: int, end_token: str
+ ) -> Tuple[List[EventBase], str]:
"""Get the most recent events in the room in topological ordering.
Args:
- room_id (str)
- limit (int)
- end_token (str): The stream token representing now.
+ room_id
+ limit
+ end_token: The stream token representing now.
Returns:
- Deferred[tuple[list[FrozenEvent], str]]: Returns a list of
- events and a token pointing to the start of the returned
- events.
- The events returned are in ascending order.
+ A list of events and a token pointing to the start of the returned
+ events. The events returned are in ascending order.
"""
- rows, token = yield self.get_recent_event_ids_for_room(
+ rows, token = await self.get_recent_event_ids_for_room(
room_id, limit, end_token
)
- events = yield self.get_events_as_list(
+ events = await self.get_events_as_list(
[r.event_id for r in rows], get_prev_content=True
)
@@ -498,20 +506,19 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return (events, token)
- @defer.inlineCallbacks
- def get_recent_event_ids_for_room(self, room_id, limit, end_token):
+ async def get_recent_event_ids_for_room(
+ self, room_id: str, limit: int, end_token: str
+ ) -> Tuple[List[_EventDictReturn], str]:
"""Get the most recent events in the room in topological ordering.
Args:
- room_id (str)
- limit (int)
- end_token (str): The stream token representing now.
+ room_id
+ limit
+ end_token: The stream token representing now.
Returns:
- Deferred[tuple[list[_EventDictReturn], str]]: Returns a list of
- _EventDictReturn and a token pointing to the start of the returned
- events.
- The events returned are in ascending order.
+ A list of _EventDictReturn and a token pointing to the start of the
+ returned events. The events returned are in ascending order.
"""
# Allow a zero limit here, and no-op.
if limit == 0:
@@ -519,7 +526,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
end_token = RoomStreamToken.parse(end_token)
- rows, token = yield self.db_pool.runInteraction(
+ rows, token = await self.db_pool.runInteraction(
"get_recent_event_ids_for_room",
self._paginate_room_events_txn,
room_id,
@@ -532,12 +539,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return rows, token
- def get_room_event_before_stream_ordering(self, room_id, stream_ordering):
+ def get_room_event_before_stream_ordering(self, room_id: str, stream_ordering: int):
"""Gets details of the first event in a room at or before a stream ordering
Args:
- room_id (str):
- stream_ordering (int):
+ room_id:
+ stream_ordering:
Returns:
Deferred[(int, int, str)]:
@@ -574,55 +581,67 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
)
return "t%d-%d" % (topo, token)
- def get_stream_token_for_event(self, event_id):
- """The stream token for an event
+ async def get_stream_id_for_event(self, event_id: str) -> int:
+ """The stream ID for an event
Args:
- event_id(str): The id of the event to look up a stream token for.
+ event_id: The id of the event to look up a stream token for.
Raises:
StoreError if the event wasn't in the database.
Returns:
- A deferred "s%d" stream token.
+ A stream ID.
"""
- return self.db_pool.simple_select_one_onecol(
+ return await self.db_pool.simple_select_one_onecol(
table="events", keyvalues={"event_id": event_id}, retcol="stream_ordering"
- ).addCallback(lambda row: "s%d" % (row,))
+ )
- def get_topological_token_for_event(self, event_id):
+ async def get_stream_token_for_event(self, event_id: str) -> str:
"""The stream token for an event
Args:
- event_id(str): The id of the event to look up a stream token for.
+ event_id: The id of the event to look up a stream token for.
Raises:
StoreError if the event wasn't in the database.
Returns:
- A deferred "t%d-%d" topological token.
+ A "s%d" stream token.
"""
- return self.db_pool.simple_select_one(
+ stream_id = await self.get_stream_id_for_event(event_id)
+ return "s%d" % (stream_id,)
+
+ async def get_topological_token_for_event(self, event_id: str) -> str:
+ """The stream token for an event
+ Args:
+ event_id: The id of the event to look up a stream token for.
+ Raises:
+ StoreError if the event wasn't in the database.
+ Returns:
+ A "t%d-%d" topological token.
+ """
+ row = await self.db_pool.simple_select_one(
table="events",
keyvalues={"event_id": event_id},
retcols=("stream_ordering", "topological_ordering"),
desc="get_topological_token_for_event",
- ).addCallback(
- lambda row: "t%d-%d" % (row["topological_ordering"], row["stream_ordering"])
)
+ return "t%d-%d" % (row["topological_ordering"], row["stream_ordering"])
- def get_max_topological_token(self, room_id, stream_key):
+ async def get_max_topological_token(self, room_id: str, stream_key: int) -> int:
"""Get the max topological token in a room before the given stream
ordering.
Args:
- room_id (str)
- stream_key (int)
+ room_id
+ stream_key
Returns:
- Deferred[int]
+ The maximum topological token.
"""
sql = (
"SELECT coalesce(max(topological_ordering), 0) FROM events"
" WHERE room_id = ? AND stream_ordering < ?"
)
- return self.db_pool.execute(
+ row = await self.db_pool.execute(
"get_max_topological_token", None, sql, room_id, stream_key
- ).addCallback(lambda r: r[0][0] if r else 0)
+ )
+ return row[0][0] if row else 0
def _get_max_topological_txn(self, txn, room_id):
txn.execute(
@@ -634,16 +653,18 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return rows[0][0] if rows else 0
@staticmethod
- def _set_before_and_after(events, rows, topo_order=True):
+ def _set_before_and_after(
+ events: List[EventBase], rows: List[_EventDictReturn], topo_order: bool = True
+ ):
"""Inserts ordering information to events' internal metadata from
the DB rows.
Args:
- events (list[FrozenEvent])
- rows (list[_EventDictReturn])
- topo_order (bool): Whether the events were ordered topologically
- or by stream ordering. If true then all rows should have a non
- null topological_ordering.
+ events
+ rows
+ topo_order: Whether the events were ordered topologically or by stream
+ ordering. If true then all rows should have a non null
+ topological_ordering.
"""
for event, row in zip(events, rows):
stream = row.stream_ordering
@@ -656,25 +677,19 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
internal.after = str(RoomStreamToken(topo, stream))
internal.order = (int(topo) if topo else 0, int(stream))
- @defer.inlineCallbacks
- def get_events_around(
- self, room_id, event_id, before_limit, after_limit, event_filter=None
- ):
+ async def get_events_around(
+ self,
+ room_id: str,
+ event_id: str,
+ before_limit: int,
+ after_limit: int,
+ event_filter: Optional[Filter] = None,
+ ) -> dict:
"""Retrieve events and pagination tokens around a given event in a
room.
-
- Args:
- room_id (str)
- event_id (str)
- before_limit (int)
- after_limit (int)
- event_filter (Filter|None)
-
- Returns:
- dict
"""
- results = yield self.db_pool.runInteraction(
+ results = await self.db_pool.runInteraction(
"get_events_around",
self._get_events_around_txn,
room_id,
@@ -684,11 +699,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
event_filter,
)
- events_before = yield self.get_events_as_list(
+ events_before = await self.get_events_as_list(
list(results["before"]["event_ids"]), get_prev_content=True
)
- events_after = yield self.get_events_as_list(
+ events_after = await self.get_events_as_list(
list(results["after"]["event_ids"]), get_prev_content=True
)
@@ -700,17 +715,23 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
}
def _get_events_around_txn(
- self, txn, room_id, event_id, before_limit, after_limit, event_filter
- ):
+ self,
+ txn,
+ room_id: str,
+ event_id: str,
+ before_limit: int,
+ after_limit: int,
+ event_filter: Optional[Filter],
+ ) -> dict:
"""Retrieves event_ids and pagination tokens around a given event in a
room.
Args:
- room_id (str)
- event_id (str)
- before_limit (int)
- after_limit (int)
- event_filter (Filter|None)
+ room_id
+ event_id
+ before_limit
+ after_limit
+ event_filter
Returns:
dict
@@ -758,22 +779,23 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
"after": {"event_ids": events_after, "token": end_token},
}
- @defer.inlineCallbacks
- def get_all_new_events_stream(self, from_id, current_id, limit):
+ async def get_all_new_events_stream(
+ self, from_id: int, current_id: int, limit: int
+ ) -> Tuple[int, List[EventBase]]:
"""Get all new events
Returns all events with from_id < stream_ordering <= current_id.
Args:
- from_id (int): the stream_ordering of the last event we processed
- current_id (int): the stream_ordering of the most recently processed event
- limit (int): the maximum number of events to return
+ from_id: the stream_ordering of the last event we processed
+ current_id: the stream_ordering of the most recently processed event
+ limit: the maximum number of events to return
Returns:
- Deferred[Tuple[int, list[FrozenEvent]]]: A tuple of (next_id, events), where
- `next_id` is the next value to pass as `from_id` (it will either be the
- stream_ordering of the last returned event, or, if fewer than `limit` events
- were found, `current_id`.
+ A tuple of (next_id, events), where `next_id` is the next value to
+ pass as `from_id` (it will either be the stream_ordering of the
+ last returned event, or, if fewer than `limit` events were found,
+ the `current_id`).
"""
def get_all_new_events_stream_txn(txn):
@@ -795,11 +817,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return upper_bound, [row[1] for row in rows]
- upper_bound, event_ids = yield self.db_pool.runInteraction(
+ upper_bound, event_ids = await self.db_pool.runInteraction(
"get_all_new_events_stream", get_all_new_events_stream_txn
)
- events = yield self.get_events_as_list(event_ids)
+ events = await self.get_events_as_list(event_ids)
return upper_bound, events
@@ -817,21 +839,21 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
desc="get_federation_out_pos",
)
- async def update_federation_out_pos(self, typ, stream_id):
+ async def update_federation_out_pos(self, typ: str, stream_id: int) -> None:
if self._need_to_reset_federation_stream_positions:
await self.db_pool.runInteraction(
"_reset_federation_positions_txn", self._reset_federation_positions_txn
)
self._need_to_reset_federation_stream_positions = False
- return await self.db_pool.simple_update_one(
+ await self.db_pool.simple_update_one(
table="federation_stream_position",
keyvalues={"type": typ, "instance_name": self._instance_name},
updatevalues={"stream_id": stream_id},
desc="update_federation_out_pos",
)
- def _reset_federation_positions_txn(self, txn):
+ def _reset_federation_positions_txn(self, txn) -> None:
"""Fiddles with the `federation_stream_position` table to make it match
the configured federation sender instances during start up.
"""
@@ -892,39 +914,37 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
values={"stream_id": stream_id},
)
- def has_room_changed_since(self, room_id, stream_id):
+ def has_room_changed_since(self, room_id: str, stream_id: int) -> bool:
return self._events_stream_cache.has_entity_changed(room_id, stream_id)
def _paginate_room_events_txn(
self,
txn,
- room_id,
- from_token,
- to_token=None,
- direction="b",
- limit=-1,
- event_filter=None,
- ):
+ room_id: str,
+ from_token: RoomStreamToken,
+ to_token: Optional[RoomStreamToken] = None,
+ direction: str = "b",
+ limit: int = -1,
+ event_filter: Optional[Filter] = None,
+ ) -> Tuple[List[_EventDictReturn], str]:
"""Returns list of events before or after a given token.
Args:
txn
- room_id (str)
- from_token (RoomStreamToken): The token used to stream from
- to_token (RoomStreamToken|None): A token which if given limits the
- results to only those before
- direction(char): Either 'b' or 'f' to indicate whether we are
- paginating forwards or backwards from `from_key`.
- limit (int): The maximum number of events to return.
- event_filter (Filter|None): If provided filters the events to
+ room_id
+ from_token: The token used to stream from
+ to_token: A token which if given limits the results to only those before
+ direction: Either 'b' or 'f' to indicate whether we are paginating
+ forwards or backwards from `from_key`.
+ limit: The maximum number of events to return.
+ event_filter: If provided filters the events to
those that match the filter.
Returns:
- Deferred[tuple[list[_EventDictReturn], str]]: Returns the results
- as a list of _EventDictReturn and a token that points to the end
- of the result set. If no events are returned then the end of the
- stream has been reached (i.e. there are no events between
- `from_token` and `to_token`), or `limit` is zero.
+ A list of _EventDictReturn and a token that points to the end of the
+ result set. If no events are returned then the end of the stream has
+ been reached (i.e. there are no events between `from_token` and
+ `to_token`), or `limit` is zero.
"""
assert int(limit) >= 0
@@ -1008,35 +1028,38 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return rows, str(next_token)
- @defer.inlineCallbacks
- def paginate_room_events(
- self, room_id, from_key, to_key=None, direction="b", limit=-1, event_filter=None
- ):
+ async def paginate_room_events(
+ self,
+ room_id: str,
+ from_key: str,
+ to_key: Optional[str] = None,
+ direction: str = "b",
+ limit: int = -1,
+ event_filter: Optional[Filter] = None,
+ ) -> Tuple[List[EventBase], str]:
"""Returns list of events before or after a given token.
Args:
- room_id (str)
- from_key (str): The token used to stream from
- to_key (str|None): A token which if given limits the results to
- only those before
- direction(char): Either 'b' or 'f' to indicate whether we are
- paginating forwards or backwards from `from_key`.
- limit (int): The maximum number of events to return.
- event_filter (Filter|None): If provided filters the events to
- those that match the filter.
+ room_id
+ from_key: The token used to stream from
+ to_key: A token which if given limits the results to only those before
+ direction: Either 'b' or 'f' to indicate whether we are paginating
+ forwards or backwards from `from_key`.
+ limit: The maximum number of events to return.
+ event_filter: If provided filters the events to those that match the filter.
Returns:
- tuple[list[FrozenEvent], str]: Returns the results as a list of
- events and a token that points to the end of the result set. If no
- events are returned then the end of the stream has been reached
- (i.e. there are no events between `from_key` and `to_key`).
+ The results as a list of events and a token that points to the end
+ of the result set. If no events are returned then the end of the
+ stream has been reached (i.e. there are no events between `from_key`
+ and `to_key`).
"""
from_key = RoomStreamToken.parse(from_key)
if to_key:
to_key = RoomStreamToken.parse(to_key)
- rows, token = yield self.db_pool.runInteraction(
+ rows, token = await self.db_pool.runInteraction(
"paginate_room_events",
self._paginate_room_events_txn,
room_id,
@@ -1047,7 +1070,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
event_filter,
)
- events = yield self.get_events_as_list(
+ events = await self.get_events_as_list(
[r.event_id for r in rows], get_prev_content=True
)
@@ -1057,8 +1080,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
class StreamStore(StreamWorkerStore):
- def get_room_max_stream_ordering(self):
+ def get_room_max_stream_ordering(self) -> int:
return self._stream_id_gen.get_current_token()
- def get_room_min_stream_ordering(self):
+ def get_room_min_stream_ordering(self) -> int:
return self._backfill_id_gen.get_current_token()
diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py
index e4e0a0c433..0c34bbf21a 100644
--- a/synapse/storage/databases/main/tags.py
+++ b/synapse/storage/databases/main/tags.py
@@ -17,11 +17,10 @@
import logging
from typing import Dict, List, Tuple
-from canonicaljson import json
-
from synapse.storage._base import db_to_json
from synapse.storage.databases.main.account_data import AccountDataWorkerStore
from synapse.types import JsonDict
+from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
logger = logging.getLogger(__name__)
@@ -98,7 +97,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
txn.execute(sql, (user_id, room_id))
tags = []
for tag, content in txn:
- tags.append(json.dumps(tag) + ":" + content)
+ tags.append(json_encoder.encode(tag) + ":" + content)
tag_json = "{" + ",".join(tags) + "}"
results.append((stream_id, (user_id, room_id, tag_json)))
@@ -200,7 +199,7 @@ class TagsStore(TagsWorkerStore):
Returns:
The next account data ID.
"""
- content_json = json.dumps(content)
+ content_json = json_encoder.encode(content)
def add_tag_txn(txn, next_id):
self.db_pool.simple_upsert_txn(
@@ -211,7 +210,7 @@ class TagsStore(TagsWorkerStore):
)
self._update_revision_txn(txn, user_id, room_id, next_id)
- with self._account_data_id_gen.get_next() as next_id:
+ with await self._account_data_id_gen.get_next() as next_id:
await self.db_pool.runInteraction("add_tag", add_tag_txn, next_id)
self.get_tags_for_user.invalidate((user_id,))
@@ -233,7 +232,7 @@ class TagsStore(TagsWorkerStore):
txn.execute(sql, (user_id, room_id, tag))
self._update_revision_txn(txn, user_id, room_id, next_id)
- with self._account_data_id_gen.get_next() as next_id:
+ with await self._account_data_id_gen.get_next() as next_id:
await self.db_pool.runInteraction("remove_tag", remove_tag_txn, next_id)
self.get_tags_for_user.invalidate((user_id,))
diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py
index 37276f73f8..9eef8e57c5 100644
--- a/synapse/storage/databases/main/ui_auth.py
+++ b/synapse/storage/databases/main/ui_auth.py
@@ -12,15 +12,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Dict, Optional, Union
+from typing import Any, Dict, List, Optional, Tuple, Union
import attr
-from canonicaljson import json
from synapse.api.errors import StoreError
from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage.database import LoggingTransaction
from synapse.types import JsonDict
-from synapse.util import stringutils as stringutils
+from synapse.util import json_encoder, stringutils
@attr.s
@@ -72,7 +72,7 @@ class UIAuthWorkerStore(SQLBaseStore):
StoreError if a unique session ID cannot be generated.
"""
# The clientdict gets stored as JSON.
- clientdict_json = json.dumps(clientdict)
+ clientdict_json = json_encoder.encode(clientdict)
# autogen a session ID and try to create it. We may clash, so just
# try a few times till one goes through, giving up eventually.
@@ -143,7 +143,7 @@ class UIAuthWorkerStore(SQLBaseStore):
await self.db_pool.simple_upsert(
table="ui_auth_sessions_credentials",
keyvalues={"session_id": session_id, "stage_type": stage_type},
- values={"result": json.dumps(result)},
+ values={"result": json_encoder.encode(result)},
desc="mark_ui_auth_stage_complete",
)
except self.db_pool.engine.module.IntegrityError:
@@ -184,7 +184,7 @@ class UIAuthWorkerStore(SQLBaseStore):
The dictionary from the client root level, not the 'auth' key.
"""
# The clientdict gets stored as JSON.
- clientdict_json = json.dumps(clientdict)
+ clientdict_json = json_encoder.encode(clientdict)
await self.db_pool.simple_update_one(
table="ui_auth_sessions",
@@ -214,14 +214,16 @@ class UIAuthWorkerStore(SQLBaseStore):
value,
)
- def _set_ui_auth_session_data_txn(self, txn, session_id: str, key: str, value: Any):
+ def _set_ui_auth_session_data_txn(
+ self, txn: LoggingTransaction, session_id: str, key: str, value: Any
+ ):
# Get the current value.
result = self.db_pool.simple_select_one_txn(
txn,
table="ui_auth_sessions",
keyvalues={"session_id": session_id},
retcols=("serverdict",),
- )
+ ) # type: Dict[str, Any] # type: ignore
# Update it and add it back to the database.
serverdict = db_to_json(result["serverdict"])
@@ -231,7 +233,7 @@ class UIAuthWorkerStore(SQLBaseStore):
txn,
table="ui_auth_sessions",
keyvalues={"session_id": session_id},
- updatevalues={"serverdict": json.dumps(serverdict)},
+ updatevalues={"serverdict": json_encoder.encode(serverdict)},
)
async def get_ui_auth_session_data(
@@ -258,6 +260,34 @@ class UIAuthWorkerStore(SQLBaseStore):
return serverdict.get(key, default)
+ async def add_user_agent_ip_to_ui_auth_session(
+ self, session_id: str, user_agent: str, ip: str,
+ ):
+ """Add the given user agent / IP to the tracking table
+ """
+ await self.db_pool.simple_upsert(
+ table="ui_auth_sessions_ips",
+ keyvalues={"session_id": session_id, "user_agent": user_agent, "ip": ip},
+ values={},
+ desc="add_user_agent_ip_to_ui_auth_session",
+ )
+
+ async def get_user_agents_ips_to_ui_auth_session(
+ self, session_id: str,
+ ) -> List[Tuple[str, str]]:
+ """Get the given user agents / IPs used during the ui auth process
+
+ Returns:
+ List of user_agent/ip pairs
+ """
+ rows = await self.db_pool.simple_select_list(
+ table="ui_auth_sessions_ips",
+ keyvalues={"session_id": session_id},
+ retcols=("user_agent", "ip"),
+ desc="get_user_agents_ips_to_ui_auth_session",
+ )
+ return [(row["user_agent"], row["ip"]) for row in rows]
+
class UIAuthStore(UIAuthWorkerStore):
def delete_old_ui_auth_sessions(self, expiration_time: int):
@@ -275,12 +305,23 @@ class UIAuthStore(UIAuthWorkerStore):
expiration_time,
)
- def _delete_old_ui_auth_sessions_txn(self, txn, expiration_time: int):
+ def _delete_old_ui_auth_sessions_txn(
+ self, txn: LoggingTransaction, expiration_time: int
+ ):
# Get the expired sessions.
sql = "SELECT session_id FROM ui_auth_sessions WHERE creation_time <= ?"
txn.execute(sql, [expiration_time])
session_ids = [r[0] for r in txn.fetchall()]
+ # Delete the corresponding IP/user agents.
+ self.db_pool.simple_delete_many_txn(
+ txn,
+ table="ui_auth_sessions_ips",
+ column="session_id",
+ iterable=session_ids,
+ keyvalues={},
+ )
+
# Delete the corresponding completed credentials.
self.db_pool.simple_delete_many_txn(
txn,
diff --git a/synapse/storage/databases/main/user_erasure_store.py b/synapse/storage/databases/main/user_erasure_store.py
index ab6cb2c1f6..e3547e53b3 100644
--- a/synapse/storage/databases/main/user_erasure_store.py
+++ b/synapse/storage/databases/main/user_erasure_store.py
@@ -13,35 +13,32 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import operator
-
from synapse.storage._base import SQLBaseStore
from synapse.util.caches.descriptors import cached, cachedList
class UserErasureWorkerStore(SQLBaseStore):
@cached()
- def is_user_erased(self, user_id):
+ async def is_user_erased(self, user_id: str) -> bool:
"""
Check if the given user id has requested erasure
Args:
- user_id (str): full user id to check
+ user_id: full user id to check
Returns:
- Deferred[bool]: True if the user has requested erasure
+ True if the user has requested erasure
"""
- return self.db_pool.simple_select_onecol(
+ result = await self.db_pool.simple_select_onecol(
table="erased_users",
keyvalues={"user_id": user_id},
retcol="1",
desc="is_user_erased",
- ).addCallback(operator.truth)
+ )
+ return bool(result)
- @cachedList(
- cached_method_name="is_user_erased", list_name="user_ids", inlineCallbacks=True
- )
- def are_users_erased(self, user_ids):
+ @cachedList(cached_method_name="is_user_erased", list_name="user_ids")
+ async def are_users_erased(self, user_ids):
"""
Checks which users in a list have requested erasure
@@ -49,14 +46,14 @@ class UserErasureWorkerStore(SQLBaseStore):
user_ids (iterable[str]): full user id to check
Returns:
- Deferred[dict[str, bool]]:
+ dict[str, bool]:
for each user, whether the user has requested erasure.
"""
# this serves the dual purpose of (a) making sure we can do len and
# iterate it multiple times, and (b) avoiding duplicates.
user_ids = tuple(set(user_ids))
- rows = yield self.db_pool.simple_select_many_batch(
+ rows = await self.db_pool.simple_select_many_batch(
table="erased_users",
column="user_id",
iterable=user_ids,
@@ -65,8 +62,7 @@ class UserErasureWorkerStore(SQLBaseStore):
)
erased_users = {row["user_id"] for row in rows}
- res = {u: u in erased_users for u in user_ids}
- return res
+ return {u: u in erased_users for u in user_ids}
class UserErasureStore(UserErasureWorkerStore):
|