diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py
index 1b13e84425..13ec1f71a6 100644
--- a/synapse/appservice/__init__.py
+++ b/synapse/appservice/__init__.py
@@ -14,24 +14,25 @@
# limitations under the License.
import logging
import re
-
-from six import string_types
-
-from twisted.internet import defer
+from typing import TYPE_CHECKING
from synapse.api.constants import EventTypes
+from synapse.appservice.api import ApplicationServiceApi
from synapse.types import GroupID, get_domain_from_id
-from synapse.util.caches.descriptors import cachedInlineCallbacks
+from synapse.util.caches.descriptors import cached
+
+if TYPE_CHECKING:
+ from synapse.storage.databases.main import DataStore
logger = logging.getLogger(__name__)
-class ApplicationServiceState(object):
+class ApplicationServiceState:
DOWN = "down"
UP = "up"
-class AppServiceTransaction(object):
+class AppServiceTransaction:
"""Represents an application service transaction."""
def __init__(self, service, id, events):
@@ -39,19 +40,19 @@ class AppServiceTransaction(object):
self.id = id
self.events = events
- def send(self, as_api):
+ async def send(self, as_api: ApplicationServiceApi) -> bool:
"""Sends this transaction using the provided AS API interface.
Args:
- as_api(ApplicationServiceApi): The API to use to send.
+ as_api: The API to use to send.
Returns:
- A Deferred which resolves to True if the transaction was sent.
+ True if the transaction was sent.
"""
- return as_api.push_bulk(
+ return await as_api.push_bulk(
service=self.service, events=self.events, txn_id=self.id
)
- def complete(self, store):
+ async def complete(self, store: "DataStore") -> None:
"""Completes this transaction as successful.
Marks this transaction ID on the application service and removes the
@@ -59,13 +60,11 @@ class AppServiceTransaction(object):
Args:
store: The database store to operate on.
- Returns:
- A Deferred which resolves to True if the transaction was completed.
"""
- return store.complete_appservice_txn(service=self.service, txn_id=self.id)
+ await store.complete_appservice_txn(service=self.service, txn_id=self.id)
-class ApplicationService(object):
+class ApplicationService:
"""Defines an application service. This definition is mostly what is
provided to the /register AS API.
@@ -156,7 +155,7 @@ class ApplicationService(object):
)
regex = regex_obj.get("regex")
- if isinstance(regex, string_types):
+ if isinstance(regex, str):
regex_obj["regex"] = re.compile(regex) # Pre-compile regex
else:
raise ValueError("Expected string for 'regex' in ns '%s'" % ns)
@@ -174,8 +173,7 @@ class ApplicationService(object):
return regex_obj["exclusive"]
return False
- @defer.inlineCallbacks
- def _matches_user(self, event, store):
+ async def _matches_user(self, event, store):
if not event:
return False
@@ -190,12 +188,12 @@ class ApplicationService(object):
if not store:
return False
- does_match = yield self._matches_user_in_member_list(event.room_id, store)
+ does_match = await self._matches_user_in_member_list(event.room_id, store)
return does_match
- @cachedInlineCallbacks(num_args=1, cache_context=True)
- def _matches_user_in_member_list(self, room_id, store, cache_context):
- member_list = yield store.get_users_in_room(
+ @cached(num_args=1, cache_context=True)
+ async def _matches_user_in_member_list(self, room_id, store, cache_context):
+ member_list = await store.get_users_in_room(
room_id, on_invalidate=cache_context.invalidate
)
@@ -210,35 +208,33 @@ class ApplicationService(object):
return self.is_interested_in_room(event.room_id)
return False
- @defer.inlineCallbacks
- def _matches_aliases(self, event, store):
+ async def _matches_aliases(self, event, store):
if not store or not event:
return False
- alias_list = yield store.get_aliases_for_room(event.room_id)
+ alias_list = await store.get_aliases_for_room(event.room_id)
for alias in alias_list:
if self.is_interested_in_alias(alias):
return True
return False
- @defer.inlineCallbacks
- def is_interested(self, event, store=None):
+ async def is_interested(self, event, store=None) -> bool:
"""Check if this service is interested in this event.
Args:
event(Event): The event to check.
store(DataStore)
Returns:
- bool: True if this service would like to know about this event.
+ True if this service would like to know about this event.
"""
# Do cheap checks first
if self._matches_room_id(event):
return True
- if (yield self._matches_aliases(event, store)):
+ if await self._matches_aliases(event, store):
return True
- if (yield self._matches_user(event, store)):
+ if await self._matches_user(event, store):
return True
return False
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index 57174da021..bb6fa8299a 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -13,20 +13,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-
-from six.moves import urllib
+import urllib
+from typing import TYPE_CHECKING, Optional
from prometheus_client import Counter
-from twisted.internet import defer
-
-from synapse.api.constants import ThirdPartyEntityKind
+from synapse.api.constants import EventTypes, ThirdPartyEntityKind
from synapse.api.errors import CodeMessageException
from synapse.events.utils import serialize_event
from synapse.http.client import SimpleHttpClient
-from synapse.types import ThirdPartyInstanceID
+from synapse.types import JsonDict, ThirdPartyInstanceID
from synapse.util.caches.response_cache import ResponseCache
+if TYPE_CHECKING:
+ from synapse.appservice import ApplicationService
+
logger = logging.getLogger(__name__)
sent_transactions_counter = Counter(
@@ -94,14 +95,12 @@ class ApplicationServiceApi(SimpleHttpClient):
hs, "as_protocol_meta", timeout_ms=HOUR_IN_MS
)
- @defer.inlineCallbacks
- def query_user(self, service, user_id):
+ async def query_user(self, service, user_id):
if service.url is None:
return False
uri = service.url + ("/users/%s" % urllib.parse.quote(user_id))
- response = None
try:
- response = yield self.get_json(uri, {"access_token": service.hs_token})
+ response = await self.get_json(uri, {"access_token": service.hs_token})
if response is not None: # just an empty json object
return True
except CodeMessageException as e:
@@ -112,14 +111,12 @@ class ApplicationServiceApi(SimpleHttpClient):
logger.warning("query_user to %s threw exception %s", uri, ex)
return False
- @defer.inlineCallbacks
- def query_alias(self, service, alias):
+ async def query_alias(self, service, alias):
if service.url is None:
return False
uri = service.url + ("/rooms/%s" % urllib.parse.quote(alias))
- response = None
try:
- response = yield self.get_json(uri, {"access_token": service.hs_token})
+ response = await self.get_json(uri, {"access_token": service.hs_token})
if response is not None: # just an empty json object
return True
except CodeMessageException as e:
@@ -130,8 +127,7 @@ class ApplicationServiceApi(SimpleHttpClient):
logger.warning("query_alias to %s threw exception %s", uri, ex)
return False
- @defer.inlineCallbacks
- def query_3pe(self, service, kind, protocol, fields):
+ async def query_3pe(self, service, kind, protocol, fields):
if kind == ThirdPartyEntityKind.USER:
required_field = "userid"
elif kind == ThirdPartyEntityKind.LOCATION:
@@ -148,7 +144,7 @@ class ApplicationServiceApi(SimpleHttpClient):
urllib.parse.quote(protocol),
)
try:
- response = yield self.get_json(uri, fields)
+ response = await self.get_json(uri, fields)
if not isinstance(response, list):
logger.warning(
"query_3pe to %s returned an invalid response %r", uri, response
@@ -169,19 +165,20 @@ class ApplicationServiceApi(SimpleHttpClient):
logger.warning("query_3pe to %s threw exception %s", uri, ex)
return []
- def get_3pe_protocol(self, service, protocol):
+ async def get_3pe_protocol(
+ self, service: "ApplicationService", protocol: str
+ ) -> Optional[JsonDict]:
if service.url is None:
return {}
- @defer.inlineCallbacks
- def _get():
+ async def _get() -> Optional[JsonDict]:
uri = "%s%s/thirdparty/protocol/%s" % (
service.url,
APP_SERVICE_PREFIX,
urllib.parse.quote(protocol),
)
try:
- info = yield self.get_json(uri, {})
+ info = await self.get_json(uri, {})
if not _is_valid_3pe_metadata(info):
logger.warning(
@@ -202,14 +199,13 @@ class ApplicationServiceApi(SimpleHttpClient):
return None
key = (service.id, protocol)
- return self.protocol_meta_cache.wrap(key, _get)
+ return await self.protocol_meta_cache.wrap(key, _get)
- @defer.inlineCallbacks
- def push_bulk(self, service, events, txn_id=None):
+ async def push_bulk(self, service, events, txn_id=None):
if service.url is None:
return True
- events = self._serialize(events)
+ events = self._serialize(service, events)
if txn_id is None:
logger.warning(
@@ -220,7 +216,7 @@ class ApplicationServiceApi(SimpleHttpClient):
uri = service.url + ("/transactions/%s" % urllib.parse.quote(txn_id))
try:
- yield self.put_json(
+ await self.put_json(
uri=uri,
json_body={"events": events},
args={"access_token": service.hs_token},
@@ -235,6 +231,18 @@ class ApplicationServiceApi(SimpleHttpClient):
failed_transactions_counter.labels(service.id).inc()
return False
- def _serialize(self, events):
+ def _serialize(self, service, events):
time_now = self.clock.time_msec()
- return [serialize_event(e, time_now, as_client_event=True) for e in events]
+ return [
+ serialize_event(
+ e,
+ time_now,
+ as_client_event=True,
+ is_invite=(
+ e.type == EventTypes.Member
+ and e.membership == "invite"
+ and service.is_interested_in_user(e.state_key)
+ ),
+ )
+ for e in events
+ ]
diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py
index 9998f822f1..8eb8c6f51c 100644
--- a/synapse/appservice/scheduler.py
+++ b/synapse/appservice/scheduler.py
@@ -50,8 +50,6 @@ components.
"""
import logging
-from twisted.internet import defer
-
from synapse.appservice import ApplicationServiceState
from synapse.logging.context import run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
@@ -59,7 +57,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
logger = logging.getLogger(__name__)
-class ApplicationServiceScheduler(object):
+class ApplicationServiceScheduler:
""" Public facing API for this module. Does the required DI to tie the
components together. This also serves as the "event_pool", which in this
case is a simple array.
@@ -73,12 +71,11 @@ class ApplicationServiceScheduler(object):
self.txn_ctrl = _TransactionController(self.clock, self.store, self.as_api)
self.queuer = _ServiceQueuer(self.txn_ctrl, self.clock)
- @defer.inlineCallbacks
- def start(self):
+ async def start(self):
logger.info("Starting appservice scheduler")
# check for any DOWN ASes and start recoverers for them.
- services = yield self.store.get_appservices_by_state(
+ services = await self.store.get_appservices_by_state(
ApplicationServiceState.DOWN
)
@@ -89,7 +86,7 @@ class ApplicationServiceScheduler(object):
self.queuer.enqueue(service, event)
-class _ServiceQueuer(object):
+class _ServiceQueuer:
"""Queue of events waiting to be sent to appservices.
Groups events into transactions per-appservice, and sends them on to the
@@ -117,8 +114,7 @@ class _ServiceQueuer(object):
"as-sender-%s" % (service.id,), self._send_request, service
)
- @defer.inlineCallbacks
- def _send_request(self, service):
+ async def _send_request(self, service):
# sanity-check: we shouldn't get here if this service already has a sender
# running.
assert service.id not in self.requests_in_flight
@@ -130,14 +126,14 @@ class _ServiceQueuer(object):
if not events:
return
try:
- yield self.txn_ctrl.send(service, events)
+ await self.txn_ctrl.send(service, events)
except Exception:
logger.exception("AS request failed")
finally:
self.requests_in_flight.discard(service.id)
-class _TransactionController(object):
+class _TransactionController:
"""Transaction manager.
Builds AppServiceTransactions and runs their lifecycle. Also starts a Recoverer
@@ -162,36 +158,33 @@ class _TransactionController(object):
# for UTs
self.RECOVERER_CLASS = _Recoverer
- @defer.inlineCallbacks
- def send(self, service, events):
+ async def send(self, service, events):
try:
- txn = yield self.store.create_appservice_txn(service=service, events=events)
- service_is_up = yield self._is_service_up(service)
+ txn = await self.store.create_appservice_txn(service=service, events=events)
+ service_is_up = await self._is_service_up(service)
if service_is_up:
- sent = yield txn.send(self.as_api)
+ sent = await txn.send(self.as_api)
if sent:
- yield txn.complete(self.store)
+ await txn.complete(self.store)
else:
run_in_background(self._on_txn_fail, service)
except Exception:
logger.exception("Error creating appservice transaction")
run_in_background(self._on_txn_fail, service)
- @defer.inlineCallbacks
- def on_recovered(self, recoverer):
+ async def on_recovered(self, recoverer):
logger.info(
"Successfully recovered application service AS ID %s", recoverer.service.id
)
self.recoverers.pop(recoverer.service.id)
logger.info("Remaining active recoverers: %s", len(self.recoverers))
- yield self.store.set_appservice_state(
+ await self.store.set_appservice_state(
recoverer.service, ApplicationServiceState.UP
)
- @defer.inlineCallbacks
- def _on_txn_fail(self, service):
+ async def _on_txn_fail(self, service):
try:
- yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
+ await self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
self.start_recoverer(service)
except Exception:
logger.exception("Error starting AS recoverer")
@@ -211,13 +204,12 @@ class _TransactionController(object):
recoverer.recover()
logger.info("Now %i active recoverers", len(self.recoverers))
- @defer.inlineCallbacks
- def _is_service_up(self, service):
- state = yield self.store.get_appservice_state(service)
+ async def _is_service_up(self, service):
+ state = await self.store.get_appservice_state(service)
return state == ApplicationServiceState.UP or state is None
-class _Recoverer(object):
+class _Recoverer:
"""Manages retries and backoff for a DOWN appservice.
We have one of these for each appservice which is currently considered DOWN.
@@ -254,25 +246,24 @@ class _Recoverer(object):
self.backoff_counter += 1
self.recover()
- @defer.inlineCallbacks
- def retry(self):
+ async def retry(self):
logger.info("Starting retries on %s", self.service.id)
try:
while True:
- txn = yield self.store.get_oldest_unsent_txn(self.service)
+ txn = await self.store.get_oldest_unsent_txn(self.service)
if not txn:
# nothing left: we're done!
- self.callback(self)
+ await self.callback(self)
return
logger.info(
"Retrying transaction %s for AS ID %s", txn.id, txn.service.id
)
- sent = yield txn.send(self.as_api)
+ sent = await txn.send(self.as_api)
if not sent:
break
- yield txn.complete(self.store)
+ await txn.complete(self.store)
# reset the backoff counter and then process the next transaction
self.backoff_counter = 1
|