diff --git a/mypy.ini b/mypy.ini
index 3cb6cecd7e..df6fd00d5d 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -255,3 +255,7 @@ ignore_missing_imports = True
[mypy-ijson.*]
ignore_missing_imports = True
+
+
+[mypy-psycopg2.*]
+ignore_missing_imports = True
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index f5a8f90a0f..f5fb1a33de 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -19,6 +19,7 @@ from collections import defaultdict
from sys import intern
from time import monotonic as monotonic_time
from typing import (
+ TYPE_CHECKING,
Any,
Callable,
Collection,
@@ -52,6 +53,9 @@ from synapse.storage.background_updates import BackgroundUpdater
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
from synapse.storage.types import Connection, Cursor
+if TYPE_CHECKING:
+ from psycopg2.extensions import ConnectionInfo
+
# python 3 does not have a maximum int value
MAX_TXN_ID = 2 ** 63 - 1
@@ -395,6 +399,7 @@ class DatabasePool:
hs,
database_config: DatabaseConnectionConfig,
engine: BaseDatabaseEngine,
+ db_conn: LoggingDatabaseConnection,
):
self.hs = hs
self._clock = hs.get_clock()
@@ -427,6 +432,13 @@ class DatabasePool:
if isinstance(self.engine, Sqlite3Engine):
self._unsafe_to_upsert_tables.add("user_directory_search")
+ # We store the connection info for later use when using postgres
+ # (primarily to allow things like the state auto compressor to connect
+ # to the DB).
+ self.postgres_connection_info: Optional["ConnectionInfo"] = None
+ if isinstance(self.engine, PostgresEngine):
+ self.postgres_connection_info = db_conn.info
+
if self.engine.can_native_upsert:
# Check ASAP (and then later, every 1s) to see if we have finished
# background updates of tables that aren't safe to update.
diff --git a/synapse/storage/databases/__init__.py b/synapse/storage/databases/__init__.py
index 20b755056b..6f2d9a062e 100644
--- a/synapse/storage/databases/__init__.py
+++ b/synapse/storage/databases/__init__.py
@@ -61,7 +61,7 @@ class Databases:
databases=database_config.databases,
)
- database = DatabasePool(hs, database_config, engine)
+ database = DatabasePool(hs, database_config, engine, db_conn)
if "main" in database_config.databases:
logger.info(
|