Chatbot_1 / app.py
Praveen0309's picture
Add application1 file
165a317
raw
history blame
3.62 kB
import torch
from transformers import AutoTokenizer, AutoProcessor, TrainingArguments, LlavaForConditionalGeneration, BitsAndBytesConfig
from trl import SFTTrainer
from peft import LoraConfig, PeftModel
from PIL import Image
import requests
from deep_translator import GoogleTranslator
import gradio as gr
import PIL.Image
import base64
import time
import os
model_id = "HuggingFaceH4/vsft-llava-1.5-7b-hf-trl"
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
base_model = LlavaForConditionalGeneration.from_pretrained(model_id, quantization_config=quantization_config, torch_dtype=torch.float16)
# Load the PEFT Lora adapter
peft_lora_adapter_path = "Praveen0309/llava-1.5-7b-hf-ft-mix-vsft-3"
peft_lora_adapter = PeftModel.from_pretrained(base_model, peft_lora_adapter_path, adapter_name="lora_adapter")
base_model.load_adapter(peft_lora_adapter_path, adapter_name="lora_adapter")
processor = AutoProcessor.from_pretrained("HuggingFaceH4/vsft-llava-1.5-7b-hf-trl")
# Function to translate text from Bengali to English
def deep_translator_bn_en(input_sentence):
english_translation = GoogleTranslator(source="bn", target="en").translate(input_sentence)
return english_translation
# Function to translate text from English to Bengali
def deep_translator_en_bn(input_sentence):
bengali_translation = GoogleTranslator(source="en", target="bn").translate(input_sentence)
return bengali_translation
def inference(image, image_prompt):
prompt = f"USER: <image>\n{image_prompt} ASSISTANT:"
# Assuming your model can handle PIL images
image = image.convert("RGB") # Ensure image is RGB mode
inputs = processor(text=prompt, images=image, return_tensors="pt")
generate_ids = base_model.generate(**inputs, max_new_tokens=15)
decoded_response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
return decoded_response
def image_to_base64(image_path):
with open(image_path, 'rb') as img:
encoded_string = base64.b64encode(img.read())
return encoded_string.decode('utf-8')
# Function that takes User Inputs and displays it on ChatUI
def query_message(history,txt,img):
image_prompt = deep_translator_bn_en(txt)
history += [(image_prompt,None)]
base64 = image_to_base64(img)
data_url = f"data:image/jpeg;base64,{base64}"
history += [(f"{image_prompt} ![]({data_url})", None)]
return history
# Function that takes User Inputs, generates Response and displays on Chat UI
def llm_response(history,text,img):
image_prompt = deep_translator_bn_en(text)
response = inference(img,image_prompt)
assistant_index = response.find("ASSISTANT:")
extracted_string = response[assistant_index + len("ASSISTANT:"):].strip()
output = deep_translator_en_bn(extracted_string)
history += [(text,output)]
return history
# Interface Code
with gr.Blocks() as app:
with gr.Row():
image_box = gr.Image(type="pil")
chatbot = gr.Chatbot(
scale = 2,
height=500
)
text_box = gr.Textbox(
placeholder="Enter text and press enter, or upload an image",
container=False,
)
btn = gr.Button("Submit")
clicked = btn.click(query_message,
[chatbot,text_box,image_box],
chatbot
).then(llm_response,
[chatbot,text_box,image_box],
chatbot
)
app.queue()
app.launch(debug=True,share=True)