versae commited on
Commit
129af76
·
1 Parent(s): 8dfd601

Create duplex.py

Browse files
Files changed (1) hide show
  1. duplex.py +182 -0
duplex.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import random
4
+ import string
5
+
6
+ import gradio as gr
7
+ import torch
8
+ import requests
9
+ from transformers import pipeline, set_seed
10
+ from transformers import AutoTokenizer, AutoModelForCausalLM
11
+ import logging
12
+
13
+ import sys
14
+ import gradio as gr
15
+ from transformers import pipeline, AutoModelForCTC, Wav2Vec2Processor, Wav2Vec2ProcessorWithLM
16
+
17
+ DEBUG = os.environ.get("DEBUG", "false")[0] in "ty1"
18
+ HF_AUTH_TOKEN = os.environ.get("HF_AUTH_TOKEN", None)
19
+ MAX_LENGTH = int(os.environ.get("MAX_LENGTH", 1024))
20
+
21
+ HEADER = """
22
+ # Poor Man's Duplex
23
+ """.strip()
24
+
25
+ FOOTER = """
26
+ <div align=center>
27
+ <img src="https://visitor-badge.glitch.me/badge?page_id=spaces/bertin-project/bertin-gpt-j-6B"/>
28
+ <div align=center>
29
+ """.strip()
30
+
31
+ asr_model_name_es = "jonatasgrosman/wav2vec2-large-xlsr-53-spanish"
32
+ model_instance_es = AutoModelForCTC.from_pretrained(asr_model_name_es)
33
+ processor_es = Wav2Vec2ProcessorWithLM.from_pretrained(asr_model_name_es)
34
+ asr_es = pipeline(
35
+ "automatic-speech-recognition",
36
+ model=model_instance_es,
37
+ tokenizer=processor_es.tokenizer,
38
+ feature_extractor=processor_es.feature_extractor,
39
+ decoder=processor_es.decoder
40
+ )
41
+ tts_model_name = "facebook/tts_transformer-es-css10"
42
+ speak_es = gr.Interface.load(f"huggingface/{tts_model_name}")
43
+ transcribe_es = lambda input_file: asr_es(input_file, chunk_length_s=5, stride_length_s=1)["text"]
44
+ def generate_es(text, **kwargs):
45
+ # max_length=100, top_k=100, top_p=50, temperature=0.95, do_sample=True, do_clean=True
46
+ api_uri = "https://hf.space/embed/bertin-project/bertin-gpt-j-6B/+/api/predict/"
47
+ response = requests.post(api_uri, data=json.dumps({"data": [text, 100, 100, 50, 0.95, True, True]}))
48
+ if response.ok:
49
+ print(response.json())
50
+ return response.json()["data"][0]
51
+ else:
52
+ return ""
53
+
54
+ asr_model_name_en = "jonatasgrosman/wav2vec2-large-xlsr-53-english"
55
+ model_instance_en = AutoModelForCTC.from_pretrained(asr_model_name_en)
56
+ processor_en = Wav2Vec2ProcessorWithLM.from_pretrained(asr_model_name_en)
57
+ asr_en = pipeline(
58
+ "automatic-speech-recognition",
59
+ model=model_instance_en,
60
+ tokenizer=processor_en.tokenizer,
61
+ feature_extractor=processor_en.feature_extractor,
62
+ decoder=processor_en.decoder
63
+ )
64
+ tts_model_name = "facebook/fastspeech2-en-200_speaker-cv4"
65
+ speak_en = gr.Interface.load(f"huggingface/{tts_model_name}")
66
+ transcribe_en = lambda input_file: asr_en(input_file, chunk_length_s=5, stride_length_s=1)["text"]
67
+ generate_iface = gr.Interface.load("huggingface/EleutherAI/gpt-j-6B")
68
+
69
+ def generate_en(text, **kwargs):
70
+ response = generate_iface(text)
71
+ print(response)
72
+ return response or ""
73
+
74
+
75
+ def select_lang(lang):
76
+ if lang.lower() == "spanish":
77
+ return generate_es, transcribe_es, speak_es
78
+ else:
79
+ return generate_en, transcribe_en, speak_en
80
+
81
+
82
+ def select_lang_vars(lang):
83
+ if lang.lower() == "spanish":
84
+ AGENT = "BERTIN"
85
+ USER = "ENTREVISTADOR"
86
+ CONTEXT = """La siguiente conversación es un extracto de una entrevista a {AGENT} celebrada en Madrid para Radio Televisión Española:
87
+
88
+ {USER}: Bienvenido, {AGENT}. Un placer tenerlo hoy con nosotros.
89
+ {AGENT}: Gracias. El placer es mío."""
90
+ else:
91
+ AGENT = "ELEUTHER"
92
+ USER = "INTERVIEWER"
93
+ CONTEXT = """The next conversation is an excerpt from an interview to {AGENT} that appeared in the New York Times:
94
+
95
+ {USER}: Welcome, {AGENT}. It is a pleasure to have you here today.
96
+ {AGENT}: Thanks. The pleasure is mine."""
97
+
98
+ return AGENT, USER, CONTEXT
99
+
100
+
101
+
102
+ def chat_with_gpt(lang, agent, user, context, audio_in, history):
103
+ generate, transcribe, speak = select_lang(lang)
104
+ AGENT, USER, _ = select_lang_vars(lang)
105
+ user_message = transcribe(audio_in)
106
+ # agent = AGENT
107
+ # user = USER
108
+ generation_kwargs = {
109
+ "max_length": 25,
110
+ # "top_k": top_k,
111
+ # "top_p": top_p,
112
+ # "temperature": temperature,
113
+ # "do_sample": do_sample,
114
+ # "do_clean": do_clean,
115
+ # "num_return_sequences": 1,
116
+ # "return_full_text": False,
117
+ }
118
+ message = user_message.split(" ", 1)[0].capitalize() + " " + user_message.split(" ", 1)[-1]
119
+ history = history or [] #[(f"{user}: Bienvenido. Encantado de tenerle con nosotros.", f"{agent}: Un placer, muchas gracias por la invitación.")]
120
+ context = context.format(USER=user or USER, AGENT=agent or AGENT).strip()
121
+ if context[-1] not in ".:":
122
+ context += "."
123
+ context_length = len(context.split())
124
+ history_take = 0
125
+ history_context = "\n".join(f"{user}: {history_message.capitalize()}.\n{agent}: {history_response}." for history_message, history_response in history[-len(history) + history_take:])
126
+ while len(history_context.split()) > MAX_LENGTH - (generation_kwargs["max_length"] + context_length):
127
+ history_take += 1
128
+ history_context = "\n".join(f"{user}: {history_message.capitalize()}.\n{agent}: {history_response}." for history_message, history_response in history[-len(history) + history_take:])
129
+ if history_take >= MAX_LENGTH:
130
+ break
131
+ context += history_context
132
+ for _ in range(5):
133
+ response = generate(f"{context}\n\n{user}: {message}.\n", **generation_kwargs)
134
+ if DEBUG:
135
+ print("\n-----" + response + "-----\n")
136
+ response = response.split("\n")[-1]
137
+ if agent in response and response.split(agent)[-1]:
138
+ response = response.split(agent)[-1]
139
+ if user in response and response.split(user)[-1]:
140
+ response = response.split(user)[-1]
141
+ if response and response[0] in string.punctuation:
142
+ response = response[1:].strip()
143
+ if response.strip().startswith(f"{user}: {message}"):
144
+ response = response.strip().split(f"{user}: {message}")[-1]
145
+ if response.replace(".", "").strip() and message.replace(".", "").strip() != response.replace(".", "").strip():
146
+ break
147
+ if DEBUG:
148
+ print()
149
+ print("CONTEXT:")
150
+ print(context)
151
+ print()
152
+ print("MESSAGE")
153
+ print(message)
154
+ print()
155
+ print("RESPONSE:")
156
+ print(response)
157
+ if not response.strip():
158
+ response = "Lo siento, no puedo hablar ahora" if lang.lower() == "Spanish" else "Sorry, can't talk right now"
159
+ history.append((user_message, response))
160
+ return history, history, speak(response)
161
+
162
+
163
+ with gr.Blocks() as demo:
164
+ gr.Markdown(HEADER)
165
+ lang = gr.Radio(label="Language", choices=["English", "Spanish"], default="English", type="value")
166
+ AGENT, USER, CONTEXT = select_lang_vars("English")
167
+ context = gr.Textbox(label="Context", lines=5, value=CONTEXT)
168
+ with gr.Row():
169
+ audio_in = gr.Audio(label="User", source="microphone", type="filepath")
170
+ audio_out = gr.Audio(label="Agent", interactive=False)
171
+ # chat_btn = gr.Button("Submit")
172
+ with gr.Row():
173
+ user = gr.Textbox(label="User", value=USER)
174
+ agent = gr.Textbox(label="Agent", value=AGENT)
175
+ lang.change(select_lang_vars, inputs=[lang], outputs=[agent, user, context])
176
+ history = gr.Variable(value=[])
177
+ chatbot = gr.Variable() # gr.Chatbot(color_map=("green", "gray"), visible=False)
178
+ # chat_btn.click(chat_with_gpt, inputs=[lang, agent, user, context, audio_in, history], outputs=[chatbot, history, audio_out])
179
+ audio_in.change(chat_with_gpt, inputs=[lang, agent, user, context, audio_in, history], outputs=[chatbot, history, audio_out])
180
+ gr.Markdown(FOOTER)
181
+
182
+ demo.launch()