Stage 2 Prompt 5: Prediction Pipeline System
OBJECTIVE
Implement a comprehensive prediction pipeline system that provides native Elixir execution of ML predictions with advanced monitoring, optimization, and multi-provider coordination. This system must deliver execution history tracking, performance metrics collection, adaptive strategy selection, and result validation while maintaining complete DSPy predict API compatibility and achieving 10x performance improvements through native concurrency and intelligent caching.
COMPLETE IMPLEMENTATION CONTEXT
PREDICTION PIPELINE ARCHITECTURE OVERVIEW
From Stage 2 Technical Specification:
┌─────────────────────────────────────────────────────────────┐
│ Prediction Pipeline System │
├─────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────────┐ ┌─────────────────┐ ┌──────────────┐│
│ │ Execution │ │ History │ │ Performance ││
│ │ Engine │ │ Tracking │ │ Metrics ││
│ │ - Strategies │ │ - Storage │ │ - Collection ││
│ │ - Coordination │ │ - Analysis │ │ - Analysis ││
│ │ - Optimization │ │ - Retrieval │ │ - Reporting ││
│ └─────────────────┘ └─────────────────┘ └──────────────┘│
│ │
│ ┌─────────────────┐ ┌─────────────────┐ ┌──────────────┐│
│ │ Provider │ │ Strategy │ │ Result ││
│ │ Coordination │ │ Selection │ │ Validation ││
│ │ - Multi-provider │ │ - Adaptive │ │ - Quality ││
│ │ - Fallbacks │ │ - Performance │ │ - Consistency││
│ │ - Load Balance │ │ - Cost-aware │ │ - Assessment ││
│ └─────────────────┘ └─────────────────┘ └──────────────┘│
└─────────────────────────────────────────────────────────────┘
DSPy PREDICTION SYSTEM ANALYSIS
From comprehensive DSPy source code analysis (predict/predict.py):
DSPy Predict Core Patterns:
# DSPy Predict base class with execution coordination
class Predict(Module):
def __init__(self, signature, **kwargs):
super().__init__()
self.signature = signature
self.demos = []
self.trace = []
self.config = kwargs
def forward(self, **kwargs):
"""Execute prediction with tracing."""
# Prepare prediction context
context = self._prepare_context(**kwargs)
# Execute prediction with strategy
with prediction_trace() as trace:
result = self._execute_prediction(context)
self.trace.append(trace)
# Validate and return result
validated_result = self._validate_result(result)
return validated_result
def _prepare_context(self, **kwargs):
"""Prepare execution context."""
context = {
'signature': self.signature,
'inputs': kwargs,
'demos': self.demos,
'config': self.config
}
return context
def _execute_prediction(self, context):
"""Execute prediction with selected strategy."""
strategy = self._select_strategy(context)
# Generate prompt
prompt = strategy.generate_prompt(context)
# Execute with provider
provider = self._get_provider()
raw_result = provider.generate(prompt, **self.config)
# Parse result
parsed_result = strategy.parse_result(raw_result, context)
return parsed_result
def _validate_result(self, result):
"""Validate prediction result."""
# Type validation
for field_name, field in self.signature.output_fields.items():
if field_name not in result:
raise ValueError(f"Missing output field: {field_name}")
# Validate type and constraints
field.validate(result[field_name])
return result
# Prediction execution tracking
class PredictionTrace:
def __init__(self):
self.start_time = time.time()
self.end_time = None
self.prompt = None
self.raw_output = None
self.parsed_output = None
self.provider = None
self.tokens_used = 0
self.errors = []
def complete(self):
self.end_time = time.time()
self.duration = self.end_time - self.start_time
def add_error(self, error):
self.errors.append({
'timestamp': time.time(),
'error': str(error),
'type': type(error).__name__
})
# Strategy pattern for prediction execution
class PredictionStrategy:
def generate_prompt(self, context):
"""Generate prompt from context."""
raise NotImplementedError
def parse_result(self, raw_result, context):
"""Parse raw result into structured output."""
raise NotImplementedError
# Example prediction strategies
class StandardStrategy(PredictionStrategy):
def generate_prompt(self, context):
# Build prompt with signature and demos
prompt_parts = []
# Add instructions
if context['signature'].instructions:
prompt_parts.append(context['signature'].instructions)
# Add demos
for demo in context['demos']:
prompt_parts.append(self._format_demo(demo))
# Add current input
prompt_parts.append(self._format_input(context['inputs']))
return "\n\n".join(prompt_parts)
def parse_result(self, raw_result, context):
# Extract fields from raw output
parsed = {}
for field_name, field in context['signature'].output_fields.items():
value = self._extract_field(raw_result, field_name, field)
parsed[field_name] = value
return parsed
Key DSPy Prediction Features:
- Execution Strategies - Multiple strategies for prompt generation and parsing
- Trace Tracking - Comprehensive execution history and debugging
- Provider Coordination - Flexible provider selection and execution
- Result Validation - Type checking and constraint validation
- Performance Monitoring - Token usage and execution time tracking
EXDANTIC INTEGRATION FOR PREDICTION VALIDATION
From ExDantic research for prediction result validation:
defmodule ExDantic.Prediction do
@moduledoc """
Prediction result validation with ExDantic integration.
"""
use ExDantic
# Prediction result schema
typed_schema "prediction_result" do
field :outputs, :map, required: true
field :metadata, PredictionMetadata
field :trace, PredictionTrace
# Custom validation
validate :outputs, &validate_output_fields/1
end
# Metadata schema
typed_schema "prediction_metadata" do
field :provider, :string, required: true
field :model, :string, required: true
field :temperature, :float
field :tokens_used, :integer
field :duration_ms, :integer
field :strategy, :string
end
# Trace entry schema
typed_schema "prediction_trace" do
field :timestamp, :utc_datetime
field :event_type, :string
field :details, :map
field :error, :string
end
end
NATIVE ELIXIR IMPLEMENTATION
CORE PREDICTION ENGINE
defmodule DSPex.Prediction.Engine do
@moduledoc """
Core prediction execution engine with strategy selection and monitoring.
Provides:
- Native prediction execution with multiple strategies
- Comprehensive execution history tracking
- Performance metrics collection and analysis
- Multi-provider coordination with fallbacks
- Adaptive strategy selection based on performance
- Result validation and quality assessment
"""
use GenServer
require Logger
alias DSPex.Prediction.{
Strategy,
History,
Metrics,
Validator,
ProviderCoordinator
}
@type prediction_opts :: [
strategy: atom(),
provider: atom(),
timeout: timeout(),
max_retries: non_neg_integer(),
cache: boolean(),
trace: boolean()
]
@type prediction_result :: %{
outputs: map(),
metadata: map(),
trace: list(map()),
metrics: map()
}
# Client API
@doc """
Start the prediction engine.
"""
def start_link(opts \\ []) do
GenServer.start_link(__MODULE__, opts, name: __MODULE__)
end
@doc """
Execute a prediction with the given signature and inputs.
"""
@spec predict(DSPex.Signature.t(), map(), prediction_opts()) ::
{:ok, prediction_result()} | {:error, term()}
def predict(signature, inputs, opts \\ []) do
GenServer.call(__MODULE__, {:predict, signature, inputs, opts},
Keyword.get(opts, :timeout, 30_000))
end
@doc """
Execute multiple predictions in parallel.
"""
@spec predict_batch([{DSPex.Signature.t(), map()}], prediction_opts()) ::
{:ok, [prediction_result()]} | {:error, term()}
def predict_batch(predictions, opts \\ []) do
GenServer.call(__MODULE__, {:predict_batch, predictions, opts},
Keyword.get(opts, :timeout, 60_000))
end
@doc """
Get prediction metrics for analysis.
"""
@spec get_metrics(keyword()) :: {:ok, map()} | {:error, term()}
def get_metrics(filters \\ []) do
GenServer.call(__MODULE__, {:get_metrics, filters})
end
# Server callbacks
@impl true
def init(opts) do
# Initialize components
{:ok, history} = History.start_link()
{:ok, metrics} = Metrics.start_link()
{:ok, coordinator} = ProviderCoordinator.start_link()
# Load strategies
strategies = load_strategies()
# Initialize caches
:ets.new(:prediction_cache, [:named_table, :public, :set])
:ets.new(:strategy_performance, [:named_table, :public, :set])
state = %{
history: history,
metrics: metrics,
coordinator: coordinator,
strategies: strategies,
opts: opts
}
{:ok, state}
end
@impl true
def handle_call({:predict, signature, inputs, opts}, _from, state) do
result = execute_prediction(signature, inputs, opts, state)
{:reply, result, state}
end
@impl true
def handle_call({:predict_batch, predictions, opts}, _from, state) do
result = execute_batch_predictions(predictions, opts, state)
{:reply, result, state}
end
@impl true
def handle_call({:get_metrics, filters}, _from, state) do
result = Metrics.get_metrics(state.metrics, filters)
{:reply, result, state}
end
# Private functions
defp execute_prediction(signature, inputs, opts, state) do
trace_id = generate_trace_id()
start_time = System.monotonic_time(:millisecond)
with {:ok, context} <- prepare_context(signature, inputs, opts, state),
{:ok, strategy} <- select_strategy(context, state),
{:ok, prompt} <- generate_prompt(strategy, context),
{:ok, raw_result} <- execute_with_provider(prompt, context, state),
{:ok, parsed_result} <- parse_result(strategy, raw_result, context),
{:ok, validated_result} <- validate_result(parsed_result, signature) do
# Record metrics
duration = System.monotonic_time(:millisecond) - start_time
record_success_metrics(trace_id, strategy, duration, state)
# Build final result
result = build_prediction_result(
validated_result,
trace_id,
strategy,
duration,
context
)
# Store in history
History.record(state.history, result)
{:ok, result}
else
{:error, reason} = error ->
duration = System.monotonic_time(:millisecond) - start_time
record_error_metrics(trace_id, reason, duration, state)
error
end
end
defp execute_batch_predictions(predictions, opts, state) do
# Execute predictions in parallel with controlled concurrency
max_concurrency = Keyword.get(opts, :max_concurrency, 10)
predictions
|> Task.async_stream(
fn {signature, inputs} ->
execute_prediction(signature, inputs, opts, state)
end,
max_concurrency: max_concurrency,
timeout: Keyword.get(opts, :timeout, 30_000)
)
|> Enum.reduce_while({:ok, []}, fn
{:ok, {:ok, result}}, {:ok, results} ->
{:cont, {:ok, [result | results]}}
{:ok, {:error, _} = error}, _acc ->
{:halt, error}
{:exit, reason}, _acc ->
{:halt, {:error, {:batch_execution_failed, reason}}}
end)
|> case do
{:ok, results} -> {:ok, Enum.reverse(results)}
error -> error
end
end
defp prepare_context(signature, inputs, opts, state) do
# Check cache if enabled
cache_key = generate_cache_key(signature, inputs, opts)
case Keyword.get(opts, :cache, true) and get_cached_result(cache_key) do
{:ok, cached} ->
{:ok, Map.put(cached, :from_cache, true)}
_ ->
context = %{
signature: signature,
inputs: inputs,
opts: opts,
demos: get_demos(signature, state),
trace: [],
cache_key: cache_key
}
{:ok, context}
end
end
defp select_strategy(context, state) do
strategy_name = case Keyword.get(context.opts, :strategy) do
nil -> select_adaptive_strategy(context, state)
name -> name
end
case Map.get(state.strategies, strategy_name) do
nil -> {:error, {:unknown_strategy, strategy_name}}
strategy -> {:ok, strategy}
end
end
defp select_adaptive_strategy(context, state) do
# Select strategy based on performance history
signature_type = classify_signature(context.signature)
# Get performance stats for each strategy
strategy_stats = Enum.map(state.strategies, fn {name, _strategy} ->
stats = get_strategy_performance(name, signature_type)
{name, calculate_strategy_score(stats)}
end)
# Select best performing strategy
{best_strategy, _score} = Enum.max_by(strategy_stats, &elem(&1, 1))
best_strategy
end
defp generate_prompt(strategy, context) do
try do
prompt = Strategy.generate_prompt(strategy, context)
{:ok, prompt}
rescue
e -> {:error, {:prompt_generation_failed, Exception.message(e)}}
end
end
defp execute_with_provider(prompt, context, state) do
provider_opts = build_provider_opts(context)
ProviderCoordinator.execute(
state.coordinator,
prompt,
provider_opts
)
end
defp parse_result(strategy, raw_result, context) do
try do
parsed = Strategy.parse_result(strategy, raw_result, context)
{:ok, parsed}
rescue
e -> {:error, {:parsing_failed, Exception.message(e)}}
end
end
defp validate_result(parsed_result, signature) do
Validator.validate(parsed_result, signature)
end
defp build_prediction_result(outputs, trace_id, strategy, duration, context) do
%{
outputs: outputs,
metadata: %{
trace_id: trace_id,
strategy: strategy.__struct__,
duration_ms: duration,
timestamp: DateTime.utc_now(),
from_cache: Map.get(context, :from_cache, false)
},
trace: context.trace,
metrics: %{
tokens_used: calculate_tokens(context),
prompt_length: String.length(context[:prompt] || ""),
output_length: calculate_output_length(outputs)
}
}
end
defp record_success_metrics(trace_id, strategy, duration, state) do
Metrics.record(state.metrics, %{
event: :prediction_success,
trace_id: trace_id,
strategy: strategy.__struct__,
duration_ms: duration,
timestamp: System.system_time(:millisecond)
})
# Update strategy performance
update_strategy_performance(strategy.__struct__, :success, duration)
end
defp record_error_metrics(trace_id, reason, duration, state) do
Metrics.record(state.metrics, %{
event: :prediction_error,
trace_id: trace_id,
error: inspect(reason),
duration_ms: duration,
timestamp: System.system_time(:millisecond)
})
end
defp get_cached_result(cache_key) do
case :ets.lookup(:prediction_cache, cache_key) do
[{^cache_key, result, expiry}] ->
if System.system_time(:second) < expiry do
{:ok, result}
else
:ets.delete(:prediction_cache, cache_key)
:not_found
end
[] -> :not_found
end
end
defp cache_result(cache_key, result, ttl_seconds \\ 300) do
expiry = System.system_time(:second) + ttl_seconds
:ets.insert(:prediction_cache, {cache_key, result, expiry})
end
defp generate_trace_id do
:crypto.strong_rand_bytes(16) |> Base.encode16(case: :lower)
end
defp generate_cache_key(signature, inputs, opts) do
data = {signature.name, inputs, Keyword.take(opts, [:strategy, :provider])}
:crypto.hash(:sha256, :erlang.term_to_binary(data))
|> Base.encode16(case: :lower)
end
defp load_strategies do
%{
standard: DSPex.Prediction.Strategy.Standard,
cot: DSPex.Prediction.Strategy.ChainOfThought,
react: DSPex.Prediction.Strategy.React,
program: DSPex.Prediction.Strategy.Program
}
end
end
PREDICTION STRATEGY SYSTEM
defmodule DSPex.Prediction.Strategy do
@moduledoc """
Behavior and implementations for prediction execution strategies.
"""
@callback generate_prompt(context :: map()) :: String.t()
@callback parse_result(raw_result :: String.t(), context :: map()) :: map()
@callback supports_signature?(signature :: DSPex.Signature.t()) :: boolean()
defmacro __using__(_opts) do
quote do
@behaviour DSPex.Prediction.Strategy
def supports_signature?(_signature), do: true
defoverridable supports_signature?: 1
end
end
end
defmodule DSPex.Prediction.Strategy.Standard do
@moduledoc """
Standard prediction strategy with basic prompt generation and parsing.
"""
use DSPex.Prediction.Strategy
@impl true
def generate_prompt(context) do
%{signature: signature, inputs: inputs, demos: demos} = context
prompt_parts = []
# Add instructions
if signature.instructions do
prompt_parts = [signature.instructions | prompt_parts]
end
# Add demos if available
if demos != [] do
demo_text = format_demos(demos, signature)
prompt_parts = prompt_parts ++ [demo_text]
end
# Add input fields
input_text = format_inputs(inputs, signature)
prompt_parts = prompt_parts ++ [input_text]
# Add output field prompts
output_text = format_output_prompts(signature)
prompt_parts = prompt_parts ++ [output_text]
Enum.join(prompt_parts, "\n\n")
end
@impl true
def parse_result(raw_result, context) do
%{signature: signature} = context
# Parse each output field
signature.output_fields
|> Enum.reduce(%{}, fn {field_name, field_spec}, acc ->
value = extract_field_value(raw_result, field_name, field_spec)
Map.put(acc, field_name, value)
end)
end
defp format_demos(demos, signature) do
demos
|> Enum.with_index(1)
|> Enum.map(fn {demo, idx} ->
"""
Example #{idx}:
#{format_demo_inputs(demo.inputs, signature)}
#{format_demo_outputs(demo.outputs, signature)}
"""
end)
|> Enum.join("\n")
end
defp format_inputs(inputs, signature) do
signature.input_fields
|> Enum.map(fn {field_name, field_spec} ->
value = Map.get(inputs, field_name, "")
"#{field_spec.desc || String.capitalize(to_string(field_name))}: #{value}"
end)
|> Enum.join("\n")
end
defp format_output_prompts(signature) do
signature.output_fields
|> Enum.map(fn {field_name, field_spec} ->
"#{field_spec.desc || String.capitalize(to_string(field_name))}:"
end)
|> Enum.join("\n")
end
defp extract_field_value(raw_result, field_name, field_spec) do
# Use regex or string parsing to extract field value
pattern = ~r/#{Regex.escape(field_spec.desc || to_string(field_name))}:\s*(.+?)(?:\n|$)/i
case Regex.run(pattern, raw_result) do
[_, value] -> String.trim(value)
_ -> extract_with_fallback(raw_result, field_name)
end
end
defp extract_with_fallback(raw_result, field_name) do
# Fallback extraction strategies
lines = String.split(raw_result, "\n")
# Look for field name in various formats
Enum.find_value(lines, fn line ->
cond do
String.contains?(line, "#{field_name}:") ->
String.split(line, ":", parts: 2) |> List.last() |> String.trim()
String.contains?(line, String.capitalize(to_string(field_name))) ->
String.split(line, ":", parts: 2) |> List.last() |> String.trim()
true -> nil
end
end) || ""
end
end
EXECUTION HISTORY TRACKING
defmodule DSPex.Prediction.History do
@moduledoc """
Prediction execution history tracking and analysis.
Stores comprehensive execution history for:
- Debugging and analysis
- Performance optimization
- Demo selection
- Failure analysis
"""
use GenServer
require Logger
@table_name :prediction_history
@max_history_size 10_000
@cleanup_interval :timer.minutes(5)
# Client API
def start_link(opts \\ []) do
GenServer.start_link(__MODULE__, opts, name: __MODULE__)
end
@doc """
Record a prediction execution.
"""
def record(server \\ __MODULE__, prediction_result) do
GenServer.cast(server, {:record, prediction_result})
end
@doc """
Query prediction history with filters.
"""
def query(server \\ __MODULE__, filters \\ []) do
GenServer.call(server, {:query, filters})
end
@doc """
Get execution statistics.
"""
def get_stats(server \\ __MODULE__, options \\ []) do
GenServer.call(server, {:get_stats, options})
end
@doc """
Find similar predictions for demo selection.
"""
def find_similar(server \\ __MODULE__, signature, inputs, limit \\ 5) do
GenServer.call(server, {:find_similar, signature, inputs, limit})
end
# Server callbacks
@impl true
def init(opts) do
# Create ETS table for history storage
table = :ets.new(@table_name, [
:ordered_set,
:public,
{:read_concurrency, true},
{:write_concurrency, true}
])
# Schedule periodic cleanup
schedule_cleanup()
state = %{
table: table,
opts: opts,
stats: init_stats()
}
{:ok, state}
end
@impl true
def handle_cast({:record, prediction_result}, state) do
# Store in ETS with timestamp key
key = {System.system_time(:microsecond), prediction_result.metadata.trace_id}
entry = %{
key: key,
timestamp: prediction_result.metadata.timestamp,
signature_name: prediction_result.metadata.signature_name,
inputs: prediction_result.inputs,
outputs: prediction_result.outputs,
strategy: prediction_result.metadata.strategy,
duration_ms: prediction_result.metadata.duration_ms,
tokens_used: prediction_result.metrics.tokens_used,
success: true,
trace: prediction_result.trace
}
:ets.insert(state.table, {key, entry})
# Update stats
state = update_stats(state, entry)
{:noreply, state}
end
@impl true
def handle_call({:query, filters}, _from, state) do
results = query_history(state.table, filters)
{:reply, {:ok, results}, state}
end
@impl true
def handle_call({:get_stats, options}, _from, state) do
stats = calculate_stats(state, options)
{:reply, {:ok, stats}, state}
end
@impl true
def handle_call({:find_similar, signature, inputs, limit}, _from, state) do
similar = find_similar_predictions(state.table, signature, inputs, limit)
{:reply, {:ok, similar}, state}
end
@impl true
def handle_info(:cleanup, state) do
# Remove old entries beyond max size
cleanup_old_entries(state.table)
schedule_cleanup()
{:noreply, state}
end
# Private functions
defp init_stats do
%{
total_predictions: 0,
success_count: 0,
error_count: 0,
total_duration_ms: 0,
total_tokens: 0,
strategy_counts: %{},
hourly_counts: init_hourly_counts()
}
end
defp update_stats(state, entry) do
stats = state.stats
updated_stats = %{
stats |
total_predictions: stats.total_predictions + 1,
success_count: stats.success_count + if(entry.success, do: 1, else: 0),
error_count: stats.error_count + if(entry.success, do: 0, else: 1),
total_duration_ms: stats.total_duration_ms + entry.duration_ms,
total_tokens: stats.total_tokens + entry.tokens_used,
strategy_counts: Map.update(
stats.strategy_counts,
entry.strategy,
1,
&(&1 + 1)
),
hourly_counts: update_hourly_counts(stats.hourly_counts, entry.timestamp)
}
%{state | stats: updated_stats}
end
defp query_history(table, filters) do
# Build match spec from filters
match_spec = build_match_spec(filters)
# Query ETS with limits
limit = Keyword.get(filters, :limit, 100)
:ets.select(table, match_spec, limit)
|> Enum.map(fn {_key, entry} -> entry end)
end
defp find_similar_predictions(table, signature, inputs, limit) do
# Find predictions with same signature
:ets.select(table, [
{
{:_, %{signature_name: signature.name, success: true} = :"$1"},
[],
[:"$1"]
}
])
|> Enum.map(fn entry ->
# Calculate similarity score
similarity = calculate_similarity(entry.inputs, inputs)
{similarity, entry}
end)
|> Enum.sort_by(&elem(&1, 0), :desc)
|> Enum.take(limit)
|> Enum.map(&elem(&1, 1))
end
defp calculate_similarity(inputs1, inputs2) do
# Simple similarity based on common keys and values
keys1 = MapSet.new(Map.keys(inputs1))
keys2 = MapSet.new(Map.keys(inputs2))
common_keys = MapSet.intersection(keys1, keys2)
if MapSet.size(common_keys) == 0 do
0.0
else
matching_values = Enum.count(common_keys, fn key ->
Map.get(inputs1, key) == Map.get(inputs2, key)
end)
matching_values / MapSet.size(common_keys)
end
end
defp cleanup_old_entries(table) do
# Get table size
size = :ets.info(table, :size)
if size > @max_history_size do
# Calculate how many to remove
to_remove = size - @max_history_size
# Get oldest entries
oldest = :ets.select(table, [{{:"$1", :_}, [], [:"$1"]}], to_remove)
# Delete them
Enum.each(oldest, &:ets.delete(table, &1))
Logger.info("Cleaned up #{to_remove} old prediction history entries")
end
end
defp schedule_cleanup do
Process.send_after(self(), :cleanup, @cleanup_interval)
end
defp build_match_spec(filters) do
# Complex match spec building based on filters
conditions = build_conditions(filters)
[
{
{:_, :"$1"},
conditions,
[{:_, :"$1"}]
}
]
end
defp build_conditions(filters) do
Enum.reduce(filters, [], fn
{:signature_name, name}, acc ->
[{:==, {:map_get, :signature_name, :"$1"}, name} | acc]
{:strategy, strategy}, acc ->
[{:==, {:map_get, :strategy, :"$1"}, strategy} | acc]
{:min_duration, min_ms}, acc ->
[{:>=, {:map_get, :duration_ms, :"$1"}, min_ms} | acc]
{:after, timestamp}, acc ->
[{:>, {:map_get, :timestamp, :"$1"}, timestamp} | acc]
_, acc -> acc
end)
end
end
PERFORMANCE METRICS COLLECTION
defmodule DSPex.Prediction.Metrics do
@moduledoc """
Comprehensive metrics collection and analysis for prediction pipeline.
Tracks:
- Execution performance metrics
- Provider usage and costs
- Strategy effectiveness
- Error rates and patterns
- Resource utilization
"""
use GenServer
require Logger
alias DSPex.Telemetry
@metrics_table :prediction_metrics
@aggregation_interval :timer.seconds(30)
# Metric types
@counter_metrics ~w(
predictions_total
predictions_success
predictions_error
tokens_used
cache_hits
cache_misses
)a
@histogram_metrics ~w(
prediction_duration_ms
prompt_generation_duration_ms
parsing_duration_ms
validation_duration_ms
queue_wait_time_ms
)a
@gauge_metrics ~w(
active_predictions
queue_size
cache_size
memory_usage_mb
)a
# Client API
def start_link(opts \\ []) do
GenServer.start_link(__MODULE__, opts, name: __MODULE__)
end
@doc """
Record a metric event.
"""
def record(server \\ __MODULE__, event) do
GenServer.cast(server, {:record, event})
end
@doc """
Get current metrics with optional filters.
"""
def get_metrics(server \\ __MODULE__, filters \\ []) do
GenServer.call(server, {:get_metrics, filters})
end
@doc """
Get metric aggregations over time windows.
"""
def get_aggregations(server \\ __MODULE__, metric, window) do
GenServer.call(server, {:get_aggregations, metric, window})
end
# Server callbacks
@impl true
def init(opts) do
# Create metrics storage
metrics_table = :ets.new(@metrics_table, [
:set,
:public,
{:write_concurrency, true}
])
# Initialize metric stores
init_metric_stores(metrics_table)
# Setup telemetry handlers
setup_telemetry_handlers()
# Schedule aggregation
schedule_aggregation()
state = %{
metrics_table: metrics_table,
opts: opts,
aggregations: %{},
start_time: System.system_time(:second)
}
{:ok, state}
end
@impl true
def handle_cast({:record, event}, state) do
record_event(event, state)
{:noreply, state}
end
@impl true
def handle_call({:get_metrics, filters}, _from, state) do
metrics = collect_metrics(state, filters)
{:reply, {:ok, metrics}, state}
end
@impl true
def handle_call({:get_aggregations, metric, window}, _from, state) do
aggregations = get_metric_aggregations(state, metric, window)
{:reply, {:ok, aggregations}, state}
end
@impl true
def handle_info(:aggregate, state) do
state = perform_aggregation(state)
schedule_aggregation()
{:noreply, state}
end
# Private functions
defp init_metric_stores(table) do
# Initialize counters
Enum.each(@counter_metrics, fn metric ->
:ets.insert(table, {metric, :counter.new()})
end)
# Initialize histograms
Enum.each(@histogram_metrics, fn metric ->
:ets.insert(table, {metric, init_histogram()})
end)
# Initialize gauges
Enum.each(@gauge_metrics, fn metric ->
:ets.insert(table, {metric, 0})
end)
end
defp setup_telemetry_handlers do
# Attach to prediction pipeline events
:telemetry.attach_many(
"ash-dspy-prediction-metrics",
[
[:dspex, :prediction, :start],
[:dspex, :prediction, :stop],
[:dspex, :prediction, :exception],
[:dspex, :provider, :request, :start],
[:dspex, :provider, :request, :stop],
[:dspex, :cache, :hit],
[:dspex, :cache, :miss]
],
&handle_telemetry_event/4,
nil
)
end
defp handle_telemetry_event(event_name, measurements, metadata, _config) do
case event_name do
[:dspex, :prediction, :start] ->
increment_gauge(:active_predictions, 1)
[:dspex, :prediction, :stop] ->
increment_gauge(:active_predictions, -1)
increment_counter(:predictions_success)
record_histogram(:prediction_duration_ms, measurements.duration)
[:dspex, :prediction, :exception] ->
increment_gauge(:active_predictions, -1)
increment_counter(:predictions_error)
[:dspex, :provider, :request, :stop] ->
if tokens = measurements[:tokens_used] do
increment_counter(:tokens_used, tokens)
end
[:dspex, :cache, :hit] ->
increment_counter(:cache_hits)
[:dspex, :cache, :miss] ->
increment_counter(:cache_misses)
_ -> :ok
end
end
defp record_event(event, state) do
case event do
%{event: :prediction_success} = data ->
increment_counter(:predictions_success)
record_histogram(:prediction_duration_ms, data.duration_ms)
# Record strategy-specific metrics
strategy_key = {:strategy_success, data.strategy}
increment_dynamic_counter(state.metrics_table, strategy_key)
%{event: :prediction_error} = data ->
increment_counter(:predictions_error)
# Record error-specific metrics
error_key = {:error_type, data.error_type}
increment_dynamic_counter(state.metrics_table, error_key)
%{event: :prompt_generated} = data ->
record_histogram(:prompt_generation_duration_ms, data.duration_ms)
%{event: :result_parsed} = data ->
record_histogram(:parsing_duration_ms, data.duration_ms)
%{event: :validation_completed} = data ->
record_histogram(:validation_duration_ms, data.duration_ms)
_ -> :ok
end
end
defp collect_metrics(state, filters) do
# Collect all metrics based on filters
base_metrics = %{
counters: collect_counters(state.metrics_table),
histograms: collect_histograms(state.metrics_table),
gauges: collect_gauges(state.metrics_table),
uptime_seconds: System.system_time(:second) - state.start_time
}
# Add derived metrics
derive_metrics(base_metrics)
end
defp derive_metrics(base_metrics) do
counters = base_metrics.counters
success_rate = if counters.predictions_total > 0 do
counters.predictions_success / counters.predictions_total * 100
else
0.0
end
cache_hit_rate = if (counters.cache_hits + counters.cache_misses) > 0 do
counters.cache_hits / (counters.cache_hits + counters.cache_misses) * 100
else
0.0
end
Map.merge(base_metrics, %{
derived: %{
success_rate: success_rate,
cache_hit_rate: cache_hit_rate,
avg_tokens_per_prediction: safe_divide(
counters.tokens_used,
counters.predictions_total
)
}
})
end
defp perform_aggregation(state) do
# Aggregate metrics into time windows
current_time = System.system_time(:second)
# Collect current values
metrics_snapshot = collect_metrics(state, [])
# Store in aggregations
updated_aggregations = Map.update(
state.aggregations,
current_time,
metrics_snapshot,
fn _ -> metrics_snapshot end
)
# Cleanup old aggregations (keep last hour)
cutoff_time = current_time - 3600
cleaned_aggregations = Map.reject(updated_aggregations, fn {time, _} ->
time < cutoff_time
end)
%{state | aggregations: cleaned_aggregations}
end
defp get_metric_aggregations(state, metric, window) do
# Calculate aggregations for specific metric over time window
current_time = System.system_time(:second)
start_time = current_time - window
state.aggregations
|> Enum.filter(fn {time, _} -> time >= start_time end)
|> Enum.map(fn {time, snapshot} ->
value = get_metric_value(snapshot, metric)
{time, value}
end)
|> Enum.sort_by(&elem(&1, 0))
end
# Counter operations
defp increment_counter(metric, value \\ 1) do
[{_, counter_ref}] = :ets.lookup(@metrics_table, metric)
:counters.add(counter_ref, 1, value)
end
defp increment_dynamic_counter(table, key, value \\ 1) do
:ets.update_counter(table, key, value, {key, 0})
end
# Histogram operations
defp init_histogram do
%{
count: 0,
sum: 0,
min: nil,
max: nil,
buckets: init_buckets()
}
end
defp record_histogram(metric, value) do
:ets.update_element(@metrics_table, metric, {2, fn hist ->
update_histogram(hist, value)
end})
end
defp update_histogram(hist, value) do
%{
hist |
count: hist.count + 1,
sum: hist.sum + value,
min: min(hist.min || value, value),
max: max(hist.max || value, value),
buckets: update_buckets(hist.buckets, value)
}
end
# Gauge operations
defp increment_gauge(metric, value) do
:ets.update_counter(@metrics_table, metric, value, {metric, 0})
end
defp set_gauge(metric, value) do
:ets.insert(@metrics_table, {metric, value})
end
defp schedule_aggregation do
Process.send_after(self(), :aggregate, @aggregation_interval)
end
defp safe_divide(_num, 0), do: 0.0
defp safe_divide(num, denom), do: num / denom
end
PROVIDER COORDINATION
defmodule DSPex.Prediction.ProviderCoordinator do
@moduledoc """
Multi-provider coordination with intelligent routing and fallbacks.
Features:
- Provider health monitoring
- Load balancing across providers
- Automatic fallback on failures
- Cost-aware routing
- Rate limit management
"""
use GenServer
require Logger
alias DSPex.Providers
@health_check_interval :timer.seconds(30)
@provider_timeout 30_000
# Client API
def start_link(opts \\ []) do
GenServer.start_link(__MODULE__, opts, name: __MODULE__)
end
@doc """
Execute request with automatic provider selection and fallbacks.
"""
def execute(server \\ __MODULE__, prompt, opts \\ []) do
GenServer.call(server, {:execute, prompt, opts},
Keyword.get(opts, :timeout, @provider_timeout))
end
@doc """
Get provider health status.
"""
def get_health_status(server \\ __MODULE__) do
GenServer.call(server, :get_health_status)
end
@doc """
Update provider configuration.
"""
def update_provider_config(server \\ __MODULE__, provider, config) do
GenServer.call(server, {:update_provider_config, provider, config})
end
# Server callbacks
@impl true
def init(opts) do
# Initialize provider states
providers = init_providers(opts)
# Schedule health checks
schedule_health_check()
state = %{
providers: providers,
routing_strategy: Keyword.get(opts, :routing_strategy, :least_loaded),
opts: opts
}
{:ok, state}
end
@impl true
def handle_call({:execute, prompt, opts}, from, state) do
# Select provider based on strategy
case select_provider(state, opts) do
{:ok, provider} ->
# Execute asynchronously
task = Task.async(fn ->
execute_with_provider(provider, prompt, opts)
end)
# Store task reference for monitoring
state = track_execution(state, provider, task, from)
{:noreply, state}
{:error, :no_available_providers} = error ->
{:reply, error, state}
end
end
@impl true
def handle_call(:get_health_status, _from, state) do
status = collect_health_status(state.providers)
{:reply, {:ok, status}, state}
end
@impl true
def handle_call({:update_provider_config, provider, config}, _from, state) do
updated_providers = Map.update!(state.providers, provider, fn p ->
%{p | config: Map.merge(p.config, config)}
end)
{:reply, :ok, %{state | providers: updated_providers}}
end
@impl true
def handle_info({ref, result}, state) when is_reference(ref) do
# Handle task completion
case pop_task(state, ref) do
{nil, state} ->
{:noreply, state}
{{provider, from}, state} ->
# Update provider stats
state = update_provider_stats(state, provider, result)
# Reply to caller
GenServer.reply(from, result)
{:noreply, state}
end
end
@impl true
def handle_info({:DOWN, ref, :process, _pid, reason}, state) do
# Handle task failure
case pop_task(state, ref) do
{nil, state} ->
{:noreply, state}
{{provider, from}, state} ->
# Mark provider as failed
state = mark_provider_failure(state, provider, reason)
# Try fallback
handle_fallback(from, provider, reason, state)
end
end
@impl true
def handle_info(:health_check, state) do
state = perform_health_checks(state)
schedule_health_check()
{:noreply, state}
end
# Private functions
defp init_providers(opts) do
provider_configs = Keyword.get(opts, :providers, default_providers())
Map.new(provider_configs, fn {name, config} ->
{name, %{
name: name,
module: config.module,
config: config.config,
health: :healthy,
stats: init_provider_stats(),
last_health_check: nil,
rate_limiter: init_rate_limiter(config)
}}
end)
end
defp select_provider(state, opts) do
# Filter available providers
available = state.providers
|> Enum.filter(fn {_name, provider} ->
provider.health == :healthy and
check_rate_limit(provider.rate_limiter)
end)
|> Enum.map(&elem(&1, 1))
if available == [] do
{:error, :no_available_providers}
else
# Apply routing strategy
provider = case state.routing_strategy do
:round_robin -> select_round_robin(available, state)
:least_loaded -> select_least_loaded(available)
:fastest -> select_fastest(available)
:cost_optimized -> select_cost_optimized(available, opts)
custom when is_function(custom) -> custom.(available, opts)
end
{:ok, provider}
end
end
defp execute_with_provider(provider, prompt, opts) do
start_time = System.monotonic_time(:millisecond)
try do
# Merge provider config with request opts
provider_opts = Map.merge(provider.config, Keyword.get(opts, :provider_opts, %{}))
# Execute request
result = apply(provider.module, :generate, [prompt, provider_opts])
duration = System.monotonic_time(:millisecond) - start_time
# Emit telemetry
:telemetry.execute(
[:dspex, :provider, :request, :stop],
%{duration: duration, tokens_used: result[:usage][:total_tokens]},
%{provider: provider.name, success: true}
)
{:ok, result}
rescue
e ->
duration = System.monotonic_time(:millisecond) - start_time
:telemetry.execute(
[:dspex, :provider, :request, :stop],
%{duration: duration},
%{provider: provider.name, success: false, error: e}
)
{:error, {:provider_error, provider.name, Exception.message(e)}}
end
end
defp handle_fallback(from, failed_provider, reason, state) do
# Remove failed provider from available list
available_providers = Map.delete(state.providers, failed_provider.name)
# Try with next provider
temp_state = %{state | providers: available_providers}
case select_provider(temp_state, []) do
{:ok, fallback_provider} ->
Logger.warning("Falling back from #{failed_provider.name} to #{fallback_provider.name}")
# Re-execute with fallback
handle_call({:execute, nil, []}, from, state)
{:error, :no_available_providers} ->
GenServer.reply(from, {:error, {:all_providers_failed, reason}})
{:noreply, state}
end
end
defp update_provider_stats(state, provider, {:ok, _result}) do
update_in(state.providers[provider.name].stats, fn stats ->
%{stats |
total_requests: stats.total_requests + 1,
successful_requests: stats.successful_requests + 1,
consecutive_failures: 0,
last_success: System.system_time(:second)
}
end)
end
defp update_provider_stats(state, provider, {:error, _reason}) do
update_in(state.providers[provider.name].stats, fn stats ->
%{stats |
total_requests: stats.total_requests + 1,
failed_requests: stats.failed_requests + 1,
consecutive_failures: stats.consecutive_failures + 1,
last_failure: System.system_time(:second)
}
end)
end
defp mark_provider_failure(state, provider, reason) do
state = update_provider_stats(state, provider, {:error, reason})
# Check if provider should be marked unhealthy
provider = state.providers[provider.name]
if provider.stats.consecutive_failures >= 3 do
Logger.error("Provider #{provider.name} marked unhealthy after consecutive failures")
put_in(state.providers[provider.name].health, :unhealthy)
else
state
end
end
defp perform_health_checks(state) do
# Check each provider's health
Enum.reduce(state.providers, state, fn {name, provider}, acc ->
if should_check_health?(provider) do
check_provider_health(acc, name)
else
acc
end
end)
end
defp check_provider_health(state, provider_name) do
provider = state.providers[provider_name]
# Simple health check - try a minimal request
health_check_result = try do
apply(provider.module, :health_check, [provider.config])
rescue
_ -> :unhealthy
end
# Update provider health status
put_in(state.providers[provider_name], %{
provider |
health: health_check_result,
last_health_check: System.system_time(:second)
})
end
defp should_check_health?(provider) do
case provider.last_health_check do
nil -> true
last_check ->
# Check if enough time has passed
System.system_time(:second) - last_check > 30
end
end
defp select_least_loaded(providers) do
Enum.min_by(providers, fn provider ->
provider.stats.active_requests
end)
end
defp select_fastest(providers) do
Enum.min_by(providers, fn provider ->
if provider.stats.total_requests > 0 do
provider.stats.total_duration / provider.stats.total_requests
else
0
end
end)
end
defp select_cost_optimized(providers, opts) do
max_quality = Keyword.get(opts, :min_quality, 0.8)
providers
|> Enum.filter(fn provider ->
provider.config[:quality_score] >= max_quality
end)
|> Enum.min_by(fn provider ->
provider.config[:cost_per_token] || 1.0
end)
end
defp init_provider_stats do
%{
total_requests: 0,
successful_requests: 0,
failed_requests: 0,
active_requests: 0,
total_duration: 0,
consecutive_failures: 0,
last_success: nil,
last_failure: nil
}
end
defp init_rate_limiter(config) do
%{
max_requests_per_minute: config[:rate_limit] || 1000,
current_window_start: System.system_time(:second),
current_window_count: 0
}
end
defp check_rate_limit(rate_limiter) do
current_time = System.system_time(:second)
window_start = div(current_time, 60) * 60
if window_start > rate_limiter.current_window_start do
# New window
true
else
# Check current window
rate_limiter.current_window_count < rate_limiter.max_requests_per_minute
end
end
defp schedule_health_check do
Process.send_after(self(), :health_check, @health_check_interval)
end
defp default_providers do
[
openai: %{
module: DSPex.Providers.OpenAI,
config: %{
api_key: System.get_env("OPENAI_API_KEY"),
model: "gpt-4",
quality_score: 0.95,
cost_per_token: 0.00003
}
},
anthropic: %{
module: DSPex.Providers.Anthropic,
config: %{
api_key: System.get_env("ANTHROPIC_API_KEY"),
model: "claude-3-opus",
quality_score: 0.98,
cost_per_token: 0.00005
}
}
]
end
end
RESULT VALIDATION AND QUALITY ASSESSMENT
defmodule DSPex.Prediction.Validator do
@moduledoc """
Comprehensive result validation and quality assessment.
Validates:
- Type correctness
- Field constraints
- Output completeness
- Quality metrics
- Consistency checks
"""
alias DSPex.Types
@doc """
Validate prediction outputs against signature.
"""
def validate(outputs, signature) do
with :ok <- validate_required_fields(outputs, signature),
:ok <- validate_field_types(outputs, signature),
:ok <- validate_field_constraints(outputs, signature),
:ok <- validate_output_quality(outputs, signature) do
{:ok, outputs}
end
end
@doc """
Assess output quality with scoring.
"""
def assess_quality(outputs, signature) do
scores = %{
completeness: assess_completeness(outputs, signature),
coherence: assess_coherence(outputs),
relevance: assess_relevance(outputs, signature),
consistency: assess_consistency(outputs, signature)
}
overall_score = calculate_overall_score(scores)
%{
scores: scores,
overall: overall_score,
passed: overall_score >= signature.min_quality_score
}
end
defp validate_required_fields(outputs, signature) do
missing_fields = signature.output_fields
|> Enum.filter(fn {name, spec} -> spec.required end)
|> Enum.map(&elem(&1, 0))
|> Enum.reject(&Map.has_key?(outputs, &1))
if missing_fields == [] do
:ok
else
{:error, {:missing_required_fields, missing_fields}}
end
end
defp validate_field_types(outputs, signature) do
errors = outputs
|> Enum.flat_map(fn {field_name, value} ->
case Map.get(signature.output_fields, field_name) do
nil -> []
field_spec ->
case Types.validate(value, field_spec.type) do
:ok -> []
{:error, reason} -> [{field_name, reason}]
end
end
end)
if errors == [] do
:ok
else
{:error, {:type_validation_failed, errors}}
end
end
defp validate_field_constraints(outputs, signature) do
errors = outputs
|> Enum.flat_map(fn {field_name, value} ->
case Map.get(signature.output_fields, field_name) do
nil -> []
field_spec ->
validate_constraints(value, field_spec.constraints, field_name)
end
end)
if errors == [] do
:ok
else
{:error, {:constraint_validation_failed, errors}}
end
end
defp validate_constraints(value, constraints, field_name) do
constraints
|> Enum.flat_map(fn
{:min_length, min} when is_binary(value) ->
if String.length(value) >= min do
[]
else
[{field_name, {:min_length, min, String.length(value)}}]
end
{:max_length, max} when is_binary(value) ->
if String.length(value) <= max do
[]
else
[{field_name, {:max_length, max, String.length(value)}}]
end
{:pattern, regex} when is_binary(value) ->
if Regex.match?(regex, value) do
[]
else
[{field_name, {:pattern_mismatch, regex}}]
end
{:range, {min, max}} when is_number(value) ->
if value >= min and value <= max do
[]
else
[{field_name, {:out_of_range, {min, max}, value}}]
end
_ -> []
end)
end
defp validate_output_quality(outputs, signature) do
if signature.quality_validation do
quality = assess_quality(outputs, signature)
if quality.passed do
:ok
else
{:error, {:quality_check_failed, quality}}
end
else
:ok
end
end
defp assess_completeness(outputs, signature) do
total_fields = map_size(signature.output_fields)
present_fields = Enum.count(outputs, fn {k, v} ->
Map.has_key?(signature.output_fields, k) and not is_nil(v) and v != ""
end)
present_fields / total_fields * 100
end
defp assess_coherence(outputs) do
# Check text coherence for string outputs
text_values = outputs
|> Map.values()
|> Enum.filter(&is_binary/1)
if text_values == [] do
100.0
else
# Simple coherence check based on sentence structure
coherence_scores = Enum.map(text_values, &calculate_text_coherence/1)
Enum.sum(coherence_scores) / length(coherence_scores)
end
end
defp calculate_text_coherence(text) do
# Basic coherence scoring
sentences = String.split(text, ~r/[.!?]+/)
|> Enum.map(&String.trim/1)
|> Enum.reject(&(&1 == ""))
if length(sentences) == 0 do
50.0
else
# Check for proper sentence structure
valid_sentences = Enum.count(sentences, fn sentence ->
words = String.split(sentence, " ")
length(words) >= 3 and String.match?(hd(words), ~r/^[A-Z]/)
end)
valid_sentences / length(sentences) * 100
end
end
defp assess_relevance(outputs, signature) do
# Check if outputs are relevant to the signature context
# This is a simplified implementation
expected_keywords = extract_keywords(signature)
output_text = outputs
|> Map.values()
|> Enum.filter(&is_binary/1)
|> Enum.join(" ")
|> String.downcase()
matching_keywords = Enum.count(expected_keywords, fn keyword ->
String.contains?(output_text, keyword)
end)
if length(expected_keywords) == 0 do
100.0
else
matching_keywords / length(expected_keywords) * 100
end
end
defp assess_consistency(outputs, _signature) do
# Check for internal consistency in outputs
# Look for contradictions or inconsistencies
# This is a simplified check - in production, you'd want
# more sophisticated consistency checking
100.0
end
defp calculate_overall_score(scores) do
# Weighted average of individual scores
weights = %{
completeness: 0.3,
coherence: 0.3,
relevance: 0.2,
consistency: 0.2
}
Enum.reduce(scores, 0.0, fn {metric, score}, acc ->
acc + score * Map.get(weights, metric, 0.25)
end)
end
defp extract_keywords(signature) do
# Extract keywords from signature for relevance checking
text_parts = [
signature.instructions || "",
signature.desc || ""
]
text_parts
|> Enum.join(" ")
|> String.downcase()
|> String.split(~r/\W+/)
|> Enum.filter(fn word -> String.length(word) > 3 end)
|> Enum.uniq()
end
end
ADVANCED FEATURES
ADAPTIVE STRATEGY SELECTION
defmodule DSPex.Prediction.AdaptiveStrategy do
@moduledoc """
Machine learning-based adaptive strategy selection.
Learns from execution history to select optimal strategies
based on signature characteristics and performance data.
"""
use GenServer
require Logger
@learning_rate 0.1
@exploration_rate 0.1
@min_samples 10
def start_link(opts \\ []) do
GenServer.start_link(__MODULE__, opts, name: __MODULE__)
end
def select_strategy(signature, context) do
GenServer.call(__MODULE__, {:select_strategy, signature, context})
end
def record_outcome(signature, strategy, performance) do
GenServer.cast(__MODULE__, {:record_outcome, signature, strategy, performance})
end
@impl true
def init(_opts) do
# Initialize strategy performance model
model = %{
features: [],
weights: %{},
performance_history: %{}
}
{:ok, model}
end
@impl true
def handle_call({:select_strategy, signature, context}, _from, model) do
features = extract_features(signature, context)
# Epsilon-greedy selection
strategy = if :rand.uniform() < @exploration_rate do
# Explore: random strategy
random_strategy()
else
# Exploit: best predicted strategy
predict_best_strategy(features, model)
end
{:reply, strategy, model}
end
@impl true
def handle_cast({:record_outcome, signature, strategy, performance}, model) do
features = extract_features(signature, %{})
# Update model with new outcome
updated_model = update_model(model, features, strategy, performance)
{:noreply, updated_model}
end
defp extract_features(signature, context) do
%{
input_field_count: map_size(signature.input_fields),
output_field_count: map_size(signature.output_fields),
has_instructions: not is_nil(signature.instructions),
instruction_length: String.length(signature.instructions || ""),
requires_reasoning: signature_requires_reasoning?(signature),
complexity_score: calculate_complexity(signature),
context_size: map_size(context)
}
end
defp predict_best_strategy(features, model) do
# Calculate scores for each strategy
scores = available_strategies()
|> Enum.map(fn strategy ->
score = calculate_strategy_score(strategy, features, model)
{strategy, score}
end)
|> Enum.sort_by(&elem(&1, 1), :desc)
# Return best strategy
{best_strategy, _score} = hd(scores)
best_strategy
end
defp calculate_strategy_score(strategy, features, model) do
# Get historical performance for this strategy
history = Map.get(model.performance_history, strategy, [])
if length(history) < @min_samples do
# Not enough data, use default score
0.5
else
# Calculate weighted score based on feature similarity
similar_performances = history
|> Enum.map(fn {hist_features, performance} ->
similarity = calculate_similarity(features, hist_features)
{similarity, performance}
end)
|> Enum.sort_by(&elem(&1, 0), :desc)
|> Enum.take(10)
# Weighted average of similar performances
total_weight = Enum.sum(similar_performances, &elem(&1, 0))
if total_weight > 0 do
weighted_sum = Enum.sum(similar_performances, fn {sim, perf} ->
sim * perf
end)
weighted_sum / total_weight
else
0.5
end
end
end
defp update_model(model, features, strategy, performance) do
# Add to performance history
history = Map.get(model.performance_history, strategy, [])
updated_history = [{features, performance} | history] |> Enum.take(1000)
%{model |
performance_history: Map.put(
model.performance_history,
strategy,
updated_history
)
}
end
defp calculate_similarity(features1, features2) do
# Cosine similarity between feature vectors
keys = Map.keys(features1)
dot_product = Enum.sum(keys, fn key ->
v1 = Map.get(features1, key, 0)
v2 = Map.get(features2, key, 0)
v1 * v2
end)
mag1 = :math.sqrt(Enum.sum(keys, fn key ->
v = Map.get(features1, key, 0)
v * v
end))
mag2 = :math.sqrt(Enum.sum(keys, fn key ->
v = Map.get(features2, key, 0)
v * v
end))
if mag1 * mag2 > 0 do
dot_product / (mag1 * mag2)
else
0.0
end
end
defp signature_requires_reasoning?(signature) do
# Check if signature likely requires reasoning
indicators = [
"reasoning", "think", "explain", "analyze", "consider",
"step", "conclusion", "because", "therefore"
]
text = [
signature.instructions || "",
signature.desc || ""
]
|> Enum.join(" ")
|> String.downcase()
Enum.any?(indicators, &String.contains?(text, &1))
end
defp calculate_complexity(signature) do
# Simple complexity scoring
base_score = map_size(signature.input_fields) + map_size(signature.output_fields)
instruction_complexity = if signature.instructions do
String.length(signature.instructions) / 100
else
0
end
base_score + instruction_complexity
end
defp available_strategies do
[:standard, :cot, :react, :program]
end
defp random_strategy do
Enum.random(available_strategies())
end
end
INTEGRATION AND TESTING
INTEGRATION WITH ASH FRAMEWORK
defmodule DSPex.Prediction.AshIntegration do
@moduledoc """
Integration with Ash framework for prediction operations.
"""
use Ash.Resource.Change
alias DSPex.Prediction.Engine
@doc """
Ash change for executing predictions.
"""
def change(changeset, opts, _context) do
signature = Keyword.fetch!(opts, :signature)
input_mapping = Keyword.get(opts, :input_mapping, %{})
output_field = Keyword.get(opts, :output_field, :prediction_result)
Ash.Changeset.before_action(changeset, fn changeset ->
# Extract inputs from changeset
inputs = extract_inputs(changeset, input_mapping)
# Execute prediction
case Engine.predict(signature, inputs) do
{:ok, result} ->
Ash.Changeset.force_change_attribute(
changeset,
output_field,
result
)
{:error, reason} ->
Ash.Changeset.add_error(
changeset,
field: output_field,
message: "Prediction failed: #{inspect(reason)}"
)
end
end)
end
defp extract_inputs(changeset, mapping) do
mapping
|> Enum.reduce(%{}, fn {input_field, changeset_field}, acc ->
value = Ash.Changeset.get_attribute(changeset, changeset_field)
Map.put(acc, input_field, value)
end)
end
end
COMPREHENSIVE TESTING
defmodule DSPex.Prediction.EngineTest do
use ExUnit.Case, async: true
alias DSPex.Prediction.Engine
alias DSPex.Signature
setup do
# Start test instance
{:ok, engine} = Engine.start_link(test_mode: true)
# Create test signature
signature = %Signature{
name: "test_qa",
instructions: "Answer questions concisely.",
input_fields: %{
question: %{type: :string, required: true}
},
output_fields: %{
answer: %{type: :string, required: true}
}
}
%{engine: engine, signature: signature}
end
describe "predict/3" do
test "executes successful prediction", %{signature: signature} do
inputs = %{question: "What is 2+2?"}
assert {:ok, result} = Engine.predict(signature, inputs)
assert is_map(result.outputs)
assert Map.has_key?(result.outputs, :answer)
assert is_binary(result.outputs.answer)
end
test "validates required inputs", %{signature: signature} do
inputs = %{} # Missing required question
assert {:error, {:missing_required_fields, [:question]}} =
Engine.predict(signature, inputs)
end
test "respects timeout option", %{signature: signature} do
inputs = %{question: "Simulate timeout"}
opts = [timeout: 100, mock_delay: 200]
assert {:error, :timeout} = Engine.predict(signature, inputs, opts)
end
test "uses caching when enabled", %{signature: signature} do
inputs = %{question: "What is the capital of France?"}
opts = [cache: true]
# First call
assert {:ok, result1} = Engine.predict(signature, inputs, opts)
refute result1.metadata.from_cache
# Second call should hit cache
assert {:ok, result2} = Engine.predict(signature, inputs, opts)
assert result2.metadata.from_cache
assert result1.outputs == result2.outputs
end
end
describe "predict_batch/2" do
test "executes batch predictions", %{signature: signature} do
predictions = [
{signature, %{question: "What is 2+2?"}},
{signature, %{question: "What is the capital of France?"}},
{signature, %{question: "Explain quantum physics"}}
]
assert {:ok, results} = Engine.predict_batch(predictions)
assert length(results) == 3
assert Enum.all?(results, &match?(%{outputs: %{answer: _}}, &1))
end
test "handles partial failures in batch", %{signature: signature} do
predictions = [
{signature, %{question: "Valid question"}},
{signature, %{}}, # Invalid - missing required field
{signature, %{question: "Another valid question"}}
]
assert {:error, {:missing_required_fields, [:question]}} =
Engine.predict_batch(predictions)
end
end
describe "adaptive strategy selection" do
test "selects appropriate strategy based on signature", %{engine: engine} do
# Create reasoning signature
reasoning_signature = %Signature{
name: "reasoning_task",
instructions: "Think step by step and explain your reasoning.",
input_fields: %{
problem: %{type: :string, required: true}
},
output_fields: %{
reasoning: %{type: :string, required: true},
answer: %{type: :string, required: true}
}
}
inputs = %{problem: "Complex math problem"}
# Should select CoT strategy for reasoning
assert {:ok, result} = Engine.predict(reasoning_signature, inputs)
assert result.metadata.strategy == DSPex.Prediction.Strategy.ChainOfThought
end
end
describe "performance metrics" do
test "collects execution metrics", %{signature: signature} do
# Execute several predictions
for i <- 1..5 do
inputs = %{question: "Question #{i}"}
Engine.predict(signature, inputs)
end
# Check metrics
assert {:ok, metrics} = Engine.get_metrics()
assert metrics.counters.predictions_total >= 5
assert metrics.counters.predictions_success >= 5
assert metrics.histograms.prediction_duration_ms.count >= 5
end
end
end
PERFORMANCE CONSIDERATIONS
1. CACHING STRATEGY
- Intelligent cache key generation based on signature and inputs
- TTL-based expiration with configurable durations
- Memory-aware eviction when cache grows too large
2. PARALLEL EXECUTION
- Batch predictions executed with controlled concurrency
- Task.async_stream for efficient parallel processing
- Backpressure handling to prevent overload
3. PROVIDER OPTIMIZATION
- Connection pooling per provider
- Request coalescing for similar predictions
- Automatic retry with exponential backoff
4. MEMORY MANAGEMENT
- ETS tables with size limits
- Periodic cleanup of old data
- Streaming for large result sets
CONFIGURATION
# config/config.exs
config :dspex, :prediction,
# Engine configuration
max_concurrent_predictions: 50,
default_timeout: 30_000,
cache_ttl: 300,
# History configuration
max_history_size: 10_000,
history_cleanup_interval: 300_000,
# Metrics configuration
metrics_aggregation_interval: 30_000,
metrics_retention_period: 3600,
# Provider configuration
providers: [
openai: [
module: DSPex.Providers.OpenAI,
api_key: {:system, "OPENAI_API_KEY"},
default_model: "gpt-4",
rate_limit: 1000
],
anthropic: [
module: DSPex.Providers.Anthropic,
api_key: {:system, "ANTHROPIC_API_KEY"},
default_model: "claude-3-opus",
rate_limit: 500
]
],
# Strategy configuration
routing_strategy: :least_loaded,
adaptive_learning: true,
exploration_rate: 0.1
This implementation provides a comprehensive prediction pipeline system with advanced monitoring, multi-provider coordination, adaptive strategy selection, and extensive performance optimization capabilities.