# ################################## HOW TO USE #################################### #
#                                                                                    #
# This is a Jupyter notebook formatted as a script                                   #
# Format: https://jupytext.readthedocs.io/en/latest/formats.html#the-percent-format  #
#                                                                                    #
# Save this file and remove the '.txt' extension                                     #
# In Jupyter Lab, right click on the Python file -> Open With -> Jupytext Notebook   #
# Make sure to have Jupytext installed: https://github.com/mwouts/jupytext           #
#                                                                                    #
# ################################################################################## #

# %% [markdown]
# #  Analysis
# ## Helper methods

# %%
from vectorbtpro import *

SignalType = namedtuple('SigType', ['Entry', 'Exit'])(0, 1)

def apply_func(ts, fastw, sloww, minp=None):
    fast_ma = vbt.nb.rolling_mean_nb(ts, fastw, minp=minp)
    slow_ma = vbt.nb.rolling_mean_nb(ts, sloww, minp=minp)
    entries = vbt.nb.crossed_above_nb(fast_ma, slow_ma)
    exits = vbt.nb.crossed_above_nb(slow_ma, fast_ma)
    signals = entries | exits
    signal_type = np.full(ts.shape, -1, dtype=np.int_)
    signal_type[entries] = SignalType.Entry
    signal_type[exits] = SignalType.Exit
    return (fast_ma, slow_ma, signals, signal_type)

CrossSig = vbt.IF(
    class_name="CrossSig",
    input_names=['ts'],
    param_names=['fastw', 'sloww'],
    output_names=['fast_ma', 'slow_ma', 'signals', 'signal_type'],
    attr_settings=dict(
        fast_ma=dict(dtype=np.float_),
        slow_ma=dict(dtype=np.float_),
        signals=dict(dtype=np.bool_),
        signal_type=dict(dtype=SignalType),
    )
).with_apply_func(apply_func)

def generate_index(n):
    return pd.date_range("2020-01-01", periods=n)

ts = pd.DataFrame({
    'a': [1, 2, 3, 2, 1, 2, 3],
    'b': [3, 2, 1, 2, 3, 2, 1]
}, index=generate_index(7))
cross_sig = CrossSig.run(ts, 2, 3)

# %%
dir(cross_sig)


# %%
cross_sig.fast_ma_stats(column=(2, 3, 'a'))

# %%
cross_sig.fast_ma.vbt.stats(column=(2, 3, 'a'))

# %% [markdown]
# ### Numeric

# %%
cross_sig.fast_ma_above([2, 3])

# %%
cross_sig.fast_ma.vbt > vbt.Param([2, 3], name='crosssig_fast_ma_above')

# %%
cross_sig.fast_ma_crossed_above(cross_sig.slow_ma)

# %%
cross_sig.fast_ma.vbt.crossed_above(cross_sig.slow_ma)

# %% [markdown]
# ### Boolean

# %%
other_signals = pd.Series([False, False, False, False, True, False, False])
cross_sig.signals_and(other_signals)

# %%
cross_sig.signals.vbt & other_signals

# %% [markdown]
# ### Enumerated

# %%
cross_sig.signal_type_readable

# %%
cross_sig.signal_type.vbt(mapping=SignalType).apply_mapping()

# %% [markdown]
# ## Indexing

# %%
cross_sig = CrossSig.run(ts, [2, 3], [3, 4], param_product=True)

cross_sig.loc["2020-01-03":, (2, 3, 'a')]

# %%
cross_sig.loc["2020-01-03":, (2, 3, 'a')].signals

# %%
cross_sig.fastw_loc[2].sloww_loc[3]['a']

# %%
cross_sig.fastw_loc[2].sloww_loc[3]['a'].signals

# %% [markdown]
# ## Stats and plots

# %%
metrics = dict(
    start=dict(
        title='Start',
        calc_func=lambda self: self.wrapper.index[0],
        agg_func=None
    ),
    end=dict(
        title='End',
        calc_func=lambda self: self.wrapper.index[-1],
        agg_func=None
    ),
    period=dict(
        title='Period',
        calc_func=lambda self: len(self.wrapper.index),
        apply_to_timedelta=True,
        agg_func=None
    ),
    fast_stats=dict(
        title="Fast Stats",
        calc_func=lambda self:
        self.fast_ma.describe()\
        .loc[['count', 'mean', 'std', 'min', 'max']]\
        .vbt.to_dict(orient='index_series')
    ),
    slow_stats=dict(
        title="Slow Stats",
        calc_func=lambda self:
        self.slow_ma.describe()\
        .loc[['count', 'mean', 'std', 'min', 'max']]\
        .vbt.to_dict(orient='index_series')
    ),
    num_entries=dict(
        title="Entries",
        calc_func=lambda self:
        np.sum(self.signal_type == SignalType.Entry)
    ),
    num_exits=dict(
        title="Exits",
        calc_func=lambda self:
        np.sum(self.signal_type == SignalType.Exit)
    )
)

def plot_mas(self, column=None, add_trace_kwargs=None, fig=None):
    ts = self.select_col_from_obj(self.ts, column).rename('TS')
    fast_ma = self.select_col_from_obj(self.fast_ma, column).rename('Fast MA')
    slow_ma = self.select_col_from_obj(self.slow_ma, column).rename('Slow MA')
    ts.vbt.plot(add_trace_kwargs=add_trace_kwargs, fig=fig)
    fast_ma.vbt.plot(add_trace_kwargs=add_trace_kwargs, fig=fig)
    slow_ma.vbt.plot(add_trace_kwargs=add_trace_kwargs, fig=fig)

def plot_signals(self, column=None, add_trace_kwargs=None, fig=None):
    signal_type = self.select_col_from_obj(self.signal_type, column)
    entries = (signal_type == SignalType.Entry).rename('Entries')
    exits = (signal_type == SignalType.Exit).rename('Exits')
    entries.vbt.plot(add_trace_kwargs=add_trace_kwargs, fig=fig)
    exits.vbt.plot(add_trace_kwargs=add_trace_kwargs, fig=fig)

subplots = dict(
    mas=dict(
        title="Moving averages",
        plot_func=plot_mas
    ),
    signals=dict(
        title="Signals",
        plot_func=plot_signals
    )
)

CrossSig = vbt.IF(
    class_name="CrossSig",
    input_names=['ts'],
    param_names=['fastw', 'sloww'],
    output_names=['fast_ma', 'slow_ma', 'signals', 'signal_type'],
    attr_settings=dict(
        fast_ma=dict(dtype=np.float_),
        slow_ma=dict(dtype=np.float_),
        signals=dict(dtype=np.bool_),
        signal_type=dict(dtype=SignalType),
    ),
    metrics=metrics,
    subplots=subplots
).with_apply_func(apply_func)

cross_sig = CrossSig.run(ts, [2, 3], 4)

# %%
cross_sig.stats(column=(2, 4, 'a'))

# %%
cross_sig.plots(column=(2, 4, 'a')).show()

# %% [markdown]
# ## Extending

# %%
class SmartCrossSig(CrossSig):
    def plot_mas(self, column=None, add_trace_kwargs=None, fig=None):
        ts = self.select_col_from_obj(self.ts, column).rename('TS')
        fast_ma = self.select_col_from_obj(self.fast_ma, column).rename('Fast MA')
        slow_ma = self.select_col_from_obj(self.slow_ma, column).rename('Slow MA')
        fig = ts.vbt.plot(add_trace_kwargs=add_trace_kwargs, fig=fig)
        fast_ma.vbt.plot(add_trace_kwargs=add_trace_kwargs, fig=fig)
        slow_ma.vbt.plot(add_trace_kwargs=add_trace_kwargs, fig=fig)
        return fig

    def plot_signals(self, column=None, add_trace_kwargs=None, fig=None):
        signal_type = self.select_col_from_obj(self.signal_type, column)
        entries = (signal_type == SignalType.Entry).rename('Entries')
        exits = (signal_type == SignalType.Exit).rename('Exits')
        fig = entries.vbt.plot(add_trace_kwargs=add_trace_kwargs, fig=fig)
        exits.vbt.plot(add_trace_kwargs=add_trace_kwargs, fig=fig)
        return fig

    subplots = vbt.HybridConfig(
        mas=dict(
            title="Moving averages",
            plot_func='plot_mas'
        ),
        signals=dict(
            title="Signals",
            plot_func='plot_signals'
        )
    )

smart_cross_sig = SmartCrossSig.run(ts, [2, 3], 4)
smart_cross_sig.plot_signals(column=(2, 4, 'a')).show()

# %%