# ################################## HOW TO USE #################################### #
#                                                                                    #
# This is a Jupyter notebook formatted as a script                                   #
# Format: https://jupytext.readthedocs.io/en/latest/formats.html#the-percent-format  #
#                                                                                    #
# Save this file and remove the '.txt' extension                                     #
# In Jupyter Lab, right click on the Python file -> Open With -> Jupytext Notebook   #
# Make sure to have Jupytext installed: https://github.com/mwouts/jupytext           #
#                                                                                    #
# ################################################################################## #

# %% [markdown]
# #  Optimization
# ## Lazy parameter grids

# %%
@vbt.parameterized(merge_func="concat")
def test_combination(data, n, sl_stop, tsl_stop, tp_stop):
    return data.run(
        "from_random_signals",
        n=n,
        sl_stop=sl_stop,
        tsl_stop=tsl_stop,
        tp_stop=tp_stop,
    ).total_return

n = np.arange(10, 100)
sl_stop = np.arange(1, 1000) / 1000
tsl_stop = np.arange(1, 1000) / 1000
tp_stop = np.arange(1, 1000) / 1000
len(n) * len(sl_stop) * len(tsl_stop) * len(tp_stop)

# %%
test_combination(
    vbt.YFData.pull("BTC-USD"),
    n=vbt.Param(n),
    sl_stop=vbt.Param(sl_stop),
    tsl_stop=vbt.Param(tsl_stop),
    tp_stop=vbt.Param(tp_stop),
    _random_subset=10
)

# %% [markdown]
# ## Mono-chunks

# %%
@vbt.parameterized(
    merge_func="concat",
    mono_chunk_len=100,
    chunk_len="auto",
    engine="threadpool",
    show_progress=True,
    warmup=True
)
@njit(nogil=True)
def test_stops_nb(close, entries, exits, sl_stop, tp_stop):
    sim_out = vbt.pf_nb.from_signals_nb(
        target_shape=(close.shape[0], sl_stop.shape[1]),
        group_lens=np.full(sl_stop.shape[1], 1),
        close=close,
        long_entries=entries,
        short_entries=exits,
        sl_stop=sl_stop,
        tp_stop=tp_stop,
        save_returns=True
    )
    return vbt.ret_nb.cum_returns_final_nb(sim_out.in_outputs.returns, 0)

data = vbt.YFData.pull("BTC-USD", start="2020")
entries, exits = data.run("randnx", n=10, hide_params=True, unpack=True)
sharpe_ratios = test_stops_nb(
    vbt.to_2d_array(data.close),
    vbt.to_2d_array(entries),
    vbt.to_2d_array(exits),
    sl_stop=vbt.Param(np.arange(0.01, 1.0, 0.01), mono_merge_func=np.column_stack),
    tp_stop=vbt.Param(np.arange(0.01, 1.0, 0.01), mono_merge_func=np.column_stack)
)
sharpe_ratios.vbt.heatmap().show()

# %% [markdown]
# ## CV decorator

# %%
@vbt.cv_split(
    splitter="from_rolling",
    splitter_kwargs=dict(length=365, split=0.5, set_labels=["train", "test"]),
    takeable_args=["data"],
    execute_kwargs=dict(show_progress=True),
    parameterized_kwargs=dict(random_subset=100),
    merge_func="concat"
)
def sma_crossover_cv(data, fast_period, slow_period, metric):
    fast_sma = data.run("sma", fast_period, hide_params=True)
    slow_sma = data.run("sma", slow_period, hide_params=True)
    entries = fast_sma.real_crossed_above(slow_sma)
    exits = fast_sma.real_crossed_below(slow_sma)
    pf = vbt.PF.from_signals(data, entries, exits, direction="both")
    return pf.deep_getattr(metric)

sma_crossover_cv(
    vbt.YFData.pull("BTC-USD", start="4 years ago"),
    vbt.Param(np.arange(20, 50), condition="x < slow_period"),
    vbt.Param(np.arange(20, 50)),
    "trades.expectancy"
)

# %% [markdown]
# ## Split decorator

# %%
@vbt.split(
    splitter="from_grouper",
    splitter_kwargs=dict(by="Q"),
    takeable_args=["data"],
    merge_func="concat"
)
def get_quarter_return(data):
    return data.returns.vbt.returns.total()

data = vbt.YFData.pull("BTC-USD")
get_quarter_return(data.loc["2021"])

# %%
get_quarter_return(data.loc["2022"])

# %% [markdown]
# ## Conditional parameters

# %%
@vbt.parameterized(merge_func="column_stack")
def ma_crossover_signals(data, fast_window, slow_window):
    fast_sma = data.run("sma", fast_window, short_name="fast_sma")
    slow_sma = data.run("sma", slow_window, short_name="slow_sma")
    entries = fast_sma.real_crossed_above(slow_sma.real)
    exits = fast_sma.real_crossed_below(slow_sma.real)
    return entries, exits

entries, exits = ma_crossover_signals(
    vbt.YFData.pull("BTC-USD", start="one year ago UTC"),
    vbt.Param(np.arange(5, 50), condition="slow_window - fast_window >= 5"),
    vbt.Param(np.arange(5, 50))
)
entries.columns

# %% [markdown]
# ## Splitter

# %%
data = vbt.YFData.pull("BTC-USD", start="4 years ago")
splitter = vbt.Splitter.from_rolling(
    data.index,
    length="360 days",
    split=0.5,
    set_labels=["train", "test"],
    freq="daily"
)
splitter.plots().show()

# %% [markdown]
# ## Random search

# %%
data = vbt.YFData.pull("BTC-USD", start="2020")
stop_values = np.arange(1, 100) / 100
pf = vbt.PF.from_random_signals(
    data,
    n=100,
    sl_stop=vbt.Param(stop_values),
    tsl_stop=vbt.Param(stop_values),
    tp_stop=vbt.Param(stop_values),
    broadcast_kwargs=dict(random_subset=1000)
)
pf.total_return.sort_values(ascending=False)

# %% [markdown]
# ## Parameterized decorator

# %%
@vbt.parameterized(merge_func="concat", show_progress=True)
def bbands_sharpe(data, timeperiod=14, nbdevup=2, nbdevdn=2, thup=0.3, thdn=0.1):
    bb = data.run(
        "talib_bbands",
        timeperiod=timeperiod,
        nbdevup=nbdevup,
        nbdevdn=nbdevdn
    )
    bandwidth = (bb.upperband - bb.lowerband) / bb.middleband
    cond1 = data.low < bb.lowerband
    cond2 = bandwidth > thup
    cond3 = data.high > bb.upperband
    cond4 = bandwidth < thdn
    entries = (cond1 & cond2) | (cond3 & cond4)
    exits = (cond1 & cond4) | (cond3 & cond2)
    pf = vbt.PF.from_signals(data, entries, exits)
    return pf.sharpe_ratio

bbands_sharpe(
    vbt.YFData.pull("BTC-USD"),
    nbdevup=vbt.Param([1, 2]),
    nbdevdn=vbt.Param([1, 2]),
    thup=vbt.Param([0.4, 0.5]),
    thdn=vbt.Param([0.1, 0.2])
)

# %% [markdown]
# ## Riskfolio-Lib

# %%
data = vbt.YFData.pull(
    ["SPY", "TLT", "XLF", "XLE", "XLU", "XLK", "XLB", "XLP", "XLY", "XLI", "XLV"],
    start="2020",
    end="2023",
    missing_index="drop"
)
pfo = vbt.PFO.from_riskfolio(
    returns=data.close.vbt.to_returns(),
    port_cls="hc",
    every="MS"
)
pfo.plot().show()

# %% [markdown]
# ## Array-like parameters

# %%
def steep_slope(close, up_th):
    r = vbt.broadcast(dict(close=close, up_th=up_th))
    return r["close"].pct_change() >= r["up_th"]

data = vbt.YFData.pull("BTC-USD", start="2020", end="2022")
fig = data.plot(plot_volume=False)
sma = vbt.talib("SMA").run(data.close, timeperiod=50).real
sma.rename("SMA").vbt.plot(fig=fig)
mask = steep_slope(sma, vbt.Param([0.005, 0.01, 0.015]))

def plot_mask_ranges(column, color):
    mask.vbt.ranges.plot_shapes(
        column=column,
        plot_close=False,
        shape_kwargs=dict(fillcolor=color),
        fig=fig
    )
plot_mask_ranges(0.005, "orangered")
plot_mask_ranges(0.010, "orange")
plot_mask_ranges(0.015, "yellow")
fig.update_xaxes(showgrid=False)
fig.update_yaxes(showgrid=False)
fig.show()

# %% [markdown]
# ## Parameters

# %%
from itertools import combinations

window_space = np.arange(100)
fastk_windows, slowk_windows = list(zip(*combinations(window_space, 2)))
window_type_space = list(vbt.enums.WType)
param_product = vbt.combine_params(
    dict(
        fast_window=vbt.Param(fastk_windows, level=0),
        slow_window=vbt.Param(slowk_windows, level=0),
        signal_window=vbt.Param(window_space, level=1),
        macd_wtype=vbt.Param(window_type_space, level=2),
        signal_wtype=vbt.Param(window_type_space, level=2),
    ),
    random_subset=10_000,
    build_index=False
)
pd.DataFrame(param_product)

# %% [markdown]
# ## Portfolio optimization

# %%
def regime_change_optimize_func(data):
    returns = data.returns
    total_return = returns.vbt.returns.total()
    weights = data.symbol_wrapper.fill_reduced(0)
    pos_mask = total_return > 0
    if pos_mask.any():
        weights[pos_mask] = total_return[pos_mask] / total_return.abs().sum()
    neg_mask = total_return < 0
    if neg_mask.any():
        weights[neg_mask] = total_return[neg_mask] / total_return.abs().sum()
    return -1 * weights

data = vbt.YFData.pull(
    ["SPY", "TLT", "XLF", "XLE", "XLU", "XLK", "XLB", "XLP", "XLY", "XLI", "XLV"],
    start="2020",
    end="2023",
    missing_index="drop"
)
pfo = vbt.PFO.from_optimize_func(
    data.symbol_wrapper,
    regime_change_optimize_func,
    vbt.RepEval("data[index_slice]", context=dict(data=data)),
    every="MS"
)
pfo.plot().show()

# %% [markdown]
# ## PyPortfolioOpt

# %%
data = vbt.YFData.pull(
    ["SPY", "TLT", "XLF", "XLE", "XLU", "XLK", "XLB", "XLP", "XLY", "XLI", "XLV"],
    start="2020",
    end="2023",
    missing_index="drop"
)
pfo = vbt.PFO.from_pypfopt(
    returns=data.returns,
    optimizer="hrp",
    target="optimize",
    every="MS"
)
pfo.plot().show()

# %% [markdown]
# ## Universal Portfolios

# %%
data = vbt.YFData.pull(
    ["SPY", "TLT", "XLF", "XLE", "XLU", "XLK", "XLB", "XLP", "XLY", "XLI", "XLV"],
    start="2020",
    end="2023",
    missing_index="drop"
)
pfo = vbt.PFO.from_universal_algo(
    "MPT",
    data.resample("W").close,
    window=52,
    min_history=4,
    mu_estimator='historical',
    cov_estimator='empirical',
    method='mpt',
    q=0
)
pfo.plot().show()

# %%