gpu support added
This commit is contained in:
2
setup.py
2
setup.py
@ -2,7 +2,7 @@ from setuptools import setup, find_packages
|
|||||||
|
|
||||||
setup(
|
setup(
|
||||||
name='ttools',
|
name='ttools',
|
||||||
version='0.7.91',
|
version='0.7.92',
|
||||||
packages=find_packages(),
|
packages=find_packages(),
|
||||||
install_requires=[
|
install_requires=[
|
||||||
# list your dependencies here
|
# list your dependencies here
|
||||||
|
|||||||
@ -27,6 +27,31 @@ from traceback import format_exc
|
|||||||
from scipy.stats import entropy
|
from scipy.stats import entropy
|
||||||
import pickle
|
import pickle
|
||||||
from itertools import zip_longest
|
from itertools import zip_longest
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
|
||||||
|
def set_gpu_params(model):
|
||||||
|
# Check GPU and create classifier
|
||||||
|
has_gpu = check_gpu()
|
||||||
|
if has_gpu:
|
||||||
|
print("GPU is available! Using GPU acceleration.")
|
||||||
|
gpu_params = {
|
||||||
|
'tree_method': 'gpu_hist',
|
||||||
|
'predictor': 'gpu_predictor',
|
||||||
|
'gpu_id': 0
|
||||||
|
}
|
||||||
|
model.set_params(**gpu_params)
|
||||||
|
else:
|
||||||
|
print("GPU not found. Using CPU.")
|
||||||
|
return model
|
||||||
|
|
||||||
|
def check_gpu() -> bool:
|
||||||
|
"""Check if GPU is available via nvidia-smi"""
|
||||||
|
try:
|
||||||
|
subprocess.check_output('nvidia-smi')
|
||||||
|
return True
|
||||||
|
except (subprocess.SubprocessError, FileNotFoundError):
|
||||||
|
return False
|
||||||
|
|
||||||
#https://claude.ai/chat/dc62f18b-f293-4c7e-890d-1e591ce78763
|
#https://claude.ai/chat/dc62f18b-f293-4c7e-890d-1e591ce78763
|
||||||
#skew of return prediction
|
#skew of return prediction
|
||||||
@ -996,6 +1021,8 @@ class LibraryTradingModel:
|
|||||||
else:
|
else:
|
||||||
print("Using default hyperparameters",self.def_params)
|
print("Using default hyperparameters",self.def_params)
|
||||||
model.set_params(**self.def_params)
|
model.set_params(**self.def_params)
|
||||||
|
|
||||||
|
model = set_gpu_params(model)
|
||||||
|
|
||||||
#balance unbalanced classes, works for binary:logistics
|
#balance unbalanced classes, works for binary:logistics
|
||||||
if self.config.n_classes == 2:
|
if self.config.n_classes == 2:
|
||||||
@ -1009,6 +1036,8 @@ class LibraryTradingModel:
|
|||||||
else:
|
else:
|
||||||
print("Using PROVIDED MODEL and SCALER")
|
print("Using PROVIDED MODEL and SCALER")
|
||||||
model = self.use_model
|
model = self.use_model
|
||||||
|
|
||||||
|
model = set_gpu_params(model)
|
||||||
|
|
||||||
print("TEST-----")
|
print("TEST-----")
|
||||||
if self.config.test_on_train:
|
if self.config.test_on_train:
|
||||||
|
|||||||
Reference in New Issue
Block a user