rapacious commited on
Commit
9a76509
·
verified ·
1 Parent(s): 1ec8046

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -24
app.py CHANGED
@@ -1,8 +1,10 @@
1
- import uvicorn
2
  from fastapi import FastAPI, HTTPException, Request
3
  from pydantic import BaseModel
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
5
  import torch
 
 
6
 
7
  # Khởi tạo FastAPI
8
  app = FastAPI()
@@ -22,58 +24,87 @@ except Exception as e:
22
  print(f"Error loading model: {e}")
23
  raise
24
 
25
- # Định nghĩa request body
26
  class TextInput(BaseModel):
27
  prompt: str
28
  max_length: int = 100
29
 
30
- # API endpoint để sinh văn bản
31
- @app.post("/generate")
32
- async def generate_text(input: TextInput):
33
  try:
34
- # hóa đầu vào
35
- inputs = tokenizer(input.prompt, return_tensors="pt").to(model.device)
36
-
37
- # Sinh văn bản
38
  outputs = model.generate(
39
  inputs["input_ids"],
40
- max_length=input.max_length,
41
  num_return_sequences=1,
42
  no_repeat_ngram_size=2,
43
  do_sample=True,
44
  top_k=50,
45
  top_p=0.95
46
  )
47
-
48
- # Giải kết quả
49
- generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
50
- return {"generated_text": generated_text}
 
 
 
 
 
 
51
  except Exception as e:
52
  raise HTTPException(status_code=500, detail=str(e))
53
 
54
- # Endpoint kiểm tra sức khỏe
55
  @app.get("/")
56
  async def root():
57
  return {"message": "Qwen2.5-0.5B API is running!"}
58
 
59
- # Endpoint hiển thị API URL đầy đủ
60
  @app.get("/api_link")
61
  async def get_api_link(request: Request):
62
- # Lấy host từ request
63
- host = request.client.host
64
- # Lấy port từ server (nếu chạy local thì mặc định là 7860)
65
- port = request.url.port if request.url.port else 7860
66
- # Tạo URL đầy đủ
67
- base_url = f"http://{host}:{port}"
68
  return {
69
  "api_url": base_url,
70
  "endpoints": {
71
  "health_check": f"{base_url}/",
72
  "generate_text": f"{base_url}/generate",
73
- "api_link": f"{base_url}/api_link"
 
74
  }
75
  }
76
 
77
- # Chạy server khi file được gọi trực tiếp
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  if __name__ == "__main__":
79
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
+ import gradio as gr
2
  from fastapi import FastAPI, HTTPException, Request
3
  from pydantic import BaseModel
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
5
  import torch
6
+ import uvicorn
7
+ from fastapi.responses import HTMLResponse
8
 
9
  # Khởi tạo FastAPI
10
  app = FastAPI()
 
24
  print(f"Error loading model: {e}")
25
  raise
26
 
27
+ # Định nghĩa request body cho API
28
  class TextInput(BaseModel):
29
  prompt: str
30
  max_length: int = 100
31
 
32
+ # Hàm sinh văn bản (dùng chung cho API và Gradio)
33
+ def generate_text(prompt, max_length=100):
 
34
  try:
35
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
 
 
 
36
  outputs = model.generate(
37
  inputs["input_ids"],
38
+ max_length=max_length,
39
  num_return_sequences=1,
40
  no_repeat_ngram_size=2,
41
  do_sample=True,
42
  top_k=50,
43
  top_p=0.95
44
  )
45
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
46
+ except Exception as e:
47
+ raise Exception(f"Error: {str(e)}")
48
+
49
+ # API endpoint để sinh văn bản
50
+ @app.post("/generate")
51
+ async def generate_text_api(input: TextInput):
52
+ try:
53
+ result = generate_text(input.prompt, input.max_length)
54
+ return {"generated_text": result}
55
  except Exception as e:
56
  raise HTTPException(status_code=500, detail=str(e))
57
 
58
+ # API endpoint kiểm tra sức khỏe
59
  @app.get("/")
60
  async def root():
61
  return {"message": "Qwen2.5-0.5B API is running!"}
62
 
63
+ # API endpoint hiển thị URL
64
  @app.get("/api_link")
65
  async def get_api_link(request: Request):
66
+ scheme = request.url.scheme
67
+ host = request.url.hostname
68
+ if request.url.port:
69
+ base_url = f"{scheme}://{host}:{request.url.port}"
70
+ else:
71
+ base_url = f"{scheme}://{host}"
72
  return {
73
  "api_url": base_url,
74
  "endpoints": {
75
  "health_check": f"{base_url}/",
76
  "generate_text": f"{base_url}/generate",
77
+ "api_link": f"{base_url}/api_link",
78
+ "interface": f"{base_url}/interface"
79
  }
80
  }
81
 
82
+ # Tạo giao diện Gradio
83
+ def create_gradio_interface():
84
+ with gr.Blocks(title="Qwen2.5-0.5B Text Generator") as demo:
85
+ gr.Markdown("# Qwen2.5-0.5B Text Generator")
86
+ gr.Markdown("Enter a prompt and get generated text!")
87
+
88
+ with gr.Row():
89
+ prompt_input = gr.Textbox(label="Prompt", placeholder="Type something...")
90
+ max_length_input = gr.Slider(50, 500, value=100, step=10, label="Max Length")
91
+
92
+ generate_button = gr.Button("Generate")
93
+ output_text = gr.Textbox(label="Generated Text", interactive=False)
94
+
95
+ generate_button.click(
96
+ fn=generate_text,
97
+ inputs=[prompt_input, max_length_input],
98
+ outputs=output_text
99
+ )
100
+ return demo
101
+
102
+ # Thêm endpoint để hiển thị giao diện Gradio
103
+ @app.get("/interface", response_class=HTMLResponse)
104
+ async def gradio_interface(request: Request):
105
+ gradio_app = create_gradio_interface()
106
+ return HTMLResponse(content=gradio_app.render())
107
+
108
+ # Chạy ứng dụng nếu không trên Hugging Face Spaces
109
  if __name__ == "__main__":
110
  uvicorn.run(app, host="0.0.0.0", port=7860)