- app.py +60 -2
- requirements.txt +2 -1
app.py
CHANGED
@@ -19,6 +19,52 @@ pipe = pipe.to(device)
|
|
19 |
|
20 |
MAX_SEED = np.iinfo(np.int32).max
|
21 |
MAX_IMAGE_SIZE = 768 # 减小最大尺寸以提高生成速度
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
# 熊猫烧香相关的提示词模板
|
24 |
PANDA_INCENSE_PROMPTS = [
|
@@ -45,13 +91,25 @@ def infer(
|
|
45 |
num_inference_steps,
|
46 |
progress=gr.Progress(track_tqdm=True),
|
47 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
if randomize_seed:
|
49 |
seed = random.randint(0, MAX_SEED)
|
|
|
|
|
|
|
50 |
|
51 |
generator = torch.Generator().manual_seed(seed)
|
|
|
52 |
|
|
|
53 |
image = pipe(
|
54 |
-
prompt=
|
55 |
negative_prompt=negative_prompt,
|
56 |
guidance_scale=guidance_scale,
|
57 |
num_inference_steps=num_inference_steps,
|
@@ -59,7 +117,7 @@ def infer(
|
|
59 |
height=height,
|
60 |
generator=generator,
|
61 |
).images[0]
|
62 |
-
|
63 |
return image, seed
|
64 |
|
65 |
|
|
|
19 |
|
20 |
MAX_SEED = np.iinfo(np.int32).max
|
21 |
MAX_IMAGE_SIZE = 768 # 减小最大尺寸以提高生成速度
|
22 |
+
# 提示词优化:调用 Gemini Pro API
|
23 |
+
import requests
|
24 |
+
|
25 |
+
def optimize_prompt(en_text):
|
26 |
+
api_key = "AIzaSyB8Qu7XLzR6vnmnBN19z2cAXVRrJYjr2KY"
|
27 |
+
url = "https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent"
|
28 |
+
headers = {
|
29 |
+
"Content-Type": "application/json",
|
30 |
+
"X-goog-api-key": api_key
|
31 |
+
}
|
32 |
+
data = {
|
33 |
+
"contents": [{
|
34 |
+
"parts": [{
|
35 |
+
"text": f"请将以下英文提示词优化为适合AI文生图生成的英文提示词,要求简洁、描述清晰、突出画面细节:{en_text}"
|
36 |
+
}]
|
37 |
+
}]
|
38 |
+
}
|
39 |
+
print(f"[LOG] 优化提示词请求: {en_text}")
|
40 |
+
try:
|
41 |
+
resp = requests.post(url, headers=headers, json=data, timeout=10)
|
42 |
+
print(f"[LOG] Gemini API响应状态: {resp.status_code}")
|
43 |
+
resp.raise_for_status()
|
44 |
+
result = resp.json()
|
45 |
+
print(f"[LOG] Gemini API响应内容: {result}")
|
46 |
+
opt_text = result["candidates"][0]["content"]["parts"][0]["text"]
|
47 |
+
print(f"[LOG] 优化后提示词: {opt_text}")
|
48 |
+
return opt_text.strip()
|
49 |
+
except Exception as e:
|
50 |
+
print(f"[ERROR] Gemini优化失败: {e}")
|
51 |
+
return en_text
|
52 |
+
# 添加自动翻译功能
|
53 |
+
from transformers import MarianMTModel, MarianTokenizer
|
54 |
+
|
55 |
+
def translate_prompt(text):
|
56 |
+
# 检查是否包含中文字符
|
57 |
+
if any('\u4e00' <= ch <= '\u9fff' for ch in text):
|
58 |
+
print(f"[LOG] 检测到中文提示词: {text}")
|
59 |
+
model_name = 'Helsinki-NLP/opus-mt-zh-en'
|
60 |
+
tokenizer = MarianTokenizer.from_pretrained(model_name)
|
61 |
+
model = MarianMTModel.from_pretrained(model_name)
|
62 |
+
translated = model.generate(**tokenizer(text, return_tensors="pt", padding=True))
|
63 |
+
en_text = tokenizer.decode(translated[0], skip_special_tokens=True)
|
64 |
+
print(f"[LOG] 翻译后的英文提示词: {en_text}")
|
65 |
+
return en_text
|
66 |
+
print(f"[LOG] 非中文提示词,直接使用: {text}")
|
67 |
+
return text
|
68 |
|
69 |
# 熊猫烧香相关的提示词模板
|
70 |
PANDA_INCENSE_PROMPTS = [
|
|
|
91 |
num_inference_steps,
|
92 |
progress=gr.Progress(track_tqdm=True),
|
93 |
):
|
94 |
+
# 自动翻译中文提示词为英文
|
95 |
+
print(f"[LOG] 用户输入提示词: {prompt}")
|
96 |
+
prompt_en = translate_prompt(prompt)
|
97 |
+
print(f"[LOG] 英文提示词: {prompt_en}")
|
98 |
+
prompt_opt = optimize_prompt(prompt_en)
|
99 |
+
print(f"[LOG] 最终用于生成的提示词: {prompt_opt}")
|
100 |
+
|
101 |
if randomize_seed:
|
102 |
seed = random.randint(0, MAX_SEED)
|
103 |
+
print(f"[LOG] 随机生成种子: {seed}")
|
104 |
+
else:
|
105 |
+
print(f"[LOG] 使用用户指定种子: {seed}")
|
106 |
|
107 |
generator = torch.Generator().manual_seed(seed)
|
108 |
+
print(f"[LOG] 生成器初始化完成")
|
109 |
|
110 |
+
print(f"[LOG] 开始生成图片,参数: guidance_scale={guidance_scale}, steps={num_inference_steps}, width={width}, height={height}")
|
111 |
image = pipe(
|
112 |
+
prompt=prompt_opt,
|
113 |
negative_prompt=negative_prompt,
|
114 |
guidance_scale=guidance_scale,
|
115 |
num_inference_steps=num_inference_steps,
|
|
|
117 |
height=height,
|
118 |
generator=generator,
|
119 |
).images[0]
|
120 |
+
print(f"[LOG] 图片生成完成")
|
121 |
return image, seed
|
122 |
|
123 |
|
requirements.txt
CHANGED
@@ -4,4 +4,5 @@ invisible_watermark
|
|
4 |
torch
|
5 |
transformers
|
6 |
xformers
|
7 |
-
gradio
|
|
|
|
4 |
torch
|
5 |
transformers
|
6 |
xformers
|
7 |
+
gradio
|
8 |
+
requests
|