Spaces:
Running
Running
import gradio as gr | |
import json | |
import requests | |
import urllib.request | |
import os | |
import ssl | |
import base64 | |
from PIL import Image | |
import soundfile as sf | |
import mimetypes | |
import logging | |
from io import BytesIO | |
import tempfile | |
# Set up logging | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
# Azure ML endpoint configuration | |
url = os.getenv("AZURE_ENDPOINT") | |
api_key = os.getenv("AZURE_API_KEY") | |
# Initialize MIME types | |
mimetypes.init() | |
def call_aml_endpoint(payload, url, api_key): | |
"""Call Azure ML endpoint with the given payload.""" | |
# Allow self-signed HTTPS certificates | |
def allow_self_signed_https(allowed): | |
if allowed and not os.environ.get('PYTHONHTTPSVERIFY', '') and getattr(ssl, '_create_unverified_context', None): | |
ssl._create_default_https_context = ssl._create_unverified_context | |
allow_self_signed_https(True) | |
# Set parameters (can be adjusted based on your needs) | |
parameters = {"temperature": 0.7} | |
if "parameters" not in payload["input_data"]: | |
payload["input_data"]["parameters"] = parameters | |
# Encode the request body | |
body = str.encode(json.dumps(payload)) | |
if not api_key: | |
raise Exception("A key should be provided to invoke the endpoint") | |
# Set up headers | |
headers = {'Content-Type': 'application/json', 'Authorization': ('Bearer ' + api_key)} | |
# Create and send the request | |
req = urllib.request.Request(url, body, headers) | |
try: | |
logger.info(f"Sending request to {url}") | |
response = urllib.request.urlopen(req) | |
result = response.read().decode('utf-8') | |
logger.info("Received response successfully") | |
return json.loads(result) | |
except urllib.error.HTTPError as error: | |
logger.error(f"Request failed with status code: {error.code}") | |
logger.error(f"Headers: {error.info()}") | |
error_message = error.read().decode("utf8", 'ignore') | |
logger.error(f"Error message: {error_message}") | |
return {"error": error_message} | |
def load_audio_from_url(url): | |
"""Load audio from a URL using soundfile | |
Args: | |
url (str): URL of the audio file | |
Returns: | |
tuple: (sample_rate, audio_data) if successful, None otherwise | |
str: file path to the temporary saved audio file | |
""" | |
try: | |
# Get the audio file from the URL | |
response = requests.get(url) | |
response.raise_for_status() # Raise exception for bad status codes | |
# For other formats that soundfile supports directly (WAV, FLAC, etc.) | |
audio_data, sample_rate = sf.read(BytesIO(response.content)) | |
# Save to a temporary file to be used by the chatbot | |
file_extension = os.path.splitext(url)[1].lower() | |
if not file_extension: | |
file_extension = '.wav' # Default to .wav if no extension | |
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) | |
sf.write(temp_file.name, audio_data, sample_rate) | |
return (sample_rate, audio_data), temp_file.name | |
except Exception as e: | |
logger.error(f"Error loading audio from URL: {e}") | |
return None, None | |
def encode_base64_from_file(file_path): | |
"""Encode file content to base64 string and determine MIME type.""" | |
file_extension = os.path.splitext(file_path)[1].lower() | |
# Map file extensions to MIME types | |
if file_extension in ['.jpg', '.jpeg']: | |
mime_type = "image/jpeg" | |
elif file_extension == '.png': | |
mime_type = "image/png" | |
elif file_extension == '.gif': | |
mime_type = "image/gif" | |
elif file_extension in ['.bmp', '.tiff', '.webp']: | |
mime_type = f"image/{file_extension[1:]}" | |
elif file_extension == '.flac': | |
mime_type = "audio/flac" | |
elif file_extension == '.wav': | |
mime_type = "audio/wav" | |
elif file_extension == '.mp3': | |
mime_type = "audio/mpeg" | |
elif file_extension in ['.m4a', '.aac']: | |
mime_type = "audio/aac" | |
elif file_extension == '.ogg': | |
mime_type = "audio/ogg" | |
else: | |
mime_type = "application/octet-stream" | |
# Read and encode file content | |
with open(file_path, "rb") as file: | |
encoded_string = base64.b64encode(file.read()).decode('utf-8') | |
return encoded_string, mime_type | |
def process_message(history, message, conversation_state): | |
"""Process user message and update both history and internal state.""" | |
# Extract text and files | |
text_content = message["text"] if message["text"] else "" | |
image_files = [] | |
audio_files = [] | |
# Create content array for internal state | |
content_items = [] | |
# Add text if available | |
if text_content: | |
content_items.append({"type": "text", "text": text_content}) | |
# Process and immediately convert files to base64 | |
if message["files"] and len(message["files"]) > 0: | |
for file_path in message["files"]: | |
file_extension = os.path.splitext(file_path)[1].lower() | |
file_name = os.path.basename(file_path) | |
# Convert the file to base64 immediately | |
base64_content, mime_type = encode_base64_from_file(file_path) | |
# Add to content items for the API | |
if mime_type.startswith("image/"): | |
content_items.append({ | |
"type": "image_url", | |
"image_url": { | |
"url": f"data:{mime_type};base64,{base64_content}" | |
} | |
}) | |
image_files.append(file_path) | |
elif mime_type.startswith("audio/"): | |
content_items.append({ | |
"type": "audio_url", | |
"audio_url": { | |
"url": f"data:{mime_type};base64,{base64_content}" | |
} | |
}) | |
audio_files.append(file_path) | |
# Only proceed if we have content | |
if content_items: | |
# Add to Gradio chatbot history (for display) | |
history.append({"role": "user", "content": text_content}) | |
# Add file messages if present | |
for file_path in image_files + audio_files: | |
history.append({"role": "user", "content": {"path": file_path}}) | |
print(f"DEBUG: history = {history}") | |
# Add to internal conversation state (with base64 data) | |
conversation_state.append({ | |
"role": "user", | |
"content": content_items | |
}) | |
return history, gr.MultimodalTextbox(value=None, interactive=False), conversation_state | |
def bot_response(history, conversation_state): | |
"""Generate bot response based on conversation state.""" | |
if not conversation_state: | |
return history, conversation_state | |
# Create the payload | |
payload = { | |
"input_data": { | |
"input_string": conversation_state | |
} | |
} | |
# Log the payload for debugging (without base64 data) | |
debug_payload = json.loads(json.dumps(payload)) | |
for item in debug_payload["input_data"]["input_string"]: | |
if "content" in item and isinstance(item["content"], list): | |
for content_item in item["content"]: | |
if "image_url" in content_item: | |
parts = content_item["image_url"]["url"].split(",") | |
if len(parts) > 1: | |
content_item["image_url"]["url"] = parts[0] + ",[BASE64_DATA_REMOVED]" | |
if "audio_url" in content_item: | |
parts = content_item["audio_url"]["url"].split(",") | |
if len(parts) > 1: | |
content_item["audio_url"]["url"] = parts[0] + ",[BASE64_DATA_REMOVED]" | |
logger.info(f"Sending payload: {json.dumps(debug_payload, indent=2)}") | |
# Call Azure ML endpoint | |
response = call_aml_endpoint(payload, url, api_key) | |
# Extract text response from the Azure ML endpoint response | |
try: | |
if isinstance(response, dict): | |
if "result" in response: | |
result = response["result"] | |
elif "output" in response: | |
# Depending on your API's response format | |
if isinstance(response["output"], list) and len(response["output"]) > 0: | |
result = response["output"][0] | |
else: | |
result = str(response["output"]) | |
elif "error" in response: | |
result = f"Error: {response['error']}" | |
else: | |
# Just return the whole response as string if we can't parse it | |
result = f"Received response: {json.dumps(response)}" | |
else: | |
result = str(response) | |
except Exception as e: | |
result = f"Error processing response: {str(e)}" | |
# Add bot response to history | |
if result == "None": | |
result = "Current implementation does not support text + audio + image inputs in the same conversation. Please hit Clear conversation button." | |
history.append({"role": "assistant", "content": result}) | |
# Add to conversation state | |
conversation_state.append({ | |
"role": "assistant", | |
"content": [{"type": "text", "text": result}] | |
}) | |
print(f"DEBUG: history after response: {history}") | |
return history, conversation_state | |
# Create Gradio demo | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
title = gr.Markdown("# Azure ML Multimodal Chatbot Demo") | |
description = gr.Markdown(""" | |
This demo allows you to interact with a multimodal AI model through Azure ML. | |
You can type messages, upload images, or record audio to communicate with the AI. | |
""") | |
# Store the conversation state with base64 data | |
conversation_state = gr.State([]) | |
with gr.Row(): | |
with gr.Column(scale=4): | |
chatbot = gr.Chatbot( | |
type="messages", | |
avatar_images=(None, "https://upload.wikimedia.org/wikipedia/commons/d/d3/Phi-integrated-information-symbol.png",), | |
height=600 | |
) | |
with gr.Row(): | |
chat_input = gr.MultimodalTextbox( | |
interactive=True, | |
file_count="multiple", | |
placeholder="Enter a message or upload files (images, audio)...", | |
show_label=False, | |
sources=["microphone", "upload"], | |
) | |
with gr.Row(): | |
clear_btn = gr.ClearButton([chatbot, chat_input], value="Clear conversation") | |
clear_btn.click(lambda: [], None, conversation_state) # Also clear the conversation state | |
gr.HTML("<div style='text-align: right; margin-top: 5px;'><small>Powered by Azure ML</small></div>") | |
# Define function to handle example submission directly | |
def handle_example_submission(text, files, history, conv_state): | |
""" | |
Process an example submission directly including bot response | |
This bypasses the regular chat_input.submit flow | |
""" | |
# Create a message object similar to what would be submitted by the user | |
message = {"text": text, "files": files if files else []} | |
# Use the same processing function as normal submissions | |
new_history, _, new_conv_state = process_message(history, message, conv_state) | |
# Then immediately trigger the bot response | |
final_history, final_conv_state = bot_response(new_history, new_conv_state) | |
# Re-enable the input box | |
chat_input.update(interactive=True) | |
# Return everything needed | |
return final_history, final_conv_state | |
with gr.Column(scale=1): | |
gr.Markdown("### Examples") | |
with gr.Tab("Text Only"): | |
# For text examples, just submit them directly | |
def run_text_example(example_text, history, conv_state): | |
# Process the example directly | |
return handle_example_submission(example_text, [], history, conv_state) | |
text_examples = gr.Examples( | |
examples=[ | |
["Tell me about Microsoft Azure cloud services."], | |
["What can you help me with today?"], | |
["Explain the difference between AI and machine learning."], | |
], | |
inputs=[gr.Textbox(visible=False)], | |
outputs=[chatbot, conversation_state], | |
fn=lambda text, h=chatbot, c=conversation_state: run_text_example(text, h, c), | |
label="Text Examples (Click to run the example)" | |
) | |
with gr.Tab("Text & Audio"): | |
# Function to handle loading both text and audio from URL and sending directly | |
def run_audio_example(example_text, example_audio_url, history, conv_state): | |
try: | |
# Download and process the audio from URL | |
print(f"Downloading audio from: {example_audio_url}") | |
response = requests.get(example_audio_url) | |
response.raise_for_status() | |
# Save to a temporary file | |
file_extension = os.path.splitext(example_audio_url)[1].lower() | |
if not file_extension: | |
file_extension = '.wav' # Default to .wav if no extension | |
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) | |
temp_file.write(response.content) | |
temp_file.close() | |
print(f"Saved audio to temporary file: {temp_file.name}") | |
# Process the example directly | |
return handle_example_submission(example_text, [temp_file.name], history, conv_state) | |
except Exception as e: | |
print(f"Error processing audio example: {e}") | |
# If an error occurs, just add the text to history | |
history.append({"role": "user", "content": f"{example_text} (Error loading audio: {e})"}) | |
return history, conv_state | |
audio_examples = gr.Examples( | |
examples=[ | |
["Transcribe this audio clip", "https://diamondfan.github.io/audio_files/english.weekend.plan.wav"], | |
["What language is being spoken in this recording?", "https://www2.cs.uic.edu/~i101/SoundFiles/BabyElephantWalk60.wav"], | |
], | |
inputs=[ | |
gr.Textbox(visible=False), | |
gr.Textbox(visible=False) | |
], | |
outputs=[chatbot, conversation_state], | |
fn=lambda text, url, h=chatbot, c=conversation_state: run_audio_example(text, url, h, c), | |
label="Audio Examples (Click to run the example)" | |
) | |
with gr.Tab("Text & Image"): | |
# Function to handle loading both text and image from URL and sending directly | |
def run_image_example(example_text, example_image_url, history, conv_state): | |
try: | |
# Download the image from URL | |
print(f"Downloading image from: {example_image_url}") | |
response = requests.get(example_image_url) | |
response.raise_for_status() | |
# Save to a temporary file | |
file_extension = os.path.splitext(example_image_url)[1].lower() | |
if not file_extension: | |
file_extension = '.jpg' # Default to .jpg if no extension | |
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) | |
temp_file.write(response.content) | |
temp_file.close() | |
print(f"Saved image to temporary file: {temp_file.name}") | |
# Process the example directly | |
return handle_example_submission(example_text, [temp_file.name], history, conv_state) | |
except Exception as e: | |
print(f"Error processing image example: {e}") | |
# If an error occurs, just add the text to history | |
history.append({"role": "user", "content": f"{example_text} (Error loading image: {e})"}) | |
return history, conv_state | |
image_examples = gr.Examples( | |
examples=[ | |
["What's in this image?", "https://storage.googleapis.com/demo-image/dog.jpg"], | |
["Describe this chart", "https://matplotlib.org/stable/_images/sphx_glr_bar_stacked_001.png"], | |
], | |
inputs=[ | |
gr.Textbox(visible=False), | |
gr.Textbox(visible=False) | |
], | |
outputs=[chatbot, conversation_state], | |
fn=lambda text, url, h=chatbot, c=conversation_state: run_image_example(text, url, h, c), | |
label="Image Examples (Click to run the example)" | |
) | |
gr.Markdown("### Instructions") | |
gr.Markdown(""" | |
- Type a question or statement | |
- Upload images or audio files | |
- You can combine text with media files | |
- The model can analyze images and transcribe audio | |
- For best results with images, use JPG or PNG files | |
- For audio, use WAV, MP3, or FLAC files | |
""") | |
gr.Markdown("### Capabilities") | |
gr.Markdown(""" | |
This chatbot can: | |
- Answer questions and provide explanations | |
- Describe and analyze images | |
- Transcribe and analyze audio content | |
- Process multiple inputs in the same message | |
- Maintain context throughout the conversation | |
""") | |
with gr.Accordion("Debug Info", open=False): | |
debug_output = gr.JSON( | |
label="Last API Request", | |
value={} | |
) | |
def update_debug(conversation_state): | |
"""Update debug output with the last payload that would be sent.""" | |
if not conversation_state: | |
return {} | |
# Create a payload from the conversation | |
payload = { | |
"input_data": { | |
"input_string": conversation_state | |
} | |
} | |
# Remove base64 data to avoid cluttering the UI | |
sanitized_payload = json.loads(json.dumps(payload)) | |
for item in sanitized_payload["input_data"]["input_string"]: | |
if "content" in item and isinstance(item["content"], list): | |
for content_item in item["content"]: | |
if "image_url" in content_item: | |
parts = content_item["image_url"]["url"].split(",") | |
if len(parts) > 1: | |
content_item["image_url"]["url"] = parts[0] + ",[BASE64_DATA_REMOVED]" | |
if "audio_url" in content_item: | |
parts = content_item["audio_url"]["url"].split(",") | |
if len(parts) > 1: | |
content_item["audio_url"]["url"] = parts[0] + ",[BASE64_DATA_REMOVED]" | |
return sanitized_payload | |
def enable_input(): | |
"""Re-enable the input box after bot responds.""" | |
return gr.MultimodalTextbox(interactive=True) | |
# Set up event handlers | |
msg_submit = chat_input.submit( | |
process_message, [chatbot, chat_input, conversation_state], [chatbot, chat_input, conversation_state], queue=False | |
) | |
msg_response = msg_submit.then( | |
bot_response, [chatbot, conversation_state], [chatbot, conversation_state], api_name="bot_response" | |
) | |
msg_response.then(enable_input, None, chat_input) | |
# btn_response.then(enable_input, None, chat_input) | |
# Update debug info | |
# msg_response.then(update_debug, conversation_state, debug_output) | |
# btn_response.then(update_debug, conversation_state, debug_output) | |
demo.launch(share=True, debug=True) |