Files
strategy-lab/to_explore/notebooks/SuperTrend.ipynb

2473 lines
68 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"id": "d2ac023c-531a-487a-81f0-fcce3c9925b6",
"metadata": {},
"source": [
"# SuperFast SuperTrend"
]
},
{
"cell_type": "markdown",
"id": "e052b3c7-c1ec-4245-8025-fbacbc71e1a6",
"metadata": {},
"source": [
"## Data"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8ea9cb06-d59c-4c2c-807a-23854fa16676",
"metadata": {},
"outputs": [],
"source": [
"from vectorbtpro import *\n",
"# whats_imported()\n",
"\n",
"vbt.settings.set_theme('dark')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7f8c4e61-3d13-4659-b559-3cef404adfb4",
"metadata": {},
"outputs": [],
"source": [
"# data = vbt.BinanceData.pull(\n",
"# ['BTCUSDT', 'ETHUSDT'], \n",
"# start='2020-01-01 UTC',\n",
"# end='2022-01-01 UTC',\n",
"# timeframe='1h'\n",
"# )"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "508e2929-eab9-4db9-b581-d3a78959e173",
"metadata": {},
"outputs": [],
"source": [
"# data.to_hdf('my_data.h5')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5f4587b7-a9d2-4765-a52d-1ea9fa32bdbf",
"metadata": {},
"outputs": [],
"source": [
"data = vbt.HDFData.pull('my_data.h5')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "87de3e25-9836-45a4-9452-a3a0976a1d5d",
"metadata": {},
"outputs": [],
"source": [
"data.data['BTCUSDT'].info()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9389c053-5429-4fe7-a05f-cd2961ebaf94",
"metadata": {},
"outputs": [],
"source": [
"data.stats()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b10cf7a1-88a5-42d4-820b-ae7752ab85f6",
"metadata": {},
"outputs": [],
"source": [
"high = data.get('High')\n",
"low = data.get('Low')\n",
"close = data.get('Close')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d086a72b-de2e-4b0f-875d-54de8753b1ab",
"metadata": {},
"outputs": [],
"source": [
"print(close)"
]
},
{
"cell_type": "markdown",
"id": "cc9fb395-725a-4376-8d89-f3cfbc0fb9dd",
"metadata": {},
"source": [
"## Design"
]
},
{
"cell_type": "markdown",
"id": "94ad1367-4ca1-4e8a-8628-39242e93ceea",
"metadata": {},
"source": [
"### Pandas"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b94872b9-23db-42d4-b5d5-ec15176c4f6e",
"metadata": {},
"outputs": [],
"source": [
"def get_med_price(high, low):\n",
" return (high + low) / 2"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "45fdd6c7-4566-4468-8872-6f1f6c62991e",
"metadata": {},
"outputs": [],
"source": [
"def get_atr(high, low, close, period):\n",
" tr0 = abs(high - low)\n",
" tr1 = abs(high - close.shift())\n",
" tr2 = abs(low - close.shift())\n",
" tr = pd.concat((tr0, tr1, tr2), axis=1).max(axis=1)\n",
" atr = tr.ewm(alpha=1 / period, adjust=False, min_periods=period).mean()\n",
" return atr"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "37baaa9e-901d-4f49-a143-a9e1a24a20e7",
"metadata": {},
"outputs": [],
"source": [
"def get_basic_bands(med_price, atr, multiplier):\n",
" matr = multiplier * atr\n",
" upper = med_price + matr\n",
" lower = med_price - matr\n",
" return upper, lower"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b67ade24-1bec-411e-b570-e54ab89a47c5",
"metadata": {},
"outputs": [],
"source": [
"def get_final_bands(close, upper, lower):\n",
" trend = pd.Series(np.full(close.shape, np.nan), index=close.index)\n",
" dir_ = pd.Series(np.full(close.shape, 1), index=close.index)\n",
" long = pd.Series(np.full(close.shape, np.nan), index=close.index)\n",
" short = pd.Series(np.full(close.shape, np.nan), index=close.index)\n",
"\n",
" for i in range(1, close.shape[0]):\n",
" if close.iloc[i] > upper.iloc[i - 1]:\n",
" dir_.iloc[i] = 1\n",
" elif close.iloc[i] < lower.iloc[i - 1]:\n",
" dir_.iloc[i] = -1\n",
" else:\n",
" dir_.iloc[i] = dir_.iloc[i - 1]\n",
" if dir_.iloc[i] > 0 and lower.iloc[i] < lower.iloc[i - 1]:\n",
" lower.iloc[i] = lower.iloc[i - 1]\n",
" if dir_.iloc[i] < 0 and upper.iloc[i] > upper.iloc[i - 1]:\n",
" upper.iloc[i] = upper.iloc[i - 1]\n",
"\n",
" if dir_.iloc[i] > 0:\n",
" trend.iloc[i] = long.iloc[i] = lower.iloc[i]\n",
" else:\n",
" trend.iloc[i] = short.iloc[i] = upper.iloc[i]\n",
" \n",
" return trend, dir_, long, short"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "eb3af20c-d59e-4bef-854e-b59bad25acce",
"metadata": {},
"outputs": [],
"source": [
"def supertrend(high, low, close, period=7, multiplier=3):\n",
" med_price = get_med_price(high, low)\n",
" atr = get_atr(high, low, close, period)\n",
" upper, lower = get_basic_bands(med_price, atr, multiplier)\n",
" return get_final_bands(close, upper, lower)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ac2ba411-f0e2-49ba-87b0-c2510e0d35f0",
"metadata": {},
"outputs": [],
"source": [
"supert, superd, superl, supers = supertrend(\n",
" high['BTCUSDT'], \n",
" low['BTCUSDT'], \n",
" close['BTCUSDT']\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7baab77d-8681-49ac-af7e-7ca5671adbb2",
"metadata": {},
"outputs": [],
"source": [
"supert"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6d8c9e67-84e0-4102-8c33-1b3b6fc3ad0e",
"metadata": {},
"outputs": [],
"source": [
"superd"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bde7f7ec-7564-447a-bf48-27873ff72a8a",
"metadata": {},
"outputs": [],
"source": [
"superl"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f21db639-c324-4325-b9b9-0def69c37497",
"metadata": {},
"outputs": [],
"source": [
"supers"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "44cdaed3-f1a3-4202-8d74-fb9509aca2c8",
"metadata": {},
"outputs": [],
"source": [
"date_range = slice('2020-01-01', '2020-02-01')\n",
"fig = close.loc[date_range, 'BTCUSDT'].rename('Close').vbt.plot()\n",
"supers.loc[date_range].rename('Short').vbt.plot(fig=fig)\n",
"superl.loc[date_range].rename('Long').vbt.plot(fig=fig).show_svg()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5fc5caf2-d308-4401-b36a-dcfdc6c16b55",
"metadata": {},
"outputs": [],
"source": [
"%%timeit\n",
"supertrend(high['BTCUSDT'], low['BTCUSDT'], close['BTCUSDT'])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a73d8c09-378b-4c27-b306-8ea61153207b",
"metadata": {},
"outputs": [],
"source": [
"SUPERTREND = vbt.pandas_ta('SUPERTREND')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ed795d82-9324-413a-a343-9d2ad3b06750",
"metadata": {},
"outputs": [],
"source": [
"%%timeit\n",
"SUPERTREND.run(high['BTCUSDT'], low['BTCUSDT'], close['BTCUSDT'])"
]
},
{
"cell_type": "markdown",
"id": "42175d89-0ee6-42c4-9ce0-4124603cf90c",
"metadata": {},
"source": [
"### NumPy + Numba"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c401b752-fb8f-4e4f-a626-5b1bc1b00abf",
"metadata": {},
"outputs": [],
"source": [
"def get_atr_np(high, low, close, period):\n",
" shifted_close = vbt.nb.fshift_1d_nb(close)\n",
" tr0 = np.abs(high - low)\n",
" tr1 = np.abs(high - shifted_close)\n",
" tr2 = np.abs(low - shifted_close)\n",
" tr = np.column_stack((tr0, tr1, tr2)).max(axis=1)\n",
" atr = vbt.nb.wwm_mean_1d_nb(tr, period)\n",
" return atr"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5020dffd-48cb-4805-b258-5b6987792d1d",
"metadata": {},
"outputs": [],
"source": [
"@njit\n",
"def get_final_bands_nb(close, upper, lower):\n",
" trend = np.full(close.shape, np.nan)\n",
" dir_ = np.full(close.shape, 1)\n",
" long = np.full(close.shape, np.nan)\n",
" short = np.full(close.shape, np.nan)\n",
"\n",
" for i in range(1, close.shape[0]):\n",
" if close[i] > upper[i - 1]:\n",
" dir_[i] = 1\n",
" elif close[i] < lower[i - 1]:\n",
" dir_[i] = -1\n",
" else:\n",
" dir_[i] = dir_[i - 1]\n",
" if dir_[i] > 0 and lower[i] < lower[i - 1]:\n",
" lower[i] = lower[i - 1]\n",
" if dir_[i] < 0 and upper[i] > upper[i - 1]:\n",
" upper[i] = upper[i - 1]\n",
"\n",
" if dir_[i] > 0:\n",
" trend[i] = long[i] = lower[i]\n",
" else:\n",
" trend[i] = short[i] = upper[i]\n",
" \n",
" return trend, dir_, long, short"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0db7eb3a-c3ac-4495-a91e-ba056c09b7af",
"metadata": {},
"outputs": [],
"source": [
"def faster_supertrend(high, low, close, period=7, multiplier=3):\n",
" med_price = get_med_price(high, low)\n",
" atr = get_atr_np(high, low, close, period)\n",
" upper, lower = get_basic_bands(med_price, atr, multiplier)\n",
" return get_final_bands_nb(close, upper, lower)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1faaaafb-871c-46c9-a03e-2a81ec075e8f",
"metadata": {},
"outputs": [],
"source": [
"supert, superd, superl, supers = faster_supertrend(\n",
" high['BTCUSDT'].values, \n",
" low['BTCUSDT'].values, \n",
" close['BTCUSDT'].values\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1337f05f-b7e5-405a-abc0-7bab2099cabc",
"metadata": {},
"outputs": [],
"source": [
"supert"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "30a8e2dd-c24c-4b25-8854-89ab4f256911",
"metadata": {},
"outputs": [],
"source": [
"superd"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cd5eb5ad-cf92-47e8-9eba-3b222ce8a3c6",
"metadata": {},
"outputs": [],
"source": [
"superl"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f87b9091-a8b8-49a5-a48d-d35f21d46461",
"metadata": {},
"outputs": [],
"source": [
"supers"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "08d97bbe-5c1b-44fe-94bb-0ce136de5c1c",
"metadata": {},
"outputs": [],
"source": [
"pd.Series(supert, index=close.index)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1d9e597d-0cf8-415e-a199-eb55a4a513d5",
"metadata": {},
"outputs": [],
"source": [
"%%timeit\n",
"faster_supertrend(\n",
" high['BTCUSDT'].values, \n",
" low['BTCUSDT'].values,\n",
" close['BTCUSDT'].values\n",
")"
]
},
{
"cell_type": "markdown",
"id": "2e52b679-5838-46ad-9e91-7674021261e0",
"metadata": {},
"source": [
"### NumPy + Numba + TA-Lib"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f2de5960-18c0-4b21-a0a8-856738c14bc9",
"metadata": {},
"outputs": [],
"source": [
"import talib\n",
"\n",
"def faster_supertrend_talib(high, low, close, period=7, multiplier=3):\n",
" avg_price = talib.MEDPRICE(high, low)\n",
" atr = talib.ATR(high, low, close, period)\n",
" upper, lower = get_basic_bands(avg_price, atr, multiplier)\n",
" return get_final_bands_nb(close, upper, lower)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4ee2ed2c-2feb-4ef4-a8b7-914cd837610d",
"metadata": {},
"outputs": [],
"source": [
"faster_supertrend_talib(\n",
" high['BTCUSDT'].values, \n",
" low['BTCUSDT'].values, \n",
" close['BTCUSDT'].values\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "92ff0625-ad73-44d7-8984-8103a702a03b",
"metadata": {},
"outputs": [],
"source": [
"%%timeit\n",
"faster_supertrend_talib(\n",
" high['BTCUSDT'].values, \n",
" low['BTCUSDT'].values, \n",
" close['BTCUSDT'].values\n",
")"
]
},
{
"cell_type": "markdown",
"id": "9f2d1d3c-80f6-4bdf-972b-af349ec3e1ea",
"metadata": {},
"source": [
"## Indicator factory"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fcd4ca20-aa82-42f9-916b-7b9024890b14",
"metadata": {},
"outputs": [],
"source": [
"SuperTrend = vbt.IF(\n",
" class_name='SuperTrend',\n",
" short_name='st',\n",
" input_names=['high', 'low', 'close'],\n",
" param_names=['period', 'multiplier'],\n",
" output_names=['supert', 'superd', 'superl', 'supers']\n",
").with_apply_func(\n",
" faster_supertrend_talib, \n",
" takes_1d=True,\n",
" period=7, \n",
" multiplier=3\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c9b84719-41b9-4622-87a4-1afc4da1cafd",
"metadata": {},
"outputs": [],
"source": [
"help(SuperTrend.run)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "98812dfa-3e5b-487f-9bb0-738ff2e9ffdf",
"metadata": {},
"outputs": [],
"source": [
"st = SuperTrend.run(high, low, close)\n",
"print(st.supert)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "200ef4a9-cd6e-4f04-a279-0e4941933a8c",
"metadata": {},
"outputs": [],
"source": [
"%%timeit\n",
"SuperTrend.run(high, low, close)"
]
},
{
"cell_type": "markdown",
"id": "a038fda3-4789-4756-a298-f57a48cd2e99",
"metadata": {},
"source": [
"### Using expressions"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ffcfb664-d323-49d2-8544-9696dd74a316",
"metadata": {},
"outputs": [],
"source": [
"expr = \"\"\"\n",
"SuperTrend[st]:\n",
"medprice = @talib_medprice(high, low)\n",
"atr = @talib_atr(high, low, close, @p_period)\n",
"upper, lower = get_basic_bands(medprice, atr, @p_multiplier)\n",
"supert, superd, superl, supers = get_final_bands(close, upper, lower)\n",
"supert, superd, superl, supers\n",
"\"\"\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "522edc62-b6ad-4f1e-95f2-d8b871609b09",
"metadata": {},
"outputs": [],
"source": [
"SuperTrend = vbt.IF.from_expr(\n",
" expr, \n",
" takes_1d=True,\n",
" get_basic_bands=get_basic_bands,\n",
" get_final_bands=get_final_bands_nb,\n",
" period=7, \n",
" multiplier=3\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dcd6acd6-ee0e-4953-9bb0-223fbb0158f1",
"metadata": {},
"outputs": [],
"source": [
"st = SuperTrend.run(high, low, close)\n",
"print(st.supert)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3fcb34c2-d1cc-48e7-8eb4-fdd0a1a9b1bc",
"metadata": {},
"outputs": [],
"source": [
"%%timeit\n",
"SuperTrend.run(high, low, close)"
]
},
{
"cell_type": "markdown",
"id": "5b834219-4c05-478a-b581-9bb6ec0db8fb",
"metadata": {},
"source": [
"## Plot indicator"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "21bb38bf-b6da-42b1-ab48-039d9e75a16d",
"metadata": {},
"outputs": [],
"source": [
"class SuperTrend(SuperTrend):\n",
" def plot(self, \n",
" column=None, \n",
" close_kwargs=None,\n",
" superl_kwargs=None,\n",
" supers_kwargs=None,\n",
" fig=None, \n",
" **layout_kwargs):\n",
" close_kwargs = close_kwargs if close_kwargs else {}\n",
" superl_kwargs = superl_kwargs if superl_kwargs else {}\n",
" supers_kwargs = supers_kwargs if supers_kwargs else {}\n",
" \n",
" close = self.select_col_from_obj(self.close, column).rename('Close')\n",
" supers = self.select_col_from_obj(self.supers, column).rename('Short')\n",
" superl = self.select_col_from_obj(self.superl, column).rename('Long')\n",
" \n",
" fig = close.vbt.plot(fig=fig, **close_kwargs, **layout_kwargs)\n",
" supers.vbt.plot(fig=fig, **supers_kwargs)\n",
" superl.vbt.plot(fig=fig, **superl_kwargs)\n",
" \n",
" return fig"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "99b0be24-8256-468a-bdaf-bd82467e4819",
"metadata": {},
"outputs": [],
"source": [
"st = SuperTrend.run(high, low, close)\n",
"st.loc[date_range, 'BTCUSDT'].plot(\n",
" superl_kwargs=dict(trace_kwargs=dict(line_color='limegreen')),\n",
" supers_kwargs=dict(trace_kwargs=dict(line_color='red'))\n",
").show_svg()"
]
},
{
"cell_type": "markdown",
"id": "2d7956a9-a24c-4db1-871e-ef5f938e193e",
"metadata": {},
"source": [
"## Test indicator"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "df6dfa41-4206-4d61-a270-6d5f7fc77c2b",
"metadata": {},
"outputs": [],
"source": [
"entries = (~st.superl.isnull()).vbt.signals.fshift()\n",
"exits = (~st.supers.isnull()).vbt.signals.fshift()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1cd2f65d-2682-4cfd-bbcf-54702bc115b2",
"metadata": {},
"outputs": [],
"source": [
"pf = vbt.Portfolio.from_signals(\n",
" close=close, \n",
" entries=entries, \n",
" exits=exits, \n",
" fees=0.001, \n",
" freq='1h'\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "757609ba-f3ed-4c70-8e3c-9c6f0cdca05c",
"metadata": {},
"outputs": [],
"source": [
"pf['ETHUSDT'].stats()"
]
},
{
"cell_type": "markdown",
"id": "961fa4a3-74c3-4b90-9678-aaee56cf60af",
"metadata": {},
"source": [
"### Optimization"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a1b757f6-0c30-4102-8580-d6d82bf249f8",
"metadata": {},
"outputs": [],
"source": [
"periods = np.arange(4, 20)\n",
"multipliers = np.arange(20, 41) / 10"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "65b4d9ed-2866-4044-8023-829d5e9b0f30",
"metadata": {},
"outputs": [],
"source": [
"st = SuperTrend.run(\n",
" high, low, close, \n",
" period=periods, \n",
" multiplier=multipliers,\n",
" param_product=True,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "00adab7e-1ac1-4aca-a2b5-e00926381009",
"metadata": {},
"outputs": [],
"source": [
"st.wrapper.columns"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bcea7660-bf82-4ea2-ae3d-6407914167bc",
"metadata": {},
"outputs": [],
"source": [
"st.loc[date_range, (19, 4, 'ETHUSDT')].plot().show_svg()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c1ed2c21-e61c-49e1-8b08-1324380ca566",
"metadata": {},
"outputs": [],
"source": [
"print(st.getsize())"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4cc44d67-cabe-41e5-9d4b-c37527ff6dbb",
"metadata": {},
"outputs": [],
"source": [
"input_size = st.wrapper.shape[0] * st.wrapper.shape[1]\n",
"n_outputs = 4\n",
"data_type_size = 8\n",
"input_size * n_outputs * data_type_size / 1024 / 1024"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5b150641-4aa6-4634-98a9-8bd698448f4c",
"metadata": {},
"outputs": [],
"source": [
"entries = (~st.superl.isnull()).vbt.signals.fshift()\n",
"exits = (~st.supers.isnull()).vbt.signals.fshift()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3f5a43d0-9b85-4347-9bb9-f0e9b096d143",
"metadata": {},
"outputs": [],
"source": [
"pf = vbt.Portfolio.from_signals(close, entries, exits, fees=0.001, freq='1h')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4a9155ae-4971-4009-b1dc-607dc4116fed",
"metadata": {},
"outputs": [],
"source": [
"pf.sharpe_ratio.vbt.heatmap(\n",
" x_level='st_period', \n",
" y_level='st_multiplier',\n",
" slider_level='symbol'\n",
").show_svg()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "855fe67d-6c43-4a9a-9e03-df18d4d2e77c",
"metadata": {},
"outputs": [],
"source": [
"vbt.Portfolio.from_holding(close, freq='1h').sharpe_ratio"
]
},
{
"cell_type": "markdown",
"id": "25a4b932-81a0-42b5-98b5-f58988319952",
"metadata": {},
"source": [
"## Streaming"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3a132681-34a2-4a7d-ab67-5cf72645fd7c",
"metadata": {},
"outputs": [],
"source": [
"class SuperTrendAIS(tp.NamedTuple):\n",
" i: int\n",
" high: float\n",
" low: float\n",
" close: float\n",
" prev_close: float\n",
" prev_upper: float\n",
" prev_lower: float\n",
" prev_dir_: float\n",
" nobs: int\n",
" weighted_avg: float\n",
" old_wt: float\n",
" period: int\n",
" multiplier: float\n",
" \n",
"class SuperTrendAOS(tp.NamedTuple):\n",
" nobs: int\n",
" weighted_avg: float\n",
" old_wt: float\n",
" upper: float\n",
" lower: float\n",
" trend: float\n",
" dir_: float\n",
" long: float\n",
" short: float"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6dca1372-1c55-48f3-b569-7a193142fff6",
"metadata": {},
"outputs": [],
"source": [
"@njit(nogil=True)\n",
"def get_tr_one_nb(high, low, prev_close):\n",
" tr0 = abs(high - low)\n",
" tr1 = abs(high - prev_close)\n",
" tr2 = abs(low - prev_close)\n",
" if np.isnan(tr0) or np.isnan(tr1) or np.isnan(tr2):\n",
" tr = np.nan\n",
" else:\n",
" tr = max(tr0, tr1, tr2)\n",
" return tr\n",
"\n",
"@njit(nogil=True)\n",
"def get_med_price_one_nb(high, low):\n",
" return (high + low) / 2\n",
"\n",
"@njit(nogil=True)\n",
"def get_basic_bands_one_nb(high, low, atr, multiplier):\n",
" med_price = get_med_price_one_nb(high, low)\n",
" matr = multiplier * atr\n",
" upper = med_price + matr\n",
" lower = med_price - matr\n",
" return upper, lower\n",
" \n",
"@njit(nogil=True)\n",
"def get_final_bands_one_nb(close, upper, lower, prev_upper, prev_lower, prev_dir_):\n",
" if close > prev_upper:\n",
" dir_ = 1\n",
" elif close < prev_lower:\n",
" dir_ = -1\n",
" else:\n",
" dir_ = prev_dir_\n",
" if dir_ > 0 and lower < prev_lower:\n",
" lower = prev_lower\n",
" if dir_ < 0 and upper > prev_upper:\n",
" upper = prev_upper\n",
"\n",
" if dir_ > 0:\n",
" trend = long = lower\n",
" short = np.nan\n",
" else:\n",
" trend = short = upper\n",
" long = np.nan\n",
" return upper, lower, trend, dir_, long, short"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e787868d-99b9-41ea-ad97-e70480e021e8",
"metadata": {},
"outputs": [],
"source": [
"@njit(nogil=True)\n",
"def superfast_supertrend_acc_nb(in_state):\n",
" i = in_state.i\n",
" high = in_state.high\n",
" low = in_state.low\n",
" close = in_state.close\n",
" prev_close = in_state.prev_close\n",
" prev_upper = in_state.prev_upper\n",
" prev_lower = in_state.prev_lower\n",
" prev_dir_ = in_state.prev_dir_\n",
" nobs = in_state.nobs\n",
" weighted_avg = in_state.weighted_avg\n",
" old_wt = in_state.old_wt\n",
" period = in_state.period\n",
" multiplier = in_state.multiplier\n",
" \n",
" tr = get_tr_one_nb(high, low, prev_close)\n",
"\n",
" alpha = vbt.nb.alpha_from_wilder_nb(period)\n",
" ewm_mean_in_state = vbt.nb.EWMMeanAIS(\n",
" i=i,\n",
" value=tr,\n",
" old_wt=old_wt,\n",
" weighted_avg=weighted_avg,\n",
" nobs=nobs,\n",
" alpha=alpha,\n",
" minp=period,\n",
" adjust=False\n",
" )\n",
" ewm_mean_out_state = vbt.nb.ewm_mean_acc_nb(ewm_mean_in_state)\n",
" atr = ewm_mean_out_state.value\n",
" \n",
" upper, lower = get_basic_bands_one_nb(high, low, atr, multiplier)\n",
" \n",
" if i == 0:\n",
" trend, dir_, long, short = np.nan, 1, np.nan, np.nan\n",
" else:\n",
" upper, lower, trend, dir_, long, short = get_final_bands_one_nb(\n",
" close, upper, lower, prev_upper, prev_lower, prev_dir_)\n",
" \n",
" return SuperTrendAOS(\n",
" nobs=ewm_mean_out_state.nobs,\n",
" weighted_avg=ewm_mean_out_state.weighted_avg,\n",
" old_wt=ewm_mean_out_state.old_wt,\n",
" upper=upper,\n",
" lower=lower,\n",
" trend=trend,\n",
" dir_=dir_,\n",
" long=long,\n",
" short=short\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1d5a76d7-f981-4e64-bf5f-9cd69046cfb2",
"metadata": {},
"outputs": [],
"source": [
"@njit(nogil=True)\n",
"def superfast_supertrend_nb(high, low, close, period=7, multiplier=3):\n",
" trend = np.empty(close.shape, dtype=np.float_)\n",
" dir_ = np.empty(close.shape, dtype=np.int_)\n",
" long = np.empty(close.shape, dtype=np.float_)\n",
" short = np.empty(close.shape, dtype=np.float_)\n",
" \n",
" if close.shape[0] == 0:\n",
" return trend, dir_, long, short\n",
"\n",
" nobs = 0\n",
" old_wt = 1.\n",
" weighted_avg = np.nan\n",
" prev_upper = np.nan\n",
" prev_lower = np.nan\n",
"\n",
" for i in range(close.shape[0]):\n",
" in_state = SuperTrendAIS(\n",
" i=i,\n",
" high=high[i],\n",
" low=low[i],\n",
" close=close[i],\n",
" prev_close=close[i - 1] if i > 0 else np.nan,\n",
" prev_upper=prev_upper,\n",
" prev_lower=prev_lower,\n",
" prev_dir_=dir_[i - 1] if i > 0 else 1,\n",
" nobs=nobs,\n",
" weighted_avg=weighted_avg,\n",
" old_wt=old_wt,\n",
" period=period,\n",
" multiplier=multiplier\n",
" )\n",
" \n",
" out_state = superfast_supertrend_acc_nb(in_state)\n",
" \n",
" nobs = out_state.nobs\n",
" weighted_avg = out_state.weighted_avg\n",
" old_wt = out_state.old_wt\n",
" prev_upper = out_state.upper\n",
" prev_lower = out_state.lower\n",
" trend[i] = out_state.trend\n",
" dir_[i] = out_state.dir_\n",
" long[i] = out_state.long\n",
" short[i] = out_state.short\n",
" \n",
" return trend, dir_, long, short"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6a16e5ed-f62c-4523-a5e1-6b75e6af9a7b",
"metadata": {},
"outputs": [],
"source": [
"superfast_out = superfast_supertrend_nb(\n",
" high['BTCUSDT'].values,\n",
" low['BTCUSDT'].values,\n",
" close['BTCUSDT'].values\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "94757fa3-9008-4320-b66f-c72c94ab19b6",
"metadata": {},
"outputs": [],
"source": [
"faster_out = faster_supertrend(\n",
" high['BTCUSDT'].values,\n",
" low['BTCUSDT'].values,\n",
" close['BTCUSDT'].values\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f0733554-9a70-4ebb-9962-7ac87ddb63b9",
"metadata": {},
"outputs": [],
"source": [
"np.testing.assert_array_equal(superfast_out[0], faster_out[0])\n",
"np.testing.assert_array_equal(superfast_out[1], faster_out[1])\n",
"np.testing.assert_array_equal(superfast_out[2], faster_out[2])\n",
"np.testing.assert_array_equal(superfast_out[3], faster_out[3])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1cc3b596-f88e-4222-947f-2fb1ca0e879c",
"metadata": {},
"outputs": [],
"source": [
"%%timeit\n",
"superfast_supertrend_nb(\n",
" high['BTCUSDT'].values, \n",
" low['BTCUSDT'].values, \n",
" close['BTCUSDT'].values\n",
")"
]
},
{
"cell_type": "markdown",
"id": "08682671-abae-4f68-98d6-da623ed67c7d",
"metadata": {},
"source": [
"## Multithreading"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "da0c1330-ee15-409c-a57a-a2d95facaa3c",
"metadata": {},
"outputs": [],
"source": [
"SuperTrend = vbt.IF(\n",
" class_name='SuperTrend',\n",
" short_name='st',\n",
" input_names=['high', 'low', 'close'],\n",
" param_names=['period', 'multiplier'],\n",
" output_names=['supert', 'superd', 'superl', 'supers']\n",
").with_apply_func(\n",
" superfast_supertrend_nb, \n",
" takes_1d=True,\n",
" period=7, \n",
" multiplier=3\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "72fbbf4b-107f-4efd-9a13-01b3230b8ec0",
"metadata": {},
"outputs": [],
"source": [
"%%timeit\n",
"SuperTrend.run(high, low, close)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "01592b3c-cba0-4007-82d5-0026ade663c2",
"metadata": {},
"outputs": [],
"source": [
"%%timeit\n",
"SuperTrend.run(\n",
" high, low, close, \n",
" period=periods, \n",
" multiplier=multipliers,\n",
" param_product=True,\n",
" execute_kwargs=dict(show_progress=False)\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ee578ba8-c32c-4c4e-9da7-e39a92bfcb7e",
"metadata": {},
"outputs": [],
"source": [
"270 / 336 / 2"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "527a2a56-f5e8-4897-8288-40edeb023639",
"metadata": {},
"outputs": [],
"source": [
"%%timeit\n",
"SuperTrend.run(\n",
" high, low, close, \n",
" period=periods, \n",
" multiplier=multipliers,\n",
" param_product=True,\n",
" execute_kwargs=dict(\n",
" engine='dask', \n",
" chunk_len='auto', \n",
" show_progress=False\n",
" )\n",
")"
]
},
{
"cell_type": "markdown",
"id": "ac158ed3-1c7c-498f-843d-09a6d02439c4",
"metadata": {},
"source": [
"## Pipelines"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c5bcfaed-2d4c-4929-8eae-3505d682577d",
"metadata": {},
"outputs": [],
"source": [
"def pipeline(data, period=7, multiplier=3):\n",
" high = data.get('High')\n",
" low = data.get('Low')\n",
" close = data.get('Close')\n",
" st = SuperTrend.run(\n",
" high, \n",
" low, \n",
" close, \n",
" period=period, \n",
" multiplier=multiplier\n",
" )\n",
" entries = (~st.superl.isnull()).vbt.signals.fshift()\n",
" exits = (~st.supers.isnull()).vbt.signals.fshift()\n",
" pf = vbt.Portfolio.from_signals(\n",
" close, \n",
" entries=entries, \n",
" exits=exits, \n",
" fees=0.001,\n",
" save_returns=True,\n",
" max_order_records=0,\n",
" freq='1h'\n",
" )\n",
" return pf.sharpe_ratio"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e57ec5e5-71a3-46fa-a8d1-edc9d77876f3",
"metadata": {},
"outputs": [],
"source": [
"pipeline(data)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c87e3133-9238-4cf6-8fb2-518b36e1962d",
"metadata": {},
"outputs": [],
"source": [
"%%timeit\n",
"pipeline(data)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b3c94c2a-7c02-484f-9b5b-83892cc0d2e0",
"metadata": {},
"outputs": [],
"source": [
"336 * 32"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b23cfc27-81f4-43d6-9a3c-3e2136db980c",
"metadata": {},
"outputs": [],
"source": [
"op_tree = (product, periods, multipliers)\n",
"period_product, multiplier_product = vbt.generate_param_combs(op_tree)\n",
"period_product = np.asarray(period_product)\n",
"multiplier_product = np.asarray(multiplier_product)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d0c220d9-bd6f-48c5-a794-f02d6a6e3eeb",
"metadata": {},
"outputs": [],
"source": [
"%%timeit\n",
"pipeline(data, period_product, multiplier_product)"
]
},
{
"cell_type": "markdown",
"id": "52529de0-d7a9-4e4e-9a91-4c2449f3405f",
"metadata": {},
"source": [
"### Chunked pipeline"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d8413e9c-5fe1-4339-a289-f9d1f7595ce6",
"metadata": {},
"outputs": [],
"source": [
"chunked_pipeline = vbt.chunked(\n",
" size=vbt.LenSizer(arg_query='period', single_type=int),\n",
" arg_take_spec=dict(\n",
" data=None,\n",
" period=vbt.ChunkSlicer(),\n",
" multiplier=vbt.ChunkSlicer()\n",
" ),\n",
" merge_func=lambda x: pd.concat(x).sort_index()\n",
")(pipeline)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bb15c6bc-d73e-4449-bb98-d43432dc9dad",
"metadata": {},
"outputs": [],
"source": [
"chunked_pipeline(data)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "728cadd0-a62f-49bb-8c86-c24df3642ad2",
"metadata": {},
"outputs": [],
"source": [
"chunked_pipeline(\n",
" data, \n",
" period_product[:4], \n",
" multiplier_product[:4],\n",
" _n_chunks=2,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bc3206ff-4b53-459c-91cd-7183193d76f2",
"metadata": {},
"outputs": [],
"source": [
"chunk_meta, tasks = chunked_pipeline(\n",
" data, \n",
" period_product[:4], \n",
" multiplier_product[:4],\n",
" _n_chunks=2,\n",
" _return_raw_chunks=True\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fc340146-43ee-40a2-8bb2-60494469d39e",
"metadata": {},
"outputs": [],
"source": [
"chunk_meta"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f734466d-59f9-49c6-ae2b-b3ff038833df",
"metadata": {},
"outputs": [],
"source": [
"list(tasks)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d279086d-743e-4044-9fe7-1d46e74d04e4",
"metadata": {},
"outputs": [],
"source": [
"%%timeit\n",
"chunked_pipeline(data, period_product, multiplier_product)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5125c3af-b330-47a2-a24d-49e58d37e471",
"metadata": {},
"outputs": [],
"source": [
"%%timeit\n",
"chunked_pipeline(data, period_product, multiplier_product, _chunk_len=1)"
]
},
{
"cell_type": "markdown",
"id": "0e6c55bd-6e54-4ebc-bb9a-9f2a2805403d",
"metadata": {},
"source": [
"### Numba pipeline"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d8830575-46de-4884-a278-214c258ecc00",
"metadata": {},
"outputs": [],
"source": [
"@njit(nogil=True)\n",
"def pipeline_nb(high, low, close, periods=np.array([7]), multipliers=np.array([3]), ann_factor=365):\n",
" sharpe = np.empty(periods.size * close.shape[1], dtype=np.float_)\n",
" long_entries = np.empty(close.shape, dtype=np.bool_)\n",
" long_exits = np.empty(close.shape, dtype=np.bool_)\n",
" group_lens = np.full(close.shape[1], 1)\n",
" init_cash = 100.\n",
" fees = 0.001\n",
" k = 0\n",
" \n",
" for i in range(periods.size):\n",
" for col in range(close.shape[1]):\n",
" _, _, superl, supers = superfast_supertrend_nb(\n",
" high[:, col], \n",
" low[:, col], \n",
" close[:, col], \n",
" periods[i], \n",
" multipliers[i]\n",
" )\n",
" long_entries[:, col] = vbt.nb.fshift_1d_nb(~np.isnan(superl), fill_value=False)\n",
" long_exits[:, col] = vbt.nb.fshift_1d_nb(~np.isnan(supers), fill_value=False)\n",
" \n",
" sim_out = vbt.pf_nb.from_signals_nb(\n",
" target_shape=close.shape,\n",
" group_lens=group_lens,\n",
" init_cash=init_cash,\n",
" high=high,\n",
" low=low,\n",
" close=close,\n",
" long_entries=long_entries,\n",
" long_exits=long_exits,\n",
" fees=fees,\n",
" save_returns=True\n",
" )\n",
" returns = sim_out.in_outputs.returns\n",
" sharpe[k:k + close.shape[1]] = vbt.ret_nb.sharpe_ratio_nb(returns, ann_factor, ddof=1)\n",
" k += close.shape[1]\n",
" \n",
" return sharpe"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2eabd3cf-5f03-4cbb-8b1d-6bba6bcbbd3d",
"metadata": {},
"outputs": [],
"source": [
"ann_factor = vbt.pd_acc.returns.get_ann_factor(freq='1h')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d5a82762-332a-4b38-9739-0da01d06a02d",
"metadata": {},
"outputs": [],
"source": [
"pipeline_nb(\n",
" high.values, \n",
" low.values, \n",
" close.values,\n",
" ann_factor=ann_factor\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cd786362-5c17-4d9a-a8b7-77604c47942f",
"metadata": {},
"outputs": [],
"source": [
"%%timeit\n",
"pipeline_nb(\n",
" high.values, \n",
" low.values, \n",
" close.values,\n",
" ann_factor=ann_factor\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1796f5ea-c0de-4249-98e5-ce14c550108f",
"metadata": {},
"outputs": [],
"source": [
"def merge_func(arrs, ann_args, input_columns):\n",
" arr = np.concatenate(arrs)\n",
" param_index = vbt.stack_indexes((\n",
" pd.Index(ann_args['periods']['value'], name='st_period'),\n",
" pd.Index(ann_args['multipliers']['value'], name='st_multiplier')\n",
" ))\n",
" index = vbt.combine_indexes((\n",
" param_index,\n",
" input_columns\n",
" ))\n",
" return pd.Series(arr, index=index)\n",
"\n",
"nb_chunked = vbt.chunked(\n",
" size=vbt.ArraySizer(arg_query='periods', axis=0),\n",
" arg_take_spec=dict(\n",
" high=None,\n",
" low=None,\n",
" close=None,\n",
" periods=vbt.ArraySlicer(axis=0),\n",
" multipliers=vbt.ArraySlicer(axis=0),\n",
" ann_factor=None\n",
" ),\n",
" merge_func=merge_func,\n",
" merge_kwargs=dict(\n",
" ann_args=vbt.Rep(\"ann_args\")\n",
" )\n",
")\n",
"chunked_pipeline_nb = nb_chunked(pipeline_nb)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8f0a9058-bc57-4331-b9e5-31fd511c8862",
"metadata": {},
"outputs": [],
"source": [
"chunked_pipeline_nb(\n",
" high.values, \n",
" low.values,\n",
" close.values,\n",
" periods=period_product[:4], \n",
" multipliers=multiplier_product[:4],\n",
" ann_factor=ann_factor,\n",
" _n_chunks=2,\n",
" _merge_kwargs=dict(input_columns=close.columns)\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "badba6cb-3405-498d-b366-cd9205721541",
"metadata": {},
"outputs": [],
"source": [
"%%timeit\n",
"chunked_pipeline_nb(\n",
" high.values, \n",
" low.values, \n",
" close.values,\n",
" periods=period_product, \n",
" multipliers=multiplier_product,\n",
" ann_factor=ann_factor,\n",
" _merge_kwargs=dict(input_columns=close.columns)\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c5208092-e88f-428a-a2da-f43d325ab20e",
"metadata": {},
"outputs": [],
"source": [
"%%timeit\n",
"chunked_pipeline_nb(\n",
" high.values, \n",
" low.values, \n",
" close.values,\n",
" periods=period_product, \n",
" multipliers=multiplier_product,\n",
" ann_factor=ann_factor,\n",
" _execute_kwargs=dict(engine='dask'),\n",
" _merge_kwargs=dict(input_columns=close.columns)\n",
")"
]
},
{
"cell_type": "markdown",
"id": "b03f9632-ebfd-4073-a80b-a1671a3d3361",
"metadata": {},
"source": [
"### Contextualized pipeline"
]
},
{
"cell_type": "markdown",
"id": "3eded1d5-1a4a-410c-96ab-3f715be479a1",
"metadata": {},
"source": [
"#### Streaming Sharpe"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "03f8b212-4ae3-404e-90fe-4c44a867e556",
"metadata": {},
"outputs": [],
"source": [
"class RollSharpeAIS(tp.NamedTuple):\n",
" i: int\n",
" ret: float\n",
" pre_window_ret: float\n",
" cumsum: float\n",
" cumsum_sq: float\n",
" nancnt: int\n",
" window: int\n",
" minp: tp.Optional[int]\n",
" ddof: int\n",
" ann_factor: float\n",
" \n",
"class RollSharpeAOS(tp.NamedTuple):\n",
" cumsum: float\n",
" cumsum_sq: float\n",
" nancnt: int\n",
" value: float\n",
"\n",
"@njit(nogil=True)\n",
"def rolling_sharpe_acc_nb(in_state):\n",
" mean_in_state = vbt.nb.RollMeanAIS(\n",
" i=in_state.i,\n",
" value=in_state.ret,\n",
" pre_window_value=in_state.pre_window_ret,\n",
" cumsum=in_state.cumsum,\n",
" nancnt=in_state.nancnt,\n",
" window=in_state.window,\n",
" minp=in_state.minp\n",
" )\n",
" mean_out_state = vbt.nb.rolling_mean_acc_nb(mean_in_state)\n",
" \n",
" std_in_state = vbt.nb.RollStdAIS(\n",
" i=in_state.i,\n",
" value=in_state.ret,\n",
" pre_window_value=in_state.pre_window_ret,\n",
" cumsum=in_state.cumsum,\n",
" cumsum_sq=in_state.cumsum_sq,\n",
" nancnt=in_state.nancnt,\n",
" window=in_state.window,\n",
" minp=in_state.minp,\n",
" ddof=in_state.ddof\n",
" )\n",
" std_out_state = vbt.nb.rolling_std_acc_nb(std_in_state)\n",
" \n",
" mean = mean_out_state.value\n",
" std = std_out_state.value\n",
" if std == 0:\n",
" sharpe = np.nan\n",
" else:\n",
" sharpe = mean / std * np.sqrt(in_state.ann_factor)\n",
" return RollSharpeAOS(\n",
" cumsum=std_out_state.cumsum,\n",
" cumsum_sq=std_out_state.cumsum_sq,\n",
" nancnt=std_out_state.nancnt,\n",
" value=sharpe\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1339d8a8-c186-4fea-abc3-fd926b17149d",
"metadata": {},
"outputs": [],
"source": [
"@njit(nogil=True)\n",
"def rolling_sharpe_ratio_nb(returns, window, minp=None, ddof=0, ann_factor=365):\n",
" if window is None:\n",
" window = returns.shape[0]\n",
" if minp is None:\n",
" minp = window\n",
" out = np.empty(returns.shape, dtype=np.float_)\n",
" \n",
" if returns.shape[0] == 0:\n",
" return out\n",
"\n",
" cumsum = 0.\n",
" cumsum_sq = 0.\n",
" nancnt = 0\n",
"\n",
" for i in range(returns.shape[0]):\n",
" in_state = RollSharpeAIS(\n",
" i=i,\n",
" ret=returns[i],\n",
" pre_window_ret=returns[i - window] if i - window >= 0 else np.nan,\n",
" cumsum=cumsum,\n",
" cumsum_sq=cumsum_sq,\n",
" nancnt=nancnt,\n",
" window=window,\n",
" minp=minp,\n",
" ddof=ddof,\n",
" ann_factor=ann_factor\n",
" )\n",
" \n",
" out_state = rolling_sharpe_acc_nb(in_state)\n",
" \n",
" cumsum = out_state.cumsum\n",
" cumsum_sq = out_state.cumsum_sq\n",
" nancnt = out_state.nancnt\n",
" out[i] = out_state.value\n",
" \n",
" return out"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "66ff280b-0057-4095-b473-ceee2ee1e0f1",
"metadata": {},
"outputs": [],
"source": [
"returns = close['BTCUSDT'].vbt.to_returns()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "084c24fc-53b0-4d85-a184-0d3a8503cd00",
"metadata": {},
"outputs": [],
"source": [
"np.testing.assert_allclose(\n",
" rolling_sharpe_ratio_nb(\n",
" returns=returns.values, \n",
" window=10, \n",
" ddof=1, \n",
" ann_factor=ann_factor),\n",
" returns.vbt.returns(freq='1h').rolling_sharpe_ratio(10).values\n",
")"
]
},
{
"cell_type": "markdown",
"id": "47c7edb2-f5de-4d7f-99a7-a35e5dcdec46",
"metadata": {},
"source": [
"#### Callbacks"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "28fba511-1030-4c6b-8592-5a965b8b6b1e",
"metadata": {},
"outputs": [],
"source": [
"class Memory(tp.NamedTuple):\n",
" nobs: tp.Array1d\n",
" old_wt: tp.Array1d\n",
" weighted_avg: tp.Array1d\n",
" prev_upper: tp.Array1d\n",
" prev_lower: tp.Array1d\n",
" prev_dir_: tp.Array1d\n",
" cumsum: tp.Array1d\n",
" cumsum_sq: tp.Array1d\n",
" nancnt: tp.Array1d\n",
" was_entry: tp.Array1d\n",
" was_exit: tp.Array1d\n",
"\n",
"@njit(nogil=True)\n",
"def pre_sim_func_nb(c):\n",
" memory = Memory(\n",
" nobs=np.full(c.target_shape[1], 0, dtype=np.int_),\n",
" old_wt=np.full(c.target_shape[1], 1., dtype=np.float_),\n",
" weighted_avg=np.full(c.target_shape[1], np.nan, dtype=np.float_),\n",
" prev_upper=np.full(c.target_shape[1], np.nan, dtype=np.float_),\n",
" prev_lower=np.full(c.target_shape[1], np.nan, dtype=np.float_),\n",
" prev_dir_=np.full(c.target_shape[1], np.nan, dtype=np.float_),\n",
" cumsum=np.full(c.target_shape[1], 0., dtype=np.float_),\n",
" cumsum_sq=np.full(c.target_shape[1], 0., dtype=np.float_),\n",
" nancnt=np.full(c.target_shape[1], 0, dtype=np.int_),\n",
" was_entry=np.full(c.target_shape[1], False, dtype=np.bool_),\n",
" was_exit=np.full(c.target_shape[1], False, dtype=np.bool_)\n",
" )\n",
" return (memory,)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5a3f8ffd-82d1-4b86-80c5-0d9d4e9a5680",
"metadata": {},
"outputs": [],
"source": [
"@njit(nogil=True)\n",
"def order_func_nb(c, memory, period, multiplier):\n",
" is_entry = memory.was_entry[c.col]\n",
" is_exit = memory.was_exit[c.col]\n",
" \n",
" in_state = SuperTrendAIS(\n",
" i=c.i,\n",
" high=c.high[c.i, c.col],\n",
" low=c.low[c.i, c.col],\n",
" close=c.close[c.i, c.col],\n",
" prev_close=c.close[c.i - 1, c.col] if c.i > 0 else np.nan,\n",
" prev_upper=memory.prev_upper[c.col],\n",
" prev_lower=memory.prev_lower[c.col],\n",
" prev_dir_=memory.prev_dir_[c.col],\n",
" nobs=memory.nobs[c.col],\n",
" weighted_avg=memory.weighted_avg[c.col],\n",
" old_wt=memory.old_wt[c.col],\n",
" period=period,\n",
" multiplier=multiplier\n",
" )\n",
"\n",
" out_state = superfast_supertrend_acc_nb(in_state)\n",
"\n",
" memory.nobs[c.col] = out_state.nobs\n",
" memory.weighted_avg[c.col] = out_state.weighted_avg\n",
" memory.old_wt[c.col] = out_state.old_wt\n",
" memory.prev_upper[c.col] = out_state.upper\n",
" memory.prev_lower[c.col] = out_state.lower\n",
" memory.prev_dir_[c.col] = out_state.dir_\n",
" memory.was_entry[c.col] = not np.isnan(out_state.long)\n",
" memory.was_exit[c.col] = not np.isnan(out_state.short)\n",
" \n",
" in_position = c.position_now > 0\n",
" if is_entry and not in_position:\n",
" size = np.inf\n",
" elif is_exit and in_position:\n",
" size = -np.inf\n",
" else:\n",
" size = 0.\n",
" return vbt.pf_nb.order_nb(\n",
" size=size, \n",
" direction=vbt.pf_enums.Direction.LongOnly,\n",
" fees=0.001\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "73e93b02-3a10-4143-abf4-24ec71afb42f",
"metadata": {},
"outputs": [],
"source": [
"@njit(nogil=True)\n",
"def post_segment_func_nb(c, memory, ann_factor):\n",
" for col in range(c.from_col, c.to_col):\n",
" in_state = RollSharpeAIS(\n",
" i=c.i,\n",
" ret=c.last_return[col],\n",
" pre_window_ret=np.nan,\n",
" cumsum=memory.cumsum[col],\n",
" cumsum_sq=memory.cumsum_sq[col],\n",
" nancnt=memory.nancnt[col],\n",
" window=c.i + 1,\n",
" minp=0,\n",
" ddof=1,\n",
" ann_factor=ann_factor\n",
" )\n",
" out_state = rolling_sharpe_acc_nb(in_state)\n",
" memory.cumsum[col] = out_state.cumsum\n",
" memory.cumsum_sq[col] = out_state.cumsum_sq\n",
" memory.nancnt[col] = out_state.nancnt\n",
" c.in_outputs.sharpe[col] = out_state.value"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fe428781-32c6-4be7-a65a-97827372a31f",
"metadata": {},
"outputs": [],
"source": [
"class InOutputs(tp.NamedTuple):\n",
" sharpe: tp.Array1d\n",
"\n",
"@njit(nogil=True)\n",
"def ctx_pipeline_nb(high, low, close, periods=np.array([7]), multipliers=np.array([3]), ann_factor=365):\n",
" in_outputs = InOutputs(sharpe=np.empty(close.shape[1], dtype=np.float_))\n",
" sharpe = np.empty(periods.size * close.shape[1], dtype=np.float_)\n",
" group_lens = np.full(close.shape[1], 1)\n",
" init_cash = 100.\n",
" k = 0\n",
" \n",
" for i in range(periods.size):\n",
" sim_out = vbt.pf_nb.from_order_func_nb(\n",
" target_shape=close.shape,\n",
" group_lens=group_lens,\n",
" cash_sharing=False,\n",
" init_cash=init_cash,\n",
" pre_sim_func_nb=pre_sim_func_nb,\n",
" order_func_nb=order_func_nb,\n",
" order_args=(periods[i], multipliers[i]),\n",
" post_segment_func_nb=post_segment_func_nb,\n",
" post_segment_args=(ann_factor,),\n",
" high=high,\n",
" low=low,\n",
" close=close,\n",
" in_outputs=in_outputs,\n",
" fill_pos_info=False,\n",
" max_order_records=0\n",
" )\n",
" sharpe[k:k + close.shape[1]] = in_outputs.sharpe\n",
" k += close.shape[1]\n",
" \n",
" return sharpe"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f3a0e59c-b063-442b-bf9c-9007cabac426",
"metadata": {},
"outputs": [],
"source": [
"ctx_pipeline_nb(\n",
" high.values, \n",
" low.values, \n",
" close.values,\n",
" ann_factor=ann_factor\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2876b2d3-c306-4340-80ea-1106d9953f95",
"metadata": {},
"outputs": [],
"source": [
"%%timeit\n",
"ctx_pipeline_nb(\n",
" high.values, \n",
" low.values, \n",
" close.values,\n",
" ann_factor=ann_factor\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3c85d1df-5094-4664-888e-90f31e0727b9",
"metadata": {},
"outputs": [],
"source": [
"chunked_ctx_pipeline_nb = nb_chunked(ctx_pipeline_nb)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3d4ed0b0-975e-4367-b17c-d233142dddb1",
"metadata": {},
"outputs": [],
"source": [
"chunked_ctx_pipeline_nb(\n",
" high.values, \n",
" low.values,\n",
" close.values,\n",
" periods=period_product[:4], \n",
" multipliers=multiplier_product[:4],\n",
" ann_factor=ann_factor,\n",
" _n_chunks=2,\n",
" _merge_kwargs=dict(input_columns=close.columns)\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5e837057-bfbf-404f-916a-fb1826cc911a",
"metadata": {},
"outputs": [],
"source": [
"%%timeit\n",
"chunked_ctx_pipeline_nb(\n",
" high.values, \n",
" low.values, \n",
" close.values,\n",
" periods=period_product, \n",
" multipliers=multiplier_product,\n",
" ann_factor=ann_factor,\n",
" _merge_kwargs=dict(input_columns=close.columns)\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "03d97ac5-d061-4b8c-a45e-857c21e7b724",
"metadata": {},
"outputs": [],
"source": [
"%%timeit\n",
"chunked_ctx_pipeline_nb(\n",
" high.values, \n",
" low.values, \n",
" close.values,\n",
" periods=period_product, \n",
" multipliers=multiplier_product,\n",
" ann_factor=ann_factor,\n",
" _execute_kwargs=dict(engine='dask'),\n",
" _merge_kwargs=dict(input_columns=close.columns)\n",
")"
]
},
{
"cell_type": "markdown",
"id": "c0d07ec5-9650-4db1-a3cd-def9a0950af2",
"metadata": {},
"source": [
"### Bonus: Own simulator"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "56dcbce3-1df4-4577-8a49-122c231f309f",
"metadata": {},
"outputs": [],
"source": [
"@njit(nogil=True)\n",
"def raw_pipeline_nb(high, low, close, periods=np.array([7]), multipliers=np.array([3]), ann_factor=365):\n",
" out = np.empty(periods.size * close.shape[1], dtype=np.float_)\n",
" \n",
" if close.shape[0] == 0:\n",
" return out\n",
"\n",
" for k in range(len(periods)):\n",
" \n",
" for col in range(close.shape[1]):\n",
" nobs = 0\n",
" old_wt = 1.\n",
" weighted_avg = np.nan\n",
" prev_close_ = np.nan\n",
" prev_upper = np.nan\n",
" prev_lower = np.nan\n",
" prev_dir_ = 1\n",
" cumsum = 0.\n",
" cumsum_sq = 0.\n",
" nancnt = 0\n",
" was_entry = False\n",
" was_exit = False\n",
"\n",
" init_cash = 100.\n",
" cash = init_cash\n",
" position = 0.\n",
" debt = 0.\n",
" locked_cash = 0.\n",
" free_cash = init_cash\n",
" val_price = np.nan\n",
" value = init_cash\n",
" prev_value = init_cash\n",
" return_ = 0.\n",
"\n",
" for i in range(close.shape[0]):\n",
" is_entry = was_entry\n",
" is_exit = was_exit\n",
"\n",
" st_in_state = SuperTrendAIS(\n",
" i=i,\n",
" high=high[i, col],\n",
" low=low[i, col],\n",
" close=close[i, col],\n",
" prev_close=prev_close_,\n",
" prev_upper=prev_upper,\n",
" prev_lower=prev_lower,\n",
" prev_dir_=prev_dir_,\n",
" nobs=nobs,\n",
" weighted_avg=weighted_avg,\n",
" old_wt=old_wt,\n",
" period=periods[k],\n",
" multiplier=multipliers[k]\n",
" )\n",
"\n",
" st_out_state = superfast_supertrend_acc_nb(st_in_state)\n",
"\n",
" nobs = st_out_state.nobs\n",
" weighted_avg = st_out_state.weighted_avg\n",
" old_wt = st_out_state.old_wt\n",
" prev_close_ = close[i, col]\n",
" prev_upper = st_out_state.upper\n",
" prev_lower = st_out_state.lower\n",
" prev_dir_ = st_out_state.dir_\n",
" was_entry = not np.isnan(st_out_state.long)\n",
" was_exit = not np.isnan(st_out_state.short)\n",
"\n",
" if is_entry and position == 0:\n",
" size = np.inf\n",
" elif is_exit and position > 0:\n",
" size = -np.inf\n",
" else:\n",
" size = np.nan\n",
"\n",
" val_price = close[i, col]\n",
" value = cash + position * val_price\n",
" if not np.isnan(size):\n",
" exec_state = vbt.pf_enums.ExecState(\n",
" cash=cash,\n",
" position=position,\n",
" debt=debt,\n",
" locked_cash=locked_cash,\n",
" free_cash=free_cash,\n",
" val_price=val_price,\n",
" value=value\n",
" )\n",
" price_area = vbt.pf_enums.PriceArea(\n",
" open=np.nan,\n",
" high=high[i, col],\n",
" low=low[i, col],\n",
" close=close[i, col]\n",
" )\n",
" order = vbt.pf_nb.order_nb(\n",
" size=size, \n",
" direction=vbt.pf_enums.Direction.LongOnly,\n",
" fees=0.001\n",
" )\n",
" _, new_exec_state = vbt.pf_nb.execute_order_nb(exec_state, order, price_area)\n",
" cash, position, debt, locked_cash, free_cash, val_price, value = new_exec_state\n",
"\n",
" value = cash + position * val_price\n",
" return_ = vbt.ret_nb.get_return_nb(prev_value, value)\n",
" prev_value = value\n",
"\n",
" sharpe_in_state = RollSharpeAIS(\n",
" i=i,\n",
" ret=return_,\n",
" pre_window_ret=np.nan,\n",
" cumsum=cumsum,\n",
" cumsum_sq=cumsum_sq,\n",
" nancnt=nancnt,\n",
" window=i + 1,\n",
" minp=0,\n",
" ddof=1,\n",
" ann_factor=ann_factor\n",
" )\n",
" sharpe_out_state = rolling_sharpe_acc_nb(sharpe_in_state)\n",
" cumsum = sharpe_out_state.cumsum\n",
" cumsum_sq = sharpe_out_state.cumsum_sq\n",
" nancnt = sharpe_out_state.nancnt\n",
" sharpe = sharpe_out_state.value\n",
"\n",
" out[k * close.shape[1] + col] = sharpe\n",
" \n",
" return out"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "99f0e6a6-68ca-4d6e-8160-a9dfa5fc6153",
"metadata": {},
"outputs": [],
"source": [
"raw_pipeline_nb(\n",
" high.values, \n",
" low.values, \n",
" close.values,\n",
" ann_factor=ann_factor\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6e16670e-ecd3-4d94-b926-e7773eb088dd",
"metadata": {},
"outputs": [],
"source": [
"%%timeit\n",
"raw_pipeline_nb(\n",
" high.values, \n",
" low.values, \n",
" close.values,\n",
" ann_factor=ann_factor\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "14812b10-2e4f-4171-ba91-eecd5c329394",
"metadata": {},
"outputs": [],
"source": [
"chunked_raw_pipeline_nb = nb_chunked(raw_pipeline_nb)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8339de75-6c3c-4e59-9df0-17f969e8d14c",
"metadata": {},
"outputs": [],
"source": [
"chunked_raw_pipeline_nb(\n",
" high.values, \n",
" low.values,\n",
" close.values,\n",
" periods=period_product[:4], \n",
" multipliers=multiplier_product[:4],\n",
" ann_factor=ann_factor,\n",
" _n_chunks=2,\n",
" _merge_kwargs=dict(input_columns=close.columns)\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1fa1ac51-3cea-4ae8-86dc-d2b5068994b0",
"metadata": {},
"outputs": [],
"source": [
"%%timeit\n",
"chunked_raw_pipeline_nb(\n",
" high.values, \n",
" low.values, \n",
" close.values,\n",
" periods=period_product, \n",
" multipliers=multiplier_product,\n",
" ann_factor=ann_factor,\n",
" _merge_kwargs=dict(input_columns=close.columns)\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d8115735-80ca-463e-ba6d-f0892e413b17",
"metadata": {},
"outputs": [],
"source": [
"%%timeit\n",
"chunked_raw_pipeline_nb(\n",
" high.values, \n",
" low.values, \n",
" close.values,\n",
" periods=period_product, \n",
" multipliers=multiplier_product,\n",
" ann_factor=ann_factor,\n",
" _execute_kwargs=dict(engine=\"dask\"),\n",
" _merge_kwargs=dict(input_columns=close.columns)\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "382e71b5-5cbc-46fa-b2de-ce031ad902a0",
"metadata": {},
"outputs": [],
"source": [
"range_len = int(vbt.timedelta('365d') / vbt.timedelta('1h'))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dabf85c8-62ac-4a93-8903-2255e66b2d09",
"metadata": {},
"outputs": [],
"source": [
"splitter = vbt.Splitter.from_n_rolling(high.index, n=100, length=range_len)\n",
"\n",
"roll_high = splitter.take(high, into=\"reset_stacked\")\n",
"roll_low = splitter.take(low, into=\"reset_stacked\")\n",
"roll_close = splitter.take(close, into=\"reset_stacked\")\n",
"\n",
"range_indexes = splitter.take(high.index)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "722b1a84-5307-42fa-97f1-0027301e4367",
"metadata": {},
"outputs": [],
"source": [
"roll_close.columns"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5832339d-885a-4daa-982e-22b3c665d45c",
"metadata": {},
"outputs": [],
"source": [
"range_indexes[0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5232615f-9745-48d4-8cea-3e32d19d4c77",
"metadata": {},
"outputs": [],
"source": [
"sharpe_ratios = chunked_raw_pipeline_nb(\n",
" roll_high.values, \n",
" roll_low.values,\n",
" roll_close.values,\n",
" periods=period_product, \n",
" multipliers=multiplier_product,\n",
" ann_factor=ann_factor,\n",
" _execute_kwargs=dict(engine=\"dask\"),\n",
" _merge_kwargs=dict(input_columns=roll_close.columns)\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "190dd192-7fc0-48e1-8ad9-c9653c9c01ab",
"metadata": {},
"outputs": [],
"source": [
"sharpe_ratios"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "077421ca-b5d6-4e22-b53b-258c4033a8c0",
"metadata": {},
"outputs": [],
"source": [
"pf_hold = vbt.Portfolio.from_holding(roll_close, freq='1h')\n",
"sharpe_ratios_hold = pf_hold.sharpe_ratio"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "838776aa-3700-4d86-921b-d0bbabbdb47e",
"metadata": {},
"outputs": [],
"source": [
"sharpe_ratios_hold"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7e8a6477-3235-4243-bb93-662d17daf28d",
"metadata": {},
"outputs": [],
"source": [
"def plot_subperiod_sharpe(index, sharpe_ratios, sharpe_ratios_hold, range_indexes, symbol):\n",
" split = index[0]\n",
" sharpe_ratios = sharpe_ratios.xs(\n",
" symbol, \n",
" level='symbol', \n",
" drop_level=True)\n",
" sharpe_ratios = sharpe_ratios.xs(\n",
" split, \n",
" level='split', \n",
" drop_level=True)\n",
" start_date = range_indexes[split][0]\n",
" end_date = range_indexes[split][-1]\n",
" return sharpe_ratios.vbt.heatmap(\n",
" x_level='st_period', \n",
" y_level='st_multiplier',\n",
" title=\"{} - {}\".format(\n",
" start_date.strftime(\"%d %b, %Y %H:%M:%S\"),\n",
" end_date.strftime(\"%d %b, %Y %H:%M:%S\")\n",
" ),\n",
" trace_kwargs=dict(\n",
" zmin=sharpe_ratios.min(),\n",
" zmid=sharpe_ratios_hold[(split, symbol)],\n",
" zmax=sharpe_ratios.max(),\n",
" colorscale='Spectral'\n",
" )\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5306f36f-765f-4d52-b8a8-b4c3a4ba7b6d",
"metadata": {},
"outputs": [],
"source": [
"fname = 'raw_pipeline.gif'\n",
"level_idx = sharpe_ratios.index.names.index('split')\n",
"split_indices = sharpe_ratios.index.levels[level_idx]\n",
"\n",
"vbt.save_animation(\n",
" fname,\n",
" split_indices, \n",
" plot_subperiod_sharpe,\n",
" sharpe_ratios,\n",
" sharpe_ratios_hold,\n",
" range_indexes,\n",
" 'BTCUSDT',\n",
" delta=1,\n",
" fps=7,\n",
" writer_kwargs=dict(loop=0)\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dd826ac3-a4da-4065-a5d0-9c77ea5ad6f0",
"metadata": {},
"outputs": [],
"source": [
"from IPython.display import Image, display\n",
" \n",
"with open(fname,'rb') as f:\n",
" display(Image(data=f.read(), format='png'))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4743b6d1-c6a1-4fa7-85f0-0c1e7366a44c",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}