# ################################## HOW TO USE #################################### #
#                                                                                    #
# This is a Jupyter notebook formatted as a script                                   #
# Format: https://jupytext.readthedocs.io/en/latest/formats.html#the-percent-format  #
#                                                                                    #
# Save this file and remove the '.txt' extension                                     #
# In Jupyter Lab, right click on the Python file -> Open With -> Jupytext Notebook   #
# Make sure to have Jupytext installed: https://github.com/mwouts/jupytext           #
#                                                                                    #
# ################################################################################## #

# %% [markdown]
# #  SuperFast SuperTrend
# ## Data

# %%
from vectorbtpro import *

data = vbt.BinanceData.pull(
    ['BTCUSDT', 'ETHUSDT'],
    start='2020-01-01 UTC',
    end='2022-01-01 UTC',
    timeframe='1h'
)

# %%
data.to_hdf('my_data.h5')

# %%
data = vbt.HDFData.pull('my_data.h5')

# %%
data.data['BTCUSDT'].info()

# %%
data.stats()

# %%
high = data.get('High')
low = data.get('Low')
close = data.get('Close')

close

# %% [markdown]
# ## Design

# %% [markdown]
# ### Pandas

# %%
def get_med_price(high, low):
    return (high + low) / 2

def get_atr(high, low, close, period):
    tr0 = abs(high - low)
    tr1 = abs(high - close.shift())
    tr2 = abs(low - close.shift())
    tr = pd.concat((tr0, tr1, tr2), axis=1).max(axis=1)
    atr = tr.ewm(
        alpha=1 / period,
        adjust=False,
        min_periods=period).mean()
    return atr

def get_basic_bands(med_price, atr, multiplier):
    matr = multiplier * atr
    upper = med_price + matr
    lower = med_price - matr
    return upper, lower

def get_final_bands(close, upper, lower):
    trend = pd.Series(np.full(close.shape, np.nan), index=close.index)
    dir_ = pd.Series(np.full(close.shape, 1), index=close.index)
    long = pd.Series(np.full(close.shape, np.nan), index=close.index)
    short = pd.Series(np.full(close.shape, np.nan), index=close.index)

    for i in range(1, close.shape[0]):
        if close.iloc[i] > upper.iloc[i - 1]:
            dir_.iloc[i] = 1
        elif close.iloc[i] < lower.iloc[i - 1]:
            dir_.iloc[i] = -1
        else:
            dir_.iloc[i] = dir_.iloc[i - 1]
            if dir_.iloc[i] > 0 and lower.iloc[i] < lower.iloc[i - 1]:
                lower.iloc[i] = lower.iloc[i - 1]
            if dir_.iloc[i] < 0 and upper.iloc[i] > upper.iloc[i - 1]:
                upper.iloc[i] = upper.iloc[i - 1]

        if dir_.iloc[i] > 0:
            trend.iloc[i] = long.iloc[i] = lower.iloc[i]
        else:
            trend.iloc[i] = short.iloc[i] = upper.iloc[i]

    return trend, dir_, long, short

def supertrend(high, low, close, period=7, multiplier=3):
    med_price = get_med_price(high, low)
    atr = get_atr(high, low, close, period)
    upper, lower = get_basic_bands(med_price, atr, multiplier)
    return get_final_bands(close, upper, lower)

# %%
supert, superd, superl, supers = supertrend(
    high['BTCUSDT'],
    low['BTCUSDT'],
    close['BTCUSDT']
)

supert

# %%
superd

# %%
superl

# %%
supers

# %%
date_range = slice('2020-01-01', '2020-02-01')
fig = close.loc[date_range, 'BTCUSDT'].rename('Close').vbt.plot()
supers.loc[date_range].rename('Short').vbt.plot(fig=fig)
superl.loc[date_range].rename('Long').vbt.plot(fig=fig)
fig.show()

# %%
%%timeit
supertrend(high['BTCUSDT'], low['BTCUSDT'], close['BTCUSDT'])

# %%
SUPERTREND = vbt.pandas_ta('SUPERTREND')

%%timeit
SUPERTREND.run(high['BTCUSDT'], low['BTCUSDT'], close['BTCUSDT'])

# %% [markdown]
# ### NumPy + Numba

# %%
def get_atr_np(high, low, close, period):
    shifted_close = vbt.nb.fshift_1d_nb(close)
    tr0 = np.abs(high - low)
    tr1 = np.abs(high - shifted_close)
    tr2 = np.abs(low - shifted_close)
    tr = np.column_stack((tr0, tr1, tr2)).max(axis=1)
    atr = vbt.nb.wwm_mean_1d_nb(tr, period)
    return atr

# %%
@njit
def get_final_bands_nb(close, upper, lower):
    trend = np.full(close.shape, np.nan)
    dir_ = np.full(close.shape, 1)
    long = np.full(close.shape, np.nan)
    short = np.full(close.shape, np.nan)

    for i in range(1, close.shape[0]):
        if close[i] > upper[i - 1]:
            dir_[i] = 1
        elif close[i] < lower[i - 1]:
            dir_[i] = -1
        else:
            dir_[i] = dir_[i - 1]
            if dir_[i] > 0 and lower[i] < lower[i - 1]:
                lower[i] = lower[i - 1]
            if dir_[i] < 0 and upper[i] > upper[i - 1]:
                upper[i] = upper[i - 1]

        if dir_[i] > 0:
            trend[i] = long[i] = lower[i]
        else:
            trend[i] = short[i] = upper[i]

    return trend, dir_, long, short

# %%
def faster_supertrend(high, low, close, period=7, multiplier=3):
    med_price = get_med_price(high, low)
    atr = get_atr_np(high, low, close, period)
    upper, lower = get_basic_bands(med_price, atr, multiplier)
    return get_final_bands_nb(close, upper, lower)

supert, superd, superl, supers = faster_supertrend(
    high['BTCUSDT'].values,
    low['BTCUSDT'].values,
    close['BTCUSDT'].values
)

supert

# %%
superd

# %%
superl

# %%
supers

# %%
pd.Series(supert, index=close.index)

# %%
faster_supertrend(
    high['BTCUSDT'].values,
    low['BTCUSDT'].values,
    close['BTCUSDT'].values
)

# %% [markdown]
# ### NumPy + Numba + TA-Lib

# %%
import talib

def faster_supertrend_talib(high, low, close, period=7, multiplier=3):
    avg_price = talib.MEDPRICE(high, low)
    atr = talib.ATR(high, low, close, period)
    upper, lower = get_basic_bands(avg_price, atr, multiplier)
    return get_final_bands_nb(close, upper, lower)

faster_supertrend_talib(
    high['BTCUSDT'].values,
    low['BTCUSDT'].values,
    close['BTCUSDT'].values
)

# %%
%%timeit
faster_supertrend_talib(
    high['BTCUSDT'].values,
    low['BTCUSDT'].values,
    close['BTCUSDT'].values
)

# %% [markdown]
# ## Indicator factory

# %%
SuperTrend = vbt.IF(
    class_name='SuperTrend',
    short_name='st',
    input_names=['high', 'low', 'close'],
    param_names=['period', 'multiplier'],
    output_names=['supert', 'superd', 'superl', 'supers']
).with_apply_func(
    faster_supertrend_talib,
    takes_1d=True,
    period=7,
    multiplier=3
)

# %%
vbt.phelp(SuperTrend.run)

# %%
st = SuperTrend.run(high, low, close)
st.supert

# %%
%%timeit
SuperTrend.run(high, low, close)

# %% [markdown]
# ### Expressions

# %%
expr = """
SuperTrend[st]:
medprice = @talib_medprice(high, low)
atr = @talib_atr(high, low, close, @p_period)
upper, lower = get_basic_bands(medprice, atr, @p_multiplier)
supert, superd, superl, supers = get_final_bands(close, upper, lower)
supert, superd, superl, supers
"""

# %%
SuperTrend = vbt.IF.from_expr(
    expr,
    takes_1d=True,
    get_basic_bands=get_basic_bands,
    get_final_bands=get_final_bands_nb,
    period=7,
    multiplier=3
)

st = SuperTrend.run(high, low, close)
st.supert

# %%
%%timeit
SuperTrend.run(high, low, close)

# %% [markdown]
# ### Plotting

# %%
class SuperTrend(SuperTrend):
    def plot(self,
             column=None,
             close_kwargs=None,
             superl_kwargs=None,
             supers_kwargs=None,
             fig=None,
             **layout_kwargs):
        close_kwargs = close_kwargs if close_kwargs else {}
        superl_kwargs = superl_kwargs if superl_kwargs else {}
        supers_kwargs = supers_kwargs if supers_kwargs else {}

        close = self.select_col_from_obj(self.close, column).rename('Close')
        supers = self.select_col_from_obj(self.supers, column).rename('Short')
        superl = self.select_col_from_obj(self.superl, column).rename('Long')

        fig = close.vbt.plot(fig=fig, **close_kwargs, **layout_kwargs)
        supers.vbt.plot(fig=fig, **supers_kwargs)
        superl.vbt.plot(fig=fig, **superl_kwargs)

        return fig

# %%
st = SuperTrend.run(high, low, close)
st.loc[date_range, 'BTCUSDT'].plot(
    superl_kwargs=dict(trace_kwargs=dict(line_color='limegreen')),
    supers_kwargs=dict(trace_kwargs=dict(line_color='red'))
).show()

# %% [markdown]
# ## Backtesting

# %%
entries = (~st.superl.isnull()).vbt.signals.fshift()
exits = (~st.supers.isnull()).vbt.signals.fshift()

# %%
pf = vbt.Portfolio.from_signals(
    close=close,
    entries=entries,
    exits=exits,
    fees=0.001,
    freq='1h'
)

# %%
pf['ETHUSDT'].stats()

# %% [markdown]
# ### Optimization

# %%
periods = np.arange(4, 20)
multipliers = np.arange(20, 41) / 10

st = SuperTrend.run(
    high, low, close,
    period=periods,
    multiplier=multipliers,
    param_product=True,
    execute_kwargs=dict(show_progress=True)
)

# %%
st.wrapper.columns

# %%
st.loc[date_range, (19, 4, 'ETHUSDT')].plot().show()

# %%
print(st.getsize())

# %%
output_size = st.wrapper.shape[0] * st.wrapper.shape[1]
n_outputs = 4
data_type_size = 8
input_size * n_outputs * data_type_size / 1024 / 1024

# %%
entries = (~st.superl.isnull()).vbt.signals.fshift()
exits = (~st.supers.isnull()).vbt.signals.fshift()

pf = vbt.Portfolio.from_signals(
    close=close,
    entries=entries,
    exits=exits,
    fees=0.001,
    freq='1h'
)

# %%
pf.sharpe_ratio.vbt.heatmap(
    x_level='st_period',
    y_level='st_multiplier',
    slider_level='symbol'
)

# %%
vbt.Portfolio.from_holding(close, freq='1h').sharpe_ratio

# %%