Spaces:
Paused
Paused
import os | |
import json | |
from typing import Any, Tuple | |
from deepeval.models.base_model import DeepEvalBaseLLM | |
from src.evaluation.writer.agent_write import create_workflow_sync, create_workflow_async | |
from src.utils.api_key_manager import with_api_manager | |
from src.helpers.helper import remove_markdown | |
from dotenv import load_dotenv | |
class LangChainWrapper(DeepEvalBaseLLM): | |
def __init__(self): | |
# Load environment variables from .env file | |
load_dotenv() | |
# Initialize model name from environment variable | |
self.model_name = os.getenv("MODEL_NAME") | |
# Method to invoke the LLM synchronously | |
def _invoke_llm_sync(self, prompt: Any) -> Tuple[str, float]: | |
def _inner_invoke_sync(*args, **kwargs): | |
response = kwargs['llm'].invoke(prompt) | |
raw_text = response.content.strip() | |
return raw_text | |
raw_text = _inner_invoke_sync() | |
return raw_text | |
# Method to invoke the LLM asynchronously | |
async def _invoke_llm_async(self, prompt: Any) -> Tuple[str, float]: | |
async def _inner_invoke_async(*args, **kwargs): | |
response = await kwargs['llm'].ainvoke(prompt) | |
raw_text = response.content.strip() | |
return raw_text | |
raw_text = await _inner_invoke_async() | |
return raw_text | |
# Method to parse text as a schema | |
def _parse_as_schema(self, raw_text: str, schema: Any) -> Any: | |
cleaned_text = remove_markdown(raw_text) | |
data = json.loads(cleaned_text) | |
# Try to parse data as schema | |
try: | |
return schema(**data) | |
except Exception: | |
print(f"Failed to parse data for schema: {schema}") | |
raise | |
# Method to generate text synchronously | |
def generate(self, prompt: Any, schema: Any = None) -> str: | |
raw_text = self._invoke_llm_sync(prompt) | |
if schema is not None: | |
try: | |
parsed_obj = self._parse_as_schema(raw_text, schema) | |
return parsed_obj | |
except json.JSONDecodeError as e: | |
print(f"Failed to parse JSON data: {e}\nUsing LangGraph fallback...") | |
input = { | |
"initial_prompt": prompt, | |
"plan": "", | |
"write_steps": [], | |
"final_json": "" | |
} | |
app = create_workflow_sync() | |
final_state = app.invoke(input) | |
output = remove_markdown(final_state['final_json']) | |
try: | |
data = json.loads(output) | |
return data | |
except json.JSONDecodeError as e: | |
raise Exception(f"Cannot parse JSON data: {e}") | |
else: | |
return raw_text | |
# Method to generate text asynchronously | |
async def a_generate(self, prompt: Any, schema: Any = None) -> str: | |
raw_text = await self._invoke_llm_async(prompt) | |
if schema is not None: | |
try: | |
parsed_obj = self._parse_as_schema(raw_text, schema) | |
return parsed_obj | |
except json.JSONDecodeError as e: | |
print(f"Failed to parse JSON data: {e}\nUsing LangGraph fallback...") | |
input = { | |
"initial_prompt": prompt, | |
"plan": "", | |
"write_steps": [], | |
"final_json": "" | |
} | |
app = create_workflow_async() | |
final_state = await app.ainvoke(input) | |
output = remove_markdown(final_state['final_json']) | |
try: | |
data = json.loads(output) | |
return data | |
except json.JSONDecodeError as e: | |
raise Exception(f"Cannot parse JSON data: {e}") | |
else: | |
return raw_text | |
# Method to get the model name | |
def get_model_name(self) -> str: | |
return f"LangChainWrapper for {self.model_name}" | |
# Method to load the model | |
def load_model(self, *, llm: Any): | |
def inner_load_model(*args, **kwargs): | |
return llm | |
return inner_load_model() |