Source code for litestar_flags.decorators

"""Decorators for feature flag evaluation."""

from __future__ import annotations

from collections.abc import Callable
from functools import wraps
from typing import TYPE_CHECKING, Any, TypeVar

from litestar.exceptions import NotAuthorizedException

from litestar_flags.context import EvaluationContext
from litestar_flags.middleware import get_request_context

if TYPE_CHECKING:
    from litestar import Request

    from litestar_flags.client import FeatureFlagClient

__all__ = ["feature_flag", "require_flag"]

F = TypeVar("F", bound=Callable[..., Any])


[docs] def feature_flag( flag_key: str, *, default: bool = False, default_response: Any = None, context_key: str | None = None, ) -> Callable[[F], F]: """Conditionally execute route handlers based on a feature flag. When the flag is disabled, the handler returns `default_response` instead of executing the handler function. Args: flag_key: The feature flag key to evaluate. default: Default value if flag is not found. default_response: Response to return when flag is disabled. context_key: Optional request attribute to use as targeting key. Returns: Decorated function. Example: >>> @get("/new-feature") >>> @feature_flag("new_feature", default_response={"error": "Not available"}) >>> async def new_feature_endpoint() -> dict: ... return {"message": "New feature!"} """ def decorator(func: F) -> F: @wraps(func) async def wrapper(*args: Any, **kwargs: Any) -> Any: # Find the request and client from kwargs request: Request[Any, Any, Any] | None = kwargs.get("request") client: FeatureFlagClient | None = kwargs.get("feature_flags") if client is None: # Try to get from request state if request is not None: client = request.app.state.feature_flags if client is None: # No client available, use default if default: return await func(*args, **kwargs) return default_response # Build context context = _build_context(request, context_key) # Evaluate flag enabled = await client.get_boolean_value(flag_key, default=default, context=context) if enabled: return await func(*args, **kwargs) return default_response return wrapper # type: ignore[return-value] return decorator
[docs] def require_flag( flag_key: str, *, default: bool = False, context_key: str | None = None, error_message: str | None = None, ) -> Callable[[F], F]: """Require a feature flag to be enabled for the decorated handler. When the flag is disabled, raises NotAuthorizedException. This is useful for protecting beta or premium features. Args: flag_key: The feature flag key to evaluate. default: Default value if flag is not found. context_key: Optional request attribute to use as targeting key. error_message: Custom error message for the exception. Returns: Decorated function. Raises: NotAuthorizedException: When the flag is disabled. Example: >>> @get("/beta") >>> @require_flag("beta_access", error_message="Beta access required") >>> async def beta_endpoint() -> dict: ... return {"message": "Welcome to beta!"} """ def decorator(func: F) -> F: @wraps(func) async def wrapper(*args: Any, **kwargs: Any) -> Any: # Find the request and client from kwargs request: Request[Any, Any, Any] | None = kwargs.get("request") client: FeatureFlagClient | None = kwargs.get("feature_flags") if client is None: # Try to get from request state if request is not None: client = request.app.state.feature_flags if client is None: # No client available, use default if not default: raise NotAuthorizedException(detail=error_message or f"Feature '{flag_key}' is not available") return await func(*args, **kwargs) # Build context context = _build_context(request, context_key) # Evaluate flag enabled = await client.get_boolean_value(flag_key, default=default, context=context) if not enabled: raise NotAuthorizedException(detail=error_message or f"Feature '{flag_key}' is not available") return await func(*args, **kwargs) return wrapper # type: ignore[return-value] return decorator
def _build_context( request: Request[Any, Any, Any] | None, context_key: str | None, ) -> EvaluationContext | None: """Build evaluation context from request. Args: request: The current request. context_key: Optional attribute to use as targeting key. Returns: Evaluation context or None. """ if request is None: return None # Try to get middleware-extracted context first context = get_request_context(request) if context is not None: if context_key is not None: # Override targeting key targeting_value = _get_context_value(request, context_key) if targeting_value is not None: context = context.with_targeting_key(str(targeting_value)) return context # Build basic context from request targeting_key = None user_id = None if context_key is not None: targeting_value = _get_context_value(request, context_key) if targeting_value is not None: targeting_key = str(targeting_value) # Try to get user ID from auth if hasattr(request, "user") and request.user is not None: user_id = getattr(request.user, "id", None) or getattr(request.user, "user_id", None) if user_id is not None: user_id = str(user_id) if targeting_key is None: targeting_key = user_id return EvaluationContext( targeting_key=targeting_key, user_id=user_id, ) def _get_context_value(request: Request[Any, Any, Any], key: str) -> Any: """Get a value from request for use as context. Args: request: The current request. key: The key to look up. Returns: The value or None. """ # Check path params if key in request.path_params: return request.path_params[key] # Check query params if key in request.query_params: return request.query_params[key] # Check headers header_value = request.headers.get(key) or request.headers.get(key.replace("_", "-")) if header_value is not None: return header_value # Check user attributes if hasattr(request, "user") and request.user is not None: if hasattr(request.user, key): return getattr(request.user, key) return None