|
import gradio as gr |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
import torch |
|
|
|
|
|
model_name = "Qwen/Qwen2.5-0.5B" |
|
try: |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
torch_dtype="auto", |
|
device_map="auto", |
|
attn_implementation="eager" |
|
) |
|
print("Model and tokenizer loaded successfully!") |
|
except Exception as e: |
|
print(f"Error loading model: {e}") |
|
raise |
|
|
|
|
|
def generate_text(prompt, max_length=100): |
|
try: |
|
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
|
outputs = model.generate( |
|
inputs["input_ids"], |
|
max_length=max_length, |
|
num_return_sequences=1, |
|
no_repeat_ngram_size=2, |
|
do_sample=True, |
|
top_k=50, |
|
top_p=0.95 |
|
) |
|
return tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
except Exception as e: |
|
return f"Error: {str(e)}" |
|
|
|
|
|
def get_api_info(): |
|
base_url = "https://<your-space-name>.hf.space" |
|
return ( |
|
"Welcome to Qwen2.5-0.5B API!\n" |
|
f"API Base URL: {base_url} (Replace '<your-space-name>' with your actual Space name)\n" |
|
"Endpoints:\n" |
|
f"- GET {base_url}/api/health_check (Check API status)\n" |
|
f"- POST {base_url}/api/generate (Generate text)\n" |
|
"To use the generate API, send a POST request with JSON:\n" |
|
'{"0": "your prompt", "1": 150}' |
|
) |
|
|
|
|
|
def health_check(): |
|
return "Qwen2.5-0.5B API is running!" |
|
|
|
|
|
with gr.Blocks(title="Qwen2.5-0.5B Text Generator") as demo: |
|
gr.Markdown("# Qwen2.5-0.5B Text Generator") |
|
gr.Markdown("Enter a prompt below or use the API!") |
|
|
|
|
|
gr.Markdown("### API Information") |
|
api_info = gr.Textbox(label="API Details", value=get_api_info(), interactive=False) |
|
|
|
|
|
gr.Markdown("### Generate Text") |
|
with gr.Row(): |
|
prompt_input = gr.Textbox(label="Prompt", placeholder="Type something...") |
|
max_length_input = gr.Slider(50, 500, value=100, step=10, label="Max Length") |
|
|
|
generate_button = gr.Button("Generate") |
|
output_text = gr.Textbox(label="Generated Text", interactive=False) |
|
|
|
|
|
generate_button.click( |
|
fn=generate_text, |
|
inputs=[prompt_input, max_length_input], |
|
outputs=output_text |
|
) |
|
|
|
|
|
interface = gr.Interface( |
|
fn=generate_text, |
|
inputs=["text", "number"], |
|
outputs="text", |
|
title="Qwen2.5-0.5B API", |
|
api_name="/generate" |
|
).queue() |
|
|
|
health_interface = gr.Interface( |
|
fn=health_check, |
|
inputs=None, |
|
outputs="text", |
|
api_name="/health_check" |
|
) |
|
|
|
|
|
demo = gr.TabbedInterface([interface, health_interface], ["Generate Text", "Health Check"]) |
|
|
|
|
|
demo.launch(server_name="0.0.0.0", server_port=7860) |