amir22010 commited on
Commit
092b591
·
verified ·
1 Parent(s): d27a194

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -47
app.py CHANGED
@@ -5,24 +5,34 @@ from groq import Groq
5
  import numpy as np
6
  import wave
7
  import uuid
 
 
8
 
9
  #tts
10
- import torchaudio
11
  #from speechbrain.inference.TTS import FastSpeech2
12
- from speechbrain.inference.TTS import Tacotron2
13
- from speechbrain.inference.vocoders import HIFIGAN
14
 
15
  #fastspeech2 = FastSpeech2.from_hparams(source="speechbrain/tts-fastspeech2-ljspeech", savedir="pretrained_models/tts-fastspeech2-ljspeech")
16
- tacotron2 = Tacotron2.from_hparams(source="speechbrain/tts-tacotron2-ljspeech", savedir="tmpdir_tts")
17
- hifi_gan = HIFIGAN.from_hparams(source="speechbrain/tts-hifigan-ljspeech", savedir="pretrained_models/tts-hifigan-ljspeech")
 
 
 
18
 
19
  def text_to_speech(text):
20
- mel_output, mel_length, alignment = tacotron2.encode_text(text)
21
  # Running Vocoder (spectrogram-to-waveform)
22
- waveforms = hifi_gan.decode_batch(mel_output)
23
  # Save the waverform
24
  outfile = f"{os.path.join(os.getcwd(), str(uuid.uuid4()))}.wav"
25
- torchaudio.save(outfile, waveforms.squeeze(1), 22050)
 
 
 
 
 
26
  return outfile
27
 
28
  def combine_audio_files(audio_files):
@@ -55,7 +65,7 @@ llm = Llama.from_pretrained(
55
  )
56
 
57
  #guardrail model
58
- guard_llm = "llama-3.1-8b-instant"
59
 
60
  #marketing prompt
61
  marketing_email_prompt = """Below is a product and description, please write a marketing email for this product.
@@ -69,48 +79,32 @@ marketing_email_prompt = """Below is a product and description, please write a m
69
  ### Marketing Email:
70
  {}"""
71
 
72
- #gaurdrails prompt
73
- guardrail_prompt = """You're given a list of moderation categories as below:
74
-
75
- - illegal: Illegal activity.
76
- - child abuse: child sexual abuse material or any content that exploits or harms children.
77
- - hate violence harassment: Generation of hateful, harassing, or violent content: content that expresses, incites, or promotes hate based on identity, content that intends to harass, threaten, or bully an individual, content that promotes or glorifies violence or celebrates the suffering or humiliation of others.
78
- - malware: Generation of malware: content that attempts to generate code that is designed to disrupt, damage, or gain unauthorized access to a computer system.
79
- - physical harm: activity that has high risk of physical harm, including: weapons development, military and warfare, management or operation of critical infrastructure in energy, transportation, and water, content that promotes, encourages, or depicts acts of self-harm, such as suicide, cutting, and eating disorders.
80
- - economic harm: activity that has high risk of economic harm, including: multi-level marketing, gambling, payday lending, automated determinations of eligibility for credit, employment, educational institutions, or public assistance services.
81
- - fraud: Fraudulent or deceptive activity, including: scams, coordinated inauthentic behavior, plagiarism, academic dishonesty, astroturfing, such as fake grassroots support or fake review generation, disinformation, spam, pseudo-pharmaceuticals.
82
- - adult: Adult content, adult industries, and dating apps, including: content meant to arouse sexual excitement, such as the description of sexual activity, or that promotes sexual services (excluding sex education and wellness), erotic chat, pornography.
83
- - political: Political campaigning or lobbying, by: generating high volumes of campaign materials, generating campaign materials personalized to or targeted at specific demographics, building conversational or interactive systems such as chatbots that provide information about campaigns or engage in political advocacy or lobbying, building products for political campaigning or lobbying purposes.
84
- - privacy: Activity that violates people's privacy, including: tracking or monitoring an individual without their consent, facial recognition of private individuals, classifying individuals based on protected characteristics, using biometrics for identification or assessment, unlawful collection or disclosure of personal identifiable information or educational, financial, or other protected records.
85
- - unqualified law: Engaging in the unauthorized practice of law, or offering tailored legal advice without a qualified person reviewing the information.
86
- - unqualified financial: Offering tailored financial advice without a qualified person reviewing the information.
87
- - unqualified health: Telling someone that they have or do not have a certain health condition, or providing instructions on how to cure or treat a health condition.
88
-
89
- Please classify the following user prompt into one of these categories, and answer with that single word only.
90
-
91
- If the user prompt does not fall within these categories, is safe and does not need to be moderated, please answer "not moderated".
92
-
93
- user prompt: {}
94
- """
95
-
96
  async def greet(product,description):
97
  user_reques = marketing_email_prompt.format(
98
  product, # product
99
  description, # description
100
  "", # output - leave this blank for generation!
101
  )
102
- messages = [
103
- {
104
- "role": "system",
105
- "content": "Your role is to assess whether the user prompt is moderate or not.",
106
- },
107
- {"role": "user", "content": guardrail_prompt.format(user_reques)},
108
- ]
109
- response = client.chat.completions.create(model=guard_llm, messages=messages, temperature=0)
110
- if response.choices[0].message.content != "not moderated":
111
- a_list = ["Sorry, I can't proceed for generating marketing email. Your content needs to be moderated first. Thank you!"]
112
- processed_audio = combine_audio_files([text_to_speech(a_list[0])])
113
- yield processed_audio, a_list[0]
 
 
 
 
 
 
 
 
114
  else:
115
  output = llm.create_chat_completion(
116
  messages=[
@@ -132,8 +126,8 @@ async def greet(product,description):
132
  #audio_list.append([text_to_speech(delta.get('content', ''))])
133
  #processed_audio = combine_audio_files(audio_list)
134
  partial_message = partial_message + delta.get('content', '')
135
- yield gr.Audio(), partial_message
136
 
137
  audio = gr.Audio()
138
- demo = gr.Interface(fn=greet, inputs=["text","text"], concurrency_limit=10, outputs=[audio,"text"])
139
  demo.launch()
 
5
  import numpy as np
6
  import wave
7
  import uuid
8
+ from nemoguardrails import LLMRails, RailsConfig
9
+ from GoogleTTS import GoogleTTS
10
 
11
  #tts
12
+ #import torchaudio
13
  #from speechbrain.inference.TTS import FastSpeech2
14
+ # from speechbrain.inference.TTS import Tacotron2
15
+ # from speechbrain.inference.vocoders import HIFIGAN
16
 
17
  #fastspeech2 = FastSpeech2.from_hparams(source="speechbrain/tts-fastspeech2-ljspeech", savedir="pretrained_models/tts-fastspeech2-ljspeech")
18
+ # tacotron2 = Tacotron2.from_hparams(source="speechbrain/tts-tacotron2-ljspeech", savedir="tmpdir_tts")
19
+ # hifi_gan = HIFIGAN.from_hparams(source="speechbrain/tts-hifigan-ljspeech", savedir="pretrained_models/tts-hifigan-ljspeech")
20
+
21
+ #google tts
22
+ tts = GoogleTTS()
23
 
24
  def text_to_speech(text):
25
+ # mel_output, mel_length, alignment = tacotron2.encode_text(text)
26
  # Running Vocoder (spectrogram-to-waveform)
27
+ # waveforms = hifi_gan.decode_batch(mel_output)
28
  # Save the waverform
29
  outfile = f"{os.path.join(os.getcwd(), str(uuid.uuid4()))}.wav"
30
+ # torchaudio.save(outfile, waveforms.squeeze(1), 22050)
31
+ if len(text) > 5000:
32
+ text_str = text[0:4999]
33
+ else:
34
+ text_str = text
35
+ ret = tts.tts(text_str, outfile)
36
  return outfile
37
 
38
  def combine_audio_files(audio_files):
 
65
  )
66
 
67
  #guardrail model
68
+ guard_llm = "llama-3.2-11b-text-preview"
69
 
70
  #marketing prompt
71
  marketing_email_prompt = """Below is a product and description, please write a marketing email for this product.
 
79
  ### Marketing Email:
80
  {}"""
81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  async def greet(product,description):
83
  user_reques = marketing_email_prompt.format(
84
  product, # product
85
  description, # description
86
  "", # output - leave this blank for generation!
87
  )
88
+ messages=[
89
+ {"role": "user", "content": user_reques},
90
+ ]
91
+ # messages = [
92
+ # {
93
+ # "role": "system",
94
+ # "content": "Your role is to assess the user prompt.",
95
+ # },
96
+ # {"role": "user", "content": guardrail_prompt.format(user_reques)},
97
+ # ]
98
+ #nemo guard
99
+ config = RailsConfig.from_path("nemo")
100
+ app = LLMRails(config=config, llm=client)
101
+ options = {"output_vars": ["triggered_input_rail", "triggered_output_rail"]}
102
+ output = app.generate(messages=messages, options=options)
103
+ warning_message = output.output_data["triggered_input_rail"] or output.output_data["triggered_output_rail"]
104
+ if warning_message:
105
+ gr.Warning(f"Guardrail triggered: {warning_message}")
106
+ chat = [output.response[0]['content']]
107
+ yield chat[0]
108
  else:
109
  output = llm.create_chat_completion(
110
  messages=[
 
126
  #audio_list.append([text_to_speech(delta.get('content', ''))])
127
  #processed_audio = combine_audio_files(audio_list)
128
  partial_message = partial_message + delta.get('content', '')
129
+ yield partial_message
130
 
131
  audio = gr.Audio()
132
+ demo = gr.Interface(fn=greet, inputs=["text","text"], concurrency_limit=10, outputs=["text"])
133
  demo.launch()