# ################################## HOW TO USE #################################### #
#                                                                                    #
# This is a Jupyter notebook formatted as a script                                   #
# Format: https://jupytext.readthedocs.io/en/latest/formats.html#the-percent-format  #
#                                                                                    #
# Save this file and remove the '.txt' extension                                     #
# In Jupyter Lab, right click on the Python file -> Open With -> Jupytext Notebook   #
# Make sure to have Jupytext installed: https://github.com/mwouts/jupytext           #
#                                                                                    #
# ################################################################################## #

# %% [markdown]
# #  Performance
# ## Accumulators

# %%
@njit
def fastest_rolling_zscore_1d_nb(arr, window, minp=None, ddof=1):
    if minp is None:
        minp = window
    out = np.full(arr.shape, np.nan)
    cumsum = 0.0
    cumsum_sq = 0.0
    nancnt = 0

    for i in range(len(arr)):
        pre_window_value = arr[i - window] if i - window >= 0 else np.nan
        mean_in_state = vbt.nb.RollMeanAIS(
            i, arr[i], pre_window_value, cumsum, nancnt, window, minp
        )
        mean_out_state = vbt.nb.rolling_mean_acc_nb(mean_in_state)
        _, _, _, mean = mean_out_state
        std_in_state = vbt.nb.RollStdAIS(
            i, arr[i], pre_window_value, cumsum, cumsum_sq, nancnt, window, minp, ddof
        )
        std_out_state = vbt.nb.rolling_std_acc_nb(std_in_state)
        cumsum, cumsum_sq, nancnt, _, std = std_out_state
        out[i] = (arr[i] - mean) / std
    return out

data = vbt.YFData.pull("BTC-USD")
rolling_zscore = fastest_rolling_zscore_1d_nb(data.returns.values, 14)
data.symbol_wrapper.wrap(rolling_zscore)

# %%
(data.returns - data.returns.rolling(14).mean()) / data.returns.rolling(14).std()

# %% [markdown]
# ## Chunking

# %%
@vbt.chunked(
    chunk_len=100,
    merge_func="concat",
    execute_kwargs=dict(
        show_progress=True,
        clear_cache=True,
        collect_garbage=True
    )
)
def backtest(data, fast_windows, slow_windows):
    fast_ma = vbt.MA.run(data.close, fast_windows, short_name="fast")
    slow_ma = vbt.MA.run(data.close, slow_windows, short_name="slow")
    entries = fast_ma.ma_crossed_above(slow_ma)
    exits = fast_ma.ma_crossed_below(slow_ma)
    pf = vbt.PF.from_signals(data.close, entries, exits)
    return pf.total_return

param_product = vbt.combine_params(
    dict(
        fast_window=vbt.Param(range(2, 100), condition="fast_window < slow_window"),
        slow_window=vbt.Param(range(2, 100)),
    ),
    build_index=False
)
backtest(
    vbt.YFData.pull(["BTC-USD", "ETH-USD"]),
    vbt.Chunked(param_product["fast_window"]),
    vbt.Chunked(param_product["slow_window"])
)

# %% [markdown]
# ## Parallel Numba

# %%
df = pd.DataFrame(np.random.uniform(size=(1000, 1000)))

%timeit df.rolling(10).mean()

# %%
%timeit df.vbt.rolling_mean(10)

# %%
%timeit df.vbt.rolling_mean(10, jitted=dict(parallel=True))

# %% [markdown]
# ## Multithreading

# %%
data = vbt.YFData.pull(["BTC-USD", "ETH-USD"])

%timeit vbt.PF.from_random_signals(data.close, n=[100] * 1000)

# %%
%timeit vbt.PF.from_random_signals(data.close, n=[100] * 1000, chunked="threadpool")

# %% [markdown]
# ## Multiprocessing

# %%
@vbt.chunked(
    size=vbt.ArraySizer(arg_query="items", axis=1),
    arg_take_spec=dict(
        items=vbt.ArraySelector(axis=1)
    ),
    merge_func=np.column_stack
)
def bubble_sort(items):
    items = items.copy()
    for i in range(len(items)):
        for j in range(len(items) - 1 - i):
            if items[j] > items[j + 1]:
                items[j], items[j + 1] = items[j + 1], items[j]
    return items

items = np.random.uniform(size=(1000, 3))

%timeit bubble_sort(items)

# %%
%timeit bubble_sort(items, _execute_kwargs=dict(engine="pathos"))

# %% [markdown]
# ## Jitting

# %%
data = vbt.YFData.pull("BTC-USD", start="7 days ago")
log_returns = np.log1p(data.close.pct_change())
log_returns.vbt.cumsum()

# %%
log_returns.vbt.cumsum(jitted=False)

# %%
@vbt.register_jitted(task_id_or_func=vbt.nb.nancumsum_nb)
def nancumsum_np(arr):
    return np.nancumsum(arr, axis=0)

log_returns.vbt.cumsum(jitted="np")

# %% [markdown]
# ## Caching

# %%
data = vbt.YFData.pull("BTC-USD")
pf = vbt.PF.from_random_signals(data.close, n=5)
_ = pf.stats()

pf.get_ca_setup().get_status_overview(
    filter_func=lambda setup: setup.caching_enabled,
    include=["hits", "misses", "total_size"]
)

# %% [markdown]
# ## Hyperfast rolling metrics

# %%
import quantstats as qs

index = pd.date_range("2020", periods=100000, freq="1min")
returns = pd.Series(np.random.normal(0, 0.001, size=len(index)), index=index)

%timeit qs.stats.rolling_sortino(returns, rolling_period=10)

# %%
%timeit returns.vbt.returns.rolling_sortino_ratio(window=10)

# %%