Source code for litestar_flags.analytics.aggregator

"""Analytics aggregator for computing metrics from feature flag evaluation events."""

from __future__ import annotations

from collections import Counter
from dataclasses import dataclass, field
from datetime import UTC, datetime, timedelta
from typing import TYPE_CHECKING, Any

from litestar_flags.analytics.collectors.memory import InMemoryAnalyticsCollector
from litestar_flags.analytics.models import FlagEvaluationEvent
from litestar_flags.types import EvaluationReason

if TYPE_CHECKING:
    from sqlalchemy.ext.asyncio import AsyncSession

__all__ = ["AnalyticsAggregator", "FlagMetrics"]


[docs] @dataclass(slots=True) class FlagMetrics: """Aggregated metrics for a feature flag. Contains computed statistics about flag evaluations including evaluation rate, unique users, distributions, and latency percentiles. Attributes: evaluation_rate: Evaluations per second in the measurement window. unique_users: Count of unique targeting keys in the window. variant_distribution: Count of evaluations per variant. reason_distribution: Count of evaluations per reason. error_rate: Percentage of evaluations that resulted in errors (0-100). latency_p50: 50th percentile latency in milliseconds. latency_p90: 90th percentile latency in milliseconds. latency_p99: 99th percentile latency in milliseconds. total_evaluations: Total number of evaluations in the window. window_start: Start of the measurement window. window_end: End of the measurement window. Example: >>> metrics = FlagMetrics( ... evaluation_rate=10.5, ... unique_users=150, ... variant_distribution={"control": 75, "treatment": 75}, ... reason_distribution={"SPLIT": 150}, ... error_rate=0.0, ... latency_p50=1.2, ... latency_p90=2.5, ... latency_p99=5.0, ... ) """ evaluation_rate: float = 0.0 unique_users: int = 0 variant_distribution: dict[str, int] = field(default_factory=dict) reason_distribution: dict[str, int] = field(default_factory=dict) error_rate: float = 0.0 latency_p50: float = 0.0 latency_p90: float = 0.0 latency_p99: float = 0.0 total_evaluations: int = 0 window_start: datetime | None = None window_end: datetime | None = None
[docs] def to_dict(self) -> dict[str, Any]: """Convert metrics to dictionary representation. Returns: Dictionary representation of the metrics. """ return { "evaluation_rate": self.evaluation_rate, "unique_users": self.unique_users, "variant_distribution": self.variant_distribution, "reason_distribution": self.reason_distribution, "error_rate": self.error_rate, "latency_p50": self.latency_p50, "latency_p90": self.latency_p90, "latency_p99": self.latency_p99, "total_evaluations": self.total_evaluations, "window_start": self.window_start.isoformat() if self.window_start else None, "window_end": self.window_end.isoformat() if self.window_end else None, }
[docs] class AnalyticsAggregator: """Aggregator for computing metrics from feature flag evaluation events. Supports multiple event sources including in-memory collectors and database sessions. Provides methods for computing various metrics including evaluation rates, unique users, distributions, and latencies. The aggregator uses window-based aggregation, only considering events within the specified time window for each metric calculation. Attributes: source: The event source (InMemoryAnalyticsCollector or AsyncSession). Example: >>> from litestar_flags.analytics import InMemoryAnalyticsCollector >>> collector = InMemoryAnalyticsCollector() >>> aggregator = AnalyticsAggregator(collector) >>> rate = await aggregator.get_evaluation_rate("my_flag", window_seconds=60) >>> metrics = await aggregator.get_flag_metrics("my_flag") """
[docs] def __init__( self, source: InMemoryAnalyticsCollector | AsyncSession, ) -> None: """Initialize the analytics aggregator. Args: source: The event source to aggregate from. Can be an InMemoryAnalyticsCollector for in-memory events or an AsyncSession for database-backed events. """ self._source = source self._is_memory_source = isinstance(source, InMemoryAnalyticsCollector)
async def _get_events_in_window( self, flag_key: str, window_seconds: int, ) -> list[FlagEvaluationEvent]: """Get events from the source within a time window. Args: flag_key: The flag key to filter events. window_seconds: Number of seconds to look back. Returns: List of events within the time window. """ since = datetime.now(UTC) - timedelta(seconds=window_seconds) if self._is_memory_source: collector = self._source if isinstance(collector, InMemoryAnalyticsCollector): # Get all events for the flag and filter by timestamp all_events = await collector.get_events(flag_key=flag_key) return [e for e in all_events if e.timestamp >= since] # For AsyncSession, query the database return await self._get_events_from_database(flag_key, window_seconds) async def _get_events_from_database( self, flag_key: str, window_seconds: int, ) -> list[FlagEvaluationEvent]: """Get events from database within a time window. Args: flag_key: The flag key to filter events. window_seconds: Number of seconds to look back. Returns: List of events within the time window. Note: This requires the analytics_events table to be defined. The implementation expects a table with columns matching the AnalyticsEventModel fields. """ if self._is_memory_source: return await self._get_events_in_window(flag_key, window_seconds) # Import here to avoid circular imports and optional dependency try: from sqlalchemy import select from litestar_flags.analytics.models import AnalyticsEventModel except ImportError: return [] session = self._source since = datetime.now(UTC) - timedelta(seconds=window_seconds) stmt = select(AnalyticsEventModel).where( AnalyticsEventModel.flag_key == flag_key, # type: ignore[arg-type] AnalyticsEventModel.timestamp >= since, # type: ignore[arg-type] ) result = await session.execute(stmt) # type: ignore[union-attr] rows = result.scalars().all() # Convert database models to FlagEvaluationEvent events = [] for row in rows: # Extract value from JSON storage value = row.value.get("value") if row.value else None events.append( FlagEvaluationEvent( timestamp=row.timestamp, flag_key=row.flag_key, value=value, reason=EvaluationReason(row.reason) if row.reason else EvaluationReason.DEFAULT, variant=row.variant, targeting_key=row.targeting_key, context_attributes=row.context_attributes or {}, evaluation_duration_ms=row.evaluation_duration_ms or 0.0, ) ) return events def _is_error_event(self, event: FlagEvaluationEvent) -> bool: """Check if an event represents an error. Args: event: The event to check. Returns: True if the event has ERROR reason. """ return event.reason == EvaluationReason.ERROR
[docs] async def get_evaluation_rate( self, flag_key: str, window_seconds: int = 60, ) -> float: """Calculate the evaluation rate for a flag. Args: flag_key: The key of the flag to measure. window_seconds: The time window in seconds (default: 60). Returns: Evaluations per second within the time window. """ events = await self._get_events_in_window(flag_key, window_seconds) if not events or window_seconds <= 0: return 0.0 return len(events) / window_seconds
[docs] async def get_unique_users( self, flag_key: str, window_seconds: int = 3600, ) -> int: """Count unique targeting keys for a flag. Args: flag_key: The key of the flag to measure. window_seconds: The time window in seconds (default: 3600). Returns: Count of unique targeting keys within the time window. """ events = await self._get_events_in_window(flag_key, window_seconds) unique_keys = {event.targeting_key for event in events if event.targeting_key} return len(unique_keys)
[docs] async def get_variant_distribution( self, flag_key: str, window_seconds: int = 3600, ) -> dict[str, int]: """Get the distribution of variants for a flag. Args: flag_key: The key of the flag to measure. window_seconds: The time window in seconds (default: 3600). Returns: Dictionary mapping variant names to evaluation counts. """ events = await self._get_events_in_window(flag_key, window_seconds) counter: Counter[str] = Counter() for event in events: variant = event.variant or "default" counter[variant] += 1 return dict(counter)
[docs] async def get_reason_distribution( self, flag_key: str, window_seconds: int = 3600, ) -> dict[str, int]: """Get the distribution of evaluation reasons for a flag. Args: flag_key: The key of the flag to measure. window_seconds: The time window in seconds (default: 3600). Returns: Dictionary mapping reason strings to evaluation counts. """ events = await self._get_events_in_window(flag_key, window_seconds) counter: Counter[str] = Counter() for event in events: reason_str = event.reason.value if isinstance(event.reason, EvaluationReason) else str(event.reason) counter[reason_str] += 1 return dict(counter)
[docs] async def get_error_rate( self, flag_key: str, window_seconds: int = 3600, ) -> float: """Calculate the error rate for a flag. Args: flag_key: The key of the flag to measure. window_seconds: The time window in seconds (default: 3600). Returns: Percentage of evaluations that resulted in errors (0-100). """ events = await self._get_events_in_window(flag_key, window_seconds) if not events: return 0.0 error_count = sum(1 for event in events if self._is_error_event(event)) return (error_count / len(events)) * 100
[docs] async def get_latency_percentiles( self, flag_key: str, percentiles: list[float] | None = None, ) -> dict[float, float]: """Calculate latency percentiles for a flag. Args: flag_key: The key of the flag to measure. percentiles: List of percentiles to calculate (default: [50, 90, 99]). Returns: Dictionary mapping percentile values to latencies in milliseconds. """ if percentiles is None: percentiles = [50.0, 90.0, 99.0] events = await self._get_events_in_window(flag_key, window_seconds=3600) latencies = [event.evaluation_duration_ms for event in events if event.evaluation_duration_ms > 0] if not latencies: return dict.fromkeys(percentiles, 0.0) if len(latencies) == 1: # With only one data point, all percentiles are the same return dict.fromkeys(percentiles, latencies[0]) # Calculate quantiles using linear interpolation result: dict[float, float] = {} sorted_latencies = sorted(latencies) for p in percentiles: if p <= 0 or p >= 100: result[p] = 0.0 continue index = (p / 100) * (len(sorted_latencies) - 1) lower_idx = int(index) upper_idx = min(lower_idx + 1, len(sorted_latencies) - 1) fraction = index - lower_idx # Linear interpolation result[p] = sorted_latencies[lower_idx] + fraction * ( sorted_latencies[upper_idx] - sorted_latencies[lower_idx] ) return result
[docs] async def get_flag_metrics( self, flag_key: str, window_seconds: int = 3600, ) -> FlagMetrics: """Get all aggregated metrics for a flag. This is a convenience method that computes all available metrics in a single call. Args: flag_key: The key of the flag to measure. window_seconds: The time window in seconds (default: 3600). Returns: FlagMetrics object containing all computed metrics. """ events = await self._get_events_in_window(flag_key, window_seconds) now = datetime.now(UTC) window_start = now - timedelta(seconds=window_seconds) if not events: return FlagMetrics( window_start=window_start, window_end=now, ) # Calculate all metrics from the same event set total = len(events) unique_keys = {event.targeting_key for event in events if event.targeting_key} # Variant distribution variant_counter: Counter[str] = Counter() for event in events: variant = event.variant or "default" variant_counter[variant] += 1 # Reason distribution reason_counter: Counter[str] = Counter() error_count = 0 for event in events: reason_str = event.reason.value if isinstance(event.reason, EvaluationReason) else str(event.reason) reason_counter[reason_str] += 1 if self._is_error_event(event): error_count += 1 # Latency percentiles latencies = [event.evaluation_duration_ms for event in events if event.evaluation_duration_ms > 0] p50, p90, p99 = 0.0, 0.0, 0.0 if latencies: sorted_latencies = sorted(latencies) n = len(sorted_latencies) if n == 1: p50 = p90 = p99 = sorted_latencies[0] else: p50 = self._percentile(sorted_latencies, 50) p90 = self._percentile(sorted_latencies, 90) p99 = self._percentile(sorted_latencies, 99) return FlagMetrics( evaluation_rate=total / window_seconds if window_seconds > 0 else 0.0, unique_users=len(unique_keys), variant_distribution=dict(variant_counter), reason_distribution=dict(reason_counter), error_rate=(error_count / total) * 100 if total > 0 else 0.0, latency_p50=p50, latency_p90=p90, latency_p99=p99, total_evaluations=total, window_start=window_start, window_end=now, )
def _percentile(self, sorted_data: list[float], p: float) -> float: """Calculate a percentile from sorted data. Args: sorted_data: Pre-sorted list of values. p: Percentile to calculate (0-100). Returns: The percentile value with linear interpolation. """ if not sorted_data: return 0.0 if len(sorted_data) == 1: return sorted_data[0] index = (p / 100) * (len(sorted_data) - 1) lower_idx = int(index) upper_idx = min(lower_idx + 1, len(sorted_data) - 1) fraction = index - lower_idx return sorted_data[lower_idx] + fraction * (sorted_data[upper_idx] - sorted_data[lower_idx])