Compare commits

...

22 Commits

Author SHA1 Message Date
68ea25d963 class weights for multiclass for training 2024-12-05 11:12:35 +01:00
277de3f73e fix 2024-11-29 13:36:38 +01:00
d99df2402c fix 2024-11-29 13:32:51 +01:00
f52fb2649c fix 2024-11-28 11:23:12 +01:00
30885171c3 fix 2024-11-27 10:27:03 +01:00
623412406e fix 2024-11-27 10:15:22 +01:00
1b3f5b1b79 update 2024-11-27 09:20:24 +01:00
0705b45351 gpu support added 2024-11-26 18:20:29 +01:00
c5e4c03af7 fix 2024-11-26 18:06:41 +01:00
191b58d11d thr conf matrix added 2024-11-21 16:08:07 +01:00
f2066f419e daily 2024-11-21 13:27:17 +01:00
519163efb5 fix 2024-11-21 09:32:20 +01:00
491bfc9feb ml support added 2024-11-21 09:20:36 +01:00
0ff42e2345 ordering of remote files fix 2024-11-21 05:28:08 +01:00
e22fda2f35 readme change 2024-11-19 10:28:32 +01:00
25b5a53774 Lee-Ready method updated - trades_buy_count and trades_sell_count added 2024-11-19 10:24:31 +01:00
169f07563e fix vol bars 2024-11-14 13:48:20 +01:00
fb5b2369e1 fix 2024-11-14 12:55:47 +01:00
b23a772836 remote fetch 2024-11-10 14:08:41 +01:00
cf6bcede48 remote range in utc 2024-11-01 15:41:23 +01:00
2116679dba optimalizations 2024-11-01 11:18:10 +01:00
c3faa53eff fix 2024-10-31 13:20:56 +01:00
11 changed files with 3722 additions and 552 deletions

View File

@ -3,9 +3,11 @@ A Python library for tools, utilities, and helpers for my trading research workf
## Installation
```python
```bash
pip install git+https://github.com/drew2323/ttools.git
```
or
```bash
pip install git+https://gitea.stratlab.dev/dwker/ttools.git
```
Modules:
@ -14,6 +16,9 @@ Modules:
- remotely fetches daily trade data
- manages trade cache (daily trade files per symbol) and aggregation cache (per symbola and requested period)
- numba compiled aggregator for required output (time based, dollars, volume bars, renkos...).
- additional columns calculated from tick data and included in bars
- buyvolume, sellvolume - total amount of volume triggered by aggressive orders (estimated by Lee-Ready algorithm)
- buytrades, selltrades - total amount of trades in each bar grouped by side of aggregsive orders
Detailed examples in [tests/data_loader_tryme.ipynb](tests/data_loader_tryme.ipynb)
@ -95,6 +100,28 @@ python3 prepare_cache.py --symbols BAC AAPL --day_start 2024-10-14 --day_stop 20
```
## remote loaders
Remote bars of given resolutions from Alpaca.
Available resolutions Minute, Hours, Day. It s not possible to limit included trades.
Use only when no precision required.
```python
from ttools.external_loaders import load_history_bars
from ttools.config import zoneNY
from datetime import datetime, time
from alpaca.data.timeframe import TimeFrame, TimeFrameUnit
symbol = "AAPL"
start_date = zoneNY.localize(datetime(2023, 2, 27, 18, 51, 38))
end_date = zoneNY.localize(datetime(2023, 4, 27, 21, 51, 39))
timeframe = TimeFrame(amount=1,unit=TimeFrameUnit.Minute)
df = load_history_bars(symbol, start_date, end_date, timeframe, main_session_only=True)
df.loc[('AAPL',)]
```
# vbtutils
Contains helpers for vbtpro

View File

@ -2,7 +2,7 @@ from setuptools import setup, find_packages
setup(
name='ttools',
version='0.6.4',
version='0.7.99',
packages=find_packages(),
install_requires=[
# list your dependencies here

460
tests/alpaca_loader.ipynb Normal file
View File

@ -0,0 +1,460 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"TTOOLS: Loaded env variables from file /Users/davidbrazda/Documents/Development/python/.env\n"
]
}
],
"source": [
"from ttools.external_loaders import load_history_bars\n",
"from ttools.config import zoneNY\n",
"from datetime import datetime, time\n",
"from alpaca.data.timeframe import TimeFrame, TimeFrameUnit\n",
"\n",
"symbol = \"AAPL\"\n",
"start_date = zoneNY.localize(datetime(2023, 2, 27, 18, 51, 38))\n",
"end_date = zoneNY.localize(datetime(2023, 4, 27, 21, 51, 39))\n",
"timeframe = TimeFrame(amount=1,unit=TimeFrameUnit.Minute)\n",
"\n",
"df = load_history_bars(symbol, start_date, end_date, timeframe, True)\n",
"df.loc[('AAPL',)]"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>open</th>\n",
" <th>high</th>\n",
" <th>low</th>\n",
" <th>close</th>\n",
" <th>volume</th>\n",
" <th>trade_count</th>\n",
" <th>vwap</th>\n",
" </tr>\n",
" <tr>\n",
" <th>timestamp</th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>2023-02-28 09:30:00-05:00</th>\n",
" <td>147.050</td>\n",
" <td>147.380</td>\n",
" <td>146.830</td>\n",
" <td>147.2700</td>\n",
" <td>1554100.0</td>\n",
" <td>6447.0</td>\n",
" <td>146.914560</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2023-02-28 09:31:00-05:00</th>\n",
" <td>147.250</td>\n",
" <td>147.320</td>\n",
" <td>147.180</td>\n",
" <td>147.2942</td>\n",
" <td>159387.0</td>\n",
" <td>6855.0</td>\n",
" <td>147.252171</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2023-02-28 09:32:00-05:00</th>\n",
" <td>147.305</td>\n",
" <td>147.330</td>\n",
" <td>147.090</td>\n",
" <td>147.1600</td>\n",
" <td>214536.0</td>\n",
" <td>7435.0</td>\n",
" <td>147.210128</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2023-02-28 09:33:00-05:00</th>\n",
" <td>147.140</td>\n",
" <td>147.230</td>\n",
" <td>147.090</td>\n",
" <td>147.1500</td>\n",
" <td>171487.0</td>\n",
" <td>7235.0</td>\n",
" <td>147.154832</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2023-02-28 09:34:00-05:00</th>\n",
" <td>147.160</td>\n",
" <td>147.160</td>\n",
" <td>146.880</td>\n",
" <td>146.9850</td>\n",
" <td>235915.0</td>\n",
" <td>4965.0</td>\n",
" <td>147.001762</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2023-04-27 15:26:00-04:00</th>\n",
" <td>168.400</td>\n",
" <td>168.415</td>\n",
" <td>168.340</td>\n",
" <td>168.3601</td>\n",
" <td>163973.0</td>\n",
" <td>1398.0</td>\n",
" <td>168.368809</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2023-04-27 15:27:00-04:00</th>\n",
" <td>168.360</td>\n",
" <td>168.400</td>\n",
" <td>168.330</td>\n",
" <td>168.3800</td>\n",
" <td>130968.0</td>\n",
" <td>1420.0</td>\n",
" <td>168.364799</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2023-04-27 15:28:00-04:00</th>\n",
" <td>168.380</td>\n",
" <td>168.430</td>\n",
" <td>168.320</td>\n",
" <td>168.3285</td>\n",
" <td>152193.0</td>\n",
" <td>1361.0</td>\n",
" <td>168.372671</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2023-04-27 15:29:00-04:00</th>\n",
" <td>168.325</td>\n",
" <td>168.330</td>\n",
" <td>168.260</td>\n",
" <td>168.2850</td>\n",
" <td>208426.0</td>\n",
" <td>1736.0</td>\n",
" <td>168.297379</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2023-04-27 15:30:00-04:00</th>\n",
" <td>168.280</td>\n",
" <td>168.350</td>\n",
" <td>168.255</td>\n",
" <td>168.3450</td>\n",
" <td>218077.0</td>\n",
" <td>1694.0</td>\n",
" <td>168.308873</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>15162 rows × 7 columns</p>\n",
"</div>"
],
"text/plain": [
" open high low close volume \\\n",
"timestamp \n",
"2023-02-28 09:30:00-05:00 147.050 147.380 146.830 147.2700 1554100.0 \n",
"2023-02-28 09:31:00-05:00 147.250 147.320 147.180 147.2942 159387.0 \n",
"2023-02-28 09:32:00-05:00 147.305 147.330 147.090 147.1600 214536.0 \n",
"2023-02-28 09:33:00-05:00 147.140 147.230 147.090 147.1500 171487.0 \n",
"2023-02-28 09:34:00-05:00 147.160 147.160 146.880 146.9850 235915.0 \n",
"... ... ... ... ... ... \n",
"2023-04-27 15:26:00-04:00 168.400 168.415 168.340 168.3601 163973.0 \n",
"2023-04-27 15:27:00-04:00 168.360 168.400 168.330 168.3800 130968.0 \n",
"2023-04-27 15:28:00-04:00 168.380 168.430 168.320 168.3285 152193.0 \n",
"2023-04-27 15:29:00-04:00 168.325 168.330 168.260 168.2850 208426.0 \n",
"2023-04-27 15:30:00-04:00 168.280 168.350 168.255 168.3450 218077.0 \n",
"\n",
" trade_count vwap \n",
"timestamp \n",
"2023-02-28 09:30:00-05:00 6447.0 146.914560 \n",
"2023-02-28 09:31:00-05:00 6855.0 147.252171 \n",
"2023-02-28 09:32:00-05:00 7435.0 147.210128 \n",
"2023-02-28 09:33:00-05:00 7235.0 147.154832 \n",
"2023-02-28 09:34:00-05:00 4965.0 147.001762 \n",
"... ... ... \n",
"2023-04-27 15:26:00-04:00 1398.0 168.368809 \n",
"2023-04-27 15:27:00-04:00 1420.0 168.364799 \n",
"2023-04-27 15:28:00-04:00 1361.0 168.372671 \n",
"2023-04-27 15:29:00-04:00 1736.0 168.297379 \n",
"2023-04-27 15:30:00-04:00 1694.0 168.308873 \n",
"\n",
"[15162 rows x 7 columns]"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.loc[('AAPL',)]"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th></th>\n",
" <th>open</th>\n",
" <th>high</th>\n",
" <th>low</th>\n",
" <th>close</th>\n",
" <th>volume</th>\n",
" <th>trade_count</th>\n",
" <th>vwap</th>\n",
" </tr>\n",
" <tr>\n",
" <th>symbol</th>\n",
" <th>timestamp</th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th rowspan=\"11\" valign=\"top\">AAPL</th>\n",
" <th>2023-02-27 18:52:00-05:00</th>\n",
" <td>148.0200</td>\n",
" <td>148.02</td>\n",
" <td>148.0200</td>\n",
" <td>148.02</td>\n",
" <td>112.0</td>\n",
" <td>7.0</td>\n",
" <td>148.020000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2023-02-27 18:56:00-05:00</th>\n",
" <td>148.0200</td>\n",
" <td>148.02</td>\n",
" <td>148.0200</td>\n",
" <td>148.02</td>\n",
" <td>175.0</td>\n",
" <td>10.0</td>\n",
" <td>148.020000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2023-02-27 19:00:00-05:00</th>\n",
" <td>148.0299</td>\n",
" <td>148.03</td>\n",
" <td>148.0299</td>\n",
" <td>148.03</td>\n",
" <td>1957.0</td>\n",
" <td>10.0</td>\n",
" <td>148.029993</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2023-02-27 19:06:00-05:00</th>\n",
" <td>148.0600</td>\n",
" <td>148.06</td>\n",
" <td>148.0600</td>\n",
" <td>148.06</td>\n",
" <td>122.0</td>\n",
" <td>7.0</td>\n",
" <td>148.060000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2023-02-27 19:09:00-05:00</th>\n",
" <td>148.0500</td>\n",
" <td>148.10</td>\n",
" <td>148.0500</td>\n",
" <td>148.10</td>\n",
" <td>1604.0</td>\n",
" <td>33.0</td>\n",
" <td>148.075109</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2023-04-27 19:54:00-04:00</th>\n",
" <td>167.8000</td>\n",
" <td>167.80</td>\n",
" <td>167.8000</td>\n",
" <td>167.80</td>\n",
" <td>534.0</td>\n",
" <td>15.0</td>\n",
" <td>167.800000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2023-04-27 19:56:00-04:00</th>\n",
" <td>167.8800</td>\n",
" <td>167.88</td>\n",
" <td>167.8800</td>\n",
" <td>167.88</td>\n",
" <td>1386.0</td>\n",
" <td>28.0</td>\n",
" <td>167.880000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2023-04-27 19:57:00-04:00</th>\n",
" <td>167.8000</td>\n",
" <td>167.80</td>\n",
" <td>167.8000</td>\n",
" <td>167.80</td>\n",
" <td>912.0</td>\n",
" <td>60.0</td>\n",
" <td>167.800000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2023-04-27 19:58:00-04:00</th>\n",
" <td>167.8000</td>\n",
" <td>167.88</td>\n",
" <td>167.8000</td>\n",
" <td>167.88</td>\n",
" <td>3311.0</td>\n",
" <td>22.0</td>\n",
" <td>167.877333</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2023-04-27 19:59:00-04:00</th>\n",
" <td>167.9000</td>\n",
" <td>167.94</td>\n",
" <td>167.9000</td>\n",
" <td>167.94</td>\n",
" <td>1969.0</td>\n",
" <td>64.0</td>\n",
" <td>167.918150</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>31217 rows × 7 columns</p>\n",
"</div>"
],
"text/plain": [
" open high low close volume \\\n",
"symbol timestamp \n",
"AAPL 2023-02-27 18:52:00-05:00 148.0200 148.02 148.0200 148.02 112.0 \n",
" 2023-02-27 18:56:00-05:00 148.0200 148.02 148.0200 148.02 175.0 \n",
" 2023-02-27 19:00:00-05:00 148.0299 148.03 148.0299 148.03 1957.0 \n",
" 2023-02-27 19:06:00-05:00 148.0600 148.06 148.0600 148.06 122.0 \n",
" 2023-02-27 19:09:00-05:00 148.0500 148.10 148.0500 148.10 1604.0 \n",
"... ... ... ... ... ... \n",
" 2023-04-27 19:54:00-04:00 167.8000 167.80 167.8000 167.80 534.0 \n",
" 2023-04-27 19:56:00-04:00 167.8800 167.88 167.8800 167.88 1386.0 \n",
" 2023-04-27 19:57:00-04:00 167.8000 167.80 167.8000 167.80 912.0 \n",
" 2023-04-27 19:58:00-04:00 167.8000 167.88 167.8000 167.88 3311.0 \n",
" 2023-04-27 19:59:00-04:00 167.9000 167.94 167.9000 167.94 1969.0 \n",
"\n",
" trade_count vwap \n",
"symbol timestamp \n",
"AAPL 2023-02-27 18:52:00-05:00 7.0 148.020000 \n",
" 2023-02-27 18:56:00-05:00 10.0 148.020000 \n",
" 2023-02-27 19:00:00-05:00 10.0 148.029993 \n",
" 2023-02-27 19:06:00-05:00 7.0 148.060000 \n",
" 2023-02-27 19:09:00-05:00 33.0 148.075109 \n",
"... ... ... \n",
" 2023-04-27 19:54:00-04:00 15.0 167.800000 \n",
" 2023-04-27 19:56:00-04:00 28.0 167.880000 \n",
" 2023-04-27 19:57:00-04:00 60.0 167.800000 \n",
" 2023-04-27 19:58:00-04:00 22.0 167.877333 \n",
" 2023-04-27 19:59:00-04:00 64.0 167.918150 \n",
"\n",
"[31217 rows x 7 columns]"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

File diff suppressed because it is too large Load Diff

View File

@ -1,69 +0,0 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The autoreload extension is already loaded. To reload it, use:\n",
" %reload_ext autoreload\n"
]
},
{
"data": {
"text/plain": [
"['CUVWAP', 'DIVRELN']"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
"import vectorbtpro as vbt\n",
"from ttools.vbtindicators import register_custom_inds\n",
"from ttools.indicators import CUVWAP\n",
"\n",
"\n",
"register_custom_inds(None, \"override\")\n",
"#chopiness = vbt.indicator(\"technical:CHOPINESS\").run(s12_data.open, s12_data.high, s12_data.low, s12_data.close, s12_data.volume, window = 100)\n",
"#vwap_cum_roll = vbt.indicator(\"technical:ROLLING_VWAP\").run(s12_data.open, s12_data.high, s12_data.low, s12_data.close, s12_data.volume, window = 100, min_periods = 5)\n",
"#vwap_cum_d = vbt.indicator(\"ttools:CUVWAP\").run(s12_data.high, s12_data.low, s12_data.close, s12_data.volume, anchor=\"D\", drag=50)\n",
"#vwap_lin_angle = vbt.indicator(\"talib:LINEARREG_ANGLE\").run(vwap_cum_d.vwap, timeperiod=2)\n",
"\n",
"vbt.IF.list_indicators(\"ttools\")\n",
"\n",
"\n",
"\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -1,4 +1,5 @@
from .vbtutils import AnchoredIndicator, create_mask_from_window, isrising, isfalling, isrisingc, isfallingc, trades2entries_exits, figs2cell
from .vbtindicators import register_custom_inds
from .utils import AggType, zoneNY, zonePRG, zoneUTC
from .loaders import load_data, prepare_trade_cache
from .loaders import load_data, prepare_trade_cache
from .external_loaders import load_history_bars

View File

@ -10,8 +10,80 @@ Includes fetch (remote/cached) methods and numba aggregator function for TIME BA
"""""
def aggregate_trades_optimized(symbol: str, trades_df: pd.DataFrame, resolution: int, type: AggType = AggType.OHLCV, clear_input: bool = False):
"""
Optimized version of trade aggregation function with reduced memory footprint.
"""
# 1. Get timestamps from index if 't' is not in columns
if 't' not in trades_df.columns:
timestamps = trades_df.index.values
else:
timestamps = trades_df['t'].values
# 2. Select only needed columns for prices and sizes
prices = trades_df['p'].values
sizes = trades_df['s'].values
#Clears input to freeup memory
if clear_input:
del trades_df
# 3. Convert timestamps maintaining exact precision
# Convert directly to int64 nanoseconds, then to float seconds - there was a problem
#unix_timestamps_s = timestamps.view('int64').astype(np.float64) / 1e6
#original not optimized, in case of issues (5x slower)
unix_timestamps_s = timestamps.astype('datetime64[ns]').astype(np.float64) / 1e9
# 4. Create ticks array efficiently
# 3. Pre-allocate array for better memory efficiency
ticks = np.empty((len(timestamps), 3), dtype=np.float64)
ticks[:, 0] = unix_timestamps_s
ticks[:, 1] = prices
ticks[:, 2] = sizes
# 5. Clear memory of intermediate objects
del timestamps, prices, sizes, unix_timestamps_s
# 6. Process based on type using existing pattern
try:
match type:
case AggType.OHLCV:
ohlcv_bars = generate_time_bars_nb(ticks, resolution)
columns = ['time', 'open', 'high', 'low', 'close', 'volume', 'trades',
'updated', 'vwap', 'buyvolume', 'sellvolume', 'buytrades', 'selltrades']
case AggType.OHLCV_VOL:
ohlcv_bars = generate_volume_bars_nb(ticks, resolution)
columns = ['time', 'open', 'high', 'low', 'close', 'volume', 'trades',
'updated', 'buyvolume', 'sellvolume', 'buytrades', 'selltrades']
case AggType.OHLCV_DOL:
ohlcv_bars = generate_dollar_bars_nb(ticks, resolution)
columns = ['time', 'open', 'high', 'low', 'close', 'volume', 'trades',
'amount', 'updated']
case _:
raise ValueError("Invalid AggType type. Supported types are 'time', 'volume' and 'dollar'.")
finally:
# 7. Clear large numpy array as soon as possible
del ticks
# 8. Create DataFrame and handle timestamps - keeping original working approach
ohlcv_df = pd.DataFrame(ohlcv_bars, columns=columns)
del ohlcv_bars
# 9. Use the original timestamp handling that we know works
ohlcv_df['time'] = pd.to_datetime(ohlcv_df['time'], unit='s').dt.tz_localize('UTC').dt.tz_convert(zoneNY)
ohlcv_df['updated'] = pd.to_datetime(ohlcv_df['updated'], unit="s").dt.tz_localize('UTC').dt.tz_convert(zoneNY)
# 10. Round microseconds as in original
ohlcv_df['updated'] = ohlcv_df['updated'].dt.round('us')
# 11. Set index last, as in original
ohlcv_df.set_index('time', inplace=True)
return ohlcv_df
def aggregate_trades(symbol: str, trades_df: pd.DataFrame, resolution: int, type: AggType = AggType.OHLCV):
""""
Original replaced by optimized version
Accepts dataframe with trades keyed by symbol. Preparess dataframe to
numpy and calls Numba optimized aggregator for given bar type. (time/volume/dollar)
"""""
@ -44,9 +116,13 @@ def aggregate_trades(symbol: str, trades_df: pd.DataFrame, resolution: int, type
columns.append('vwap')
columns.append('buyvolume')
columns.append('sellvolume')
columns.append('buytrades')
columns.append('selltrades')
if type == AggType.OHLCV_VOL:
columns.append('buyvolume')
columns.append('sellvolume')
columns.append('buytrades')
columns.append('selltrades')
ohlcv_df = pd.DataFrame(ohlcv_bars, columns=columns)
ohlcv_df['time'] = pd.to_datetime(ohlcv_df['time'], unit='s').dt.tz_localize('UTC').dt.tz_convert(zoneNY)
#print(ohlcv_df['updated'])
@ -174,22 +250,24 @@ def generate_volume_bars_nb(ticks, volume_per_bar):
close_price = ticks[0, 1]
volume = 0
trades_count = 0
trades_buy_count = 0
trades_sell_count = 0
current_day = np.floor(ticks[0, 0] / 86400) # Calculate the initial day from the first tick timestamp
bar_time = ticks[0, 0] # Initialize bar time with the time of the first tick
buy_volume = 0 # Volume of buy trades
sell_volume = 0 # Volume of sell trades
prev_price = ticks[0, 1] # Initialize previous price for the first tick
last_tick_up = None
for tick in ticks:
tick_time = tick[0]
price = tick[1]
tick_volume = tick[2]
tick_day = np.floor(tick_time / 86400) # Calculate the day of the current tick
splitted = False
# Check if the new tick is from a different day, then close the current bar
if tick_day != current_day:
if trades_count > 0:
ohlcv_bars.append([bar_time, open_price, high_price, low_price, close_price, volume, trades_count, tick_time, buy_volume, sell_volume])
ohlcv_bars.append([bar_time, open_price, high_price, low_price, close_price, volume, trades_count, tick_time, buy_volume, sell_volume, trades_buy_count, trades_sell_count])
# Reset for the new day using the current tick data
open_price = price
high_price = price
@ -197,6 +275,8 @@ def generate_volume_bars_nb(ticks, volume_per_bar):
close_price = price
volume = 0
trades_count = 0
trades_buy_count = 0
trades_sell_count = 0
remaining_volume = volume_per_bar
current_day = tick_day
bar_time = tick_time # Update bar time to the current tick time
@ -219,8 +299,21 @@ def generate_volume_bars_nb(ticks, volume_per_bar):
# Update buy and sell volumes
if price > prev_price:
buy_volume += tick_volume
trades_buy_count += 1
last_tick_up = True
elif price < prev_price:
sell_volume += tick_volume
trades_sell_count += 1
last_tick_up = False
else: #same price, use last direction
if last_tick_up is None:
pass
elif last_tick_up:
buy_volume += tick_volume
trades_buy_count += 1
else:
sell_volume += tick_volume
trades_sell_count += 1
tick_volume = 0
else:
@ -233,11 +326,24 @@ def generate_volume_bars_nb(ticks, volume_per_bar):
# Update buy and sell volumes
if price > prev_price:
buy_volume += volume_to_add
trades_buy_count += 1
last_tick_up = True
elif price < prev_price:
sell_volume += volume_to_add
trades_sell_count += 1
last_tick_up = False
else: #same price, use last direction
if last_tick_up is None:
pass
elif last_tick_up:
buy_volume += volume_to_add
trades_buy_count += 1
else:
sell_volume += volume_to_add
trades_sell_count += 1
# Append the completed bar to the list
ohlcv_bars.append([bar_time, open_price, high_price, low_price, close_price, volume, trades_count, tick_time, buy_volume, sell_volume])
ohlcv_bars.append([bar_time, open_price, high_price, low_price, close_price, volume, trades_count, tick_time, buy_volume, sell_volume, trades_buy_count, trades_sell_count])
# Reset bar values for the new bar using the current tick data
open_price = price
@ -246,21 +352,26 @@ def generate_volume_bars_nb(ticks, volume_per_bar):
close_price = price
volume = 0
trades_count = 0
trades_buy_count = 0
trades_sell_count = 0
remaining_volume = volume_per_bar
buy_volume = 0
sell_volume = 0
# Increment bar time if splitting a trade
if tick_volume > 0: # If there's remaining volume in the trade, set bar time slightly later
bar_time = tick_time + 1e-6
#if the same trade opened the bar (we are splitting trade to more bars)
#first splitted identified by time, next by flag
if bar_time == tick_time or splitted:
bar_time = bar_time + 1e-6
splitted = True
else:
bar_time = tick_time # Otherwise, set bar time to the tick time
bar_time = tick_time
splitted = False
prev_price = price
# Add the last bar if it contains any trades
if trades_count > 0:
ohlcv_bars.append([bar_time, open_price, high_price, low_price, close_price, volume, trades_count, tick_time, buy_volume, sell_volume])
ohlcv_bars.append([bar_time, open_price, high_price, low_price, close_price, volume, trades_count, tick_time, buy_volume, sell_volume, trades_buy_count, trades_sell_count])
return np.array(ohlcv_bars)
@ -284,13 +395,15 @@ def generate_time_bars_nb(ticks, resolution):
close_price = 0
volume = 0
trades_count = 0
trades_buy_count = 0
trades_sell_count = 0
vwap_cum_volume_price = 0 # Cumulative volume * price
cum_volume = 0 # Cumulative volume for VWAP
buy_volume = 0 # Volume of buy trades
sell_volume = 0 # Volume of sell trades
prev_price = ticks[0, 1] # Initialize previous price for the first tick
prev_day = np.floor(ticks[0, 0] / 86400) # Calculate the initial day from the first tick timestamp
last_tick_up = None
for tick in ticks:
curr_time = tick[0] #updated time
tick_time = np.floor(tick[0] / resolution) * resolution
@ -307,7 +420,7 @@ def generate_time_bars_nb(ticks, resolution):
if tick_time != start_time + current_bar_index * resolution:
if current_bar_index >= 0 and trades_count > 0: # Save the previous bar if trades happened
vwap = vwap_cum_volume_price / cum_volume if cum_volume > 0 else 0
ohlcv_bars.append([start_time + current_bar_index * resolution, open_price, high_price, low_price, close_price, volume, trades_count, curr_time, vwap, buy_volume, sell_volume])
ohlcv_bars.append([start_time + current_bar_index * resolution, open_price, high_price, low_price, close_price, volume, trades_count, curr_time, vwap, buy_volume, sell_volume, trades_buy_count, trades_sell_count])
# Reset bar values
current_bar_index = int((tick_time - start_time) / resolution)
@ -316,6 +429,8 @@ def generate_time_bars_nb(ticks, resolution):
low_price = price
volume = 0
trades_count = 0
trades_buy_count = 0
trades_sell_count = 0
vwap_cum_volume_price = 0
cum_volume = 0
buy_volume = 0
@ -333,15 +448,28 @@ def generate_time_bars_nb(ticks, resolution):
# Update buy and sell volumes
if price > prev_price:
buy_volume += tick_volume
trades_buy_count += 1
last_tick_up = True
elif price < prev_price:
sell_volume += tick_volume
trades_sell_count += 1
last_tick_up = False
else: #same price, use last direction
if last_tick_up is None:
pass
elif last_tick_up:
buy_volume += tick_volume
trades_buy_count += 1
else:
sell_volume += tick_volume
trades_sell_count += 1
prev_price = price
# Save the last processed bar
if trades_count > 0:
vwap = vwap_cum_volume_price / cum_volume if cum_volume > 0 else 0
ohlcv_bars.append([start_time + current_bar_index * resolution, open_price, high_price, low_price, close_price, volume, trades_count, curr_time, vwap, buy_volume, sell_volume])
ohlcv_bars.append([start_time + current_bar_index * resolution, open_price, high_price, low_price, close_price, volume, trades_count, curr_time, vwap, buy_volume, sell_volume, trades_buy_count, trades_sell_count])
return np.array(ohlcv_bars)

View File

@ -0,0 +1,55 @@
from ctypes import Union
from ttools import zoneUTC
from ttools.config import *
from datetime import datetime
from alpaca.data.historical import StockHistoricalDataClient
from ttools.config import ACCOUNT1_LIVE_API_KEY, ACCOUNT1_LIVE_SECRET_KEY
from datetime import timedelta, datetime, time
from alpaca.data.enums import DataFeed
from typing import List, Union
import pandas as pd
from alpaca.data.historical import StockHistoricalDataClient
from alpaca.data.requests import StockBarsRequest
from alpaca.data.enums import DataFeed
from alpaca.data.timeframe import TimeFrame, TimeFrameUnit
def load_history_bars(symbol: Union[str, List[str]], datetime_object_from: datetime, datetime_object_to: datetime, timeframe: TimeFrame, main_session_only: bool = True):
"""Returns dataframe fetched remotely from Alpaca.
Args:
symbol: symbol or list of symbols
datetime_object_from: datetime in zoneNY
datetime_object_to: datetime in zoneNY
timeframe: timeframe
main_session_only: boolean to fetch only main session data
Returns:
dataframe
Example:
```python
from ttools.external_loaders import load_history_bars
from ttools.config import zoneNY
from datetime import datetime
from alpaca.data.timeframe import TimeFrame, TimeFrameUnit
symbol = "AAPL"
start_date = zoneNY.localize(datetime(2023, 2, 27, 18, 51, 38))
end_date = zoneNY.localize(datetime(2023, 4, 27, 21, 51, 39))
timeframe = TimeFrame(amount=1,unit=TimeFrameUnit.Minute)
df = load_history_bars(symbol, start_date, end_date, timeframe)
```
"""
client = StockHistoricalDataClient(ACCOUNT1_LIVE_API_KEY, ACCOUNT1_LIVE_SECRET_KEY, raw_data=False)
#datetime_object_from = datetime(2023, 2, 27, 18, 51, 38, tzinfo=datetime.timezone.utc)
#datetime_object_to = datetime(2023, 2, 27, 21, 51, 39, tzinfo=datetime.timezone.utc)
bar_request = StockBarsRequest(symbol_or_symbols=symbol,timeframe=timeframe, start=datetime_object_from, end=datetime_object_to, feed=DataFeed.SIP)
#print("before df")
df = client.get_stock_bars(bar_request).df
df.index = df.index.set_levels(df.index.get_level_values(1).tz_convert(zoneNY), level=1)
if main_session_only:
start_time = time(9, 30, 0)
end_time = time(15, 30, 0)
df = df.loc[(df.index.get_level_values(1).time >= start_time) & (df.index.get_level_values(1).time <= end_time)]
return df

View File

@ -1,5 +1,6 @@
from ctypes import Union
from ttools import zoneUTC
from ttools.config import *
from datetime import datetime
from alpaca.data.historical import StockHistoricalDataClient
@ -17,8 +18,14 @@ from ttools.utils import AggType, fetch_calendar_data, print, print_matching_fil
from tqdm import tqdm
import threading
from typing import List, Union
from ttools.aggregator_vectorized import aggregate_trades
from ttools.aggregator_vectorized import aggregate_trades, aggregate_trades_optimized
import numpy as np
import pandas as pd
import pyarrow.dataset as ds
import pandas as pd
from concurrent.futures import ThreadPoolExecutor
import math
import os
"""
Module for fetching stock data. Supports
1) cache management
@ -87,6 +94,8 @@ def convert_dict_to_multiindex_df(tradesResponse, rename_labels = True, keep_sym
final_df.reset_index(inplace=True) # Reset index to remove MultiIndex levels, making them columns
final_df.drop(columns=['symbol'], inplace=True) #remove symbol column
final_df.set_index(timestamp_col, inplace=True) #reindex by timestamp
#print index datetime resolution
#print(final_df.index.dtype)
return final_df
@ -106,6 +115,28 @@ def filter_trade_df(df: pd.DataFrame, start: datetime = None, end: datetime = No
Returns:
df: pd.DataFrame
"""
def fast_filter(df, exclude_conditions):
# Convert arrays to strings once
str_series = df['c'].apply(lambda x: ','.join(x))
# Create mask using vectorized string operations
mask = np.zeros(len(df), dtype=bool)
for cond in exclude_conditions:
mask |= str_series.str.contains(cond, regex=False)
# Apply filter
return df[~mask]
def vectorized_string_sets(df, exclude_conditions):
# Convert exclude_conditions to set for O(1) lookup
exclude_set = set(exclude_conditions)
# Vectorized operation using sets intersection
arrays = df['c'].values
mask = np.array([bool(set(arr) & exclude_set) for arr in arrays])
return df[~mask]
# 9:30 to 16:00
if main_session_only:
@ -120,30 +151,50 @@ def filter_trade_df(df: pd.DataFrame, start: datetime = None, end: datetime = No
#REQUIRED FILTERING
# Create a mask to filter rows within the specified time range
if start is not None and end is not None:
print(f"filtering {start.time()} {end.time()}")
print(f"Trimming {start} {end}")
if symbol_included:
mask = (df.index.get_level_values('t') >= start) & \
(df.index.get_level_values('t') <= end)
df = df[mask]
else:
mask = (df.index >= start) & (df.index <= end)
# Apply the mask to the DataFrame
df = df[mask]
df = df.loc[start:end]
if exclude_conditions is not None:
print(f"excluding {exclude_conditions}")
# Create a mask to exclude rows with any of the specified conditions
mask = df['c'].apply(lambda x: any(cond in exclude_conditions for cond in x))
# Filter out the rows with specified conditions
df = df[~mask]
df = vectorized_string_sets(df, exclude_conditions)
print("exclude done")
if minsize is not None:
print(f"minsize {minsize}")
#exclude conditions
df = df[df['s'] >= minsize]
print("minsize done")
return df
def calculate_optimal_workers(file_count, min_workers=4, max_workers=32):
"""
Calculate optimal number of workers based on file count and system resources
Rules of thumb:
- Minimum of 4 workers to ensure parallelization
- Maximum of 32 workers to avoid thread overhead
- For 100 files, aim for around 16-24 workers
- Scale with CPU count but don't exceed max_workers
"""
cpu_count = os.cpu_count() or 4
# Base calculation: 2-4x CPU count for I/O bound tasks
suggested_workers = cpu_count * 3
# Scale based on file count (1 worker per 4-6 files is a good ratio)
files_based_workers = math.ceil(file_count / 5)
# Take the smaller of the two suggestions
optimal_workers = min(suggested_workers, files_based_workers)
# Clamp between min and max workers
return max(min_workers, min(optimal_workers, max_workers))
def fetch_daily_stock_trades(symbol, start, end, exclude_conditions=None, minsize=None, main_session_only=True, no_return=False,force_remote=False, rename_labels = False, keep_symbols=False, max_retries=5, backoff_factor=1, data_feed: DataFeed = DataFeed.SIP, verbose = None):
#doc for this function
"""
@ -152,8 +203,8 @@ def fetch_daily_stock_trades(symbol, start, end, exclude_conditions=None, minsiz
by using force_remote - forcess using remote data always and thus refreshing cache for these dates
Attributes:
:param symbol: The stock symbol to fetch trades for.
:param start: The start time for the trade data.
:param end: The end time for the trade data.
:param start: The start time for the trade data, in market timezone.
:param end: The end time for the trade data, in market timezone.
:exclude_conditions: list of string conditions to exclude from the data
:minsize minimum size of trade to be included in the data
:no_return: If True, do not return the DataFrame. Used to prepare cached files.
@ -181,24 +232,34 @@ def fetch_daily_stock_trades(symbol, start, end, exclude_conditions=None, minsiz
#exists in cache?
daily_file = f"{symbol}-{str(start.date())}.parquet"
file_path = TRADE_CACHE / daily_file
if file_path.exists() and (not force_remote or not no_return):
if file_path.exists() and (not force_remote and not no_return):
with trade_cache_lock:
df = pd.read_parquet(file_path)
print("Loaded from CACHE", file_path)
df = filter_trade_df(df, start, end, exclude_conditions, minsize, symbol_included=False, main_session_only=main_session_only)
return df
day_next = start.date() + timedelta(days=1)
#lets create borders of day in UTC as Alpaca has only UTC date
start_date = start.date()
# Create min/max times in NY timezone
ny_day_min = zoneNY.localize(datetime.combine(start_date, time.min))
ny_day_max = zoneNY.localize(datetime.combine(start_date, time.max))
# Convert both to UTC
utc_day_min = ny_day_min.astimezone(zoneUTC)
utc_day_max = ny_day_max.astimezone(zoneUTC)
print("Fetching from remote.")
client = StockHistoricalDataClient(ACCOUNT1_LIVE_API_KEY, ACCOUNT1_LIVE_SECRET_KEY, raw_data=True)
stockTradeRequest = StockTradesRequest(symbol_or_symbols=symbol, start=start.date(), end=day_next, feed=data_feed)
stockTradeRequest = StockTradesRequest(symbol_or_symbols=symbol, start=utc_day_min, end=utc_day_max, feed=data_feed)
last_exception = None
for attempt in range(max_retries):
try:
tradesResponse = client.get_stock_trades(stockTradeRequest)
print(f"Remote fetched completed.", start.date(), day_next)
print(f"Remote fetched completed whole day", start.date())
print(f"Exact UTC range fetched: {utc_day_min} - {utc_day_max}")
if not tradesResponse[symbol]:
print(f"EMPTY")
return pd.DataFrame()
@ -206,7 +267,7 @@ def fetch_daily_stock_trades(symbol, start, end, exclude_conditions=None, minsiz
df = convert_dict_to_multiindex_df(tradesResponse, rename_labels=rename_labels, keep_symbols=keep_symbols)
#if today is market still open, dont cache - also dont cache for IEX feeed
if datetime.now().astimezone(zoneNY).date() < day_next or data_feed == DataFeed.IEX:
if datetime.now().astimezone(zoneNY).date() < start_date + timedelta(days=1) or data_feed == DataFeed.IEX:
print("not saving trade cache, market still open today or IEX datapoint")
#ic(datetime.now().astimezone(zoneNY))
#ic(day.open, day.close)
@ -227,7 +288,7 @@ def fetch_daily_stock_trades(symbol, start, end, exclude_conditions=None, minsiz
print("All attempts to fetch data failed.")
raise ConnectionError(f"Failed to fetch stock trades after {max_retries} retries. Last exception: {str(last_exception)} and {format_exc()}")
def fetch_trades_parallel(symbol, start_date, end_date, exclude_conditions = EXCLUDE_CONDITIONS, minsize = 100, main_session_only = True, force_remote = False, max_workers=None, no_return = False, verbose = None):
def fetch_trades_parallel(symbol, start_date, end_date, exclude_conditions = EXCLUDE_CONDITIONS, minsize = None, main_session_only = True, force_remote = False, max_workers=None, no_return = False, verbose = None):
"""
Fetch trades between ranges.
@ -281,7 +342,12 @@ def fetch_trades_parallel(symbol, start_date, end_date, exclude_conditions = EXC
#speed it up , locals first and then fetches
s_time = timetime()
with trade_cache_lock:
local_df = pd.concat([pd.read_parquet(f) for _,f in days_from_cache])
file_paths = [f for _, f in days_from_cache]
dataset = ds.dataset(file_paths, format='parquet')
local_df = dataset.to_table().to_pandas()
del dataset
#original version
#local_df = pd.concat([pd.read_parquet(f) for _,f in days_from_cache])
final_time = timetime() - s_time
print(f"{symbol} All {len(days_from_cache)} split files loaded in", final_time, "seconds")
#the filter is required
@ -291,6 +357,7 @@ def fetch_trades_parallel(symbol, start_date, end_date, exclude_conditions = EXC
#do this only for remotes
if len(days_from_remote) > 0:
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures_with_date = []
#for single_date in (start_date + timedelta(days=i) for i in range((end_date - start_date).days + 1)):
for market_day in tqdm(days_from_remote, desc=f"{symbol} Remote fetching"):
#start = datetime.combine(single_date, time(9, 30)) # Market opens at 9:30 AM
@ -313,14 +380,19 @@ def fetch_trades_parallel(symbol, start_date, end_date, exclude_conditions = EXC
end = min(end_date, max_day_time)
future = executor.submit(fetch_daily_stock_trades, symbol, start, end, exclude_conditions, minsize, main_session_only, no_return, force_remote)
futures.append(future)
futures_with_date.append((future,start))
for future in tqdm(futures, desc=f"{symbol} Receiving trades"):
results_with_dates = []
for future, date in tqdm(futures_with_date, desc=f"{symbol} Receiving trades"):
try:
result = future.result()
results.append(result)
if result is not None:
results_with_dates.append((result,date))
except Exception as e:
print(f"Error fetching data for a day: {e}")
# Sort by date before concatenating
results_with_dates.sort(key=lambda x: x[1])
results = [r for r, _ in results_with_dates]
if not no_return:
# Batch concatenation to improve speed
@ -413,7 +485,7 @@ def load_data(symbol: Union[str, List[str]],
else:
#neslo by zrychlit, kdyz se zobrazuje pomalu Searching cache - nejaky bottle neck?
df = fetch_trades_parallel(symbol, start_date, end_date, minsize=minsize, exclude_conditions=exclude_conditions, main_session_only=main_session_only, force_remote=force_remote) #exclude_conditions=['C','O','4','B','7','V','P','W','U','Z','F'])
ohlcv_df = aggregate_trades(symbol=symbol, trades_df=df, resolution=resolution, type=agg_type)
ohlcv_df = aggregate_trades_optimized(symbol=symbol, trades_df=df, resolution=resolution, type=agg_type, clear_input = True)
ohlcv_df.to_parquet(file_ohlcv, engine='pyarrow')
print(f"{symbol} Saved to agg_cache", file_ohlcv)

2226
ttools/models.py Normal file

File diff suppressed because it is too large Load Diff

View File

@ -273,4 +273,147 @@ class StartBarAlign(str, Enum):
RANDOM = first bar starts when first trade occurs
"""
ROUND = "round"
RANDOM = "random"
RANDOM = "random"
def compare_dataframes(df1, df2, name1="DataFrame 1", name2="DataFrame 2", check_dtype=True):
"""
Compare two DataFrames and provide detailed analysis of their differences.
Parameters:
-----------
df1, df2 : pandas.DataFrame
The DataFrames to compare
name1, name2 : str
Names to identify the DataFrames in the output
check_dtype : bool
Whether to check if dtypes match for columns
Returns:
--------
bool
True if DataFrames are identical (based on check_dtype parameter)
dict
Detailed comparison results
"""
results = {
'are_equal': False,
'shape_match': False,
'column_match': False,
'index_match': False,
'dtype_match': False,
'content_match': False,
'differences': {}
}
# Shape comparison
if df1.shape != df2.shape:
results['differences']['shape'] = {
name1: df1.shape,
name2: df2.shape
}
else:
results['shape_match'] = True
# Column comparison
cols1 = set(df1.columns)
cols2 = set(df2.columns)
if cols1 != cols2:
results['differences']['columns'] = {
f'unique_to_{name1}': list(cols1 - cols2),
f'unique_to_{name2}': list(cols2 - cols1),
'common': list(cols1 & cols2)
}
else:
results['column_match'] = True
# Index comparison
idx1 = set(df1.index)
idx2 = set(df2.index)
if idx1 != idx2:
results['differences']['index'] = {
f'unique_to_{name1}': list(idx1 - idx2),
f'unique_to_{name2}': list(idx2 - idx1),
'common': list(idx1 & idx2)
}
else:
results['index_match'] = True
# dtype comparison
if check_dtype and results['column_match']:
dtype_diff = {}
for col in cols1:
if df1[col].dtype != df2[col].dtype:
dtype_diff[col] = {
name1: str(df1[col].dtype),
name2: str(df2[col].dtype)
}
if dtype_diff:
results['differences']['dtypes'] = dtype_diff
else:
results['dtype_match'] = True
# Content comparison (only for matching columns and indices)
if results['column_match'] and results['index_match']:
common_cols = list(cols1)
common_idx = list(idx1)
value_diff = {}
for col in common_cols:
# Compare values
if not df1[col].equals(df2[col]):
# Find specific differences
mask = df1[col] != df2[col]
if any(mask):
diff_indices = df1.index[mask]
value_diff[col] = {
'different_at_indices': list(diff_indices),
'sample_differences': {
str(idx): {
name1: df1.loc[idx, col],
name2: df2.loc[idx, col]
} for idx in list(diff_indices)[:5] # Show first 5 differences
}
}
if value_diff:
results['differences']['values'] = value_diff
else:
results['content_match'] = True
# Overall equality
results['are_equal'] = all([
results['shape_match'],
results['column_match'],
results['index_match'],
results['content_match'],
(results['dtype_match'] if check_dtype else True)
])
# Print summary
print(f"\nComparison Summary of {name1} vs {name2}:")
print(f"Shape Match: {results['shape_match']} ({df1.shape} vs {df2.shape})")
print(f"Column Match: {results['column_match']}")
print(f"Index Match: {results['index_match']}")
print(f"Dtype Match: {results['dtype_match']}" if check_dtype else "Dtype Check: Skipped")
print(f"Content Match: {results['content_match']}")
print(f"\nOverall Equal: {results['are_equal']}")
# Print detailed differences if any
if not results['are_equal']:
print("\nDetailed Differences:")
for diff_type, diff_content in results['differences'].items():
print(f"\n{diff_type.upper()}:")
if diff_type == 'values':
print(f"Number of columns with differences: {len(diff_content)}")
for col, details in diff_content.items():
print(f"\nColumn '{col}':")
print(f"Number of different values: {len(details['different_at_indices'])}")
print("First few differences:")
for idx, vals in details['sample_differences'].items():
print(f" At index {idx}:")
print(f" {name1}: {vals[name1]}")
print(f" {name2}: {vals[name2]}")
else:
print(diff_content)
return results['are_equal'], results