# ################################## HOW TO USE #################################### #
#                                                                                    #
# This is a Jupyter notebook formatted as a script                                   #
# Format: https://jupytext.readthedocs.io/en/latest/formats.html#the-percent-format  #
#                                                                                    #
# Save this file and remove the '.txt' extension                                     #
# In Jupyter Lab, right click on the Python file -> Open With -> Jupytext Notebook   #
# Make sure to have Jupytext installed: https://github.com/mwouts/jupytext           #
#                                                                                    #
# ################################################################################## #

# %% [markdown]
# #  Applications
# ## Taking
# ### Without stacking

# %%
close_slices = splitter.take(data.close)
close_slices

# %%
close_slices[2020, "test"]

# %%
def get_total_return(sr):
    return sr.vbt.to_returns().vbt.returns.total()

close_slices.apply(get_total_return)

# %% [markdown]
# #### Complex objects

# %%
trendlb = data.run("trendlb", 1.0, 0.5)
trendlb.plot().show()

# %%
grouper = pd.Index(trendlb.labels.map({1: "U", 0: "D"}), name="trend")
trend_splitter = vbt.Splitter.from_grouper(data.index, grouper)
trend_splitter.plot().show()

# %%
hold_pf = vbt.Portfolio.from_holding(data)
hold_returns_acc = hold_pf.returns_acc

fast_sma, slow_sma = vbt.talib("SMA").run_combs(
    data.close, np.arange(5, 50), short_names=["fast_sma", "slow_sma"])
entries = fast_sma.real_crossed_above(slow_sma)
exits = fast_sma.real_crossed_below(slow_sma)
strat_pf = vbt.Portfolio.from_signals(
    data, entries, exits, direction="both")
strat_returns_acc = strat_pf.returns_acc

# %%
hold_returns_acc_slices = trend_splitter.take(hold_returns_acc)
strat_returns_acc_slices = trend_splitter.take(strat_returns_acc)

# %%
hold_returns_acc_slices["U"].sharpe_ratio()

# %%
strat_returns_acc_slices["U"].sharpe_ratio().vbt.heatmap(
    x_level="fast_sma_timeperiod",
    y_level="slow_sma_timeperiod",
    symmetric=True
).show()

# %%
hold_returns_acc_slices["D"].sharpe_ratio()

# %%
strat_returns_acc_slices["D"].sharpe_ratio().vbt.heatmap(
    x_level="fast_sma_timeperiod",
    y_level="slow_sma_timeperiod",
    symmetric=True
).show()

# %%
trend_splitter = trend_splitter.break_up_splits("by_gap", sort=True)
trend_splitter.plot().show()

# %%
pf_slices = pf.split(trend_splitter)
pf_slices

# %%
trend_range_perf = strat_pf_slices.apply(lambda pf: pf.sharpe_ratio)
median_trend_perf = trend_range_perf.median(axis=1)
median_trend_perf

# %%
trend_perf_ts = data.symbol_wrapper.fill().rename("trend_perf")
for label, sr in trend_splitter.bounds.iterrows():
    trend_perf_ts.iloc[sr["start"]:sr["end"]] = median_trend_perf[label]
data.close.vbt.overlay_with_heatmap(trend_perf_ts).show()

# %% [markdown]
# ### Column stacking

# %%
close_stacked = pd.concat(
    close_slices.values.tolist(),
    axis=1,
    keys=close_slices.index
)
close_stacked

# %%
get_total_return(close_stacked)

# %%
close_stacked = splitter.take(data.close, into="stacked")
close_stacked.shape

# %%
close_stacked = splitter.take(data.close, into="reset_stacked")
close_stacked

# %%
close_stacked = splitter.take(data.close, into="from_end_stacked")
close_stacked

# %%
close_stacked = splitter.take(data.close, into="reset_stacked_by_set")
close_stacked

# %%
close_stacked["train"]

# %%
index_slices = splitter.take(data.index)
index_slices

# %%
close_stacked_wb = splitter.take(
    data.close,
    into="reset_stacked_by_set",
    attach_bounds="index",
    right_inclusive=True
)
close_stacked_wb["train"]

# %%
@vbt.parameterized(merge_func="concat")
def set_sma_crossover_perf(close, fast_window, slow_window, freq):
    fast_sma = vbt.talib("sma").run(
        close, fast_window, short_name="fast_sma", hide_params=True)
    slow_sma = vbt.talib("sma").run(
        close, slow_window, short_name="slow_sma", hide_params=True)
    entries = fast_sma.real_crossed_above(slow_sma)
    exits = fast_sma.real_crossed_below(slow_sma)
    pf = vbt.Portfolio.from_signals(
        close, entries, exits, freq=freq, direction="both")
    return pf.sharpe_ratio

# %%
train_perf = set_sma_crossover_perf(
    close_stacked["train"],
    vbt.Param(np.arange(5, 50), condition="x < slow_window"),
    vbt.Param(np.arange(5, 50)),
    data.index.freq,
    _execute_kwargs=dict(
        show_progress=True,
        clear_cache=50,
        collect_garbage=50
    )
)

# %%
train_perf

# %%
train_perf.vbt.heatmap(
    x_level="fast_window",
    y_level="slow_window",
    slider_level="split_year",
    symmetric=True
).show()

# %%
@njit
def prox_median_nb(arr):
    if (~np.isnan(arr)).sum() < 20:
        return np.nan
    return np.nanmedian(arr)

prox_perf_list = []
for split_label, perf_sr in train_perf.groupby("split_year"):
    perf_df = perf_sr.vbt.unstack_to_df(0, [1, 2])
    prox_perf_df = perf_df.vbt.proximity_apply(2, prox_median_nb)
    prox_perf_sr = prox_perf_df.stack([0, 1])
    prox_perf_list.append(prox_perf_sr.reindex(perf_sr.index))

train_prox_perf = pd.concat(prox_perf_list)
train_prox_perf

# %%
train_prox_perf.vbt.heatmap(
    x_level="fast_window",
    y_level="slow_window",
    slider_level="split_year",
    symmetric=True
).show()

# %%
best_params = train_prox_perf.groupby("split_year").idxmax()
best_params = train_prox_perf[best_params].index
train_prox_perf[best_params]

# %%
test_perf = set_sma_crossover_perf(
    vbt.RepEval(
        "test_close.iloc[:, [config_idx]]",
        context=dict(test_close=close_stacked["test"])
    ),
    vbt.Param(best_params.get_level_values("fast_window"), level=0),
    vbt.Param(best_params.get_level_values("slow_window"), level=0),
    data.index.freq
)
test_perf

# %%
def get_index_sharpe(index):
    return data.loc[index].run("from_holding").sharpe_ratio

index_slices.xs("test", level="set").apply(get_index_sharpe)

# %% [markdown]
# ### Row stacking

# %%
block_size = int(3.15 * len(data.index) ** (1 / 3))
block_size

# %%
block_splitter = vbt.Splitter.from_rolling(
    data.index,
    length=block_size,
    offset=1,
    offset_anchor="prev_start"
)
block_splitter.n_splits

# %%
size = int(block_splitter.n_splits / block_size)
sample_splitter = block_splitter.shuffle_splits(size=size, replace=True)
sample_splitter.plot().show()

# %%
returns = data.returns
sample_rets = sample_splitter.take(
    returns,
    into="stacked",
    stack_axis=0
)
sample_rets

# %%
sample_rets.index = data.index[:len(sample_rets)]
sample_cumrets = data.close[0] * (sample_rets + 1).cumprod()
sample_cumrets.vbt.plot().show()

# %%
samples_rets_list = []
for i in tqdm(range(1000)):
    sample_spl = block_splitter.shuffle_splits(size=size, replace=True)
    sample_rets = sample_spl.take(returns, into="stacked", stack_axis=0)
    sample_rets.index = returns.index[:len(sample_rets)]
    sample_rets.name = i
    samples_rets_list.append(sample_rets)
sample_rets_stacked = pd.concat(samples_rets_list, axis=1)

# %%
sample_sharpe = sample_rets_stacked.vbt.returns.sharpe_ratio()
sample_sharpe.vbt.boxplot(horizontal=True).show()

# %%
sample_sharpe.quantile(0.025), sample_sharpe.quantile(0.975)

# %% [markdown]
# ## Applying

# %%
splitter.apply(
    get_total_return,
    vbt.Takeable(data.close),
    merge_func="concat"
)

# %%
splitter.apply(
    get_total_return,
    vbt.RepFunc(lambda range_: data.close[range_]),
    merge_func="concat"
)

# %%
def get_total_return(range_, data):
    return data.returns[range_].vbt.returns.total()

splitter.apply(
    get_total_return,
    vbt.Rep("range_"),
    data,
    merge_func="concat"
)

# %%
def get_total_return(data):
    return data.returns.vbt.returns.total()

splitter.apply(
    get_total_return,
    vbt.Takeable(data),
    merge_func="concat"
)

# %%
splitter.apply(
    get_total_return,
    vbt.Takeable(data),
    set_group_by=True,
    merge_func="concat"
)

# %%
splitter.apply(
    get_total_return,
    vbt.Takeable(data),
    split=[2020, 2021],
    set_="train",
    merge_func="concat"
)

# %%
train_perf = splitter.apply(
    sma_crossover_perf,
    vbt.Takeable(data),
    vbt.Param(np.arange(5, 50), condition="x < slow_window"),
    vbt.Param(np.arange(5, 50)),
    _execute_kwargs=dict(
        show_progress=False,
        clear_cache=50,
        collect_garbage=50
    ),
    set_="train",
    merge_func="concat",
    execute_kwargs=dict(show_progress=True)
)

# %%
train_perf

# %%
best_params = train_perf.groupby("split_year").idxmax()
best_params = train_perf[best_params].index
train_perf[best_params]

# %%
best_fast_windows = best_params.get_level_values("fast_window")
best_slow_windows = best_params.get_level_values("slow_window")

test_perf = splitter.apply(
    sma_crossover_perf,
    vbt.Takeable(data),
    vbt.RepFunc(lambda split_idx: best_fast_windows[split_idx]),
    vbt.RepFunc(lambda split_idx: best_slow_windows[split_idx]),
    set_="test",
    merge_func="concat"
)
test_perf

# %% [markdown]
# ### Iteration schemes

# %%
def cv_sma_crossover(
    data,
    fast_windows,
    slow_windows,
    split_idx,
    set_idx,
    train_perf_list
):
    if set_idx == 0:
        train_perf = sma_crossover_perf(
            data,
            vbt.Param(fast_windows, condition="x < slow_window"),
            vbt.Param(slow_windows),
            _execute_kwargs=dict(
                show_progress=False,
                clear_cache=50,
                collect_garbage=50
            )
        )
        train_perf_list.append(train_perf)
        best_params = train_perf.idxmax()
        return train_perf[[best_params]]
    else:
        train_perf = train_perf_list[split_idx]
        best_params = train_perf.idxmax()
        test_perf = sma_crossover_perf(
            data,
            vbt.Param([best_params[0]]),
            vbt.Param([best_params[1]]),
        )
        return test_perf

train_perf_list = []
cv_perf = splitter.apply(
    cv_sma_crossover,
    vbt.Takeable(data),
    np.arange(5, 50),
    np.arange(5, 50),
    vbt.Rep("split_idx"),
    vbt.Rep("set_idx"),
    train_perf_list,
    iteration="set_major",
    merge_func="concat",
    execute_kwargs=dict(show_progress=True)
)

# %%
train_perf = pd.concat(train_perf_list, keys=splitter.split_labels)
train_perf

# %%
cv_perf

# %% [markdown]
# ### Merging

# %%
def get_entries_and_exits(data, fast_window, slow_window):
    fast_sma = data.run("sma", fast_window, short_name="fast_sma")
    slow_sma = data.run("sma", slow_window, short_name="slow_sma")
    entries = fast_sma.real_crossed_above(slow_sma)
    exits = fast_sma.real_crossed_below(slow_sma)
    return entries, exits

entries, exits = splitter.apply(
    get_entries_and_exits,
    vbt.Takeable(data),
    20,
    30,
    merge_func="column_stack"
)

entries

# %%
entries, exits = splitter.apply(
    get_entries_and_exits,
    vbt.Takeable(data),
    50,
    200,
    merge_all=False,
    merge_func="row_stack"
)

entries.loc[2018]

# %%
def get_signal_count(*args, **kwargs):
    entries, exits = get_entries_and_exits(*args, **kwargs)
    return entries.vbt.signals.total(), exits.vbt.signals.total()

entry_count, exit_count = splitter.apply(
    get_signal_count,
    vbt.Takeable(data),
    20,
    30,
    merge_func="concat",
    attach_bounds="index"
)
entry_count

# %%
def plot_entries_and_exits(results, data, keys):
    set_labels = keys.get_level_values("set")
    fig = data.plot(plot_volume=False)
    train_seen = False
    test_seen = False

    for i in range(len(results)):
        entries, exits = results[i]
        set_label = set_labels[i]
        if set_label == "train":
            entries.vbt.signals.plot_as_entries(
                data.close,
                trace_kwargs=dict(
                    marker=dict(color="limegreen"),
                    name=f"Entries ({set_label})",
                    legendgroup=f"Entries ({set_label})",
                    showlegend=not train_seen
                ),
                fig=fig
            ),
            exits.vbt.signals.plot_as_exits(
                data.close,
                trace_kwargs=dict(
                    marker=dict(color="orange"),
                    name=f"Exits ({set_label})",
                    legendgroup=f"Exits ({set_label})",
                    showlegend=not train_seen
                ),
                fig=fig
            )
            train_seen = True
        else:
            entries.vbt.signals.plot_as_entries(
                data.close,
                trace_kwargs=dict(
                    marker=dict(color="skyblue"),
                    name=f"Entries ({set_label})",
                    legendgroup=f"Entries ({set_label})",
                    showlegend=not test_seen
                ),
                fig=fig
            ),
            exits.vbt.signals.plot_as_exits(
                data.close,
                trace_kwargs=dict(
                    marker=dict(color="magenta"),
                    name=f"Exits ({set_label})",
                    legendgroup=f"Entries ({set_label})",
                    showlegend=not test_seen
                ),
                fig=fig
            )
            test_seen = True
    return fig

splitter.apply(
    get_entries_and_exits,
    vbt.Takeable(data),
    20,
    30,
    merge_func=plot_entries_and_exits,
    merge_kwargs=dict(data=data, keys=vbt.Rep("keys")),
).show()

# %% [markdown]
# ### Decorators

# %%
@vbt.split(splitter=splitter)
def get_split_total_return(data):
    return data.returns.vbt.returns.total()

get_split_total_return(vbt.Takeable(data))

# %%
def get_total_return(data):
    return data.returns.vbt.returns.total()

get_split_total_return = vbt.split(
    get_total_return,
    splitter=splitter
)
get_split_total_return(vbt.Takeable(data))

# %%
@vbt.split
def get_split_total_return(data):
    return data.returns.vbt.returns.total()

get_split_total_return(vbt.Takeable(data), _splitter=splitter)

# %%
get_split_total_return(
    vbt.Takeable(data.loc["2020":"2020"]),
    _splitter="from_rolling",
    _splitter_kwargs=dict(length="30d")
)

# %%
get_total_return_by_month = vbt.split(
    get_total_return,
    splitter="from_grouper",
    splitter_kwargs=dict(by=vbt.RepEval("index.to_period('M')")),
    takeable_args=["data"]
)

get_total_return_by_month(data.loc["2020":"2020"])

# %%
cv_sma_crossover_perf = vbt.split(
    sma_crossover_perf,
    splitter="from_single",
    splitter_kwargs=dict(split=0.6, set_labels=["train", "test"]),
    takeable_args=["data"],
    merge_func="concat",
    execute_kwargs=dict(show_progress=True)
)

# %%
train_perf = cv_sma_crossover_perf(
    data.loc["2020":"2021"],
    vbt.Param(np.arange(5, 50), condition="x < slow_window"),
    vbt.Param(np.arange(5, 50)),
    p_execute_kwargs=dict(
        clear_cache=50,
        collect_garbage=50
    ),
    _forward_kwargs_as={
        "p_execute_kwargs": "_execute_kwargs"
    },
    _apply_kwargs=dict(set_="train")
)

# %%
train_perf

# %%
test_perf = cv_sma_crossover_perf(
    data.loc["2020":"2021"],
    train_perf.idxmax()[0],
    train_perf.idxmax()[1],
    _apply_kwargs=dict(set_="test")
)

# %%
test_perf

# %%
@njit(nogil=True)
def sma_crossover_perf_nb(close, fast_window, slow_window, ann_factor):
    fast_sma = vbt.nb.ma_nb(close, fast_window)
    slow_sma = vbt.nb.ma_nb(close, slow_window)
    entries = vbt.nb.crossed_above_nb(fast_sma, slow_sma)
    exits = vbt.nb.crossed_above_nb(slow_sma, fast_sma)
    sim_out = vbt.pf_nb.from_signals_nb(
        target_shape=close.shape,
        group_lens=np.full(close.shape[1], 1),
        close=close,
        long_entries=entries,
        short_entries=exits,
        save_returns=True
    )
    return vbt.ret_nb.sharpe_ratio_nb(
        sim_out.in_outputs.returns,
        ann_factor
    )

# %%
sma_crossover_perf_nb(vbt.to_2d_array(data.close), 20, 30, 365)

# %%
cv_sma_crossover_perf = vbt.cv_split(
    sma_crossover_perf_nb,
    splitter="from_rolling",
    splitter_kwargs=dict(
        length=360,
        split=0.5,
        set_labels=["train", "test"]
    ),
    takeable_args=["close"],
    merge_func="concat",
    execute_kwargs=dict(show_progress=True),
    parameterized_kwargs=dict(
        engine="dask",
        chunk_len="auto",
        show_progress=False
    )
)

grid_perf, best_perf = cv_sma_crossover_perf(
    vbt.to_2d_array(data.close),
    vbt.Param(np.arange(5, 50), condition="x < slow_window"),
    vbt.Param(np.arange(5, 50)),
    pd.Timedelta(days=365) // data.index.freq,
    _merge_kwargs=dict(wrapper=data.symbol_wrapper),
    _index=data.index,
    _return_grid="all"
)

# %%
grid_perf

# %%
best_perf

# %%
best_train_perf = best_perf.xs("train", level="set")
best_test_perf = best_perf.xs("test", level="set")
best_train_perf.corr(best_test_perf)

# %%
param_cross_set_corr = grid_perf\
    .unstack("set")\
    .groupby(["fast_window", "slow_window"])\
    .apply(lambda x: x["train"].corr(x["test"]))
param_cross_set_corr.vbt.heatmap(symmetric=True).show()

# %%
grid_test_perf = grid_perf.xs("test", level="set")
grid_df = grid_test_perf.rename("grid").reset_index()
del grid_df["fast_window"]
del grid_df["slow_window"]
best_df = best_test_perf.rename("best").reset_index()
del best_df["fast_window"]
del best_df["slow_window"]
merged_df = pd.merge(grid_df, best_df, on=["split", "symbol"])
grid_better_mask = merged_df["grid"] > merged_df["best"]
grid_better_mask.index = grid_test_perf.index
grid_better_cnt = grid_better_mask.groupby(["split", "symbol"]).mean()
grid_better_cnt

# %%
cv_splitter = cv_sma_crossover_perf(
    _index=data.index,
    _return_splitter=True
)
stacked_close = cv_splitter.take(
    data.close,
    into="reset_stacked",
    set_="test"
)
hold_pf = vbt.Portfolio.from_holding(stacked_close, freq="daily")
hold_perf = hold_pf.sharpe_ratio
hold_perf

# %% [markdown]
# ## Modeling

# %%
X = data.run("talib")
X.shape

# %%
trendlb = data.run("trendlb", 1.0, 0.5, mode="binary")
y = trendlb.labels
y.shape

# %%
X = X.replace([-np.inf, np.inf], np.nan)
invalid_column_mask = X.isnull().all(axis=0) | (X.nunique() == 1)
X = X.loc[:, ~invalid_column_mask]
invalid_row_mask = X.isnull().any(axis=1) | y.isnull()
X = X.loc[~invalid_row_mask]
y = y.loc[~invalid_row_mask]
X.shape, y.shape

# %%
from sklearn.ensemble import RandomForestClassifier

clf = RandomForestClassifier(random_state=42)

# %%
cv = vbt.SKLSplitter(
    "from_expanding",
    min_length=360,
    offset=180,
    split=-180,
    set_labels=["train", "test"]
)

cv_splitter = cv.get_splitter(X)
cv_splitter.plot().show()

# %%
from sklearn.model_selection import cross_val_score

cross_val_score(clf, X, y, cv=cv, scoring="accuracy")

# %%
X_slices = cv_splitter.take(X)
y_slices = cv_splitter.take(y)

test_labels = []
test_preds = []
for split in X_slices.index.unique(level="split"):
    X_train_slice = X_slices[(split, "train")]
    y_train_slice = y_slices[(split, "train")]
    X_test_slice = X_slices[(split, "test")]
    y_test_slice = y_slices[(split, "test")]
    slice_clf = clf.fit(X_train_slice, y_train_slice)
    test_pred = slice_clf.predict(X_test_slice)
    test_pred = pd.Series(test_pred, index=y_test_slice.index)
    test_labels.append(y_test_slice)
    test_preds.append(test_pred)

test_labels = pd.concat(test_labels).rename("labels")
test_preds = pd.concat(test_preds).rename("preds")

# %%
data.close.vbt.overlay_with_heatmap(test_labels).show()

# %%
data.close.vbt.overlay_with_heatmap(test_preds).show()

# %%
pf = vbt.Portfolio.from_signals(
    data.close[test_preds.index],
    test_preds == 1,
    test_preds == 0,
    direction="both"
)
pf.stats()

# %%