rapacious commited on
Commit
9a46859
·
verified ·
1 Parent(s): 8a73b48

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -12
app.py CHANGED
@@ -6,6 +6,9 @@ import torch
6
  model_name = "Qwen/Qwen2.5-0.5B"
7
  try:
8
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
 
 
9
  model = AutoModelForCausalLM.from_pretrained(
10
  model_name,
11
  torch_dtype="auto",
@@ -18,21 +21,29 @@ except Exception as e:
18
  raise
19
 
20
  # Hàm sinh văn bản (dùng cho cả UI và API)
21
- def generate_text(prompt, max_length=100):
22
  try:
23
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
 
24
  outputs = model.generate(
25
- inputs["input_ids"],
 
26
  max_length=max_length,
27
  num_return_sequences=1,
28
  no_repeat_ngram_size=2,
29
  do_sample=True,
30
  top_k=50,
31
- top_p=0.95
 
32
  )
33
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
34
  except Exception as e:
35
- return f"Error: {str(e)}"
 
 
36
 
37
  # Hàm hiển thị thông tin API
38
  def get_api_info():
@@ -56,6 +67,9 @@ with gr.Blocks(title="Qwen2.5-0.5B Text Generator") as demo:
56
  gr.Markdown("# Qwen2.5-0.5B Text Generator")
57
  gr.Markdown("Enter a prompt below or use the API!")
58
 
 
 
 
59
  # Hiển thị thông tin API
60
  gr.Markdown("### API Information")
61
  api_info = gr.Textbox(label="API Details", value=get_api_info(), interactive=False)
@@ -67,29 +81,29 @@ with gr.Blocks(title="Qwen2.5-0.5B Text Generator") as demo:
67
  max_length_input = gr.Slider(50, 500, value=100, step=10, label="Max Length")
68
 
69
  generate_button = gr.Button("Generate")
70
- output_text = gr.Textbox(label="Generated Text", interactive=False)
71
 
72
  # Liên kết button với hàm generate_text
73
  generate_button.click(
74
  fn=generate_text,
75
- inputs=[prompt_input, max_length_input],
76
- outputs=output_text
77
  )
78
 
79
  # Định nghĩa API endpoints với Gradio
80
  interface = gr.Interface(
81
- fn=generate_text,
82
  inputs=["text", "number"],
83
  outputs="text",
84
  title="Qwen2.5-0.5B API",
85
- api_name="/generate" # API endpoint: /api/generate
86
  ).queue()
87
 
88
  health_interface = gr.Interface(
89
  fn=health_check,
90
  inputs=None,
91
  outputs="text",
92
- api_name="/health_check" # API endpoint: /api/health_check
93
  )
94
 
95
  # Gắn các interface vào demo
 
6
  model_name = "Qwen/Qwen2.5-0.5B"
7
  try:
8
  tokenizer = AutoTokenizer.from_pretrained(model_name)
9
+ # Đặt pad_token_id nếu chưa có
10
+ if tokenizer.pad_token_id is None:
11
+ tokenizer.pad_token_id = tokenizer.eos_token_id
12
  model = AutoModelForCausalLM.from_pretrained(
13
  model_name,
14
  torch_dtype="auto",
 
21
  raise
22
 
23
  # Hàm sinh văn bản (dùng cho cả UI và API)
24
+ def generate_text(prompt, max_length, state):
25
  try:
26
+ # hóa đầu vào với attention_mask
27
+ inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True).to(model.device)
28
  outputs = model.generate(
29
+ input_ids=inputs["input_ids"],
30
+ attention_mask=inputs["attention_mask"],
31
  max_length=max_length,
32
  num_return_sequences=1,
33
  no_repeat_ngram_size=2,
34
  do_sample=True,
35
  top_k=50,
36
+ top_p=0.95,
37
+ pad_token_id=tokenizer.pad_token_id
38
  )
39
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
40
+ # Cập nhật state với kết quả mới
41
+ state.append(generated_text)
42
+ return state, generated_text # Trả về state và output để hiển thị
43
  except Exception as e:
44
+ error_msg = f"Error: {str(e)}"
45
+ state.append(error_msg)
46
+ return state, error_msg
47
 
48
  # Hàm hiển thị thông tin API
49
  def get_api_info():
 
67
  gr.Markdown("# Qwen2.5-0.5B Text Generator")
68
  gr.Markdown("Enter a prompt below or use the API!")
69
 
70
+ # State để lưu trữ lịch sử kết quả
71
+ state = gr.State(value=[]) # Khởi tạo state là danh sách rỗng
72
+
73
  # Hiển thị thông tin API
74
  gr.Markdown("### API Information")
75
  api_info = gr.Textbox(label="API Details", value=get_api_info(), interactive=False)
 
81
  max_length_input = gr.Slider(50, 500, value=100, step=10, label="Max Length")
82
 
83
  generate_button = gr.Button("Generate")
84
+ output_text = gr.Textbox(label="Generated Text History", interactive=False, lines=10)
85
 
86
  # Liên kết button với hàm generate_text
87
  generate_button.click(
88
  fn=generate_text,
89
+ inputs=[prompt_input, max_length_input, state],
90
+ outputs=[state, output_text] # Cập nhật cả state và output_text
91
  )
92
 
93
  # Định nghĩa API endpoints với Gradio
94
  interface = gr.Interface(
95
+ fn=lambda prompt, max_length: generate_text(prompt, max_length, [])[1], # Chỉ lấy output, không dùng state cho API
96
  inputs=["text", "number"],
97
  outputs="text",
98
  title="Qwen2.5-0.5B API",
99
+ api_name="/generate"
100
  ).queue()
101
 
102
  health_interface = gr.Interface(
103
  fn=health_check,
104
  inputs=None,
105
  outputs="text",
106
+ api_name="/health_check"
107
  )
108
 
109
  # Gắn các interface vào demo