# ################################## HOW TO USE #################################### #
#                                                                                    #
# This is a Jupyter notebook formatted as a script                                   #
# Format: https://jupytext.readthedocs.io/en/latest/formats.html#the-percent-format  #
#                                                                                    #
# Save this file and remove the '.txt' extension                                     #
# In Jupyter Lab, right click on the Python file -> Open With -> Jupytext Notebook   #
# Make sure to have Jupytext installed: https://github.com/mwouts/jupytext           #
#                                                                                    #
# ################################################################################## #

# %% [markdown]
# #  Building blocks
# ## Utilities

# %% [markdown]
# ### Formatting

# %%
from vectorbtpro import *

dct = {'planet' : {'has': {'plants': 'yes', 'animals': 'yes', 'cryptonite': 'no'}, 'name': 'Earth'}}
print(vbt.prettify(dct))

# %%
eval(vbt.prettify(dct)) == dct

# %% [markdown]
# ### Pickling
# ### Configuring

# %%
print(vbt.Records.field_config)

# %%
cfg = vbt.Config(
    hello="world",
    options_=dict(readonly=True)
)
print(cfg)

# %%
cfg["change"] = "something"

# %%
class CorrStats(vbt.Configured):
    def __init__(self, obj1, obj2):
        vbt.Configured.__init__(self, obj1=obj1, obj2=obj2)

        self._obj1 = obj1
        self._obj2 = obj2

    @property
    def obj1(self):
        return self._obj1

    @property
    def obj2(self):
        return self._obj2

    def corr(self):
        if isinstance(self.obj1, pd.Series):
            return self.obj1.corr(self.obj2)
        return self.obj1.corrwith(self.obj2)

    def rolling_corr(self, window):
        return self.obj1.rolling(window).corr(self.obj2)

# %%
index = [datetime(2020, 1, 1) + timedelta(days=i) for i in range(5)]
df1 = pd.DataFrame({
    'a': [1, 5, 2, 4, 3],
    'b': [3, 2, 4, 1, 5]
}, index=index)
df2 = pd.DataFrame({
    'a': [1, 2, 3, 4, 5],
    'b': [5, 4, 3, 2, 1]
}, index=index)

corrstats = CorrStats(df1, df2)
print(corrstats.config)

# %%
df3 = pd.DataFrame({
    'a': [3, 2, 1, 5, 4],
    'b': [4, 5, 1, 2, 3]
}, index=index)
corrstats.obj1 = df3

# %%
corrstats.config['obj1'] = df3

# %%
corrstats._obj1 = df3
corrstats.obj1.iloc[:] = df3

# %%
new_corrstats = corrstats.replace(obj1=df3)
new_corrstats.obj1

# %%
new_corrstats.obj2

# %%
corrstats.save('corrstats')

corrstats = CorrStats.load('corrstats')

# %% [markdown]
# ### Attribute resolution

# %%
sr = pd.Series([1, 2, 3, 4, 5])
attr_chain = [('rolling', (3,)), 'mean', 'min']
vbt.deep_getattr(sr, attr_chain)

# %% [markdown]
# ### Templating

# %%
def some_function(*args, **kwargs):
    context = {}
    args = vbt.substitute_templates(args, context=context, strict=False)
    kwargs = vbt.substitute_templates(kwargs, context=context, strict=False)
    print(args)
    print(kwargs)

    context['result'] = 100
    args = vbt.substitute_templates(args, context=context)
    kwargs = vbt.substitute_templates(kwargs, context=context)
    print(args)
    print(kwargs)

some_function(vbt.Rep('result'), double_result=vbt.RepEval('result * 2'))

# %% [markdown]
# ## Base

# %% [markdown]
# ### Grouping

# %%
columns = pd.MultiIndex.from_tuples([
    ('BTC-USD', 'group1'),
    ('ETH-USD', 'group1'),
    ('ADA-USD', 'group2'),
    ('SOL-USD', 'group2')
], names=['symbol', 'group'])
vbt.Grouper(columns, 'group').get_groups()

# %% [markdown]
# ### Indexing

# %%
class CorrStats(vbt.Configured, vbt.PandasIndexer):
    def __init__(self, obj1, obj2):
        vbt.Configured.__init__(self, obj1=obj1, obj2=obj2)
        vbt.PandasIndexer.__init__(self)

        self._obj1 = obj1
        self._obj2 = obj2

    def indexing_func(self, pd_indexing_func):
        return self.replace(
            obj1=pd_indexing_func(self.obj1),
            obj2=pd_indexing_func(self.obj2)
        )

    @property
    def obj1(self):
        return self._obj1

    @property
    def obj2(self):
        return self._obj2

    def corr(self):
        if isinstance(self.obj1, pd.Series):
            return self.obj1.corr(self.obj2)
        return self.obj1.corrwith(self.obj2)

    def rolling_corr(self, window):
        return self.obj1.rolling(window).corr(self.obj2)

corrstats = CorrStats(df1, df2)
corrstats.corr()

# %%
corrstats.loc['2020-01-01':'2020-01-03', 'a'].corr()

# %% [markdown]
# ### Wrapping

# %%
df = pd.DataFrame({
    'a': range(0, 5),
    'b': range(5, 10),
    'c': range(10, 15),
    'd': range(15, 20)
}, index=index)
wrapper = vbt.ArrayWrapper.from_obj(df)
print(wrapper)

# %%
def sum_per_column(df):
    wrapper = vbt.ArrayWrapper.from_obj(df)
    result = np.sum(df.values, axis=0)
    return wrapper.wrap_reduced(result)

sum_per_column(df)

# %%
big_df = pd.DataFrame(np.random.uniform(size=(1000, 1000)))

%timeit big_df.sum()

# %%
%timeit sum_per_column(big_df)

# %%
def sum_per_group(df, group_by):
    wrapper = vbt.ArrayWrapper.from_obj(df, group_by=group_by)
    results = []
    for group_idxs in wrapper.grouper.yield_group_idxs():
        group_result = np.sum(df.values[:, group_idxs])
        results.append(group_result)
    return wrapper.wrap_reduced(results)

sum_per_group(df, False)

# %%
sum_per_group(df, True)

# %%
group_by = pd.Index(['group1', 'group1', 'group2', 'group2'])
sum_per_group(df, group_by)

# %%
class CorrStats(vbt.Wrapping):
    _expected_keys = vbt.Wrapping._expected_keys | {"obj1", "obj2"}

    @classmethod
    def from_objs(cls, obj1, obj2):
        (obj1, obj2), wrapper = vbt.broadcast(
            obj1, obj2,
            to_pd=False,
            return_wrapper=True
        )
        return cls(wrapper, obj1, obj2)

    def __init__(self, wrapper, obj1, obj2):
        vbt.Wrapping.__init__(self, wrapper, obj1=obj1, obj2=obj2)

        self._obj1 = vbt.to_2d_array(obj1)
        self._obj2 = vbt.to_2d_array(obj2)

    def indexing_func(self, pd_indexing_func, **kwargs):
        wrapper_meta = self.wrapper.indexing_func_meta(pd_indexing_func, **kwargs)
        new_wrapper = wrapper_meta["new_wrapper"]
        row_idxs = wrapper_meta["row_idxs"]
        col_idxs = wrapper_meta["col_idxs"]
        return self.replace(
            wrapper=new_wrapper,
            obj1=self.obj1[row_idxs, :][:, col_idxs],
            obj2=self.obj2[row_idxs, :][:, col_idxs]
        )

    @property
    def obj1(self):
        return self._obj1

    @property
    def obj2(self):
        return self._obj2

    def corr(self):
        out = vbt.nb.nancorr_nb(self.obj1, self.obj2)
        return self.wrapper.wrap_reduced(out)

    def rolling_corr(self, window):
        out = vbt.nb.rolling_corr_nb(
            self.obj1, self.obj2,
            window, minp=window)
        return self.wrapper.wrap(out)

# %%
df1.corrwith(df2)

# %%
corrstats = CorrStats.from_objs(df1, df2)
corrstats.corr()

# %%
df2_sh = vbt.pd_acc.concat(
    df2, df2.vbt.shuffle(seed=42),
    keys=['plain', 'shuffled'])
df2_sh

# %%
df1.corrwith(df2_sh)

# %%
corrstats = CorrStats.from_objs(df1, df2_sh)
corrstats.corr()

# %%
big_df1 = pd.DataFrame(np.random.uniform(size=(1000, 1000)))
big_df2 = pd.DataFrame(np.random.uniform(size=(1000, 1000)))

%timeit big_df1.rolling(10).corr(big_df2)

# %%
corrstats = CorrStats.from_objs(big_df1, big_df2)
%timeit corrstats.rolling_corr(10)

# %%
corrstats = CorrStats.from_objs(df1, df2_sh)
corrstats.loc['2020-01-02':'2020-01-05'].rolling_corr(3)

# %% [markdown]
# ### Base accessor

# %%
df.vbt

# %%
df.vbt.to_2d_array()

# %%
pd.Series([1, 2, 3]).vbt.combine(np.array([[4, 5, 6]]), np.add)

# %%
pd.Series([1, 2, 3]) + np.array([[4, 5, 6]])

# %%
pd.Series([1, 2, 3]).vbt + np.array([[4, 5, 6]])

# %% [markdown]
# ## Generic

# %% [markdown]
# ### Builder mixins
# ### Analyzing

# %%
class CorrStats(vbt.Analyzable):
    _expected_keys = vbt.Analyzable._expected_keys | {"obj1", "obj2"}

    @classmethod
    def from_objs(cls, obj1, obj2):
        (obj1, obj2), wrapper = vbt.broadcast(
            obj1, obj2,
            to_pd=False,
            return_wrapper=True
        )
        return cls(wrapper, obj1, obj2)

    def __init__(self, wrapper, obj1, obj2):
        vbt.Analyzable.__init__(self, wrapper, obj1=obj1, obj2=obj2)

        self._obj1 = vbt.to_2d_array(obj1)
        self._obj2 = vbt.to_2d_array(obj2)

    def indexing_func(self, pd_indexing_func, **kwargs):
        wrapper_meta = self.wrapper.indexing_func_meta(pd_indexing_func, **kwargs)
        new_wrapper = wrapper_meta["new_wrapper"]
        row_idxs = wrapper_meta["row_idxs"]
        col_idxs = wrapper_meta["col_idxs"]
        return self.replace(
            wrapper=new_wrapper,
            obj1=self.obj1[row_idxs, :][:, col_idxs],
            obj2=self.obj2[row_idxs, :][:, col_idxs]
        )

    @property
    def obj1(self):
        return self._obj1

    @property
    def obj2(self):
        return self._obj2

    def corr(self):
        out = vbt.nb.nancorr_nb(self.obj1, self.obj2)
        return self.wrapper.wrap_reduced(out)

    def rolling_corr(self, window):
        out = vbt.nb.rolling_corr_nb(
            self.obj1, self.obj2,
            window, minp=window)
        return self.wrapper.wrap(out)

    _metrics = vbt.HybridConfig(
        corr=dict(
            title='Corr. Coefficient',
            calc_func='corr'
        )
    )

    _subplots = vbt.HybridConfig(
         rolling_corr=dict(
             title=vbt.Sub("Rolling Corr. Coefficient (window=$window)"),
             plot_func=vbt.Sub('rolling_corr($window).vbt.plot'),
             pass_trace_names=False
         )
    )

# %%
corrstats = CorrStats.from_objs(df1, df2)
corrstats.stats(column='a')

# %%
corrstats['a'].stats()

# %%
corrstats.plots(template_context=dict(window=3)).show()

# %% [markdown]
# ### Generic accessor

# %%
df.vbt.stats(column='a')

# %% [markdown]
# ## Records

# %%
dd_df = pd.DataFrame({
    'a': [10, 11, 12, 11, 12, 13, 12],
    'b': [14, 13, 12, 11, 12, 13, 14]
}, index=[datetime(2020, 1, 1) + timedelta(days=i) for i in range(7)])
drawdowns = dd_df.vbt.drawdowns
drawdowns.records_readable

# %%
drawdowns['b'].records_readable

# %%
drawdowns.status.values

# %%
drawdowns.get_apply_mapping_arr('status')

# %% [markdown]
# ### Column mapper

# %%
drawdowns.col_mapper.col_map

# %% [markdown]
# ### Mapped arrays

# %%
dd_ma = drawdowns.drawdown
dd_ma

# %%
dd_ma.values

# %%
dd_ma.stats(column='a')

# %%
dd_ma.to_pd()

# %%
dd_ma['b'].values

# %%