JusTalk / app.py
buletomato25
upload
093dde1
raw
history blame
11 kB
from flask import Flask, request, jsonify, render_template, send_from_directory,redirect, make_response, Response, session, url_for
import base64
from pydub import AudioSegment # 変換用にpydubをインポート
import os
import shutil
import numpy as np
import string
import random
from datetime import datetime, timedelta
from pyannote.audio import Model, Inference
from pydub import AudioSegment
from flask_sqlalchemy import SQLAlchemy
from dotenv import load_dotenv
from google.oauth2 import id_token
from google_auth_oauthlib.flow import Flow
from google.auth.transport import requests as google_requests
from new_record import record_bp
# Hugging Face のトークン取得(環境変数 HF に設定)
#hf_token = os.environ.get("HF")
load_dotenv()
hf_token = os.getenv("HF")
if hf_token is None:
raise ValueError("HUGGINGFACE_HUB_TOKEN が設定されていません。")
# キャッシュディレクトリの作成(書き込み可能な /tmp を利用)
cache_dir = "/tmp/hf_cache"
os.makedirs(cache_dir, exist_ok=True)
# pyannote モデルの読み込み
model = Model.from_pretrained("pyannote/embedding", use_auth_token=hf_token, cache_dir=cache_dir)
inference = Inference(model)
app = Flask(__name__)
app.config['SECRET_KEY'] = os.urandom(24)
# Google OAuth 2.0の設定
os.environ["OAUTHLIB_INSECURE_TRANSPORT"] = "1"
GOOGLE_CLIENT_ID = "228160683186-6u7986qsfhcv3kd9iqtv08iphpl4gdk2.apps.googleusercontent.com"
GOOGLE_CLIENT_SECRET = "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"
#HFにpushするときは下記のコメントアウトを外してください
#REDIRECT_URI = "https://huggingface.co/spaces/Justtalk/JusTalk/callback"
#ローカルの時はこちら
REDIRECT_URI = "http://127.0.0.1:7860/callback"
flow = Flow.from_client_secrets_file(
'client_secret.json',
scopes=["openid", "https://www.googleapis.com/auth/userinfo.profile", "https://www.googleapis.com/auth/userinfo.email"],
redirect_uri=REDIRECT_URI
)
def cosine_similarity(vec1, vec2):
vec1 = vec1 / np.linalg.norm(vec1)
vec2 = vec2 / np.linalg.norm(vec2)
return np.dot(vec1, vec2)
def segment_audio(path, target_path='/tmp/setup_voice', seg_duration=1.0):
"""
音声を指定秒数ごとに分割する。
target_path に分割したファイルを保存し、元の音声の総長(ミリ秒)を返す。
"""
os.makedirs(target_path, exist_ok=True)
base_sound = AudioSegment.from_file(path)
duration_ms = len(base_sound)
seg_duration_ms = int(seg_duration * 1000)
for i, start in enumerate(range(0, duration_ms, seg_duration_ms)):
end = min(start + seg_duration_ms, duration_ms)
segment = base_sound[start:end]
segment.export(os.path.join(target_path, f'{i}.wav'), format="wav")
return target_path, duration_ms
def calculate_similarity(path1, path2):
embedding1 = inference(path1)
embedding2 = inference(path2)
return float(cosine_similarity(embedding1.data.flatten(), embedding2.data.flatten()))
def process_audio(reference_path, input_path, output_folder='/tmp/data/matched_segments', seg_duration=1.0, threshold=0.5):
"""
入力音声ファイルを seg_duration 秒ごとに分割し、各セグメントと参照音声の類似度を計算。
類似度が threshold を超えたセグメントを output_folder にコピーし、マッチした時間(ms)と
マッチしなかった時間(ms)を返す。
"""
os.makedirs(output_folder, exist_ok=True)
segmented_path, total_duration_ms = segment_audio(input_path, seg_duration=seg_duration)
matched_time_ms = 0
for file in sorted(os.listdir(segmented_path)):
segment_file = os.path.join(segmented_path, file)
similarity = calculate_similarity(segment_file, reference_path)
if similarity > threshold:
shutil.copy(segment_file, output_folder)
matched_time_ms += len(AudioSegment.from_file(segment_file))
unmatched_time_ms = total_duration_ms - matched_time_ms
return matched_time_ms, unmatched_time_ms
def generate_random_string(length):
letters = string.ascii_letters + string.digits
return ''.join(random.choice(letters) for i in range(length))
def generate_filename(random_length):
random_string = generate_random_string(random_length)
current_time = datetime.now().strftime("%Y%m%d%H%M%S")
filename = f"{current_time}_{random_string}.wav"
return filename
app.register_blueprint(record_bp)
# トップページ(テンプレート: index.html)
@app.route('/')
def top():
return redirect('index')
# ログイン画面(テンプレート: login.html)
@app.route('/login')
def login():
authorization_url, state = flow.authorization_url()
session['state'] = state
return redirect(authorization_url)
# ログイン後画面
@app.route('/callback')
def callback():
flow.fetch_token(authorization_response=request.url)
# `session.get('state')` を使用し、エラーを防ぐ
session_state = session.get('state')
request_state = request.args.get('state')
if session_state is None or session_state != request_state:
print(f"State mismatch error: session_state={session_state}, request_state={request_state}")
return 'State mismatch error', 400
credentials = flow.credentials
request_session = google_requests.Request()
id_info = id_token.verify_oauth2_token(
credentials.id_token, request_session, GOOGLE_CLIENT_ID
)
session['google_id'] = id_info.get("sub")
session['email'] = id_info.get("email")
session['name'] = id_info.get("name")
return redirect(url_for('new_person'))
# フィードバック画面(テンプレート: feedback.html)
@app.route('/feedback', methods=['GET', 'POST'])
def feedback():
#ログイン問題解決しだい戻す
"""
if 'google_id' not in session:
return redirect(url_for('login'))
user_info = {
'name': session.get('name'),
'email': session.get('email')
}
"""
return render_template('feedback.html')
# 会話詳細画面(テンプレート: talkDetail.html)
@app.route('/talk_detail', methods=['GET', 'POST'])
def talk_detail():
"""
if 'google_id' not in session:
return redirect(url_for('login'))
user_info = {
'name': session.get('name'),
'email': session.get('email')
}
"""
return render_template('talkDetail.html')
# インデックス画面(テンプレート: index.html)
@app.route('/index', methods=['GET', 'POST'])
def index():
"""
if 'google_id' not in session:
return redirect(url_for('login'))
user_info = {
'name': session.get('name'),
'email': session.get('email')
}
"""
return render_template('index.html')
# 登録画面(テンプレート: new_person.html)
@app.route('/new_person', methods=['GET', 'POST'])
def new_person():
if 'google_id' not in session:
return redirect(url_for('login'))
user_info = {
'name': session.get('name'),
'email': session.get('email')
}
return render_template('new_person.html', user=user_info)
@app.before_request
def before_request():
# リクエストのたびにセッションの寿命を更新する
session.permanent = True
app.permanent_session_lifetime = timedelta(minutes=15)
session.modified = True
# 音声アップロード&解析エンドポイント
@app.route('/upload_audio', methods=['POST'])
def upload_audio():
try:
data = request.get_json()
if not data or 'audio_data' not in data:
return jsonify({"error": "音声データがありません"}), 400
# Base64デコードして音声バイナリを取得
audio_binary = base64.b64decode(data['audio_data'])
audio_dir = "/tmp/data"
os.makedirs(audio_dir, exist_ok=True)
# 固定ファイル名(必要に応じて generate_filename() で一意のファイル名に変更可能)
audio_path = os.path.join(audio_dir, "recorded_audio.wav")
with open(audio_path, 'wb') as f:
f.write(audio_binary)
# 参照音声ファイルのパスを指定(sample.wav を正しい場所に配置すること)
reference_audio = os.path.abspath('/tmp/data/base_audio/recorded_base_audio.wav')
if not os.path.exists(reference_audio):
return jsonify({"error": "参照音声ファイルが見つかりません", "details": reference_audio}), 500
# 音声解析:参照音声とアップロードされた音声との類似度をセグメント毎に計算
# threshold の値は調整可能です(例: 0.1)
matched_time, unmatched_time = process.process_audio(reference_audio, audio_path, threshold=0.05)
total_time = matched_time + unmatched_time
rate = (matched_time / total_time) * 100 if total_time > 0 else 0
return jsonify({"rate": rate}), 200
except Exception as e:
print("Error in /upload_audio:", str(e))
return jsonify({"error": "サーバーエラー", "details": str(e)}), 500
@app.route('/upload_base_audio', methods=['POST'])
def upload_base_audio():
try:
data = request.get_json()
if not data or 'audio_data' not in data:
return jsonify({"error": "音声データがありません"}), 400
# Base64デコードして音声バイナリを取得
audio_binary = base64.b64decode(data['audio_data'])
# 保存先ディレクトリの作成
audio_dir = "/tmp/data/base_audio"
os.makedirs(audio_dir, exist_ok=True)
# 一時ファイルに保存(実際の形式は WebM などと仮定)
temp_audio_path = os.path.join(audio_dir, "temp_audio")
with open(temp_audio_path, 'wb') as f:
f.write(audio_binary)
# pydub を使って一時ファイルを WAV に変換
# ※ここでは WebM 形式と仮定していますが、実際の形式に合わせて format の指定を変更してください
try:
audio = AudioSegment.from_file(temp_audio_path, format="webm")
except Exception as e:
# 形式が不明な場合は自動判別させる(ただし変換できない場合もあり)
audio = AudioSegment.from_file(temp_audio_path)
wav_audio_path = os.path.join(audio_dir, "recorded_base_audio.wav")
audio.export(wav_audio_path, format="wav")
# 一時ファイルを削除
os.remove(temp_audio_path)
return jsonify({"state": "Registration Success!"}), 200
except Exception as e:
print("Error in /upload_base_audio:", str(e))
return jsonify({"error": "サーバーエラー", "details": str(e)}), 500
if __name__ == '__main__':
port = int(os.environ.get("PORT", 7860))
app.run(debug=True, host="0.0.0.0", port=port)