# ################################## HOW TO USE #################################### #
#                                                                                    #
# This is a Jupyter notebook formatted as a script                                   #
# Format: https://jupytext.readthedocs.io/en/latest/formats.html#the-percent-format  #
#                                                                                    #
# Save this file and remove the '.txt' extension                                     #
# In Jupyter Lab, right click on the Python file -> Open With -> Jupytext Notebook   #
# Make sure to have Jupytext installed: https://github.com/mwouts/jupytext           #
#                                                                                    #
# ################################################################################## #

# %% [markdown]
# #  Projections

# %%
pattern_ranges = price.vbt.find_pattern(
    [1, 1.2],
    window=7,
    rescale_mode="rebase",
    max_error=0.01,
    max_error_interp_mode="discrete",
    max_error_strict=True
)
pattern_ranges.count()

# %% [markdown]
# ## Pattern projections

# %%
range_idxs, raw_projections = vbt.nb.map_ranges_to_projections_nb(
    vbt.to_2d_array(price),
    pattern_ranges.get_field_arr("col"),
    pattern_ranges.get_field_arr("start_idx"),
    pattern_ranges.get_field_arr("end_idx"),
    pattern_ranges.get_field_arr("status")
)
range_idxs

# %%
raw_projections

# %%
projections = pattern_ranges.get_projections()
projections

# %%
pattern_ranges.duration.values

# %%
projections = pattern_ranges.get_projections(incl_end_idx=False)
projections

# %%
projections.iloc[-1] / projections.iloc[0] - 1

# %%
projections.vbt.plot().show()

# %% [markdown]
# ## Delta projections

# %%
delta_ranges = pattern_ranges.with_delta(4)

# %%
fig = pattern_ranges.loc["2021-01":"2021-03"].plot()
delta_ranges.loc["2021-01":"2021-03"].plot(
    plot_ohlc=False,
    plot_close=False,
    plot_markers=False,
    closed_shape_kwargs=dict(fillcolor="DeepSkyBlue"),
    fig=fig
)
fig.show()

# %%
projections = delta_ranges.get_projections()
projections

# %%
np.mean(projections.iloc[-1] / projections.iloc[0] - 1)

# %%
pattern_ranges = mult_price.vbt.find_pattern(
    [1, 1.2],
    window=7,
    max_window=30,
    rescale_mode="rebase",
    max_error=0.01,
    max_error_interp_mode="discrete",
    max_error_strict=True,
    overlap_mode="allow"
)
pattern_ranges.count()

# %%
delta_ranges = pattern_ranges.with_delta(4)
projections = delta_ranges.get_projections()
(projections.iloc[-1] / projections.iloc[0] - 1).describe()

# %%
projections = delta_ranges.get_projections(id_level="end_idx")
projections.columns

# %%
btc_projections = projections.xs("BTCUSDT", level="symbol", axis=1)
total_proj_return = btc_projections.iloc[-1] / btc_projections.iloc[0] - 1
total_proj_return.vbt.scatterplot(
    trace_kwargs=dict(
        marker=dict(
            color=total_proj_return.values,
            colorscale="Temps_r",
            cmid=0
        )
    )
).show()

# %% [markdown]
# ## Plotting

# %%
btc_projections.vbt.plot_projections(plot_bands=False).show()

# %% [markdown]
# ### Colorization

# %%
btc_projections["2020-08-03"]

# %%
btc_projections["2020-08-03"].median()

# %%
btc_projections.vbt.plot_projections(
    plot_bands=False, colorize=np.std
).show()

# %% [markdown]
# ### Bands

# %%
projections.xs("ETHUSDT", level="symbol", axis=1).median(axis=1)

# %%
projections.groupby("symbol", axis=1).median()

# %%
projections.median(axis=1)

# %%
btc_projections.vbt.plot_projections().show()

# %%
btc_projections.iloc[-1].quantile(0.8)

# %%
btc_projections.vbt.plot_projections(
    plot_lower=False,
    plot_middle="30%",
    plot_upper=False,
    plot_aux_middle=False,
).show()

# %%
btc_projections.iloc[-1].vbt.qqplot().show()

# %%
btc_projections.vbt.plot_projections(
    plot_lower="P=20%",
    plot_middle="P=50%",
    plot_upper="P=80%",
    plot_aux_middle=False,
).show()

# %%
def finishes_at_quantile(df, q):
    nth_element = int(np.ceil(q * (df.shape[1] - 1)))
    nth_index = np.argsort(df.iloc[-1])[nth_element]
    return df.iloc[:, nth_index]

btc_projections.vbt.plot_projections(
    plot_lower=partial(finishes_at_quantile, q=0.2),
    plot_middle=False,
    plot_upper=partial(finishes_at_quantile, q=0.8),
).show()

# %% [markdown]
# ## Filtering

# %%
crossed_mask = projections.expanding().max().iloc[1] >= 1.05
filt_projections = projections.loc[:, crossed_mask]
filt_projections.iloc[-1].describe()

# %%
filt_projections.loc[:, crossed_mask].vbt.plot_projections().show()

# %% [markdown]
# ## Latest projections

# %%
pattern_ranges = price.vbt.find_pattern(
    pattern=data.close.iloc[-7:],
    rescale_mode="rebase",
    overlap_mode="allow"
)
pattern_ranges.count()

# %%
pattern_ranges = pattern_ranges.status_closed
pattern_ranges.count()

# %%
projections = pattern_ranges.get_projections()
projections.vbt.plot_projections(plot_bands=False).show()

# %%
delta_ranges = pattern_ranges.with_delta(7)
projections = delta_ranges.get_projections(start_value=-1)
fig = data.iloc[-7:].plot(plot_volume=False)
projections.vbt.plot_projections(fig=fig)
fig.show()

# %%
projections.mean(axis=1)

# %%
next_data = vbt.BinanceData.pull(
    "BTCUSDT",
    start="2022-05-31",
    end="2022-06-08"
)
next_data.close

# %% [markdown]
# ### Quick plotting

# %%
delta_ranges.plot_projections().show()

# %% [markdown]
# ## Non-uniform projections

# %%
windows = np.arange(10, 31)
window_tuples = combinations(windows, 2)
window_tuples = filter(lambda x: abs(x[0] - x[1]) >= 5, window_tuples)
fast_windows, slow_windows = zip(*window_tuples)
fast_sma = data.run("sma", fast_windows, short_name="fast_sma")
slow_sma = data.run("sma", slow_windows, short_name="slow_sma")
entries = fast_sma.real_crossed_above(slow_sma.real)
exits = fast_sma.real_crossed_below(slow_sma.real)

entries.shape

# %%
entry_ranges = entries.vbt.signals.delta_ranges(30, close=data.close)
entry_ranges = entry_ranges.status_closed
entry_ranges.count().sum()

# %%
exit_ranges = exits.vbt.signals.delta_ranges(30, close=data.close)
exit_ranges = exit_ranges.status_closed
exit_ranges.count().sum()

# %%
entry_projections = entry_ranges.get_projections()
entry_projections.shape

# %%
exit_projections = exit_ranges.get_projections()
exit_projections.shape

# %%
fig = entry_projections.vbt.plot_projections(
    plot_projections=False,
    plot_aux_middle=False,
    plot_fill=False,
    lower_trace_kwargs=dict(name="Lower (entry)", line_color="green"),
    middle_trace_kwargs=dict(name="Middle (entry)", line_color="green"),
    upper_trace_kwargs=dict(name="Upper (entry)", line_color="green"),
)
fig = exit_projections.vbt.plot_projections(
    plot_projections=False,
    plot_aux_middle=False,
    plot_fill=False,
    lower_trace_kwargs=dict(name="Lower (exit)", line_color="orangered"),
    middle_trace_kwargs=dict(name="Middle (exit)", line_color="orangered"),
    upper_trace_kwargs=dict(name="Upper (exit)", line_color="orangered"),
    fig=fig
)
fig.show()

# %%
entry_ranges = entries.vbt.signals.between_ranges(exits, close=data.close)
entry_ranges = entry_ranges.status_closed
entry_ranges.count().sum()

# %%
exit_ranges = exits.vbt.signals.between_ranges(entries, close=data.close)
exit_ranges = exit_ranges.status_closed
exit_ranges.count().sum()

# %%
entry_projections = entry_ranges.get_projections()
entry_projections.shape

# %%
exit_projections = exit_ranges.get_projections()
exit_projections.shape

# %%
rand_cols = np.random.choice(entry_projections.shape[1], 100)
entry_projections.iloc[:, rand_cols].vbt.plot_projections(plot_bands=False).show()

# %%
rand_cols = np.random.choice(exit_projections.shape[1], 100)
exit_projections.iloc[:, rand_cols].vbt.plot_projections(plot_bands=False).show()

# %% [markdown]
# ### Shrinking

# %%
entry_projections = entry_ranges.get_projections(proj_period="30d")
entry_projections.shape

# %%
exit_projections = exit_ranges.get_projections(proj_period="30d")
exit_projections.shape

# %%
rand_cols = np.random.choice(entry_projections.shape[1], 100)
entry_projections.iloc[:, rand_cols].vbt.plot_projections().show()

# %%
rand_cols = np.random.choice(exit_projections.shape[1], 100)
exit_projections.iloc[:, rand_cols].vbt.plot_projections().show()

# %% [markdown]
# ### Stretching

# %%
entry_projections = entry_ranges.get_projections(
    proj_period="30d", extend=True
)
entry_projections.shape

# %%
exit_projections = exit_ranges.get_projections(
    proj_period="30d", extend=True
)
exit_projections.shape

# %%
rand_cols = np.random.choice(entry_projections.shape[1], 100)
entry_projections.iloc[:, rand_cols].vbt.plot_projections().show()

# %%
rand_cols = np.random.choice(exit_projections.shape[1], 100)
exit_projections.iloc[:, rand_cols].vbt.plot_projections().show()

# %% [markdown]
# ### Quick plotting

# %%
entry_ranges.wrapper.columns

# %%
entry_ranges.plot_projections(
    column=(25, 30),
    last_n=10,
    proj_period="30d",
    extend=True,
    plot_lower=False,
    plot_upper=False,
    plot_aux_middle=False,
    projection_trace_kwargs=dict(opacity=0.3)
).show()

# %% [markdown]
# ## Open projections

# %%
exit_ranges = exits.vbt.signals.between_ranges(
    entries,
    incl_open=True,
    close=data.close
)
exit_ranges.count().sum()

# %%
exit_ranges.wrapper.columns[exit_ranges.status_open.col_arr]

# %%
exit_ranges.status_closed.plot_projections(
    column=(20, 30), plot_bands=False
).show()

# %%
exit_ranges.plot_projections(
    column=(20, 30), plot_bands=False
).show()

# %%
column = (20, 30)
signal_index = data.wrapper.index[np.flatnonzero(exits[column])[-1]]
plot_start_index = signal_index - pd.Timedelta(days=10)
sub_close = data.close[plot_start_index:]
sub_exits = exits.loc[plot_start_index:, column]

fig = sub_close.vbt.plot()
sub_exits.vbt.signals.plot_as_exits(sub_close, fig=fig)
projections = exit_ranges[column].status_closed.get_projections(
    start_value=sub_close.loc[signal_index],
    start_index=signal_index
)
projections.vbt.plot_projections(plot_bands=False, fig=fig)
fig.show()

# %%