summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/databases/main/__init__.py37
1 files changed, 37 insertions, 0 deletions
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 3a10c265c9..80c0304b19 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -19,6 +19,7 @@ from typing import TYPE_CHECKING, List, Optional, Tuple, cast
 
 from synapse.api.constants import Direction
 from synapse.config.homeserver import HomeServerConfig
+from synapse.storage._base import make_in_list_sql_clause
 from synapse.storage.database import (
     DatabasePool,
     LoggingDatabaseConnection,
@@ -170,6 +171,7 @@ class DataStore(
         order_by: str = UserSortOrder.NAME.value,
         direction: Direction = Direction.FORWARDS,
         approved: bool = True,
+        not_user_types: Optional[List[str]] = None,
     ) -> Tuple[List[JsonDict], int]:
         """Function to retrieve a paginated list of users from
         users list. This will return a json list of users and the
@@ -185,6 +187,7 @@ class DataStore(
             order_by: the sort order of the returned list
             direction: sort ascending or descending
             approved: whether to include approved users
+            not_user_types: list of user types to exclude
         Returns:
             A tuple of a list of mappings from user to information and a count of total users.
         """
@@ -222,6 +225,40 @@ class DataStore(
                 # be already existing users that we consider as already approved.
                 filters.append("approved IS FALSE")
 
+            if not_user_types:
+                if len(not_user_types) == 1 and not_user_types[0] == "":
+                    # Only exclude NULL type users
+                    filters.append("user_type IS NOT NULL")
+                else:
+                    not_user_types_has_empty = False
+                    not_user_types_without_empty = []
+
+                    for not_user_type in not_user_types:
+                        if not_user_type == "":
+                            not_user_types_has_empty = True
+                        else:
+                            not_user_types_without_empty.append(not_user_type)
+
+                    not_user_type_clause, not_user_type_args = make_in_list_sql_clause(
+                        self.database_engine,
+                        "u.user_type",
+                        not_user_types_without_empty,
+                    )
+
+                    if not_user_types_has_empty:
+                        # NULL values should be excluded.
+                        # They evaluate to false > nothing to do here.
+                        filters.append("NOT %s" % (not_user_type_clause))
+                    else:
+                        # NULL values should *not* be excluded.
+                        # Add a special predicate to the query.
+                        filters.append(
+                            "(NOT %s OR %s IS NULL)"
+                            % (not_user_type_clause, "u.user_type")
+                        )
+
+                    args.extend(not_user_type_args)
+
             where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else ""
 
             sql_base = f"""