santanus24 commited on
Commit
010c325
·
verified ·
1 Parent(s): 7d1da2f

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -0
app.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import LlavaForConditionalGeneration, BitsAndBytesConfig, AutoProcessor
3
+ from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
4
+ import requests
5
+ from PIL import Image
6
+ import requests
7
+ import gradio as gr
8
+
9
+
10
+ # Load translation model and tokenizer
11
+ translate_model_name = "facebook/mbart-large-50-many-to-many-mmt"
12
+ translate_model = MBartForConditionalGeneration.from_pretrained(translate_model_name)
13
+ tokenizer = MBart50TokenizerFast.from_pretrained(translate_model_name)
14
+
15
+ # load the base model in 4 bit quantized
16
+ quantization_config = BitsAndBytesConfig(
17
+ load_in_4bit=True,
18
+ )
19
+
20
+ # finetuned model adapter path (Hugging Face Hub)
21
+ model_id = 'somnathsingh31/llava-1.5-7b-hf-ft-merged_model'
22
+
23
+ # merge the models
24
+ merged_model = LlavaForConditionalGeneration.from_pretrained(model_id,
25
+ quantization_config=quantization_config,
26
+ torch_dtype=torch.float16)
27
+
28
+ # create processor from base model
29
+ processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
30
+
31
+ # function to translate
32
+ def translate(text, source_lang, target_lang):
33
+ # Set source language
34
+ tokenizer.src_lang = source_lang
35
+
36
+ # Encode the text
37
+ encoded_text = tokenizer(text, return_tensors="pt")
38
+
39
+ # Force target language token
40
+ forced_bos_token_id = tokenizer.lang_code_to_id[target_lang]
41
+
42
+ # Generate the translation
43
+ generated_tokens = translate_model.generate(**encoded_text, forced_bos_token_id=forced_bos_token_id)
44
+
45
+ # Decode the translation
46
+ translation = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
47
+
48
+ return translation
49
+
50
+
51
+ # function for making inference
52
+ def ask_vlm(hindi_input_text, image):
53
+ # translate from Hindi to English
54
+ prompt_eng = translate(hindi_input_text, "hi_IN", "en_XX")
55
+ prompt = "USER: <image>\n" + prompt_eng + " ASSISTANT:"
56
+
57
+ # If image is uploaded, open the image from bytes, else open from URL
58
+ if hasattr(image, 'read'):
59
+ image = Image.open(image)
60
+ else:
61
+ image = Image.open(requests.get(image, stream=True).raw)
62
+
63
+ inputs = processor(text=prompt, images=image, return_tensors="pt")
64
+ generate_ids = merged_model.generate(**inputs, max_new_tokens=250)
65
+ decoded_response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
66
+ assistant_index = decoded_response.find("ASSISTANT:")
67
+
68
+ # Extract text after "ASSISTANT:"
69
+ if assistant_index != -1:
70
+ text_after_assistant = decoded_response[assistant_index + len("ASSISTANT:"):]
71
+ # Remove leading and trailing whitespace
72
+ text_after_assistant = text_after_assistant.strip()
73
+ else:
74
+ text_after_assistant = None
75
+
76
+ hindi_output_text = translate(text_after_assistant, "en_XX", "hi_IN")
77
+ return hindi_output_text
78
+
79
+ # Define Gradio interface
80
+ input_image = gr.inputs.Image(type="pil", label="Input Image (Upload or URL)")
81
+ input_question = gr.inputs.Textbox(lines=2, label="Question (Hindi)")
82
+ output_text = gr.outputs.Textbox(label="Response (Hindi)")
83
+
84
+ # Create Gradio app
85
+ gr.Interface(fn=ask_vlm, inputs=[input_question, input_image], outputs=output_text, title="Image and Text-based Dialogue System", description="Enter a question in Hindi and an image, either by uploading or providing URL, and get a response in Hindi.").launch()
86
+
87
+
88
+ if __name__ == '__main__':
89
+ image_url = 'https://images.metmuseum.org/CRDImages/ad/original/138425.jpg'
90
+ user_query = 'यह किस प्रकार की कला है? विस्तार से बताइये'
91
+ output = ask_vlm(user_query, image_url)
92
+ print('Output:\n', output)