summary refs log tree commit diff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--changelog.d/17254.bugfix1
-rw-r--r--synapse/storage/databases/main/user_directory.py66
-rw-r--r--tests/handlers/test_user_directory.py39
-rw-r--r--tests/storage/test_user_directory.py4
4 files changed, 104 insertions, 6 deletions
diff --git a/changelog.d/17254.bugfix b/changelog.d/17254.bugfix
new file mode 100644
index 0000000000..b0d61309e2
--- /dev/null
+++ b/changelog.d/17254.bugfix
@@ -0,0 +1 @@
+Fix searching for users with their exact localpart whose ID includes a hyphen.
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index 0513e7dc06..6e18f714d7 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -1281,7 +1281,7 @@ def _parse_words_with_regex(search_term: str) -> List[str]:
     Break down search term into words, when we don't have ICU available.
     See: `_parse_words`
     """
-    return re.findall(r"([\w\-]+)", search_term, re.UNICODE)
+    return re.findall(r"([\w-]+)", search_term, re.UNICODE)
 
 
 def _parse_words_with_icu(search_term: str) -> List[str]:
@@ -1303,15 +1303,69 @@ def _parse_words_with_icu(search_term: str) -> List[str]:
         if j < 0:
             break
 
-        result = search_term[i:j]
+        # We want to make sure that we split on `@` and `:` specifically, as
+        # they occur in user IDs.
+        for result in re.split(r"[@:]+", search_term[i:j]):
+            results.append(result.strip())
+
+        i = j
+
+    # libicu will break up words that have punctuation in them, but to handle
+    # cases where user IDs have '-', '.' and '_' in them we want to *not* break
+    # those into words and instead allow the DB to tokenise them how it wants.
+    #
+    # In particular, user-71 in postgres gets tokenised to "user, -71", and this
+    # will not match a query for "user, 71".
+    new_results: List[str] = []
+    i = 0
+    while i < len(results):
+        curr = results[i]
+
+        prev = None
+        next = None
+        if i > 0:
+            prev = results[i - 1]
+        if i + 1 < len(results):
+            next = results[i + 1]
+
+        i += 1
 
         # libicu considers spaces and punctuation between words as words, but we don't
         # want to include those in results as they would result in syntax errors in SQL
         # queries (e.g. "foo bar" would result in the search query including "foo &  &
         # bar").
-        if len(re.findall(r"([\w\-]+)", result, re.UNICODE)):
-            results.append(result)
+        if not curr:
+            continue
+
+        if curr in ["-", ".", "_"]:
+            prefix = ""
+            suffix = ""
+
+            # Check if the next item is a word, and if so use it as the suffix.
+            # We check for if its a word as we don't want to concatenate
+            # multiple punctuation marks.
+            if next is not None and re.match(r"\w", next):
+                suffix = next
+                i += 1  # We're using next, so we skip it in the outer loop.
+            else:
+                # We want to avoid creating terms like "user-", as we should
+                # strip trailing punctuation.
+                continue
 
-        i = j
+            if prev and re.match(r"\w", prev) and new_results:
+                prefix = new_results[-1]
+                new_results.pop()
+
+            # We might not have a prefix here, but that's fine as we want to
+            # ensure that we don't strip preceding punctuation e.g. '-71'
+            # shouldn't be converted to '71'.
+
+            new_results.append(f"{prefix}{curr}{suffix}")
+            continue
+        elif not re.match(r"\w", curr):
+            # Ignore other punctuation
+            continue
+
+        new_results.append(curr)
 
-    return results
+    return new_results
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 77c6cac449..878d9683b6 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -1061,6 +1061,45 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
             {alice: ProfileInfo(display_name=None, avatar_url=MXC_DUMMY)},
         )
 
+    def test_search_punctuation(self) -> None:
+        """Test that you can search for a user that includes punctuation"""
+
+        searching_user = self.register_user("searcher", "password")
+        searching_user_tok = self.login("searcher", "password")
+
+        room_id = self.helper.create_room_as(
+            searching_user,
+            room_version=RoomVersions.V1.identifier,
+            tok=searching_user_tok,
+        )
+
+        # We want to test searching for users of the form e.g. "user-1", with
+        # various punctuation. We also test both where the prefix is numeric and
+        # alphanumeric, as e.g. postgres tokenises "user-1" as "user" and "-1".
+        i = 1
+        for char in ["-", ".", "_"]:
+            for use_numeric in [False, True]:
+                if use_numeric:
+                    prefix1 = f"{i}"
+                    prefix2 = f"{i+1}"
+                else:
+                    prefix1 = f"a{i}"
+                    prefix2 = f"a{i+1}"
+
+                local_user_1 = self.register_user(f"user{char}{prefix1}", "password")
+                local_user_2 = self.register_user(f"user{char}{prefix2}", "password")
+
+                self._add_user_to_room(room_id, RoomVersions.V1, local_user_1)
+                self._add_user_to_room(room_id, RoomVersions.V1, local_user_2)
+
+                results = self.get_success(
+                    self.handler.search_users(searching_user, local_user_1, 20)
+                )["results"]
+                received_user_id_ordering = [result["user_id"] for result in results]
+                self.assertSequenceEqual(received_user_id_ordering[:1], [local_user_1])
+
+                i += 2
+
 
 class TestUserDirSearchDisabled(unittest.HomeserverTestCase):
     servlets = [
diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index 156a610faa..c26932069f 100644
--- a/tests/storage/test_user_directory.py
+++ b/tests/storage/test_user_directory.py
@@ -711,6 +711,10 @@ class UserDirectoryICUTestCase(HomeserverTestCase):
             ),
         )
 
+        self.assertEqual(_parse_words_with_icu("user-1"), ["user-1"])
+        self.assertEqual(_parse_words_with_icu("user-ab"), ["user-ab"])
+        self.assertEqual(_parse_words_with_icu("user.--1"), ["user", "-1"])
+
     def test_regex_word_boundary_punctuation(self) -> None:
         """
         Tests the behaviour of punctuation with the non-ICU tokeniser