soojinchoi_test / utils.py
shangrilar's picture
Update utils.py
a8dea68 verified
import os
import sys
import time
import urllib.request
import json
import random
import requests
from voice import voice_dict
from dotenv import load_dotenv
load_dotenv('credentials.env')
OPENAPI_KEY = os.getenv('OPENAPI_KEY')
CLOVA_VOICE_Client_ID = os.getenv('CLOVA_VOICE_Client_ID')
CLOVA_VOICE_Client_Secret = os.getenv('CLOVA_VOICE_Client_Secret')
PAPAGO_Translate_Client_ID = os.getenv('PAPAGO_Translate_Client_ID')
PAPAGO_Translate_Client_Secret = os.getenv('PAPAGO_Translate_Client_Secret')
mubert_pat = os.getenv('mubert_pat')
SUMMARY_Client_ID = os.getenv('SUMMARY_Client_ID')
SUMMARY_Client_Secret = os.getenv('SUMMARY_Client_Secret')
import time
import os
import subprocess
from tempfile import NamedTemporaryFile
import torch
from audiocraft.data.audio import audio_write
from audiocraft.models import MusicGen
# Using small model, better results would be obtained with `medium` or `large`.
model = MusicGen.get_pretrained('facebook/musicgen-melody')
model.set_generation_params(
use_sampling=True,
top_k=250,
duration=30
)
def get_voice(input_text:str, gender:str="female", age_group:str="youth", speed:int=1, pitch:int=1, alpha:int=-1, filename="voice.mp3"):
"""
gender: female or male
age_group: child, teenager, youth, middle_aged
"""
speaker = random.choice(voice_dict[gender][age_group])
data = {"speaker":speaker, "text":input_text, 'speed':speed, 'pitch':pitch, 'alpha':alpha}
url = "https://naveropenapi.apigw.ntruss.com/tts-premium/v1/tts"
headers = {
"X-NCP-APIGW-API-KEY-ID": CLOVA_VOICE_Client_ID,
"X-NCP-APIGW-API-KEY": CLOVA_VOICE_Client_Secret,
}
response = requests.post(url, headers=headers, data=data)
if response.status_code == 200:
print("TTS mp3 μ €μž₯")
response_body = response.content
with open(filename, 'wb') as f:
f.write(response_body)
else:
print("Error Code: " + str(response.status_code))
print("Error Message: " + str(response.json()))
return filename
def translate_text(text:str):
encText = urllib.parse.quote(text)
data = f"source=ko&target=en&text={encText}"
url = "https://naveropenapi.apigw.ntruss.com/nmt/v1/translation"
request = urllib.request.Request(url)
request.add_header("X-NCP-APIGW-API-KEY-ID", PAPAGO_Translate_Client_ID)
request.add_header("X-NCP-APIGW-API-KEY", PAPAGO_Translate_Client_Secret)
try:
response = urllib.request.urlopen(request, data=data.encode("utf-8"))
response_body = response.read()
return json.loads(response_body.decode('utf-8'))['message']['result']['translatedText']
except urllib.error.HTTPError as e:
return f"Error Code: {e.code}"
# -
def get_summary(input_text:str, summary_count:int = 5):
if len(input_text) > 2000:
input_text = input_text[:2000]
input_text = input_text.strip()
data = {
"document": {
"content": input_text
},
"option": {
"language": "ko",
"model": "general",
"tone": "0",
"summaryCount": summary_count
}
}
url = "https://naveropenapi.apigw.ntruss.com/text-summary/v1/summarize"
headers = {
"X-NCP-APIGW-API-KEY-ID": SUMMARY_Client_ID,
"X-NCP-APIGW-API-KEY": SUMMARY_Client_Secret,
"Content-Type": "application/json"
}
response = requests.post(url, headers=headers, data=json.dumps(data))
if response.status_code == 200:
return ' '.join(response.json()['summary'].split('\n'))
elif response.status_code == 400 and response.json()['error']['errorCode'] == 'E100':
return input_text
else:
print("Error Code: " + str(response.status_code))
print("Error Message: " + str(response.json()))
def get_mubert_music(text, duration=300):
print('original text length: ', len(text))
summary = get_summary(text, 3)
print('summary text length: ', len(summary))
translated_text = translate_text(summary)
print('translated_text length: ', len(translated_text))
if len(translated_text) > 200:
translated_text = translated_text[:200]
r = requests.post('https://api-b2b.mubert.com/v2/TTMRecordTrack',
json={
"method":"TTMRecordTrack",
"params":
{
"text":translated_text,
"pat":mubert_pat,
"mode":"track",
"duration":duration,
"bitrate":128
}
})
rdata = json.loads(r.text)
if rdata['status'] == 1:
url = rdata['data']['tasks'][0]['download_link']
done = False
while not done:
r = requests.post('https://api-b2b.mubert.com/v2/TrackStatus',
json={
"method":"TrackStatus",
"params":
{
"pat":mubert_pat
}
})
if r.json()['data']['tasks'][0]['task_status_text'] == 'Done':
done = True
time.sleep(2)
# return url
local_filename = "mubert_music.mp3"
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3"
}
download = False
while not download:
response = requests.get(url, stream=True, headers=headers)
if response.status_code == 200:
download=True
time.sleep(1)
if response.status_code == 404:
print("파일이 μ‘΄μž¬ν•˜μ§€ μ•ŠμŠ΅λ‹ˆλ‹€.")
return
elif response.status_code != 200:
print(f"파일 λ‹€μš΄λ‘œλ“œμ— μ‹€νŒ¨ν•˜μ˜€μŠ΅λ‹ˆλ‹€. μ—λŸ¬ μ½”λ“œ: {response.status_code}")
return
with open(local_filename, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
if chunk:
f.write(chunk)
print(f"{local_filename} 파일이 μ €μž₯λ˜μ—ˆμŠ΅λ‹ˆλ‹€.")
return local_filename
def get_musicgen_music(text, duration=300):
file_name = 'musicgen_output.wav'
print('original text length: ', len(text))
summary = get_summary(text, 3)
print('summary text length: ', len(summary))
translated_text = translate_text(summary)
print('translated_text length: ', len(translated_text))
if len(translated_text) > 200:
translated_text = translated_text[:200]
print(translated_text)
start = time.time()
overlap = 5
music_length = 30
target_length = duration
desc = [translated_text]
print(model.sample_rate)
output = model.generate(descriptions=desc, progress=True)
while music_length < target_length:
last_sec = output[:, :, int(-overlap*model.sample_rate):]
cont = model.generate_continuation(last_sec, model.sample_rate, descriptions=desc, progress=True)
output = torch.cat([output[:, :, :int(-overlap*model.sample_rate)], cont], 2)
music_length = output.shape[2] / model.sample_rate
if music_length > target_length:
output = output[:, :, :int(target_length*model.sample_rate)]
output = output.detach().cpu().float()[0]
audio_write(
file_name, output, model.sample_rate, strategy="loudness",
loudness_headroom_db=16, loudness_compressor=True, add_suffix=False)
print(f'Elapsed time: {time.time() - start}')
return file_name
# def get_story(first_sentence:str, history, num_sentences:int):
# response = requests.post("https://api.openai.com/v1/chat/completions",
# headers={"Content-Type": "application/json", "Authorization": f"Bearer {OPENAPI_KEY}"},
# data=json.dumps({
# "model": "gpt-3.5-turbo",
# "messages": [{"role": "system", "content": "You are a helpful assistant."},
# {"role": "user", "content": f"""I will provide the first sentence of the novel, and please write {num_sentences} sentences continuing the story in a first-person protagonist's perspective in Korean. Don't number the sentences.
# \n\nStory: {first_sentence}"""}]
# }))
# print(response.json())
# return response.json()['choices'][0]['message']['content']
def get_story(first_sentence:str, num_sentences:int, chatbot=[], history=[]):
history.append(first_sentence)
# make a POST request to the API endpoint using the requests.post method, passing in stream=True
response = requests.post("https://api.openai.com/v1/chat/completions",
headers={"Content-Type": "application/json", "Authorization": f"Bearer {OPENAPI_KEY}"},
stream=True,
data=json.dumps({
"stream": True,
"model": "gpt-3.5-turbo",
"messages": [{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": f"""I will provide the first sentence of the novel, and please write {num_sentences} sentences continuing the story in a first-person protagonist's perspective in Korean. Don't number the sentences.
\n\nFirst sentence: {first_sentence}"""}]
}))
token_counter = 0
partial_words = ""
counter=0
for chunk in response.iter_lines():
#Skipping first chunk
if counter == 0:
counter+=1
continue
# check whether each line is non-empty
if chunk.decode() :
chunk = chunk.decode()
# decode each line as response data is in bytes
if len(chunk) > 12 and "content" in json.loads(chunk[6:])['choices'][0]['delta']:
partial_words = partial_words + json.loads(chunk[6:])['choices'][0]["delta"]["content"]
if token_counter == 0:
history.append(" " + partial_words)
else:
history[-1] = partial_words
chat = [(history[i], history[i + 1]) for i in range(0, len(history) - 1, 2) ] # convert to tuples of list
token_counter+=1
yield chat, history, response
def get_voice_filename(text, gender, age, speed, pitch, alpha):
filename = None
if gender == '남성':
if age == "어린이":
filename = get_voice(text, gender="male", age_group="child", speed=speed, pitch=pitch, alpha=alpha, filename="voice.mp3")
elif age == "μ²­μ†Œλ…„":
filename = get_voice(text, gender="male", age_group="teenager", speed=speed, pitch=pitch, alpha=alpha, filename="voice.mp3")
elif age == "μ²­λ…„":
filename = get_voice(text, gender="male", age_group="youth", speed=speed, pitch=pitch, alpha=alpha, filename="voice.mp3")
elif age == "쀑년":
filename = get_voice(text, gender="male", age_group="middle_aged", speed=speed, pitch=pitch, alpha=alpha, filename="voice.mp3")
else:
if age == "어린이":
filename = get_voice(text, gender="female", age_group="child", speed=speed, pitch=pitch, alpha=alpha, filename="voice.mp3")
elif age == "μ²­μ†Œλ…„":
filename = get_voice(text, gender="female", age_group="teenager", speed=speed, pitch=pitch, alpha=alpha, filename="voice.mp3")
elif age == "μ²­λ…„":
filename = get_voice(text, gender="female", age_group="youth", speed=speed, pitch=pitch, alpha=alpha, filename="voice.mp3")
elif age == "쀑년":
filename = get_voice(text, gender="female", age_group="middle_aged", speed=speed, pitch=pitch, alpha=alpha, filename="voice.mp3")
return filename