"""In-memory storage backend for feature flags."""
from __future__ import annotations
from collections.abc import Sequence
from datetime import UTC, datetime
from typing import TYPE_CHECKING
from uuid import UUID
from litestar_flags.types import FlagStatus
if TYPE_CHECKING:
from litestar_flags.models.flag import FeatureFlag
from litestar_flags.models.override import FlagOverride
from litestar_flags.models.schedule import RolloutPhase, ScheduledFlagChange, TimeSchedule
from litestar_flags.models.segment import Segment
__all__ = ["MemoryStorageBackend"]
[docs]
class MemoryStorageBackend:
"""In-memory storage backend for development and testing.
This backend stores all data in memory and is not persistent.
Ideal for development, testing, and simple single-instance deployments.
Example:
>>> storage = MemoryStorageBackend()
>>> await storage.create_flag(flag)
>>> flag = await storage.get_flag("my-feature")
"""
[docs]
def __init__(self) -> None:
"""Initialize the in-memory storage."""
self._flags: dict[str, FeatureFlag] = {}
self._flags_by_id: dict[UUID, FeatureFlag] = {}
self._overrides: dict[str, FlagOverride] = {}
self._scheduled_changes: dict[UUID, ScheduledFlagChange] = {}
self._time_schedules: dict[UUID, TimeSchedule] = {}
self._rollout_phases: dict[UUID, RolloutPhase] = {}
self._segments: dict[UUID, Segment] = {}
self._segments_by_name: dict[str, Segment] = {}
def _override_key(self, flag_id: UUID, entity_type: str, entity_id: str) -> str:
"""Generate a unique key for an override."""
return f"{flag_id}:{entity_type}:{entity_id}"
[docs]
async def get_flag(self, key: str) -> FeatureFlag | None:
"""Retrieve a single flag by key.
Args:
key: The unique flag key.
Returns:
The FeatureFlag if found, None otherwise.
"""
return self._flags.get(key)
[docs]
async def get_flags(self, keys: Sequence[str]) -> dict[str, FeatureFlag]:
"""Retrieve multiple flags by keys.
Args:
keys: Sequence of flag keys to retrieve.
Returns:
Dictionary mapping flag keys to FeatureFlag objects.
"""
return {key: flag for key in keys if (flag := self._flags.get(key)) is not None}
[docs]
async def get_all_active_flags(self) -> list[FeatureFlag]:
"""Retrieve all active flags.
Returns:
List of all FeatureFlag objects with ACTIVE status.
"""
return [flag for flag in self._flags.values() if flag.status == FlagStatus.ACTIVE]
[docs]
async def get_override(
self,
flag_id: UUID,
entity_type: str,
entity_id: str,
) -> FlagOverride | None:
"""Retrieve entity-specific override.
Args:
flag_id: The flag's UUID.
entity_type: Type of entity (e.g., "user", "organization").
entity_id: The entity's identifier.
Returns:
The FlagOverride if found and not expired, None otherwise.
"""
key = self._override_key(flag_id, entity_type, entity_id)
override = self._overrides.get(key)
if override is not None and override.is_expired(datetime.now(UTC)):
# Remove expired override
del self._overrides[key]
return None
return override
[docs]
async def get_overrides_for_entity(
self,
entity_type: str,
entity_id: str,
) -> list[FlagOverride]:
"""Retrieve all overrides for an entity.
Args:
entity_type: Type of entity (e.g., "user", "organization").
entity_id: The entity's identifier.
Returns:
List of non-expired overrides for the entity.
"""
now = datetime.now(UTC)
result = []
expired_keys = []
for key, override in self._overrides.items():
if override.entity_type == entity_type and override.entity_id == entity_id:
if override.is_expired(now):
expired_keys.append(key)
else:
result.append(override)
# Clean up expired overrides
for key in expired_keys:
del self._overrides[key]
return result
[docs]
async def create_flag(self, flag: FeatureFlag) -> FeatureFlag:
"""Create a new flag.
Args:
flag: The flag to create.
Returns:
The created flag.
Raises:
ValueError: If a flag with the same key already exists.
"""
if flag.key in self._flags:
raise ValueError(f"Flag with key '{flag.key}' already exists")
# Set timestamps if not present
now = datetime.now(UTC)
if flag.created_at is None:
flag.created_at = now # type: ignore[misc]
if flag.updated_at is None:
flag.updated_at = now # type: ignore[misc]
self._flags[flag.key] = flag
self._flags_by_id[flag.id] = flag
return flag
[docs]
async def update_flag(self, flag: FeatureFlag) -> FeatureFlag:
"""Update an existing flag.
Args:
flag: The flag with updated values.
Returns:
The updated flag.
Raises:
ValueError: If the flag does not exist.
"""
if flag.key not in self._flags:
raise ValueError(f"Flag with key '{flag.key}' not found")
flag.updated_at = datetime.now(UTC) # type: ignore[misc]
self._flags[flag.key] = flag
self._flags_by_id[flag.id] = flag
return flag
[docs]
async def delete_flag(self, key: str) -> bool:
"""Delete a flag by key.
Args:
key: The unique flag key.
Returns:
True if the flag was deleted, False if not found.
"""
flag = self._flags.pop(key, None)
if flag is not None:
self._flags_by_id.pop(flag.id, None)
# Remove associated overrides
keys_to_remove = [k for k in self._overrides if k.startswith(f"{flag.id}:")]
for k in keys_to_remove:
del self._overrides[k]
return True
return False
[docs]
async def create_override(self, override: FlagOverride) -> FlagOverride:
"""Create a new override.
Args:
override: The override to create.
Returns:
The created override.
"""
if override.flag_id is None:
raise ValueError("Override must have a flag_id")
key = self._override_key(override.flag_id, override.entity_type, override.entity_id)
now = datetime.now(UTC)
if override.created_at is None:
override.created_at = now # type: ignore[misc]
if override.updated_at is None:
override.updated_at = now # type: ignore[misc]
self._overrides[key] = override
return override
[docs]
async def delete_override(
self,
flag_id: UUID,
entity_type: str,
entity_id: str,
) -> bool:
"""Delete an override.
Args:
flag_id: The flag's UUID.
entity_type: Type of entity.
entity_id: The entity's identifier.
Returns:
True if the override was deleted, False if not found.
"""
key = self._override_key(flag_id, entity_type, entity_id)
if key in self._overrides:
del self._overrides[key]
return True
return False
[docs]
async def health_check(self) -> bool:
"""Check storage backend health.
Returns:
Always True for in-memory storage.
"""
return True
[docs]
async def close(self) -> None:
"""Close the storage backend.
Clears all data from memory.
"""
self._flags.clear()
self._flags_by_id.clear()
self._overrides.clear()
self._scheduled_changes.clear()
self._time_schedules.clear()
self._rollout_phases.clear()
self._segments.clear()
self._segments_by_name.clear()
def __len__(self) -> int:
"""Return the number of flags stored."""
return len(self._flags)
# Scheduled changes methods
[docs]
async def get_scheduled_changes(
self,
flag_id: UUID | None = None,
pending_only: bool = True,
) -> list[ScheduledFlagChange]:
"""Get scheduled changes, optionally filtered by flag and status.
Args:
flag_id: If provided, filter to changes for this flag only.
pending_only: If True, only return changes not yet executed.
Returns:
List of scheduled changes matching the criteria.
"""
result = []
for change in self._scheduled_changes.values():
# Filter by flag_id if provided
if flag_id is not None and change.flag_id != flag_id:
continue
# Filter by pending status if requested
if pending_only and change.executed:
continue
result.append(change)
# Sort by scheduled_at
result.sort(key=lambda c: c.scheduled_at)
return result
[docs]
async def create_scheduled_change(
self,
change: ScheduledFlagChange,
) -> ScheduledFlagChange:
"""Create a new scheduled change.
Args:
change: The scheduled change to create.
Returns:
The created scheduled change.
"""
now = datetime.now(UTC)
if change.created_at is None:
change.created_at = now # type: ignore[misc]
if change.updated_at is None:
change.updated_at = now # type: ignore[misc]
self._scheduled_changes[change.id] = change
return change
[docs]
async def update_scheduled_change(
self,
change: ScheduledFlagChange,
) -> ScheduledFlagChange:
"""Update a scheduled change (e.g., mark as executed).
Args:
change: The scheduled change with updated values.
Returns:
The updated scheduled change.
Raises:
ValueError: If the scheduled change does not exist.
"""
if change.id not in self._scheduled_changes:
raise ValueError(f"Scheduled change with id '{change.id}' not found")
change.updated_at = datetime.now(UTC) # type: ignore[misc]
self._scheduled_changes[change.id] = change
return change
# Time schedule methods
[docs]
async def get_time_schedules(
self,
flag_id: UUID | None = None,
) -> list[TimeSchedule]:
"""Get time schedules for a flag or all flags.
Args:
flag_id: If provided, filter to schedules for this flag only.
Returns:
List of time schedules matching the criteria.
"""
if flag_id is None:
return list(self._time_schedules.values())
return [schedule for schedule in self._time_schedules.values() if schedule.flag_id == flag_id]
[docs]
async def create_time_schedule(
self,
schedule: TimeSchedule,
) -> TimeSchedule:
"""Create a new time schedule.
Args:
schedule: The time schedule to create.
Returns:
The created time schedule.
"""
now = datetime.now(UTC)
if schedule.created_at is None:
schedule.created_at = now # type: ignore[misc]
if schedule.updated_at is None:
schedule.updated_at = now # type: ignore[misc]
self._time_schedules[schedule.id] = schedule
return schedule
[docs]
async def delete_time_schedule(self, schedule_id: UUID) -> bool:
"""Delete a time schedule.
Args:
schedule_id: The UUID of the time schedule to delete.
Returns:
True if the schedule was deleted, False if not found.
"""
if schedule_id in self._time_schedules:
del self._time_schedules[schedule_id]
return True
return False
# Rollout phase methods
[docs]
async def get_rollout_phases(self, flag_id: UUID) -> list[RolloutPhase]:
"""Get rollout phases for a flag.
Args:
flag_id: The UUID of the flag.
Returns:
List of rollout phases for the flag, ordered by phase number.
"""
phases = [phase for phase in self._rollout_phases.values() if phase.flag_id == flag_id]
# Sort by phase_number
phases.sort(key=lambda p: p.phase_number)
return phases
[docs]
async def create_rollout_phase(self, phase: RolloutPhase) -> RolloutPhase:
"""Create a new rollout phase.
Args:
phase: The rollout phase to create.
Returns:
The created rollout phase.
"""
now = datetime.now(UTC)
if phase.created_at is None:
phase.created_at = now # type: ignore[misc]
if phase.updated_at is None:
phase.updated_at = now # type: ignore[misc]
self._rollout_phases[phase.id] = phase
return phase
# Segment methods
[docs]
async def get_segment(self, segment_id: UUID) -> Segment | None:
"""Retrieve a segment by ID.
Args:
segment_id: The UUID of the segment.
Returns:
The Segment if found, None otherwise.
"""
return self._segments.get(segment_id)
[docs]
async def get_segment_by_name(self, name: str) -> Segment | None:
"""Retrieve a segment by name.
Args:
name: The unique segment name.
Returns:
The Segment if found, None otherwise.
"""
return self._segments_by_name.get(name)
[docs]
async def get_all_segments(self) -> list[Segment]:
"""Retrieve all segments.
Returns:
List of all Segment objects.
"""
return list(self._segments.values())
[docs]
async def get_child_segments(self, parent_id: UUID) -> list[Segment]:
"""Retrieve all child segments of a parent segment.
Args:
parent_id: The UUID of the parent segment.
Returns:
List of child Segment objects.
"""
return [segment for segment in self._segments.values() if segment.parent_segment_id == parent_id]
[docs]
async def create_segment(self, segment: Segment) -> Segment:
"""Create a new segment.
Args:
segment: The segment to create.
Returns:
The created segment.
Raises:
ValueError: If a segment with the same name already exists.
"""
if segment.name in self._segments_by_name:
raise ValueError(f"Segment with name '{segment.name}' already exists")
now = datetime.now(UTC)
if segment.created_at is None:
segment.created_at = now # type: ignore[misc]
if segment.updated_at is None:
segment.updated_at = now # type: ignore[misc]
self._segments[segment.id] = segment
self._segments_by_name[segment.name] = segment
return segment
[docs]
async def update_segment(self, segment: Segment) -> Segment:
"""Update an existing segment.
Args:
segment: The segment with updated values.
Returns:
The updated segment.
Raises:
ValueError: If the segment does not exist or name conflict occurs.
"""
if segment.id not in self._segments:
raise ValueError(f"Segment with id '{segment.id}' not found")
old_segment = self._segments[segment.id]
# Handle name change
if old_segment.name != segment.name:
if segment.name in self._segments_by_name:
raise ValueError(f"Segment with name '{segment.name}' already exists")
del self._segments_by_name[old_segment.name]
self._segments_by_name[segment.name] = segment
segment.updated_at = datetime.now(UTC) # type: ignore[misc]
self._segments[segment.id] = segment
return segment
[docs]
async def delete_segment(self, segment_id: UUID) -> bool:
"""Delete a segment by ID.
Args:
segment_id: The UUID of the segment to delete.
Returns:
True if the segment was deleted, False if not found.
"""
segment = self._segments.pop(segment_id, None)
if segment is not None:
self._segments_by_name.pop(segment.name, None)
return True
return False