summary refs log tree commit diff
path: root/synapse/storage/databases
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases')
-rw-r--r--synapse/storage/databases/main/user_directory.py66
1 files changed, 60 insertions, 6 deletions
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index 0513e7dc06..6e18f714d7 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -1281,7 +1281,7 @@ def _parse_words_with_regex(search_term: str) -> List[str]:
     Break down search term into words, when we don't have ICU available.
     See: `_parse_words`
     """
-    return re.findall(r"([\w\-]+)", search_term, re.UNICODE)
+    return re.findall(r"([\w-]+)", search_term, re.UNICODE)
 
 
 def _parse_words_with_icu(search_term: str) -> List[str]:
@@ -1303,15 +1303,69 @@ def _parse_words_with_icu(search_term: str) -> List[str]:
         if j < 0:
             break
 
-        result = search_term[i:j]
+        # We want to make sure that we split on `@` and `:` specifically, as
+        # they occur in user IDs.
+        for result in re.split(r"[@:]+", search_term[i:j]):
+            results.append(result.strip())
+
+        i = j
+
+    # libicu will break up words that have punctuation in them, but to handle
+    # cases where user IDs have '-', '.' and '_' in them we want to *not* break
+    # those into words and instead allow the DB to tokenise them how it wants.
+    #
+    # In particular, user-71 in postgres gets tokenised to "user, -71", and this
+    # will not match a query for "user, 71".
+    new_results: List[str] = []
+    i = 0
+    while i < len(results):
+        curr = results[i]
+
+        prev = None
+        next = None
+        if i > 0:
+            prev = results[i - 1]
+        if i + 1 < len(results):
+            next = results[i + 1]
+
+        i += 1
 
         # libicu considers spaces and punctuation between words as words, but we don't
         # want to include those in results as they would result in syntax errors in SQL
         # queries (e.g. "foo bar" would result in the search query including "foo &  &
         # bar").
-        if len(re.findall(r"([\w\-]+)", result, re.UNICODE)):
-            results.append(result)
+        if not curr:
+            continue
+
+        if curr in ["-", ".", "_"]:
+            prefix = ""
+            suffix = ""
+
+            # Check if the next item is a word, and if so use it as the suffix.
+            # We check for if its a word as we don't want to concatenate
+            # multiple punctuation marks.
+            if next is not None and re.match(r"\w", next):
+                suffix = next
+                i += 1  # We're using next, so we skip it in the outer loop.
+            else:
+                # We want to avoid creating terms like "user-", as we should
+                # strip trailing punctuation.
+                continue
 
-        i = j
+            if prev and re.match(r"\w", prev) and new_results:
+                prefix = new_results[-1]
+                new_results.pop()
+
+            # We might not have a prefix here, but that's fine as we want to
+            # ensure that we don't strip preceding punctuation e.g. '-71'
+            # shouldn't be converted to '71'.
+
+            new_results.append(f"{prefix}{curr}{suffix}")
+            continue
+        elif not re.match(r"\w", curr):
+            # Ignore other punctuation
+            continue
+
+        new_results.append(curr)
 
-    return results
+    return new_results