|
import os |
|
import gradio as gr |
|
from argparse import ArgumentParser |
|
from groq import Groq |
|
import base64 |
|
import io |
|
|
|
|
|
API_KEY = os.environ['GROQ_API_KEY'] |
|
client = Groq(api_key=API_KEY) |
|
|
|
REVISION = 'v1.0.4' |
|
|
|
def _get_args(): |
|
parser = ArgumentParser() |
|
parser.add_argument("--revision", type=str, default=REVISION) |
|
parser.add_argument("--share", action="store_true", default=False, help="Create a publicly shareable link for the interface.") |
|
return parser.parse_args() |
|
|
|
def process_image(image): |
|
buffered = io.BytesIO() |
|
image.save(buffered, format="JPEG") |
|
return buffered.getvalue() |
|
|
|
def translate_audio(audio_file): |
|
with open(audio_file, "rb") as file: |
|
translation = client.audio.translations.create( |
|
file=(audio_file, file.read()), |
|
model="whisper-large-v3", |
|
response_format="json", |
|
temperature=0.0 |
|
) |
|
return translation.text |
|
|
|
def transcribe_audio(audio_file): |
|
with open(audio_file, "rb") as file: |
|
transcription = client.audio.transcriptions.create( |
|
file=(audio_file, file.read()), |
|
model="whisper-large-v3", |
|
response_format="json", |
|
temperature=0.0 |
|
) |
|
return transcription.text |
|
|
|
def predict(chat_history, query, image, audio, translate): |
|
final_query = query.strip() |
|
|
|
if audio: |
|
audio_file_path = audio |
|
if translate: |
|
translation_text = translate_audio(audio_file_path) |
|
final_query = translation_text.strip() |
|
chat_history.append({'role': 'assistant', 'content': translation_text}) |
|
else: |
|
transcribed_text = transcribe_audio(audio_file_path) |
|
final_query = f"{final_query} {transcribed_text}".strip() |
|
|
|
image_data = process_image(image) if image else None |
|
messages = create_messages(final_query, image_data) |
|
|
|
if not messages: |
|
error_message = "No valid input provided. Please enter a query or upload an image/audio." |
|
chat_history.append({'role': 'assistant', 'content': error_message}) |
|
return chat_history |
|
|
|
try: |
|
completion = client.chat.completions.create( |
|
model="llama-3.2-90b-vision-preview", |
|
messages=messages, |
|
temperature=1, |
|
max_tokens=1500, |
|
top_p=1, |
|
stream=False, |
|
) |
|
|
|
response_text = completion.choices[0].message.content.strip() |
|
chat_history.append({'role': 'user', 'content': final_query}) |
|
chat_history.append({'role': 'assistant', 'content': response_text}) |
|
except Exception as e: |
|
response_text = f"Error: {str(e)}" |
|
chat_history.append({'role': 'user', 'content': final_query}) |
|
chat_history.append({'role': 'assistant', 'content': response_text}) |
|
|
|
return chat_history |
|
|
|
def create_messages(query, image_data): |
|
messages = [] |
|
if query: |
|
messages.append({'role': 'user', 'content': query}) |
|
if image_data: |
|
image_base64 = f"data:image/jpeg;base64,{base64.b64encode(image_data).decode()}" |
|
messages.append({ |
|
'role': 'user', |
|
'content': [ |
|
{"type": "text", "text": "Please analyze this image."}, |
|
{"type": "image_url", "image_url": {"url": image_base64}} |
|
] |
|
}) |
|
return messages |
|
|
|
def clear_history(): |
|
return [] |
|
|
|
def main(): |
|
args = _get_args() |
|
|
|
with gr.Blocks(css="#chatbox {height: 400px; background-color: #f9f9f9; padding: 20px; border-radius: 10px; }") as demo: |
|
gr.Markdown("<h1 style='text-align: center; color: #4a4a4a;'>Llama-3.2-90b-vision-preview</h1>") |
|
|
|
chatbox = gr.Chatbot(type='messages', elem_id="chatbox") |
|
query = gr.Textbox(label="Type your query here...", placeholder="Enter your question or command...", lines=2) |
|
image_input = gr.Image(type="pil", label="Upload Image") |
|
audio_input = gr.Audio(type="filepath", label="Upload Audio") |
|
translate_checkbox = gr.Checkbox(label="Translate Audio to English Text") |
|
|
|
with gr.Row(): |
|
submit_btn = gr.Button("Submit", variant="primary", elem_id="submit-btn") |
|
clear_btn = gr.Button("Clear History", variant="secondary", elem_id="clear-btn") |
|
|
|
submit_btn.click(predict, inputs=[chatbox, query, image_input, audio_input, translate_checkbox], outputs=chatbox) |
|
clear_btn.click(clear_history, outputs=chatbox) |
|
|
|
demo.launch(share=args.share) |
|
|
|
if __name__ == '__main__': |
|
main() |