summary refs log tree commit diff
path: root/synapse/handlers/cas_handler.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/cas_handler.py')
-rw-r--r--synapse/handlers/cas_handler.py19
1 files changed, 13 insertions, 6 deletions
diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py
index 64aaa1335c..a4cc4b9a5a 100644
--- a/synapse/handlers/cas_handler.py
+++ b/synapse/handlers/cas_handler.py
@@ -12,12 +12,10 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import logging
-import xml.etree.ElementTree as ET
+import urllib
 from typing import Dict, Optional, Tuple
-
-from six.moves import urllib
+from xml.etree import ElementTree as ET
 
 from twisted.web.client import PartialDownloadError
 
@@ -37,6 +35,7 @@ class CasHandler:
     """
 
     def __init__(self, hs):
+        self.hs = hs
         self._hostname = hs.hostname
         self._auth_handler = hs.get_auth_handler()
         self._registration_handler = hs.get_registration_handler()
@@ -106,7 +105,7 @@ class CasHandler:
         return user, displayname
 
     def _parse_cas_response(
-        self, cas_response_body: str
+        self, cas_response_body: bytes
     ) -> Tuple[str, Dict[str, Optional[str]]]:
         """
         Retrieve the user and other parameters from the CAS response.
@@ -212,8 +211,16 @@ class CasHandler:
 
         else:
             if not registered_user_id:
+                # Pull out the user-agent and IP from the request.
+                user_agent = request.requestHeaders.getRawHeaders(
+                    b"User-Agent", default=[b""]
+                )[0].decode("ascii", "surrogateescape")
+                ip_address = self.hs.get_ip_from_request(request)
+
                 registered_user_id = await self._registration_handler.register_user(
-                    localpart=localpart, default_display_name=user_display_name
+                    localpart=localpart,
+                    default_display_name=user_display_name,
+                    user_agent_ips=(user_agent, ip_address),
                 )
 
             await self._auth_handler.complete_sso_login(