# ################################## HOW TO USE #################################### #
#                                                                                    #
# This is a Jupyter notebook formatted as a script                                   #
# Format: https://jupytext.readthedocs.io/en/latest/formats.html#the-percent-format  #
#                                                                                    #
# Save this file and remove the '.txt' extension                                     #
# In Jupyter Lab, right click on the Python file -> Open With -> Jupytext Notebook   #
# Make sure to have Jupytext installed: https://github.com/mwouts/jupytext           #
#                                                                                    #
# ################################################################################## #

# %% [markdown]
# #  Stop signals
# ## Parameters

# %%
from vectorbtpro import *
import ipywidgets

seed = 42
symbols = [
    "BTC-USD", "ETH-USD", "XRP-USD", "BCH-USD", "LTC-USD",
    "BNB-USD", "EOS-USD", "XLM-USD", "XMR-USD", "ADA-USD"
]
start_date = pd.Timestamp("2018-01-01", tz="utc")
end_date = pd.Timestamp("2021-01-01", tz="utc")
time_delta = end_date - start_date
window_len = pd.Timedelta(days=180)
window_cnt = 400
exit_types = ["SL", "TS", "TP", "Random", "Holding"]
step = 0.01
stops = np.arange(step, 1 + step, step)

vbt.settings.wrapping["freq"] = "d"
vbt.settings.plotting["layout"]["template"] = "vbt_dark"
vbt.settings.portfolio["init_cash"] = 100.
vbt.settings.portfolio["fees"] = 0.0025
vbt.settings.portfolio["slippage"] = 0.0025

pd.Series({
    "Start date": start_date,
    "End date": end_date,
    "Time period (days)": time_delta.days,
    "Assets": len(symbols),
    "Window length": window_len,
    "Windows": window_cnt,
    "Exit types": len(exit_types),
    "Stop values": len(stops),
    "Tests per asset": window_cnt * len(stops) * len(exit_types),
    "Tests per window": len(symbols) * len(stops) * len(exit_types),
    "Tests per exit type": len(symbols) * window_cnt * len(stops),
    "Tests per stop type and value": len(symbols) * window_cnt,
    "Tests total": len(symbols) * window_cnt * len(stops) * len(exit_types)
})

# %%
cols = ["Open", "Low", "High", "Close", "Volume"]
yfdata = vbt.YFData.pull(symbols, start=start_date, end=end_date)

# %%
yfdata.data.keys()

# %%
yfdata.data["BTC-USD"].shape

# %%
yfdata.plot(symbol="BTC-USD").show()

# %%
ohlcv = yfdata.concat()

ohlcv.keys()

# %%
ohlcv['Open'].shape

# %% [markdown]
# ## Time windows

# %%
splitter = vbt.Splitter.from_n_rolling(
    ohlcv["Open"].index,
    n=window_cnt,
    length=window_len.days
)

split_ohlcv = {}
for k, v in ohlcv.items():
    split_ohlcv[k] = splitter.take(v, into="reset_stacked")
print(split_ohlcv["Open"].shape)

# %%
split_indexes = splitter.take(ohlcv["Open"].index)
print(split_indexes)

# %%
print(split_indexes[10])

# %%
split_ohlcv["Open"].columns

# %% [markdown]
# ## Entry signals

# %%
entries = pd.DataFrame.vbt.signals.empty_like(split_ohlcv["Open"])
entries.iloc[0, :] = True

entries.shape

# %% [markdown]
# ## Exit signals

# %%
sl_ohlcstx = vbt.OHLCSTX.run(
    entries,
    entry_price=split_ohlcv["Close"],
    open=split_ohlcv["Open"],
    high=split_ohlcv["High"],
    low=split_ohlcv["Low"],
    close=split_ohlcv["Close"],
    sl_stop=list(stops),
    stop_type=None
)
sl_exits = sl_ohlcstx.exits.copy()
sl_price = sl_ohlcstx.close.copy()
sl_price[sl_exits] = sl_ohlcstx.stop_price
del sl_ohlcstx

sl_exits.shape

# %%
tsl_ohlcstx = vbt.OHLCSTX.run(
    entries,
    entry_price=split_ohlcv["Close"],
    open=split_ohlcv["Open"],
    high=split_ohlcv["High"],
    low=split_ohlcv["Low"],
    close=split_ohlcv["Close"],
    tsl_stop=list(stops),
    stop_type=None
)
tsl_exits = tsl_ohlcstx.exits.copy()
tsl_price = tsl_ohlcstx.close.copy()
tsl_price[tsl_exits] = tsl_ohlcstx.stop_price
del tsl_ohlcstx

tsl_exits.shape

# %%
tp_ohlcstx = vbt.OHLCSTX.run(
    entries,
    entry_price=split_ohlcv["Close"],
    open=split_ohlcv["Open"],
    high=split_ohlcv["High"],
    low=split_ohlcv["Low"],
    close=split_ohlcv["Close"],
    tp_stop=list(stops),
    stop_type=None
)
tp_exits = tp_ohlcstx.exits.copy()
tp_price = tp_ohlcstx.close.copy()
tp_price[tp_exits] = tp_ohlcstx.stop_price
del tp_ohlcstx

tp_exits.shape

# %%
def rename_stop_level(df):
    return df.vbt.rename_levels({
        "ohlcstx_sl_stop": "stop_value",
        "ohlcstx_tsl_stop": "stop_value",
        "ohlcstx_tp_stop": "stop_value"
    }, strict=False)

sl_exits = rename_stop_level(sl_exits)
tsl_exits = rename_stop_level(tsl_exits)
tp_exits = rename_stop_level(tp_exits)

sl_price = rename_stop_level(sl_price)
tsl_price = rename_stop_level(tsl_price)
tp_price = rename_stop_level(tp_price)

sl_exits.columns

# %%
pd.Series({
    "SL": sl_exits.vbt.signals.total().mean(),
    "TS": tsl_exits.vbt.signals.total().mean(),
    "TP": tp_exits.vbt.signals.total().mean()
}, name="avg_num_signals")

# %%
def groupby_stop_value(df):
    return df.vbt.signals.total().groupby("stop_value").mean()

pd.DataFrame({
    "Stop Loss": groupby_stop_value(sl_exits),
    "Trailing Stop": groupby_stop_value(tsl_exits),
    "Take Profit": groupby_stop_value(tp_exits)
}).vbt.plot(
    xaxis_title="Stop value",
    yaxis_title="Avg number of signals"
).show()

# %%
sl_exits.iloc[-1, :] = True
tsl_exits.iloc[-1, :] = True
tp_exits.iloc[-1, :] = True

sl_exits = sl_exits.vbt.signals.first_after(entries)
tsl_exits = tsl_exits.vbt.signals.first_after(entries)
tp_exits = tp_exits.vbt.signals.first_after(entries)

pd.Series({
    "SL": sl_exits.vbt.signals.total().mean(),
    "TS": tsl_exits.vbt.signals.total().mean(),
    "TP": tp_exits.vbt.signals.total().mean()
}, name="avg_num_signals")

# %%
hold_exits = pd.DataFrame.vbt.signals.empty_like(sl_exits)
hold_exits.iloc[-1, :] = True
hold_price = vbt.broadcast_to(split_ohlcv["Close"], sl_price)

hold_exits.shape

# %%
rand_exits = hold_exits.vbt.shuffle(seed=seed)
rand_price = hold_price

rand_exits.shape

# %%
exits = pd.DataFrame.vbt.concat(
    sl_exits,
    tsl_exits,
    tp_exits,
    rand_exits,
    hold_exits,
    keys=pd.Index(exit_types, name="exit_type")
)
del sl_exits
del tsl_exits
del tp_exits
del rand_exits
del hold_exits

exits.shape

# %%
price = pd.DataFrame.vbt.concat(
    sl_price,
    tsl_price,
    tp_price,
    rand_price,
    hold_price,
    keys=pd.Index(exit_types, name="exit_type")
)
del sl_price
del tsl_price
del tp_price
del rand_price
del hold_price

price.shape

# %%
exits.columns

# %%
print(exits.vbt.getsize())

# %%
print(price.vbt.getsize())

# %%
avg_distance = entries.vbt.signals.between_ranges(target=exits)\
    .duration.mean()\
    .groupby(["exit_type", "stop_value"])\
    .mean()\
    .unstack(level="exit_type")

avg_distance.mean()

# %%
avg_distance[exit_types].vbt.plot(
    xaxis_title="Stop value",
    yaxis_title="Avg distance to entry"
).show()

# %% [markdown]
# ## Simulation

# %%
%%time
pf = vbt.Portfolio.from_signals(
    split_ohlcv["Close"],
    entries,
    exits,
    price=price
)

len(pf.orders)

# %%
total_return = pf.total_return
del pf

total_return.shape

# %%
total_returns = []
for i in vbt.get_pbar(range(len(exit_types))):
    exit_type_columns = exits.columns.get_level_values("exit_type")
    chunk_mask = exit_type_columns == exit_types[i]
    chunk_pf = vbt.Portfolio.from_signals(
        split_ohlcv["Close"],
        entries,
        exits.loc[:, chunk_mask],
        price=price.loc[:, chunk_mask]
    )
    total_returns.append(chunk_pf.total_return)

    del chunk_pf
    vbt.flush()

total_return = pd.concat(total_returns)

total_return.shape

# %%
total_return = pd.concat(total_returns)

total_return.shape

# %% [markdown]
# ## Performance

# %%
return_by_type = total_return.unstack(level="exit_type")[exit_types]

return_by_type["Holding"].describe(percentiles=[])

# %%
purple_color = vbt.settings["plotting"]["color_schema"]["purple"]
return_by_type["Holding"].vbt.histplot(
    xaxis_title="Total return",
    xaxis_tickformat=".2%",
    yaxis_title="Count",
    trace_kwargs=dict(marker_color=purple_color)
).show()

# %%
pd.DataFrame({
    "Mean": return_by_type.mean(),
    "Median": return_by_type.median(),
    "Std": return_by_type.std(),
})

# %%
return_by_type.vbt.boxplot(
    trace_kwargs=dict(boxpoints=False),
    yaxis_title="Total return",
    yaxis_tickformat=".2%"
).show()

# %%
(return_by_type > 0).mean().rename("win_rate")

# %%
init_cash = vbt.settings.portfolio["init_cash"]

def get_expectancy(return_by_type, level_name):
    grouped = return_by_type.groupby(level_name, axis=0)
    win_rate = grouped.apply(lambda x: (x > 0).mean())
    avg_win = grouped.apply(lambda x: init_cash * x[x > 0].mean())
    avg_win = avg_win.fillna(0)
    avg_loss = grouped.apply(lambda x: init_cash * x[x < 0].mean())
    avg_loss = avg_loss.fillna(0)
    return win_rate * avg_win - (1 - win_rate) * np.abs(avg_loss)

expectancy_by_stop = get_expectancy(return_by_type, "stop_value")

expectancy_by_stop.mean()

# %%
expectancy_by_stop.vbt.plot(
    xaxis_title="Stop value",
    yaxis_title="Expectancy"
).show()

# %%
return_values = np.sort(return_by_type["Holding"].values)
idxs = np.ceil(np.linspace(0, len(return_values) - 1, 21)).astype(int)
bins = return_values[idxs][:-1]

def bin_return(return_by_type):
    classes = pd.cut(return_by_type["Holding"], bins=bins, right=True)
    new_level = np.array(classes.apply(lambda x: x.right))
    new_level = pd.Index(new_level, name="bin_right")
    return return_by_type.vbt.stack_index(new_level, axis=0)

binned_return_by_type = bin_return(return_by_type)

expectancy_by_bin = get_expectancy(binned_return_by_type, "bin_right")

expectancy_by_bin.vbt.plot(
    trace_kwargs=dict(mode="lines"),
    xaxis_title="Total return of holding",
    xaxis_tickformat=".2%",
    yaxis_title="Expectancy"
).show()

# %% [markdown]
# ## Bonus: Dashboard

# %%
range_starts = pd.DatetimeIndex(list(map(lambda x: x[0], split_indexes)))
range_ends = pd.DatetimeIndex(list(map(lambda x: x[-1], split_indexes)))

symbol_lvl = return_by_type.index.get_level_values("symbol")
split_lvl = return_by_type.index.get_level_values("split")
range_start_lvl = range_starts[split_lvl]
range_end_lvl = range_ends[split_lvl]

asset_multi_select = ipywidgets.SelectMultiple(
    options=symbols,
    value=symbols,
    rows=len(symbols),
    description="Symbols"
)
dates = np.unique(yfdata.wrapper.index)
date_range_slider = ipywidgets.SelectionRangeSlider(
    options=dates,
    index=(0, len(dates)-1),
    orientation="horizontal",
    readout=False,
    continuous_update=False
)
range_start_label = ipywidgets.Label()
range_end_label = ipywidgets.Label()
metric_dropdown = ipywidgets.Dropdown(
    options=["Mean", "Median", "Win Rate", "Expectancy"],
    value="Expectancy"
)
stop_scatter = vbt.Scatter(
    trace_names=exit_types,
    x_labels=stops,
    xaxis_title="Stop value",
    yaxis_title="Expectancy"
)
stop_scatter_img = ipywidgets.Image(
    format="png",
    width=stop_scatter.fig.layout.width,
    height=stop_scatter.fig.layout.height
)
bin_scatter = vbt.Scatter(
    trace_names=exit_types,
    x_labels=expectancy_by_bin.index,
    trace_kwargs=dict(mode="lines"),
    xaxis_title="Total return of holding",
    xaxis_tickformat="%",
    yaxis_title="Expectancy"
)
bin_scatter_img = ipywidgets.Image(
    format="png",
    width=bin_scatter.fig.layout.width,
    height=bin_scatter.fig.layout.height
)

# %%
def update_scatter(*args, **kwargs):
    _symbols = asset_multi_select.value
    _from = date_range_slider.value[0]
    _to = date_range_slider.value[1]
    _metric_name = metric_dropdown.value

    range_mask = (range_start_lvl >= _from) & (range_end_lvl <= _to)
    asset_mask = symbol_lvl.isin(_symbols)
    filt = return_by_type[range_mask & asset_mask]

    filt_binned = bin_return(filt)
    if _metric_name == "Mean":
        filt_metric = filt.groupby("stop_value").mean()
        filt_bin_metric = filt_binned.groupby("bin_right").mean()
    elif _metric_name == "Median":
        filt_metric = filt.groupby("stop_value").median()
        filt_bin_metric = filt_binned.groupby("bin_right").median()
    elif _metric_name == "Win Rate":
        filt_metric = (filt > 0).groupby("stop_value").mean()
        filt_bin_metric = (filt_binned > 0).groupby("bin_right").mean()
    elif _metric_name == "Expectancy":
        filt_metric = get_expectancy(filt, "stop_value")
        filt_bin_metric = get_expectancy(filt_binned, "bin_right")

    stop_scatter.fig.update_layout(yaxis_title=_metric_name)
    stop_scatter.update(filt_metric)
    stop_scatter_img.value = stop_scatter.fig.to_image(format="png")

    bin_scatter.fig.update_layout(yaxis_title=_metric_name)
    bin_scatter.update(filt_bin_metric)
    bin_scatter_img.value = bin_scatter.fig.to_image(format="png")

    range_start_label.value = np.datetime_as_string(
        _from.to_datetime64(), unit="D")
    range_end_label.value = np.datetime_as_string(
        _to.to_datetime64(), unit="D")

asset_multi_select.observe(update_scatter, names="value")
date_range_slider.observe(update_scatter, names="value")
metric_dropdown.observe(update_scatter, names="value")
update_scatter()

# %%
dashboard = ipywidgets.VBox([
    asset_multi_select,
    ipywidgets.HBox([
        range_start_label,
        date_range_slider,
        range_end_label
    ]),
    metric_dropdown,
    stop_scatter_img,
    bin_scatter_img
])
dashboard

# %%