stumpy bt update

This commit is contained in:
David Brazda
2024-09-29 20:22:20 +02:00
parent c48b11ea76
commit fa125e1c8f

View File

@ -11715,7 +11715,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# Univariate and simple backtesting\n",
"# Univariate and simple backtesting (returns)\n",
"\n",
"- Load a sample time series dataset.\n",
"- Discover motifs using STUMPY's stump and motifs functions.\n",
@ -11725,7 +11725,7 @@
},
{
"cell_type": "code",
"execution_count": 30,
"execution_count": 37,
"metadata": {},
"outputs": [
{
@ -11733,39 +11733,34 @@
"output_type": "stream",
"text": [
"Motif DataFrame with profits:\n",
" motif_set_id motif_id match_id start_idx end_idx profit\n",
"0 0 0 0 0 20 -0.980125\n",
"1 0 0 1 1 21 -2.333143\n",
"2 0 0 2 1 21 -2.333143\n",
"3 0 0 3 2 22 -19.432151\n",
"4 0 1 0 0 20 -0.980125\n",
"5 0 1 1 1 21 -2.333143\n",
"6 0 1 2 3 23 -0.727909\n",
"7 0 1 3 3 23 -0.727909\n",
"8 0 2 0 0 20 -0.980125\n",
"9 0 2 1 2 22 -19.432151\n",
"10 0 2 2 3 23 -0.727909\n",
"11 0 2 3 3 23 -0.727909\n",
"12 1 0 0 370 390 0.206577\n",
"13 1 0 1 1653 1673 -0.967765\n",
"14 1 0 2 1540 1560 -1.380473\n",
"15 1 0 3 266 286 -1.502156\n",
"16 1 1 0 760 780 0.654485\n",
"17 1 1 1 1150 1170 -0.805488\n",
"18 1 1 2 1275 1295 -0.167817\n",
"19 1 1 3 1920 1940 NaN\n",
"20 1 2 0 1487 1507 -1.861615\n",
"21 1 2 1 1740 1760 0.009996\n",
"22 1 2 2 1585 1605 -0.992155\n",
"23 1 2 3 42 62 -3.815360\n",
" motif_id match_id distance start_idx end_idx profit\n",
"0 0 0 0.000000 370 390 0.206577\n",
"1 0 1 1.250590 1653 1673 -0.967765\n",
"2 0 2 1.982121 1540 1560 -1.380473\n",
"3 0 3 2.503828 266 286 -1.502156\n",
"4 1 0 0.000000 760 780 0.654485\n",
"5 1 1 1.261131 1150 1170 -0.805488\n",
"6 1 2 3.515795 1275 1295 -0.167817\n",
"7 1 3 3.590752 1920 1940 NaN\n",
"8 2 0 0.000000 1487 1507 -1.861615\n",
"9 2 1 2.309025 1740 1760 0.009996\n",
"10 2 2 3.323742 1585 1605 -0.992155\n",
"11 2 3 3.456676 42 62 -3.815360\n",
"\n",
"Average profit per motif category:\n",
" motif_id profit\n",
"0 0 -3.590297\n",
"1 1 -0.726844\n",
"2 2 -3.565903\n",
"0 0 -0.910954\n",
"1 1 -0.106273\n",
"2 2 -1.664783\n",
"\n",
"The most profitable motif category is: 1.0\n"
"The most profitable motif avg category is: 1.0\n",
"\\Sum profit per motif category:\n",
" motif_id profit\n",
"0 0 -3.643816\n",
"1 1 -0.318820\n",
"2 2 -6.659133\n",
"\n",
"The most profitable motif sum category is: 1.0\n"
]
}
],
@ -11774,6 +11769,10 @@
"import pandas as pd\n",
"import stumpy\n",
"\n",
"\n",
"# Convert the 'close' column to a 1D NumPy array\n",
"T = timeseries_sr.values\n",
"\n",
"# Parameters\n",
"N = 20 # Window length (motif length)\n",
"max_motifs = 3 # Number of main motifs to discover\n",
@ -11783,18 +11782,18 @@
"# Step 1: Calculate the Matrix Profile\n",
"matrix_profile = stumpy.stump(timeseries_sr, m=N)\n",
"\n",
"# Step 2: Discover the motifs\n",
"motif_indices = stumpy.motifs(timeseries_sr, matrix_profile[:, 0], max_motifs=max_motifs, max_matches=motifs_per_category)\n",
"# Step 2: Discover the motifs return distances and indices in the shape of (max_motifs, motifs_per_category)\n",
"motif_distances, motif_indices = stumpy.motifs(timeseries_sr, matrix_profile[:, 0], max_motifs=max_motifs, max_matches=motifs_per_category)\n",
"\n",
"# Create a DataFrame to store the motifs information\n",
"motif_results = []\n",
"\n",
"# Since motif_indices is a list containing arrays of shape (3, 4), iterate accordingly\n",
"for motif_set_id, motif_array in enumerate(motif_indices): # Iterate over each motif set (array1 and array2)\n",
" for motif_id in range(motif_array.shape[0]): # Iterate over each row, i.e., each motif\n",
" for match_id in range(motif_array.shape[1]): # Iterate over each column, i.e., each match\n",
" start_idx = motif_array[motif_id, match_id]\n",
" \n",
"# Since motif_indices is an array of shape (3, 4), iterate accordingly\n",
"for motif_id in range(motif_indices.shape[0]): # Iterate over each row, i.e., each motif\n",
" for match_id in range(motif_indices.shape[1]): # Iterate over each column, i.e., each match\n",
" start_idx = motif_indices[motif_id, match_id]\n",
" distance = motif_distances[motif_id, match_id]\n",
"\n",
" # Check if the entry is valid (not NaN)\n",
" if np.isnan(start_idx):\n",
" continue\n",
@ -11812,9 +11811,9 @@
" \n",
" # Append the motif result\n",
" motif_results.append({\n",
" 'motif_set_id': motif_set_id, # Identifies which motif set this belongs to\n",
" 'motif_id': motif_id, # Identifies the motif category within the set\n",
" 'match_id': match_id, # Rank/order of the match within the motif category\n",
" 'distance': distance,\n",
" 'start_idx': start_idx,\n",
" 'end_idx': end_idx,\n",
" 'profit': future_return\n",
@ -11827,10 +11826,15 @@
"motif_profits = motif_df.groupby('motif_id')['profit'].mean().reset_index()\n",
"most_profitable_motif_id = motif_profits.sort_values(by='profit', ascending=False).iloc[0]['motif_id']\n",
"\n",
"motif_profits_sum = motif_df.groupby('motif_id')['profit'].sum().reset_index()\n",
"most_profitable_motif_sum_id = motif_profits_sum.sort_values(by='profit', ascending=False).iloc[0]['motif_id']\n",
"\n",
"# Display results\n",
"print(\"Motif DataFrame with profits:\\n\", motif_df)\n",
"print(\"\\nAverage profit per motif category:\\n\", motif_profits)\n",
"print(f\"\\nThe most profitable motif category is: {most_profitable_motif_id}\")\n"
"print(f\"\\nThe most profitable motif avg category is: {most_profitable_motif_id}\")\n",
"print(\"\\Sum profit per motif category:\\n\", motif_profits_sum)\n",
"print(f\"\\nThe most profitable motif sum category is: {most_profitable_motif_sum_id}\")"
]
}
],