rapacious commited on
Commit
ac716aa
·
verified ·
1 Parent(s): 9a76509

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -71
app.py CHANGED
@@ -1,13 +1,6 @@
1
  import gradio as gr
2
- from fastapi import FastAPI, HTTPException, Request
3
- from pydantic import BaseModel
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
5
  import torch
6
- import uvicorn
7
- from fastapi.responses import HTMLResponse
8
-
9
- # Khởi tạo FastAPI
10
- app = FastAPI()
11
 
12
  # Tải model và tokenizer khi ứng dụng khởi động
13
  model_name = "Qwen/Qwen2.5-0.5B"
@@ -24,12 +17,7 @@ except Exception as e:
24
  print(f"Error loading model: {e}")
25
  raise
26
 
27
- # Định nghĩa request body cho API
28
- class TextInput(BaseModel):
29
- prompt: str
30
- max_length: int = 100
31
-
32
- # Hàm sinh văn bản (dùng chung cho API và Gradio)
33
  def generate_text(prompt, max_length=100):
34
  try:
35
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
@@ -44,67 +32,71 @@ def generate_text(prompt, max_length=100):
44
  )
45
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
46
  except Exception as e:
47
- raise Exception(f"Error: {str(e)}")
48
-
49
- # API endpoint để sinh văn bản
50
- @app.post("/generate")
51
- async def generate_text_api(input: TextInput):
52
- try:
53
- result = generate_text(input.prompt, input.max_length)
54
- return {"generated_text": result}
55
- except Exception as e:
56
- raise HTTPException(status_code=500, detail=str(e))
57
 
58
- # API endpoint kiểm tra sức khỏe
59
- @app.get("/")
60
- async def root():
61
- return {"message": "Qwen2.5-0.5B API is running!"}
 
 
 
 
 
 
 
 
 
 
62
 
63
- # API endpoint hiển thị URL
64
- @app.get("/api_link")
65
- async def get_api_link(request: Request):
66
- scheme = request.url.scheme
67
- host = request.url.hostname
68
- if request.url.port:
69
- base_url = f"{scheme}://{host}:{request.url.port}"
70
- else:
71
- base_url = f"{scheme}://{host}"
72
- return {
73
- "api_url": base_url,
74
- "endpoints": {
75
- "health_check": f"{base_url}/",
76
- "generate_text": f"{base_url}/generate",
77
- "api_link": f"{base_url}/api_link",
78
- "interface": f"{base_url}/interface"
79
- }
80
- }
81
 
82
  # Tạo giao diện Gradio
83
- def create_gradio_interface():
84
- with gr.Blocks(title="Qwen2.5-0.5B Text Generator") as demo:
85
- gr.Markdown("# Qwen2.5-0.5B Text Generator")
86
- gr.Markdown("Enter a prompt and get generated text!")
87
-
88
- with gr.Row():
89
- prompt_input = gr.Textbox(label="Prompt", placeholder="Type something...")
90
- max_length_input = gr.Slider(50, 500, value=100, step=10, label="Max Length")
91
-
92
- generate_button = gr.Button("Generate")
93
- output_text = gr.Textbox(label="Generated Text", interactive=False)
94
-
95
- generate_button.click(
96
- fn=generate_text,
97
- inputs=[prompt_input, max_length_input],
98
- outputs=output_text
99
- )
100
- return demo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
- # Thêm endpoint để hiển thị giao diện Gradio
103
- @app.get("/interface", response_class=HTMLResponse)
104
- async def gradio_interface(request: Request):
105
- gradio_app = create_gradio_interface()
106
- return HTMLResponse(content=gradio_app.render())
107
 
108
- # Chạy ứng dụng nếu không trên Hugging Face Spaces
109
- if __name__ == "__main__":
110
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
  import gradio as gr
 
 
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
 
 
 
 
 
4
 
5
  # Tải model và tokenizer khi ứng dụng khởi động
6
  model_name = "Qwen/Qwen2.5-0.5B"
 
17
  print(f"Error loading model: {e}")
18
  raise
19
 
20
+ # Hàm sinh văn bản (dùng cho cả UI và API)
 
 
 
 
 
21
  def generate_text(prompt, max_length=100):
22
  try:
23
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
 
32
  )
33
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
34
  except Exception as e:
35
+ return f"Error: {str(e)}"
 
 
 
 
 
 
 
 
 
36
 
37
+ # Hàm hiển thị thông tin API
38
+ def get_api_info():
39
+ # Trên Hugging Face Spaces, API URL sẽ dựa trên tên Space
40
+ # Khi chạy local, ta giả định port 7860
41
+ base_url = "http://localhost:7860" if gr.context.local else "https://<your-space-name>.hf.space"
42
+ return (
43
+ "Welcome to Qwen2.5-0.5B API!\n"
44
+ f"API Base URL: {base_url}\n"
45
+ "Endpoints:\n"
46
+ f"- GET {base_url}/api/health_check (Check API status)\n"
47
+ f"- POST {base_url}/api/generate (Generate text)\n"
48
+ "To use the generate API, send a POST request with JSON:\n"
49
+ '{"prompt": "your prompt", "max_length": 150}'
50
+ )
51
 
52
+ # Hàm kiểm tra sức khỏe (dành cho API)
53
+ def health_check():
54
+ return "Qwen2.5-0.5B API is running!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
  # Tạo giao diện Gradio
57
+ with gr.Blocks(title="Qwen2.5-0.5B Text Generator") as demo:
58
+ gr.Markdown("# Qwen2.5-0.5B Text Generator")
59
+ gr.Markdown("Enter a prompt below or use the API!")
60
+
61
+ # Hiển thị thông tin API
62
+ gr.Markdown("### API Information")
63
+ api_info = gr.Textbox(label="API Details", value=get_api_info(), interactive=False)
64
+
65
+ # Giao diện sinh văn bản
66
+ gr.Markdown("### Generate Text")
67
+ with gr.Row():
68
+ prompt_input = gr.Textbox(label="Prompt", placeholder="Type something...")
69
+ max_length_input = gr.Slider(50, 500, value=100, step=10, label="Max Length")
70
+
71
+ generate_button = gr.Button("Generate")
72
+ output_text = gr.Textbox(label="Generated Text", interactive=False)
73
+
74
+ # Liên kết button với hàm generate_text
75
+ generate_button.click(
76
+ fn=generate_text,
77
+ inputs=[prompt_input, max_length_input],
78
+ outputs=output_text
79
+ )
80
+
81
+ # Định nghĩa API endpoints với Gradio
82
+ demo = gr.Interface(
83
+ fn=generate_text,
84
+ inputs=["text", "number"],
85
+ outputs="text",
86
+ title="Qwen2.5-0.5B API",
87
+ api_name="/generate" # API endpoint: /api/generate
88
+ ).queue()
89
+
90
+ # Thêm endpoint health check
91
+ health_interface = gr.Interface(
92
+ fn=health_check,
93
+ inputs=None,
94
+ outputs="text",
95
+ api_name="/health_check" # API endpoint: /api/health_check
96
+ )
97
 
98
+ # Kết hợp giao diện và API
99
+ app = gr.mount_gradio_app(demo, health_interface)
 
 
 
100
 
101
+ # Chạy ứng dụng
102
+ demo.launch(server_name="0.0.0.0", server_port=7860)