"""Database storage backend using Advanced-Alchemy."""
from __future__ import annotations
from collections.abc import Sequence
from typing import TYPE_CHECKING
from uuid import UUID
from litestar_flags.models.base import HAS_ADVANCED_ALCHEMY
from litestar_flags.types import FlagStatus
if not HAS_ADVANCED_ALCHEMY:
raise ImportError(
"Database backend requires 'advanced-alchemy'. Install with: pip install litestar-flags[database]"
)
from advanced_alchemy.repository import SQLAlchemyAsyncRepository
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from litestar_flags.models.flag import FeatureFlag
from litestar_flags.models.override import FlagOverride
from litestar_flags.models.schedule import RolloutPhase, ScheduledFlagChange, TimeSchedule
if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncEngine
__all__ = [
"DatabaseStorageBackend",
"FeatureFlagRepository",
"FlagOverrideRepository",
"RolloutPhaseRepository",
"ScheduledFlagChangeRepository",
"TimeScheduleRepository",
]
[docs]
class FeatureFlagRepository(SQLAlchemyAsyncRepository[FeatureFlag]):
"""Repository for feature flag CRUD operations."""
model_type = FeatureFlag
[docs]
async def get_by_key(self, key: str) -> FeatureFlag | None:
"""Get a flag by its unique key.
Args:
key: The flag key.
Returns:
The flag if found, None otherwise.
"""
return await self.get_one_or_none(FeatureFlag.key == key)
[docs]
async def get_by_keys(self, keys: Sequence[str]) -> list[FeatureFlag]:
"""Get multiple flags by their keys.
Args:
keys: The flag keys.
Returns:
List of found flags.
"""
if not keys:
return []
return await self.list(FeatureFlag.key.in_(keys))
[docs]
async def get_active_flags(self) -> list[FeatureFlag]:
"""Get all active flags.
Returns:
List of active flags.
"""
return await self.list(FeatureFlag.status == FlagStatus.ACTIVE)
[docs]
class FlagOverrideRepository(SQLAlchemyAsyncRepository[FlagOverride]):
"""Repository for flag override CRUD operations."""
model_type = FlagOverride
[docs]
async def get_override(
self,
flag_id: UUID,
entity_type: str,
entity_id: str,
) -> FlagOverride | None:
"""Get an override for a specific entity.
Args:
flag_id: The flag's UUID.
entity_type: Type of entity.
entity_id: The entity's identifier.
Returns:
The override if found, None otherwise.
"""
return await self.get_one_or_none(
FlagOverride.flag_id == flag_id,
FlagOverride.entity_type == entity_type,
FlagOverride.entity_id == entity_id,
)
[docs]
class ScheduledFlagChangeRepository(SQLAlchemyAsyncRepository[ScheduledFlagChange]):
"""Repository for scheduled flag change CRUD operations."""
model_type = ScheduledFlagChange
[docs]
async def get_pending_changes(
self,
flag_id: UUID | None = None,
) -> list[ScheduledFlagChange]:
"""Get pending (not yet executed) scheduled changes.
Args:
flag_id: If provided, filter to changes for this flag only.
Returns:
List of pending scheduled changes, ordered by scheduled_at.
"""
if flag_id is not None:
return await self.list(
ScheduledFlagChange.executed == False, # noqa: E712
ScheduledFlagChange.flag_id == flag_id,
order_by=[ScheduledFlagChange.scheduled_at],
)
return await self.list(
ScheduledFlagChange.executed == False, # noqa: E712
order_by=[ScheduledFlagChange.scheduled_at],
)
[docs]
async def get_all_changes(
self,
flag_id: UUID | None = None,
) -> list[ScheduledFlagChange]:
"""Get all scheduled changes (pending and executed).
Args:
flag_id: If provided, filter to changes for this flag only.
Returns:
List of scheduled changes, ordered by scheduled_at.
"""
if flag_id is not None:
return await self.list(
ScheduledFlagChange.flag_id == flag_id,
order_by=[ScheduledFlagChange.scheduled_at],
)
return await self.list(order_by=[ScheduledFlagChange.scheduled_at])
[docs]
class TimeScheduleRepository(SQLAlchemyAsyncRepository[TimeSchedule]):
"""Repository for time schedule CRUD operations."""
model_type = TimeSchedule
[docs]
async def get_schedules_for_flag(self, flag_id: UUID) -> list[TimeSchedule]:
"""Get all time schedules for a flag.
Args:
flag_id: The flag's UUID.
Returns:
List of time schedules for the flag.
"""
return await self.list(TimeSchedule.flag_id == flag_id)
[docs]
async def get_enabled_schedules(
self,
flag_id: UUID | None = None,
) -> list[TimeSchedule]:
"""Get all enabled time schedules.
Args:
flag_id: If provided, filter to schedules for this flag only.
Returns:
List of enabled time schedules.
"""
if flag_id is not None:
return await self.list(
TimeSchedule.enabled == True, # noqa: E712
TimeSchedule.flag_id == flag_id,
)
return await self.list(TimeSchedule.enabled == True) # noqa: E712
[docs]
class RolloutPhaseRepository(SQLAlchemyAsyncRepository[RolloutPhase]):
"""Repository for rollout phase CRUD operations."""
model_type = RolloutPhase
[docs]
async def get_phases_for_flag(self, flag_id: UUID) -> list[RolloutPhase]:
"""Get all rollout phases for a flag.
Args:
flag_id: The flag's UUID.
Returns:
List of rollout phases, ordered by phase_number.
"""
return await self.list(
RolloutPhase.flag_id == flag_id,
order_by=[RolloutPhase.phase_number],
)
[docs]
async def get_pending_phases(self, flag_id: UUID) -> list[RolloutPhase]:
"""Get pending (not yet executed) rollout phases for a flag.
Args:
flag_id: The flag's UUID.
Returns:
List of pending rollout phases, ordered by phase_number.
"""
return await self.list(
RolloutPhase.flag_id == flag_id,
RolloutPhase.executed == False, # noqa: E712
order_by=[RolloutPhase.phase_number],
)
[docs]
class DatabaseStorageBackend:
"""Database storage backend using Advanced-Alchemy.
This backend stores feature flags and related data in a relational
database using SQLAlchemy async operations.
Example:
>>> storage = await DatabaseStorageBackend.create(
... connection_string="postgresql+asyncpg://user:pass@localhost/db"
... )
>>> flag = await storage.get_flag("my-feature")
"""
[docs]
def __init__(
self,
engine: AsyncEngine,
session_maker: async_sessionmaker[AsyncSession],
) -> None:
"""Initialize the database storage backend.
Args:
engine: The SQLAlchemy async engine.
session_maker: The session maker factory.
"""
self._engine = engine
self._session_maker = session_maker
[docs]
@classmethod
async def create(
cls,
connection_string: str,
table_prefix: str = "ff_",
create_tables: bool = True,
**engine_kwargs: dict,
) -> DatabaseStorageBackend:
"""Create a new database storage backend.
Args:
connection_string: Database connection string.
table_prefix: Prefix for table names (not currently used).
create_tables: Whether to create tables on startup.
**engine_kwargs: Additional arguments for create_async_engine.
Returns:
Configured DatabaseStorageBackend instance.
"""
engine = create_async_engine(
connection_string,
echo=engine_kwargs.pop("echo", False),
**engine_kwargs,
)
if create_tables:
from litestar_flags.models.flag import FeatureFlag
from litestar_flags.models.override import FlagOverride
from litestar_flags.models.rule import FlagRule
from litestar_flags.models.schedule import (
RolloutPhase,
ScheduledFlagChange,
TimeSchedule,
)
from litestar_flags.models.variant import FlagVariant
# Import to register models
_ = FeatureFlag, FlagOverride, FlagRule, FlagVariant
_ = ScheduledFlagChange, TimeSchedule, RolloutPhase
async with engine.begin() as conn:
from advanced_alchemy.base import orm_registry
await conn.run_sync(orm_registry.metadata.create_all)
session_maker = async_sessionmaker(engine, expire_on_commit=False)
return cls(engine=engine, session_maker=session_maker)
[docs]
async def get_flag(self, key: str) -> FeatureFlag | None:
"""Retrieve a single flag by key.
Args:
key: The unique flag key.
Returns:
The FeatureFlag if found, None otherwise.
"""
async with self._session_maker() as session:
repo = FeatureFlagRepository(session=session)
return await repo.get_by_key(key)
[docs]
async def get_flags(self, keys: Sequence[str]) -> dict[str, FeatureFlag]:
"""Retrieve multiple flags by keys.
Args:
keys: Sequence of flag keys to retrieve.
Returns:
Dictionary mapping flag keys to FeatureFlag objects.
"""
async with self._session_maker() as session:
repo = FeatureFlagRepository(session=session)
flags = await repo.get_by_keys(keys)
return {flag.key: flag for flag in flags}
[docs]
async def get_all_active_flags(self) -> list[FeatureFlag]:
"""Retrieve all active flags.
Returns:
List of all FeatureFlag objects with ACTIVE status.
"""
async with self._session_maker() as session:
repo = FeatureFlagRepository(session=session)
return await repo.get_active_flags()
[docs]
async def get_override(
self,
flag_id: UUID,
entity_type: str,
entity_id: str,
) -> FlagOverride | None:
"""Retrieve entity-specific override.
Args:
flag_id: The flag's UUID.
entity_type: Type of entity (e.g., "user", "organization").
entity_id: The entity's identifier.
Returns:
The FlagOverride if found, None otherwise.
"""
async with self._session_maker() as session:
repo = FlagOverrideRepository(session=session)
return await repo.get_override(flag_id, entity_type, entity_id)
[docs]
async def create_flag(self, flag: FeatureFlag) -> FeatureFlag:
"""Create a new flag.
Args:
flag: The flag to create.
Returns:
The created flag with any generated fields populated.
"""
async with self._session_maker() as session:
repo = FeatureFlagRepository(session=session)
created = await repo.add(flag)
await session.commit()
await session.refresh(created)
return created
[docs]
async def update_flag(self, flag: FeatureFlag) -> FeatureFlag:
"""Update an existing flag.
Args:
flag: The flag with updated values.
Returns:
The updated flag.
"""
async with self._session_maker() as session:
repo = FeatureFlagRepository(session=session)
updated = await repo.update(flag)
await session.commit()
await session.refresh(updated)
return updated
[docs]
async def delete_flag(self, key: str) -> bool:
"""Delete a flag by key.
Args:
key: The unique flag key.
Returns:
True if the flag was deleted, False if not found.
"""
async with self._session_maker() as session:
repo = FeatureFlagRepository(session=session)
flag = await repo.get_by_key(key)
if flag is None:
return False
await repo.delete(flag.id)
await session.commit()
return True
[docs]
async def create_override(self, override: FlagOverride) -> FlagOverride:
"""Create a new override.
Args:
override: The override to create.
Returns:
The created override.
"""
async with self._session_maker() as session:
repo = FlagOverrideRepository(session=session)
created = await repo.add(override)
await session.commit()
await session.refresh(created)
return created
[docs]
async def delete_override(
self,
flag_id: UUID,
entity_type: str,
entity_id: str,
) -> bool:
"""Delete an override.
Args:
flag_id: The flag's UUID.
entity_type: Type of entity.
entity_id: The entity's identifier.
Returns:
True if the override was deleted, False if not found.
"""
async with self._session_maker() as session:
repo = FlagOverrideRepository(session=session)
override = await repo.get_override(flag_id, entity_type, entity_id)
if override is None:
return False
await repo.delete(override.id)
await session.commit()
return True
[docs]
async def health_check(self) -> bool:
"""Check storage backend health.
Returns:
True if the backend is healthy, False otherwise.
"""
try:
async with self._session_maker() as session:
await session.execute(select(1))
return True
except Exception:
return False
[docs]
async def close(self) -> None:
"""Close database connections."""
await self._engine.dispose()
# Scheduled changes methods
[docs]
async def get_scheduled_changes(
self,
flag_id: UUID | None = None,
pending_only: bool = True,
) -> list[ScheduledFlagChange]:
"""Get scheduled changes, optionally filtered by flag and status.
Args:
flag_id: If provided, filter to changes for this flag only.
pending_only: If True, only return changes not yet executed.
Returns:
List of scheduled changes matching the criteria.
"""
async with self._session_maker() as session:
repo = ScheduledFlagChangeRepository(session=session)
if pending_only:
return await repo.get_pending_changes(flag_id)
return await repo.get_all_changes(flag_id)
[docs]
async def create_scheduled_change(
self,
change: ScheduledFlagChange,
) -> ScheduledFlagChange:
"""Create a new scheduled change.
Args:
change: The scheduled change to create.
Returns:
The created scheduled change with any generated fields populated.
"""
async with self._session_maker() as session:
repo = ScheduledFlagChangeRepository(session=session)
created = await repo.add(change)
await session.commit()
await session.refresh(created)
return created
[docs]
async def update_scheduled_change(
self,
change: ScheduledFlagChange,
) -> ScheduledFlagChange:
"""Update a scheduled change (e.g., mark as executed).
Args:
change: The scheduled change with updated values.
Returns:
The updated scheduled change.
"""
async with self._session_maker() as session:
repo = ScheduledFlagChangeRepository(session=session)
updated = await repo.update(change)
await session.commit()
await session.refresh(updated)
return updated
# Time schedule methods
[docs]
async def get_time_schedules(
self,
flag_id: UUID | None = None,
) -> list[TimeSchedule]:
"""Get time schedules for a flag or all flags.
Args:
flag_id: If provided, filter to schedules for this flag only.
Returns:
List of time schedules matching the criteria.
"""
async with self._session_maker() as session:
repo = TimeScheduleRepository(session=session)
if flag_id is not None:
return await repo.get_schedules_for_flag(flag_id)
return await repo.list()
[docs]
async def create_time_schedule(
self,
schedule: TimeSchedule,
) -> TimeSchedule:
"""Create a new time schedule.
Args:
schedule: The time schedule to create.
Returns:
The created time schedule with any generated fields populated.
"""
async with self._session_maker() as session:
repo = TimeScheduleRepository(session=session)
created = await repo.add(schedule)
await session.commit()
await session.refresh(created)
return created
[docs]
async def delete_time_schedule(self, schedule_id: UUID) -> bool:
"""Delete a time schedule.
Args:
schedule_id: The UUID of the time schedule to delete.
Returns:
True if the schedule was deleted, False if not found.
"""
async with self._session_maker() as session:
repo = TimeScheduleRepository(session=session)
schedule = await repo.get(schedule_id)
if schedule is None:
return False
await repo.delete(schedule_id)
await session.commit()
return True
# Rollout phase methods
[docs]
async def get_rollout_phases(self, flag_id: UUID) -> list[RolloutPhase]:
"""Get rollout phases for a flag.
Args:
flag_id: The UUID of the flag.
Returns:
List of rollout phases for the flag, ordered by phase number.
"""
async with self._session_maker() as session:
repo = RolloutPhaseRepository(session=session)
return await repo.get_phases_for_flag(flag_id)
[docs]
async def create_rollout_phase(self, phase: RolloutPhase) -> RolloutPhase:
"""Create a new rollout phase.
Args:
phase: The rollout phase to create.
Returns:
The created rollout phase with any generated fields populated.
"""
async with self._session_maker() as session:
repo = RolloutPhaseRepository(session=session)
created = await repo.add(phase)
await session.commit()
await session.refresh(created)
return created