diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 85f6b1e3fd..e550cbc866 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -15,18 +15,31 @@
# limitations under the License.
import logging
import re
+from typing import TYPE_CHECKING, List, Optional, Pattern, Tuple
-from synapse.appservice import AppServiceTransaction
+from synapse.appservice import (
+ ApplicationService,
+ ApplicationServiceState,
+ AppServiceTransaction,
+)
from synapse.config.appservice import load_appservices
+from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.events_worker import EventsWorkerStore
+from synapse.storage.types import Connection
+from synapse.types import JsonDict
from synapse.util import json_encoder
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
logger = logging.getLogger(__name__)
-def _make_exclusive_regex(services_cache):
+def _make_exclusive_regex(
+ services_cache: List[ApplicationService],
+) -> Optional[Pattern]:
# We precompile a regex constructed from all the regexes that the AS's
# have registered for exclusive users.
exclusive_user_regexes = [
@@ -36,17 +49,19 @@ def _make_exclusive_regex(services_cache):
]
if exclusive_user_regexes:
exclusive_user_regex = "|".join("(" + r + ")" for r in exclusive_user_regexes)
- exclusive_user_regex = re.compile(exclusive_user_regex)
+ exclusive_user_pattern = re.compile(
+ exclusive_user_regex
+ ) # type: Optional[Pattern]
else:
# We handle this case specially otherwise the constructed regex
# will always match
- exclusive_user_regex = None
+ exclusive_user_pattern = None
- return exclusive_user_regex
+ return exclusive_user_pattern
class ApplicationServiceWorkerStore(SQLBaseStore):
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
self.services_cache = load_appservices(
hs.hostname, hs.config.app_service_config_files
)
@@ -57,7 +72,7 @@ class ApplicationServiceWorkerStore(SQLBaseStore):
def get_app_services(self):
return self.services_cache
- def get_if_app_services_interested_in_user(self, user_id):
+ def get_if_app_services_interested_in_user(self, user_id: str) -> bool:
"""Check if the user is one associated with an app service (exclusively)
"""
if self.exclusive_user_regex:
@@ -65,7 +80,7 @@ class ApplicationServiceWorkerStore(SQLBaseStore):
else:
return False
- def get_app_service_by_user_id(self, user_id):
+ def get_app_service_by_user_id(self, user_id: str) -> Optional[ApplicationService]:
"""Retrieve an application service from their user ID.
All application services have associated with them a particular user ID.
@@ -74,35 +89,35 @@ class ApplicationServiceWorkerStore(SQLBaseStore):
a user ID to an application service.
Args:
- user_id(str): The user ID to see if it is an application service.
+ user_id: The user ID to see if it is an application service.
Returns:
- synapse.appservice.ApplicationService or None.
+ The application service or None.
"""
for service in self.services_cache:
if service.sender == user_id:
return service
return None
- def get_app_service_by_token(self, token):
+ def get_app_service_by_token(self, token: str) -> Optional[ApplicationService]:
"""Get the application service with the given appservice token.
Args:
- token (str): The application service token.
+ token: The application service token.
Returns:
- synapse.appservice.ApplicationService or None.
+ The application service or None.
"""
for service in self.services_cache:
if service.token == token:
return service
return None
- def get_app_service_by_id(self, as_id):
+ def get_app_service_by_id(self, as_id: str) -> Optional[ApplicationService]:
"""Get the application service with the given appservice ID.
Args:
- as_id (str): The application service ID.
+ as_id: The application service ID.
Returns:
- synapse.appservice.ApplicationService or None.
+ The application service or None.
"""
for service in self.services_cache:
if service.id == as_id:
@@ -121,11 +136,13 @@ class ApplicationServiceStore(ApplicationServiceWorkerStore):
class ApplicationServiceTransactionWorkerStore(
ApplicationServiceWorkerStore, EventsWorkerStore
):
- async def get_appservices_by_state(self, state):
+ async def get_appservices_by_state(
+ self, state: ApplicationServiceState
+ ) -> List[ApplicationService]:
"""Get a list of application services based on their state.
Args:
- state(ApplicationServiceState): The state to filter on.
+ state: The state to filter on.
Returns:
A list of ApplicationServices, which may be empty.
"""
@@ -142,13 +159,15 @@ class ApplicationServiceTransactionWorkerStore(
services.append(service)
return services
- async def get_appservice_state(self, service):
+ async def get_appservice_state(
+ self, service: ApplicationService
+ ) -> Optional[ApplicationServiceState]:
"""Get the application service state.
Args:
- service(ApplicationService): The service whose state to set.
+ service: The service whose state to set.
Returns:
- An ApplicationServiceState.
+ An ApplicationServiceState or none.
"""
result = await self.db_pool.simple_select_one(
"application_services_state",
@@ -161,26 +180,36 @@ class ApplicationServiceTransactionWorkerStore(
return result.get("state")
return None
- async def set_appservice_state(self, service, state) -> None:
+ async def set_appservice_state(
+ self, service: ApplicationService, state: ApplicationServiceState
+ ) -> None:
"""Set the application service state.
Args:
- service(ApplicationService): The service whose state to set.
- state(ApplicationServiceState): The connectivity state to apply.
+ service: The service whose state to set.
+ state: The connectivity state to apply.
"""
await self.db_pool.simple_upsert(
"application_services_state", {"as_id": service.id}, {"state": state}
)
- async def create_appservice_txn(self, service, events):
+ async def create_appservice_txn(
+ self,
+ service: ApplicationService,
+ events: List[EventBase],
+ ephemeral: List[JsonDict],
+ ) -> AppServiceTransaction:
"""Atomically creates a new transaction for this application service
- with the given list of events.
+ with the given list of events. Ephemeral events are NOT persisted to the
+ database and are not resent if a transaction is retried.
Args:
- service(ApplicationService): The service who the transaction is for.
- events(list<Event>): A list of events to put in the transaction.
+ service: The service who the transaction is for.
+ events: A list of persistent events to put in the transaction.
+ ephemeral: A list of ephemeral events to put in the transaction.
+
Returns:
- AppServiceTransaction: A new transaction.
+ A new transaction.
"""
def _create_appservice_txn(txn):
@@ -207,19 +236,22 @@ class ApplicationServiceTransactionWorkerStore(
"VALUES(?,?,?)",
(service.id, new_txn_id, event_ids),
)
- return AppServiceTransaction(service=service, id=new_txn_id, events=events)
+ return AppServiceTransaction(
+ service=service, id=new_txn_id, events=events, ephemeral=ephemeral
+ )
return await self.db_pool.runInteraction(
"create_appservice_txn", _create_appservice_txn
)
- async def complete_appservice_txn(self, txn_id, service) -> None:
+ async def complete_appservice_txn(
+ self, txn_id: int, service: ApplicationService
+ ) -> None:
"""Completes an application service transaction.
Args:
- txn_id(str): The transaction ID being completed.
- service(ApplicationService): The application service which was sent
- this transaction.
+ txn_id: The transaction ID being completed.
+ service: The application service which was sent this transaction.
"""
txn_id = int(txn_id)
@@ -259,12 +291,13 @@ class ApplicationServiceTransactionWorkerStore(
"complete_appservice_txn", _complete_appservice_txn
)
- async def get_oldest_unsent_txn(self, service):
- """Get the oldest transaction which has not been sent for this
- service.
+ async def get_oldest_unsent_txn(
+ self, service: ApplicationService
+ ) -> Optional[AppServiceTransaction]:
+ """Get the oldest transaction which has not been sent for this service.
Args:
- service(ApplicationService): The app service to get the oldest txn.
+ service: The app service to get the oldest txn.
Returns:
An AppServiceTransaction or None.
"""
@@ -296,9 +329,11 @@ class ApplicationServiceTransactionWorkerStore(
events = await self.get_events_as_list(event_ids)
- return AppServiceTransaction(service=service, id=entry["txn_id"], events=events)
+ return AppServiceTransaction(
+ service=service, id=entry["txn_id"], events=events, ephemeral=[]
+ )
- def _get_last_txn(self, txn, service_id):
+ def _get_last_txn(self, txn, service_id: Optional[str]) -> int:
txn.execute(
"SELECT last_txn FROM application_services_state WHERE as_id=?",
(service_id,),
@@ -309,7 +344,7 @@ class ApplicationServiceTransactionWorkerStore(
else:
return int(last_txn_id[0]) # select 'last_txn' col
- async def set_appservice_last_pos(self, pos) -> None:
+ async def set_appservice_last_pos(self, pos: int) -> None:
def set_appservice_last_pos_txn(txn):
txn.execute(
"UPDATE appservice_stream_position SET stream_ordering = ?", (pos,)
@@ -319,8 +354,10 @@ class ApplicationServiceTransactionWorkerStore(
"set_appservice_last_pos", set_appservice_last_pos_txn
)
- async def get_new_events_for_appservice(self, current_id, limit):
- """Get all new evnets"""
+ async def get_new_events_for_appservice(
+ self, current_id: int, limit: int
+ ) -> Tuple[int, List[EventBase]]:
+ """Get all new events for an appservice"""
def get_new_events_for_appservice_txn(txn):
sql = (
@@ -351,6 +388,54 @@ class ApplicationServiceTransactionWorkerStore(
return upper_bound, events
+ async def get_type_stream_id_for_appservice(
+ self, service: ApplicationService, type: str
+ ) -> int:
+ if type not in ("read_receipt", "presence"):
+ raise ValueError(
+ "Expected type to be a valid application stream id type, got %s"
+ % (type,)
+ )
+
+ def get_type_stream_id_for_appservice_txn(txn):
+ stream_id_type = "%s_stream_id" % type
+ txn.execute(
+ # We do NOT want to escape `stream_id_type`.
+ "SELECT %s FROM application_services_state WHERE as_id=?"
+ % stream_id_type,
+ (service.id,),
+ )
+ last_stream_id = txn.fetchone()
+ if last_stream_id is None or last_stream_id[0] is None: # no row exists
+ return 0
+ else:
+ return int(last_stream_id[0])
+
+ return await self.db_pool.runInteraction(
+ "get_type_stream_id_for_appservice", get_type_stream_id_for_appservice_txn
+ )
+
+ async def set_type_stream_id_for_appservice(
+ self, service: ApplicationService, type: str, pos: Optional[int]
+ ) -> None:
+ if type not in ("read_receipt", "presence"):
+ raise ValueError(
+ "Expected type to be a valid application stream id type, got %s"
+ % (type,)
+ )
+
+ def set_type_stream_id_for_appservice_txn(txn):
+ stream_id_type = "%s_stream_id" % type
+ txn.execute(
+ "UPDATE application_services_state SET %s = ? WHERE as_id=?"
+ % stream_id_type,
+ (pos, service.id),
+ )
+
+ await self.db_pool.runInteraction(
+ "set_type_stream_id_for_appservice", set_type_stream_id_for_appservice_txn
+ )
+
class ApplicationServiceTransactionStore(ApplicationServiceTransactionWorkerStore):
# This is currently empty due to there not being any AS storage functions
|