# ################################## HOW TO USE #################################### #
#                                                                                    #
# This is a Jupyter notebook formatted as a script                                   #
# Format: https://jupytext.readthedocs.io/en/latest/formats.html#the-percent-format  #
#                                                                                    #
# Save this file and remove the '.txt' extension                                     #
# In Jupyter Lab, right click on the Python file -> Open With -> Jupytext Notebook   #
# Make sure to have Jupytext installed: https://github.com/mwouts/jupytext           #
#                                                                                    #
# ################################################################################## #

# %% [markdown]
# #  Pipelines

# %%
def pipeline(data, period=7, multiplier=3):
    high = data.get('High')
    low = data.get('Low')
    close = data.get('Close')
    st = SuperTrend.run(
        high,
        low,
        close,
        period=period,
        multiplier=multiplier
    )
    entries = (~st.superl.isnull()).vbt.signals.fshift()
    exits = (~st.supers.isnull()).vbt.signals.fshift()
    pf = vbt.Portfolio.from_signals(
        close,
        entries=entries,
        exits=exits,
        fees=0.001,
        save_returns=True,
        max_order_records=0,
        freq='1h'
    )
    return pf.sharpe_ratio

pipeline(data)

# %%
%%timeit
pipeline(data)

# %%
op_tree = (product, periods, multipliers)
period_product, multiplier_product = vbt.generate_param_combs(op_tree)
period_product = np.asarray(period_product)
multiplier_product = np.asarray(multiplier_product)

%%timeit
pipeline(data, period_product, multiplier_product)

# %% [markdown]
# ### Chunked pipeline

# %%
chunked_pipeline = vbt.chunked(
    size=vbt.LenSizer(arg_query='period', single_type=int),
    arg_take_spec=dict(
        data=None,
        period=vbt.ChunkSlicer(),
        multiplier=vbt.ChunkSlicer()
    ),
    merge_func=lambda x: pd.concat(x).sort_index()
)(pipeline)

# %%
chunked_pipeline(
    data,
    period_product[:4],
    multiplier_product[:4],
    _n_chunks=2,
    _execute_kwargs=dict(show_progress=True)
)

# %%
chunk_meta, funcs_args = chunked_pipeline(
    data,
    period_product[:4],
    multiplier_product[:4],
    _n_chunks=2,
    _return_raw_chunks=True
)

chunk_meta

# %%
list(funcs_args)

# %%
%%timeit
chunked_pipeline(data, period_product, multiplier_product)

# %%
%%timeit
chunked_pipeline(data, period_product, multiplier_product, _chunk_len=1)

# %% [markdown]
# ## Numba pipeline

# %%
@njit(nogil=True)
def pipeline_nb(high, low, close,
                periods=np.asarray([7]),
                multipliers=np.asarray([3]),
                ann_factor=365):


    sharpe = np.empty(periods.size * close.shape[1], dtype=np.float_)
    long_entries = np.empty(close.shape, dtype=np.bool_)
    long_exits = np.empty(close.shape, dtype=np.bool_)
    group_lens = np.full(close.shape[1], 1)
    init_cash = 100.
    fees = 0.001
    k = 0

    for i in range(periods.size):
        for col in range(close.shape[1]):
            _, _, superl, supers = superfast_supertrend_nb(
                high[:, col],
                low[:, col],
                close[:, col],
                periods[i],
                multipliers[i]
            )
            long_entries[:, col] = vbt.nb.fshift_1d_nb(
                ~np.isnan(superl),
                fill_value=False
            )
            long_exits[:, col] = vbt.nb.fshift_1d_nb(
                ~np.isnan(supers),
                fill_value=False
            )

        sim_out = vbt.pf_nb.from_signals_nb(
            target_shape=close.shape,
            group_lens=group_lens,
            init_cash=init_cash,
            high=high,
            low=low,
            close=close,
            long_entries=long_entries,
            long_exits=long_exits,
            fees=fees,
            save_returns=True
        )
        returns = sim_out.in_outputs.returns
        _sharpe = vbt.ret_nb.sharpe_ratio_nb(returns, ann_factor, ddof=1)
        sharpe[k:k + close.shape[1]] = _sharpe
        k += close.shape[1]

    return sharpe

# %%
ann_factor = vbt.pd_acc.returns.get_ann_factor(freq='1h')
pipeline_nb(
    high.values,
    low.values,
    close.values,
    ann_factor=ann_factor
)

# %%
%%timeit
pipeline_nb(
    high.values,
    low.values,
    close.values,
    ann_factor=ann_factor
)

# %%
def merge_func(arrs, ann_args, input_columns):
    arr = np.concatenate(arrs)
    param_index = vbt.stack_indexes((
        pd.Index(ann_args['periods']['value'], name='st_period'),
        pd.Index(ann_args['multipliers']['value'], name='st_multiplier')
    ))
    index = vbt.combine_indexes((
        param_index,
        input_columns
    ))
    return pd.Series(arr, index=index)

nb_chunked = vbt.chunked(
    size=vbt.ArraySizer(arg_query='periods', axis=0),
    arg_take_spec=dict(
        high=None,
        low=None,
        close=None,
        periods=vbt.ArraySlicer(axis=0),
        multipliers=vbt.ArraySlicer(axis=0),
        ann_factor=None
    ),
    merge_func=merge_func,
    merge_kwargs=dict(
        ann_args=vbt.Rep("ann_args")
    )
)
chunked_pipeline_nb = nb_chunked(pipeline_nb)

# %%
chunked_pipeline_nb(
    high.values,
    low.values,
    close.values,
    periods=period_product[:4],
    multipliers=multiplier_product[:4],
    ann_factor=ann_factor,
    _n_chunks=2,
    _execute_kwargs=dict(show_progress=True),
    _merge_kwargs=dict(input_columns=close.columns)
)

# %%
%%timeit
chunked_pipeline_nb(
    high.values,
    low.values,
    close.values,
    periods=period_product,
    multipliers=multiplier_product,
    ann_factor=ann_factor,
    _merge_kwargs=dict(input_columns=close.columns)
)

# %%
%%timeit
chunked_pipeline_nb(
    high.values,
    low.values,
    close.values,
    periods=period_product,
    multipliers=multiplier_product,
    ann_factor=ann_factor,
    _execute_kwargs=dict(engine="dask"),
    _merge_kwargs=dict(input_columns=close.columns)
)

# %% [markdown]
# ## Contextualized pipeline
# ### Streaming Sharpe

# %%
class RollSharpeAIS(tp.NamedTuple):
    i: int
    ret: float
    pre_window_ret: float
    cumsum: float
    cumsum_sq: float
    nancnt: int
    window: int
    minp: tp.Optional[int]
    ddof: int
    ann_factor: float

class RollSharpeAOS(tp.NamedTuple):
    cumsum: float
    cumsum_sq: float
    nancnt: int
    value: float

@njit(nogil=True)
def rolling_sharpe_acc_nb(in_state):

    mean_in_state = vbt.nb.RollMeanAIS(
        i=in_state.i,
        value=in_state.ret,
        pre_window_value=in_state.pre_window_ret,
        cumsum=in_state.cumsum,
        nancnt=in_state.nancnt,
        window=in_state.window,
        minp=in_state.minp
    )
    mean_out_state = vbt.nb.rolling_mean_acc_nb(mean_in_state)


    std_in_state = vbt.nb.RollStdAIS(
        i=in_state.i,
        value=in_state.ret,
        pre_window_value=in_state.pre_window_ret,
        cumsum=in_state.cumsum,
        cumsum_sq=in_state.cumsum_sq,
        nancnt=in_state.nancnt,
        window=in_state.window,
        minp=in_state.minp,
        ddof=in_state.ddof
    )
    std_out_state = vbt.nb.rolling_std_acc_nb(std_in_state)


    mean = mean_out_state.value
    std = std_out_state.value
    if std == 0:
        sharpe = np.nan
    else:
        sharpe = mean / std * np.sqrt(in_state.ann_factor)


    return RollSharpeAOS(
        cumsum=std_out_state.cumsum,
        cumsum_sq=std_out_state.cumsum_sq,
        nancnt=std_out_state.nancnt,
        value=sharpe
    )

# %%
@njit(nogil=True)
def rolling_sharpe_ratio_nb(returns, window, minp=None, ddof=0, ann_factor=365):
    if window is None:
        window = returns.shape[0]
    if minp is None:
        minp = window
    out = np.empty(returns.shape, dtype=np.float_)

    if returns.shape[0] == 0:
        return out

    cumsum = 0.
    cumsum_sq = 0.
    nancnt = 0

    for i in range(returns.shape[0]):
        in_state = RollSharpeAIS(
            i=i,
            ret=returns[i],
            pre_window_ret=returns[i - window] if i - window >= 0 else np.nan,
            cumsum=cumsum,
            cumsum_sq=cumsum_sq,
            nancnt=nancnt,
            window=window,
            minp=minp,
            ddof=ddof,
            ann_factor=ann_factor
        )

        out_state = rolling_sharpe_acc_nb(in_state)

        cumsum = out_state.cumsum
        cumsum_sq = out_state.cumsum_sq
        nancnt = out_state.nancnt
        out[i] = out_state.value

    return out

ann_factor = vbt.pd_acc.returns.get_ann_factor(freq='1h')

returns = close['BTCUSDT'].vbt.to_returns()

np.testing.assert_allclose(
    rolling_sharpe_ratio_nb(
        returns=returns.values,
        window=10,
        ddof=1,
        ann_factor=ann_factor),
    returns.vbt.returns(freq='1h').rolling_sharpe_ratio(10).values
)

# %% [markdown]
# ### Callbacks

# %%
class Memory(tp.NamedTuple):
    nobs: tp.Array1d
    old_wt: tp.Array1d
    weighted_avg: tp.Array1d
    prev_upper: tp.Array1d
    prev_lower: tp.Array1d
    prev_dir_: tp.Array1d
    cumsum: tp.Array1d
    cumsum_sq: tp.Array1d
    nancnt: tp.Array1d
    was_entry: tp.Array1d
    was_exit: tp.Array1d

@njit(nogil=True)
def pre_sim_func_nb(c):
    memory = Memory(
        nobs=np.full(c.target_shape[1], 0, dtype=np.int_),
        old_wt=np.full(c.target_shape[1], 1., dtype=np.float_),
        weighted_avg=np.full(c.target_shape[1], np.nan, dtype=np.float_),
        prev_upper=np.full(c.target_shape[1], np.nan, dtype=np.float_),
        prev_lower=np.full(c.target_shape[1], np.nan, dtype=np.float_),
        prev_dir_=np.full(c.target_shape[1], np.nan, dtype=np.float_),
        cumsum=np.full(c.target_shape[1], 0., dtype=np.float_),
        cumsum_sq=np.full(c.target_shape[1], 0., dtype=np.float_),
        nancnt=np.full(c.target_shape[1], 0, dtype=np.int_),
        was_entry=np.full(c.target_shape[1], False, dtype=np.bool_),
        was_exit=np.full(c.target_shape[1], False, dtype=np.bool_)
    )
    return (memory,)

# %%
@njit(nogil=True)
def order_func_nb(c, memory, period, multiplier):

    is_entry = memory.was_entry[c.col]
    is_exit = memory.was_exit[c.col]


    in_state = SuperTrendAIS(
        i=c.i,
        high=c.high[c.i, c.col],
        low=c.low[c.i, c.col],
        close=c.close[c.i, c.col],
        prev_close=c.close[c.i - 1, c.col] if c.i > 0 else np.nan,
        prev_upper=memory.prev_upper[c.col],
        prev_lower=memory.prev_lower[c.col],
        prev_dir_=memory.prev_dir_[c.col],
        nobs=memory.nobs[c.col],
        weighted_avg=memory.weighted_avg[c.col],
        old_wt=memory.old_wt[c.col],
        period=period,
        multiplier=multiplier
    )


    out_state = superfast_supertrend_acc_nb(in_state)


    memory.nobs[c.col] = out_state.nobs
    memory.weighted_avg[c.col] = out_state.weighted_avg
    memory.old_wt[c.col] = out_state.old_wt
    memory.prev_upper[c.col] = out_state.upper
    memory.prev_lower[c.col] = out_state.lower
    memory.prev_dir_[c.col] = out_state.dir_
    memory.was_entry[c.col] = not np.isnan(out_state.long)
    memory.was_exit[c.col] = not np.isnan(out_state.short)


    in_position = c.position_now > 0
    if is_entry and not in_position:
        size = np.inf
    elif is_exit and in_position:
        size = -np.inf
    else:
        size = np.nan
    return vbt.pf_nb.order_nb(
        size=size,
        direction=vbt.pf_enums.Direction.LongOnly,
        fees=0.001
    )

# %%
@njit(nogil=True)
def post_segment_func_nb(c, memory, ann_factor):
    for col in range(c.from_col, c.to_col):
        in_state = RollSharpeAIS(
            i=c.i,
            ret=c.last_return[col],
            pre_window_ret=np.nan,
            cumsum=memory.cumsum[col],
            cumsum_sq=memory.cumsum_sq[col],
            nancnt=memory.nancnt[col],
            window=c.i + 1,
            minp=0,
            ddof=1,
            ann_factor=ann_factor
        )
        out_state = rolling_sharpe_acc_nb(in_state)
        memory.cumsum[col] = out_state.cumsum
        memory.cumsum_sq[col] = out_state.cumsum_sq
        memory.nancnt[col] = out_state.nancnt
        c.in_outputs.sharpe[col] = out_state.value

# %% [markdown]
# ### Pipeline

# %%
class InOutputs(tp.NamedTuple):
    sharpe: tp.Array1d

@njit(nogil=True)
def ctx_pipeline_nb(high, low, close,
                    periods=np.asarray([7]),
                    multipliers=np.asarray([3]),
                    ann_factor=365):

    in_outputs = InOutputs(sharpe=np.empty(close.shape[1], dtype=np.float_))
    sharpe = np.empty(periods.size * close.shape[1], dtype=np.float_)
    group_lens = np.full(close.shape[1], 1)
    init_cash = 100.
    k = 0

    for i in range(periods.size):
        sim_out = vbt.pf_nb.from_order_func_nb(
            target_shape=close.shape,
            group_lens=group_lens,
            cash_sharing=False,
            init_cash=init_cash,
            pre_sim_func_nb=pre_sim_func_nb,
            order_func_nb=order_func_nb,
            order_args=(periods[i], multipliers[i]),
            post_segment_func_nb=post_segment_func_nb,
            post_segment_args=(ann_factor,),
            high=high,
            low=low,
            close=close,
            in_outputs=in_outputs,
            fill_pos_info=False,
            max_order_records=0
        )
        sharpe[k:k + close.shape[1]] = in_outputs.sharpe
        k += close.shape[1]

    return sharpe

# %%
ctx_pipeline_nb(
    high.values,
    low.values,
    close.values,
    ann_factor=ann_factor
)

# %%
chunked_ctx_pipeline_nb = nb_chunked(ctx_pipeline_nb)
chunked_ctx_pipeline_nb(
    high.values,
    low.values,
    close.values,
    periods=period_product[:4],
    multipliers=multiplier_product[:4],
    ann_factor=ann_factor,
    _n_chunks=2,
    _execute_kwargs=dict(show_progress=True),
    _merge_kwargs=dict(input_columns=close.columns)
)

# %%
%%timeit
chunked_ctx_pipeline_nb(
    high.values,
    low.values,
    close.values,
    periods=period_product,
    multipliers=multiplier_product,
    ann_factor=ann_factor,
    _merge_kwargs=dict(input_columns=close.columns)
)

# %%
%%timeit
chunked_ctx_pipeline_nb(
    high.values,
    low.values,
    close.values,
    periods=period_product,
    multipliers=multiplier_product,
    ann_factor=ann_factor,
    _execute_kwargs=dict(engine="dask"),
    _merge_kwargs=dict(input_columns=close.columns)
)

# %% [markdown]
# ## Bonus: Own simulator

# %%
@njit(nogil=True)
def raw_pipeline_nb(high, low, close,
                    periods=np.array([7]),
                    multipliers=np.array([3]),
                    ann_factor=365):
    out = np.empty(periods.size * close.shape[1], dtype=np.float_)

    if close.shape[0] == 0:
        return out

    for k in range(len(periods)):

        for col in range(close.shape[1]):

            nobs = 0
            old_wt = 1.
            weighted_avg = np.nan
            prev_close_ = np.nan
            prev_upper = np.nan
            prev_lower = np.nan
            prev_dir_ = 1
            cumsum = 0.
            cumsum_sq = 0.
            nancnt = 0
            was_entry = False
            was_exit = False


            init_cash = 100.
            cash = init_cash
            position = 0.
            debt = 0.
            locked_cash = 0.
            free_cash = init_cash
            val_price = np.nan
            value = init_cash
            prev_value = init_cash
            return_ = 0.

            for i in range(close.shape[0]):

                is_entry = was_entry
                is_exit = was_exit

                st_in_state = SuperTrendAIS(
                    i=i,
                    high=high[i, col],
                    low=low[i, col],
                    close=close[i, col],
                    prev_close=prev_close_,
                    prev_upper=prev_upper,
                    prev_lower=prev_lower,
                    prev_dir_=prev_dir_,
                    nobs=nobs,
                    weighted_avg=weighted_avg,
                    old_wt=old_wt,
                    period=periods[k],
                    multiplier=multipliers[k]
                )

                st_out_state = superfast_supertrend_acc_nb(st_in_state)

                nobs = st_out_state.nobs
                weighted_avg = st_out_state.weighted_avg
                old_wt = st_out_state.old_wt
                prev_close_ = close[i, col]
                prev_upper = st_out_state.upper
                prev_lower = st_out_state.lower
                prev_dir_ = st_out_state.dir_
                was_entry = not np.isnan(st_out_state.long)
                was_exit = not np.isnan(st_out_state.short)

                if is_entry and position == 0:
                    size = np.inf
                elif is_exit and position > 0:
                    size = -np.inf
                else:
                    size = np.nan


                val_price = close[i, col]
                value = cash + position * val_price
                if not np.isnan(size):
                    exec_state = vbt.pf_enums.ExecState(
                        cash=cash,
                        position=position,
                        debt=debt,
                        locked_cash=locked_cash,
                        free_cash=free_cash,
                        val_price=val_price,
                        value=value
                    )
                    price_area = vbt.pf_enums.PriceArea(
                        open=np.nan,
                        high=high[i, col],
                        low=low[i, col],
                        close=close[i, col]
                    )
                    order = vbt.pf_nb.order_nb(
                        size=size,
                        direction=vbt.pf_enums.Direction.LongOnly,
                        fees=0.001
                    )
                    _, new_exec_state = vbt.pf_nb.execute_order_nb(
                        exec_state, order, price_area)
                    cash, position, debt, locked_cash, free_cash, val_price, value = new_exec_state

                value = cash + position * val_price
                return_ = vbt.ret_nb.get_return_nb(prev_value, value)
                prev_value = value


                sharpe_in_state = RollSharpeAIS(
                    i=i,
                    ret=return_,
                    pre_window_ret=np.nan,
                    cumsum=cumsum,
                    cumsum_sq=cumsum_sq,
                    nancnt=nancnt,
                    window=i + 1,
                    minp=0,
                    ddof=1,
                    ann_factor=ann_factor
                )
                sharpe_out_state = rolling_sharpe_acc_nb(sharpe_in_state)
                cumsum = sharpe_out_state.cumsum
                cumsum_sq = sharpe_out_state.cumsum_sq
                nancnt = sharpe_out_state.nancnt
                sharpe = sharpe_out_state.value

            out[k * close.shape[1] + col] = sharpe

    return out

# %%
chunked_raw_pipeline_nb = nb_chunked(raw_pipeline_nb)

%%timeit
chunked_raw_pipeline_nb(
    high.values,
    low.values,
    close.values,
    periods=period_product,
    multipliers=multiplier_product,
    ann_factor=ann_factor,
    _merge_kwargs=dict(input_columns=close.columns)
)

# %%
%%timeit
chunked_raw_pipeline_nb(
    high.values,
    low.values,
    close.values,
    periods=period_product,
    multipliers=multiplier_product,
    ann_factor=ann_factor,
    _execute_kwargs=dict(engine="dask"),
    _merge_kwargs=dict(input_columns=close.columns)
)

# %%
range_len = int(pd.Timedelta('365d') / pd.Timedelta('1h'))
splitter = vbt.Splitter.from_n_rolling(
    high.index,
    n=100,
    length=range_len
)

roll_high = splitter.take(high, into="reset_stacked")
roll_low = splitter.take(low, into="reset_stacked")
roll_close = splitter.take(close, into="reset_stacked")
roll_close.columns

# %%
range_indexes = splitter.take(high.index)
range_indexes[0]

# %%
sharpe_ratios = chunked_raw_pipeline_nb(
    roll_high.values,
    roll_low.values,
    roll_close.values,
    periods=period_product,
    multipliers=multiplier_product,
    ann_factor=ann_factor,
    _execute_kwargs=dict(engine="dask"),
    _merge_kwargs=dict(input_columns=roll_close.columns)
)

sharpe_ratios

# %%
pf_hold = vbt.Portfolio.from_holding(roll_close, freq='1h')
sharpe_ratios_hold = pf_hold.sharpe_ratio

sharpe_ratios_hold

# %%
def plot_subperiod_sharpe(index,
                          sharpe_ratios,
                          sharpe_ratios_hold,
                          range_indexes,
                          symbol):
    split = index[0]
    sharpe_ratios = sharpe_ratios.xs(
        symbol,
        level='symbol',
        drop_level=True)
    sharpe_ratios = sharpe_ratios.xs(
        split,
        level='split',
        drop_level=True)
    start_date = range_indexes[split][0]
    end_date = range_indexes[split][-1]
    return sharpe_ratios.vbt.heatmap(
        x_level='st_period',
        y_level='st_multiplier',
        title="{} - {}".format(
            start_date.strftime("%d %b, %Y %H:%M:%S"),
            end_date.strftime("%d %b, %Y %H:%M:%S")
        ),
        trace_kwargs=dict(
            zmin=sharpe_ratios.min(),
            zmid=sharpe_ratios_hold[(split, symbol)],
            zmax=sharpe_ratios.max(),
            colorscale='Spectral'
        )
    )

# %%
fname = 'raw_pipeline.gif'
level_idx = sharpe_ratios.index.names.index('split')
split_indices = sharpe_ratios.index.levels[level_idx]

vbt.save_animation(
    fname,
    split_indices,
    plot_subperiod_sharpe,
    sharpe_ratios,
    sharpe_ratios_hold,
    range_indexes,
    'BTCUSDT',
    delta=1,
    fps=7,
    writer_kwargs=dict(loop=0)
)

# %%
from IPython.display import Image, display

with open(fname,'rb') as f:
    display(Image(data=f.read(), format='png'))

# %%