import numpy as np

def display_comparison_results(comparison_results):
    """
    Display comparison results in a formatted table
    
    Args:
        comparison_results (dict): Results from model comparison
    """
    print("\n" + "=" * 100)
    print("MODEL PERFORMANCE COMPARISON")
    print("=" * 100)
    
    # Create header
    header = f"{'Model':<15} {'Accuracy':<10} {'Precision':<10} {'Recall':<10} {'F1-Score':<10} {'MCC':<10} {'AUC-ROC':<10} {'Avg Time':<11} {'Status':<10}"
    print(header)
    print("-" * len(header))
    
    # Display results for each model
    label_mismatch_models = []
    wrong_prediction_models = []
    
    for model_name, metrics in comparison_results.items():
        accuracy = f"{metrics['accuracy']:.2%}"
        precision = f"{metrics['precision']:.3f}"
        recall = f"{metrics['recall']:.3f}"
        f1_score = f"{metrics['f1_score']:.3f}"
        mcc = f"{metrics['mcc']:.3f}"
        auc_roc = f"{metrics['auc_roc']:.3f}"
        avg_time = f"{metrics['avg_inference_time']:.3f}s"
        status = "Good" if metrics['valid_predictions'] > 0 else "Error"
        
        row = f"{model_name:<15} {accuracy:<10} {precision:<10} {recall:<10} {f1_score:<10} {mcc:<10} {auc_roc:<10} {avg_time:<11} {status:<10}"
        print(row)
        
        # Track models with label mismatches for the NOTE section
        if metrics.get('label_mismatch', False):
            label_mismatch_models.append(model_name)
            
        # Track models with wrong predictions
        if metrics['accuracy'] < 1.0:
            wrong_prediction_models.append(model_name)
    
    # Add NOTE section for zero metrics
    zero_metrics_models = []
    for model_name, metrics in comparison_results.items():
        if metrics['precision'] == 0.0 and metrics['recall'] == 0.0 and metrics['f1_score'] == 0.0:
            zero_metrics_models.append(model_name)
    
    if zero_metrics_models:
        print(f"\nNOTE: Precision, Recall, F1-Score, and MCC show 0.000 for challenging datasets")
        print("      This occurs when models predict classes not present in ambiguous ground truth labels.")
        print("      Accuracy remains meaningful, but other metrics become undefined.")
    
    # Add NOTE section for label mismatches
    if label_mismatch_models:
        print("\nNOTE: The following models predicted classes not present in the ground truth:")
        for model in label_mismatch_models:
            metrics = comparison_results[model]
            
            # Calculate number of errors
            total_samples = metrics['valid_predictions']
            correct_samples = int(round(metrics['accuracy'] * total_samples))
            error_count = total_samples - correct_samples
            
            # Show error information
            print(f"  - {model}: Made {error_count} incorrect prediction(s) out of {total_samples} samples")
    
    # Add NOTE section for wrong predictions (when no label mismatch)
    if wrong_prediction_models and not label_mismatch_models:
        print("\nNOTE: The following models made incorrect predictions:")
        for model in wrong_prediction_models:
            # Calculate number of errors
            metrics = comparison_results[model]
            total_samples = metrics['valid_predictions']
            correct_samples = int(round(metrics['accuracy'] * total_samples))
            error_count = total_samples - correct_samples
            
            # Show error information
            print(f"  - {model}: Made {error_count} incorrect prediction(s) out of {total_samples} samples")
    
    print("\n" + "=" * 100)
    
    # Find and display best models
    valid_models = {k: v for k, v in comparison_results.items() if v['valid_predictions'] > 0}
    
    if valid_models:
        # Check if we have multiple models for comparison
        if len(valid_models) == 1:
            model_name = list(valid_models.keys())[0]
            print("SINGLE MODEL EVALUATION:")
            print(f"   → {model_name} completed evaluation successfully")
            print(f"   → No comparison available (only one model tested)")
            return
            
        # Best accuracy (only for multiple models)        # Check if we have multiple models for comparison
        if len(valid_models) == 1:
            model_name = list(valid_models.keys())[0]
            print("SINGLE MODEL EVALUATION:")
            print(f"   → {model_name} completed evaluation successfully")
            print("   → No comparison available (only one model tested)")
            return
            
        # Best accuracy (only for multiple models)
        best_accuracy = max(valid_models.keys(), key=lambda x: valid_models[x]['accuracy'])
        print(f"BEST ACCURACY: {best_accuracy} ({valid_models[best_accuracy]['accuracy']:.2%})")
        
        # Best balanced accuracy
        best_balanced = max(valid_models.keys(), key=lambda x: valid_models[x]['balanced_accuracy'])
        print(f"BEST BALANCED ACCURACY: {best_balanced} ({valid_models[best_balanced]['balanced_accuracy']:.2%})")
        
        # Best F1-score
        best_f1 = max(valid_models.keys(), key=lambda x: valid_models[x]['f1_score'])
        print(f"BEST F1-SCORE: {best_f1} ({valid_models[best_f1]['f1_score']:.3f})")
        
        # Best MCC
        best_mcc = max(valid_models.keys(), key=lambda x: valid_models[x]['mcc'])
        print(f"BEST MCC: {best_mcc} ({valid_models[best_mcc]['mcc']:.3f})")
        
        # Best AUC-ROC
        best_auc = max(valid_models.keys(), key=lambda x: valid_models[x]['auc_roc'])
        print(f"BEST AUC-ROC: {best_auc} ({valid_models[best_auc]['auc_roc']:.3f})")
        
        # Fastest model
        fastest = min(valid_models.keys(), key=lambda x: valid_models[x]['avg_inference_time'])
        print(f"FASTEST MODEL: {fastest} ({valid_models[fastest]['avg_inference_time']:.3f}s avg)")
        
        # Most consistent model (lowest p99/avg ratio)
        most_consistent = min(
            valid_models.keys(), 
            key=lambda x: valid_models[x]['latency_p99'] / (valid_models[x]['avg_inference_time'] + 0.001)
        )
        consistency_ratio = valid_models[most_consistent]['latency_p99'] / (valid_models[most_consistent]['avg_inference_time'] + 0.001)
        print(f"MOST CONSISTENT: {most_consistent} (p99/avg ratio: {consistency_ratio:.1f})")
        
        print("\n" + "=" * 100)
        
        # Overall recommendation (only for multiple models)
        print("RECOMMENDATION:")
        
        # Find the best overall model based on a weighted score
        weighted_scores = {}
        for name, metrics in valid_models.items():
            # Normalize metrics to 0-1 scale
            epsilon = 1e-10  # Small value to prevent division by zero
            acc_norm = metrics['accuracy'] / (max(m['accuracy'] for m in valid_models.values()) + epsilon)
            bal_acc_norm = metrics['balanced_accuracy'] / (max(m['balanced_accuracy'] for m in valid_models.values()) + epsilon)
            f1_norm = metrics['f1_score'] / (max(m['f1_score'] for m in valid_models.values() or [1]) + epsilon)
            mcc_norm = (metrics['mcc'] + 1) / 2  # MCC ranges from -1 to 1, normalize to 0-1
            auc_norm = metrics['auc_roc'] / (max(m['auc_roc'] for m in valid_models.values() or [1]) + epsilon)
            
            # Invert time so lower is better
            time_values = [m['avg_inference_time'] for m in valid_models.values()]
            time_norm = min(time_values) / (metrics['avg_inference_time'] + 0.001)
            
            # Calculate weighted score (adjust weights as needed)
            weighted_scores[name] = (0.25 * acc_norm) + (0.15 * bal_acc_norm) + (0.2 * f1_norm) + (0.2 * mcc_norm) + (0.1 * auc_norm) + (0.1 * time_norm)
        
        best_overall = max(weighted_scores.keys(), key=lambda x: weighted_scores[x])
        print(f"   → {best_overall} is the best overall model (balanced performance and speed)")
        
        if best_accuracy != best_overall:
            print(f"   → {best_accuracy} for highest accuracy")
            
        if best_f1 != best_overall and best_f1 != best_accuracy:
            print(f"   → {best_f1} for best F1-score")
            
        if fastest != best_overall and fastest != best_accuracy and fastest != best_f1:
            print(f"   → Consider {fastest} if inference speed is critical")
    
    else:
        print("No valid model results found. Please check your endpoint names.")

def plot_confusion_matrices_all_datasets(all_results):
    """
    Plot confusion matrices for all models across all datasets
    
    Args:
        all_results (dict): Dictionary with dataset names as keys and comparison results as values
    """
    dataset_names = list(all_results.keys())
    
    # Get all model names from the first dataset
    model_names = list(all_results[dataset_names[0]].keys())
    
    # Filter to only include models that are in all datasets
    common_models = set(model_names)
    for dataset_results in all_results.values():
        common_models = common_models.intersection(set(dataset_results.keys()))
    
    common_models = list(common_models)
    
    if not common_models:
        print("No common models found across all datasets.")
        return
    
    # Create a figure with subplots for each model and dataset
    n_models = len(common_models)
    n_datasets = len(dataset_names)
    
    # Increase the figure height to accommodate the accuracy text
    fig, axes = plt.subplots(n_models, n_datasets, figsize=(5*n_datasets, 4.5*n_models))
    
    # Ensure axes is always a 2D array for consistent indexing
    if n_models == 1 and n_datasets == 1:
        # Single subplot case - wrap in 2D array
        axes = np.array([[axes]])
    elif n_models == 1:
        # Single row case - ensure it's 2D
        axes = np.array([axes]) if n_datasets > 1 else np.array([[axes]])
    elif n_datasets == 1:
        # Single column case - ensure it's 2D  
        axes = np.array([[ax] for ax in axes]) if n_models > 1 else np.array([[axes]])
    
    for i, model_name in enumerate(common_models):
        for j, dataset_name in enumerate(dataset_names):
            metrics = all_results[dataset_name][model_name]
            cm = metrics['confusion_matrix']
            
            # Plot confusion matrix
            sns.heatmap(
                cm, 
                annot=True, 
                fmt='d', 
                cmap='Blues',
                xticklabels=['Negative', 'Positive'],
                yticklabels=['Negative', 'Positive'],
                ax=axes[i, j]
            )
            
            # Add accuracy information to the plot (moved higher up)
            axes[i, j].text(0.5, -0.25, f"Accuracy: {metrics['accuracy']:.2%}", 
                          horizontalalignment='center',
                          transform=axes[i, j].transAxes)
            
            # Set title based on position
            if i == 0:
                axes[i, j].set_title(f"{dataset_name} Dataset")
            
            if j == 0:
                axes[i, j].set_ylabel(f"{model_name}")
            
            # Only show x-label on bottom row
            if i == n_models - 1:
                # Move the x-label up to avoid overlap with accuracy
                axes[i, j].set_xlabel('Predicted Label', labelpad=20)
            else:
                axes[i, j].set_xlabel('')
            
            # Only show y-label on first column
            if j == 0:
                axes[i, j].set_ylabel(f"{model_name}\nTrue Label")
            else:
                axes[i, j].set_ylabel('')
    
    # Increase spacing between subplots
    plt.tight_layout(h_pad=1.0, w_pad=0.5)
    plt.subplots_adjust(bottom=0.15)  # Add more space at the bottom
    plt.suptitle("Confusion Matrices Across All Datasets", y=1.02, fontsize=16)
    plt.show()

def plot_roc_curves_all_datasets(all_results):
    """
    Plot ROC curves for all models across all datasets
    
    Args:
        all_results (dict): Dictionary with dataset names as keys and comparison results as values
    """
    dataset_names = list(all_results.keys())
    
    # Create a figure with subplots for each dataset
    fig, axes = plt.subplots(1, len(dataset_names), figsize=(6*len(dataset_names), 5))
    
    # Handle case with only one dataset
    if len(dataset_names) == 1:
        axes = [axes]
    
    for i, dataset_name in enumerate(dataset_names):
        comparison_results = all_results[dataset_name]
        valid_models = {k: v for k, v in comparison_results.items() if v['valid_predictions'] > 0}
        
        # Check if any model has a meaningful AUC-ROC (> 0.55)
        meaningful_models = {k: v for k, v in valid_models.items() if v['auc_roc'] > 0.55}
        
        if not meaningful_models:
            axes[i].text(0.5, 0.5, f"No meaningful ROC curves\nfor {dataset_name} dataset", 
                        ha='center', va='center', transform=axes[i].transAxes)
            axes[i].set_title(f"{dataset_name} Dataset")
            continue
        
        has_curves = False
        for model_name, metrics in meaningful_models.items():
            if 'true_labels' in metrics and 'pos_probs' in metrics:
                # Only plot if there are two unique classes in the true labels
                if len(np.unique(metrics['true_labels'])) > 1:
                    try:
                        fpr, tpr, _ = roc_curve(metrics['true_labels'], metrics['pos_probs'])
                        roc_auc = metrics['auc_roc']
                        axes[i].plot(fpr, tpr, lw=2, label=f'{model_name} (AUC = {roc_auc:.3f})')
                        has_curves = True
                    except:
                        print(f"Could not plot ROC curve for {model_name} on {dataset_name} dataset")
        
        if not has_curves:
            axes[i].text(0.5, 0.5, f"No valid ROC curves\nfor {dataset_name} dataset", 
                        ha='center', va='center', transform=axes[i].transAxes)
        else:
            # Plot diagonal line (random classifier)
            axes[i].plot([0, 1], [0, 1], 'k--', lw=2)
            
            axes[i].set_xlim([0.0, 1.0])
            axes[i].set_ylim([0.0, 1.05])
            axes[i].set_xlabel('False Positive Rate')
            axes[i].set_ylabel('True Positive Rate')
            axes[i].set_title(f"ROC Curves - {dataset_name} Dataset")
            axes[i].legend(loc="lower right")
            axes[i].grid(True)
    
    plt.tight_layout()
    plt.show()

def plot_confidence_distribution_all_datasets(all_results):
    """
    Plot confidence score distributions for models across all datasets
    
    Args:
        all_results (dict): Dictionary with dataset names as keys and comparison results as values
    """
    dataset_names = list(all_results.keys())
    
    # Get all model names from the first dataset
    model_names = list(all_results[dataset_names[0]].keys())
    
    # Filter to only include models that are in all datasets
    common_models = set(model_names)
    for dataset_results in all_results.values():
        common_models = common_models.intersection(set(dataset_results.keys()))
    
    common_models = list(common_models)
    
    if not common_models:
        print("No common models found across all datasets.")
        return
    
    # Create a figure with subplots for each model and dataset
    n_models = len(common_models)
    n_datasets = len(dataset_names)
    
    fig, axes = plt.subplots(n_models, n_datasets, figsize=(5*n_datasets, 4*n_models))
    
    # Ensure axes is always a 2D array for consistent indexing
    if n_models == 1 and n_datasets == 1:
        # Single subplot case - wrap in 2D array
        axes = np.array([[axes]])
    elif n_models == 1:
        # Single row case - ensure it's 2D
        axes = np.array([axes]) if n_datasets > 1 else np.array([[axes]])
    elif n_datasets == 1:
        # Single column case - ensure it's 2D  
        axes = np.array([[ax] for ax in axes]) if n_models > 1 else np.array([[axes]])
    
    for i, model_name in enumerate(common_models):
        for j, dataset_name in enumerate(dataset_names):
            metrics = all_results[dataset_name][model_name]
            
            if 'confidences' in metrics and len(metrics.get('confidences', [])) > 0:
                # Split confidences by correct/incorrect predictions
                correct_indices = np.where(metrics['true_labels'] == metrics['pred_labels'])[0]
                incorrect_indices = np.where(metrics['true_labels'] != metrics['pred_labels'])[0]
                
                confidences = np.array(metrics['confidences'])
                
                if len(correct_indices) > 0:
                    correct_confidences = confidences[correct_indices]
                    sns.histplot(correct_confidences, ax=axes[i, j], bins=10, color='green', alpha=0.7, label='Correct')
                
                if len(incorrect_indices) > 0:
                    incorrect_confidences = confidences[incorrect_indices]
                    sns.histplot(incorrect_confidences, ax=axes[i, j], bins=10, color='red', alpha=0.7, label='Incorrect')
                
                # Set title based on position
                if i == 0:
                    axes[i, j].set_title(f"{dataset_name} Dataset")
                
                # Only show x-label on bottom row
                if i == n_models - 1:
                    axes[i, j].set_xlabel('Confidence Score')
                else:
                    axes[i, j].set_xlabel('')
                
                # Only show y-label on first column
                if j == 0:
                    axes[i, j].set_ylabel(f"{model_name}\nCount")
                else:
                    axes[i, j].set_ylabel('')
                
                axes[i, j].legend(loc='upper right', fontsize='small')
                
                # Add mean confidence as text
                if len(confidences) > 0:
                    mean_conf = np.mean(confidences)
                    axes[i, j].text(0.05, 0.95, f"Mean: {mean_conf:.2f}", transform=axes[i, j].transAxes, 
                            verticalalignment='top', bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
            else:
                axes[i, j].text(0.5, 0.5, "No confidence data available", 
                        ha='center', va='center', transform=axes[i, j].transAxes)
    
    plt.tight_layout()
    plt.suptitle("Confidence Distributions Across All Datasets", y=1.02, fontsize=16)
    plt.show()

def plot_latency_comparison_all_datasets(all_results):
    """
    Plot latency comparison for all models across all datasets
    
    Args:
        all_results (dict): Dictionary with dataset names as keys and comparison results as values
    """
    dataset_names = list(all_results.keys())
    
    # Get all model names from the first dataset
    model_names = list(all_results[dataset_names[0]].keys())
    
    # Filter to only include models that are in all datasets
    common_models = set(model_names)
    for dataset_results in all_results.values():
        common_models = common_models.intersection(set(dataset_results.keys()))
    
    common_models = list(common_models)
    
    if not common_models:
        print("No common models found across all datasets.")
        return
    
    # Create a figure with subplots for each dataset
    fig, axes = plt.subplots(1, len(dataset_names), figsize=(6*len(dataset_names), 5))
    
    # Handle case with only one dataset
    if len(dataset_names) == 1:
        axes = [axes]
    
    for i, dataset_name in enumerate(dataset_names):
        comparison_results = all_results[dataset_name]
        valid_models = {k: v for k, v in comparison_results.items() if k in common_models and v['valid_predictions'] > 0}
        
        if not valid_models:
            axes[i].text(0.5, 0.5, f"No valid models for\n{dataset_name} dataset", 
                        ha='center', va='center', transform=axes[i].transAxes)
            axes[i].set_title(f"{dataset_name} Dataset")
            continue
        
        model_names_sorted = sorted(valid_models.keys())
        avg_times = [valid_models[name]['avg_inference_time'] for name in model_names_sorted]
        p50_times = [valid_models[name]['latency_p50'] for name in model_names_sorted]
        p90_times = [valid_models[name]['latency_p90'] for name in model_names_sorted]
        p99_times = [valid_models[name]['latency_p99'] for name in model_names_sorted]
        
        x = np.arange(len(model_names_sorted))
        width = 0.2
        
        axes[i].bar(x - width*1.5, avg_times, width, label='Average')
        axes[i].bar(x - width/2, p50_times, width, label='P50')
        axes[i].bar(x + width/2, p90_times, width, label='P90')
        axes[i].bar(x + width*1.5, p99_times, width, label='P99')
        
        axes[i].set_ylabel('Latency (seconds)')
        axes[i].set_title(f"Latency - {dataset_name} Dataset")
        axes[i].set_xticks(x)
        axes[i].set_xticklabels(model_names_sorted, rotation=45, ha='right')
        
        # Only show legend on first subplot
        if i == 0:
            axes[i].legend()
        
        axes[i].grid(axis='y', linestyle='--', alpha=0.7)
    
    plt.tight_layout()
    plt.show()


def cross_dataset_comparison(all_results):
    """
    Compare model performance across different datasets
    
    Args:
        all_results (dict): Dictionary with dataset names as keys and comparison results as values
    """
    print("\n" + "=" * 60)
    print("CROSS-DATASET PERFORMANCE COMPARISON")
    print("=" * 60)
    
    # Get common models across all datasets
    model_names = list(list(all_results.values())[0].keys())
    dataset_names = list(all_results.keys())
    
    common_models = set(model_names)
    for dataset_results in all_results.values():
        common_models = common_models.intersection(set(dataset_results.keys()))
    
    common_models = list(common_models)
    
    if not common_models:
        print("No common models found across all datasets.")
        return
    
    # Create comparison metrics
    metrics = {
        'Accuracy': {},
        'Balanced Accuracy': {},
        'F1-Score': {},
        'MCC': {},
        'AUC-ROC': {}
    }
    
    # Collect metrics for common models
    for model in common_models:
        metrics['Accuracy'][model] = [all_results[dataset][model]['accuracy'] for dataset in dataset_names]
        metrics['Balanced Accuracy'][model] = [all_results[dataset][model]['balanced_accuracy'] for dataset in dataset_names]
        metrics['F1-Score'][model] = [all_results[dataset][model]['f1_score'] for dataset in dataset_names]
        metrics['MCC'][model] = [all_results[dataset][model]['mcc'] for dataset in dataset_names]
        metrics['AUC-ROC'][model] = [all_results[dataset][model]['auc_roc'] for dataset in dataset_names]
    
    # Display metrics as tables
    for metric_name, metric_values in metrics.items():
        print(f"\n{metric_name}:")
        # Create DataFrame with models as rows and datasets as columns
        df = pd.DataFrame.from_dict(metric_values, orient='index')
        df.columns = dataset_names
        print(df.round(3))
    
    # Calculate robustness score (consistency across datasets) - only for multiple models
    robustness_scores = {}
    if len(common_models) > 1:
        print("\n" + "=" * 60)
        print("MODEL ROBUSTNESS SCORES")
        print("=" * 60)
        print("Higher scores indicate better consistency across datasets:")
        print("Robustness = 0.25×(1-CV_accuracy) + 0.25×(1-CV_f1) + 0.25×min_accuracy + 0.25×min_f1")
        print("Where CV = coefficient of variation (lower variation = more consistent)")
        print()
        
        for model in common_models:
            # Get metrics across datasets
            accuracies = [all_results[dataset][model]['accuracy'] for dataset in dataset_names]
            f1_scores = [all_results[dataset][model]['f1_score'] for dataset in dataset_names]
            
            # Calculate coefficient of variation (lower is better)
            # We use 1 - CV so higher is better
            acc_cv = 1 - (np.std(accuracies) / (np.mean(accuracies) + 1e-10))
            f1_cv = 1 - (np.std(f1_scores) / (np.mean(f1_scores) + 1e-10))
            
            # Calculate minimum performance (higher is better)
            acc_min = min(accuracies)
            f1_min = min(f1_scores)
            
            # Combine into a single score (higher is better)
            robustness_scores[model] = 0.25 * acc_cv + 0.25 * f1_cv + 0.25 * acc_min + 0.25 * f1_min
        
        # Sort by robustness score (descending)
        sorted_models = sorted(robustness_scores.keys(), key=lambda x: robustness_scores[x], reverse=True)
        for model in sorted_models:
            print(f"{model}: {robustness_scores[model]:.3f}")
        
        # Identify most robust model with tie-breaking logic
        max_score = robustness_scores[sorted_models[0]]
        tied_models = [model for model in sorted_models if abs(robustness_scores[model] - max_score) < 1e-6]
        
        if len(tied_models) > 1:
            print(f"\nTied models with score {max_score:.3f}: {', '.join(tied_models)}")
            
            # Get average inference times for tied models
            tied_times = {}
            for model in tied_models:
                # Calculate average inference time across all datasets for this model
                model_times = []
                for dataset in dataset_names:
                    if model in all_results[dataset] and 'avg_inference_time' in all_results[dataset][model]:
                        model_times.append(all_results[dataset][model]['avg_inference_time'])
                tied_times[model] = np.mean(model_times) if model_times else float('inf')
            
            # Find fastest model(s)
            min_time = min(tied_times.values())
            fastest_models = [model for model, time in tied_times.items() if abs(time - min_time) < 1e-6]
            
            if len(fastest_models) == 1:
                most_robust = fastest_models[0]
                print(f"Most robust model: {most_robust} (fastest among tied models: {min_time:.3f}s)")
            else:
                # If inference times are also tied, use alphabetical order
                most_robust = sorted(fastest_models)[0]
                print(f"Most robust model: {most_robust} (selected alphabetically from models tied in both robustness and speed)")
        else:
            most_robust = sorted_models[0]
            print(f"\nMost robust model: {most_robust}")
    else:
        # For single model, just calculate robustness score without displaying comparison
        model = common_models[0]
        accuracies = [all_results[dataset][model]['accuracy'] for dataset in dataset_names]
        f1_scores = [all_results[dataset][model]['f1_score'] for dataset in dataset_names]
        
        acc_cv = 1 - (np.std(accuracies) / (np.mean(accuracies) + 1e-10))
        f1_cv = 1 - (np.std(f1_scores) / (np.mean(f1_scores) + 1e-10))
        acc_min = min(accuracies)
        f1_min = min(f1_scores)
        
        robustness_scores[model] = 0.25 * acc_cv + 0.25 * f1_cv + 0.25 * acc_min + 0.25 * f1_min
    
    return metrics, robustness_scores