Source code for litestar_flags.client

"""Feature flag client for evaluation."""

from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Any, TypeVar

from litestar_flags.context import EvaluationContext
from litestar_flags.engine import EvaluationEngine
from litestar_flags.results import EvaluationDetails
from litestar_flags.security import sanitize_error_message
from litestar_flags.types import ErrorCode, EvaluationReason, FlagType

if TYPE_CHECKING:
    from litestar_flags.bootstrap import BootstrapConfig
    from litestar_flags.cache import CacheProtocol, CacheStats
    from litestar_flags.models.flag import FeatureFlag
    from litestar_flags.protocols import StorageBackend
    from litestar_flags.rate_limit import RateLimiter

__all__ = ["FeatureFlagClient"]

logger = logging.getLogger(__name__)

T = TypeVar("T")

# Cache key prefix for flags
_CACHE_KEY_PREFIX = "flag:"


[docs] class FeatureFlagClient: """Main client for feature flag evaluation. Provides type-safe methods for all flag types with automatic caching and graceful degradation (never throws exceptions). Example: >>> client = FeatureFlagClient(storage=MemoryStorageBackend()) >>> enabled = await client.get_boolean_value("my-feature", default=False) >>> variant = await client.get_string_value("ab-test", default="control") """
[docs] def __init__( self, storage: StorageBackend, default_context: EvaluationContext | None = None, rate_limiter: RateLimiter | None = None, cache: CacheProtocol | None = None, ) -> None: """Initialize the feature flag client. Args: storage: The storage backend for flag data. default_context: Default evaluation context to use when none is provided. rate_limiter: Optional rate limiter to control evaluation throughput. cache: Optional cache for flag data. When provided, flag lookups will check the cache before hitting storage, and cache entries will be populated after storage reads. """ self._storage = storage self._default_context = default_context or EvaluationContext() self._engine = EvaluationEngine() self._rate_limiter = rate_limiter self._cache = cache self._preloaded_flags: dict[str, FeatureFlag] = {} self._closed = False
@property def storage(self) -> StorageBackend: """Get the storage backend.""" return self._storage @property def rate_limiter(self) -> RateLimiter | None: """Get the rate limiter.""" return self._rate_limiter @property def cache(self) -> CacheProtocol | None: """Get the cache instance.""" return self._cache
[docs] def cache_stats(self) -> CacheStats | None: """Get cache statistics. Returns: CacheStats if a cache is configured, None otherwise. """ if self._cache is not None: return self._cache.stats() return None
# Bootstrap and preload methods
[docs] @classmethod async def bootstrap( cls, config: BootstrapConfig, storage: StorageBackend, default_context: EvaluationContext | None = None, rate_limiter: RateLimiter | None = None, cache: CacheProtocol | None = None, ) -> FeatureFlagClient: """Create a client with flags bootstrapped from a static source. Loads flags from the bootstrap configuration and stores them in the provided storage backend, then returns a configured client. Args: config: Bootstrap configuration specifying flag source. storage: Storage backend to populate with bootstrap flags. default_context: Default evaluation context. rate_limiter: Optional rate limiter. cache: Optional cache for flag data. Returns: Configured FeatureFlagClient with bootstrapped flags. Example: >>> config = BootstrapConfig(source=Path("flags.json")) >>> client = await FeatureFlagClient.bootstrap( ... config=config, ... storage=MemoryStorageBackend(), ... ) """ from litestar_flags.bootstrap import BootstrapLoader loader = BootstrapLoader() flags = await loader.load(config) # Store bootstrap flags in the storage backend for flag in flags: try: await storage.create_flag(flag) except ValueError: # Flag already exists, update it await storage.update_flag(flag) return cls( storage=storage, default_context=default_context, rate_limiter=rate_limiter, cache=cache, )
[docs] async def preload_flags( self, flag_keys: list[str] | None = None, ) -> dict[str, FeatureFlag]: """Preload flags into the client's cache for faster evaluation. This method fetches flags from storage and caches them locally. Useful for warming up the client at startup to avoid cold-start latency on first evaluations. Args: flag_keys: Optional list of specific flag keys to preload. If None, preloads all active flags. Returns: Dictionary of preloaded flags keyed by flag key. Example: >>> await client.preload_flags() # Preload all flags >>> await client.preload_flags(["feature-a", "feature-b"]) """ try: if flag_keys is None: flags = await self._storage.get_all_active_flags() self._preloaded_flags = {flag.key: flag for flag in flags} else: flags_dict = await self._storage.get_flags(flag_keys) self._preloaded_flags.update(flags_dict) logger.info(f"Preloaded {len(self._preloaded_flags)} flags") return self._preloaded_flags.copy() except Exception as e: logger.error(f"Error preloading flags: {e}") return {}
[docs] def clear_preloaded_flags(self) -> None: """Clear the preloaded flags cache. Call this method when you want to force fresh flag fetches from the storage backend. """ self._preloaded_flags.clear() logger.debug("Cleared preloaded flags cache")
[docs] async def clear_cache(self) -> None: """Clear the external cache. This method clears all entries in the external cache if one is configured. Use this when you need to invalidate all cached flag data. Note: This does not clear the preloaded flags. Use clear_preloaded_flags() for that, or clear_all_caches() to clear both. """ if self._cache is not None: await self._cache.clear() logger.debug("Cleared external cache")
[docs] async def clear_all_caches(self) -> None: """Clear both preloaded flags and external cache. Convenience method to ensure all cached data is cleared. """ self.clear_preloaded_flags() await self.clear_cache()
[docs] async def invalidate_flag(self, flag_key: str) -> None: """Invalidate a specific flag from all caches. Args: flag_key: The flag key to invalidate. """ # Remove from preloaded flags self._preloaded_flags.pop(flag_key, None) # Remove from external cache if self._cache is not None: cache_key = f"{_CACHE_KEY_PREFIX}{flag_key}" await self._cache.delete(cache_key) logger.debug(f"Invalidated flag '{flag_key}' from all caches")
async def _get_flag_with_cache(self, flag_key: str) -> FeatureFlag | None: """Get a flag, checking preloaded cache and external cache first. The lookup order is: 1. Preloaded flags (in-memory, set via preload_flags()) 2. External cache (if configured) 3. Storage backend Args: flag_key: The flag key to retrieve. Returns: The flag if found, None otherwise. """ # Check preloaded flags first (fastest) if flag_key in self._preloaded_flags: return self._preloaded_flags[flag_key] cache_key = f"{_CACHE_KEY_PREFIX}{flag_key}" # Check external cache if configured if self._cache is not None: try: cached = await self._cache.get(cache_key) if cached is not None: # Reconstruct flag from cached data return self._deserialize_cached_flag(cached) except Exception as e: logger.warning(f"Cache get error for '{flag_key}': {e}") # Fall back to storage flag = await self._storage.get_flag(flag_key) # Populate cache on successful storage read if flag is not None and self._cache is not None: try: serialized = self._serialize_flag_for_cache(flag) await self._cache.set(cache_key, serialized) except Exception as e: logger.warning(f"Cache set error for '{flag_key}': {e}") return flag def _serialize_flag_for_cache(self, flag: FeatureFlag) -> dict[str, Any]: """Serialize a flag for cache storage. Args: flag: The flag to serialize. Returns: Dictionary representation of the flag. """ return { "id": str(flag.id), "key": flag.key, "name": flag.name, "description": flag.description, "flag_type": flag.flag_type.value, "status": flag.status.value, "default_enabled": flag.default_enabled, "default_value": flag.default_value, "tags": flag.tags, "metadata": flag.metadata_, "rules": [ { "id": str(r.id), "name": r.name, "description": r.description, "priority": r.priority, "enabled": r.enabled, "conditions": r.conditions, "serve_enabled": r.serve_enabled, "serve_value": r.serve_value, "rollout_percentage": r.rollout_percentage, } for r in (flag.rules or []) ], "variants": [ { "id": str(v.id), "key": v.key, "name": v.name, "description": v.description, "value": v.value, "weight": v.weight, } for v in (flag.variants or []) ], "created_at": flag.created_at.isoformat() if flag.created_at else None, "updated_at": flag.updated_at.isoformat() if flag.updated_at else None, } def _deserialize_cached_flag(self, data: dict[str, Any]) -> FeatureFlag: """Deserialize a flag from cache storage. Args: data: The cached dictionary representation. Returns: Reconstructed FeatureFlag object. """ from datetime import datetime from uuid import UUID from litestar_flags.models.flag import FeatureFlag from litestar_flags.models.rule import FlagRule from litestar_flags.models.variant import FlagVariant from litestar_flags.types import FlagStatus, FlagType # Create rule objects rules = [ FlagRule( id=UUID(r["id"]), name=r["name"], description=r.get("description"), priority=r["priority"], enabled=r["enabled"], conditions=r["conditions"], serve_enabled=r["serve_enabled"], serve_value=r.get("serve_value"), rollout_percentage=r.get("rollout_percentage"), ) for r in data.get("rules", []) ] # Create variant objects variants = [ FlagVariant( id=UUID(v["id"]), key=v["key"], name=v["name"], description=v.get("description"), value=v["value"], weight=v["weight"], ) for v in data.get("variants", []) ] return FeatureFlag( id=UUID(data["id"]), key=data["key"], name=data["name"], description=data.get("description"), flag_type=FlagType(data["flag_type"]), status=FlagStatus(data["status"]), default_enabled=data["default_enabled"], default_value=data.get("default_value"), tags=data.get("tags", []), metadata_=data.get("metadata", {}), rules=rules, variants=variants, created_at=(datetime.fromisoformat(data["created_at"]) if data.get("created_at") else None), updated_at=(datetime.fromisoformat(data["updated_at"]) if data.get("updated_at") else None), ) # Boolean evaluation
[docs] async def get_boolean_value( self, flag_key: str, default: bool = False, context: EvaluationContext | None = None, ) -> bool: """Evaluate a boolean flag. Args: flag_key: The unique flag key. default: Default value if flag is not found or evaluation fails. context: Optional evaluation context. Returns: The evaluated boolean value. """ details = await self.get_boolean_details(flag_key, default, context) return details.value
[docs] async def get_boolean_details( self, flag_key: str, default: bool = False, context: EvaluationContext | None = None, ) -> EvaluationDetails[bool]: """Evaluate a boolean flag with details. Args: flag_key: The unique flag key. default: Default value if flag is not found or evaluation fails. context: Optional evaluation context. Returns: EvaluationDetails containing the value and metadata. """ return await self._evaluate(flag_key, default, FlagType.BOOLEAN, context)
# String evaluation
[docs] async def get_string_value( self, flag_key: str, default: str = "", context: EvaluationContext | None = None, ) -> str: """Evaluate a string flag. Args: flag_key: The unique flag key. default: Default value if flag is not found or evaluation fails. context: Optional evaluation context. Returns: The evaluated string value. """ details = await self.get_string_details(flag_key, default, context) return details.value
[docs] async def get_string_details( self, flag_key: str, default: str = "", context: EvaluationContext | None = None, ) -> EvaluationDetails[str]: """Evaluate a string flag with details. Args: flag_key: The unique flag key. default: Default value if flag is not found or evaluation fails. context: Optional evaluation context. Returns: EvaluationDetails containing the value and metadata. """ return await self._evaluate(flag_key, default, FlagType.STRING, context)
# Number evaluation
[docs] async def get_number_value( self, flag_key: str, default: float = 0.0, context: EvaluationContext | None = None, ) -> float: """Evaluate a number flag. Args: flag_key: The unique flag key. default: Default value if flag is not found or evaluation fails. context: Optional evaluation context. Returns: The evaluated number value. """ details = await self.get_number_details(flag_key, default, context) return details.value
[docs] async def get_number_details( self, flag_key: str, default: float = 0.0, context: EvaluationContext | None = None, ) -> EvaluationDetails[float]: """Evaluate a number flag with details. Args: flag_key: The unique flag key. default: Default value if flag is not found or evaluation fails. context: Optional evaluation context. Returns: EvaluationDetails containing the value and metadata. """ return await self._evaluate(flag_key, default, FlagType.NUMBER, context)
# Object/JSON evaluation
[docs] async def get_object_value( self, flag_key: str, default: dict[str, Any] | None = None, context: EvaluationContext | None = None, ) -> dict[str, Any]: """Evaluate an object/JSON flag. Args: flag_key: The unique flag key. default: Default value if flag is not found or evaluation fails. context: Optional evaluation context. Returns: The evaluated object value. """ details = await self.get_object_details(flag_key, default or {}, context) return details.value
[docs] async def get_object_details( self, flag_key: str, default: dict[str, Any], context: EvaluationContext | None = None, ) -> EvaluationDetails[dict[str, Any]]: """Evaluate an object/JSON flag with details. Args: flag_key: The unique flag key. default: Default value if flag is not found or evaluation fails. context: Optional evaluation context. Returns: EvaluationDetails containing the value and metadata. """ return await self._evaluate(flag_key, default, FlagType.JSON, context)
# Convenience methods
[docs] async def is_enabled( self, flag_key: str, context: EvaluationContext | None = None, ) -> bool: """Check if a boolean flag is enabled. Shorthand for `get_boolean_value(flag_key, default=False, context)`. Args: flag_key: The unique flag key. context: Optional evaluation context. Returns: True if the flag is enabled, False otherwise. """ return await self.get_boolean_value(flag_key, default=False, context=context)
# Bulk evaluation
[docs] async def get_all_flags( self, context: EvaluationContext | None = None, ) -> dict[str, EvaluationDetails[Any]]: """Evaluate all active flags. Args: context: Optional evaluation context. Returns: Dictionary mapping flag keys to their evaluation details. """ ctx = self._merge_context(context) results: dict[str, EvaluationDetails[Any]] = {} try: flags = await self._storage.get_all_active_flags() for flag in flags: try: results[flag.key] = await self._evaluate_flag(flag, ctx) except Exception as e: logger.warning(f"Error evaluating flag '{flag.key}': {e}") # Skip failed evaluations in bulk mode continue except Exception as e: logger.error(f"Error fetching flags: {e}") return results
[docs] async def get_flags( self, flag_keys: list[str], context: EvaluationContext | None = None, ) -> dict[str, EvaluationDetails[Any]]: """Evaluate specific flags by key. Args: flag_keys: List of flag keys to evaluate. context: Optional evaluation context. Returns: Dictionary mapping flag keys to their evaluation details. """ ctx = self._merge_context(context) results: dict[str, EvaluationDetails[Any]] = {} try: flags = await self._storage.get_flags(flag_keys) for key, flag in flags.items(): try: results[key] = await self._evaluate_flag(flag, ctx) except Exception as e: logger.warning(f"Error evaluating flag '{key}': {e}") continue except Exception as e: logger.error(f"Error fetching flags: {e}") return results
# Internal methods async def _evaluate( self, flag_key: str, default: T, expected_type: FlagType, context: EvaluationContext | None, ) -> EvaluationDetails[T]: """Core evaluation logic with error handling. This method NEVER throws exceptions - it always returns a result with the default value on error. Args: flag_key: The flag key to evaluate. default: Default value on error or not found. expected_type: Expected flag type for validation. context: Optional evaluation context. Returns: EvaluationDetails with the evaluated or default value. """ ctx = self._merge_context(context) try: # Check rate limits if rate limiter is configured if self._rate_limiter is not None: await self._rate_limiter.acquire(flag_key) # Use preload cache, external cache, then fall back to storage flag = await self._get_flag_with_cache(flag_key) if flag is None: return EvaluationDetails( value=default, flag_key=flag_key, reason=EvaluationReason.DEFAULT, error_code=ErrorCode.FLAG_NOT_FOUND, error_message=f"Flag '{flag_key}' not found", ) # Type validation (skip for boolean as it's always compatible) if expected_type != FlagType.BOOLEAN and flag.flag_type != expected_type: return EvaluationDetails( value=default, flag_key=flag_key, reason=EvaluationReason.ERROR, error_code=ErrorCode.TYPE_MISMATCH, error_message=f"Expected type '{expected_type.value}', got '{flag.flag_type.value}'", ) result = await self._evaluate_flag(flag, ctx) # Cast to expected type return EvaluationDetails( value=result.value, # type: ignore[arg-type] flag_key=result.flag_key, reason=result.reason, variant=result.variant, error_code=result.error_code, error_message=result.error_message, flag_metadata=result.flag_metadata, ) except Exception as e: # Import here to avoid circular imports from litestar_flags.exceptions import RateLimitExceededError # Sanitize error message to prevent information disclosure safe_error = sanitize_error_message(e) # Handle rate limit exceptions specially if isinstance(e, RateLimitExceededError): logger.warning(f"Rate limit exceeded for flag '{flag_key}': {safe_error}") return EvaluationDetails( value=default, flag_key=flag_key, reason=EvaluationReason.ERROR, error_code=ErrorCode.GENERAL_ERROR, error_message=f"Rate limit exceeded: {safe_error}", ) logger.error(f"Error evaluating flag '{flag_key}': {safe_error}") return EvaluationDetails( value=default, flag_key=flag_key, reason=EvaluationReason.ERROR, error_code=ErrorCode.GENERAL_ERROR, error_message=safe_error, ) async def _evaluate_flag( self, flag: FeatureFlag, context: EvaluationContext, ) -> EvaluationDetails[Any]: """Evaluate a single flag using the engine. Args: flag: The flag to evaluate. context: The evaluation context. Returns: EvaluationDetails from the engine. """ return await self._engine.evaluate(flag, context, self._storage) def _merge_context(self, context: EvaluationContext | None) -> EvaluationContext: """Merge provided context with default context. Args: context: The provided context (may be None). Returns: Merged context with defaults. """ if context is None: return self._default_context return self._default_context.merge(context)
[docs] async def health_check(self) -> bool: """Check if the client and storage are healthy. Returns: True if healthy, False otherwise. """ if self._closed: return False try: return await self._storage.health_check() except Exception: return False
[docs] async def close(self) -> None: """Close the client and release resources.""" if not self._closed: self._closed = True await self._storage.close()
async def __aenter__(self) -> FeatureFlagClient: """Async context manager entry.""" return self async def __aexit__(self, *args: Any) -> None: """Async context manager exit.""" await self.close()