MrDonStuff commited on
Commit
20c1d60
·
verified ·
1 Parent(s): 3fbfb42

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -27
app.py CHANGED
@@ -1,21 +1,15 @@
1
- from flask import Flask, render_template, request, jsonify
2
  from huggingface_hub import InferenceClient
3
- import gradio as gr
4
 
5
  app = Flask(__name__)
 
6
 
 
7
  client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
8
 
9
- def format_prompt(message, history):
10
- prompt = "<s>"
11
- for user_prompt, bot_response in history:
12
- prompt += f"[INST] {user_prompt} [/INST]"
13
- prompt += f" {bot_response}</s> "
14
- prompt += f"[INST] {message} [/INST]"
15
- return prompt
16
-
17
  def generate(
18
- prompt, history, system_prompt, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,
19
  ):
20
  temperature = float(temperature)
21
  if temperature < 1e-2:
@@ -31,29 +25,34 @@ def generate(
31
  seed=42,
32
  )
33
 
34
- formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)
 
 
35
  stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
36
  output = ""
37
 
38
  for response in stream:
39
  output += response.token.text
40
- yield output
41
  return output
42
 
43
- @app.route('/generate', methods=['POST'])
44
- def generate_response():
45
- data = request.get_json()
46
- prompt = data['prompt']
47
- history = data['history']
48
- system_prompt = data['system_prompt']
49
- temperature = data['temperature']
50
- max_new_tokens = data['max_new_tokens']
51
- top_p = data['top_p']
52
- repetition_penalty = data['repetition_penalty']
53
 
54
- result = list(generate(prompt, history, system_prompt, temperature, max_new_tokens, top_p, repetition_penalty))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
- return jsonify({'result': result})
57
 
58
- if __name__ == '__main__':
59
- app.run(port=7860)
 
1
+ from flask import Flask, request, jsonify
2
  from huggingface_hub import InferenceClient
 
3
 
4
  app = Flask(__name__)
5
+ app.config["DEBUG"] = True # Enable for debugging
6
 
7
+ # Load model client
8
  client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
9
 
10
+ # Function for text generation with enhanced prompt formatting
 
 
 
 
 
 
 
11
  def generate(
12
+ prompt, history, system_prompt, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0
13
  ):
14
  temperature = float(temperature)
15
  if temperature < 1e-2:
 
25
  seed=42,
26
  )
27
 
28
+ # Enhanced prompt formatting for better context
29
+ formatted_prompt = f"{system_prompt}\n{''.join(f'{user_prompt} ||| {bot_response}\n' for user_prompt, bot_response in history)}\n{prompt}"
30
+
31
  stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
32
  output = ""
33
 
34
  for response in stream:
35
  output += response.token.text
 
36
  return output
37
 
 
 
 
 
 
 
 
 
 
 
38
 
39
+ @app.route("/generate", methods=["POST"])
40
+ def generate_text():
41
+ data = request.json
42
+ prompt = data.get("prompt")
43
+ history = data.get("history", [])
44
+ system_prompt = data.get("system_prompt")
45
+ temperature = data.get("temperature", 0.9)
46
+ max_new_tokens = data.get("max_new_tokens", 256)
47
+ top_p = data.get("top_p", 0.95)
48
+ repetition_penalty = data.get("repetition_penalty", 1.0)
49
+
50
+ response = generate(
51
+ prompt, history, system_prompt, temperature, max_new_tokens, top_p, repetition_penalty
52
+ )
53
+
54
+ return jsonify({"response": response})
55
 
 
56
 
57
+ if __name__ == "__main__":
58
+ app.run(host="0.0.0.0", port=7860)