From 191b58d11de32a35e7e03c98d2dda15d45f23f87 Mon Sep 17 00:00:00 2001 From: David Brazda Date: Thu, 21 Nov 2024 16:08:07 +0100 Subject: [PATCH] thr conf matrix added --- setup.py | 2 +- ttools/models.py | 101 ++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 101 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 9abcc94..682d5a3 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ from setuptools import setup, find_packages setup( name='ttools', - version='0.7.82', + version='0.7.9', packages=find_packages(), install_requires=[ # list your dependencies here diff --git a/ttools/models.py b/ttools/models.py index 8d1acbb..45fa312 100644 --- a/ttools/models.py +++ b/ttools/models.py @@ -778,6 +778,7 @@ class LibraryTradingModel: return results, model except Exception as e: + raise e print(f"Error in iteration {iteration_num}: {str(e)} - {format_exc()}") return None, None @@ -1020,7 +1021,7 @@ class LibraryTradingModel: left=[], middle1=[(results["predicted"],"predicted"),(results["actual"],"actual"),(prob_df,)], ).chart(size="s", precision=6, title=f"Iteration {iteration_num} classes:{self.config.n_classes} forward_bars:{self.config.forward_bars}") - + num_classes = self.config.n_classes # Add probability columns to results @@ -1062,6 +1063,104 @@ class LibraryTradingModel: print("Overall Accuracy:", (results['predicted'] == results['actual']).mean()) print("Directional Accuracy:", results['direction_correct'].mean()) + print("New confusion matrix") + + def plot_threshold_confusion_matrices(results, predictions_proba, thresholds=[0.3, 0.5, 0.8], num_classes=5): + """ + Plot confusion matrices for different probability thresholds + """ + # Calculate subplot dimensions + plots_per_row = 3 + num_rows = (len(thresholds) + plots_per_row - 1) // plots_per_row # Ceiling division + num_cols = min(len(thresholds), plots_per_row) + + # Create figure with extra spacing between subplots + fig = plt.figure(figsize=(7*num_cols, 6*num_rows)) + gs = fig.add_gridspec(num_rows, num_cols, hspace=0.4, wspace=0.3) + + for idx, threshold in enumerate(thresholds): + # Calculate subplot position + row = idx // plots_per_row + col = idx % plots_per_row + ax = fig.add_subplot(gs[row, col]) + + # Create masked predictions where low confidence predictions are marked as -1 + predicted_classes = np.full(len(predictions_proba), -1) + max_probs = np.max(predictions_proba, axis=1) + confident_mask = max_probs >= threshold + + # Only assign predictions where confidence meets threshold + predicted_classes[confident_mask] = np.argmax(predictions_proba[confident_mask], axis=1) + + # Filter results to only include confident predictions + valid_indices = predicted_classes != -1 + filtered_actual = results['actual'][valid_indices] + filtered_predicted = predicted_classes[valid_indices] + + if len(filtered_actual) > 0: + # Calculate confusion matrix for confident predictions + conf_matrix = confusion_matrix(filtered_actual, filtered_predicted) + conf_matrix_pct = conf_matrix / conf_matrix.sum(axis=1)[:, np.newaxis] + + # Plot heatmap + sns.heatmap(conf_matrix_pct, annot=conf_matrix, fmt='d', cmap='YlOrRd', + xticklabels=range(num_classes), yticklabels=range(num_classes), ax=ax) + + # Add dynamic directional indicators based on number of classes + negative_end = num_classes // 3 + positive_start = num_classes - (num_classes // 3) + + # Add lines at positions that divide the classes into thirds + ax.axhline(y=negative_end, color='blue', linestyle='--', alpha=0.3) + ax.axhline(y=positive_start, color='blue', linestyle='--', alpha=0.3) + ax.axvline(x=negative_end, color='blue', linestyle='--', alpha=0.3) + ax.axvline(x=positive_start, color='blue', linestyle='--', alpha=0.3) + + # Add dynamically positioned direction labels + ax.text(-0.2, negative_end/2, 'Negative', rotation=90, verticalalignment='center') + ax.text(-0.2, (positive_start + negative_end)/2, 'Neutral', rotation=90, verticalalignment='center') + ax.text(-0.2, (positive_start + num_classes)/2, 'Positive', rotation=90, verticalalignment='center') + + # Calculate accuracy metrics + accuracy = (filtered_predicted == filtered_actual).mean() + + # Calculate directional metrics + def get_direction(x): + negative_end = num_classes // 3 + positive_start = num_classes - (num_classes // 3) + if x < negative_end: return 'negative' + elif x >= positive_start: return 'positive' + return 'neutral' + + pred_direction = np.vectorize(get_direction)(filtered_predicted) + actual_direction = np.vectorize(get_direction)(filtered_actual) + directional_accuracy = (pred_direction == actual_direction).mean() + + coverage = np.mean(confident_mask) + title = f'Threshold: {threshold}\n' + title += f'Coverage: {coverage:.2%}\n' + title += 'Color: % of True Class\nValues: Absolute Count' + ax.set_title(title) + ax.set_xlabel('Predicted Class') + ax.set_ylabel('True Class') + + ax.text(0, -0.2, f'Acc: {accuracy:.2f}\nDir Acc: {directional_accuracy:.2f}', + transform=ax.transAxes) + else: + ax.text(0.5, 0.5, f'No predictions meet\nthreshold {threshold}', + ha='center', va='center') + + plt.tight_layout() + plt.show() + + # Replace the original confusion matrix plotting with: + plot_threshold_confusion_matrices( + results, + predictions_proba, + thresholds=[0.3, 0.5, 0.6, 0.7, 0.8, 0.9], # Configurable thresholds + num_classes=len(model.classes_) + ) + # Create visual confusion matrix conf_matrix = confusion_matrix(results['actual'], results['predicted']) plt.figure(figsize=(10, 8))