diff options
Diffstat (limited to 'synapse/storage/databases/main/appservice.py')
-rw-r--r-- | synapse/storage/databases/main/appservice.py | 171 |
1 files changed, 128 insertions, 43 deletions
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index 85f6b1e3fd..e550cbc866 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -15,18 +15,31 @@ # limitations under the License. import logging import re +from typing import TYPE_CHECKING, List, Optional, Pattern, Tuple -from synapse.appservice import AppServiceTransaction +from synapse.appservice import ( + ApplicationService, + ApplicationServiceState, + AppServiceTransaction, +) from synapse.config.appservice import load_appservices +from synapse.events import EventBase from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import DatabasePool from synapse.storage.databases.main.events_worker import EventsWorkerStore +from synapse.storage.types import Connection +from synapse.types import JsonDict from synapse.util import json_encoder +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) -def _make_exclusive_regex(services_cache): +def _make_exclusive_regex( + services_cache: List[ApplicationService], +) -> Optional[Pattern]: # We precompile a regex constructed from all the regexes that the AS's # have registered for exclusive users. exclusive_user_regexes = [ @@ -36,17 +49,19 @@ def _make_exclusive_regex(services_cache): ] if exclusive_user_regexes: exclusive_user_regex = "|".join("(" + r + ")" for r in exclusive_user_regexes) - exclusive_user_regex = re.compile(exclusive_user_regex) + exclusive_user_pattern = re.compile( + exclusive_user_regex + ) # type: Optional[Pattern] else: # We handle this case specially otherwise the constructed regex # will always match - exclusive_user_regex = None + exclusive_user_pattern = None - return exclusive_user_regex + return exclusive_user_pattern class ApplicationServiceWorkerStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): self.services_cache = load_appservices( hs.hostname, hs.config.app_service_config_files ) @@ -57,7 +72,7 @@ class ApplicationServiceWorkerStore(SQLBaseStore): def get_app_services(self): return self.services_cache - def get_if_app_services_interested_in_user(self, user_id): + def get_if_app_services_interested_in_user(self, user_id: str) -> bool: """Check if the user is one associated with an app service (exclusively) """ if self.exclusive_user_regex: @@ -65,7 +80,7 @@ class ApplicationServiceWorkerStore(SQLBaseStore): else: return False - def get_app_service_by_user_id(self, user_id): + def get_app_service_by_user_id(self, user_id: str) -> Optional[ApplicationService]: """Retrieve an application service from their user ID. All application services have associated with them a particular user ID. @@ -74,35 +89,35 @@ class ApplicationServiceWorkerStore(SQLBaseStore): a user ID to an application service. Args: - user_id(str): The user ID to see if it is an application service. + user_id: The user ID to see if it is an application service. Returns: - synapse.appservice.ApplicationService or None. + The application service or None. """ for service in self.services_cache: if service.sender == user_id: return service return None - def get_app_service_by_token(self, token): + def get_app_service_by_token(self, token: str) -> Optional[ApplicationService]: """Get the application service with the given appservice token. Args: - token (str): The application service token. + token: The application service token. Returns: - synapse.appservice.ApplicationService or None. + The application service or None. """ for service in self.services_cache: if service.token == token: return service return None - def get_app_service_by_id(self, as_id): + def get_app_service_by_id(self, as_id: str) -> Optional[ApplicationService]: """Get the application service with the given appservice ID. Args: - as_id (str): The application service ID. + as_id: The application service ID. Returns: - synapse.appservice.ApplicationService or None. + The application service or None. """ for service in self.services_cache: if service.id == as_id: @@ -121,11 +136,13 @@ class ApplicationServiceStore(ApplicationServiceWorkerStore): class ApplicationServiceTransactionWorkerStore( ApplicationServiceWorkerStore, EventsWorkerStore ): - async def get_appservices_by_state(self, state): + async def get_appservices_by_state( + self, state: ApplicationServiceState + ) -> List[ApplicationService]: """Get a list of application services based on their state. Args: - state(ApplicationServiceState): The state to filter on. + state: The state to filter on. Returns: A list of ApplicationServices, which may be empty. """ @@ -142,13 +159,15 @@ class ApplicationServiceTransactionWorkerStore( services.append(service) return services - async def get_appservice_state(self, service): + async def get_appservice_state( + self, service: ApplicationService + ) -> Optional[ApplicationServiceState]: """Get the application service state. Args: - service(ApplicationService): The service whose state to set. + service: The service whose state to set. Returns: - An ApplicationServiceState. + An ApplicationServiceState or none. """ result = await self.db_pool.simple_select_one( "application_services_state", @@ -161,26 +180,36 @@ class ApplicationServiceTransactionWorkerStore( return result.get("state") return None - async def set_appservice_state(self, service, state) -> None: + async def set_appservice_state( + self, service: ApplicationService, state: ApplicationServiceState + ) -> None: """Set the application service state. Args: - service(ApplicationService): The service whose state to set. - state(ApplicationServiceState): The connectivity state to apply. + service: The service whose state to set. + state: The connectivity state to apply. """ await self.db_pool.simple_upsert( "application_services_state", {"as_id": service.id}, {"state": state} ) - async def create_appservice_txn(self, service, events): + async def create_appservice_txn( + self, + service: ApplicationService, + events: List[EventBase], + ephemeral: List[JsonDict], + ) -> AppServiceTransaction: """Atomically creates a new transaction for this application service - with the given list of events. + with the given list of events. Ephemeral events are NOT persisted to the + database and are not resent if a transaction is retried. Args: - service(ApplicationService): The service who the transaction is for. - events(list<Event>): A list of events to put in the transaction. + service: The service who the transaction is for. + events: A list of persistent events to put in the transaction. + ephemeral: A list of ephemeral events to put in the transaction. + Returns: - AppServiceTransaction: A new transaction. + A new transaction. """ def _create_appservice_txn(txn): @@ -207,19 +236,22 @@ class ApplicationServiceTransactionWorkerStore( "VALUES(?,?,?)", (service.id, new_txn_id, event_ids), ) - return AppServiceTransaction(service=service, id=new_txn_id, events=events) + return AppServiceTransaction( + service=service, id=new_txn_id, events=events, ephemeral=ephemeral + ) return await self.db_pool.runInteraction( "create_appservice_txn", _create_appservice_txn ) - async def complete_appservice_txn(self, txn_id, service) -> None: + async def complete_appservice_txn( + self, txn_id: int, service: ApplicationService + ) -> None: """Completes an application service transaction. Args: - txn_id(str): The transaction ID being completed. - service(ApplicationService): The application service which was sent - this transaction. + txn_id: The transaction ID being completed. + service: The application service which was sent this transaction. """ txn_id = int(txn_id) @@ -259,12 +291,13 @@ class ApplicationServiceTransactionWorkerStore( "complete_appservice_txn", _complete_appservice_txn ) - async def get_oldest_unsent_txn(self, service): - """Get the oldest transaction which has not been sent for this - service. + async def get_oldest_unsent_txn( + self, service: ApplicationService + ) -> Optional[AppServiceTransaction]: + """Get the oldest transaction which has not been sent for this service. Args: - service(ApplicationService): The app service to get the oldest txn. + service: The app service to get the oldest txn. Returns: An AppServiceTransaction or None. """ @@ -296,9 +329,11 @@ class ApplicationServiceTransactionWorkerStore( events = await self.get_events_as_list(event_ids) - return AppServiceTransaction(service=service, id=entry["txn_id"], events=events) + return AppServiceTransaction( + service=service, id=entry["txn_id"], events=events, ephemeral=[] + ) - def _get_last_txn(self, txn, service_id): + def _get_last_txn(self, txn, service_id: Optional[str]) -> int: txn.execute( "SELECT last_txn FROM application_services_state WHERE as_id=?", (service_id,), @@ -309,7 +344,7 @@ class ApplicationServiceTransactionWorkerStore( else: return int(last_txn_id[0]) # select 'last_txn' col - async def set_appservice_last_pos(self, pos) -> None: + async def set_appservice_last_pos(self, pos: int) -> None: def set_appservice_last_pos_txn(txn): txn.execute( "UPDATE appservice_stream_position SET stream_ordering = ?", (pos,) @@ -319,8 +354,10 @@ class ApplicationServiceTransactionWorkerStore( "set_appservice_last_pos", set_appservice_last_pos_txn ) - async def get_new_events_for_appservice(self, current_id, limit): - """Get all new evnets""" + async def get_new_events_for_appservice( + self, current_id: int, limit: int + ) -> Tuple[int, List[EventBase]]: + """Get all new events for an appservice""" def get_new_events_for_appservice_txn(txn): sql = ( @@ -351,6 +388,54 @@ class ApplicationServiceTransactionWorkerStore( return upper_bound, events + async def get_type_stream_id_for_appservice( + self, service: ApplicationService, type: str + ) -> int: + if type not in ("read_receipt", "presence"): + raise ValueError( + "Expected type to be a valid application stream id type, got %s" + % (type,) + ) + + def get_type_stream_id_for_appservice_txn(txn): + stream_id_type = "%s_stream_id" % type + txn.execute( + # We do NOT want to escape `stream_id_type`. + "SELECT %s FROM application_services_state WHERE as_id=?" + % stream_id_type, + (service.id,), + ) + last_stream_id = txn.fetchone() + if last_stream_id is None or last_stream_id[0] is None: # no row exists + return 0 + else: + return int(last_stream_id[0]) + + return await self.db_pool.runInteraction( + "get_type_stream_id_for_appservice", get_type_stream_id_for_appservice_txn + ) + + async def set_type_stream_id_for_appservice( + self, service: ApplicationService, type: str, pos: Optional[int] + ) -> None: + if type not in ("read_receipt", "presence"): + raise ValueError( + "Expected type to be a valid application stream id type, got %s" + % (type,) + ) + + def set_type_stream_id_for_appservice_txn(txn): + stream_id_type = "%s_stream_id" % type + txn.execute( + "UPDATE application_services_state SET %s = ? WHERE as_id=?" + % stream_id_type, + (pos, service.id), + ) + + await self.db_pool.runInteraction( + "set_type_stream_id_for_appservice", set_type_stream_id_for_appservice_txn + ) + class ApplicationServiceTransactionStore(ApplicationServiceTransactionWorkerStore): # This is currently empty due to there not being any AS storage functions |