← Back to Gap analysis

DSPEX GAP ANALYSIS 12 code

Documentation for DSPEX_GAP_ANALYSIS_12_code from the Ds ex repository.

12. NEW: Memory-Efficient Trajectory Management

defmodule DSPEx.Teleprompter.SIMBA.TrajectoryManager do
  @moduledoc """
  Memory-efficient trajectory management with selective storage and compression.
  """
  
  use GenServer
  
  alias DSPEx.Teleprompter.SIMBA.Trajectory
  
  @type trajectory_summary :: %{
    score: float(),
    program_type: atom(),
    execution_time: non_neg_integer(),
    success: boolean(),
    hash: binary()
  }
  
  defstruct [
    :max_trajectories,
    :compression_threshold,
    trajectories: [],
    summaries: [],
    total_stored: 0,
    memory_usage: 0
  ]
  
  @type t :: %__MODULE__{
    max_trajectories: pos_integer(),
    compression_threshold: pos_integer(),
    trajectories: [Trajectory.t()],
    summaries: [trajectory_summary()],
    total_stored: non_neg_integer(),
    memory_usage: non_neg_integer()
  }
  
  # Client API
  
  def start_link(opts \\ []) do
    GenServer.start_link(__MODULE__, opts, name: __MODULE__)
  end
  
  def store_trajectory(trajectory) do
    GenServer.call(__MODULE__, {:store_trajectory, trajectory})
  end
  
  def get_recent_trajectories(count \\ 100) do
    GenServer.call(__MODULE__, {:get_recent, count})
  end
  
  def get_trajectory_statistics() do
    GenServer.call(__MODULE__, :get_statistics)
  end
  
  def cleanup_old_trajectories() do
    GenServer.cast(__MODULE__, :cleanup)
  end
  
  # Server Callbacks
  
  @impl GenServer
  def init(opts) do
    state = %__MODULE__{
      max_trajectories: Keyword.get(opts, :max_trajectories, 1000),
      compression_threshold: Keyword.get(opts, :compression_threshold, 500)
    }
    
    # Schedule periodic cleanup
    :timer.send_interval(30_000, :cleanup)
    
    {:ok, state}
  end
  
  @impl GenServer
  def handle_call({:store_trajectory, trajectory}, _from, state) do
    {updated_state, stored} = store_trajectory_internal(state, trajectory)
    {:reply, stored, updated_state}
  end
  
  def handle_call({:get_recent, count}, _from, state) do
    recent = Enum.take(state.trajectories, count)
    {:reply, recent, state}
  end
  
  def handle_call(:get_statistics, _from, state) do
    stats = calculate_statistics(state)
    {:reply, stats, state}
  end
  
  @impl GenServer
  def handle_cast(:cleanup, state) do
    updated_state = cleanup_trajectories(state)
    {:noreply, updated_state}
  end
  
  @impl GenServer
  def handle_info(:cleanup, state) do
    updated_state = cleanup_trajectories(state)
    {:noreply, updated_state}
  end
  
  # Internal Functions
  
  defp store_trajectory_internal(state, trajectory) do
    # Calculate trajectory hash for deduplication
    trajectory_hash = calculate_trajectory_hash(trajectory)
    
    # Check if we already have this trajectory
    existing_hash = Enum.find(state.summaries, &(&1.hash == trajectory_hash))
    
    if existing_hash do
      {state, false}  # Don't store duplicate
    else
      # Store trajectory and create summary
      summary = create_trajectory_summary(trajectory, trajectory_hash)
      
      updated_trajectories = [trajectory | state.trajectories]
      updated_summaries = [summary | state.summaries]
      
      # Check if we need compression/cleanup
      updated_state = %{state |
        trajectories: updated_trajectories,
        summaries: updated_summaries,
        total_stored: state.total_stored + 1
      }
      
      final_state = if length(updated_trajectories) > state.compression_threshold do
        compress_old_trajectories(updated_state)
      else
        updated_state
      end
      
      {final_state, true}
    end
  end
  
  defp calculate_trajectory_hash(trajectory) do
    hash_data = %{
      inputs: trajectory.inputs,
      outputs: trajectory.outputs,
      program_type: trajectory.metadata[:program_type],
      model_config: trajectory.model_config
    }
    
    hash_data
    |> :erlang.term_to_binary()
    |> :crypto.hash(:sha256)
  end
  
  defp create_trajectory_summary(trajectory, hash) do
    %{
      score: trajectory.score,
      program_type: trajectory.metadata[:program_type] || :unknown,
      execution_time: trajectory.duration,
      success: trajectory.success,
      hash: hash
    }
  end
  
  defp compress_old_trajectories(state) do
    # Keep recent high-performing trajectories, compress the rest
    {keep_full, compress} = state.trajectories
      |> Enum.with_index()
      |> Enum.split_with(fn {trajectory, index} ->
        index < 100 or trajectory.score > 0.8  # Keep recent or high-scoring
      end)
    
    kept_trajectories = Enum.map(keep_full, fn {trajectory, _} -> trajectory end)
    compressed_summaries = Enum.map(compress, fn {trajectory, _} ->
      create_trajectory_summary(trajectory, calculate_trajectory_hash(trajectory))
    end)
    
    %{state |
      trajectories: kept_trajectories,
      summaries: state.summaries ++ compressed_summaries
    }
  end
  
  defp cleanup_trajectories(state) do
    # Remove oldest trajectories if we exceed max_trajectories
    if length(state.trajectories) > state.max_trajectories do
      kept_trajectories = Enum.take(state.trajectories, state.max_trajectories)
      %{state | trajectories: kept_trajectories}
    else
      state
    end
  end
  
  defp calculate_statistics(state) do
    trajectory_scores = Enum.map(state.trajectories, & &1.score)
    summary_scores = Enum.map(state.summaries, & &1.score)
    all_scores = trajectory_scores ++ summary_scores
    
    %{
      total_trajectories: length(state.trajectories),
      total_summaries: length(state.summaries),
      total_stored: state.total_stored,
      memory_usage_estimate: estimate_memory_usage(state),
      score_statistics: if Enum.empty?(all_scores) do
        %{}
      else
        %{
          mean: Enum.sum(all_scores) / length(all_scores),
          min: Enum.min(all_scores),
          max: Enum.max(all_scores),
          count: length(all_scores)
        }
      end
    }
  end
  
  defp estimate_memory_usage(state) do
    # Rough estimate of memory usage
    trajectory_size = length(state.trajectories) * 1024  # Assume ~1KB per trajectory
    summary_size = length(state.summaries) * 100        # Assume ~100B per summary
    trajectory_size + summary_size
  end
end