ruslanmv commited on
Commit
e9de53d
·
1 Parent(s): b9b8383

First commit

Browse files
Files changed (2) hide show
  1. flux_app/enhance.py +70 -43
  2. flux_app/enhance_v2.py +55 -0
flux_app/enhance.py CHANGED
@@ -1,55 +1,82 @@
1
- # flux_app/enhance.py
2
  import time
3
- from huggingface_hub import InferenceClient
4
- import gradio as gr
5
 
6
- # Initialize the inference client with the new LLM
7
- client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
 
 
 
8
 
9
- # Define the system prompt for enhancing user prompts
10
- SYSTEM_PROMPT = (
11
- "You are a prompt enhancer and your work is to enhance the given prompt under 100 words "
12
- "without changing the essence, only write the enhanced prompt and nothing else."
13
- )
 
14
 
15
- def format_prompt(message):
16
- """
17
- Format the input message using the system prompt and a timestamp to ensure uniqueness.
18
  """
 
 
 
 
 
 
19
  timestamp = time.time()
20
- formatted = (
21
  f"<s>[INST] SYSTEM: {SYSTEM_PROMPT} [/INST]"
22
  f"[INST] {message} {timestamp} [/INST]"
23
  )
24
- return formatted
25
-
26
- def generate(message, max_new_tokens=256, temperature=0.9, top_p=0.95, repetition_penalty=1.0):
27
- """
28
- Generate an enhanced prompt using the new LLM.
29
- This function yields intermediate results as they are generated.
30
- """
31
- temperature = float(temperature)
32
- if temperature < 1e-2:
33
- temperature = 1e-2
34
- top_p = float(top_p)
35
- generate_kwargs = {
36
  "temperature": temperature,
37
- "max_new_tokens": int(max_new_tokens),
38
  "top_p": top_p,
39
- "repetition_penalty": float(repetition_penalty),
40
- "do_sample": True,
 
41
  }
42
- formatted_prompt = format_prompt(message)
43
- stream = client.text_generation(
44
- formatted_prompt,
45
- **generate_kwargs,
46
- stream=True,
47
- details=True,
48
- return_full_text=False,
49
- )
50
- output = ""
51
- for response in stream:
52
- token_text = response.token.text
53
- output += token_text
54
- yield output.strip('</s>')
55
- return output.strip('</s>')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import time
2
+ import requests
3
+ import json
4
 
5
+ def generate(message, max_new_tokens=256, temperature=0.9, top_p=0.95, repetition_penalty=1.0):
6
+ """
7
+ Generates an enhanced prompt using the streaming inference mechanism from a Hugging Face API endpoint.
8
+ This function formats the prompt with a system instruction, sends a streaming request to the API,
9
+ and yields the accumulated text as tokens are received.
10
 
11
+ Parameters:
12
+ message (str): The user's input prompt.
13
+ max_new_tokens (int): The maximum number of tokens to generate.
14
+ temperature (float): Sampling temperature.
15
+ top_p (float): Nucleus sampling parameter.
16
+ repetition_penalty (float): Penalty factor for repetition (not used in the payload but kept for API consistency).
17
 
18
+ Yields:
19
+ str: The accumulated generated text as it streams in.
 
20
  """
21
+ # Define the system prompt.
22
+ SYSTEM_PROMPT = (
23
+ "You are a prompt enhancer and your work is to enhance the given prompt under 100 words "
24
+ "without changing the essence, only write the enhanced prompt and nothing else."
25
+ )
26
+ # Format the prompt with a timestamp for uniqueness.
27
  timestamp = time.time()
28
+ formatted_prompt = (
29
  f"<s>[INST] SYSTEM: {SYSTEM_PROMPT} [/INST]"
30
  f"[INST] {message} {timestamp} [/INST]"
31
  )
32
+
33
+ # Define the API endpoint and headers.
34
+ api_url = "https://ruslanmv-hf-llm-api.hf.space/api/v1/chat/completions"
35
+ headers = {"Content-Type": "application/json"}
36
+
37
+ # Build the payload for the inference request.
38
+ payload = {
39
+ "model": "mixtral-8x7b",
40
+ "messages": [{"role": "user", "content": formatted_prompt}],
 
 
 
41
  "temperature": temperature,
 
42
  "top_p": top_p,
43
+ "max_tokens": max_new_tokens,
44
+ "use_cache": False,
45
+ "stream": True
46
  }
47
+
48
+ try:
49
+ response = requests.post(api_url, headers=headers, json=payload, stream=True)
50
+ response.raise_for_status()
51
+ full_output = ""
52
+
53
+ # Process the streaming response line by line.
54
+ for line in response.iter_lines():
55
+ if not line:
56
+ continue
57
+
58
+ decoded_line = line.decode("utf-8").strip()
59
+ # Remove the "data:" prefix if present.
60
+ if decoded_line.startswith("data:"):
61
+ decoded_line = decoded_line[len("data:"):].strip()
62
+
63
+ # Check if the stream is finished.
64
+ if decoded_line == "[DONE]":
65
+ break
66
+
67
+ try:
68
+ json_data = json.loads(decoded_line)
69
+ for choice in json_data.get("choices", []):
70
+ delta = choice.get("delta", {})
71
+ content = delta.get("content", "")
72
+ full_output += content
73
+ yield full_output # Yield the accumulated text so far.
74
+
75
+ # If the finish reason is provided, stop further streaming.
76
+ if choice.get("finish_reason") == "stop":
77
+ return
78
+ except json.JSONDecodeError:
79
+ # If a line is not valid JSON, skip it.
80
+ continue
81
+ except requests.exceptions.RequestException as e:
82
+ yield f"Error during generation: {str(e)}"
flux_app/enhance_v2.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flux_app/enhance.py
2
+ import time
3
+ from huggingface_hub import InferenceClient
4
+ import gradio as gr
5
+
6
+ # Initialize the inference client with the new LLM
7
+ client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
8
+
9
+ # Define the system prompt for enhancing user prompts
10
+ SYSTEM_PROMPT = (
11
+ "You are a prompt enhancer and your work is to enhance the given prompt under 100 words "
12
+ "without changing the essence, only write the enhanced prompt and nothing else."
13
+ )
14
+
15
+ def format_prompt(message):
16
+ """
17
+ Format the input message using the system prompt and a timestamp to ensure uniqueness.
18
+ """
19
+ timestamp = time.time()
20
+ formatted = (
21
+ f"<s>[INST] SYSTEM: {SYSTEM_PROMPT} [/INST]"
22
+ f"[INST] {message} {timestamp} [/INST]"
23
+ )
24
+ return formatted
25
+
26
+ def generate(message, max_new_tokens=256, temperature=0.9, top_p=0.95, repetition_penalty=1.0):
27
+ """
28
+ Generate an enhanced prompt using the new LLM.
29
+ This function yields intermediate results as they are generated.
30
+ """
31
+ temperature = float(temperature)
32
+ if temperature < 1e-2:
33
+ temperature = 1e-2
34
+ top_p = float(top_p)
35
+ generate_kwargs = {
36
+ "temperature": temperature,
37
+ "max_new_tokens": int(max_new_tokens),
38
+ "top_p": top_p,
39
+ "repetition_penalty": float(repetition_penalty),
40
+ "do_sample": True,
41
+ }
42
+ formatted_prompt = format_prompt(message)
43
+ stream = client.text_generation(
44
+ formatted_prompt,
45
+ **generate_kwargs,
46
+ stream=True,
47
+ details=True,
48
+ return_full_text=False,
49
+ )
50
+ output = ""
51
+ for response in stream:
52
+ token_text = response.token.text
53
+ output += token_text
54
+ yield output.strip('</s>')
55
+ return output.strip('</s>')