# ################################## HOW TO USE #################################### #
#                                                                                    #
# This is a Jupyter notebook formatted as a script                                   #
# Format: https://jupytext.readthedocs.io/en/latest/formats.html#the-percent-format  #
#                                                                                    #
# Save this file and remove the '.txt' extension                                     #
# In Jupyter Lab, right click on the Python file -> Open With -> Jupytext Notebook   #
# Make sure to have Jupytext installed: https://github.com/mwouts/jupytext           #
#                                                                                    #
# ################################################################################## #

# %% [markdown]
# #  Cross-validation

# %%
from vectorbtpro import *

data = vbt.BinanceData.pull("BTCUSDT", end="2022-11-01 UTC")
data.index

# %%
@vbt.parameterized(merge_func="concat")
def sma_crossover_perf(data, fast_window, slow_window):
    fast_sma = data.run("sma", fast_window, short_name="fast_sma")
    slow_sma = data.run("sma", slow_window, short_name="slow_sma")
    entries = fast_sma.real_crossed_above(slow_sma)
    exits = fast_sma.real_crossed_below(slow_sma)
    pf = vbt.Portfolio.from_signals(
        data, entries, exits, direction="both")
    return pf.sharpe_ratio

# %%
perf = sma_crossover_perf(
    data["2020":"2020"],
    vbt.Param(np.arange(5, 50), condition="x < slow_window"),
    vbt.Param(np.arange(5, 50)),
    _execute_kwargs=dict(
        show_progress=True,
        clear_cache=50,
        collect_garbage=50
    )
)
perf

# %%
perf.sort_values(ascending=False)

# %%
best_fast_window, best_slow_window = perf.idxmax()
sma_crossover_perf(
    data["2021":"2021"],
    best_fast_window,
    best_slow_window
)

# %%
data["2021":"2021"].run("from_holding").sharpe_ratio

# %%
start_index = data.index[0]
period = pd.Timedelta(days=180)
all_is_bounds = {}
all_is_bl_perf = {}
all_is_perf = {}
all_oos_bounds = {}
all_oos_bl_perf = {}
all_oos_perf = {}
split_idx = 0
period_idx = 0

with vbt.get_pbar() as pbar:
    while start_index + 2 * period <= data.index[-1]:
        pbar.set_description(str(start_index))

        is_start_index = start_index
        is_end_index = start_index + period - pd.Timedelta(1)
        is_data = data[is_start_index : is_end_index]
        is_bl_perf = is_data.run("from_holding").sharpe_ratio
        is_perf = sma_crossover_perf(
            is_data,
            vbt.Param(np.arange(5, 50), condition="x < slow_window"),
            vbt.Param(np.arange(5, 50)),
            _execute_kwargs=dict(
                clear_cache=50,
                collect_garbage=50
            )
        )

        oos_start_index = start_index + period
        oos_end_index = start_index + 2 * period - pd.Timedelta(1)
        oos_data = data[oos_start_index : oos_end_index]
        oos_bl_perf = oos_data.run("from_holding").sharpe_ratio
        best_fw, best_sw = is_perf.idxmax()
        oos_perf = sma_crossover_perf(oos_data, best_fw, best_sw)
        oos_perf_index = is_perf.index[is_perf.index == (best_fw, best_sw)]
        oos_perf = pd.Series([oos_perf], index=oos_perf_index)

        all_is_bounds[period_idx] = (is_start_index, is_end_index)
        all_oos_bounds[period_idx + 1] = (oos_start_index, oos_end_index)
        all_is_bl_perf[(split_idx, period_idx)] = is_bl_perf
        all_oos_bl_perf[(split_idx, period_idx + 1)] = oos_bl_perf
        all_is_perf[(split_idx, period_idx)] = is_perf
        all_oos_perf[(split_idx, period_idx + 1)] = oos_perf
        start_index = start_index + period
        split_idx += 1
        period_idx += 1
        pbar.update(1)

# %%
is_period_ranges = pd.DataFrame.from_dict(
    all_is_bounds,
    orient="index",
    columns=["start", "end"]
)
is_period_ranges.index.name = "period"
oos_period_ranges = pd.DataFrame.from_dict(
    all_oos_bounds,
    orient="index",
    columns=["start", "end"]
)
oos_period_ranges.index.name = "period"
period_ranges = pd.concat((is_period_ranges, oos_period_ranges))
period_ranges = period_ranges.drop_duplicates()
period_ranges

# %%
is_bl_perf = pd.Series(all_is_bl_perf)
is_bl_perf.index.names = ["split", "period"]
oos_bl_perf = pd.Series(all_oos_bl_perf)
oos_bl_perf.index.names = ["split", "period"]
bl_perf = pd.concat((
    is_bl_perf.vbt.select_levels("period"),
    oos_bl_perf.vbt.select_levels("period")
))
bl_perf = bl_perf.drop_duplicates()
bl_perf

# %%
is_perf = pd.concat(all_is_perf, names=["split", "period"])
is_perf

# %%
oos_perf = pd.concat(all_oos_perf, names=["split", "period"])
oos_perf

# %%
is_best_mask = is_perf.index.vbt.drop_levels("period").isin(
    oos_perf.index.vbt.drop_levels("period"))
is_best_perf = is_perf[is_best_mask]
is_best_perf

# %%
pd.concat((
    is_perf.describe(),
    is_best_perf.describe(),
    is_bl_perf.describe(),
    oos_perf.describe(),
    oos_bl_perf.describe()
), axis=1, keys=[
    "IS",
    "IS (Best)",
    "IS (Baseline)",
    "OOS (Test)",
    "OOS (Baseline)"
])

# %%
fig = is_perf.vbt.boxplot(
    by_level="period",
    trace_kwargs=dict(
        line=dict(color="lightskyblue"),
        opacity=0.4,
        showlegend=False
    ),
    xaxis_title="Period",
    yaxis_title="Sharpe",
)
is_best_perf.vbt.select_levels("period").vbt.plot(
    trace_kwargs=dict(
        name="Best",
        line=dict(color="limegreen", dash="dash")
    ),
    fig=fig
)
bl_perf.vbt.plot(
    trace_kwargs=dict(
        name="Baseline",
        line=dict(color="orange", dash="dash")
    ),
    fig=fig
)
oos_perf.vbt.select_levels("period").vbt.plot(
    trace_kwargs=dict(
        name="Test",
        line=dict(color="orangered")
    ),
    fig=fig
)
fig.show()

# %%
is_perf_split6 = is_perf.xs(6, level="split")
is_perf_split6.describe()

# %%
first_left_bound = period_ranges.loc[6, "start"]
first_right_bound = period_ranges.loc[6, "end"]
data[first_left_bound : first_right_bound].plot().show()

# %%
oos_perf.xs(6, level="period")

# %%
is_perf_split6.quantile(0.25)

# %%