Source code for litestar_flags.storage.redis

"""Redis storage backend for feature flags."""

from __future__ import annotations

import json
from collections.abc import Sequence
from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any
from uuid import UUID

try:
    from redis.asyncio import Redis
except ImportError as e:
    raise ImportError("Redis backend requires 'redis'. Install with: pip install litestar-flags[redis]") from e

from litestar_flags.types import ChangeType, FlagStatus, FlagType, RecurrenceType

if TYPE_CHECKING:
    from litestar_flags.models.flag import FeatureFlag
    from litestar_flags.models.override import FlagOverride
    from litestar_flags.models.schedule import RolloutPhase, ScheduledFlagChange, TimeSchedule

__all__ = ["RedisStorageBackend"]


[docs] class RedisStorageBackend: """Redis storage backend for distributed feature flags. This backend stores feature flags in Redis, making it suitable for distributed deployments where multiple application instances need to share flag state. Data is stored as JSON strings with the following key patterns: - Flags: {prefix}flag:{key} - Overrides: {prefix}override:{flag_id}:{entity_type}:{entity_id} - Flag index: {prefix}flags (SET of all flag keys) Example: >>> storage = await RedisStorageBackend.create( ... url="redis://localhost:6379", ... prefix="ff:" ... ) >>> flag = await storage.get_flag("my-feature") """
[docs] def __init__(self, redis: Redis, prefix: str = "feature_flags:") -> None: # type: ignore[type-arg] """Initialize the Redis storage backend. Args: redis: The Redis client instance. prefix: Key prefix for all stored data. """ self._redis = redis self._prefix = prefix
[docs] @classmethod async def create( cls, url: str, prefix: str = "feature_flags:", **redis_kwargs: Any, ) -> RedisStorageBackend: """Create a new Redis storage backend. Args: url: Redis connection URL. prefix: Key prefix for all stored data. **redis_kwargs: Additional arguments for Redis.from_url(). Returns: Configured RedisStorageBackend instance. """ redis = Redis.from_url(url, decode_responses=True, **redis_kwargs) # Test connection await redis.ping() return cls(redis=redis, prefix=prefix)
def _flag_key(self, key: str) -> str: """Get the Redis key for a flag.""" return f"{self._prefix}flag:{key}" def _override_key(self, flag_id: UUID, entity_type: str, entity_id: str) -> str: """Get the Redis key for an override.""" return f"{self._prefix}override:{flag_id}:{entity_type}:{entity_id}" def _flags_index_key(self) -> str: """Get the Redis key for the flags index set.""" return f"{self._prefix}flags" def _scheduled_change_key(self, change_id: UUID) -> str: """Get the Redis key for a scheduled change.""" return f"{self._prefix}schedule:{change_id}" def _scheduled_changes_index_key(self, flag_id: UUID | None = None) -> str: """Get the Redis key for the scheduled changes index set.""" if flag_id is not None: return f"{self._prefix}schedules:flag:{flag_id}" return f"{self._prefix}schedules" def _time_schedule_key(self, schedule_id: UUID) -> str: """Get the Redis key for a time schedule.""" return f"{self._prefix}time_schedule:{schedule_id}" def _time_schedules_index_key(self, flag_id: UUID | None = None) -> str: """Get the Redis key for the time schedules index set.""" if flag_id is not None: return f"{self._prefix}time_schedules:flag:{flag_id}" return f"{self._prefix}time_schedules" def _rollout_phase_key(self, phase_id: UUID) -> str: """Get the Redis key for a rollout phase.""" return f"{self._prefix}rollout_phase:{phase_id}" def _rollout_phases_index_key(self, flag_id: UUID) -> str: """Get the Redis key for the rollout phases index set for a flag.""" return f"{self._prefix}rollout_phases:flag:{flag_id}" def _serialize_flag(self, flag: FeatureFlag) -> str: """Serialize a flag to JSON.""" data = { "id": str(flag.id), "key": flag.key, "name": flag.name, "description": flag.description, "flag_type": flag.flag_type.value, "status": flag.status.value, "default_enabled": flag.default_enabled, "default_value": flag.default_value, "tags": flag.tags, "metadata": flag.metadata_, "rules": [ { "id": str(r.id), "name": r.name, "description": r.description, "priority": r.priority, "enabled": r.enabled, "conditions": r.conditions, "serve_enabled": r.serve_enabled, "serve_value": r.serve_value, "rollout_percentage": r.rollout_percentage, } for r in (flag.rules or []) ], "variants": [ { "id": str(v.id), "key": v.key, "name": v.name, "description": v.description, "value": v.value, "weight": v.weight, } for v in (flag.variants or []) ], "created_at": flag.created_at.isoformat() if flag.created_at else None, "updated_at": flag.updated_at.isoformat() if flag.updated_at else None, } return json.dumps(data) def _deserialize_flag(self, data: str) -> FeatureFlag: """Deserialize a flag from JSON.""" from litestar_flags.models.flag import FeatureFlag from litestar_flags.models.rule import FlagRule from litestar_flags.models.variant import FlagVariant obj = json.loads(data) # Create rule objects rules = [ FlagRule( id=UUID(r["id"]), name=r["name"], description=r.get("description"), priority=r["priority"], enabled=r["enabled"], conditions=r["conditions"], serve_enabled=r["serve_enabled"], serve_value=r.get("serve_value"), rollout_percentage=r.get("rollout_percentage"), ) for r in obj.get("rules", []) ] # Create variant objects variants = [ FlagVariant( id=UUID(v["id"]), key=v["key"], name=v["name"], description=v.get("description"), value=v["value"], weight=v["weight"], ) for v in obj.get("variants", []) ] return FeatureFlag( id=UUID(obj["id"]), key=obj["key"], name=obj["name"], description=obj.get("description"), flag_type=FlagType(obj["flag_type"]), status=FlagStatus(obj["status"]), default_enabled=obj["default_enabled"], default_value=obj.get("default_value"), tags=obj.get("tags", []), metadata_=obj.get("metadata", {}), rules=rules, variants=variants, created_at=datetime.fromisoformat(obj["created_at"]) if obj.get("created_at") else None, updated_at=datetime.fromisoformat(obj["updated_at"]) if obj.get("updated_at") else None, ) def _serialize_override(self, override: FlagOverride) -> str: """Serialize an override to JSON.""" data = { "id": str(override.id), "flag_id": str(override.flag_id), "entity_type": override.entity_type, "entity_id": override.entity_id, "enabled": override.enabled, "value": override.value, "expires_at": override.expires_at.isoformat() if override.expires_at else None, "created_at": override.created_at.isoformat() if override.created_at else None, "updated_at": override.updated_at.isoformat() if override.updated_at else None, } return json.dumps(data) def _deserialize_override(self, data: str) -> FlagOverride: """Deserialize an override from JSON.""" from litestar_flags.models.override import FlagOverride obj = json.loads(data) return FlagOverride( id=UUID(obj["id"]), flag_id=UUID(obj["flag_id"]), entity_type=obj["entity_type"], entity_id=obj["entity_id"], enabled=obj["enabled"], value=obj.get("value"), expires_at=datetime.fromisoformat(obj["expires_at"]) if obj.get("expires_at") else None, created_at=datetime.fromisoformat(obj["created_at"]) if obj.get("created_at") else None, updated_at=datetime.fromisoformat(obj["updated_at"]) if obj.get("updated_at") else None, ) def _serialize_scheduled_change(self, change: ScheduledFlagChange) -> str: """Serialize a scheduled change to JSON.""" data = { "id": str(change.id), "flag_id": str(change.flag_id), "change_type": change.change_type.value, "scheduled_at": change.scheduled_at.isoformat(), "executed": change.executed, "executed_at": change.executed_at.isoformat() if change.executed_at else None, "new_value": change.new_value, "new_rollout_percentage": change.new_rollout_percentage, "created_by": change.created_by, "created_at": change.created_at.isoformat() if change.created_at else None, "updated_at": change.updated_at.isoformat() if change.updated_at else None, } return json.dumps(data) def _deserialize_scheduled_change(self, data: str) -> ScheduledFlagChange: """Deserialize a scheduled change from JSON.""" from litestar_flags.models.schedule import ScheduledFlagChange obj = json.loads(data) return ScheduledFlagChange( id=UUID(obj["id"]), flag_id=UUID(obj["flag_id"]), change_type=ChangeType(obj["change_type"]), scheduled_at=datetime.fromisoformat(obj["scheduled_at"]), executed=obj.get("executed", False), executed_at=datetime.fromisoformat(obj["executed_at"]) if obj.get("executed_at") else None, new_value=obj.get("new_value"), new_rollout_percentage=obj.get("new_rollout_percentage"), created_by=obj.get("created_by"), created_at=datetime.fromisoformat(obj["created_at"]) if obj.get("created_at") else None, updated_at=datetime.fromisoformat(obj["updated_at"]) if obj.get("updated_at") else None, ) def _serialize_time_schedule(self, schedule: TimeSchedule) -> str: """Serialize a time schedule to JSON.""" from datetime import time data = { "id": str(schedule.id), "flag_id": str(schedule.flag_id), "name": schedule.name, "recurrence_type": schedule.recurrence_type.value, "start_time": ( schedule.start_time.isoformat() if isinstance(schedule.start_time, time) else schedule.start_time ), "end_time": (schedule.end_time.isoformat() if isinstance(schedule.end_time, time) else schedule.end_time), "days_of_week": schedule.days_of_week, "days_of_month": schedule.days_of_month, "cron_expression": schedule.cron_expression, "timezone": schedule.timezone, "enabled": schedule.enabled, "created_at": schedule.created_at.isoformat() if schedule.created_at else None, "updated_at": schedule.updated_at.isoformat() if schedule.updated_at else None, } return json.dumps(data) def _deserialize_time_schedule(self, data: str) -> TimeSchedule: """Deserialize a time schedule from JSON.""" from datetime import time from litestar_flags.models.schedule import TimeSchedule obj = json.loads(data) # Handle time parsing start_time_str = obj["start_time"] end_time_str = obj["end_time"] start_time = time.fromisoformat(start_time_str) if start_time_str else time(0, 0) end_time = time.fromisoformat(end_time_str) if end_time_str else time(23, 59) return TimeSchedule( id=UUID(obj["id"]), flag_id=UUID(obj["flag_id"]), name=obj["name"], recurrence_type=RecurrenceType(obj["recurrence_type"]), start_time=start_time, end_time=end_time, days_of_week=obj.get("days_of_week"), days_of_month=obj.get("days_of_month"), cron_expression=obj.get("cron_expression"), timezone=obj.get("timezone", "UTC"), enabled=obj.get("enabled", True), created_at=datetime.fromisoformat(obj["created_at"]) if obj.get("created_at") else None, updated_at=datetime.fromisoformat(obj["updated_at"]) if obj.get("updated_at") else None, ) def _serialize_rollout_phase(self, phase: RolloutPhase) -> str: """Serialize a rollout phase to JSON.""" data = { "id": str(phase.id), "flag_id": str(phase.flag_id), "phase_number": phase.phase_number, "target_percentage": phase.target_percentage, "scheduled_at": phase.scheduled_at.isoformat(), "executed": phase.executed, "executed_at": phase.executed_at.isoformat() if phase.executed_at else None, "created_at": phase.created_at.isoformat() if phase.created_at else None, "updated_at": phase.updated_at.isoformat() if phase.updated_at else None, } return json.dumps(data) def _deserialize_rollout_phase(self, data: str) -> RolloutPhase: """Deserialize a rollout phase from JSON.""" from litestar_flags.models.schedule import RolloutPhase obj = json.loads(data) return RolloutPhase( id=UUID(obj["id"]), flag_id=UUID(obj["flag_id"]), phase_number=obj["phase_number"], target_percentage=obj["target_percentage"], scheduled_at=datetime.fromisoformat(obj["scheduled_at"]), executed=obj.get("executed", False), executed_at=datetime.fromisoformat(obj["executed_at"]) if obj.get("executed_at") else None, created_at=datetime.fromisoformat(obj["created_at"]) if obj.get("created_at") else None, updated_at=datetime.fromisoformat(obj["updated_at"]) if obj.get("updated_at") else None, )
[docs] async def get_flag(self, key: str) -> FeatureFlag | None: """Retrieve a single flag by key. Args: key: The unique flag key. Returns: The FeatureFlag if found, None otherwise. """ data = await self._redis.get(self._flag_key(key)) if data is None: return None return self._deserialize_flag(data)
[docs] async def get_flags(self, keys: Sequence[str]) -> dict[str, FeatureFlag]: """Retrieve multiple flags by keys. Args: keys: Sequence of flag keys to retrieve. Returns: Dictionary mapping flag keys to FeatureFlag objects. """ if not keys: return {} redis_keys = [self._flag_key(k) for k in keys] values = await self._redis.mget(redis_keys) result = {} for key, value in zip(keys, values, strict=False): if value is not None: result[key] = self._deserialize_flag(value) return result
[docs] async def get_all_active_flags(self) -> list[FeatureFlag]: """Retrieve all active flags. Returns: List of all FeatureFlag objects with ACTIVE status. """ # Get all flag keys from the index keys = await self._redis.smembers(self._flags_index_key()) if not keys: return [] # Get all flags flags = await self.get_flags(list(keys)) # Filter to active only return [f for f in flags.values() if f.status == FlagStatus.ACTIVE]
[docs] async def get_override( self, flag_id: UUID, entity_type: str, entity_id: str, ) -> FlagOverride | None: """Retrieve entity-specific override. Args: flag_id: The flag's UUID. entity_type: Type of entity (e.g., "user", "organization"). entity_id: The entity's identifier. Returns: The FlagOverride if found and not expired, None otherwise. """ key = self._override_key(flag_id, entity_type, entity_id) data = await self._redis.get(key) if data is None: return None override = self._deserialize_override(data) # Check expiration if override.is_expired(datetime.now(UTC)): await self._redis.delete(key) return None return override
[docs] async def create_flag(self, flag: FeatureFlag) -> FeatureFlag: """Create a new flag. Args: flag: The flag to create. Returns: The created flag. """ now = datetime.now(UTC) if flag.created_at is None: flag.created_at = now # type: ignore[misc] if flag.updated_at is None: flag.updated_at = now # type: ignore[misc] # Store flag await self._redis.set(self._flag_key(flag.key), self._serialize_flag(flag)) # Add to index await self._redis.sadd(self._flags_index_key(), flag.key) return flag
[docs] async def update_flag(self, flag: FeatureFlag) -> FeatureFlag: """Update an existing flag. Args: flag: The flag with updated values. Returns: The updated flag. """ flag.updated_at = datetime.now(UTC) # type: ignore[misc] await self._redis.set(self._flag_key(flag.key), self._serialize_flag(flag)) return flag
[docs] async def delete_flag(self, key: str) -> bool: """Delete a flag by key. Args: key: The unique flag key. Returns: True if the flag was deleted, False if not found. """ # Get flag to find its ID for override cleanup flag = await self.get_flag(key) if flag is None: return False # Delete flag await self._redis.delete(self._flag_key(key)) # Remove from index await self._redis.srem(self._flags_index_key(), key) # Note: Overrides are not automatically cleaned up # This could be handled with a pattern scan if needed return True
[docs] async def create_override(self, override: FlagOverride) -> FlagOverride: """Create a new override. Args: override: The override to create. Returns: The created override. """ if override.flag_id is None: raise ValueError("Override must have a flag_id") now = datetime.now(UTC) if override.created_at is None: override.created_at = now # type: ignore[misc] if override.updated_at is None: override.updated_at = now # type: ignore[misc] key = self._override_key(override.flag_id, override.entity_type, override.entity_id) await self._redis.set(key, self._serialize_override(override)) # Set TTL if expires_at is set if override.expires_at: ttl = int((override.expires_at - now).total_seconds()) if ttl > 0: await self._redis.expire(key, ttl) return override
[docs] async def delete_override( self, flag_id: UUID, entity_type: str, entity_id: str, ) -> bool: """Delete an override. Args: flag_id: The flag's UUID. entity_type: Type of entity. entity_id: The entity's identifier. Returns: True if the override was deleted, False if not found. """ key = self._override_key(flag_id, entity_type, entity_id) result = await self._redis.delete(key) return result > 0
[docs] async def health_check(self) -> bool: """Check storage backend health. Returns: True if the backend is healthy, False otherwise. """ try: await self._redis.ping() return True except Exception: return False
[docs] async def close(self) -> None: """Close Redis connections.""" await self._redis.aclose()
# Scheduled changes methods
[docs] async def get_scheduled_changes( self, flag_id: UUID | None = None, pending_only: bool = True, ) -> list[ScheduledFlagChange]: """Get scheduled changes, optionally filtered by flag and status. Args: flag_id: If provided, filter to changes for this flag only. pending_only: If True, only return changes not yet executed. Returns: List of scheduled changes matching the criteria. """ # Get all change IDs from the appropriate index index_key = self._scheduled_changes_index_key(flag_id) change_ids = await self._redis.smembers(index_key) if not change_ids: # If filtering by flag_id and no index, try the global index if flag_id is not None: global_ids = await self._redis.smembers(self._scheduled_changes_index_key()) change_ids = global_ids else: return [] # Get all changes result = [] for change_id in change_ids: key = self._scheduled_change_key(UUID(change_id)) data = await self._redis.get(key) if data: change = self._deserialize_scheduled_change(data) # Filter by flag_id if provided if flag_id is not None and change.flag_id != flag_id: continue # Filter by pending status if requested if pending_only and change.executed: continue result.append(change) # Sort by scheduled_at result.sort(key=lambda c: c.scheduled_at) return result
[docs] async def create_scheduled_change( self, change: ScheduledFlagChange, ) -> ScheduledFlagChange: """Create a new scheduled change. Args: change: The scheduled change to create. Returns: The created scheduled change. """ now = datetime.now(UTC) if change.created_at is None: change.created_at = now # type: ignore[misc] if change.updated_at is None: change.updated_at = now # type: ignore[misc] # Store the change key = self._scheduled_change_key(change.id) await self._redis.set(key, self._serialize_scheduled_change(change)) # Add to global index await self._redis.sadd(self._scheduled_changes_index_key(), str(change.id)) # Add to flag-specific index await self._redis.sadd(self._scheduled_changes_index_key(change.flag_id), str(change.id)) return change
[docs] async def update_scheduled_change( self, change: ScheduledFlagChange, ) -> ScheduledFlagChange: """Update a scheduled change (e.g., mark as executed). Args: change: The scheduled change with updated values. Returns: The updated scheduled change. """ change.updated_at = datetime.now(UTC) # type: ignore[misc] key = self._scheduled_change_key(change.id) await self._redis.set(key, self._serialize_scheduled_change(change)) return change
# Time schedule methods
[docs] async def get_time_schedules( self, flag_id: UUID | None = None, ) -> list[TimeSchedule]: """Get time schedules for a flag or all flags. Args: flag_id: If provided, filter to schedules for this flag only. Returns: List of time schedules matching the criteria. """ # Get all schedule IDs from the appropriate index index_key = self._time_schedules_index_key(flag_id) schedule_ids = await self._redis.smembers(index_key) if not schedule_ids: if flag_id is not None: # No schedules for this flag return [] # Try global index schedule_ids = await self._redis.smembers(self._time_schedules_index_key()) if not schedule_ids: return [] # Get all schedules result = [] for schedule_id in schedule_ids: key = self._time_schedule_key(UUID(schedule_id)) data = await self._redis.get(key) if data: schedule = self._deserialize_time_schedule(data) # Filter by flag_id if provided if flag_id is not None and schedule.flag_id != flag_id: continue result.append(schedule) return result
[docs] async def create_time_schedule( self, schedule: TimeSchedule, ) -> TimeSchedule: """Create a new time schedule. Args: schedule: The time schedule to create. Returns: The created time schedule. """ now = datetime.now(UTC) if schedule.created_at is None: schedule.created_at = now # type: ignore[misc] if schedule.updated_at is None: schedule.updated_at = now # type: ignore[misc] # Store the schedule key = self._time_schedule_key(schedule.id) await self._redis.set(key, self._serialize_time_schedule(schedule)) # Add to global index await self._redis.sadd(self._time_schedules_index_key(), str(schedule.id)) # Add to flag-specific index await self._redis.sadd(self._time_schedules_index_key(schedule.flag_id), str(schedule.id)) return schedule
[docs] async def delete_time_schedule(self, schedule_id: UUID) -> bool: """Delete a time schedule. Args: schedule_id: The UUID of the time schedule to delete. Returns: True if the schedule was deleted, False if not found. """ key = self._time_schedule_key(schedule_id) data = await self._redis.get(key) if data is None: return False schedule = self._deserialize_time_schedule(data) # Delete the schedule await self._redis.delete(key) # Remove from global index await self._redis.srem(self._time_schedules_index_key(), str(schedule_id)) # Remove from flag-specific index await self._redis.srem(self._time_schedules_index_key(schedule.flag_id), str(schedule_id)) return True
# Rollout phase methods
[docs] async def get_rollout_phases(self, flag_id: UUID) -> list[RolloutPhase]: """Get rollout phases for a flag. Args: flag_id: The UUID of the flag. Returns: List of rollout phases for the flag, ordered by phase number. """ # Get all phase IDs from the flag-specific index index_key = self._rollout_phases_index_key(flag_id) phase_ids = await self._redis.smembers(index_key) if not phase_ids: return [] # Get all phases result = [] for phase_id in phase_ids: key = self._rollout_phase_key(UUID(phase_id)) data = await self._redis.get(key) if data: phase = self._deserialize_rollout_phase(data) result.append(phase) # Sort by phase_number result.sort(key=lambda p: p.phase_number) return result
[docs] async def create_rollout_phase(self, phase: RolloutPhase) -> RolloutPhase: """Create a new rollout phase. Args: phase: The rollout phase to create. Returns: The created rollout phase. """ now = datetime.now(UTC) if phase.created_at is None: phase.created_at = now # type: ignore[misc] if phase.updated_at is None: phase.updated_at = now # type: ignore[misc] # Store the phase key = self._rollout_phase_key(phase.id) await self._redis.set(key, self._serialize_rollout_phase(phase)) # Add to flag-specific index await self._redis.sadd(self._rollout_phases_index_key(phase.flag_id), str(phase.id)) return phase