blob: 7bd077d7979734533ef5fa4c15f944aad38c713e [file] [log] [blame]
// SPDX-License-Identifier: GPL-3.0-or-later OR AGPL-3.0-or-later
// Copyright (C) 2025 Red Hat, Inc.
use crate::bench_args::BenchArgs;
use crate::config::{Config, EndpointTypeConfig};
use crate::conflict_resolver::{Conflict, ConflictResolver};
use crate::git_utils::{ContextLines, GitUtils};
use crate::prob::logprob_to_prob;
use anyhow::{Context, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fs::File;
use std::path::Path;
#[derive(Debug)]
pub struct TestEntry {
patch: String,
code: String,
patch_commit_hash: String,
code_commit_hash: String,
patched_code: String,
filename: String,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
struct TestResult {
entry_index: usize,
model: String,
correct: bool,
correct_aligned: bool,
correct_stripped: bool,
duration: f64,
tokens: Option<u64>,
logprob: Option<f64>,
failed_patched_code: Option<String>,
error: bool,
patch_commit_hash: String,
code_commit_hash: String,
}
#[derive(Debug)]
pub enum LogprobType {
Global,
Errors,
Incorrect,
CorrectStripped,
CorrectAligned,
Correct,
COUNT,
}
#[derive(Debug)]
struct ModelStats {
total: usize,
correct: usize,
correct_aligned: usize,
correct_stripped: usize,
error: usize,
accuracy: f64,
accuracy_aligned: f64,
accuracy_stripped: f64,
error_rate: f64,
avg_tokens: f64,
avg_logprob: [f64; LogprobType::COUNT as usize],
std_logprob: [f64; LogprobType::COUNT as usize],
avg_duration: f64,
}
#[derive(Debug)]
pub struct Bench {
results: Vec<TestResult>,
model_stats: HashMap<String, ModelStats>,
git_diffs: HashMap<String, String>,
line_number_re: regex::Regex,
}
impl Default for Bench {
fn default() -> Self {
Self::new()
}
}
impl Bench {
fn line_number_regex() -> regex::Regex {
regex::Regex::new(r"^@@ -\d+,\d+ \+\d+,\d+ @@").unwrap()
}
pub fn new() -> Self {
Bench {
results: Vec::new(),
model_stats: HashMap::new(),
git_diffs: HashMap::new(),
line_number_re: Self::line_number_regex(),
}
}
pub fn load_database<P: AsRef<Path>>(path: P) -> Result<Vec<TestEntry>> {
let file = File::open(path.as_ref())?;
let mut reader = csv::Reader::from_reader(file);
let mut entries = Vec::new();
for result in reader.records() {
let record = result?;
if record.len() < 6 {
continue;
}
let description = record
.get(2)
.ok_or_else(|| anyhow::anyhow!("Failed to get description from CSV record"))?;
let mut split_desc = description.splitn(2, " / ");
let code_commit_hash = split_desc.next().unwrap_or("").trim().to_string();
let code_commit_hash = format!("{}^", code_commit_hash);
let mut split_desc = split_desc.next().unwrap_or("").split('\n');
let patch_commit_hash = split_desc.next().unwrap_or("").trim().to_string();
let patch = record.get(3).unwrap_or("").to_string();
let code = record.get(4).unwrap_or("").to_string();
let patched_code = record.get(5).unwrap_or("").to_string();
let filename = split_desc.next().unwrap_or("").trim().to_string();
let entry = TestEntry {
patch,
code,
patch_commit_hash,
code_commit_hash,
patched_code,
filename,
};
entries.push(entry);
}
Ok(entries)
}
fn save_checkpoint(&mut self, args: &BenchArgs) -> Result<()> {
let file = File::create(&args.checkpoint_path)?;
let mut writer = csv::Writer::from_writer(file);
for result in &self.results {
writer.serialize(result)?;
}
writer.flush()?;
self.calculate_stats(args);
Ok(())
}
fn load_checkpoint(&mut self, args: &BenchArgs) -> Result<()> {
if !Path::new(&args.checkpoint_path).exists() {
return Ok(());
}
let file = File::open(&args.checkpoint_path)?;
let mut reader = csv::Reader::from_reader(file);
for result in reader.deserialize() {
let result: TestResult = result.with_context(|| "Failed to parse test result")?;
self.results.push(result);
}
self.calculate_stats(args);
Ok(())
}
fn calculate_stats(&mut self, args: &BenchArgs) {
// Initialize stats for all models
let mut model_totals = HashMap::new();
let mut model_correct = HashMap::new();
let mut model_correct_aligned = HashMap::new();
let mut model_correct_stripped = HashMap::new();
let mut model_tokens = HashMap::new();
let mut model_logprob = Vec::with_capacity(LogprobType::COUNT as usize);
for _ in 0..LogprobType::COUNT as usize {
model_logprob.push(HashMap::new());
}
let mut model_durations = HashMap::new();
let mut model_errors = HashMap::new();
// Collect all results by model
for result in self
.results
.iter()
.filter(|x| args.max_entries.is_none() || x.entry_index < args.max_entries.unwrap())
{
let model = &result.model;
*model_totals.entry(model.clone()).or_insert(0) += 1;
if result.correct {
*model_correct.entry(model.clone()).or_insert(0) += 1;
}
if result.correct_aligned {
*model_correct_aligned.entry(model.clone()).or_insert(0) += 1;
}
if result.correct_stripped {
*model_correct_stripped.entry(model.clone()).or_insert(0) += 1;
}
if let Some(tokens) = result.tokens {
model_tokens
.entry(model.clone())
.or_insert_with(Vec::new)
.push(tokens);
}
if let Some(logprob) = result.logprob {
model_logprob[LogprobType::Global as usize]
.entry(model.clone())
.or_insert_with(Vec::new)
.push(logprob);
if result.error {
assert!(!result.correct);
assert!(!result.correct_aligned);
assert!(!result.correct_stripped);
model_logprob[LogprobType::Errors as usize]
.entry(model.clone())
.or_insert_with(Vec::new)
.push(logprob);
}
if result.correct {
assert!(!result.error);
assert!(result.correct_aligned);
assert!(result.correct_stripped);
model_logprob[LogprobType::Correct as usize]
.entry(model.clone())
.or_insert_with(Vec::new)
.push(logprob);
}
if result.correct_aligned {
assert!(!result.error);
assert!(result.correct_aligned);
model_logprob[LogprobType::CorrectAligned as usize]
.entry(model.clone())
.or_insert_with(Vec::new)
.push(logprob);
}
if result.correct_stripped {
assert!(!result.error);
model_logprob[LogprobType::CorrectStripped as usize]
.entry(model.clone())
.or_insert_with(Vec::new)
.push(logprob);
} else {
assert!(!result.error);
model_logprob[LogprobType::Incorrect as usize]
.entry(model.clone())
.or_insert_with(Vec::new)
.push(logprob);
}
}
model_durations
.entry(model.clone())
.or_insert_with(Vec::new)
.push(result.duration);
if result.error {
*model_errors.entry(model.clone()).or_insert(0) += 1;
}
}
// Calculate final stats
self.model_stats.clear();
for (model, total) in model_totals {
let correct = model_correct.get(&model).copied().unwrap_or(0);
let accuracy = correct as f64 / total as f64;
let correct_aligned = model_correct_aligned.get(&model).copied().unwrap_or(0);
let accuracy_aligned = correct_aligned as f64 / total as f64;
let correct_stripped = model_correct_stripped.get(&model).copied().unwrap_or(0);
let accuracy_stripped = correct_stripped as f64 / total as f64;
let error = model_errors.get(&model).copied().unwrap_or(0);
let error_rate = error as f64 / total as f64;
let avg_tokens = model_tokens
.get(&model)
.map(|tokens| tokens.iter().sum::<u64>() as f64 / tokens.len() as f64)
.unwrap_or(f64::INFINITY);
let avg_logprob = [
model_logprob[LogprobType::Global as usize]
.get(&model)
.map(|logprob| logprob.iter().sum::<f64>() / logprob.len() as f64)
.unwrap_or(f64::INFINITY),
model_logprob[LogprobType::Errors as usize]
.get(&model)
.map(|logprob| logprob.iter().sum::<f64>() / logprob.len() as f64)
.unwrap_or(f64::INFINITY),
model_logprob[LogprobType::Incorrect as usize]
.get(&model)
.map(|logprob| logprob.iter().sum::<f64>() / logprob.len() as f64)
.unwrap_or(f64::INFINITY),
model_logprob[LogprobType::CorrectStripped as usize]
.get(&model)
.map(|logprob| logprob.iter().sum::<f64>() / logprob.len() as f64)
.unwrap_or(f64::INFINITY),
model_logprob[LogprobType::CorrectAligned as usize]
.get(&model)
.map(|logprob| logprob.iter().sum::<f64>() / logprob.len() as f64)
.unwrap_or(f64::INFINITY),
model_logprob[LogprobType::Correct as usize]
.get(&model)
.map(|logprob| logprob.iter().sum::<f64>() / logprob.len() as f64)
.unwrap_or(f64::INFINITY),
];
let std_logprob = [
model_logprob[LogprobType::Global as usize]
.get(&model)
.map(|logprob| {
let avg = logprob.iter().map(|x| logprob_to_prob(*x)).sum::<f64>()
/ logprob.len() as f64;
let variance = logprob
.iter()
.map(|x| (logprob_to_prob(*x) - avg).powi(2))
.sum::<f64>()
/ logprob.len() as f64;
variance.sqrt()
})
.unwrap_or(f64::INFINITY),
model_logprob[LogprobType::Errors as usize]
.get(&model)
.map(|logprob| {
let avg = logprob.iter().map(|x| logprob_to_prob(*x)).sum::<f64>()
/ logprob.len() as f64;
let variance = logprob
.iter()
.map(|x| (logprob_to_prob(*x) - avg).powi(2))
.sum::<f64>()
/ logprob.len() as f64;
variance.sqrt()
})
.unwrap_or(f64::INFINITY),
model_logprob[LogprobType::Incorrect as usize]
.get(&model)
.map(|logprob| {
let avg = logprob.iter().map(|x| logprob_to_prob(*x)).sum::<f64>()
/ logprob.len() as f64;
let variance = logprob
.iter()
.map(|x| (logprob_to_prob(*x) - avg).powi(2))
.sum::<f64>()
/ logprob.len() as f64;
variance.sqrt()
})
.unwrap_or(f64::INFINITY),
model_logprob[LogprobType::CorrectStripped as usize]
.get(&model)
.map(|logprob| {
let avg = logprob.iter().map(|x| logprob_to_prob(*x)).sum::<f64>()
/ logprob.len() as f64;
let variance = logprob
.iter()
.map(|x| (logprob_to_prob(*x) - avg).powi(2))
.sum::<f64>()
/ logprob.len() as f64;
variance.sqrt()
})
.unwrap_or(f64::INFINITY),
model_logprob[LogprobType::CorrectAligned as usize]
.get(&model)
.map(|logprob| {
let avg = logprob.iter().map(|x| logprob_to_prob(*x)).sum::<f64>()
/ logprob.len() as f64;
let variance = logprob
.iter()
.map(|x| (logprob_to_prob(*x) - avg).powi(2))
.sum::<f64>()
/ logprob.len() as f64;
variance.sqrt()
})
.unwrap_or(f64::INFINITY),
model_logprob[LogprobType::Correct as usize]
.get(&model)
.map(|logprob| {
let avg = logprob.iter().map(|x| logprob_to_prob(*x)).sum::<f64>()
/ logprob.len() as f64;
let variance = logprob
.iter()
.map(|x| (logprob_to_prob(*x) - avg).powi(2))
.sum::<f64>()
/ logprob.len() as f64;
variance.sqrt()
})
.unwrap_or(f64::INFINITY),
];
let avg_duration = model_durations
.get(&model)
.map(|durations| durations.iter().sum::<f64>() / durations.len() as f64)
.unwrap_or(f64::INFINITY);
self.model_stats.insert(
model,
ModelStats {
total,
correct,
correct_aligned,
correct_stripped,
error,
accuracy,
accuracy_aligned,
accuracy_stripped,
error_rate,
avg_tokens,
avg_logprob,
std_logprob,
avg_duration,
},
);
}
self.print_results();
}
fn print_results(&self) {
println!("\n=== MODEL ACCURACY RESULTS ===");
if self.model_stats.is_empty() {
println!("No results available");
return;
}
let mut sorted_stats: Vec<_> = self.model_stats.iter().collect();
sorted_stats.sort_by(|a, b| b.1.accuracy.partial_cmp(&a.1.accuracy).unwrap());
for (model, stats) in sorted_stats {
println!("\nModel: {}", model);
println!(
" Accuracy: {:.2}% ({}/{})",
stats.accuracy * 100.0,
stats.correct,
stats.total
);
if stats.accuracy_aligned.is_finite() {
println!(
" Accuracy (aligned): {:.2}% ({}/{})",
stats.accuracy_aligned * 100.0,
stats.correct_aligned,
stats.total
);
}
if stats.accuracy_stripped.is_finite() {
println!(
" Accuracy (stripped): {:.2}% ({}/{})",
stats.accuracy_stripped * 100.0,
stats.correct_stripped,
stats.total
);
}
if stats.error_rate.is_finite() {
println!(
" Error Rate: {:.2}% ({}/{})",
stats.error_rate * 100.0,
stats.error,
stats.total
);
}
if stats.avg_tokens.is_finite() {
println!(" Average tokens: {:.2}", stats.avg_tokens);
}
if stats.avg_duration.is_finite() {
println!(" Average duration: {:.2} s", stats.avg_duration);
}
let logprob_type_names = [
"Average prob",
"Average prob (errors)",
"Average prob (incorrect)",
"Average prob (stripped)",
"Average prob (aligned)",
"Average prob (correct)",
];
for (i, &value) in stats.avg_logprob.iter().enumerate() {
if value.is_finite() {
println!(
" {}: {:.1}% (+- {:.1})",
logprob_type_names[i],
logprob_to_prob(value),
stats.std_logprob[i]
);
}
}
}
}
pub async fn run_test(
&mut self,
config: &Config,
entries: &[TestEntry],
args: BenchArgs,
) -> Result<()> {
println!("Running statistics test on {} entries", entries.len());
let context_lines = ContextLines {
code_context_lines: args.code_context_lines,
diff_context_lines: args.diff_context_lines,
patch_context_lines: args.patch_context_lines,
};
// Create a new GitUtils instance to find the commit hash
let git_utils = GitUtils::new(context_lines.clone(), false);
// Load existing checkpoint
self.load_checkpoint(&args)?;
println!(
"Loaded {} existing results from checkpoint",
self.results.len()
);
let model_names = self.get_all_model_names(config);
let (tx, mut rx) = tokio::sync::mpsc::channel(1);
tokio::spawn(async move {
loop {
let exit_code;
let mut sigterm =
tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
.unwrap();
tokio::select! {
_ = tokio::signal::ctrl_c() => {
exit_code = 130;
}
_ = sigterm.recv() => {
exit_code = 143;
}
}
tx.try_send(exit_code)
.unwrap_or_else(|_| std::process::exit(exit_code));
println!("Saving checkpoint and exiting...");
}
});
let mut modified = false;
for (i, entry) in entries
.iter()
.enumerate()
.take(args.max_entries.unwrap_or(usize::MAX))
{
// Skip entries that are already processed
if self
.results
.iter()
.filter(|r| model_names.contains(&r.model) && r.entry_index == i)
.count()
== model_names.len()
{
continue;
}
let processing_msg = format!("Processing entry {} of {}...", i + 1, entries.len());
log::info!("{}", processing_msg);
println!("{}", processing_msg);
// Create conflict from test entry
let git_diff = self.git_diffs.get(&entry.patch_commit_hash).cloned();
let git_diff = git_diff.or_else(|| {
// cache only the current commit
self.git_diffs.clear();
// Find the commit hash from the patch_commit_hash
let commit_hash = &entry.patch_commit_hash;
// Extract the diff from git
// Try each directory
if let Some(diff) =
self.git_show_dirs(&git_utils, &args.git_directories, commit_hash, None)
{
// Store the diff for future use
self.git_diffs.insert(commit_hash.clone(), diff.clone());
return Some(diff);
}
panic!("Git diff for commit {} not found", commit_hash);
});
let conflict = self.create_conflict_from_entry(entry)?;
let resolver = ConflictResolver::new(context_lines.clone(), config, git_diff, true);
let resolved_conflicts = resolver.resolve_conflicts(&[conflict]).await;
match resolved_conflicts {
Ok((resolved_conflicts, resolved_errors)) => {
assert!(!(resolved_conflicts.is_empty() && resolved_errors.errors.is_empty()));
for (model_name, error_count) in resolved_errors.errors.iter() {
let test_result = TestResult {
entry_index: i,
model: model_name.clone(),
correct: false,
correct_aligned: false,
correct_stripped: false,
duration: 0.0,
tokens: None,
logprob: None,
failed_patched_code: None,
error: true,
patch_commit_hash: entry.patch_commit_hash.clone(),
code_commit_hash: entry.code_commit_hash.clone(),
};
for _ in 0..*error_count {
self.results.push(test_result.clone());
}
}
for resolved_conflict in resolved_conflicts.iter() {
let test_result = TestResult {
entry_index: i,
model: resolved_conflict.model.clone(),
correct: resolved_conflict.resolved_version == entry.patched_code,
correct_aligned: self
.aligned(&resolved_conflict.resolved_version, &entry.patched_code),
correct_stripped: self
.stripped(&resolved_conflict.resolved_version, &entry.patched_code),
duration: resolved_conflict.duration,
tokens: resolved_conflict.total_tokens,
logprob: resolved_conflict.logprob,
failed_patched_code: if resolved_conflict.resolved_version
== entry.patched_code
{
None
} else {
Some(ConflictResolver::create_diff(
&resolved_conflict.resolved_version,
&entry.patched_code.clone(),
1,
))
},
error: false,
patch_commit_hash: entry.patch_commit_hash.clone(),
code_commit_hash: entry.code_commit_hash.clone(),
};
self.results.push(test_result);
}
}
Err(e) => anyhow::bail!("Failed to resolve conflicts: {}", e),
};
modified = true;
if let Ok(exit_code) = rx.try_recv() {
self.save_checkpoint(&args)?;
std::process::exit(exit_code);
}
// Save checkpoint periodically
if (i + 1) % args.checkpoint_interval == 0 {
self.save_checkpoint(&args)?;
modified = false;
}
}
// Save final checkpoint
if modified {
self.save_checkpoint(&args)?;
}
Ok(())
}
fn git_show_dirs(
&self,
git_utils: &GitUtils,
dirs: &[String],
commit_hash: &str,
filename: Option<&str>,
) -> Option<String> {
for dir in dirs {
if let Ok(Some(diff)) = git_utils.git_show_in_dir(commit_hash, Some(dir), filename) {
return Some(diff);
}
}
None
}
fn get_all_model_names(&mut self, config: &Config) -> std::collections::HashSet<String> {
let mut model_names = std::collections::HashSet::new();
// Collect all model names from endpoints configuration
for endpoint in config.get_all_endpoints() {
match &endpoint.config {
EndpointTypeConfig::OpenAI { variants, .. }
| EndpointTypeConfig::Anthropic { variants, .. } => {
if let Some(variants) = variants {
for variant in variants.iter() {
let variant_name = if let Some(variant) = &*variant.name {
format!("{} ({})", endpoint.name, variant)
} else {
endpoint.name.clone()
};
model_names.insert(variant_name);
}
} else {
// No variants, just the endpoint name
model_names.insert(endpoint.name.clone());
}
}
EndpointTypeConfig::Patchpal { .. } => {
// For patchpal, we have 3 variants (as per existing logic)
model_names.insert(endpoint.name.to_string());
for y in 1..3 {
model_names.insert(format!("{} (#{})", endpoint.name, y));
}
}
}
}
assert!(!model_names.is_empty());
model_names
}
fn stripped(&self, resolved: &str, expected: &str) -> bool {
self.__stripped(resolved) == self.__stripped(expected)
}
fn __stripped(&self, s: &str) -> String {
s.split_whitespace().collect::<Vec<_>>().join(" ")
}
fn aligned(&self, resolved: &str, expected: &str) -> bool {
self.__aligned(resolved) == self.__aligned(expected)
}
fn __aligned(&self, s: &str) -> String {
s.lines()
.filter(|line| line.chars().any(|c| !c.is_whitespace()))
.map(|line| {
let mut result = String::new();
let mut seen_non_whitespace = false;
let mut last_was_whitespace = false;
for c in line.chars() {
if c.is_whitespace() {
if !seen_non_whitespace {
result.push(c);
} else {
last_was_whitespace = true;
}
} else {
if last_was_whitespace {
result.push(' ');
}
result.push(c);
seen_non_whitespace = true;
}
}
result
})
.collect::<Vec<_>>()
.join("\n")
}
fn create_conflict_from_entry(&self, entry: &TestEntry) -> Result<Conflict> {
let mut base_lines = Vec::new();
let mut remote_lines = Vec::new();
let mut nr_head_context_lines = 0;
let mut found_first_change = false;
let mut line_count = 0;
for line in entry.patch.split_inclusive('\n') {
if self.line_number_re.is_match(line) {
continue;
}
if let Some(line) = line.strip_prefix('+') {
remote_lines.push(line.to_string());
if !found_first_change {
nr_head_context_lines = line_count;
found_first_change = true;
}
line_count = 0;
continue;
} else if let Some(line) = line.strip_prefix('-') {
base_lines.push(line.to_string());
if !found_first_change {
nr_head_context_lines = line_count;
found_first_change = true;
}
line_count = 0;
continue;
} else if let Some(line) = line.strip_prefix(' ') {
base_lines.push(line.to_string());
remote_lines.push(line.to_string());
line_count += 1;
continue;
}
panic!("malformed patch: {:?}", entry.patch);
}
let nr_tail_context_lines = line_count;
let base = base_lines.join("");
let remote = remote_lines.join("");
Ok(Conflict {
file_path: entry.filename.clone(),
local: entry.code.clone(),
base,
remote,
head_context: String::new(),
tail_context: String::new(),
start_line: 0,
remote_end: 0,
nr_head_context_lines,
nr_tail_context_lines,
marker_size: 0,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_test_result_serialization() {
let result = TestResult {
entry_index: 0,
model: "test_model".to_string(),
correct: true,
correct_aligned: true,
correct_stripped: true,
duration: 0.0,
tokens: None,
logprob: None,
failed_patched_code: None,
error: false,
patch_commit_hash: "abc123".to_string(),
code_commit_hash: "def456".to_string(),
};
let serialized = serde_json::to_string(&result).unwrap();
let deserialized: TestResult = serde_json::from_str(&serialized).unwrap();
assert_eq!(result.entry_index, deserialized.entry_index);
assert_eq!(result.model, deserialized.model);
assert_eq!(result.correct, deserialized.correct);
assert_eq!(result.correct_aligned, deserialized.correct_aligned);
assert_eq!(result.correct_stripped, deserialized.correct_stripped);
assert_eq!(result.duration, deserialized.duration);
assert_eq!(result.tokens, deserialized.tokens);
assert_eq!(result.logprob, deserialized.logprob);
assert_eq!(result.failed_patched_code, deserialized.failed_patched_code);
assert_eq!(result.error, deserialized.error);
assert_eq!(result.patch_commit_hash, deserialized.patch_commit_hash);
assert_eq!(result.code_commit_hash, deserialized.code_commit_hash);
}
}
// Local Variables:
// rust-format-on-save: t
// End: