# ################################## HOW TO USE #################################### #
#                                                                                    #
# This is a Jupyter notebook formatted as a script                                   #
# Format: https://jupytext.readthedocs.io/en/latest/formats.html#the-percent-format  #
#                                                                                    #
# Save this file and remove the '.txt' extension                                     #
# In Jupyter Lab, right click on the Python file -> Open With -> Jupytext Notebook   #
# Make sure to have Jupytext installed: https://github.com/mwouts/jupytext           #
#                                                                                    #
# ################################################################################## #

# %% [markdown]
# #  Parsers
# ## TA-Lib

# %%
from vectorbtpro import *

vbt.IF.list_talib_indicators()

# %%
vbt.IF.from_talib('RSI')

# %%
vbt.talib('RSI')

# %% [markdown]
# ### Skipping NaN

# %%
price = vbt.RandomData.pull(
    start='2020-01-01',
    end='2020-06-01',
    timeframe='1H',
    seed=42
).get()
price_na = price.copy()
price_na.iloc[2] = np.nan

SMA = vbt.talib("SMA")
sma = SMA.run(price_na, timeperiod=10)
sma.real

# %%
sma = SMA.run(price_na, timeperiod=10, skipna=True)
sma.real

# %% [markdown]
# ### Resampling

# %%
sma = SMA.run(
    price_na,
    timeperiod=10,
    skipna=True,
    timeframe=["1h", "4h", "1d"]
)
sma.real

# %% [markdown]
# ### Plotting

# %%
STOCH = vbt.talib('STOCH')
STOCH.output_flags

# %%
ohlc = price.resample('1d').ohlc()
stoch = STOCH.run(ohlc['high'], ohlc['low'], ohlc['close'])
stoch.plot().show()

# %%
vbt.phelp(STOCH.plot)

# %%
fig = vbt.make_subplots(
    rows=2,
    cols=1,
    shared_xaxes=True,
    vertical_spacing=0.05)
ohlc.vbt.ohlcv.plot(
    add_trace_kwargs=dict(row=1, col=1),
    fig=fig,
    xaxis=dict(rangeslider_visible=False))
stoch.plot(
    limits=(20, 80),
    add_trace_kwargs=dict(row=2, col=1),
    slowk_trace_kwargs=dict(line=dict(dash=None)),
    slowd_trace_kwargs=dict(line=dict(dash=None)),
    fig=fig)
fig.show()

# %% [markdown]
# ## Pandas TA

# %%
vbt.IF.list_pandas_ta_indicators()

# %%
vbt.IF.from_pandas_ta('RSI')

# %%
vbt.pandas_ta('RSI')

# %% [markdown]
# ## TA

# %%
vbt.IF.list_ta_indicators()

# %%
vbt.IF.from_ta('RSIIndicator')

# %%
vbt.ta('RSIIndicator')

# %% [markdown]
# ## Expressions
# ### Instance method

# %%
expr = """
tr0 = abs(high - low)
tr1 = abs(high - fshift(close))
tr2 = abs(low - fshift(close))
tr = nanmax(column_stack((tr0, tr1, tr2)), axis=1)
atr = wwm_mean_1d(tr, n)
tr, atr
"""
ATR = vbt.IF(
    class_name='ATR',
    input_names=['high', 'low', 'close'],
    param_names=['n'],
    output_names=['tr', 'atr']
).from_expr(expr, n=14)

atr = ATR.run(ohlc['high'], ohlc['low'], ohlc['close'])
atr.atr

# %% [markdown]
# ### Class method

# %%
expr = """
ATR:
tr0 = abs(@in_high - @in_low)
tr1 = abs(@in_high - fshift(@in_close))
tr2 = abs(@in_low - fshift(@in_close))
@out_tr = nanmax(column_stack((tr0, tr1, tr2)), axis=1)
@out_atr = wwm_mean_1d(@out_tr, @p_n)
@out_tr, @out_atr
"""
ATR = vbt.IF.from_expr(expr, n=14)

# %%
ATR.input_names

# %%
expr = """
ATR:
tr0 = abs(high - low)
tr1 = abs(high - fshift(close))
tr2 = abs(low - fshift(close))
tr = nanmax(column_stack((tr0, tr1, tr2)), axis=1)
atr = wwm_mean_1d(tr, @p_n)
tr, atr
"""
ATR = vbt.IF.from_expr(expr, n=14)
ATR.input_names

# %%
ATR.output_names

# %%
expr = """
ATR:
tr0 = abs(high - low)
tr1 = abs(high - vbt.nb.fshift_nb(close))
tr2 = abs(low - vbt.nb.fshift_nb(close))
tr = np.nanmax(np.column_stack((tr0, tr1, tr2)), axis=1)
atr = vbt.nb.wwm_mean_1d_nb(tr, n)
tr, atr
"""

# %% [markdown]
# ### TA-Lib

# %%
expr = """
ATR:
tr0 = abs(high - low)
tr1 = abs(high - fshift(close))
tr2 = abs(low - fshift(close))
tr = nanmax(column_stack((tr0, tr1, tr2)), axis=1)
atr = @talib_ema(tr, 2 * n - 1)  # Wilder's EMA
tr, atr
"""

# %% [markdown]
# ### Context

# %%
expr = """
ATR:
tr0 = abs(high - low)
tr1 = abs(high - shift_close(close))
tr2 = abs(low - shift_close(close))
tr = nanmax(column_stack((tr0, tr1, tr2)), axis=1)
atr = wwm_mean_1d(tr, @p_n)
tr, atr
"""
ATR = vbt.IF.from_expr(expr, n=14, shift_close=vbt.nb.fshift_nb)

# %%
def shift_close(close, context):
    return vbt.nb.fshift_nb(close, context.get('shift', 1))

expr = """
ATR:
tr0 = abs(high - low)
tr1 = abs(high - shift_close(close))
tr2 = abs(low - shift_close(close))
tr = nanmax(column_stack((tr0, tr1, tr2)), axis=1)
atr = wwm_mean_1d(tr, @p_n)
tr, atr
"""
ATR = vbt.IF.from_expr(expr, n=14, shift_close=shift_close, shift=2)

# %%
def shift_close(context):
    return vbt.nb.fshift_nb(context['close'], context.get('shift', 1))

expr = """
ATR:
tr0 = abs(high - low)
tr1 = abs(high - shift_close())
tr2 = abs(low - shift_close())
tr = nanmax(column_stack((tr0, tr1, tr2)), axis=1)
atr = wwm_mean_1d(tr, @p_n)
tr, atr
"""
ATR = vbt.IF.from_expr(expr, n=14, shift_close=shift_close)

# %%
func_mapping = dict(
    shift_close=dict(
        func=shift_close,
        magnet_inputs=['close']
    )
)
ATR = vbt.IF.from_expr(expr, n=14, func_mapping=func_mapping)

# %%
expr = """
ATR:
tr0 = abs(high - low)
tr1 = abs(high - shifted_close)
tr2 = abs(low - shifted_close)
tr = nanmax(column_stack((tr0, tr1, tr2)), axis=1)
atr = wwm_mean_1d(tr, @p_n)
tr, atr
"""
res_func_mapping = dict(
    shifted_close=dict(
        func=shift_close,
        magnet_inputs=['close']
    )
)
ATR = vbt.IF.from_expr(expr, n=14, res_func_mapping=res_func_mapping)

# %% [markdown]
# ### Settings

# %%
expr = """
@settings(dict(
    factory_kwargs=dict(
        class_name='ATR',
        input_names=['high', 'low', 'close'],
        param_names=['n'],
        output_names=['tr', 'atr']
    ),
    n=14
))

tr0 = abs(high - low)
tr1 = abs(high - fshift(close))
tr2 = abs(low - fshift(close))
tr = nanmax(column_stack((tr0, tr1, tr2)), axis=1)
atr = wwm_mean_1d(tr, n)
tr, atr
"""
ATR = vbt.IF.from_expr(expr)

# %% [markdown]
# ### Stacking

# %%
expr = """
SuperTrend[st]:
avg_price = (high + low) / 2
up = avg_price + @p_mult * @res_talib_atr
down = avg_price - @p_mult * @res_talib_atr
up, down
"""
SuperTrend = vbt.IF.from_expr(expr, mult=3, atr_timeperiod=10)

SuperTrend.input_names

# %%
SuperTrend.param_names

# %%
SuperTrend.output_names

# %%
st = SuperTrend.run(ohlc['high'], ohlc['low'], ohlc['close'])
st.up

# %%
expr = """
SuperTrend[st]:
avg_price = (high + low) / 2
up = avg_price + @p_mult * @res_talib_atr.real.values
down = avg_price - @p_mult * @res_talib_atr.real.values
up, down
"""
SuperTrend = vbt.IF.from_expr(
    expr,
    mult=3,
    atr_timeperiod=10,
    atr_kwargs=dict(return_raw=False))

# %% [markdown]
# ### One-liners

# %%
AvgPrice = vbt.IF.from_expr("AvgPrice: @out_avg_price:(high + low) / 2")

AvgPrice.run(ohlc['high'], ohlc['low']).avg_price

# %%
SuperTrend = vbt.IF.from_expr(
    "SuperTrend[st]: @out_up:@res_avg_price + @p_mult * @res_talib_atr, "
    "@out_down:@res_avg_price - @p_mult * @res_talib_atr",
    avg_price=AvgPrice,
    atr_timeperiod=10,
    mult=3)
st = SuperTrend.run(ohlc['high'], ohlc['low'], ohlc['close'])

fig = ohlc.vbt.ohlcv.plot()
st.up.rename('Upper').vbt.plot(fig=fig)
st.down.rename('Lower').vbt.plot(fig=fig)
fig.show()

# %% [markdown]
# ### Using Pandas

# %%
expr = """
ATR:
tr0 = abs(high - low)
tr1 = abs(high - close.shift())
tr2 = abs(low - close.shift())
tr = pd.concat((tr0, tr1, tr2), axis=1).max(axis=1)
atr = tr.ewm(alpha=1 / @p_n, adjust=False, min_periods=@p_n).mean()
tr, atr
"""
ATR = vbt.IF.from_expr(expr, n=14, keep_pd=True)
atr = ATR.run(ohlc['high'], ohlc['low'], ohlc['close'])
atr.atr

# %%
AvgPrice = vbt.IF.from_expr(
    "AvgPrice: @out_avg_price:(high + low) / 2",
    use_pd_eval=True
)

AvgPrice.run(ohlc['high'], ohlc['low']).avg_price

# %% [markdown]
# ### Debugging

# %%
expr = """
ATR:
tr0 = abs(@in_high - @in_low)
tr1 = abs(@in_high - fshift(@in_close))
tr2 = abs(@in_low - fshift(@in_close))
@out_tr = nanmax(column_stack((tr0, tr1, tr2)), axis=1)
@out_atr = wwm_mean_1d(@out_tr, @p_n)
@out_tr, @out_atr
"""
print(vbt.IF.from_expr(expr, n=14, return_clean_expr=True))

# %%
expr = """
ATR:
tr0 = abs(@in_high - @in_low)
print('tr0: ', tr0.shape)
tr1 = abs(@in_high - fshift(@in_close))
print('tr1: ', tr1.shape)
tr2 = abs(@in_low - fshift(@in_close))
print('tr2: ', tr2.shape)
@out_tr = nanmax(column_stack((tr0, tr1, tr2)), axis=1)
print('tr: ', @out_tr.shape)
@out_atr = wwm_mean_1d(@out_tr, @p_n)
print('atr: ', @out_atr.shape)
@out_tr, @out_atr
"""
ATR = vbt.IF.from_expr(expr, n=14)
atr = ATR.run(ohlc['high'], ohlc['low'], ohlc['close'])

# %% [markdown]
# ## WorldQuant's Alphas

# %%
WQA53 = vbt.IF.from_wqa101(53)
wqa53 = WQA53.run(ohlc['open'], ohlc['high'], ohlc['low'], ohlc['close'])
wqa53.out

# %%
vbt.wqa101(53)

# %%
WQA53 = vbt.IF.from_expr("-delta(((close - low) - (high - close)) / (close - low), 9)")
wqa53 = WQA53.run(ohlc['open'], ohlc['high'], ohlc['low'], ohlc['close'])
wqa53.out

# %%