class weights for multiclass for training
This commit is contained in:
2
setup.py
2
setup.py
@ -2,7 +2,7 @@ from setuptools import setup, find_packages
|
|||||||
|
|
||||||
setup(
|
setup(
|
||||||
name='ttools',
|
name='ttools',
|
||||||
version='0.7.98',
|
version='0.7.99',
|
||||||
packages=find_packages(),
|
packages=find_packages(),
|
||||||
install_requires=[
|
install_requires=[
|
||||||
# list your dependencies here
|
# list your dependencies here
|
||||||
|
|||||||
@ -906,22 +906,42 @@ class LibraryTradingModel:
|
|||||||
|
|
||||||
model = set_gpu_params(model)
|
model = set_gpu_params(model)
|
||||||
|
|
||||||
# Handle class imbalance for binary classification
|
# Determine class weights dynamically
|
||||||
|
class_counts = np.bincount(y_fold_train) # Counts the occurrences of each class
|
||||||
|
class_weights = {i: sum(class_counts) / (len(class_counts) * count) for i, count in enumerate(class_counts)}
|
||||||
|
|
||||||
|
# Check number of classes
|
||||||
if self.config.n_classes == 2:
|
if self.config.n_classes == 2:
|
||||||
n_0 = sum(y_fold_train == 0)
|
# For binary classification, use scale_pos_weight
|
||||||
n_1 = sum(y_fold_train == 1)
|
n_0 = class_counts[0]
|
||||||
|
n_1 = class_counts[1]
|
||||||
scale_pos_weight = n_0 / n_1
|
scale_pos_weight = n_0 / n_1
|
||||||
model.set_params(scale_pos_weight=scale_pos_weight)
|
model.set_params(scale_pos_weight=scale_pos_weight)
|
||||||
|
else:
|
||||||
# Train with early stopping
|
# For multiclass classification, use sample_weight
|
||||||
model.fit(
|
sample_weights = [class_weights[label] for label in y_fold_train]
|
||||||
X_fold_train,
|
|
||||||
y_fold_train,
|
print("Model Training...")
|
||||||
eval_set=[(X_fold_val, y_fold_val)],
|
if self.config.n_classes == 2:
|
||||||
#early_stopping_rounds=50,
|
# Train with early stopping
|
||||||
#verbose=2
|
model.fit(
|
||||||
)
|
X_fold_train,
|
||||||
|
y_fold_train,
|
||||||
|
eval_set=[(X_fold_val, y_fold_val)],
|
||||||
|
#early_stopping_rounds=50,
|
||||||
|
#verbose=2
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Train with early stopping
|
||||||
|
model.fit(
|
||||||
|
X_fold_train,
|
||||||
|
y_fold_train,
|
||||||
|
eval_set=[(X_fold_val, y_fold_val)],
|
||||||
|
sample_weight=sample_weights
|
||||||
|
#early_stopping_rounds=50,
|
||||||
|
#verbose=2
|
||||||
|
)
|
||||||
|
|
||||||
# Calculate score
|
# Calculate score
|
||||||
if self.config.model_type == 'classifier':
|
if self.config.model_type == 'classifier':
|
||||||
pred = model.predict(X_fold_val)
|
pred = model.predict(X_fold_val)
|
||||||
@ -1029,15 +1049,26 @@ class LibraryTradingModel:
|
|||||||
|
|
||||||
model.set_params(verbosity=1)
|
model.set_params(verbosity=1)
|
||||||
|
|
||||||
#balance unbalanced classes, works for binary:logistics
|
# Determine class weights dynamically
|
||||||
|
class_counts = np.bincount(y_train) # Counts the occurrences of each class
|
||||||
|
class_weights = {i: sum(class_counts) / (len(class_counts) * count) for i, count in enumerate(class_counts)}
|
||||||
|
|
||||||
|
# Check number of classes
|
||||||
if self.config.n_classes == 2:
|
if self.config.n_classes == 2:
|
||||||
n_0 = sum(y_train == 0) # 900
|
# For binary classification, use scale_pos_weight
|
||||||
n_1 = sum(y_train == 1) # 100
|
n_0 = class_counts[0]
|
||||||
scale_pos_weight = n_0 / n_1 # 900/100 = 9
|
n_1 = class_counts[1]
|
||||||
model.set_params(scale_pos_weight=scale_pos_weight)
|
scale_pos_weight = n_0 / n_1
|
||||||
|
model.set_params(scale_pos_weight=scale_pos_weight)
|
||||||
|
else:
|
||||||
|
# For multiclass classification, use sample_weight
|
||||||
|
sample_weights = [class_weights[label] for label in y_train]
|
||||||
|
|
||||||
print("Model Training...")
|
print("Model Training...")
|
||||||
model.fit(X_train_scaled, y_train)
|
if self.config.n_classes == 2:
|
||||||
|
model.fit(X_train_scaled, y_train)
|
||||||
|
else:
|
||||||
|
model.fit(X_train_scaled, y_train, sample_weight=sample_weights)
|
||||||
else:
|
else:
|
||||||
print("Using PROVIDED MODEL and SCALER")
|
print("Using PROVIDED MODEL and SCALER")
|
||||||
model = self.use_model
|
model = self.use_model
|
||||||
|
|||||||
Reference in New Issue
Block a user