# ################################## HOW TO USE #################################### #
#                                                                                    #
# This is a Jupyter notebook formatted as a script                                   #
# Format: https://jupytext.readthedocs.io/en/latest/formats.html#the-percent-format  #
#                                                                                    #
# Save this file and remove the '.txt' extension                                     #
# In Jupyter Lab, right click on the Python file -> Open With -> Jupytext Notebook   #
# Make sure to have Jupytext installed: https://github.com/mwouts/jupytext           #
#                                                                                    #
# ################################################################################## #

# %% [markdown]
# #  Patterns

# %%
from vectorbtpro import *

data = vbt.BinanceData.pull(
    "BTCUSDT",
    start="2020-06-01 UTC",
    end="2022-06-01 UTC"
)
data.plot().show()

# %%
data_window = data.loc["2021-09-25":"2021-11-25"]
data_window.plot(plot_volume=False).show()

# %%
price_window = data_window.hlc3
price_window.vbt.plot().show()

# %%
pattern = np.array([1, 2, 3, 2, 3, 2])
pd.Series(pattern).vbt.plot().show()

# %% [markdown]
# ## Interpolation
# ### Linear

# %%
resized_pattern = vbt.nb.interp_resize_1d_nb(
    pattern, 10, vbt.enums.InterpMode.Linear
)
resized_pattern

# %%
def plot_linear(n):
    resized_pattern = vbt.nb.interp_resize_1d_nb(
        pattern, n, vbt.enums.InterpMode.Linear
    )
    return pd.Series(resized_pattern).vbt.plot()

# %%
plot_linear(7).show()

# %%
plot_linear(11).show()

# %%
plot_linear(30).show()

# %%
resized_pattern = vbt.nb.interp_resize_1d_nb(
    pattern, 7, vbt.enums.InterpMode.Linear
)
ratio = (len(pattern) - 1) / (len(resized_pattern) - 1)
new_points = np.arange(len(resized_pattern)) * ratio
fig = pd.Series(pattern).vbt.plot()
pd.Series(resized_pattern, index=new_points).vbt.scatterplot(fig=fig)
fig.show()

# %% [markdown]
# ### Nearest

# %%
resized_pattern = vbt.nb.interp_resize_1d_nb(
    pattern, 10, vbt.enums.InterpMode.Nearest
)
resized_pattern

# %%
def plot_nearest(n):
    resized_pattern = vbt.nb.interp_resize_1d_nb(
        pattern, n, vbt.enums.InterpMode.Nearest
    )
    return pd.Series(resized_pattern).vbt.plot()

# %%
plot_nearest(7).show()

# %%
plot_nearest(11).show()

# %%
plot_nearest(30).show()

# %% [markdown]
# ### Discrete

# %%
resized_pattern = vbt.nb.interp_resize_1d_nb(
    pattern, 10, vbt.enums.InterpMode.Discrete
)
resized_pattern

# %%
def plot_discrete(n):
    resized_pattern = vbt.nb.interp_resize_1d_nb(
        pattern, n, vbt.enums.InterpMode.Discrete
    )
    return pd.Series(resized_pattern).vbt.plot(
        trace_kwargs=dict(
            line=dict(dash="dot"),
            connectgaps=True
        )
    )

# %%
plot_discrete(7).show()

# %%
plot_discrete(11).show()

# %%
plot_discrete(30).show()

# %% [markdown]
# ### Mixed

# %%
resized_pattern = vbt.nb.interp_resize_1d_nb(
    pattern, 10, vbt.enums.InterpMode.Mixed
)
resized_pattern

# %%
def plot_mixed(n):
    lin_resized_pattern = vbt.nb.interp_resize_1d_nb(
        pattern, n, vbt.enums.InterpMode.Linear
    )
    mix_resized_pattern = vbt.nb.interp_resize_1d_nb(
        pattern, n, vbt.enums.InterpMode.Mixed
    )
    fig = pd.Series(lin_resized_pattern, name="Linear").vbt.plot()
    pd.Series(mix_resized_pattern, name="Mixed").vbt.plot(fig=fig)
    return fig

# %%
plot_mixed(7).show()

# %%
plot_mixed(11).show()

# %%
plot_mixed(30).show()

# %%
resized_pattern = vbt.nb.interp_resize_1d_nb(
    pattern, len(price_window), vbt.enums.InterpMode.Mixed
)
resized_pattern.shape

# %% [markdown]
# ## Rescaling

# %%
pattern_scale = (resized_pattern.min(), resized_pattern.max())
price_window_scale = (price_window.min(), price_window.max())
rescaled_pattern = vbt.utils.array_.rescale_nb(
    resized_pattern, pattern_scale, price_window_scale
)
rescaled_pattern = pd.Series(rescaled_pattern, index=price_window.index)

# %%
fig = price_window.vbt.plot()
rescaled_pattern.vbt.plot(
    trace_kwargs=dict(
        fill="tonexty",
        fillcolor="rgba(255, 0, 0, 0.25)"
    ),
    fig=fig
)
fig.show()

# %% [markdown]
# ### Rebasing

# %%
pct_pattern = np.array([1, 1.3, 1.6, 1.3, 1.6, 1.3])
resized_pct_pattern = vbt.nb.interp_resize_1d_nb(
    pct_pattern, len(price_window), vbt.enums.InterpMode.Mixed
)
rebased_pattern = resized_pct_pattern / resized_pct_pattern[0]
rebased_pattern *= price_window.values[0]
rebased_pattern = pd.Series(rebased_pattern, index=price_window.index)
fig = price_window.vbt.plot()
rebased_pattern.vbt.plot(
    trace_kwargs=dict(
        fill="tonexty",
        fillcolor="rgba(255, 0, 0, 0.25)"
    ),
    fig=fig
)
fig.show()

# %% [markdown]
# ## Fitting

# %%
new_pattern, _ = vbt.nb.fit_pattern_nb(
    price_window.values,
    pct_pattern,
    interp_mode=vbt.enums.InterpMode.Mixed,
    rescale_mode=vbt.enums.RescaleMode.Rebase
)

# %%
np.testing.assert_array_equal(new_pattern, rebased_pattern)

# %% [markdown]
# ## Similarity

# %%
abs_distances = np.abs(rescaled_pattern - price_window.values)
mae = abs_distances.sum()
max_abs_distances = np.column_stack((
    (price_window.max() - rescaled_pattern),
    (rescaled_pattern - price_window.min())
)).max(axis=1)
max_mae = max_abs_distances.sum()
similarity = 1 - mae / max_mae
similarity

# %%
quad_distances = (rescaled_pattern - price_window.values) ** 2
rmse = np.sqrt(quad_distances.sum())
max_quad_distances = np.column_stack((
    (price_window.max() - rescaled_pattern),
    (rescaled_pattern - price_window.min())
)).max(axis=1) ** 2
max_rmse = np.sqrt(max_quad_distances.sum())
similarity = 1 - rmse / max_rmse
similarity

# %%
quad_distances = (rescaled_pattern - price_window.values) ** 2
mse = quad_distances.sum()
max_quad_distances = np.column_stack((
    (price_window.max() - rescaled_pattern),
    (rescaled_pattern - price_window.min())
)).max(axis=1) ** 2
max_mse = max_quad_distances.sum()
similarity = 1 - mse / max_mse
similarity

# %%
vbt.nb.pattern_similarity_nb(price_window.values, pattern)

# %%
vbt.nb.pattern_similarity_nb(
    price_window.values,
    pct_pattern,
    rescale_mode=vbt.enums.RescaleMode.Rebase
)

# %%
vbt.nb.pattern_similarity_nb(
    price_window.values,
    pct_pattern,
    interp_mode=vbt.enums.InterpMode.Nearest,
    rescale_mode=vbt.enums.RescaleMode.Rebase,
    distance_measure=vbt.enums.DistanceMeasure.RMSE
)

# %%
price_window.vbt.plot_pattern(
    pct_pattern,
    interp_mode="nearest",
    rescale_mode="rebase",
    fill_distance=True
).show()

# %%
adj_pct_pattern = np.array([1, 1.3, 1.6, 1.45, 1.6, 1.3])
vbt.nb.pattern_similarity_nb(
    price_window.values,
    adj_pct_pattern,
    interp_mode=vbt.enums.InterpMode.Nearest,
    rescale_mode=vbt.enums.RescaleMode.Rebase,
    distance_measure=vbt.enums.DistanceMeasure.RMSE
)

# %%
price_window.vbt.plot_pattern(
    adj_pct_pattern,
    interp_mode="discrete",
    rescale_mode="rebase",
).show()

# %%
vbt.nb.pattern_similarity_nb(
    price_window.values,
    adj_pct_pattern,
    interp_mode=vbt.enums.InterpMode.Discrete,
    rescale_mode=vbt.enums.RescaleMode.Rebase,
    distance_measure=vbt.enums.DistanceMeasure.RMSE
)

# %% [markdown]
# ### Relative

# %%
abs_pct_distances = abs_distances / rescaled_pattern
pct_mae = abs_pct_distances.sum()
max_abs_pct_distances = max_abs_distances / rescaled_pattern
max_pct_mae = max_abs_pct_distances.sum()
similarity = 1 - pct_mae / max_pct_mae
similarity

# %%
vbt.nb.pattern_similarity_nb(
    price_window.values,
    pct_pattern,
    error_type=vbt.enums.ErrorType.Relative
)

# %%
vbt.nb.pattern_similarity_nb(
    np.array([10, 30, 100]),
    np.array([1, 2, 3]),
    error_type=vbt.enums.ErrorType.Absolute
)

# %%
vbt.nb.pattern_similarity_nb(
    np.array([10, 30, 100]),
    np.array([1, 2, 3]),
    error_type=vbt.enums.ErrorType.Relative
)

# %% [markdown]
# ### Inverse

# %%
vbt.nb.pattern_similarity_nb(price_window.values, pattern, invert=True)

# %%
price_window.vbt.plot_pattern(pattern, invert=True).show()

# %%
pattern.max() + pattern.min() - pattern

# %% [markdown]
# ### Max error

# %%
vbt.nb.pattern_similarity_nb(
    price_window.values,
    pattern,
)

# %%
vbt.nb.pattern_similarity_nb(
    price_window.values,
    pattern,
    max_error=np.array([0.5, 0.5, 0.5, 0.5, 0.5, 0.5]),
)

# %%
vbt.nb.pattern_similarity_nb(
    price_window.values,
    pattern,
    max_error=np.array([0.5]),
)

# %%
price_window.vbt.plot_pattern(
    pattern,
    max_error=0.5
).show()

# %%
vbt.nb.pattern_similarity_nb(
    price_window.values,
    pattern,
    max_error=np.array([0.5]),
    max_error_strict=True
)

# %%
vbt.nb.pattern_similarity_nb(
    price_window.values,
    pattern,
    max_error=np.array([0.1]),
    error_type=vbt.enums.ErrorType.Relative
)

# %%
price_window.vbt.plot_pattern(
    pattern,
    max_error=0.1,
    error_type="relative"
).show()

# %%
vbt.nb.pattern_similarity_nb(
    price_window.values,
    adj_pct_pattern,
    rescale_mode=vbt.enums.RescaleMode.Rebase,
    max_error=np.array([0.2, 0.1, 0.05, 0.1, 0.05, 0.1]),
    max_error_strict=True
)

# %%
price_window.vbt.plot_pattern(
    adj_pct_pattern,
    rescale_mode="rebase",
    max_error=np.array([0.2, 0.1, 0.05, 0.1, 0.05, 0.1])
).show()

# %%
vbt.nb.pattern_similarity_nb(
    price_window.values,
    adj_pct_pattern,
    rescale_mode=vbt.enums.RescaleMode.Rebase,
    max_error=np.array([0.2, 0.1, 0.05, 0.1, 0.05, 0.1]) + 0.05,
    max_error_strict=True
)

# %% [markdown]
# #### Interpolation

# %%
vbt.nb.pattern_similarity_nb(
    price_window.values,
    adj_pct_pattern,
    rescale_mode=vbt.enums.RescaleMode.Rebase,
    max_error=np.array([np.nan, np.nan, 0.1, np.nan, 0.1, np.nan]),
    max_error_interp_mode=vbt.enums.InterpMode.Discrete,
    max_error_strict=True
)

# %%
price_window.vbt.plot_pattern(
    adj_pct_pattern,
    rescale_mode="rebase",
    max_error=np.array([np.nan, np.nan, 0.1, np.nan, 0.1, np.nan]),
    max_error_interp_mode="discrete"
).show()

# %% [markdown]
# #### Max distance

# %%
vbt.nb.pattern_similarity_nb(price_window.values, pattern)

# %%
vbt.nb.pattern_similarity_nb(
    price_window.values,
    pattern,
    max_error=np.array([0.5]),
    max_error_as_maxdist=True
)

# %% [markdown]
# ### Further filters

# %%
vbt.nb.pattern_similarity_nb(
    price_window.values, pattern, max_pct_change=0.3
)

# %%
vbt.nb.pattern_similarity_nb(
    price_window.values, pattern, min_similarity=0.9
)

# %% [markdown]
# ## Rolling similarity

# %%
price = data.hlc3

similarity = price.vbt.rolling_pattern_similarity(
    pattern,
    window=30,
    error_type="relative",
    max_error=0.05,
    max_error_interp_mode="discrete"
)
similarity.describe()

# %%
end_row = similarity.argmax() + 1
start_row = end_row - 30
fig = data.iloc[start_row:end_row].plot(plot_volume=False)
price.iloc[start_row:end_row].vbt.plot_pattern(
    pattern,
    error_type="relative",
    max_error=0.05,
    max_error_interp_mode="discrete",
    plot_obj=False,
    fig=fig
)
fig.show()

# %%
end_row = similarity.argmin() + 1
start_row = end_row - 30
fig = data.iloc[start_row:end_row].plot(plot_volume=False)
price.iloc[start_row:end_row].vbt.plot_pattern(
    pattern,
    invert=True,
    error_type="relative",
    max_error=0.05,
    max_error_interp_mode="discrete",
    plot_obj=False,
    fig=fig
)
fig.show()

# %%
inv_similarity = price.vbt.rolling_pattern_similarity(
    pattern,
    window=30,
    invert=True,
    error_type="relative",
    max_error=0.05,
    max_error_interp_mode="discrete"
)
end_row = inv_similarity.argmax() + 1
start_row = end_row - 30
fig = data.iloc[start_row:end_row].plot(plot_volume=False)
price.iloc[start_row:end_row].vbt.plot_pattern(
    pattern,
    invert=True,
    error_type="relative",
    max_error=0.05,
    max_error_interp_mode="discrete",
    plot_obj=False,
    fig=fig
)
fig.show()

# %% [markdown]
# ### Indicator

# %%
patsim = vbt.PATSIM.run(
    price,
    vbt.Default(pattern),
    error_type=vbt.Default("relative"),
    max_error=vbt.Default(0.05),
    max_error_interp_mode=vbt.Default("discrete"),
    window=[30, 45, 60, 75, 90]
)

# %%
patsim.wrapper.columns

# %%
patsim.plot(column=60).show()

# %%
patsim.overlay_with_heatmap(column=60).show()

# %%
exits = patsim.similarity >= 0.8
exits.sum()

# %%
patsim = vbt.PATSIM.run(
    price,
    vbt.Default(pattern),
    error_type=vbt.Default("relative"),
    max_error=vbt.Default(0.05),
    max_error_interp_mode=vbt.Default("discrete"),
    window=[30, 45, 60, 75, 90],
    invert=[False, True],
    min_similarity=[0.7, 0.8],
    param_product=True
)
exits = ~patsim.similarity.isnull()
exits.sum()

# %%
groupby = [
    name for name in patsim.wrapper.columns.names
    if name != "patsim_window"
]
max_sim = patsim.similarity.groupby(groupby, axis=1).max()
entries = ~max_sim.xs(True, level="patsim_invert", axis=1).isnull()
exits = ~max_sim.xs(False, level="patsim_invert", axis=1).isnull()

# %%
fig = data.plot(ohlc_trace_kwargs=dict(opacity=0.5))
entries[0.8].vbt.signals.plot_as_entries(price, fig=fig)
exits[0.8].vbt.signals.plot_as_exits(price, fig=fig)
fig.show()

# %% [markdown]
# ## Search

# %%
pattern_range_records = vbt.nb.find_pattern_1d_nb(
    price.values,
    pattern,
    window=30,
    max_window=90,
    error_type=vbt.enums.ErrorType.Relative,
    max_error=np.array([0.05]),
    max_error_interp_mode=vbt.enums.InterpMode.Discrete,
    min_similarity=0.85
)
pattern_range_records

# %%
start_row = pattern_range_records[1]["start_idx"]
end_row = pattern_range_records[1]["end_idx"]
fig = data.iloc[start_row:end_row + 30].plot(plot_volume=False)
price.iloc[start_row:end_row].vbt.plot_pattern(
    pattern,
    error_type="relative",
    max_error=0.05,
    max_error_interp_mode="discrete",
    plot_obj=False,
    fig=fig
)
fig.show()

# %%
pattern_ranges = vbt.PatternRanges.from_pattern_search(
    price,
    pattern,
    window=30,
    max_window=90,
    error_type="relative",
    max_error=0.05,
    max_error_interp_mode="discrete",
    min_similarity=0.85
)
pattern_ranges

# %%
pattern_ranges = price.vbt.find_pattern(
    pattern,
    window=30,
    max_window=90,
    error_type="relative",
    max_error=0.05,
    max_error_interp_mode="discrete",
    min_similarity=0.85
)

# %%
pattern_ranges.records_readable

# %%
pattern_ranges.wrapper.columns

# %%
pattern_ranges.search_configs

# %%
pattern_ranges.plot().show()

# %%
pattern_ranges.loc["2021-09-01":"2022-01-01"].plot().show()

# %%
pattern_ranges.stats()

# %% [markdown]
# ### Overlapping

# %%
pattern_ranges = price.vbt.find_pattern(
    pattern,
    window=30,
    max_window=120,
    error_type="relative",
    max_error=0.05,
    max_error_interp_mode="discrete",
    min_similarity=0.85,
    overlap_mode="allow"
)

pattern_ranges.count()

# %%
pattern_ranges.overlap_coverage

# %%
pattern_ranges.plot(plot_zones=False, plot_patterns=False).show()

# %% [markdown]
# ### Random selection

# %%
def run_prob_search(row_select_prob, window_select_prob):
    return price.vbt.find_pattern(
        pattern,
        window=30,
        max_window=120,
        row_select_prob=row_select_prob,
        window_select_prob=window_select_prob,
        error_type="relative",
        max_error=0.05,
        max_error_interp_mode="discrete",
        min_similarity=0.8,
    )

%timeit run_prob_search(1.0, 1.0)

# %%
%timeit run_prob_search(0.5, 0.25)

# %%
run_prob_search(1.0, 1.0).count()

# %%
pd.Series([
    run_prob_search(0.5, 0.25).count()
    for i in range(100)
]).vbt.plot().show()

# %% [markdown]
# ### Params

# %%
pattern_ranges = price.vbt.find_pattern(
    vbt.Param([
        [1, 2, 1],
        [2, 1, 2],
        [1, 2, 3],
        [3, 2, 1]
    ]),
    window=30,
    max_window=120,
)

# %%
pattern_ranges.count()

# %%
pattern_ranges = price.vbt.find_pattern(
    vbt.Param([
        [1, 2, 1],
        [2, 1, 2],
        [1, 2, 3],
        [3, 2, 1]
    ], keys=["v-top", "v-bottom", "rising", "falling"]),
    window=30,
    max_window=120,
)
pattern_ranges.count()

# %%
pattern_ranges.plot(column="falling").show()

# %%
pattern_ranges = price.vbt.find_pattern(
    vbt.Param([
        [1, 2, 1],
        [2, 1, 2],
        [1, 2, 3],
        [3, 2, 1]
    ], keys=["v-top", "v-bottom", "rising", "falling"]),
    window=30,
    max_window=120,
    min_similarity=vbt.Param([0.8, 0.85])
)
pattern_ranges.count()

# %%
pattern_ranges.plot(column=("v-bottom", 0.8)).show()

# %%
pattern_ranges = price.vbt.find_pattern(
    vbt.Param([
        [1, 2, 1],
        [2, 1, 2],
        [1, 2, 3],
        [3, 2, 1]
    ], keys=["v-top", "v-bottom", "rising", "falling"], level=0),
    window=vbt.Param([30, 30, 7, 7], level=0),
    max_window=vbt.Param([120, 120, 30, 30], level=0),
    min_similarity=vbt.Param([0.8, 0.85], level=1)
)
pattern_ranges.count()

# %% [markdown]
# ### Configs

# %%
mult_data = vbt.BinanceData.pull(
    ["BTCUSDT", "ETHUSDT"],
    start="2020-06-01 UTC",
    end="2022-06-01 UTC"
)
mult_price = mult_data.hlc3

pattern_ranges = mult_price.vbt.find_pattern(
    search_configs=[
        vbt.PSC(pattern=[1, 2, 3, 2, 3, 2], window=30),
        [
            vbt.PSC(pattern=mult_price.iloc[-30:, 0]),
            vbt.PSC(pattern=mult_price.iloc[-30:, 1]),
        ]
    ],
    min_similarity=0.8
)
pattern_ranges.count()

# %%
pattern_ranges = mult_price.vbt.find_pattern(
    search_configs=[
        vbt.PSC(pattern=[1, 2, 3, 2, 3, 2], window=30, name="double_top"),
        [
            vbt.PSC(pattern=mult_price.iloc[-30:, 0], name="last"),
            vbt.PSC(pattern=mult_price.iloc[-30:, 1], name="last"),
        ]
    ],
    min_similarity=0.8
)
pattern_ranges.count()

# %%
pattern_ranges = mult_price.vbt.find_pattern(
    search_configs=[
        vbt.PSC(pattern=[1, 2, 3, 2, 3, 2], window=30, name="double_top"),
        [
            vbt.PSC(pattern=mult_price.iloc[-30:, 0], name="last"),
            vbt.PSC(pattern=mult_price.iloc[-30:, 1], name="last"),
        ]
    ],
    rescale_mode=vbt.Param(["minmax", "rebase"]),
    min_similarity=0.8,
    open=mult_data.open,
    high=mult_data.high,
    low=mult_data.low,
    close=mult_data.close,
)
pattern_ranges.count()

# %%
pattern_ranges.plot(column=("rebase", "last", "ETHUSDT")).show()

# %% [markdown]
# ### Mask

# %%
mask = pattern_ranges.last_pd_mask
mask.sum()

# %% [markdown]
# ### Indicator

# %%
pattern_ranges = price.vbt.find_pattern(
    pattern,
    window=30,
    max_window=120,
    row_select_prob=0.5,
    window_select_prob=0.5,
    overlap_mode="allow",
    seed=42
)
pr_mask = pattern_ranges.map_field(
    "similarity",
    idx_arr=pattern_ranges.last_idx.values
).to_pd()
pr_mask[~pr_mask.isnull()]

# %%
patsim = vbt.PATSIM.run(
    price,
    vbt.Default(pattern),
    window=vbt.Default(30),
    max_window=vbt.Default(120),
    row_select_prob=vbt.Default(0.5),
    window_select_prob=vbt.Default(0.5),
    min_similarity=vbt.Default(0.85),
    seed=42
)
ind_mask = patsim.similarity
ind_mask[~ind_mask.isnull()]

# %% [markdown]
# ## Combination

# %%
price_highs = vbt.PATSIM.run(
    data.high,
    pattern=np.array([1, 3, 2, 4]),
    window=40,
    max_window=50
)
macd = data.run("talib_macd").macd
macd_lows = vbt.PATSIM.run(
    macd,
    pattern=np.array([4, 2, 3, 1]),
    window=40,
    max_window=50
)

fig = vbt.make_subplots(
    rows=3, cols=1, shared_xaxes=True, vertical_spacing=0.02
)
fig.update_layout(height=500)
data.high.rename("Price").vbt.plot(
    add_trace_kwargs=dict(row=1, col=1), fig=fig
)
macd.rename("MACD").vbt.plot(
    add_trace_kwargs=dict(row=2, col=1), fig=fig
)
price_highs.similarity.rename("Price Sim").vbt.plot(
    add_trace_kwargs=dict(row=3, col=1), fig=fig
)
macd_lows.similarity.rename("MACD Sim").vbt.plot(
    add_trace_kwargs=dict(row=3, col=1), fig=fig
)
fig.show()

# %%
cond1 = (price_highs.similarity >= 0.8).vbt.rolling_any(10)
cond2 = (macd_lows.similarity >= 0.8).vbt.rolling_any(10)
exits = cond1 & cond2
fig = data.plot(ohlc_trace_kwargs=dict(opacity=0.5))
exits.vbt.signals.plot_as_exits(data.close, fig=fig)
fig.show()

# %%