linear_regression_tool / tools /linear_regression.py
marcusinect's picture
Upload agent
f5550b6 verified
raw
history blame
1.02 kB
from typing import Any, Optional
from smolagents.tools import Tool
import numpy
class LinearRegressionTool(Tool):
name = "linear_regression"
description = "Gives the regression coefficient for a linear regression model. It may be a good idea to add an intercept to the input dictionaryOutput is dictionary of (str, float)."
inputs = {'outcome': {'type': 'array', 'description': 'Values of the outcome variable'}, 'covariates': {'type': 'object', 'description': 'Dictionary with name and values for each covariate'}}
output_type = "object"
def forward(self, outcome, covariates):
import numpy as np
array = np.concatenate(
np.array([
np.array(arr)[:, None] for arr in covariates.values()
]),
axis=1
)
coefs = np.linalg.lstsq(array, outcome)[0]
return {cov: round(float(coef), 4) for cov, coef in zip(covariates.keys(), coefs)}
def __init__(self, *args, **kwargs):
self.is_initialized = False