Source code for glmpynet.logistic_regression

"""
This module contains the LogisticRegression class, a scikit-learn compatible wrapper
for penalized logistic regression.
"""

import numpy as np
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.linear_model import LogisticRegression as SklearnLogisticRegression
from sklearn.utils._param_validation import InvalidParameterError
from sklearn.utils.multiclass import unique_labels, type_of_target
from sklearn.utils.validation import check_X_y, check_array, check_is_fitted, validate_data

# Import our new binding interface and mock implementation
from .binding.base import GlmNetBinding
from .binding.mock import MockGlmNetBinding


[docs] class LogisticRegression(ClassifierMixin, BaseEstimator): """ A scikit-learn compatible estimator for penalized logistic regression. This class provides a user-friendly hybrid API. By default, it accepts scikit-learn style parameters like `C` and `penalty`. It also provides an "escape hatch" for advanced users to pass glmnet-native parameters like `alpha` and `nlambda` directly. Parameters ---------- penalty : {'l1', 'l2'}, default='l2' Specifies the norm of the penalty. C : float, default=1.0 Inverse of regularization strength; must be a positive float. alpha : float, optional The elastic net mixing parameter, with 0 <= alpha <= 1. If provided, this will override the `penalty` parameter. nlambda : int, default=100 The number of lambda values in the regularization path. """
[docs] def __init__(self, penalty: str = 'l2', C: float = 1.0, alpha: float = None, nlambda: int = 100, binding: GlmNetBinding = None): """ Initializes the LogisticRegression model. The constructor is "lean" and only stores parameters. All validation and translation happens in `fit`. """ self.penalty = penalty self.C = C self.alpha = alpha self.nlambda = nlambda self.binding = binding
def _validate_and_translate_params(self): """ Validates hyperparameters and translates sklearn-style params to glmnet-style. This is called at the beginning of fit(). """ if self.C <= 0: raise InvalidParameterError( f"The 'C' parameter of LogisticRegression must be a positive float. " f"Got {self.C} instead." ) # Determine the final alpha value if self.alpha is not None: # User provided alpha directly, it takes precedence if not 0 <= self.alpha <= 1: raise InvalidParameterError(f"alpha must be in [0, 1]; got (alpha={self.alpha})") final_alpha = self.alpha else: # Translate from penalty if self.penalty == 'l1': final_alpha = 1.0 elif self.penalty == 'l2': final_alpha = 0.0 else: # --- THIS IS THE FIX --- # Raise the specific error type that the scikit-learn ecosystem expects. raise InvalidParameterError( f"The 'penalty' parameter of LogisticRegression must be 'l1' or 'l2'. " f"Got '{self.penalty}' instead." ) return {"alpha": final_alpha, "nlambda": self.nlambda}
[docs] def fit(self, X, y): """ Fit the logistic regression model according to the given training data. """ # Step 1: Validate and translate hyperparameters glmnet_params = self._validate_and_translate_params() # Step 2: Validate input data X, y = check_X_y(X, y, accept_sparse=True) self.classes_ = unique_labels(y) self.n_features_in_ = X.shape[1] # Use the recommended scikit-learn utility to check the target type. # This ensures we raise the exact error message that check_estimator expects. y_type = type_of_target(y) if y_type != "binary": raise ValueError( "Only binary classification is supported. The type of the target " f"is {y_type}." ) # Step 3: Instantiate the binding if self.binding is None: self.binding_ = MockGlmNetBinding() else: self.binding_ = self.binding # Step 4: Call the binding's fit method with the translated parameters results = self.binding_.fit( x=X, y=y, alpha=glmnet_params['alpha'], nlambda=glmnet_params['nlambda'] ) # Step 5: Store the fitted coefficients lambda_idx_to_use = self.nlambda // 2 self.intercept_ = np.array([results['a0'][lambda_idx_to_use]]) self.coef_ = results['ca'][:, lambda_idx_to_use].T.reshape(1, -1) return self
def _decision_function(self, X): """Calculate the linear decision function.""" check_is_fitted(self) # Use validate_data to ensure n_features_in_ is checked correctly X = validate_data(self, X, accept_sparse=True, reset=False) return (X @ self.coef_.T) + self.intercept_
[docs] def predict(self, X): """ Predict class labels for samples in X. """ scores = self._decision_function(X).ravel() predictions = (scores > 0).astype(int) return self.classes_[predictions]
[docs] def predict_proba(self, X): """ Probability estimates for samples in X. """ scores = self._decision_function(X).ravel() prob_class_1 = 1 / (1 + np.exp(-scores)) prob_class_0 = 1 - prob_class_1 return np.vstack((prob_class_0, prob_class_1)).T
def _more_tags(self): """ Provides custom metadata to scikit-learn's tag system. This is the correct and documented way to override a default tag. BaseEstimator's __sklearn_tags__ method will find and use this. """ return {"binary_only": True} def __sklearn_tags__(self): tags = super().__sklearn_tags__() tags.classifier_tags.multi_class = False tags.classifier_tags.multi_label = False tags.classifier_tags.poor_score = True tags.estimator_type = 'classifier' tags.input_tags.sparse = True return tags