# ################################## HOW TO USE #################################### #
#                                                                                    #
# This is a Jupyter notebook formatted as a script                                   #
# Format: https://jupytext.readthedocs.io/en/latest/formats.html#the-percent-format  #
#                                                                                    #
# Save this file and remove the '.txt' extension                                     #
# In Jupyter Lab, right click on the Python file -> Open With -> Jupytext Notebook   #
# Make sure to have Jupytext installed: https://github.com/mwouts/jupytext           #
#                                                                                    #
# ################################################################################## #

# %% [markdown]
# #  Basic RSI strategy
# ## Single backtest

# %%
from vectorbtpro import *

data = vbt.BinanceData.pull('BTCUSDT')
data

# %%
data.plot().show()

# %%
data.data['BTCUSDT'].info()

# %%
open_price = data.get('Open')
close_price = data.get('Close')

# %%
vbt.IF.list_indicators("RSI*")

# %%
vbt.indicator("talib:RSI")

# %%
vbt.RSI

# %%
vbt.talib('RSI')

# %%
vbt.ta('RSIIndicator')

# %%
vbt.pandas_ta('RSI')

# %%
vbt.technical('RSI')

# %%
vbt.phelp(vbt.RSI.run)

# %%
rsi = vbt.RSI.run(open_price)
rsi

# %%
rsi.rsi

# %%
entries = rsi.rsi.vbt.crossed_below(30)
entries

# %%
exits = rsi.rsi.vbt.crossed_above(70)
exits

# %%
entries = rsi.rsi_crossed_below(30)
exits = rsi.rsi_crossed_above(70)

# %%
def plot_rsi(rsi, entries, exits):
    fig = rsi.plot()
    entries.vbt.signals.plot_as_entries(rsi.rsi, fig=fig)
    exits.vbt.signals.plot_as_exits(rsi.rsi, fig=fig)
    return fig

plot_rsi(rsi, entries, exits).show()

# %%
clean_entries, clean_exits = entries.vbt.signals.clean(exits)

plot_rsi(rsi, clean_entries, clean_exits).show()

# %%
clean_entries.vbt.signals.total()

# %%
clean_exits.vbt.signals.total()

# %%
ranges = clean_entries.vbt.signals.between_ranges(target=clean_exits)
ranges.duration.mean(wrap_kwargs=dict(to_timedelta=True))

# %%
pf = vbt.Portfolio.from_signals(
    close=close_price,
    entries=clean_entries,
    exits=clean_exits,
    size=100,
    size_type='value',
    init_cash='auto'
)
pf

# %%
pf.stats()

# %%
pf.plot(settings=dict(bm_returns=False)).show()

# %% [markdown]
# ## Multiple backtests
# ### Using for-loop

# %%
def test_rsi(window=14, wtype="wilder", lower_th=30, upper_th=70):
    rsi = vbt.RSI.run(open_price, window=window, wtype=wtype)
    entries = rsi.rsi_crossed_below(lower_th)
    exits = rsi.rsi_crossed_above(upper_th)
    pf = vbt.Portfolio.from_signals(
        close=close_price,
        entries=entries,
        exits=exits,
        size=100,
        size_type='value',
        init_cash='auto')
    return pf.stats([
        'total_return',
        'total_trades',
        'win_rate',
        'expectancy'
    ])

test_rsi()

# %%
test_rsi(lower_th=20, upper_th=80)

# %%
lower_ths = range(20, 31)
upper_ths = range(70, 81)
th_combs = list(product(lower_ths, upper_ths))
len(th_combs)

# %%
comb_stats = [
    test_rsi(lower_th=lower_th, upper_th=upper_th)
    for lower_th, upper_th in th_combs
]

# %%
comb_stats_df = pd.DataFrame(comb_stats)
comb_stats_df

# %%
comb_stats_df.index = pd.MultiIndex.from_tuples(
    th_combs,
    names=['lower_th', 'upper_th'])
comb_stats_df

# %%
comb_stats_df['Expectancy'].vbt.heatmap().show()

# %% [markdown]
# ### Using columns

# %%
windows = list(range(8, 21))
wtypes = ["simple", "exp", "wilder"]
lower_ths = list(range(20, 31))
upper_ths = list(range(70, 81))

# %%
rsi = vbt.RSI.run(
    open_price,
    window=windows,
    wtype=wtypes,
    param_product=True)
rsi.rsi.columns

# %%
lower_ths_prod, upper_ths_prod = zip(*product(lower_ths, upper_ths))
len(lower_ths_prod)

# %%
len(upper_ths_prod)

# %%
lower_th_index = vbt.Param(lower_ths_prod, name='lower_th')
entries = rsi.rsi_crossed_below(lower_th_index)
entries.columns

# %%
upper_th_index = vbt.Param(upper_ths_prod, name='upper_th')
exits = rsi.rsi_crossed_above(upper_th_index)
exits.columns

# %%
pf = vbt.Portfolio.from_signals(
    close=close_price,
    entries=entries,
    exits=exits,
    size=100,
    size_type='value',
    init_cash='auto'
)
pf

# %%
stats_df = pf.stats([
    'total_return',
    'total_trades',
    'win_rate',
    'expectancy'
], agg_func=None)
stats_df

# %%
print(pf.getsize())

# %%
np.product(pf.wrapper.shape) * 8 / 1024 / 1024

# %%
stats_df['Expectancy'].groupby('rsi_window').mean()

# %%
stats_df.sort_values(by='Expectancy', ascending=False).head()

# %%
pf[(22, 80, 20, "wilder")].plot_value().show()

# %%
data = vbt.BinanceData.pull(['BTCUSDT', 'ETHUSDT'])

# %%
eth_mask = stats_df.index.get_level_values('symbol') == 'ETHUSDT'
btc_mask = stats_df.index.get_level_values('symbol') == 'BTCUSDT'
pd.DataFrame({
    'ETHUSDT': stats_df[eth_mask]['Expectancy'].values,
    'BTCUSDT': stats_df[btc_mask]['Expectancy'].values
}).vbt.histplot(xaxis=dict(title="Expectancy")).show()

# %%