File size: 4,061 Bytes
449ce0a
e70ffc1
449ce0a
e70ffc1
 
1de075a
449ce0a
e70ffc1
 
 
d32424b
12d3e1a
1de075a
99fb68e
 
 
 
1de075a
 
12d3e1a
 
 
 
 
d32424b
 
e70ffc1
dc376b6
e70ffc1
 
 
dc376b6
e70ffc1
 
e5550d4
e70ffc1
 
dc376b6
e5550d4
e70ffc1
 
 
 
1de075a
 
 
 
 
e2aa91f
e5550d4
1de075a
e2aa91f
1de075a
e5550d4
1de075a
 
e2aa91f
 
 
 
1de075a
e2aa91f
 
 
1de075a
e2aa91f
 
 
 
 
 
 
 
 
 
 
449ce0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from typing import List, Literal, cast
from pydantic import SecretStr
from _utils.google_integration.google_cloud import GCP_PROJECT, upload_to_gcs
from setup.easy_imports import ChatOpenAI, ChatGoogleGenerativeAI
import os
from langchain_core.messages import HumanMessage
from langchain_google_vertexai import ChatVertexAI

deepseek_api_key = cast(str, os.environ.get("DEEPSEEKK_API_KEY"))
google_api_key = cast(str, os.environ.get("GOOGLE_API_KEY_PEIXE"))
open_ai_token = cast(str, os.environ.get("OPENAI_API_KEY"))

Google_llms = Literal[
    "gemini-2.5-pro-preview-05-06",
    "gemini-2.0-flash",
    "gemini-2.0-flash-lite",
    "gemini-2.5-flash-preview-04-17",
]


class LLM:
    def __init__(self):
        pass

    def open_ai(self, model="gpt-4o-mini"):
        return ChatOpenAI(api_key=SecretStr(open_ai_token), model=model)

    def deepseek(self, model="deepseek-chat"):
        return ChatOpenAI(
            api_key=SecretStr(deepseek_api_key),
            base_url="https://api.deepseek.com/v1",
            model=model,
        )

    def google_gemini(self, model: Google_llms = "gemini-2.0-flash", temperature=0.4):
        return ChatGoogleGenerativeAI(
            api_key=SecretStr(google_api_key),
            model=model,
            temperature=temperature,
            max_tokens=None,
            timeout=None,
            max_retries=2,
        )

    async def google_gemini_ainvoke(
        self,
        prompt: str,
        model: Google_llms = "gemini-2.0-flash",
        max_retries: int = 3,
        temperature=0.4,
    ):
        for attempt in range(max_retries):
            try:
                response = await self.google_gemini(model, temperature).ainvoke(
                    [HumanMessage(content=prompt)]
                )

                if isinstance(response.content, list):
                    response.content = "\n".join(response.content)  # type: ignore

                return response
            except Exception as e:
                model = "gemini-2.0-flash"
                print(f"Attempt {attempt + 1} failed with error: {e}")

        # Final attempt fallback logic (optional)
        try:
            print("Final attempt with fallback model...")
            response = await self.open_ai("chat-gpt-4o-mini").ainvoke(
                [HumanMessage(content=prompt)]
            )
            return response
        except Exception as e:
            raise Exception(
                "Failed to generate the final document after 5 retries and the fallback attempt with chat-gpt-4o-mini."
            ) from e

    async def google_gemini_vertex_ainvoke(
        self,
        prompt: str,
        list_of_pdfs: List[str],
        model: Google_llms = "gemini-2.5-flash-preview-04-17",
        max_retries: int = 3,
    ) -> str | None:
        message_parts = [
            {"type": "text", "text": prompt},
        ]
        for pdf in list_of_pdfs:
            pdf_gcs_uri = upload_to_gcs(pdf)
            message_parts.append(
                {
                    # This structure is used for file references via URI
                    "type": "media",
                    "mime_type": "application/pdf",  # <-- mime_type moved up
                    "file_uri": pdf_gcs_uri,  # <-- file_uri moved up
                }
            )

        for attempt in range(max_retries):
            try:
                llm = ChatVertexAI(
                    model_name=model,
                    project=GCP_PROJECT,
                    location="us-central1",
                    temperature=0,
                )
                response = await llm.ainvoke(
                    [HumanMessage(content=message_parts)]  # type: ignore
                )

                if isinstance(response.content, list):
                    response.content = "\n".join(response.content)  # type: ignore

                return response.content  # type: ignore
            except Exception as e:
                model = "gemini-2.0-flash"
                print(f"Attempt {attempt + 1} failed with error: {e}")