# ################################## HOW TO USE #################################### #
#                                                                                    #
# This is a Jupyter notebook formatted as a script                                   #
# Format: https://jupytext.readthedocs.io/en/latest/formats.html#the-percent-format  #
#                                                                                    #
# Save this file and remove the '.txt' extension                                     #
# In Jupyter Lab, right click on the Python file -> Open With -> Jupytext Notebook   #
# Make sure to have Jupytext installed: https://github.com/mwouts/jupytext           #
#                                                                                    #
# ################################################################################## #

# %% [markdown]
# #  Development
# ## Parameters

# %%
from vectorbtpro import *

def broadcast_params(*params):
    return list(zip(*vbt.broadcast(*[vbt.to_1d_array(p) for p in params])))

broadcast_params(2, 3)

# %%
broadcast_params([2, 3], 4)

# %%
broadcast_params(2, [3, 4])

# %%
broadcast_params([2, 3], [4, 5])

# %%
broadcast_params([2, 3], [4, 5, 6])

# %%
def apply_func(ts, window, lower, upper):
    out = np.full_like(ts, np.nan, dtype=np.float_)
    ts_mean = vbt.nb.rolling_mean_nb(ts, window)
    out[ts_mean >= upper] = 1
    out[ts_mean <= lower] = -1
    out[(ts_mean > lower) & (ts_mean < upper)] = 0
    return out

Bounded = vbt.IF(
    class_name="Bounded",
    input_names=['ts'],
    param_names=['window', 'lower', 'upper'],
    output_names=['out']
).with_apply_func(apply_func)

def generate_index(n):
    return pd.date_range("2020-01-01", periods=n)

ts = pd.DataFrame({
    'a': [5, 4, 3, 2, 3, 4, 5],
    'b': [2, 3, 4, 5, 4, 3, 2]
}, index=generate_index(7))
bounded = Bounded.run(ts, 2, 3, 5)

# %%
bounded.param_names

# %%
bounded.window_list

# %%
Bounded.run(
    ts,
    window=2,
    lower=3,
    upper=5
).out

# %%
Bounded.run(
    ts,
    window=[2, 3],
    lower=3,
    upper=5
).out

# %%
Bounded.run(
    ts,
    window=[2, 3],
    lower=[3, 4],
    upper=5,
    param_product=True
).out

# %%
bound_combs_op = (combinations, [3, 4, 5], 2)
product_op = (product, [2, 3], bound_combs_op)
windows, lowers, uppers = vbt.generate_param_combs(product_op)

Bounded.run(
    ts,
    window=windows,
    lower=lowers,
    upper=uppers
).out

# %%
Bounded.run(
    ts,
    window=[2, 3],
    lower=[3, 4],
    upper=5,
    per_column=True
).out

# %% [markdown]
# ### Defaults

# %%
Bounded = vbt.IF(
    class_name="Bounded",
    input_names=['ts'],
    param_names=['window', 'lower', 'upper'],
    output_names=['out']
).with_apply_func(apply_func, window=2, lower=3, upper=4)

Bounded.run(ts).out

# %%
Bounded.run(ts, upper=[5, 6]).out

# %%
Bounded.run(ts, hide_default=False).out

# %% [markdown]
# ### Array-like

# %%
def apply_func(ts, window, lower, upper):
    out = np.full_like(ts, np.nan, dtype=np.float_)
    ts_means = []
    for col in range(ts.shape[1]):
        ts_means.append(vbt.nb.rolling_mean_1d_nb(ts[:, col], window[col]))
    ts_mean = np.column_stack(ts_means)
    out[ts_mean >= upper] = 1
    out[ts_mean <= lower] = -1
    out[(ts_mean > lower) & (ts_mean < upper)] = 0
    return out

Bounded = vbt.IF(
    class_name="Bounded",
    input_names=['ts'],
    param_names=['window', 'lower', 'upper'],
    output_names=['out']
).with_apply_func(
    apply_func,
    param_settings=dict(
        window=dict(is_array_like=True, bc_to_input=1, per_column=True),
        lower=dict(is_array_like=True, bc_to_input=True),
        upper=dict(is_array_like=True, bc_to_input=True)
    )
)

# %%
Bounded.run(
    ts,
    window=[np.array([2, 3]), 4],
    lower=np.array([[1, 2]]),
    upper=np.array([6, 5, 4, 3, 4, 5, 6]),
).out

# %% [markdown]
# ### Lazy broadcasting

# %%
def apply_func(ts, window, lower, upper):
    window = np.broadcast_to(window, ts.shape[1])
    lower = np.broadcast_to(lower, ts.shape)
    upper = np.broadcast_to(upper, ts.shape)

    out = np.full_like(ts, np.nan, dtype=np.float_)
    ts_means = []
    for col in range(ts.shape[1]):
        ts_means.append(vbt.nb.rolling_mean_1d_nb(ts[:, col], window[col]))
    ts_mean = np.column_stack(ts_means)
    out[ts_mean >= upper] = 1
    out[ts_mean <= lower] = -1
    out[(ts_mean > lower) & (ts_mean < upper)] = 0
    return out

Bounded = vbt.IF(
    class_name="Bounded",
    input_names=['ts'],
    param_names=['window', 'lower', 'upper'],
    output_names=['out']
).with_apply_func(
    apply_func,
    param_settings=dict(
        window=vbt.flex_col_param_config,
        lower=vbt.flex_elem_param_config,
        upper=vbt.flex_elem_param_config
    )
)

# %% [markdown]
# #### With Numba

# %%
@njit
def apply_func_nb(ts, window, lower, upper):
    out = np.full_like(ts, np.nan, dtype=np.float_)

    for col in range(ts.shape[1]):
        _window = vbt.flex_select_1d_pc_nb(window, col)

        for row in range(ts.shape[0]):
            window_start = max(0, row + 1 - _window)
            window_end = row + 1
            if window_end - window_start >= _window:
                _lower = vbt.flex_select_nb(lower, row, col)
                _upper = vbt.flex_select_nb(upper, row, col)

                mean = np.nanmean(ts[window_start:window_end, col])
                if mean >= _upper:
                    out[row, col] = 1
                elif mean <= _lower:
                    out[row, col] = -1
                elif _lower < mean < _upper:
                    out[row, col] = 0
    return out

# %% [markdown]
# ### Parameterless
# ## Inputs

# %%
def apply_func(high, low, close):
    return (close - low) / (high - low)

RelClose = vbt.IF(
    input_names=['high', 'low', 'close'],
    output_names=['out']
).with_apply_func(apply_func)

close = pd.Series([1, 2, 3, 4, 5], index=generate_index(5))
high = close * 1.2
low = close * 0.8

rel_close = RelClose.run(high, low, close)
rel_close.out

# %%
rel_close.input_names

# %%
rel_close.high

# %%
high = 10
low = pd.Series([1, 2, 3, 4, 5], index=generate_index(5))
close = pd.DataFrame({
    'a': [3, 2, 1, 2, 3],
    'b': [5, 4, 3, 4, 5]
}, index=generate_index(5))
RelClose.run(high, low, close).out

# %%
RelClose.run(
    high, low, close,
    broadcast_kwargs=dict(require_kwargs=dict(dtype=np.float16))
).out.dtypes

# %%
RelClose.run(
    vbt.BCO(high, require_kwargs=dict(dtype=np.float16)),
    vbt.BCO(low, require_kwargs=dict(dtype=np.float16)),
    vbt.BCO(close, require_kwargs=dict(dtype=np.float16))
).out.dtypes

# %% [markdown]
# ### One dim

# %%
import talib

def apply_func_1d(close, timeperiod):
    return talib.SMA(close.astype(np.double), timeperiod)

SMA = vbt.IF(
    input_names=['ts'],
    param_names=['timeperiod'],
    output_names=['sma']
).with_apply_func(apply_func_1d, takes_1d=True)

sma = SMA.run(ts, [3, 4])
sma.sma

# %% [markdown]
# ### Defaults

# %%
RelClose = vbt.IF(
    input_names=['high', 'low', 'close'],
    output_names=['out']
).with_apply_func(
    apply_func,
    high=0,
    low=10
)

RelClose.run(close).out

# %%
RelClose = vbt.IF(
    input_names=['high', 'low', 'close'],
    output_names=['out']
).with_apply_func(
    apply_func,
    high=vbt.Ref('close'),
    low=vbt.Ref('close')
)

RelClose.run(high=high, close=close).out

# %% [markdown]
# ### Using Pandas

# %%
def apply_func(ts, group_by):
    return ts.vbt.demean(group_by=group_by)

Demeaner = vbt.IF(
    input_names=['ts'],
    param_names=['group_by'],
    output_names=['out']
).with_apply_func(apply_func, keep_pd=True)

ts_wide = pd.DataFrame({
    'a': [1, 2, 3, 4, 5],
    'b': [5, 4, 3, 2, 1],
    'c': [3, 2, 1, 2, 3],
    'd': [1, 2, 3, 2, 1]
}, index=generate_index(5))
demeaner = Demeaner.run(ts_wide, group_by=[(0, 0, 1, 1), True])
demeaner.out

# %%
def apply_func(ts, group_by, wrapper):
    group_map = wrapper.grouper.get_group_map(group_by=group_by)
    return vbt.nb.demean_nb(ts, group_map)

Demeaner = vbt.IF(
    input_names=['ts'],
    param_names=['group_by'],
    output_names=['out']
).with_apply_func(apply_func, pass_wrapper=True)

# %% [markdown]
# ### Inputless

# %%
def apply_func(input_shape, start, mean, std):
    rand_returns = np.random.normal(mean, std, input_shape)
    return start * np.cumprod(1 + rand_returns, axis=0)

RandPrice = vbt.IF(
    class_name="RandPrice",
    param_names=['start', 'mean', 'std'],
    output_names=['out']
).with_apply_func(
    apply_func,
    require_input_shape=True,
    start=100,
    mean=0,
    std=0.01,
    seed=42
)

RandPrice.run((5, 2)).out

# %%
RandPrice.run(
    (5, 2),
    input_index=generate_index(5),
    input_columns=['a', 'b'],
    mean=[-0.1, 0.1]
).out

# %%
def custom_func(min_rows=1, max_rows=5, min_cols=1, max_cols=3):
    n_rows = np.random.randint(min_rows, max_rows)
    n_cols = np.random.randint(min_cols, max_cols)
    return np.random.uniform(size=(n_rows, n_cols))

RandShaped = vbt.IF(
    output_names=['out']
).with_custom_func(custom_func)

RandShaped.run(seed=42).out

# %%
RandShaped.run(seed=43).out

# %%
RandShaped.run(seed=44).out

# %% [markdown]
# ## Outputs
# ### Regular

# %%
def apply_func(ts, fastw, sloww, minp=None):
    fast_ma = vbt.nb.rolling_mean_nb(ts, fastw, minp=minp)
    slow_ma = vbt.nb.rolling_mean_nb(ts, sloww, minp=minp)
    entries = vbt.nb.crossed_above_nb(fast_ma, slow_ma)
    exits = vbt.nb.crossed_above_nb(slow_ma, fast_ma)
    return (fast_ma, slow_ma, entries, exits)

CrossSig = vbt.IF(
    class_name="CrossSig",
    input_names=['ts'],
    param_names=['fastw', 'sloww'],
    output_names=['fast_ma', 'slow_ma', 'entries', 'exits']
).with_apply_func(apply_func)

ts2 = pd.DataFrame({
    'a': [1, 2, 3, 2, 1, 2, 3],
    'b': [3, 2, 1, 2, 3, 2, 1]
}, index=generate_index(7))
cross_sig = CrossSig.run(ts2, 2, 4)

# %%
cross_sig.output_names

# %%
cross_sig.entries

# %% [markdown]
# ### In-place

# %%
def apply_func(ts, entries, exits, fastw, sloww, minp=None):
    fast_ma = vbt.nb.rolling_mean_nb(ts, fastw, minp=minp)
    slow_ma = vbt.nb.rolling_mean_nb(ts, sloww, minp=minp)
    entries[:] = vbt.nb.crossed_above_nb(fast_ma, slow_ma)
    exits[:] = vbt.nb.crossed_above_nb(slow_ma, fast_ma)
    return (fast_ma, slow_ma)

CrossSig = vbt.IF(
    class_name="CrossSig",
    input_names=['ts'],
    in_output_names=['entries', 'exits'],
    param_names=['fastw', 'sloww'],
    output_names=['fast_ma', 'slow_ma']
).with_apply_func(
    apply_func,
    in_output_settings=dict(
        entries=dict(dtype=np.bool_),
        exits=dict(dtype=np.bool_)
    )
)
cross_sig = CrossSig.run(ts2, 2, 4)

# %%
cross_sig.output_names

# %%
cross_sig.in_output_names

# %%
cross_sig.entries

# %% [markdown]
# #### Defaults

# %%
@njit
def apply_func_nb(signals, n):
    for col in range(signals.shape[1]):
        n_found = 0
        for row in range(signals.shape[0]):
            if signals[row, col]:
                if n_found >= n:
                    signals[row, col] = False
                else:
                    n_found += 1

FirstNSig = vbt.IF(
    class_name="FirstNSig",
    in_output_names=['signals'],
    param_names=['n']
).with_apply_func(apply_func_nb)

signals = pd.Series([False, True, True, True, False])
first_n_sig = FirstNSig.run([1, 2, 3], signals=signals)
first_n_sig.signals

# %%
signals

# %% [markdown]
# ### Extra

# %%
def custom_func(ts, window):
    ts_mas = []
    ts_ma_maxs = []
    for w in window:
        ts_ma = vbt.nb.rolling_mean_nb(ts, w)
        ts_mas.append(ts_ma)
        ts_ma_maxs.append(np.nanmax(ts_ma, axis=0))
    return np.column_stack(ts_mas), np.concatenate(ts_ma_maxs)

MAMax = vbt.IF(
    class_name='MAMax',
    input_names=['ts'],
    param_names=['window'],
    output_names=['ma'],
).with_custom_func(custom_func)

ma_ind, ma_max = MAMax.run(ts2, [2, 3])
ma_ind

# %%
ma_ind.wrapper.wrap_reduced(ma_max)

# %% [markdown]
# ### Lazy

# %%
MAMax = vbt.IF(
    class_name='MAMax',
    input_names=['ts'],
    param_names=['window'],
    output_names=['ma'],
    lazy_outputs=dict(
        ma_max=vbt.cached_property(lambda self: self.ma.max())
    )
).with_apply_func(vbt.nb.rolling_mean_nb)

ma_ind = MAMax.run(ts2, [2, 3])
ma_ind.ma_max

# %% [markdown]
# ## Custom arguments
# ### Optional
# ### Variable

# %%
def custom_func(*arrs):
    out = None
    for arr in arrs:
        if out is None:
            out = arr
        else:
            out += arr
    return out

VarArgAdder = vbt.IF(
    output_names=['out']
).with_custom_func(custom_func, var_args=True)

VarArgAdder.run(
    pd.Series([1, 2, 3]),
    pd.Series([10, 20, 30]),
    pd.Series([100, 200, 300])
).out

# %% [markdown]
# ### Positional
# ### Keyword-only

# %%
def apply_func(high, low, close):
    return (close - low) / (high - low)

RelClose = vbt.IF(
    input_names=['high', 'low', 'close'],
    output_names=['out']
).with_apply_func(apply_func)

RelClose.run(close, high, low).out

# %%
RelClose = vbt.IF(
    input_names=['high', 'low', 'close'],
    output_names=['out']
).with_apply_func(apply_func, keyword_only_args=True)

RelClose.run(close, high, low).out

# %%
RelClose.run(close=close, high=high, low=low).out

# %% [markdown]
# ## Built-in caching

# %%
raw = vbt.MA.run(
    ts2,
    window=[2, 2, 3],
    wtype=["simple", "simple", "exp"],
    return_raw=True)
raw

# %%
raw = vbt.MA.run(
    ts2,
    window=[2, 2, 3],
    wtype=["simple", "simple", "exp"],
    return_raw=True,
    run_unique=True,
    silence_warnings=True)
raw

# %%
a = np.random.uniform(size=(1000,))

%timeit vbt.MA.run(a, np.full(1000, 2), run_unique=False)

# %%
%timeit vbt.MA.run(a, np.full(1000, 2), run_unique=True)

# %% [markdown]
# ### Reusing cache

# %%
raw = vbt.MA.run(
    ts2,
    window=[2, 3],
    wtype=["simple", "exp"],
    return_raw=True)
vbt.MA.run(ts2, 2, "simple", use_raw=raw).ma

# %%
vbt.MA.run(ts2, 2, "exp", use_raw=raw).ma

# %% [markdown]
# ## Manual caching

# %%
def roll_mean_expensive_nb(ts, w):
    for i in range(100):
        out = vbt.nb.rolling_mean_nb(ts, w)
    return out

def apply_func(ts, w1, w2):
    roll_mean1 = roll_mean_expensive_nb(ts, w1)
    roll_mean2 = roll_mean_expensive_nb(ts, w2)
    return (roll_mean2 - roll_mean1) / roll_mean1

RelMADist = vbt.IF(
    class_name="RelMADist",
    input_names=['ts'],
    param_names=['w1', 'w2'],
    output_names=['out'],
).with_apply_func(apply_func)

RelMADist.run(ts2, 2, 3).out

# %%
%timeit RelMADist.run(ts2, 2, np.arange(2, 1000))

# %%
def cache_func(ts, w1, w2):
    cache_dict = dict()
    for w in w1 + w2:
        if w not in cache_dict:
            cache_dict[w] = roll_mean_expensive_nb(ts, w)
    return cache_dict

def apply_func(ts, w1, w2, cache_dict):
    return (cache_dict[w2] - cache_dict[w1]) / cache_dict[w1]

RelMADist = vbt.IF(
    class_name="RelMADist",
    input_names=['ts'],
    param_names=['w1', 'w2'],
    output_names=['out'],
).with_apply_func(apply_func, cache_func=cache_func)

RelMADist.run(ts2, 2, 3).out

# %%
%timeit RelMADist.run(ts2, 2, np.arange(2, 1000))

# %% [markdown]
# ### Per column

# %%
def cache_func(ts, w1, w2, per_column=False):
    if per_column:
        return None
    cache_dict = dict()
    for w in w1 + w2:
        if w not in cache_dict:
            cache_dict[w] = roll_mean_expensive_nb(ts, w)
    return cache_dict

def apply_func(ts, w1, w2, cache_dict=None):
    if cache_dict is None:
        roll_mean1 = roll_mean_expensive_nb(ts, w1)
        roll_mean2 = roll_mean_expensive_nb(ts, w2)
    else:
        roll_mean1 = cache_dict[w1]
        roll_mean2 = cache_dict[w2]
    return (roll_mean2 - roll_mean1) / roll_mean1


RelMADist = vbt.IF(
    class_name="RelMADist",
    input_names=['ts'],
    param_names=['w1', 'w2'],
    output_names=['out'],
).with_apply_func(apply_func, cache_func=cache_func)

RelMADist.run(ts2, 2, 3).out

# %%
RelMADist.run(ts2, [2, 2], [3, 4], per_column=True).out

# %% [markdown]
# ### Reusing cache

# %%
cache = RelMADist.run(
    ts2,
    w1=2,
    w2=np.arange(2, 1000),
    return_cache=True)

%timeit RelMADist.run( \
    ts2, \
    w1=np.arange(2, 1000), \
    w2=np.arange(2, 1000), \
    use_cache=cache)

# %% [markdown]
# ## Stacking

# %%
vbt.phelp(vbt.talib('SMA').run)

# %%
def apply_func(close, timeperiod1, timeperiod2):
    fast_ma = vbt.talib('SMA').run(close, timeperiod1)
    slow_ma = vbt.talib('SMA').run(close, timeperiod2)
    entries = fast_ma.real_crossed_above(slow_ma)
    exits = fast_ma.real_crossed_below(slow_ma)
    return (fast_ma.real, slow_ma.real, entries, exits)

MACrossover = vbt.IF(
    class_name="CrossSig",
    input_names=['close'],
    param_names=['timeperiod1', 'timeperiod2'],
    output_names=['fast_ma', 'slow_ma', 'entries', 'exits'],
).with_apply_func(apply_func)

MACrossover.run(ts2, 2, 3).entries

# %%
def sma(close, timeperiod):
    return vbt.talib('SMA').run(close, timeperiod, return_raw=True)[0][0]

def apply_func(close, timeperiod1, timeperiod2):
    fast_ma = sma(close, timeperiod1)
    slow_ma = sma(close, timeperiod2)
    entries = vbt.nb.crossed_above_nb(fast_ma, slow_ma)
    exits = vbt.nb.crossed_above_nb(slow_ma, fast_ma)
    return (fast_ma, slow_ma, entries, exits)

# %%
vbt.phelp(vbt.talib('SMA').apply_func)

# %%
def sma(close, timeperiod):
    return vbt.talib('SMA').apply_func((close,), (), (timeperiod,))

# %%