File size: 4,349 Bytes
4279593
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
import json
from typing import Any, Tuple
from deepeval.models.base_model import DeepEvalBaseLLM
from src.evaluation.writer.agent_write import create_workflow_sync, create_workflow_async
from src.utils.api_key_manager import with_api_manager
from src.helpers.helper import remove_markdown
from dotenv import load_dotenv

class LangChainWrapper(DeepEvalBaseLLM):
    def __init__(self):
        # Load environment variables from .env file
        load_dotenv()

        # Initialize model name from environment variable
        self.model_name = os.getenv("MODEL_NAME")

    # Method to invoke the LLM synchronously
    def _invoke_llm_sync(self, prompt: Any) -> Tuple[str, float]:
        @with_api_manager(temperature=0.0, top_p=1.0)
        def _inner_invoke_sync(*args, **kwargs):
            response = kwargs['llm'].invoke(prompt)
            raw_text = response.content.strip()
            return raw_text

        raw_text = _inner_invoke_sync()
        return raw_text

    # Method to invoke the LLM asynchronously
    async def _invoke_llm_async(self, prompt: Any) -> Tuple[str, float]:
        @with_api_manager(temperature=0.0, top_p=1.0)
        async def _inner_invoke_async(*args, **kwargs):
            response = await kwargs['llm'].ainvoke(prompt)
            raw_text = response.content.strip()
            return raw_text

        raw_text = await _inner_invoke_async()
        return raw_text

    # Method to parse text as a schema
    def _parse_as_schema(self, raw_text: str, schema: Any) -> Any:
        cleaned_text = remove_markdown(raw_text)
        data = json.loads(cleaned_text)

        # Try to parse data as schema
        try:
            return schema(**data)
        except Exception:
            print(f"Failed to parse data for schema: {schema}")
            raise

    # Method to generate text synchronously
    def generate(self, prompt: Any, schema: Any = None) -> str:
        raw_text = self._invoke_llm_sync(prompt)

        if schema is not None:
            try:
                parsed_obj = self._parse_as_schema(raw_text, schema)
                return parsed_obj
            except json.JSONDecodeError as e:
                print(f"Failed to parse JSON data: {e}\nUsing LangGraph fallback...")

                input = {
                    "initial_prompt": prompt,
                    "plan": "",
                    "write_steps": [],
                    "final_json": ""
                }
                app = create_workflow_sync()
                final_state = app.invoke(input)
                output = remove_markdown(final_state['final_json'])

                try:
                    data = json.loads(output)
                    return data
                except json.JSONDecodeError as e:
                    raise Exception(f"Cannot parse JSON data: {e}")
        else:
            return raw_text

    # Method to generate text asynchronously
    async def a_generate(self, prompt: Any, schema: Any = None) -> str:
        raw_text = await self._invoke_llm_async(prompt)

        if schema is not None:
            try:
                parsed_obj = self._parse_as_schema(raw_text, schema)
                return parsed_obj
            except json.JSONDecodeError as e:
                print(f"Failed to parse JSON data: {e}\nUsing LangGraph fallback...")
                
                input = {
                    "initial_prompt": prompt,
                    "plan": "",
                    "write_steps": [],
                    "final_json": ""
                }
                app = create_workflow_async()
                final_state = await app.ainvoke(input)
                output = remove_markdown(final_state['final_json'])

                try:
                    data = json.loads(output)
                    return data
                except json.JSONDecodeError as e:
                    raise Exception(f"Cannot parse JSON data: {e}")
        else:
            return raw_text
        
    # Method to get the model name
    def get_model_name(self) -> str:
        return f"LangChainWrapper for {self.model_name}"
    
    # Method to load the model
    def load_model(self, *, llm: Any):
        @with_api_manager(temperature=0.0, top_p=1.0)
        def inner_load_model(*args, **kwargs):
            return llm

        return inner_load_model()