File size: 8,380 Bytes
f641099
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44c2b25
bd87526
f641099
bd87526
 
 
 
 
f641099
 
 
44c2b25
bd87526
44c2b25
bd87526
 
44c2b25
bd87526
f641099
bd87526
 
 
 
 
 
 
 
44c2b25
 
f641099
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44c2b25
bd87526
f641099
bd87526
 
 
 
 
f641099
 
 
44c2b25
 
bd87526
 
 
f641099
44c2b25
bd87526
f641099
bd87526
 
 
 
 
 
 
 
44c2b25
 
f641099
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bd87526
f641099
 
 
 
 
 
 
bd87526
f641099
bd87526
 
 
 
 
f641099
 
 
 
bd87526
 
 
f641099
bd87526
 
f641099
bd87526
f641099
bd87526
 
 
 
 
 
 
 
f641099
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
from fastapi import APIRouter, Request, HTTPException, File, UploadFile, Body, Form, Query
from slowapi import Limiter
from slowapi.util import get_remote_address
from pydantic import BaseModel, field_validator
from models.gemma_llm import LLMManager
from utils.translate import perform_internal_translation
from config import settings, SUPPORTED_LANGUAGES
from logging_config import logger
from PIL import Image
import io

router = APIRouter()
limiter = Limiter(key_func=get_remote_address)
llm_manager = LLMManager(settings.llm_model_name)

class ChatRequest(BaseModel):
    prompt: str
    src_lang: str = "kan_Knda"
    tgt_lang: str = "kan_Knda"

    @field_validator("prompt")
    def prompt_must_be_valid(cls, v):
        if len(v) > 1000:
            raise ValueError("Prompt cannot exceed 1000 characters")
        return v.strip()

    @field_validator("src_lang", "tgt_lang")
    def validate_language(cls, v):
        if v not in SUPPORTED_LANGUAGES:
            raise ValueError(f"Unsupported language code: {v}")
        return v

class ChatResponse(BaseModel):
    response: str

@router.post("/unload_all_models")
async def unload_all_models():
    try:
        logger.info("Starting to unload all models...")
        llm_manager.unload()
        logger.info("All models unloaded successfully")
        return {"status": "success", "message": "All models unloaded"}
    except Exception as e:
        logger.error(f"Error unloading models: {str(e)}")
        raise HTTPException(status_code=500, detail=f"Failed to unload models: {str(e)}")

@router.post("/load_all_models")
async def load_all_models():
    try:
        logger.info("Starting to load all models...")
        llm_manager.load()
        logger.info("All models loaded successfully")
        return {"status": "success", "message": "All models loaded"}
    except Exception as e:
        logger.error(f"Error loading models: {str(e)}")
        raise HTTPException(status_code=500, detail=f"Failed to load models: {str(e)}")

@router.post("/chat", response_model=ChatResponse)
@limiter.limit(settings.chat_rate_limit)
async def chat(request: Request, chat_request: ChatRequest):
    if not chat_request.prompt:
        raise HTTPException(status_code=400, detail="Prompt cannot be empty")
    logger.info(f"Received prompt: {chat_request.prompt}, src_lang: {chat_request.src_lang}, tgt_lang: {chat_request.tgt_lang}")
    try:
        '''
        # Step 1: Translate prompt to English if needed
        if chat_request.src_lang != "eng_Latn":
            translated_prompt = await perform_internal_translation(
                [chat_request.prompt], chat_request.src_lang, "eng_Latn"
            )
            prompt_to_process = translated_prompt[0]
            logger.info(f"Translated prompt to English: {prompt_to_process}")
        else:
            prompt_to_process = chat_request.prompt
            logger.info("Prompt already in English, no translation needed")
'''
        # Step 2: Generate response in English
        prompt_to_process = chat_request.prompt
        response = await llm_manager.generate(prompt_to_process, settings.max_tokens)
        logger.info(f"Generated English response: {response}")
        '''
        # Step 3: Translate response to target language if needed
        if chat_request.tgt_lang != "eng_Latn":
            translated_response = await perform_internal_translation(
                [response], "eng_Latn", chat_request.tgt_lang
            )
            final_response = translated_response[0]
            logger.info(f"Translated response to {chat_request.tgt_lang}: {final_response}")
        else:
            final_response = response
            logger.info("Response kept in English, no translation needed")
        '''
        final_response = response
        return ChatResponse(response=final_response)
    except Exception as e:
        logger.error(f"Error processing request: {str(e)}")
        raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")

@router.post("/visual_query/")
async def visual_query(
    file: UploadFile = File(...),
    query: str = Body(...),
    src_lang: str = Query("kan_Knda", enum=list(SUPPORTED_LANGUAGES)),
    tgt_lang: str = Query("kan_Knda", enum=list(SUPPORTED_LANGUAGES)),
):
    try:
        image = Image.open(file.file)
        if image.size == (0, 0):
            raise HTTPException(status_code=400, detail="Uploaded image is empty or invalid")

        '''
        # Step 1: Translate query to English if needed
        if src_lang != "eng_Latn":
            translated_query = await perform_internal_translation(
                [query], src_lang, "eng_Latn"
            )
            query_to_process = translated_query[0]
            logger.info(f"Translated query to English: {query_to_process}")
        else:
            query_to_process = query
            logger.info("Query already in English, no translation needed")
        '''
        query_to_process = query
        # Step 2: Generate answer in English
        answer = await llm_manager.vision_query(image, query_to_process)
        logger.info(f"Generated English answer: {answer}")

        '''
        # Step 3: Translate answer to target language if needed
        if tgt_lang != "eng_Latn":
            translated_answer = await perform_internal_translation(
                [answer], "eng_Latn", tgt_lang
            )
            final_answer = translated_answer[0]
            logger.info(f"Translated answer to {tgt_lang}: {final_answer}")
        else:
            final_answer = answer
            logger.info("Answer kept in English, no translation needed")
        '''
        final_answer = answer
        return {"answer": final_answer}
    except Exception as e:
        logger.error(f"Error processing request: {str(e)}")
        raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")

@router.post("/chat_v2", response_model=ChatResponse)
@limiter.limit(settings.chat_rate_limit)
async def chat_v2(
    request: Request,
    prompt: str = Form(...),
    image: UploadFile = File(default=None),
    src_lang: str = Form("kan_Knda"),
    tgt_lang: str = Form("kan_Knda"),
):
    if not prompt:
        raise HTTPException(status_code=400, detail="Prompt cannot be empty")
    if src_lang not in SUPPORTED_LANGUAGES or tgt_lang not in SUPPORTED_LANGUAGES:
        raise HTTPException(status_code=400, detail=f"Unsupported language code")

    logger.info(f"Received prompt: {prompt}, src_lang: {src_lang}, tgt_lang: {tgt_lang}, Image provided: {image is not None}")

    try:
        # Step 1: Handle image if provided
        img = None
        if image:
            image_data = await image.read()
            if not image_data:
                raise HTTPException(status_code=400, detail="Uploaded image is empty")
            img = Image.open(io.BytesIO(image_data))

        # Step 2: Translate prompt to English if needed
        if src_lang != "eng_Latn":
            translated_prompt = await perform_internal_translation(
                [prompt], src_lang, "eng_Latn"
            )
            prompt_to_process = translated_prompt[0]
            logger.info(f"Translated prompt to English: {prompt_to_process}")
        else:
            prompt_to_process = prompt
            logger.info("Prompt already in English, no translation needed")

        # Step 3: Generate response in English
        if img:
            response = await llm_manager.chat_v2(img, prompt_to_process)
        else:
            response = await llm_manager.generate(prompt_to_process, settings.max_tokens)
        logger.info(f"Generated English response: {response}")

        # Step 4: Translate response to target language if needed
        if tgt_lang != "eng_Latn":
            translated_response = await perform_internal_translation(
                [response], "eng_Latn", tgt_lang
            )
            final_response = translated_response[0]
            logger.info(f"Translated response to {tgt_lang}: {final_response}")
        else:
            final_response = response
            logger.info("Response kept in English, no translation needed")

        return ChatResponse(response=final_response)
    except Exception as e:
        logger.error(f"Error processing request: {str(e)}")
        raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")