K00B404 commited on
Commit
91fe340
1 Parent(s): fd24536

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -5
app.py CHANGED
@@ -15,7 +15,7 @@ API_URL_DEV = "https://api-inference.huggingface.co/models/black-forest-labs/FLU
15
  API_URL = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-schnell"
16
  timeout = 100
17
 
18
- def enhance_prompt(prompt):
19
  client = Client("K00B404/Mistral-Nemo-custom")
20
  result = client.predict(
21
  system_prompt="You are a image generation prompt enhancer and must respond only with the enhanced version of the users input prompt",
@@ -23,7 +23,26 @@ def enhance_prompt(prompt):
23
  api_name="/predict"
24
  )
25
  return result
26
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  def query(prompt, is_negative=False, steps=30, cfg_scale=7, sampler="DPM++ 2M Karras", seed=-1, strength=0.7, huggingface_api_key=None, use_dev=False, enhance_prompt_option=False):
28
  # Determine which API URL to use
29
  api_url = API_URL_DEV if use_dev else API_URL
@@ -34,12 +53,13 @@ def query(prompt, is_negative=False, steps=30, cfg_scale=7, sampler="DPM++ 2M Ka
34
  if is_api_call:
35
  # Use the environment variable for the API key in GUI mode
36
  API_TOKEN = os.getenv("HF_READ_TOKEN")
37
- headers = {"Authorization": f"Bearer {API_TOKEN}"}
38
  else:
39
  # Validate the API key if it's an API call
40
  if huggingface_api_key == "":
41
  raise gr.Error("API key is required for API calls.")
42
- headers = {"Authorization": f"Bearer {huggingface_api_key}"}
 
 
43
 
44
  if prompt == "" or prompt is None:
45
  return None, None, None
@@ -53,7 +73,10 @@ def query(prompt, is_negative=False, steps=30, cfg_scale=7, sampler="DPM++ 2M Ka
53
  if enhance_prompt_option:
54
  prompt = enhance_prompt(prompt)
55
  print(f'\033[1mGeneration {key} enhanced prompt:\033[0m {prompt}')
56
-
 
 
 
57
  final_prompt = f"{prompt} | ultra detail, ultra elaboration, ultra quality, perfect."
58
  print(f'\033[1mGeneration {key}:\033[0m {final_prompt}')
59
 
 
15
  API_URL = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-schnell"
16
  timeout = 100
17
 
18
+ def enhance_prompt(prompt, style="photorealistic"):
19
  client = Client("K00B404/Mistral-Nemo-custom")
20
  result = client.predict(
21
  system_prompt="You are a image generation prompt enhancer and must respond only with the enhanced version of the users input prompt",
 
23
  api_name="/predict"
24
  )
25
  return result
26
+
27
+ def mistral_nemo_call(prompt, model="mistralai/Mistral-Nemo-Instruct-2407"):
28
+
29
+ client = InferenceClient(api_key=API_TOKEN)
30
+ system_prompt=f"""
31
+ You are a image generation prompt enhancer specialized in the {style} style.
32
+ You must respond only with the enhanced version of the users input prompt
33
+ Remember, image generation models can be stimulated by refering to camera 'effect' in the prompt like :4k ,award winning, super details, 35mm lens, hd
34
+ """,
35
+
36
+ response = ""
37
+ for message in client.chat_completion(
38
+ model="mistralai/Mistral-Nemo-Instruct-2407",
39
+ messages=[{"role": "system", "content": system_prompt},{"role": "user", "content": prompt}],
40
+ max_tokens=500,
41
+ stream=True,
42
+ ):
43
+ response += message.choices[0].delta.content
44
+ return response
45
+
46
  def query(prompt, is_negative=False, steps=30, cfg_scale=7, sampler="DPM++ 2M Karras", seed=-1, strength=0.7, huggingface_api_key=None, use_dev=False, enhance_prompt_option=False):
47
  # Determine which API URL to use
48
  api_url = API_URL_DEV if use_dev else API_URL
 
53
  if is_api_call:
54
  # Use the environment variable for the API key in GUI mode
55
  API_TOKEN = os.getenv("HF_READ_TOKEN")
 
56
  else:
57
  # Validate the API key if it's an API call
58
  if huggingface_api_key == "":
59
  raise gr.Error("API key is required for API calls.")
60
+ API_TOKEN = huggingface_api_key
61
+
62
+ headers = {"Authorization": f"Bearer {API_TOKEN}"}
63
 
64
  if prompt == "" or prompt is None:
65
  return None, None, None
 
73
  if enhance_prompt_option:
74
  prompt = enhance_prompt(prompt)
75
  print(f'\033[1mGeneration {key} enhanced prompt:\033[0m {prompt}')
76
+ elif use_mistral_nemo:
77
+ prompt = mistral_nemo_call(prompt)
78
+ print(f'\033[1mGeneration {key} Mistral-Nemo prompt:\033[0m {prompt}')
79
+
80
  final_prompt = f"{prompt} | ultra detail, ultra elaboration, ultra quality, perfect."
81
  print(f'\033[1mGeneration {key}:\033[0m {final_prompt}')
82