|
import gradio as gr |
|
from fastapi import FastAPI, HTTPException, Request |
|
from pydantic import BaseModel |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
import torch |
|
import uvicorn |
|
from fastapi.responses import HTMLResponse |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
model_name = "Qwen/Qwen2.5-0.5B" |
|
try: |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
torch_dtype="auto", |
|
device_map="auto", |
|
attn_implementation="eager" |
|
) |
|
print("Model and tokenizer loaded successfully!") |
|
except Exception as e: |
|
print(f"Error loading model: {e}") |
|
raise |
|
|
|
|
|
class TextInput(BaseModel): |
|
prompt: str |
|
max_length: int = 100 |
|
|
|
|
|
def generate_text(prompt, max_length=100): |
|
try: |
|
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
|
outputs = model.generate( |
|
inputs["input_ids"], |
|
max_length=max_length, |
|
num_return_sequences=1, |
|
no_repeat_ngram_size=2, |
|
do_sample=True, |
|
top_k=50, |
|
top_p=0.95 |
|
) |
|
return tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
except Exception as e: |
|
raise Exception(f"Error: {str(e)}") |
|
|
|
|
|
@app.post("/generate") |
|
async def generate_text_api(input: TextInput): |
|
try: |
|
result = generate_text(input.prompt, input.max_length) |
|
return {"generated_text": result} |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
@app.get("/") |
|
async def root(): |
|
return {"message": "Qwen2.5-0.5B API is running!"} |
|
|
|
|
|
@app.get("/api_link") |
|
async def get_api_link(request: Request): |
|
scheme = request.url.scheme |
|
host = request.url.hostname |
|
if request.url.port: |
|
base_url = f"{scheme}://{host}:{request.url.port}" |
|
else: |
|
base_url = f"{scheme}://{host}" |
|
return { |
|
"api_url": base_url, |
|
"endpoints": { |
|
"health_check": f"{base_url}/", |
|
"generate_text": f"{base_url}/generate", |
|
"api_link": f"{base_url}/api_link", |
|
"interface": f"{base_url}/interface" |
|
} |
|
} |
|
|
|
|
|
def create_gradio_interface(): |
|
with gr.Blocks(title="Qwen2.5-0.5B Text Generator") as demo: |
|
gr.Markdown("# Qwen2.5-0.5B Text Generator") |
|
gr.Markdown("Enter a prompt and get generated text!") |
|
|
|
with gr.Row(): |
|
prompt_input = gr.Textbox(label="Prompt", placeholder="Type something...") |
|
max_length_input = gr.Slider(50, 500, value=100, step=10, label="Max Length") |
|
|
|
generate_button = gr.Button("Generate") |
|
output_text = gr.Textbox(label="Generated Text", interactive=False) |
|
|
|
generate_button.click( |
|
fn=generate_text, |
|
inputs=[prompt_input, max_length_input], |
|
outputs=output_text |
|
) |
|
return demo |
|
|
|
|
|
@app.get("/interface", response_class=HTMLResponse) |
|
async def gradio_interface(request: Request): |
|
gradio_app = create_gradio_interface() |
|
return HTMLResponse(content=gradio_app.render()) |
|
|
|
|
|
if __name__ == "__main__": |
|
uvicorn.run(app, host="0.0.0.0", port=7860) |