gospacedev commited on
Commit
1b34aa5
·
1 Parent(s): 9d744bd

create formatted chat history

Browse files
Files changed (1) hide show
  1. app.py +21 -10
app.py CHANGED
@@ -8,10 +8,16 @@ from huggingface_hub import InferenceClient
8
 
9
 
10
  ASR_MODEL_NAME = "openai/whisper-small"
11
- NLP_MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.2"
12
- system_prompt = """"<s> [INST] You are Friday a helpful and conversational assistant. [/INST]"""
13
 
14
- client = InferenceClient(NLP_MODEL_NAME)
 
 
 
 
 
 
 
15
 
16
  device = 0 if torch.cuda.is_available() else "cpu"
17
 
@@ -22,7 +28,7 @@ pipe = pipeline(
22
  )
23
 
24
 
25
- def generate(prompt, temperature=0.1, max_new_tokens=64, top_p=0.95, repetition_penalty=1.0):
26
  temperature = float(temperature)
27
  if temperature < 1e-2:
28
  temperature = 1e-2
@@ -37,10 +43,10 @@ def generate(prompt, temperature=0.1, max_new_tokens=64, top_p=0.95, repetition_
37
  seed=42,
38
  )
39
 
40
- formatted_prompt = system_prompt + f""" {prompt} </s>"""
41
 
42
  output = client.text_generation(
43
- formatted_prompt, **generate_kwargs, stream=False, details=False, return_full_text=False)
44
 
45
  print(output)
46
  return output
@@ -54,13 +60,18 @@ def transcribe(audio):
54
 
55
  inputs = pipe({"sampling_rate": sr, "raw": y})["text"]
56
 
57
- print("User transcription: ", inputs)
 
 
 
 
 
 
58
 
59
- response = generate(inputs)
60
- audio_response = gTTS(response)
61
  audio_response.save("response.mp3")
62
 
63
- print(audio_response)
64
 
65
  return "response.mp3"
66
 
 
8
 
9
 
10
  ASR_MODEL_NAME = "openai/whisper-small"
11
+ LLM_MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.2"
 
12
 
13
+
14
+ system_prompt = """"<s>[INST] You are Friday, a helpful and conversational AI assistant and You respond with one to two sentences. [/INST] Hello there! I'm friday how can I help you?</s>"""
15
+
16
+ chat_history = system_prompt + """"""
17
+
18
+ formatted_history = """"""
19
+
20
+ client = InferenceClient(LLM_MODEL_NAME)
21
 
22
  device = 0 if torch.cuda.is_available() else "cpu"
23
 
 
28
  )
29
 
30
 
31
+ def generate(user_prompt, temperature=0.1, max_new_tokens=128, top_p=0.95, repetition_penalty=1.0):
32
  temperature = float(temperature)
33
  if temperature < 1e-2:
34
  temperature = 1e-2
 
43
  seed=42,
44
  )
45
 
46
+ chat_history += f""" <s>[INST] {user_prompt} [/INST] """
47
 
48
  output = client.text_generation(
49
+ chat_history, **generate_kwargs, stream=False, details=False, return_full_text=False)
50
 
51
  print(output)
52
  return output
 
60
 
61
  inputs = pipe({"sampling_rate": sr, "raw": y})["text"]
62
 
63
+ formatted_history += f"""Human: {inputs}\n"""
64
+
65
+ llm_response = generate(inputs)
66
+
67
+ chat_history += f""" {llm_response}</s>"""
68
+
69
+ formatted_history += f"""Friday: {llm_response}\n"""
70
 
71
+ audio_response = gTTS(llm_response)
 
72
  audio_response.save("response.mp3")
73
 
74
+ print(formatted_history)
75
 
76
  return "response.mp3"
77