Spaces:
Sleeping
Sleeping
import streamlit as st | |
import json | |
import pandas as pd | |
import numpy as np | |
import plotly.express as px | |
from io import StringIO | |
import time | |
def model_inference_dashboard(model_info): | |
"""Create a dashboard for testing model inference directly in the app""" | |
if not model_info: | |
st.error("Model information not found") | |
return | |
st.subheader("🧠 Model Inference Dashboard") | |
# Get the pipeline type based on model tags or information | |
pipeline_tag = getattr(model_info, "pipeline_tag", None) | |
if not pipeline_tag: | |
# Try to determine from tags | |
tags = getattr(model_info, "tags", []) | |
for tag in tags: | |
if tag in [ | |
"text-classification", "token-classification", "question-answering", | |
"summarization", "translation", "text-generation", "fill-mask", | |
"sentence-similarity", "image-classification", "object-detection", | |
"image-segmentation", "text-to-image", "image-to-text" | |
]: | |
pipeline_tag = tag | |
break | |
if not pipeline_tag: | |
pipeline_tag = "text-classification" # Default fallback | |
# Display information about the model | |
st.info(f"This dashboard allows you to test your model's inference capabilities. Model pipeline: **{pipeline_tag}**") | |
# Different input options based on pipeline type | |
input_data = None | |
if pipeline_tag in ["text-classification", "token-classification", "fill-mask", "text-generation", "summarization"]: | |
# Text-based input | |
st.markdown("### Text Input") | |
input_text = st.text_area( | |
"Enter text for inference", | |
value="This model is amazing!", | |
height=150 | |
) | |
# Additional parameters for specific pipelines | |
if pipeline_tag == "text-generation": | |
col1, col2 = st.columns(2) | |
with col1: | |
max_length = st.slider("Max Length", min_value=10, max_value=500, value=100) | |
with col2: | |
temperature = st.slider("Temperature", min_value=0.1, max_value=2.0, value=1.0, step=0.1) | |
input_data = { | |
"text": input_text, | |
"max_length": max_length, | |
"temperature": temperature | |
} | |
elif pipeline_tag == "summarization": | |
max_length = st.slider("Max Summary Length", min_value=10, max_value=200, value=50) | |
input_data = { | |
"text": input_text, | |
"max_length": max_length | |
} | |
else: | |
input_data = {"text": input_text} | |
elif pipeline_tag in ["question-answering"]: | |
st.markdown("### Question & Context") | |
question = st.text_input("Question", value="What is this model about?") | |
context = st.text_area( | |
"Context", | |
value="This model is a transformer-based language model designed for natural language understanding tasks.", | |
height=150 | |
) | |
input_data = { | |
"question": question, | |
"context": context | |
} | |
elif pipeline_tag in ["translation"]: | |
st.markdown("### Translation") | |
source_lang = st.selectbox("Source Language", ["English", "French", "German", "Spanish", "Chinese"]) | |
target_lang = st.selectbox("Target Language", ["French", "English", "German", "Spanish", "Chinese"]) | |
translation_text = st.text_area("Text to translate", value="Hello, how are you?", height=150) | |
input_data = { | |
"text": translation_text, | |
"source_language": source_lang, | |
"target_language": target_lang | |
} | |
elif pipeline_tag in ["image-classification", "object-detection", "image-segmentation"]: | |
st.markdown("### Image Input") | |
upload_method = st.radio("Select input method", ["Upload Image", "Image URL"]) | |
if upload_method == "Upload Image": | |
uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"]) | |
if uploaded_file is not None: | |
st.image(uploaded_file, caption="Uploaded Image", use_column_width=True) | |
input_data = {"image": uploaded_file} | |
else: | |
image_url = st.text_input("Image URL", value="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/distilbert-base-uncased-finetuned-sst-2-english-architecture.png") | |
if image_url: | |
st.image(image_url, caption="Image from URL", use_column_width=True) | |
input_data = {"image_url": image_url} | |
elif pipeline_tag in ["audio-classification", "automatic-speech-recognition"]: | |
st.markdown("### Audio Input") | |
upload_method = st.radio("Select input method", ["Upload Audio", "Audio URL"]) | |
if upload_method == "Upload Audio": | |
uploaded_file = st.file_uploader("Upload an audio file", type=["mp3", "wav", "ogg"]) | |
if uploaded_file is not None: | |
st.audio(uploaded_file) | |
input_data = {"audio": uploaded_file} | |
else: | |
audio_url = st.text_input("Audio URL") | |
if audio_url: | |
st.audio(audio_url) | |
input_data = {"audio_url": audio_url} | |
# Execute inference | |
if st.button("Run Inference", use_container_width=True): | |
if input_data: | |
with st.spinner("Running inference..."): | |
# In a real implementation, this would call the HF Inference API | |
# For demo purposes, simulate a response | |
time.sleep(2) | |
# Generate a sample response based on the pipeline type | |
if pipeline_tag == "text-classification": | |
result = [ | |
{"label": "POSITIVE", "score": 0.9231}, | |
{"label": "NEGATIVE", "score": 0.0769} | |
] | |
elif pipeline_tag == "token-classification": | |
result = [ | |
{"entity": "B-PER", "word": "This", "score": 0.2, "index": 0, "start": 0, "end": 4}, | |
{"entity": "O", "word": "model", "score": 0.95, "index": 1, "start": 5, "end": 10}, | |
{"entity": "O", "word": "is", "score": 0.99, "index": 2, "start": 11, "end": 13}, | |
{"entity": "B-MISC", "word": "amazing", "score": 0.85, "index": 3, "start": 14, "end": 21} | |
] | |
elif pipeline_tag == "text-generation": | |
result = { | |
"generated_text": input_data["text"] + " It provides state-of-the-art performance on a wide range of natural language processing tasks, including sentiment analysis, named entity recognition, and question answering. The model was trained on a diverse corpus of text data, allowing it to generate coherent and contextually relevant responses." | |
} | |
elif pipeline_tag == "summarization": | |
result = { | |
"summary_text": "This model provides excellent performance." | |
} | |
elif pipeline_tag == "question-answering": | |
result = { | |
"answer": "a transformer-based language model", | |
"start": 9, | |
"end": 45, | |
"score": 0.953 | |
} | |
elif pipeline_tag == "translation": | |
if input_data["target_language"] == "French": | |
result = {"translation_text": "Bonjour, comment allez-vous?"} | |
elif input_data["target_language"] == "German": | |
result = {"translation_text": "Hallo, wie geht es dir?"} | |
elif input_data["target_language"] == "Spanish": | |
result = {"translation_text": "Hola, ¿cómo estás?"} | |
elif input_data["target_language"] == "Chinese": | |
result = {"translation_text": "你好,你好吗?"} | |
else: | |
result = {"translation_text": "Hello, how are you?"} | |
elif pipeline_tag in ["image-classification"]: | |
result = [ | |
{"label": "diagram", "score": 0.9712}, | |
{"label": "architecture", "score": 0.0231}, | |
{"label": "document", "score": 0.0057} | |
] | |
elif pipeline_tag in ["object-detection"]: | |
result = [ | |
{"label": "box", "score": 0.9712, "box": {"xmin": 10, "ymin": 20, "xmax": 100, "ymax": 80}}, | |
{"label": "text", "score": 0.8923, "box": {"xmin": 120, "ymin": 30, "xmax": 250, "ymax": 60}} | |
] | |
else: | |
result = {"result": "Sample response for " + pipeline_tag} | |
# Display the results | |
st.markdown("### Inference Results") | |
# Different visualizations based on the response type | |
if pipeline_tag == "text-classification": | |
# Create a bar chart for classification results | |
result_df = pd.DataFrame(result) | |
fig = px.bar( | |
result_df, | |
x="label", | |
y="score", | |
color="score", | |
color_continuous_scale=px.colors.sequential.Viridis, | |
title="Classification Results" | |
) | |
st.plotly_chart(fig, use_container_width=True) | |
# Show the raw results | |
st.json(result) | |
elif pipeline_tag == "token-classification": | |
# Display entity highlighting | |
st.markdown("#### Named Entities") | |
# Create HTML with colored spans for entities | |
html = "" | |
input_text = input_data["text"] | |
entities = {} | |
for item in result: | |
if item["entity"].startswith("B-") or item["entity"].startswith("I-"): | |
entity_type = item["entity"][2:] # Remove B- or I- prefix | |
entities[entity_type] = entities.get(entity_type, 0) + 1 | |
# Create a color map for entity types | |
colors = px.colors.qualitative.Plotly[:len(entities)] | |
entity_colors = dict(zip(entities.keys(), colors)) | |
# Create the HTML | |
for item in result: | |
word = item["word"] | |
entity = item["entity"] | |
if entity == "O": | |
html += f"{word} " | |
else: | |
entity_type = entity[2:] if entity.startswith("B-") or entity.startswith("I-") else entity | |
color = entity_colors.get(entity_type, "#CCCCCC") | |
html += f'<span style="background-color: {color}; padding: 2px; border-radius: 3px;" title="{entity} ({item["score"]:.2f})">{word}</span> ' | |
st.markdown(f'<div style="line-height: 2.5;">{html}</div>', unsafe_allow_html=True) | |
# Display legend | |
st.markdown("#### Entity Legend") | |
legend_html = "".join([ | |
f'<span style="background-color: {color}; padding: 2px 8px; margin-right: 10px; border-radius: 3px;">{entity}</span>' | |
for entity, color in entity_colors.items() | |
]) | |
st.markdown(f'<div>{legend_html}</div>', unsafe_allow_html=True) | |
# Show the raw results | |
st.json(result) | |
elif pipeline_tag in ["text-generation", "summarization", "translation"]: | |
# Display the generated text | |
response_key = "generated_text" if "generated_text" in result else "summary_text" if "summary_text" in result else "translation_text" | |
st.markdown(f"#### Output Text") | |
st.markdown(f'<div style="background-color: #f0f2f6; padding: 20px; border-radius: 10px;">{result[response_key]}</div>', unsafe_allow_html=True) | |
# Text stats | |
st.markdown("#### Text Statistics") | |
input_length = len(input_data["text"]) if "text" in input_data else 0 | |
output_length = len(result[response_key]) | |
col1, col2, col3 = st.columns(3) | |
with col1: | |
st.metric("Input Length", input_length, "characters") | |
with col2: | |
st.metric("Output Length", output_length, "characters") | |
with col3: | |
compression = ((output_length - input_length) / input_length * 100) if input_length > 0 else 0 | |
st.metric("Length Change", f"{compression:.1f}%", f"{output_length - input_length} chars") | |
elif pipeline_tag == "question-answering": | |
# Highlight the answer in the context | |
st.markdown("#### Answer") | |
st.markdown(f'<div style="background-color: #e6f3ff; padding: 10px; border-radius: 5px; font-weight: bold;">{result["answer"]}</div>', unsafe_allow_html=True) | |
# Show the answer in context | |
if "context" in input_data: | |
st.markdown("#### Answer in Context") | |
context = input_data["context"] | |
start = result["start"] | |
end = result["end"] | |
highlighted_context = ( | |
context[:start] + | |
f'<span style="background-color: #ffeb3b; font-weight: bold;">{context[start:end]}</span>' + | |
context[end:] | |
) | |
st.markdown(f'<div style="background-color: #f0f2f6; padding: 15px; border-radius: 10px; line-height: 1.5;">{highlighted_context}</div>', unsafe_allow_html=True) | |
# Confidence score | |
st.markdown("#### Confidence") | |
st.progress(result["score"]) | |
st.text(f"Confidence Score: {result['score']:.4f}") | |
elif pipeline_tag == "image-classification": | |
# Create a bar chart for classification results | |
result_df = pd.DataFrame(result) | |
fig = px.bar( | |
result_df, | |
x="score", | |
y="label", | |
orientation='h', | |
color="score", | |
color_continuous_scale=px.colors.sequential.Viridis, | |
title="Image Classification Results" | |
) | |
fig.update_layout(yaxis={'categoryorder':'total ascending'}) | |
st.plotly_chart(fig, use_container_width=True) | |
# Show the raw results | |
st.json(result) | |
else: | |
# Generic display for other types | |
st.json(result) | |
# Option to save the results | |
st.download_button( | |
label="Download Results", | |
data=json.dumps(result, indent=2), | |
file_name="inference_results.json", | |
mime="application/json" | |
) | |
else: | |
st.warning("Please provide input data for inference") | |
# API integration options | |
with st.expander("API Integration"): | |
st.markdown("### Use this model in your application") | |
# Python code example | |
st.markdown("#### Python") | |
python_code = f""" | |
```python | |
import requests | |
API_URL = "https://api-inference.huggingface.co/models/{model_info.modelId}" | |
headers = {{"Authorization": "Bearer YOUR_API_KEY"}} | |
def query(payload): | |
response = requests.post(API_URL, headers=headers, json=payload) | |
return response.json() | |
# Example usage | |
output = query({{ | |
"inputs": "This model is amazing!" | |
}}) | |
print(output) | |
``` | |
""" | |
st.markdown(python_code) | |
# JavaScript code example | |
st.markdown("#### JavaScript") | |
js_code = f""" | |
```javascript | |
async function query(data) {{ | |
const response = await fetch( | |
"https://api-inference.huggingface.co/models/{model_info.modelId}", | |
{{ | |
headers: {{ Authorization: "Bearer YOUR_API_KEY" }}, | |
method: "POST", | |
body: JSON.stringify(data), | |
}} | |
); | |
const result = await response.json(); | |
return result; | |
}} | |
// Example usage | |
query({{"inputs": "This model is amazing!"}}).then((response) => {{ | |
console.log(JSON.stringify(response)); | |
}}); | |
``` | |
""" | |
st.markdown(js_code) | |