diff options
Diffstat (limited to 'synapse/storage')
-rw-r--r-- | synapse/storage/databases/main/client_ips.py | 140 | ||||
-rw-r--r-- | synapse/storage/databases/main/event_federation.py | 2 | ||||
-rw-r--r-- | synapse/storage/databases/main/events.py | 10 | ||||
-rw-r--r-- | synapse/storage/databases/main/events_worker.py | 142 | ||||
-rw-r--r-- | synapse/storage/databases/main/registration.py | 8 | ||||
-rw-r--r-- | synapse/storage/databases/main/room.py | 8 | ||||
-rw-r--r-- | synapse/storage/databases/main/room_batch.py | 13 | ||||
-rw-r--r-- | synapse/storage/prepare_database.py | 2 | ||||
-rw-r--r-- | synapse/storage/schema/__init__.py | 6 | ||||
-rw-r--r-- | synapse/storage/schema/main/delta/65/01msc2716_insertion_event_edges.sql | 19 |
10 files changed, 235 insertions, 115 deletions
diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py index 6c1ef09049..b81d9218ce 100644 --- a/synapse/storage/databases/main/client_ips.py +++ b/synapse/storage/databases/main/client_ips.py @@ -13,14 +13,26 @@ # limitations under the License. import logging -from typing import Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Union, cast + +from typing_extensions import TypedDict from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore -from synapse.storage.database import DatabasePool, make_tuple_comparison_clause -from synapse.types import UserID +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, + make_tuple_comparison_clause, +) +from synapse.storage.databases.main.monthly_active_users import MonthlyActiveUsersStore +from synapse.storage.types import Connection +from synapse.types import JsonDict, UserID from synapse.util.caches.lrucache import LruCache +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) # Number of msec of granularity to store the user IP 'last seen' time. Smaller @@ -29,8 +41,31 @@ logger = logging.getLogger(__name__) LAST_SEEN_GRANULARITY = 120 * 1000 +class DeviceLastConnectionInfo(TypedDict): + """Metadata for the last connection seen for a user and device combination""" + + # These types must match the columns in the `devices` table + user_id: str + device_id: str + + ip: Optional[str] + user_agent: Optional[str] + last_seen: Optional[int] + + +class LastConnectionInfo(TypedDict): + """Metadata for the last connection seen for an access token and IP combination""" + + # These types must match the columns in the `user_ips` table + access_token: str + ip: str + + user_agent: str + last_seen: int + + class ClientIpBackgroundUpdateStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_index_update( @@ -81,8 +116,10 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): "devices_last_seen", self._devices_last_seen_update ) - async def _remove_user_ip_nonunique(self, progress, batch_size): - def f(conn): + async def _remove_user_ip_nonunique( + self, progress: JsonDict, batch_size: int + ) -> int: + def f(conn: LoggingDatabaseConnection) -> None: txn = conn.cursor() txn.execute("DROP INDEX IF EXISTS user_ips_user_ip") txn.close() @@ -93,14 +130,14 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): ) return 1 - async def _analyze_user_ip(self, progress, batch_size): + async def _analyze_user_ip(self, progress: JsonDict, batch_size: int) -> int: # Background update to analyze user_ips table before we run the # deduplication background update. The table may not have been analyzed # for ages due to the table locks. # # This will lock out the naive upserts to user_ips while it happens, but # the analyze should be quick (28GB table takes ~10s) - def user_ips_analyze(txn): + def user_ips_analyze(txn: LoggingTransaction) -> None: txn.execute("ANALYZE user_ips") await self.db_pool.runInteraction("user_ips_analyze", user_ips_analyze) @@ -109,16 +146,16 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): return 1 - async def _remove_user_ip_dupes(self, progress, batch_size): + async def _remove_user_ip_dupes(self, progress: JsonDict, batch_size: int) -> int: # This works function works by scanning the user_ips table in batches # based on `last_seen`. For each row in a batch it searches the rest of # the table to see if there are any duplicates, if there are then they # are removed and replaced with a suitable row. # Fetch the start of the batch - begin_last_seen = progress.get("last_seen", 0) + begin_last_seen: int = progress.get("last_seen", 0) - def get_last_seen(txn): + def get_last_seen(txn: LoggingTransaction) -> Optional[int]: txn.execute( """ SELECT last_seen FROM user_ips @@ -129,7 +166,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): """, (begin_last_seen, batch_size), ) - row = txn.fetchone() + row = cast(Optional[Tuple[int]], txn.fetchone()) if row: return row[0] else: @@ -149,7 +186,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): end_last_seen, ) - def remove(txn): + def remove(txn: LoggingTransaction) -> None: # This works by looking at all entries in the given time span, and # then for each (user_id, access_token, ip) tuple in that range # checking for any duplicates in the rest of the table (via a join). @@ -161,10 +198,12 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): # Define the search space, which requires handling the last batch in # a different way + args: Tuple[int, ...] if last: clause = "? <= last_seen" args = (begin_last_seen,) else: + assert end_last_seen is not None clause = "? <= last_seen AND last_seen < ?" args = (begin_last_seen, end_last_seen) @@ -189,7 +228,9 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): ), args, ) - res = txn.fetchall() + res = cast( + List[Tuple[str, str, str, Optional[str], str, int, int]], txn.fetchall() + ) # We've got some duplicates for i in res: @@ -278,13 +319,15 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): return batch_size - async def _devices_last_seen_update(self, progress, batch_size): + async def _devices_last_seen_update( + self, progress: JsonDict, batch_size: int + ) -> int: """Background update to insert last seen info into devices table""" - last_user_id = progress.get("last_user_id", "") - last_device_id = progress.get("last_device_id", "") + last_user_id: str = progress.get("last_user_id", "") + last_device_id: str = progress.get("last_device_id", "") - def _devices_last_seen_update_txn(txn): + def _devices_last_seen_update_txn(txn: LoggingTransaction) -> int: # This consists of two queries: # # 1. The sub-query searches for the next N devices and joins @@ -296,6 +339,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): # we'll just end up updating the same device row multiple # times, which is fine. + where_args: List[Union[str, int]] where_clause, where_args = make_tuple_comparison_clause( [("user_id", last_user_id), ("device_id", last_device_id)], ) @@ -319,7 +363,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): } txn.execute(sql, where_args + [batch_size]) - rows = txn.fetchall() + rows = cast(List[Tuple[int, str, str, str, str]], txn.fetchall()) if not rows: return 0 @@ -350,7 +394,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): class ClientIpWorkerStore(ClientIpBackgroundUpdateStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): super().__init__(database, db_conn, hs) self.user_ips_max_age = hs.config.server.user_ips_max_age @@ -359,7 +403,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore): self._clock.looping_call(self._prune_old_user_ips, 5 * 1000) @wrap_as_background_process("prune_old_user_ips") - async def _prune_old_user_ips(self): + async def _prune_old_user_ips(self) -> None: """Removes entries in user IPs older than the configured period.""" if self.user_ips_max_age is None: @@ -394,9 +438,9 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore): ) """ - timestamp = self.clock.time_msec() - self.user_ips_max_age + timestamp = self._clock.time_msec() - self.user_ips_max_age - def _prune_old_user_ips_txn(txn): + def _prune_old_user_ips_txn(txn: LoggingTransaction) -> None: txn.execute(sql, (timestamp,)) await self.db_pool.runInteraction( @@ -405,7 +449,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore): async def get_last_client_ip_by_device( self, user_id: str, device_id: Optional[str] - ) -> Dict[Tuple[str, str], dict]: + ) -> Dict[Tuple[str, str], DeviceLastConnectionInfo]: """For each device_id listed, give the user_ip it was last seen on. The result might be slightly out of date as client IPs are inserted in batches. @@ -423,26 +467,32 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore): if device_id is not None: keyvalues["device_id"] = device_id - res = await self.db_pool.simple_select_list( - table="devices", - keyvalues=keyvalues, - retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"), + res = cast( + List[DeviceLastConnectionInfo], + await self.db_pool.simple_select_list( + table="devices", + keyvalues=keyvalues, + retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"), + ), ) return {(d["user_id"], d["device_id"]): d for d in res} -class ClientIpStore(ClientIpWorkerStore): - def __init__(self, database: DatabasePool, db_conn, hs): +class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore): + def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): - self.client_ip_last_seen = LruCache( + # (user_id, access_token, ip,) -> last_seen + self.client_ip_last_seen = LruCache[Tuple[str, str, str], int]( cache_name="client_ip_last_seen", max_size=50000 ) super().__init__(database, db_conn, hs) # (user_id, access_token, ip,) -> (user_agent, device_id, last_seen) - self._batch_row_update = {} + self._batch_row_update: Dict[ + Tuple[str, str, str], Tuple[str, Optional[str], int] + ] = {} self._client_ip_looper = self._clock.looping_call( self._update_client_ips_batch, 5 * 1000 @@ -452,8 +502,14 @@ class ClientIpStore(ClientIpWorkerStore): ) async def insert_client_ip( - self, user_id, access_token, ip, user_agent, device_id, now=None - ): + self, + user_id: str, + access_token: str, + ip: str, + user_agent: str, + device_id: Optional[str], + now: Optional[int] = None, + ) -> None: if not now: now = int(self._clock.time_msec()) key = (user_id, access_token, ip) @@ -485,7 +541,11 @@ class ClientIpStore(ClientIpWorkerStore): "_update_client_ips_batch", self._update_client_ips_batch_txn, to_update ) - def _update_client_ips_batch_txn(self, txn, to_update): + def _update_client_ips_batch_txn( + self, + txn: LoggingTransaction, + to_update: Mapping[Tuple[str, str, str], Tuple[str, Optional[str], int]], + ) -> None: if "user_ips" in self.db_pool._unsafe_to_upsert_tables or ( not self.database_engine.can_native_upsert ): @@ -525,7 +585,7 @@ class ClientIpStore(ClientIpWorkerStore): async def get_last_client_ip_by_device( self, user_id: str, device_id: Optional[str] - ) -> Dict[Tuple[str, str], dict]: + ) -> Dict[Tuple[str, str], DeviceLastConnectionInfo]: """For each device_id listed, give the user_ip it was last seen on Args: @@ -561,12 +621,12 @@ class ClientIpStore(ClientIpWorkerStore): async def get_user_ip_and_agents( self, user: UserID, since_ts: int = 0 - ) -> List[Dict[str, Union[str, int]]]: + ) -> List[LastConnectionInfo]: """ Fetch IP/User Agent connection since a given timestamp. """ user_id = user.to_string() - results = {} + results: Dict[Tuple[str, str], Tuple[str, int]] = {} for key in self._batch_row_update: ( @@ -579,7 +639,7 @@ class ClientIpStore(ClientIpWorkerStore): if last_seen >= since_ts: results[(access_token, ip)] = (user_agent, last_seen) - def get_recent(txn): + def get_recent(txn: LoggingTransaction) -> List[Tuple[str, str, str, int]]: txn.execute( """ SELECT access_token, ip, user_agent, last_seen FROM user_ips @@ -589,7 +649,7 @@ class ClientIpStore(ClientIpWorkerStore): """, (since_ts, user_id), ) - return txn.fetchall() + return cast(List[Tuple[str, str, str, int]], txn.fetchall()) rows = await self.db_pool.runInteraction( desc="get_user_ip_and_agents", func=get_recent diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 10184d6ae7..ba9f71a230 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -906,7 +906,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas desc="get_latest_event_ids_in_room", ) - async def get_min_depth(self, room_id: str) -> int: + async def get_min_depth(self, room_id: str) -> Optional[int]: """For the given room, get the minimum depth we have seen for it.""" return await self.db_pool.runInteraction( "get_min_depth", self._get_min_depth_interaction, room_id diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 19f55c19c5..37439f8562 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -2069,12 +2069,14 @@ class PersistEventsStore: state_groups[event.event_id] = context.state_group - self.db_pool.simple_insert_many_txn( + self.db_pool.simple_upsert_many_txn( txn, table="event_to_state_groups", - values=[ - {"state_group": state_group_id, "event_id": event_id} - for event_id, state_group_id in state_groups.items() + key_names=["event_id"], + key_values=[[event_id] for event_id, _ in state_groups.items()], + value_names=["state_group"], + value_values=[ + [state_group_id] for _, state_group_id in state_groups.items() ], ) diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 4a1a2f4a6a..ae37901be9 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -55,8 +55,9 @@ from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import BackfillStream from synapse.replication.tcp.streams.events import EventsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause -from synapse.storage.database import DatabasePool +from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.engines import PostgresEngine +from synapse.storage.types import Connection from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.storage.util.sequence import build_sequence_generator from synapse.types import JsonDict, get_domain_from_id @@ -86,6 +87,47 @@ class _EventCacheEntry: redacted_event: Optional[EventBase] +@attr.s(slots=True, frozen=True, auto_attribs=True) +class _EventRow: + """ + An event, as pulled from the database. + + Properties: + event_id: The event ID of the event. + + stream_ordering: stream ordering for this event + + json: json-encoded event structure + + internal_metadata: json-encoded internal metadata dict + + format_version: The format of the event. Hopefully one of EventFormatVersions. + 'None' means the event predates EventFormatVersions (so the event is format V1). + + room_version_id: The version of the room which contains the event. Hopefully + one of RoomVersions. + + Due to historical reasons, there may be a few events in the database which + do not have an associated room; in this case None will be returned here. + + rejected_reason: if the event was rejected, the reason why. + + redactions: a list of event-ids which (claim to) redact this event. + + outlier: True if this event is an outlier. + """ + + event_id: str + stream_ordering: int + json: str + internal_metadata: str + format_version: Optional[int] + room_version_id: Optional[int] + rejected_reason: Optional[str] + redactions: List[str] + outlier: bool + + class EventRedactBehaviour(Names): """ What to do when retrieving a redacted event from the database. @@ -686,7 +728,7 @@ class EventsWorkerStore(SQLBaseStore): for e in state_to_include.values() ] - def _do_fetch(self, conn): + def _do_fetch(self, conn: Connection) -> None: """Takes a database connection and waits for requests for events from the _event_fetch_list queue. """ @@ -713,13 +755,15 @@ class EventsWorkerStore(SQLBaseStore): self._fetch_event_list(conn, event_list) - def _fetch_event_list(self, conn, event_list): + def _fetch_event_list( + self, conn: Connection, event_list: List[Tuple[List[str], defer.Deferred]] + ) -> None: """Handle a load of requests from the _event_fetch_list queue Args: - conn (twisted.enterprise.adbapi.Connection): database connection + conn: database connection - event_list (list[Tuple[list[str], Deferred]]): + event_list: The fetch requests. Each entry consists of a list of event ids to be fetched, and a deferred to be completed once the events have been fetched. @@ -788,7 +832,7 @@ class EventsWorkerStore(SQLBaseStore): row = row_map.get(event_id) fetched_events[event_id] = row if row: - redaction_ids.update(row["redactions"]) + redaction_ids.update(row.redactions) events_to_fetch = redaction_ids.difference(fetched_events.keys()) if events_to_fetch: @@ -799,32 +843,32 @@ class EventsWorkerStore(SQLBaseStore): for event_id, row in fetched_events.items(): if not row: continue - assert row["event_id"] == event_id + assert row.event_id == event_id - rejected_reason = row["rejected_reason"] + rejected_reason = row.rejected_reason # If the event or metadata cannot be parsed, log the error and act # as if the event is unknown. try: - d = db_to_json(row["json"]) + d = db_to_json(row.json) except ValueError: logger.error("Unable to parse json from event: %s", event_id) continue try: - internal_metadata = db_to_json(row["internal_metadata"]) + internal_metadata = db_to_json(row.internal_metadata) except ValueError: logger.error( "Unable to parse internal_metadata from event: %s", event_id ) continue - format_version = row["format_version"] + format_version = row.format_version if format_version is None: # This means that we stored the event before we had the concept # of a event format version, so it must be a V1 event. format_version = EventFormatVersions.V1 - room_version_id = row["room_version_id"] + room_version_id = row.room_version_id if not room_version_id: # this should only happen for out-of-band membership events which @@ -889,8 +933,8 @@ class EventsWorkerStore(SQLBaseStore): internal_metadata_dict=internal_metadata, rejected_reason=rejected_reason, ) - original_ev.internal_metadata.stream_ordering = row["stream_ordering"] - original_ev.internal_metadata.outlier = row["outlier"] + original_ev.internal_metadata.stream_ordering = row.stream_ordering + original_ev.internal_metadata.outlier = row.outlier event_map[event_id] = original_ev @@ -898,7 +942,7 @@ class EventsWorkerStore(SQLBaseStore): # the cache entries. result_map = {} for event_id, original_ev in event_map.items(): - redactions = fetched_events[event_id]["redactions"] + redactions = fetched_events[event_id].redactions redacted_event = self._maybe_redact_event_row( original_ev, redactions, event_map ) @@ -912,17 +956,17 @@ class EventsWorkerStore(SQLBaseStore): return result_map - async def _enqueue_events(self, events): + async def _enqueue_events(self, events: Iterable[str]) -> Dict[str, _EventRow]: """Fetches events from the database using the _event_fetch_list. This allows batch and bulk fetching of events - it allows us to fetch events without having to create a new transaction for each request for events. Args: - events (Iterable[str]): events to be fetched. + events: events to be fetched. Returns: - Dict[str, Dict]: map from event id to row data from the database. - May contain events that weren't requested. + A map from event id to row data from the database. May contain events + that weren't requested. """ events_d = defer.Deferred() @@ -949,43 +993,19 @@ class EventsWorkerStore(SQLBaseStore): return row_map - def _fetch_event_rows(self, txn, event_ids): + def _fetch_event_rows( + self, txn: LoggingTransaction, event_ids: Iterable[str] + ) -> Dict[str, _EventRow]: """Fetch event rows from the database Events which are not found are omitted from the result. - The returned per-event dicts contain the following keys: - - * event_id (str) - - * stream_ordering (int): stream ordering for this event - - * json (str): json-encoded event structure - - * internal_metadata (str): json-encoded internal metadata dict - - * format_version (int|None): The format of the event. Hopefully one - of EventFormatVersions. 'None' means the event predates - EventFormatVersions (so the event is format V1). - - * room_version_id (str|None): The version of the room which contains the event. - Hopefully one of RoomVersions. - - Due to historical reasons, there may be a few events in the database which - do not have an associated room; in this case None will be returned here. - - * rejected_reason (str|None): if the event was rejected, the reason - why. - - * redactions (List[str]): a list of event-ids which (claim to) redact - this event. - Args: - txn (twisted.enterprise.adbapi.Connection): - event_ids (Iterable[str]): event IDs to fetch + txn: The database transaction. + event_ids: event IDs to fetch Returns: - Dict[str, Dict]: a map from event id to event info. + A map from event id to event info. """ event_dict = {} for evs in batch_iter(event_ids, 200): @@ -1013,17 +1033,17 @@ class EventsWorkerStore(SQLBaseStore): for row in txn: event_id = row[0] - event_dict[event_id] = { - "event_id": event_id, - "stream_ordering": row[1], - "internal_metadata": row[2], - "json": row[3], - "format_version": row[4], - "room_version_id": row[5], - "rejected_reason": row[6], - "redactions": [], - "outlier": row[7], - } + event_dict[event_id] = _EventRow( + event_id=event_id, + stream_ordering=row[1], + internal_metadata=row[2], + json=row[3], + format_version=row[4], + room_version_id=row[5], + rejected_reason=row[6], + redactions=[], + outlier=row[7], + ) # check for redactions redactions_sql = "SELECT event_id, redacts FROM redactions WHERE " @@ -1035,7 +1055,7 @@ class EventsWorkerStore(SQLBaseStore): for (redacter, redacted) in txn: d = event_dict.get(redacted) if d: - d["redactions"].append(redacter) + d.redactions.append(redacter) return event_dict diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 181841ee06..0ab56d8a07 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -2237,7 +2237,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): # accident. row = {"client_secret": None, "validated_at": None} else: - raise ThreepidValidationError(400, "Unknown session_id") + raise ThreepidValidationError("Unknown session_id") retrieved_client_secret = row["client_secret"] validated_at = row["validated_at"] @@ -2252,14 +2252,14 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): if not row: raise ThreepidValidationError( - 400, "Validation token not found or has expired" + "Validation token not found or has expired" ) expires = row["expires"] next_link = row["next_link"] if retrieved_client_secret != client_secret: raise ThreepidValidationError( - 400, "This client_secret does not match the provided session_id" + "This client_secret does not match the provided session_id" ) # If the session is already validated, no need to revalidate @@ -2268,7 +2268,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): if expires <= current_ts: raise ThreepidValidationError( - 400, "This token has expired. Please request a new one" + "This token has expired. Please request a new one" ) # Looks good. Validate the session diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index d69eaf80ce..835d7889cb 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -679,8 +679,8 @@ class RoomWorkerStore(SQLBaseStore): # policy. if not ret: return { - "min_lifetime": self.config.server.retention_default_min_lifetime, - "max_lifetime": self.config.server.retention_default_max_lifetime, + "min_lifetime": self.config.retention.retention_default_min_lifetime, + "max_lifetime": self.config.retention.retention_default_max_lifetime, } row = ret[0] @@ -690,10 +690,10 @@ class RoomWorkerStore(SQLBaseStore): # The default values will be None if no default policy has been defined, or if one # of the attributes is missing from the default policy. if row["min_lifetime"] is None: - row["min_lifetime"] = self.config.server.retention_default_min_lifetime + row["min_lifetime"] = self.config.retention.retention_default_min_lifetime if row["max_lifetime"] is None: - row["max_lifetime"] = self.config.server.retention_default_max_lifetime + row["max_lifetime"] = self.config.retention.retention_default_max_lifetime return row diff --git a/synapse/storage/databases/main/room_batch.py b/synapse/storage/databases/main/room_batch.py index 300a563c9e..dcbce8fdcf 100644 --- a/synapse/storage/databases/main/room_batch.py +++ b/synapse/storage/databases/main/room_batch.py @@ -36,3 +36,16 @@ class RoomBatchStore(SQLBaseStore): retcol="event_id", allow_none=True, ) + + async def store_state_group_id_for_event_id( + self, event_id: str, state_group_id: int + ) -> Optional[str]: + { + await self.db_pool.simple_upsert( + table="event_to_state_groups", + keyvalues={"event_id": event_id}, + values={"state_group": state_group_id, "event_id": event_id}, + # Unique constraint on event_id so we don't have to lock + lock=False, + ) + } diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 11ca47ea28..1629d2a53c 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -549,6 +549,8 @@ def _apply_module_schemas( database_engine: config: application config """ + # This is the old way for password_auth_provider modules to make changes + # to the database. This should instead be done using the module API for (mod, _config) in config.authproviders.password_providers: if not hasattr(mod, "get_db_schema_files"): continue diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py index 1aee741a8b..a1d2332326 100644 --- a/synapse/storage/schema/__init__.py +++ b/synapse/storage/schema/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -SCHEMA_VERSION = 64 # remember to update the list below when updating +SCHEMA_VERSION = 65 # remember to update the list below when updating """Represents the expectations made by the codebase about the database schema This should be incremented whenever the codebase changes its requirements on the @@ -41,6 +41,10 @@ Changes in SCHEMA_VERSION = 63: Changes in SCHEMA_VERSION = 64: - MSC2716: Rename related tables and columns from "chunks" to "batches". + +Changes in SCHEMA_VERSION = 65: + - MSC2716: Remove unique event_id constraint from insertion_event_edges + because an insertion event can have multiple edges. """ diff --git a/synapse/storage/schema/main/delta/65/01msc2716_insertion_event_edges.sql b/synapse/storage/schema/main/delta/65/01msc2716_insertion_event_edges.sql new file mode 100644 index 0000000000..98b25daf45 --- /dev/null +++ b/synapse/storage/schema/main/delta/65/01msc2716_insertion_event_edges.sql @@ -0,0 +1,19 @@ +/* Copyright 2021 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Recreate the insertion_event_edges event_id index without the unique constraint +-- because an insertion event can have multiple edges. +DROP INDEX insertion_event_edges_event_id; +CREATE INDEX IF NOT EXISTS insertion_event_edges_event_id ON insertion_event_edges(event_id); |