diff --git a/lightweight_charts/abstract.py b/lightweight_charts/abstract.py index 87d2f21..bf67401 100644 --- a/lightweight_charts/abstract.py +++ b/lightweight_charts/abstract.py @@ -245,9 +245,12 @@ class SeriesCommon(Pane): if format_cols: df = self._df_datetime_format(df, exclude_lowercase=self.name) if self.name: - if self.name not in df: + if self.name and len(df.columns) == 1: #if only one col rename it + df.columns = ['value'] + elif self.name not in df: raise NameError(f'No column named "{self.name}".') - df = df.rename(columns={self.name: 'value'}) + else: + df = df.rename(columns={self.name: 'value'}) self.data = df.copy() self._last_bar = df.iloc[-1] self.run_script(f'{self.id}.series.setData({js_data(df)}); ') diff --git a/lightweight_charts/helpers.py b/lightweight_charts/helpers.py index ee08dec..130c632 100644 --- a/lightweight_charts/helpers.py +++ b/lightweight_charts/helpers.py @@ -1,16 +1,46 @@ +from ast import parse from .widgets import JupyterChart from .util import ( is_vbt_indicator, get_next_color ) import pandas as pd +#default settings for each pricescale +ohlcv_cols = ['close', 'volume', 'open', 'high', 'low'] +right_cols = ['vwap'] +left_cols = ['rsi', 'cci', 'macd', 'macdsignal'] +middle1_cols = ["mom"] +middle2_cols = ["updated"] +histogram_cols = ['buyvolume', 'sellvolume', 'trades', 'macdhist'] + +def append_scales(df, right, histogram, left, middle1, middle2, name = ""): + if isinstance(df, pd.DataFrame): + for col in df.columns: + match col: + case c if c.lower() in ohlcv_cols: + continue + case c if c.lower() in right_cols: + right.append((df[c],name+c,)) + case c if c.lower() in histogram_cols: + histogram.append((df[c],name+c,)) + case c if c.lower() in left_cols: + left.append((df[c],name+c,)) + case c if c.lower() in middle1_cols: + middle1.append((df[c],name+c,)) + case c if c.lower() in middle2_cols: + middle2.append((df[c],name+c,)) + case _: + right.append((df[c],name+c,)) + else: #it is series (as df multiindex can be just envelope for series) + right.append((df,str(df.name),)) + def append_or_extend(target_list, value): if isinstance(value, list): target_list.extend(value) # Extend if it's a list else: target_list.append(value) # Append if it's a single value -def extend_kwargs(ohlcv, right, left, middle1, middle2, histogram, kwargs): +def extend_kwargs(ohlcv, right, left, middle1, middle2, histogram, auto_scale, kwargs): """ Mutate lists based on kwargs for accessor. Used when user added additional series to kwargs when using accessor. @@ -26,7 +56,9 @@ def extend_kwargs(ohlcv, right, left, middle1, middle2, histogram, kwargs): if 'middle1' in kwargs: append_or_extend(middle1, kwargs['middle1']) if 'middle2' in kwargs: - append_or_extend(middle1, kwargs['middle2']) + append_or_extend(middle2, kwargs['middle2']) + if 'auto_scale' in kwargs: + append_or_extend(auto_scale, kwargs['auto_scale']) return ohlcv #as tuple is immutable @@ -67,14 +99,16 @@ class PlotSRAccessor: middle1 = [] middle2 = [] histogram = [] + auto_scale = [] #if there are additional series in kwargs add them too #ohlcv is returned as it is tuple thus immutable - ohlcv = extend_kwargs(ohlcv, right, left, middle1, middle2, histogram, kwargs) + ohlcv = extend_kwargs(ohlcv, right, left, middle1, middle2, histogram, auto_scale, kwargs) right.append((self._obj,name)) pane1 = Panel( + auto_scale=auto_scale, ohlcv=ohlcv, histogram=histogram, right=right, @@ -125,39 +159,36 @@ class PlotDFAccessor: if "size" not in kwargs: kwargs["size"] = "xs" - #default settings for each pricescale - ohlcv_cols = ['close', 'volume', 'open', 'high', 'low'] - right_cols = ['vwap'] - left_cols = ['rsi'] - middle1_cols = [] - middle2_cols = [] - histogram_cols = ['buyvolume', 'sellvolume', 'trades'] - ohlcv = () right = [] left = [] middle1 = [] middle2 = [] histogram = [] + auto_scale = [] + + if isinstance(self._obj.columns, pd.MultiIndex): + for col_tuple in self._obj.columns: + # Access the data for each column tuple dynamically + df = self._obj.loc[:, col_tuple] + name = str(col_tuple)+" " + append_scales(df, right, histogram, left, middle1, middle2, name) + first_column_df = self._obj.loc[:, self._obj.columns[0]] + ohlcv = (first_column_df[ohlcv_cols],) if isinstance(first_column_df, pd.DataFrame) and first_column_df.columns in ohlcv else () #in case of multiindex only the first ohlcv is display only one ohlcv is allowed on the pane - for col in self._obj.columns: - if col in right_cols: - right.append((self._obj[col],col,)) - if col in histogram_cols: - histogram.append((self._obj[col],col,)) - if col in left_cols: - left.append((self._obj[col],col,)) - if col in middle1_cols: - middle1_cols.append((self._obj[col],col,)) - if col in middle2_cols: - middle2_cols.append((self._obj[col],col,)) - - ohlcv = (self._obj[ohlcv_cols],) - + else: + append_scales(self._obj, right, histogram, left, middle1, middle2) + #add ohlcv if all columns ohlcv_cols + #column mapping enables either both lowercase and first upper + column_mapping = {key: next((col for col in self._obj.columns if col.lower() == key), None) for key in ohlcv_cols} + mapped_columns = [column_mapping[key] for key in ohlcv_cols if column_mapping[key] is not None] + ohlcv = (self._obj[mapped_columns],) if isinstance(self._obj, pd.DataFrame) and all(col in self._obj.columns.str.lower() for col in ohlcv_cols) else () + #if there are additional series in kwargs add them too - ohlcv = extend_kwargs(ohlcv, right, left, middle1, middle2, histogram, kwargs) + ohlcv = extend_kwargs(ohlcv, right, left, middle1, middle2, histogram, auto_scale, kwargs) pane1 = Panel( + auto_scale=auto_scale, ohlcv=ohlcv, histogram=histogram, right=right, @@ -196,6 +227,7 @@ class Panel: * left : list of tuples, optional * middle1 : list of tuples, optional * middle2 : list of tuples, optional + * auto_scale: list of objects, optional - external objects (vbt indicators) that can be automatically parsed to given scaleID * xloc : str or slice, optional. Vectorbt indexing. Default is None. * precision: int, optional. The number of digits after the decimal point. Apply to all lines on this pane. Default is None. @@ -227,6 +259,7 @@ class Panel: ) pane2 = Panel( + auto_scale=[macd_vbt_ind], ohlcv=(t1data.data["BAC"],), right=[], left=[(sma, "sma_below", short_signals, short_exits)], @@ -258,7 +291,8 @@ class Panel: ch = chart([pane1], title="Chart with EntryShort/ExitShort (yellow) and EntryLong/ExitLong markers (pink)", sync=True, session=None, size="s") ``` """ - def __init__(self, ohlcv=None, right=None, left=None, middle1=None, middle2=None, histogram=None, title=None, xloc=None, precision=None): + def __init__(self, auto_scale=[],ohlcv=None, right=None, left=None, middle1=None, middle2=None, histogram=None, title=None, xloc=None, precision=None): + self.auto_scale = auto_scale self.ohlcv = ohlcv if ohlcv is not None else () self.right = right if right is not None else [] self.left = left if left is not None else [] @@ -382,6 +416,78 @@ def chart(panes: list[Panel], sync=False, title='', size="m", xloc=None, session active_chart.markers_set(markers=xloc_me(markers, xloc), type=type, color=color if color is not None else None) + def add_to_scale(series, right, histogram, left, middle1, middle2, column,name = None): + """ + Assigns a series to a scaleId based on its name and pre-defined col names. + + Args: + ----- + series (pd.Series): The series to be added to a scaleId + right (list): The right scale to add to + histogram (list): The histogram scale to add to + left (list): The left scale to add to + middle1 (list): The middle1 scale to add to + middle2 (list): The middle2 scale to add to + name (str): The name of the series + + Returns: + ------- + None + + Notes: + ----- + The function checks if the series name is in the pre-defined column names + (e.g. ohlcv_cols, right_cols, histogram_cols, etc.) and assigns the series to + the corresponding scaleId. If the name is not found in any of the pre-defined + column names, the series is added to the right scale by default. + """ + if name is None: + name = column + if column.lower() in ohlcv_cols: + return + elif column.lower() in right_cols: + right.append((series, name,)) + elif column.lower() in histogram_cols: + histogram.append((series, name)) + elif column.lower() in left_cols: + left.append((series, name)) + elif column.lower() in middle1_cols: + middle1.append((series, name)) + elif column.lower() in middle2_cols: + middle2.append((series, name)) + else: + right.append((series, name,)) + + # automatic scale assignment + if len(pane.auto_scale) > 0: + for obj in pane.auto_scale: + if is_vbt_indicator(obj): #for vbt indicators + for output in obj.output_names: + output_series = getattr(obj, output) + output_name = obj.short_name + ':' + output + output = obj.short_name if output == "real" else output + #if output_series is multiindex - add each combination to respective scaleId + if isinstance(output_series, pd.DataFrame) and isinstance(output_series.columns, pd.MultiIndex): + for col_tuple in output_series.columns: + name=output_name + " " + str(col_tuple) + series_copy = output_series.loc[:, col_tuple].copy(deep=True) + add_to_scale(series_copy, pane.right, pane.histogram, pane.left, pane.middle1, pane.middle2, output, name) + elif isinstance(output_series, pd.DataFrame) and len(output_series.columns) > 1: #in case of multicolumns + for col in output_series.columns: + name=output_name + " " + col + series_copy = output_series.loc[:, col].copy(deep=True) + add_to_scale(series_copy, pane.right, pane.histogram, pane.left, pane.middle1, pane.middle2, output, name) + elif isinstance(output_series, pd.DataFrame) and len(output_series.columns) == 1: + name=output_name + " " + output_series.columns[0] + series_copy = output_series.squeeze() + add_to_scale(series_copy, pane.right, pane.histogram, pane.left, pane.middle1, pane.middle2, output, name) + else: #add output to respective scale + series_copy = output_series.copy(deep=True) + add_to_scale(series_copy, pane.right, pane.histogram, pane.left, pane.middle1, pane.middle2, output, output_name) + + # zde jsem skoncil + #vbt ind + if pane.ohlcv != (): series, entries, exits, markers = (pane.ohlcv + (None,) * 4)[:4] if series is None: @@ -404,8 +510,14 @@ def chart(panes: list[Panel], sync=False, title='', size="m", xloc=None, session kwargs['color'] = color if opacity is not None: kwargs['opacity'] = opacity - tmp = active_chart.create_histogram(**kwargs) #green transparent "rgba(53, 94, 59, 0.6)" - tmp.set(xloc_me(series, xloc)) + if isinstance(series, pd.DataFrame) and isinstance(series.columns, pd.MultiIndex): #multiindex handling + for col_tuple in series.columns: + kwargs = {'name': name + str(col_tuple)} + tmp = active_chart.create_histogram(**kwargs) #green transparent "rgba(53, 94, 59, 0.6)" + tmp.set(xloc_me(series.loc[:, col_tuple], xloc)) + else: + tmp = active_chart.create_histogram(**kwargs) #green transparent "rgba(53, 94, 59, 0.6)" + tmp.set(xloc_me(series, xloc)) if pane.title is not None: active_chart.topbar.textbox("title",pane.title) @@ -413,7 +525,7 @@ def chart(panes: list[Panel], sync=False, title='', size="m", xloc=None, session #iterate over keys - they are all priceScaleId except of these for att_name, att_value_tuple in vars(pane).items(): - if att_name in ["ohlcv","histogram","title","xloc","precision"]: + if att_name in ["ohlcv","histogram","title","xloc","precision", "auto_scale"]: continue for tup in att_value_tuple: series, name, entries, exits, markers = (tup + (None, None, None, None, None))[:5] @@ -425,9 +537,20 @@ def chart(panes: list[Panel], sync=False, title='', size="m", xloc=None, session series = series.xloc[xloc] if xloc is not None else series for output in series.output_names: output_series = getattr(series, output) - output = name + ':' + output if name is not None else output - tmp = active_chart.create_line(name=output, priceScaleId=att_name)#, color="blue") - tmp.set(output_series) + output = name + ':' + output if name is not None else series.short_name + ":" + output + #if output_series is multiindex - create aline for each combination + if isinstance(output_series, pd.DataFrame) and isinstance(output_series.columns, pd.MultiIndex): + for col_tuple in output_series.columns: + tmp = active_chart.create_line(name=output + " " + str(col_tuple), priceScaleId=att_name)#, color="blue") + tmp.set(output_series.loc[:, col_tuple]) + else: + tmp = active_chart.create_line(name=output, priceScaleId=att_name)#, color="blue") + tmp.set(output_series) + #if multiindex then unpack them all with tuple as names + elif isinstance(series, pd.DataFrame) and isinstance(series.columns, pd.MultiIndex): + for col_tuple in series.columns: + tmp = active_chart.create_line(name=str(col_tuple) if name is None else name+" "+str(col_tuple), priceScaleId=att_name)#, color="blue") + tmp.set(xloc_me(series.loc[:, col_tuple], xloc)) else: if name is None: name = "no_name" if not hasattr(series, 'name') or series.name is None else str(series.name) @@ -449,13 +572,14 @@ def chart(panes: list[Panel], sync=False, title='', size="m", xloc=None, session active_chart.fit() if session is not None and session: try: - last_used_series = output_series if is_vbt_indicator(series) else series #pokud byl posledni series vbt, pak pouzijeme jeho outputy + last_used_series = output_series.loc[:, col_tuple] if isinstance(output_series, pd.DataFrame) and isinstance(output_series.columns, pd.MultiIndex) else output_series if is_vbt_indicator(series) else series #pokud byl posledni series vbt, pak pouzijeme jeho outputy + last_used_series = last_used_series.iloc[:,0] if isinstance(last_used_series, pd.DataFrame) else last_used_series #if df then use just first column t1 = xloc_me(last_used_series, xloc) t1 = t1.vbt.xloc[session] target_data = t1.obj #we dont know the exact time of market start +- 3 seconds thus we find mark first row after 9:30 # Resample the data to daily frequency and get the first entry of each day - first_row_indexes = target_data.resample('D').apply(lambda x: x.index[0]) + first_row_indexes = target_data.resample('D').apply(lambda x: x.index[0] if not x.empty else None).dropna() # Convert the indexes to a list session_starts = first_row_indexes.to_list() diff --git a/setup.py b/setup.py index 2775ead..9ca1acf 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ with open('README.md', 'r', encoding='utf-8') as f: setup( name='lightweight_charts', - version='2.2.2', + version='2.2.3', packages=find_packages(), python_requires='>=3.8', install_requires=[