summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/11340.bugfix1
-rw-r--r--synapse/handlers/auth.py26
2 files changed, 14 insertions, 13 deletions
diff --git a/changelog.d/11340.bugfix b/changelog.d/11340.bugfix
new file mode 100644
index 0000000000..551817f42d
--- /dev/null
+++ b/changelog.d/11340.bugfix
@@ -0,0 +1 @@
+Fix a bug, introduced in Synapse 1.46.0, which caused the `check_3pid_auth` and `on_logged_out` callbacks in legacy password authentication provider modules to not be registered. Modules using the generic module API were not affected.
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 60e59d11a0..b62e13b725 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -1828,13 +1828,6 @@ def load_single_legacy_password_auth_provider(
         logger.error("Error while initializing %r: %s", module, e)
         raise
 
-    # The known hooks. If a module implements a method who's name appears in this set
-    # we'll want to register it
-    password_auth_provider_methods = {
-        "check_3pid_auth",
-        "on_logged_out",
-    }
-
     # All methods that the module provides should be async, but this wasn't enforced
     # in the old module system, so we wrap them if needed
     def async_wrapper(f: Optional[Callable]) -> Optional[Callable[..., Awaitable]]:
@@ -1919,11 +1912,14 @@ def load_single_legacy_password_auth_provider(
 
         return run
 
-    # populate hooks with the implemented methods, wrapped with async_wrapper
-    hooks = {
-        hook: async_wrapper(getattr(provider, hook, None))
-        for hook in password_auth_provider_methods
-    }
+    # If the module has these methods implemented, then we pull them out
+    # and register them as hooks.
+    check_3pid_auth_hook: Optional[CHECK_3PID_AUTH_CALLBACK] = async_wrapper(
+        getattr(provider, "check_3pid_auth", None)
+    )
+    on_logged_out_hook: Optional[ON_LOGGED_OUT_CALLBACK] = async_wrapper(
+        getattr(provider, "on_logged_out", None)
+    )
 
     supported_login_types = {}
     # call get_supported_login_types and add that to the dict
@@ -1950,7 +1946,11 @@ def load_single_legacy_password_auth_provider(
         # need to use a tuple here for ("password",) not a list since lists aren't hashable
         auth_checkers[(LoginType.PASSWORD, ("password",))] = check_password
 
-    api.register_password_auth_provider_callbacks(hooks, auth_checkers=auth_checkers)
+    api.register_password_auth_provider_callbacks(
+        check_3pid_auth=check_3pid_auth_hook,
+        on_logged_out=on_logged_out_hook,
+        auth_checkers=auth_checkers,
+    )
 
 
 CHECK_3PID_AUTH_CALLBACK = Callable[