|
import json |
|
import traceback |
|
from datetime import datetime |
|
from typing import Literal, Optional, Union |
|
|
|
import httpx |
|
from openai.types.fine_tuning.fine_tuning_job import FineTuningJob |
|
|
|
import litellm |
|
from litellm._logging import verbose_logger |
|
from litellm.llms.custom_httpx.http_handler import HTTPHandler, get_async_httpx_client |
|
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexLLM |
|
from litellm.types.fine_tuning import OpenAIFineTuningHyperparameters |
|
from litellm.types.llms.openai import FineTuningJobCreate |
|
from litellm.types.llms.vertex_ai import ( |
|
FineTuneHyperparameters, |
|
FineTuneJobCreate, |
|
FineTunesupervisedTuningSpec, |
|
ResponseSupervisedTuningSpec, |
|
ResponseTuningJob, |
|
) |
|
|
|
|
|
class VertexFineTuningAPI(VertexLLM): |
|
""" |
|
Vertex methods to support for batches |
|
""" |
|
|
|
def __init__(self) -> None: |
|
super().__init__() |
|
self.async_handler = get_async_httpx_client( |
|
llm_provider=litellm.LlmProviders.VERTEX_AI, |
|
params={"timeout": 600.0}, |
|
) |
|
|
|
def convert_response_created_at(self, response: ResponseTuningJob): |
|
try: |
|
|
|
create_time_str = response.get("createTime", "") or "" |
|
create_time_datetime = datetime.fromisoformat( |
|
create_time_str.replace("Z", "+00:00") |
|
) |
|
|
|
created_at = int(create_time_datetime.timestamp()) |
|
|
|
return created_at |
|
except Exception: |
|
return 0 |
|
|
|
def convert_openai_request_to_vertex( |
|
self, |
|
create_fine_tuning_job_data: FineTuningJobCreate, |
|
original_hyperparameters: dict = {}, |
|
kwargs: Optional[dict] = None, |
|
) -> FineTuneJobCreate: |
|
""" |
|
convert request from OpenAI format to Vertex format |
|
https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/tuning |
|
supervised_tuning_spec = FineTunesupervisedTuningSpec( |
|
""" |
|
|
|
supervised_tuning_spec = FineTunesupervisedTuningSpec( |
|
training_dataset_uri=create_fine_tuning_job_data.training_file, |
|
) |
|
|
|
if create_fine_tuning_job_data.validation_file: |
|
supervised_tuning_spec["validation_dataset"] = ( |
|
create_fine_tuning_job_data.validation_file |
|
) |
|
|
|
_vertex_hyperparameters = ( |
|
self._transform_openai_hyperparameters_to_vertex_hyperparameters( |
|
create_fine_tuning_job_data=create_fine_tuning_job_data, |
|
kwargs=kwargs, |
|
original_hyperparameters=original_hyperparameters, |
|
) |
|
) |
|
|
|
if _vertex_hyperparameters and len(_vertex_hyperparameters) > 0: |
|
supervised_tuning_spec["hyperParameters"] = _vertex_hyperparameters |
|
|
|
fine_tune_job = FineTuneJobCreate( |
|
baseModel=create_fine_tuning_job_data.model, |
|
supervisedTuningSpec=supervised_tuning_spec, |
|
tunedModelDisplayName=create_fine_tuning_job_data.suffix, |
|
) |
|
|
|
return fine_tune_job |
|
|
|
def _transform_openai_hyperparameters_to_vertex_hyperparameters( |
|
self, |
|
create_fine_tuning_job_data: FineTuningJobCreate, |
|
original_hyperparameters: dict = {}, |
|
kwargs: Optional[dict] = None, |
|
) -> FineTuneHyperparameters: |
|
_oai_hyperparameters = create_fine_tuning_job_data.hyperparameters |
|
_vertex_hyperparameters = FineTuneHyperparameters() |
|
if _oai_hyperparameters: |
|
if _oai_hyperparameters.n_epochs: |
|
_vertex_hyperparameters["epoch_count"] = int( |
|
_oai_hyperparameters.n_epochs |
|
) |
|
if _oai_hyperparameters.learning_rate_multiplier: |
|
_vertex_hyperparameters["learning_rate_multiplier"] = float( |
|
_oai_hyperparameters.learning_rate_multiplier |
|
) |
|
|
|
_adapter_size = original_hyperparameters.get("adapter_size", None) |
|
if _adapter_size: |
|
_vertex_hyperparameters["adapter_size"] = _adapter_size |
|
|
|
return _vertex_hyperparameters |
|
|
|
def convert_vertex_response_to_open_ai_response( |
|
self, response: ResponseTuningJob |
|
) -> FineTuningJob: |
|
status: Literal[ |
|
"validating_files", "queued", "running", "succeeded", "failed", "cancelled" |
|
] = "queued" |
|
if response["state"] == "JOB_STATE_PENDING": |
|
status = "queued" |
|
if response["state"] == "JOB_STATE_SUCCEEDED": |
|
status = "succeeded" |
|
if response["state"] == "JOB_STATE_FAILED": |
|
status = "failed" |
|
if response["state"] == "JOB_STATE_CANCELLED": |
|
status = "cancelled" |
|
if response["state"] == "JOB_STATE_RUNNING": |
|
status = "running" |
|
|
|
created_at = self.convert_response_created_at(response) |
|
|
|
_supervisedTuningSpec: ResponseSupervisedTuningSpec = ( |
|
response.get("supervisedTuningSpec", None) or {} |
|
) |
|
training_uri: str = _supervisedTuningSpec.get("trainingDatasetUri", "") or "" |
|
return FineTuningJob( |
|
id=response.get("name", "") or "", |
|
created_at=created_at, |
|
fine_tuned_model=response.get("tunedModelDisplayName", ""), |
|
finished_at=None, |
|
hyperparameters=self._translate_vertex_response_hyperparameters( |
|
vertex_hyper_parameters=_supervisedTuningSpec.get("hyperParameters", {}) |
|
or {} |
|
), |
|
model=response.get("baseModel", "") or "", |
|
object="fine_tuning.job", |
|
organization_id="", |
|
result_files=[], |
|
seed=0, |
|
status=status, |
|
trained_tokens=None, |
|
training_file=training_uri, |
|
validation_file=None, |
|
estimated_finish=None, |
|
integrations=[], |
|
) |
|
|
|
def _translate_vertex_response_hyperparameters( |
|
self, vertex_hyper_parameters: FineTuneHyperparameters |
|
) -> OpenAIFineTuningHyperparameters: |
|
""" |
|
translate vertex responsehyperparameters to openai hyperparameters |
|
""" |
|
_dict_remaining_hyperparameters: dict = dict(vertex_hyper_parameters) |
|
return OpenAIFineTuningHyperparameters( |
|
n_epochs=_dict_remaining_hyperparameters.pop("epoch_count", 0), |
|
**_dict_remaining_hyperparameters, |
|
) |
|
|
|
async def acreate_fine_tuning_job( |
|
self, |
|
fine_tuning_url: str, |
|
headers: dict, |
|
request_data: FineTuneJobCreate, |
|
): |
|
|
|
try: |
|
verbose_logger.debug( |
|
"about to create fine tuning job: %s, request_data: %s", |
|
fine_tuning_url, |
|
json.dumps(request_data, indent=4), |
|
) |
|
if self.async_handler is None: |
|
raise ValueError( |
|
"VertexAI Fine Tuning - async_handler is not initialized" |
|
) |
|
response = await self.async_handler.post( |
|
headers=headers, |
|
url=fine_tuning_url, |
|
json=request_data, |
|
) |
|
|
|
if response.status_code != 200: |
|
raise Exception( |
|
f"Error creating fine tuning job. Status code: {response.status_code}. Response: {response.text}" |
|
) |
|
|
|
verbose_logger.debug( |
|
"got response from creating fine tuning job: %s", response.json() |
|
) |
|
|
|
vertex_response = ResponseTuningJob( |
|
**response.json(), |
|
) |
|
|
|
verbose_logger.debug("vertex_response %s", vertex_response) |
|
open_ai_response = self.convert_vertex_response_to_open_ai_response( |
|
vertex_response |
|
) |
|
return open_ai_response |
|
|
|
except Exception as e: |
|
verbose_logger.error("asyncerror creating fine tuning job %s", e) |
|
trace_back_str = traceback.format_exc() |
|
verbose_logger.error(trace_back_str) |
|
raise e |
|
|
|
def create_fine_tuning_job( |
|
self, |
|
_is_async: bool, |
|
create_fine_tuning_job_data: FineTuningJobCreate, |
|
vertex_project: Optional[str], |
|
vertex_location: Optional[str], |
|
vertex_credentials: Optional[str], |
|
api_base: Optional[str], |
|
timeout: Union[float, httpx.Timeout], |
|
kwargs: Optional[dict] = None, |
|
original_hyperparameters: Optional[dict] = {}, |
|
): |
|
|
|
verbose_logger.debug( |
|
"creating fine tuning job, args= %s", create_fine_tuning_job_data |
|
) |
|
_auth_header, vertex_project = self._ensure_access_token( |
|
credentials=vertex_credentials, |
|
project_id=vertex_project, |
|
custom_llm_provider="vertex_ai_beta", |
|
) |
|
|
|
auth_header, _ = self._get_token_and_url( |
|
model="", |
|
auth_header=_auth_header, |
|
gemini_api_key=None, |
|
vertex_credentials=vertex_credentials, |
|
vertex_project=vertex_project, |
|
vertex_location=vertex_location, |
|
stream=False, |
|
custom_llm_provider="vertex_ai_beta", |
|
api_base=api_base, |
|
) |
|
|
|
headers = { |
|
"Authorization": f"Bearer {auth_header}", |
|
"Content-Type": "application/json", |
|
} |
|
|
|
fine_tune_job = self.convert_openai_request_to_vertex( |
|
create_fine_tuning_job_data=create_fine_tuning_job_data, |
|
kwargs=kwargs, |
|
original_hyperparameters=original_hyperparameters or {}, |
|
) |
|
|
|
fine_tuning_url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/tuningJobs" |
|
if _is_async is True: |
|
return self.acreate_fine_tuning_job( |
|
fine_tuning_url=fine_tuning_url, |
|
headers=headers, |
|
request_data=fine_tune_job, |
|
) |
|
sync_handler = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0)) |
|
|
|
verbose_logger.debug( |
|
"about to create fine tuning job: %s, request_data: %s", |
|
fine_tuning_url, |
|
fine_tune_job, |
|
) |
|
response = sync_handler.post( |
|
headers=headers, |
|
url=fine_tuning_url, |
|
json=fine_tune_job, |
|
) |
|
|
|
if response.status_code != 200: |
|
raise Exception( |
|
f"Error creating fine tuning job. Status code: {response.status_code}. Response: {response.text}" |
|
) |
|
|
|
verbose_logger.debug( |
|
"got response from creating fine tuning job: %s", response.json() |
|
) |
|
vertex_response = ResponseTuningJob( |
|
**response.json(), |
|
) |
|
|
|
verbose_logger.debug("vertex_response %s", vertex_response) |
|
open_ai_response = self.convert_vertex_response_to_open_ai_response( |
|
vertex_response |
|
) |
|
return open_ai_response |
|
|
|
async def pass_through_vertex_ai_POST_request( |
|
self, |
|
request_data: dict, |
|
vertex_project: str, |
|
vertex_location: str, |
|
vertex_credentials: str, |
|
request_route: str, |
|
): |
|
_auth_header, vertex_project = await self._ensure_access_token_async( |
|
credentials=vertex_credentials, |
|
project_id=vertex_project, |
|
custom_llm_provider="vertex_ai_beta", |
|
) |
|
auth_header, _ = self._get_token_and_url( |
|
model="", |
|
auth_header=_auth_header, |
|
gemini_api_key=None, |
|
vertex_credentials=vertex_credentials, |
|
vertex_project=vertex_project, |
|
vertex_location=vertex_location, |
|
stream=False, |
|
custom_llm_provider="vertex_ai_beta", |
|
api_base="", |
|
) |
|
|
|
headers = { |
|
"Authorization": f"Bearer {auth_header}", |
|
"Content-Type": "application/json", |
|
} |
|
|
|
url = None |
|
if request_route == "/tuningJobs": |
|
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/tuningJobs" |
|
elif "/tuningJobs/" in request_route and "cancel" in request_route: |
|
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/tuningJobs{request_route}" |
|
elif "generateContent" in request_route: |
|
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}{request_route}" |
|
elif "predict" in request_route: |
|
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}{request_route}" |
|
elif "/batchPredictionJobs" in request_route: |
|
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}{request_route}" |
|
elif "countTokens" in request_route: |
|
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}{request_route}" |
|
elif "cachedContents" in request_route: |
|
_model = request_data.get("model") |
|
if _model is not None and "/publishers/google/models/" not in _model: |
|
request_data["model"] = ( |
|
f"projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{_model}" |
|
) |
|
|
|
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}{request_route}" |
|
else: |
|
raise ValueError(f"Unsupported Vertex AI request route: {request_route}") |
|
if self.async_handler is None: |
|
raise ValueError("VertexAI Fine Tuning - async_handler is not initialized") |
|
|
|
response = await self.async_handler.post( |
|
headers=headers, |
|
url=url, |
|
json=request_data, |
|
) |
|
|
|
if response.status_code != 200: |
|
raise Exception( |
|
f"Error creating fine tuning job. Status code: {response.status_code}. Response: {response.text}" |
|
) |
|
|
|
response_json = response.json() |
|
return response_json |
|
|