Files
strategy-lab/to_explore/notebooks/SuperTrend.ipynb

68 KiB

SuperFast SuperTrend

Data

In [ ]:
from vectorbtpro import *
# whats_imported()

vbt.settings.set_theme('dark')
In [ ]:
# data = vbt.BinanceData.pull(
#     ['BTCUSDT', 'ETHUSDT'], 
#     start='2020-01-01 UTC',
#     end='2022-01-01 UTC',
#     timeframe='1h'
# )
In [ ]:
# data.to_hdf('my_data.h5')
In [ ]:
data = vbt.HDFData.pull('my_data.h5')
In [ ]:
data.data['BTCUSDT'].info()
In [ ]:
data.stats()
In [ ]:
high = data.get('High')
low = data.get('Low')
close = data.get('Close')
In [ ]:
print(close)

Design

Pandas

In [ ]:
def get_med_price(high, low):
    return (high + low) / 2
In [ ]:
def get_atr(high, low, close, period):
    tr0 = abs(high - low)
    tr1 = abs(high - close.shift())
    tr2 = abs(low - close.shift())
    tr = pd.concat((tr0, tr1, tr2), axis=1).max(axis=1)
    atr = tr.ewm(alpha=1 / period, adjust=False, min_periods=period).mean()
    return atr
In [ ]:
def get_basic_bands(med_price, atr, multiplier):
    matr = multiplier * atr
    upper = med_price + matr
    lower = med_price - matr
    return upper, lower
In [ ]:
def get_final_bands(close, upper, lower):
    trend = pd.Series(np.full(close.shape, np.nan), index=close.index)
    dir_ = pd.Series(np.full(close.shape, 1), index=close.index)
    long = pd.Series(np.full(close.shape, np.nan), index=close.index)
    short = pd.Series(np.full(close.shape, np.nan), index=close.index)

    for i in range(1, close.shape[0]):
        if close.iloc[i] > upper.iloc[i - 1]:
            dir_.iloc[i] = 1
        elif close.iloc[i] < lower.iloc[i - 1]:
            dir_.iloc[i] = -1
        else:
            dir_.iloc[i] = dir_.iloc[i - 1]
            if dir_.iloc[i] > 0 and lower.iloc[i] < lower.iloc[i - 1]:
                lower.iloc[i] = lower.iloc[i - 1]
            if dir_.iloc[i] < 0 and upper.iloc[i] > upper.iloc[i - 1]:
                upper.iloc[i] = upper.iloc[i - 1]

        if dir_.iloc[i] > 0:
            trend.iloc[i] = long.iloc[i] = lower.iloc[i]
        else:
            trend.iloc[i] = short.iloc[i] = upper.iloc[i]
            
    return trend, dir_, long, short
In [ ]:
def supertrend(high, low, close, period=7, multiplier=3):
    med_price = get_med_price(high, low)
    atr = get_atr(high, low, close, period)
    upper, lower = get_basic_bands(med_price, atr, multiplier)
    return get_final_bands(close, upper, lower)
In [ ]:
supert, superd, superl, supers = supertrend(
    high['BTCUSDT'], 
    low['BTCUSDT'], 
    close['BTCUSDT']
)
In [ ]:
supert
In [ ]:
superd
In [ ]:
superl
In [ ]:
supers
In [ ]:
date_range = slice('2020-01-01', '2020-02-01')
fig = close.loc[date_range, 'BTCUSDT'].rename('Close').vbt.plot()
supers.loc[date_range].rename('Short').vbt.plot(fig=fig)
superl.loc[date_range].rename('Long').vbt.plot(fig=fig).show_svg()
In [ ]:
%%timeit
supertrend(high['BTCUSDT'], low['BTCUSDT'], close['BTCUSDT'])
In [ ]:
SUPERTREND = vbt.pandas_ta('SUPERTREND')
In [ ]:
%%timeit
SUPERTREND.run(high['BTCUSDT'], low['BTCUSDT'], close['BTCUSDT'])

NumPy + Numba

In [ ]:
def get_atr_np(high, low, close, period):
    shifted_close = vbt.nb.fshift_1d_nb(close)
    tr0 = np.abs(high - low)
    tr1 = np.abs(high - shifted_close)
    tr2 = np.abs(low - shifted_close)
    tr = np.column_stack((tr0, tr1, tr2)).max(axis=1)
    atr = vbt.nb.wwm_mean_1d_nb(tr, period)
    return atr
In [ ]:
@njit
def get_final_bands_nb(close, upper, lower):
    trend = np.full(close.shape, np.nan)
    dir_ = np.full(close.shape, 1)
    long = np.full(close.shape, np.nan)
    short = np.full(close.shape, np.nan)

    for i in range(1, close.shape[0]):
        if close[i] > upper[i - 1]:
            dir_[i] = 1
        elif close[i] < lower[i - 1]:
            dir_[i] = -1
        else:
            dir_[i] = dir_[i - 1]
            if dir_[i] > 0 and lower[i] < lower[i - 1]:
                lower[i] = lower[i - 1]
            if dir_[i] < 0 and upper[i] > upper[i - 1]:
                upper[i] = upper[i - 1]

        if dir_[i] > 0:
            trend[i] = long[i] = lower[i]
        else:
            trend[i] = short[i] = upper[i]
            
    return trend, dir_, long, short
In [ ]:
def faster_supertrend(high, low, close, period=7, multiplier=3):
    med_price = get_med_price(high, low)
    atr = get_atr_np(high, low, close, period)
    upper, lower = get_basic_bands(med_price, atr, multiplier)
    return get_final_bands_nb(close, upper, lower)
In [ ]:
supert, superd, superl, supers = faster_supertrend(
    high['BTCUSDT'].values, 
    low['BTCUSDT'].values, 
    close['BTCUSDT'].values
)
In [ ]:
supert
In [ ]:
superd
In [ ]:
superl
In [ ]:
supers
In [ ]:
pd.Series(supert, index=close.index)
In [ ]:
%%timeit
faster_supertrend(
    high['BTCUSDT'].values, 
    low['BTCUSDT'].values,
    close['BTCUSDT'].values
)

NumPy + Numba + TA-Lib

In [ ]:
import talib

def faster_supertrend_talib(high, low, close, period=7, multiplier=3):
    avg_price = talib.MEDPRICE(high, low)
    atr = talib.ATR(high, low, close, period)
    upper, lower = get_basic_bands(avg_price, atr, multiplier)
    return get_final_bands_nb(close, upper, lower)
In [ ]:
faster_supertrend_talib(
    high['BTCUSDT'].values, 
    low['BTCUSDT'].values, 
    close['BTCUSDT'].values
)
In [ ]:
%%timeit
faster_supertrend_talib(
    high['BTCUSDT'].values, 
    low['BTCUSDT'].values, 
    close['BTCUSDT'].values
)

Indicator factory

In [ ]:
SuperTrend = vbt.IF(
    class_name='SuperTrend',
    short_name='st',
    input_names=['high', 'low', 'close'],
    param_names=['period', 'multiplier'],
    output_names=['supert', 'superd', 'superl', 'supers']
).with_apply_func(
    faster_supertrend_talib, 
    takes_1d=True,
    period=7, 
    multiplier=3
)
In [ ]:
help(SuperTrend.run)
In [ ]:
st = SuperTrend.run(high, low, close)
print(st.supert)
In [ ]:
%%timeit
SuperTrend.run(high, low, close)

Using expressions

In [ ]:
expr = """
SuperTrend[st]:
medprice = @talib_medprice(high, low)
atr = @talib_atr(high, low, close, @p_period)
upper, lower = get_basic_bands(medprice, atr, @p_multiplier)
supert, superd, superl, supers = get_final_bands(close, upper, lower)
supert, superd, superl, supers
"""
In [ ]:
SuperTrend = vbt.IF.from_expr(
    expr, 
    takes_1d=True,
    get_basic_bands=get_basic_bands,
    get_final_bands=get_final_bands_nb,
    period=7, 
    multiplier=3
)
In [ ]:
st = SuperTrend.run(high, low, close)
print(st.supert)
In [ ]:
%%timeit
SuperTrend.run(high, low, close)

Plot indicator

In [ ]:
class SuperTrend(SuperTrend):
    def plot(self, 
             column=None, 
             close_kwargs=None,
             superl_kwargs=None,
             supers_kwargs=None,
             fig=None, 
             **layout_kwargs):
        close_kwargs = close_kwargs if close_kwargs else {}
        superl_kwargs = superl_kwargs if superl_kwargs else {}
        supers_kwargs = supers_kwargs if supers_kwargs else {}
        
        close = self.select_col_from_obj(self.close, column).rename('Close')
        supers = self.select_col_from_obj(self.supers, column).rename('Short')
        superl = self.select_col_from_obj(self.superl, column).rename('Long')
        
        fig = close.vbt.plot(fig=fig, **close_kwargs, **layout_kwargs)
        supers.vbt.plot(fig=fig, **supers_kwargs)
        superl.vbt.plot(fig=fig, **superl_kwargs)
        
        return fig
In [ ]:
st = SuperTrend.run(high, low, close)
st.loc[date_range, 'BTCUSDT'].plot(
    superl_kwargs=dict(trace_kwargs=dict(line_color='limegreen')),
    supers_kwargs=dict(trace_kwargs=dict(line_color='red'))
).show_svg()

Test indicator

In [ ]:
entries = (~st.superl.isnull()).vbt.signals.fshift()
exits = (~st.supers.isnull()).vbt.signals.fshift()
In [ ]:
pf = vbt.Portfolio.from_signals(
    close=close, 
    entries=entries, 
    exits=exits, 
    fees=0.001, 
    freq='1h'
)
In [ ]:
pf['ETHUSDT'].stats()

Optimization

In [ ]:
periods = np.arange(4, 20)
multipliers = np.arange(20, 41) / 10
In [ ]:
st = SuperTrend.run(
    high, low, close, 
    period=periods, 
    multiplier=multipliers,
    param_product=True,
)
In [ ]:
st.wrapper.columns
In [ ]:
st.loc[date_range, (19, 4, 'ETHUSDT')].plot().show_svg()
In [ ]:
print(st.getsize())
In [ ]:
input_size = st.wrapper.shape[0] * st.wrapper.shape[1]
n_outputs = 4
data_type_size = 8
input_size * n_outputs * data_type_size / 1024 / 1024
In [ ]:
entries = (~st.superl.isnull()).vbt.signals.fshift()
exits = (~st.supers.isnull()).vbt.signals.fshift()
In [ ]:
pf = vbt.Portfolio.from_signals(close, entries, exits, fees=0.001, freq='1h')
In [ ]:
pf.sharpe_ratio.vbt.heatmap(
    x_level='st_period', 
    y_level='st_multiplier',
    slider_level='symbol'
).show_svg()
In [ ]:
vbt.Portfolio.from_holding(close, freq='1h').sharpe_ratio

Streaming

In [ ]:
class SuperTrendAIS(tp.NamedTuple):
    i: int
    high: float
    low: float
    close: float
    prev_close: float
    prev_upper: float
    prev_lower: float
    prev_dir_: float
    nobs: int
    weighted_avg: float
    old_wt: float
    period: int
    multiplier: float
    
class SuperTrendAOS(tp.NamedTuple):
    nobs: int
    weighted_avg: float
    old_wt: float
    upper: float
    lower: float
    trend: float
    dir_: float
    long: float
    short: float
In [ ]:
@njit(nogil=True)
def get_tr_one_nb(high, low, prev_close):
    tr0 = abs(high - low)
    tr1 = abs(high - prev_close)
    tr2 = abs(low - prev_close)
    if np.isnan(tr0) or np.isnan(tr1) or np.isnan(tr2):
        tr = np.nan
    else:
        tr = max(tr0, tr1, tr2)
    return tr

@njit(nogil=True)
def get_med_price_one_nb(high, low):
    return (high + low) / 2

@njit(nogil=True)
def get_basic_bands_one_nb(high, low, atr, multiplier):
    med_price = get_med_price_one_nb(high, low)
    matr = multiplier * atr
    upper = med_price + matr
    lower = med_price - matr
    return upper, lower
    
@njit(nogil=True)
def get_final_bands_one_nb(close, upper, lower, prev_upper, prev_lower, prev_dir_):
    if close > prev_upper:
        dir_ = 1
    elif close < prev_lower:
        dir_ = -1
    else:
        dir_ = prev_dir_
        if dir_ > 0 and lower < prev_lower:
            lower = prev_lower
        if dir_ < 0 and upper > prev_upper:
            upper = prev_upper

    if dir_ > 0:
        trend = long = lower
        short = np.nan
    else:
        trend = short = upper
        long = np.nan
    return upper, lower, trend, dir_, long, short
In [ ]:
@njit(nogil=True)
def superfast_supertrend_acc_nb(in_state):
    i = in_state.i
    high = in_state.high
    low = in_state.low
    close = in_state.close
    prev_close = in_state.prev_close
    prev_upper = in_state.prev_upper
    prev_lower = in_state.prev_lower
    prev_dir_ = in_state.prev_dir_
    nobs = in_state.nobs
    weighted_avg = in_state.weighted_avg
    old_wt = in_state.old_wt
    period = in_state.period
    multiplier = in_state.multiplier
    
    tr = get_tr_one_nb(high, low, prev_close)

    alpha = vbt.nb.alpha_from_wilder_nb(period)
    ewm_mean_in_state = vbt.nb.EWMMeanAIS(
        i=i,
        value=tr,
        old_wt=old_wt,
        weighted_avg=weighted_avg,
        nobs=nobs,
        alpha=alpha,
        minp=period,
        adjust=False
    )
    ewm_mean_out_state = vbt.nb.ewm_mean_acc_nb(ewm_mean_in_state)
    atr = ewm_mean_out_state.value
    
    upper, lower = get_basic_bands_one_nb(high, low, atr, multiplier)
    
    if i == 0:
        trend, dir_, long, short = np.nan, 1, np.nan, np.nan
    else:
        upper, lower, trend, dir_, long, short = get_final_bands_one_nb(
            close, upper, lower, prev_upper, prev_lower, prev_dir_)
            
    return SuperTrendAOS(
        nobs=ewm_mean_out_state.nobs,
        weighted_avg=ewm_mean_out_state.weighted_avg,
        old_wt=ewm_mean_out_state.old_wt,
        upper=upper,
        lower=lower,
        trend=trend,
        dir_=dir_,
        long=long,
        short=short
    )
In [ ]:
@njit(nogil=True)
def superfast_supertrend_nb(high, low, close, period=7, multiplier=3):
    trend = np.empty(close.shape, dtype=np.float_)
    dir_ = np.empty(close.shape, dtype=np.int_)
    long = np.empty(close.shape, dtype=np.float_)
    short = np.empty(close.shape, dtype=np.float_)
    
    if close.shape[0] == 0:
        return trend, dir_, long, short

    nobs = 0
    old_wt = 1.
    weighted_avg = np.nan
    prev_upper = np.nan
    prev_lower = np.nan

    for i in range(close.shape[0]):
        in_state = SuperTrendAIS(
            i=i,
            high=high[i],
            low=low[i],
            close=close[i],
            prev_close=close[i - 1] if i > 0 else np.nan,
            prev_upper=prev_upper,
            prev_lower=prev_lower,
            prev_dir_=dir_[i - 1] if i > 0 else 1,
            nobs=nobs,
            weighted_avg=weighted_avg,
            old_wt=old_wt,
            period=period,
            multiplier=multiplier
        )
        
        out_state = superfast_supertrend_acc_nb(in_state)
        
        nobs = out_state.nobs
        weighted_avg = out_state.weighted_avg
        old_wt = out_state.old_wt
        prev_upper = out_state.upper
        prev_lower = out_state.lower
        trend[i] = out_state.trend
        dir_[i] = out_state.dir_
        long[i] = out_state.long
        short[i] = out_state.short
        
    return trend, dir_, long, short
In [ ]:
superfast_out = superfast_supertrend_nb(
    high['BTCUSDT'].values,
    low['BTCUSDT'].values,
    close['BTCUSDT'].values
)
In [ ]:
faster_out = faster_supertrend(
    high['BTCUSDT'].values,
    low['BTCUSDT'].values,
    close['BTCUSDT'].values
)
In [ ]:
np.testing.assert_array_equal(superfast_out[0], faster_out[0])
np.testing.assert_array_equal(superfast_out[1], faster_out[1])
np.testing.assert_array_equal(superfast_out[2], faster_out[2])
np.testing.assert_array_equal(superfast_out[3], faster_out[3])
In [ ]:
%%timeit
superfast_supertrend_nb(
    high['BTCUSDT'].values, 
    low['BTCUSDT'].values, 
    close['BTCUSDT'].values
)

Multithreading

In [ ]:
SuperTrend = vbt.IF(
    class_name='SuperTrend',
    short_name='st',
    input_names=['high', 'low', 'close'],
    param_names=['period', 'multiplier'],
    output_names=['supert', 'superd', 'superl', 'supers']
).with_apply_func(
    superfast_supertrend_nb, 
    takes_1d=True,
    period=7, 
    multiplier=3
)
In [ ]:
%%timeit
SuperTrend.run(high, low, close)
In [ ]:
%%timeit
SuperTrend.run(
    high, low, close, 
    period=periods, 
    multiplier=multipliers,
    param_product=True,
    execute_kwargs=dict(show_progress=False)
)
In [ ]:
270 / 336 / 2
In [ ]:
%%timeit
SuperTrend.run(
    high, low, close, 
    period=periods, 
    multiplier=multipliers,
    param_product=True,
    execute_kwargs=dict(
        engine='dask', 
        chunk_len='auto', 
        show_progress=False
    )
)

Pipelines

In [ ]:
def pipeline(data, period=7, multiplier=3):
    high = data.get('High')
    low = data.get('Low')
    close = data.get('Close')
    st = SuperTrend.run(
        high, 
        low, 
        close, 
        period=period, 
        multiplier=multiplier
    )
    entries = (~st.superl.isnull()).vbt.signals.fshift()
    exits = (~st.supers.isnull()).vbt.signals.fshift()
    pf = vbt.Portfolio.from_signals(
        close, 
        entries=entries, 
        exits=exits, 
        fees=0.001,
        save_returns=True,
        max_order_records=0,
        freq='1h'
    )
    return pf.sharpe_ratio
In [ ]:
pipeline(data)
In [ ]:
%%timeit
pipeline(data)
In [ ]:
336 * 32
In [ ]:
op_tree = (product, periods, multipliers)
period_product, multiplier_product = vbt.generate_param_combs(op_tree)
period_product = np.asarray(period_product)
multiplier_product = np.asarray(multiplier_product)
In [ ]:
%%timeit
pipeline(data, period_product, multiplier_product)

Chunked pipeline

In [ ]:
chunked_pipeline = vbt.chunked(
    size=vbt.LenSizer(arg_query='period', single_type=int),
    arg_take_spec=dict(
        data=None,
        period=vbt.ChunkSlicer(),
        multiplier=vbt.ChunkSlicer()
    ),
    merge_func=lambda x: pd.concat(x).sort_index()
)(pipeline)
In [ ]:
chunked_pipeline(data)
In [ ]:
chunked_pipeline(
    data, 
    period_product[:4], 
    multiplier_product[:4],
    _n_chunks=2,
)
In [ ]:
chunk_meta, tasks = chunked_pipeline(
    data, 
    period_product[:4], 
    multiplier_product[:4],
    _n_chunks=2,
    _return_raw_chunks=True
)
In [ ]:
chunk_meta
In [ ]:
list(tasks)
In [ ]:
%%timeit
chunked_pipeline(data, period_product, multiplier_product)
In [ ]:
%%timeit
chunked_pipeline(data, period_product, multiplier_product, _chunk_len=1)

Numba pipeline

In [ ]:
@njit(nogil=True)
def pipeline_nb(high, low, close, periods=np.array([7]), multipliers=np.array([3]), ann_factor=365):
    sharpe = np.empty(periods.size * close.shape[1], dtype=np.float_)
    long_entries = np.empty(close.shape, dtype=np.bool_)
    long_exits = np.empty(close.shape, dtype=np.bool_)
    group_lens = np.full(close.shape[1], 1)
    init_cash = 100.
    fees = 0.001
    k = 0
    
    for i in range(periods.size):
        for col in range(close.shape[1]):
            _, _, superl, supers = superfast_supertrend_nb(
                high[:, col], 
                low[:, col], 
                close[:, col], 
                periods[i], 
                multipliers[i]
            )
            long_entries[:, col] = vbt.nb.fshift_1d_nb(~np.isnan(superl), fill_value=False)
            long_exits[:, col] = vbt.nb.fshift_1d_nb(~np.isnan(supers), fill_value=False)
            
        sim_out = vbt.pf_nb.from_signals_nb(
            target_shape=close.shape,
            group_lens=group_lens,
            init_cash=init_cash,
            high=high,
            low=low,
            close=close,
            long_entries=long_entries,
            long_exits=long_exits,
            fees=fees,
            save_returns=True
        )
        returns = sim_out.in_outputs.returns
        sharpe[k:k + close.shape[1]] = vbt.ret_nb.sharpe_ratio_nb(returns, ann_factor, ddof=1)
        k += close.shape[1]
        
    return sharpe
In [ ]:
ann_factor = vbt.pd_acc.returns.get_ann_factor(freq='1h')
In [ ]:
pipeline_nb(
    high.values, 
    low.values, 
    close.values,
    ann_factor=ann_factor
)
In [ ]:
%%timeit
pipeline_nb(
    high.values, 
    low.values, 
    close.values,
    ann_factor=ann_factor
)
In [ ]:
def merge_func(arrs, ann_args, input_columns):
    arr = np.concatenate(arrs)
    param_index = vbt.stack_indexes((
        pd.Index(ann_args['periods']['value'], name='st_period'),
        pd.Index(ann_args['multipliers']['value'], name='st_multiplier')
    ))
    index = vbt.combine_indexes((
        param_index,
        input_columns
    ))
    return pd.Series(arr, index=index)

nb_chunked = vbt.chunked(
    size=vbt.ArraySizer(arg_query='periods', axis=0),
    arg_take_spec=dict(
        high=None,
        low=None,
        close=None,
        periods=vbt.ArraySlicer(axis=0),
        multipliers=vbt.ArraySlicer(axis=0),
        ann_factor=None
    ),
    merge_func=merge_func,
    merge_kwargs=dict(
        ann_args=vbt.Rep("ann_args")
    )
)
chunked_pipeline_nb = nb_chunked(pipeline_nb)
In [ ]:
chunked_pipeline_nb(
    high.values, 
    low.values,
    close.values,
    periods=period_product[:4], 
    multipliers=multiplier_product[:4],
    ann_factor=ann_factor,
    _n_chunks=2,
    _merge_kwargs=dict(input_columns=close.columns)
)
In [ ]:
%%timeit
chunked_pipeline_nb(
    high.values, 
    low.values, 
    close.values,
    periods=period_product, 
    multipliers=multiplier_product,
    ann_factor=ann_factor,
    _merge_kwargs=dict(input_columns=close.columns)
)
In [ ]:
%%timeit
chunked_pipeline_nb(
    high.values, 
    low.values, 
    close.values,
    periods=period_product, 
    multipliers=multiplier_product,
    ann_factor=ann_factor,
    _execute_kwargs=dict(engine='dask'),
    _merge_kwargs=dict(input_columns=close.columns)
)

Contextualized pipeline

Streaming Sharpe

In [ ]:
class RollSharpeAIS(tp.NamedTuple):
    i: int
    ret: float
    pre_window_ret: float
    cumsum: float
    cumsum_sq: float
    nancnt: int
    window: int
    minp: tp.Optional[int]
    ddof: int
    ann_factor: float
    
class RollSharpeAOS(tp.NamedTuple):
    cumsum: float
    cumsum_sq: float
    nancnt: int
    value: float

@njit(nogil=True)
def rolling_sharpe_acc_nb(in_state):
    mean_in_state = vbt.nb.RollMeanAIS(
        i=in_state.i,
        value=in_state.ret,
        pre_window_value=in_state.pre_window_ret,
        cumsum=in_state.cumsum,
        nancnt=in_state.nancnt,
        window=in_state.window,
        minp=in_state.minp
    )
    mean_out_state = vbt.nb.rolling_mean_acc_nb(mean_in_state)
    
    std_in_state = vbt.nb.RollStdAIS(
        i=in_state.i,
        value=in_state.ret,
        pre_window_value=in_state.pre_window_ret,
        cumsum=in_state.cumsum,
        cumsum_sq=in_state.cumsum_sq,
        nancnt=in_state.nancnt,
        window=in_state.window,
        minp=in_state.minp,
        ddof=in_state.ddof
    )
    std_out_state = vbt.nb.rolling_std_acc_nb(std_in_state)
    
    mean = mean_out_state.value
    std = std_out_state.value
    if std == 0:
        sharpe = np.nan
    else:
        sharpe = mean / std * np.sqrt(in_state.ann_factor)
    return RollSharpeAOS(
        cumsum=std_out_state.cumsum,
        cumsum_sq=std_out_state.cumsum_sq,
        nancnt=std_out_state.nancnt,
        value=sharpe
    )
In [ ]:
@njit(nogil=True)
def rolling_sharpe_ratio_nb(returns, window, minp=None, ddof=0, ann_factor=365):
    if window is None:
        window = returns.shape[0]
    if minp is None:
        minp = window
    out = np.empty(returns.shape, dtype=np.float_)
    
    if returns.shape[0] == 0:
        return out

    cumsum = 0.
    cumsum_sq = 0.
    nancnt = 0

    for i in range(returns.shape[0]):
        in_state = RollSharpeAIS(
            i=i,
            ret=returns[i],
            pre_window_ret=returns[i - window] if i - window >= 0 else np.nan,
            cumsum=cumsum,
            cumsum_sq=cumsum_sq,
            nancnt=nancnt,
            window=window,
            minp=minp,
            ddof=ddof,
            ann_factor=ann_factor
        )
        
        out_state = rolling_sharpe_acc_nb(in_state)
        
        cumsum = out_state.cumsum
        cumsum_sq = out_state.cumsum_sq
        nancnt = out_state.nancnt
        out[i] = out_state.value
        
    return out
In [ ]:
returns = close['BTCUSDT'].vbt.to_returns()
In [ ]:
np.testing.assert_allclose(
    rolling_sharpe_ratio_nb(
        returns=returns.values, 
        window=10, 
        ddof=1, 
        ann_factor=ann_factor),
    returns.vbt.returns(freq='1h').rolling_sharpe_ratio(10).values
)

Callbacks

In [ ]:
class Memory(tp.NamedTuple):
    nobs: tp.Array1d
    old_wt: tp.Array1d
    weighted_avg: tp.Array1d
    prev_upper: tp.Array1d
    prev_lower: tp.Array1d
    prev_dir_: tp.Array1d
    cumsum: tp.Array1d
    cumsum_sq: tp.Array1d
    nancnt: tp.Array1d
    was_entry: tp.Array1d
    was_exit: tp.Array1d

@njit(nogil=True)
def pre_sim_func_nb(c):
    memory = Memory(
        nobs=np.full(c.target_shape[1], 0, dtype=np.int_),
        old_wt=np.full(c.target_shape[1], 1., dtype=np.float_),
        weighted_avg=np.full(c.target_shape[1], np.nan, dtype=np.float_),
        prev_upper=np.full(c.target_shape[1], np.nan, dtype=np.float_),
        prev_lower=np.full(c.target_shape[1], np.nan, dtype=np.float_),
        prev_dir_=np.full(c.target_shape[1], np.nan, dtype=np.float_),
        cumsum=np.full(c.target_shape[1], 0., dtype=np.float_),
        cumsum_sq=np.full(c.target_shape[1], 0., dtype=np.float_),
        nancnt=np.full(c.target_shape[1], 0, dtype=np.int_),
        was_entry=np.full(c.target_shape[1], False, dtype=np.bool_),
        was_exit=np.full(c.target_shape[1], False, dtype=np.bool_)
    )
    return (memory,)
In [ ]:
@njit(nogil=True)
def order_func_nb(c, memory, period, multiplier):
    is_entry = memory.was_entry[c.col]
    is_exit = memory.was_exit[c.col]
    
    in_state = SuperTrendAIS(
        i=c.i,
        high=c.high[c.i, c.col],
        low=c.low[c.i, c.col],
        close=c.close[c.i, c.col],
        prev_close=c.close[c.i - 1, c.col] if c.i > 0 else np.nan,
        prev_upper=memory.prev_upper[c.col],
        prev_lower=memory.prev_lower[c.col],
        prev_dir_=memory.prev_dir_[c.col],
        nobs=memory.nobs[c.col],
        weighted_avg=memory.weighted_avg[c.col],
        old_wt=memory.old_wt[c.col],
        period=period,
        multiplier=multiplier
    )

    out_state = superfast_supertrend_acc_nb(in_state)

    memory.nobs[c.col] = out_state.nobs
    memory.weighted_avg[c.col] = out_state.weighted_avg
    memory.old_wt[c.col] = out_state.old_wt
    memory.prev_upper[c.col] = out_state.upper
    memory.prev_lower[c.col] = out_state.lower
    memory.prev_dir_[c.col] = out_state.dir_
    memory.was_entry[c.col] = not np.isnan(out_state.long)
    memory.was_exit[c.col] = not np.isnan(out_state.short)
    
    in_position = c.position_now > 0
    if is_entry and not in_position:
        size = np.inf
    elif is_exit and in_position:
        size = -np.inf
    else:
        size = 0.
    return vbt.pf_nb.order_nb(
        size=size, 
        direction=vbt.pf_enums.Direction.LongOnly,
        fees=0.001
    )
In [ ]:
@njit(nogil=True)
def post_segment_func_nb(c, memory, ann_factor):
    for col in range(c.from_col, c.to_col):
        in_state = RollSharpeAIS(
            i=c.i,
            ret=c.last_return[col],
            pre_window_ret=np.nan,
            cumsum=memory.cumsum[col],
            cumsum_sq=memory.cumsum_sq[col],
            nancnt=memory.nancnt[col],
            window=c.i + 1,
            minp=0,
            ddof=1,
            ann_factor=ann_factor
        )
        out_state = rolling_sharpe_acc_nb(in_state)
        memory.cumsum[col] = out_state.cumsum
        memory.cumsum_sq[col] = out_state.cumsum_sq
        memory.nancnt[col] = out_state.nancnt
        c.in_outputs.sharpe[col] = out_state.value
In [ ]:
class InOutputs(tp.NamedTuple):
    sharpe: tp.Array1d

@njit(nogil=True)
def ctx_pipeline_nb(high, low, close, periods=np.array([7]), multipliers=np.array([3]), ann_factor=365):
    in_outputs = InOutputs(sharpe=np.empty(close.shape[1], dtype=np.float_))
    sharpe = np.empty(periods.size * close.shape[1], dtype=np.float_)
    group_lens = np.full(close.shape[1], 1)
    init_cash = 100.
    k = 0
    
    for i in range(periods.size):
        sim_out = vbt.pf_nb.from_order_func_nb(
            target_shape=close.shape,
            group_lens=group_lens,
            cash_sharing=False,
            init_cash=init_cash,
            pre_sim_func_nb=pre_sim_func_nb,
            order_func_nb=order_func_nb,
            order_args=(periods[i], multipliers[i]),
            post_segment_func_nb=post_segment_func_nb,
            post_segment_args=(ann_factor,),
            high=high,
            low=low,
            close=close,
            in_outputs=in_outputs,
            fill_pos_info=False,
            max_order_records=0
        )
        sharpe[k:k + close.shape[1]] = in_outputs.sharpe
        k += close.shape[1]
        
    return sharpe
In [ ]:
ctx_pipeline_nb(
    high.values, 
    low.values, 
    close.values,
    ann_factor=ann_factor
)
In [ ]:
%%timeit
ctx_pipeline_nb(
    high.values, 
    low.values, 
    close.values,
    ann_factor=ann_factor
)
In [ ]:
chunked_ctx_pipeline_nb = nb_chunked(ctx_pipeline_nb)
In [ ]:
chunked_ctx_pipeline_nb(
    high.values, 
    low.values,
    close.values,
    periods=period_product[:4], 
    multipliers=multiplier_product[:4],
    ann_factor=ann_factor,
    _n_chunks=2,
    _merge_kwargs=dict(input_columns=close.columns)
)
In [ ]:
%%timeit
chunked_ctx_pipeline_nb(
    high.values, 
    low.values, 
    close.values,
    periods=period_product, 
    multipliers=multiplier_product,
    ann_factor=ann_factor,
    _merge_kwargs=dict(input_columns=close.columns)
)
In [ ]:
%%timeit
chunked_ctx_pipeline_nb(
    high.values, 
    low.values, 
    close.values,
    periods=period_product, 
    multipliers=multiplier_product,
    ann_factor=ann_factor,
    _execute_kwargs=dict(engine='dask'),
    _merge_kwargs=dict(input_columns=close.columns)
)

Bonus: Own simulator

In [ ]:
@njit(nogil=True)
def raw_pipeline_nb(high, low, close, periods=np.array([7]), multipliers=np.array([3]), ann_factor=365):
    out = np.empty(periods.size * close.shape[1], dtype=np.float_)
    
    if close.shape[0] == 0:
        return out

    for k in range(len(periods)):
        
        for col in range(close.shape[1]):
            nobs = 0
            old_wt = 1.
            weighted_avg = np.nan
            prev_close_ = np.nan
            prev_upper = np.nan
            prev_lower = np.nan
            prev_dir_ = 1
            cumsum = 0.
            cumsum_sq = 0.
            nancnt = 0
            was_entry = False
            was_exit = False

            init_cash = 100.
            cash = init_cash
            position = 0.
            debt = 0.
            locked_cash = 0.
            free_cash = init_cash
            val_price = np.nan
            value = init_cash
            prev_value = init_cash
            return_ = 0.

            for i in range(close.shape[0]):
                is_entry = was_entry
                is_exit = was_exit

                st_in_state = SuperTrendAIS(
                    i=i,
                    high=high[i, col],
                    low=low[i, col],
                    close=close[i, col],
                    prev_close=prev_close_,
                    prev_upper=prev_upper,
                    prev_lower=prev_lower,
                    prev_dir_=prev_dir_,
                    nobs=nobs,
                    weighted_avg=weighted_avg,
                    old_wt=old_wt,
                    period=periods[k],
                    multiplier=multipliers[k]
                )

                st_out_state = superfast_supertrend_acc_nb(st_in_state)

                nobs = st_out_state.nobs
                weighted_avg = st_out_state.weighted_avg
                old_wt = st_out_state.old_wt
                prev_close_ = close[i, col]
                prev_upper = st_out_state.upper
                prev_lower = st_out_state.lower
                prev_dir_ = st_out_state.dir_
                was_entry = not np.isnan(st_out_state.long)
                was_exit = not np.isnan(st_out_state.short)

                if is_entry and position == 0:
                    size = np.inf
                elif is_exit and position > 0:
                    size = -np.inf
                else:
                    size = np.nan

                val_price = close[i, col]
                value = cash + position * val_price
                if not np.isnan(size):
                    exec_state = vbt.pf_enums.ExecState(
                        cash=cash,
                        position=position,
                        debt=debt,
                        locked_cash=locked_cash,
                        free_cash=free_cash,
                        val_price=val_price,
                        value=value
                    )
                    price_area = vbt.pf_enums.PriceArea(
                        open=np.nan,
                        high=high[i, col],
                        low=low[i, col],
                        close=close[i, col]
                    )
                    order = vbt.pf_nb.order_nb(
                        size=size, 
                        direction=vbt.pf_enums.Direction.LongOnly,
                        fees=0.001
                    )
                    _, new_exec_state = vbt.pf_nb.execute_order_nb(exec_state, order, price_area)
                    cash, position, debt, locked_cash, free_cash, val_price, value = new_exec_state

                value = cash + position * val_price
                return_ = vbt.ret_nb.get_return_nb(prev_value, value)
                prev_value = value

                sharpe_in_state = RollSharpeAIS(
                    i=i,
                    ret=return_,
                    pre_window_ret=np.nan,
                    cumsum=cumsum,
                    cumsum_sq=cumsum_sq,
                    nancnt=nancnt,
                    window=i + 1,
                    minp=0,
                    ddof=1,
                    ann_factor=ann_factor
                )
                sharpe_out_state = rolling_sharpe_acc_nb(sharpe_in_state)
                cumsum = sharpe_out_state.cumsum
                cumsum_sq = sharpe_out_state.cumsum_sq
                nancnt = sharpe_out_state.nancnt
                sharpe = sharpe_out_state.value

            out[k * close.shape[1] + col] = sharpe
        
    return out
In [ ]:
raw_pipeline_nb(
    high.values, 
    low.values, 
    close.values,
    ann_factor=ann_factor
)
In [ ]:
%%timeit
raw_pipeline_nb(
    high.values, 
    low.values, 
    close.values,
    ann_factor=ann_factor
)
In [ ]:
chunked_raw_pipeline_nb = nb_chunked(raw_pipeline_nb)
In [ ]:
chunked_raw_pipeline_nb(
    high.values, 
    low.values,
    close.values,
    periods=period_product[:4], 
    multipliers=multiplier_product[:4],
    ann_factor=ann_factor,
    _n_chunks=2,
    _merge_kwargs=dict(input_columns=close.columns)
)
In [ ]:
%%timeit
chunked_raw_pipeline_nb(
    high.values, 
    low.values, 
    close.values,
    periods=period_product, 
    multipliers=multiplier_product,
    ann_factor=ann_factor,
    _merge_kwargs=dict(input_columns=close.columns)
)
In [ ]:
%%timeit
chunked_raw_pipeline_nb(
    high.values, 
    low.values, 
    close.values,
    periods=period_product, 
    multipliers=multiplier_product,
    ann_factor=ann_factor,
    _execute_kwargs=dict(engine="dask"),
    _merge_kwargs=dict(input_columns=close.columns)
)
In [ ]:
range_len = int(vbt.timedelta('365d') / vbt.timedelta('1h'))
In [ ]:
splitter = vbt.Splitter.from_n_rolling(high.index, n=100, length=range_len)

roll_high = splitter.take(high, into="reset_stacked")
roll_low = splitter.take(low, into="reset_stacked")
roll_close = splitter.take(close, into="reset_stacked")

range_indexes = splitter.take(high.index)
In [ ]:
roll_close.columns
In [ ]:
range_indexes[0]
In [ ]:
sharpe_ratios = chunked_raw_pipeline_nb(
    roll_high.values, 
    roll_low.values,
    roll_close.values,
    periods=period_product, 
    multipliers=multiplier_product,
    ann_factor=ann_factor,
    _execute_kwargs=dict(engine="dask"),
    _merge_kwargs=dict(input_columns=roll_close.columns)
)
In [ ]:
sharpe_ratios
In [ ]:
pf_hold = vbt.Portfolio.from_holding(roll_close, freq='1h')
sharpe_ratios_hold = pf_hold.sharpe_ratio
In [ ]:
sharpe_ratios_hold
In [ ]:
def plot_subperiod_sharpe(index, sharpe_ratios, sharpe_ratios_hold, range_indexes, symbol):
    split = index[0]
    sharpe_ratios = sharpe_ratios.xs(
        symbol, 
        level='symbol', 
        drop_level=True)
    sharpe_ratios = sharpe_ratios.xs(
        split, 
        level='split', 
        drop_level=True)
    start_date = range_indexes[split][0]
    end_date = range_indexes[split][-1]
    return sharpe_ratios.vbt.heatmap(
        x_level='st_period', 
        y_level='st_multiplier',
        title="{} - {}".format(
            start_date.strftime("%d %b, %Y %H:%M:%S"),
            end_date.strftime("%d %b, %Y %H:%M:%S")
        ),
        trace_kwargs=dict(
            zmin=sharpe_ratios.min(),
            zmid=sharpe_ratios_hold[(split, symbol)],
            zmax=sharpe_ratios.max(),
            colorscale='Spectral'
        )
    )
In [ ]:
fname = 'raw_pipeline.gif'
level_idx = sharpe_ratios.index.names.index('split')
split_indices = sharpe_ratios.index.levels[level_idx]

vbt.save_animation(
    fname,
    split_indices, 
    plot_subperiod_sharpe,
    sharpe_ratios,
    sharpe_ratios_hold,
    range_indexes,
    'BTCUSDT',
    delta=1,
    fps=7,
    writer_kwargs=dict(loop=0)
)
In [ ]:
from IPython.display import Image, display
    
with open(fname,'rb') as f:
    display(Image(data=f.read(), format='png'))
In [ ]: