fujie170 commited on
Commit
c4f35db
·
1 Parent(s): d15aade
Files changed (2) hide show
  1. app.py +60 -2
  2. requirements.txt +2 -1
app.py CHANGED
@@ -19,6 +19,52 @@ pipe = pipe.to(device)
19
 
20
  MAX_SEED = np.iinfo(np.int32).max
21
  MAX_IMAGE_SIZE = 768 # 减小最大尺寸以提高生成速度
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  # 熊猫烧香相关的提示词模板
24
  PANDA_INCENSE_PROMPTS = [
@@ -45,13 +91,25 @@ def infer(
45
  num_inference_steps,
46
  progress=gr.Progress(track_tqdm=True),
47
  ):
 
 
 
 
 
 
 
48
  if randomize_seed:
49
  seed = random.randint(0, MAX_SEED)
 
 
 
50
 
51
  generator = torch.Generator().manual_seed(seed)
 
52
 
 
53
  image = pipe(
54
- prompt=prompt,
55
  negative_prompt=negative_prompt,
56
  guidance_scale=guidance_scale,
57
  num_inference_steps=num_inference_steps,
@@ -59,7 +117,7 @@ def infer(
59
  height=height,
60
  generator=generator,
61
  ).images[0]
62
-
63
  return image, seed
64
 
65
 
 
19
 
20
  MAX_SEED = np.iinfo(np.int32).max
21
  MAX_IMAGE_SIZE = 768 # 减小最大尺寸以提高生成速度
22
+ # 提示词优化:调用 Gemini Pro API
23
+ import requests
24
+
25
+ def optimize_prompt(en_text):
26
+ api_key = "AIzaSyB8Qu7XLzR6vnmnBN19z2cAXVRrJYjr2KY"
27
+ url = "https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent"
28
+ headers = {
29
+ "Content-Type": "application/json",
30
+ "X-goog-api-key": api_key
31
+ }
32
+ data = {
33
+ "contents": [{
34
+ "parts": [{
35
+ "text": f"请将以下英文提示词优化为适合AI文生图生成的英文提示词,要求简洁、描述清晰、突出画面细节:{en_text}"
36
+ }]
37
+ }]
38
+ }
39
+ print(f"[LOG] 优化提示词请求: {en_text}")
40
+ try:
41
+ resp = requests.post(url, headers=headers, json=data, timeout=10)
42
+ print(f"[LOG] Gemini API响应状态: {resp.status_code}")
43
+ resp.raise_for_status()
44
+ result = resp.json()
45
+ print(f"[LOG] Gemini API响应内容: {result}")
46
+ opt_text = result["candidates"][0]["content"]["parts"][0]["text"]
47
+ print(f"[LOG] 优化后提示词: {opt_text}")
48
+ return opt_text.strip()
49
+ except Exception as e:
50
+ print(f"[ERROR] Gemini优化失败: {e}")
51
+ return en_text
52
+ # 添加自动翻译功能
53
+ from transformers import MarianMTModel, MarianTokenizer
54
+
55
+ def translate_prompt(text):
56
+ # 检查是否包含中文字符
57
+ if any('\u4e00' <= ch <= '\u9fff' for ch in text):
58
+ print(f"[LOG] 检测到中文提示词: {text}")
59
+ model_name = 'Helsinki-NLP/opus-mt-zh-en'
60
+ tokenizer = MarianTokenizer.from_pretrained(model_name)
61
+ model = MarianMTModel.from_pretrained(model_name)
62
+ translated = model.generate(**tokenizer(text, return_tensors="pt", padding=True))
63
+ en_text = tokenizer.decode(translated[0], skip_special_tokens=True)
64
+ print(f"[LOG] 翻译后的英文提示词: {en_text}")
65
+ return en_text
66
+ print(f"[LOG] 非中文提示词,直接使用: {text}")
67
+ return text
68
 
69
  # 熊猫烧香相关的提示词模板
70
  PANDA_INCENSE_PROMPTS = [
 
91
  num_inference_steps,
92
  progress=gr.Progress(track_tqdm=True),
93
  ):
94
+ # 自动翻译中文提示词为英文
95
+ print(f"[LOG] 用户输入提示词: {prompt}")
96
+ prompt_en = translate_prompt(prompt)
97
+ print(f"[LOG] 英文提示词: {prompt_en}")
98
+ prompt_opt = optimize_prompt(prompt_en)
99
+ print(f"[LOG] 最终用于生成的提示词: {prompt_opt}")
100
+
101
  if randomize_seed:
102
  seed = random.randint(0, MAX_SEED)
103
+ print(f"[LOG] 随机生成种子: {seed}")
104
+ else:
105
+ print(f"[LOG] 使用用户指定种子: {seed}")
106
 
107
  generator = torch.Generator().manual_seed(seed)
108
+ print(f"[LOG] 生成器初始化完成")
109
 
110
+ print(f"[LOG] 开始生成图片,参数: guidance_scale={guidance_scale}, steps={num_inference_steps}, width={width}, height={height}")
111
  image = pipe(
112
+ prompt=prompt_opt,
113
  negative_prompt=negative_prompt,
114
  guidance_scale=guidance_scale,
115
  num_inference_steps=num_inference_steps,
 
117
  height=height,
118
  generator=generator,
119
  ).images[0]
120
+ print(f"[LOG] 图片生成完成")
121
  return image, seed
122
 
123
 
requirements.txt CHANGED
@@ -4,4 +4,5 @@ invisible_watermark
4
  torch
5
  transformers
6
  xformers
7
- gradio
 
 
4
  torch
5
  transformers
6
  xformers
7
+ gradio
8
+ requests