Spaces:
Running
Running
from flask import Flask, request, jsonify, render_template, send_from_directory,redirect, make_response, Response, session, url_for | |
import base64 | |
from pydub import AudioSegment # 変換用にpydubをインポート | |
import os | |
import shutil | |
import numpy as np | |
import string | |
import random | |
from datetime import datetime, timedelta | |
from pyannote.audio import Model, Inference | |
from pydub import AudioSegment | |
from flask_sqlalchemy import SQLAlchemy | |
from dotenv import load_dotenv | |
from google.oauth2 import id_token | |
from google_auth_oauthlib.flow import Flow | |
from google.auth.transport import requests as google_requests | |
from new_record import record_bp | |
# Hugging Face のトークン取得(環境変数 HF に設定) | |
#hf_token = os.environ.get("HF") | |
load_dotenv() | |
hf_token = os.getenv("HF") | |
if hf_token is None: | |
raise ValueError("HUGGINGFACE_HUB_TOKEN が設定されていません。") | |
# キャッシュディレクトリの作成(書き込み可能な /tmp を利用) | |
cache_dir = "/tmp/hf_cache" | |
os.makedirs(cache_dir, exist_ok=True) | |
# pyannote モデルの読み込み | |
model = Model.from_pretrained("pyannote/embedding", use_auth_token=hf_token, cache_dir=cache_dir) | |
inference = Inference(model) | |
app = Flask(__name__) | |
app.config['SECRET_KEY'] = os.urandom(24) | |
# Google OAuth 2.0の設定 | |
os.environ["OAUTHLIB_INSECURE_TRANSPORT"] = "1" | |
GOOGLE_CLIENT_ID = "228160683186-6u7986qsfhcv3kd9iqtv08iphpl4gdk2.apps.googleusercontent.com" | |
GOOGLE_CLIENT_SECRET = "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" | |
#HFにpushするときは下記のコメントアウトを外してください | |
#REDIRECT_URI = "https://huggingface.co/spaces/Justtalk/JusTalk/callback" | |
#ローカルの時はこちら | |
REDIRECT_URI = "http://127.0.0.1:7860/callback" | |
flow = Flow.from_client_secrets_file( | |
'client_secret.json', | |
scopes=["openid", "https://www.googleapis.com/auth/userinfo.profile", "https://www.googleapis.com/auth/userinfo.email"], | |
redirect_uri=REDIRECT_URI | |
) | |
def cosine_similarity(vec1, vec2): | |
vec1 = vec1 / np.linalg.norm(vec1) | |
vec2 = vec2 / np.linalg.norm(vec2) | |
return np.dot(vec1, vec2) | |
def segment_audio(path, target_path='/tmp/setup_voice', seg_duration=1.0): | |
""" | |
音声を指定秒数ごとに分割する。 | |
target_path に分割したファイルを保存し、元の音声の総長(ミリ秒)を返す。 | |
""" | |
os.makedirs(target_path, exist_ok=True) | |
base_sound = AudioSegment.from_file(path) | |
duration_ms = len(base_sound) | |
seg_duration_ms = int(seg_duration * 1000) | |
for i, start in enumerate(range(0, duration_ms, seg_duration_ms)): | |
end = min(start + seg_duration_ms, duration_ms) | |
segment = base_sound[start:end] | |
segment.export(os.path.join(target_path, f'{i}.wav'), format="wav") | |
return target_path, duration_ms | |
def calculate_similarity(path1, path2): | |
embedding1 = inference(path1) | |
embedding2 = inference(path2) | |
return float(cosine_similarity(embedding1.data.flatten(), embedding2.data.flatten())) | |
def process_audio(reference_path, input_path, output_folder='/tmp/data/matched_segments', seg_duration=1.0, threshold=0.5): | |
""" | |
入力音声ファイルを seg_duration 秒ごとに分割し、各セグメントと参照音声の類似度を計算。 | |
類似度が threshold を超えたセグメントを output_folder にコピーし、マッチした時間(ms)と | |
マッチしなかった時間(ms)を返す。 | |
""" | |
os.makedirs(output_folder, exist_ok=True) | |
segmented_path, total_duration_ms = segment_audio(input_path, seg_duration=seg_duration) | |
matched_time_ms = 0 | |
for file in sorted(os.listdir(segmented_path)): | |
segment_file = os.path.join(segmented_path, file) | |
similarity = calculate_similarity(segment_file, reference_path) | |
if similarity > threshold: | |
shutil.copy(segment_file, output_folder) | |
matched_time_ms += len(AudioSegment.from_file(segment_file)) | |
unmatched_time_ms = total_duration_ms - matched_time_ms | |
return matched_time_ms, unmatched_time_ms | |
def generate_random_string(length): | |
letters = string.ascii_letters + string.digits | |
return ''.join(random.choice(letters) for i in range(length)) | |
def generate_filename(random_length): | |
random_string = generate_random_string(random_length) | |
current_time = datetime.now().strftime("%Y%m%d%H%M%S") | |
filename = f"{current_time}_{random_string}.wav" | |
return filename | |
app.register_blueprint(record_bp) | |
# トップページ(テンプレート: index.html) | |
def top(): | |
return redirect('index') | |
# ログイン画面(テンプレート: login.html) | |
def login(): | |
authorization_url, state = flow.authorization_url() | |
session['state'] = state | |
return redirect(authorization_url) | |
# ログイン後画面 | |
def callback(): | |
flow.fetch_token(authorization_response=request.url) | |
# `session.get('state')` を使用し、エラーを防ぐ | |
session_state = session.get('state') | |
request_state = request.args.get('state') | |
if session_state is None or session_state != request_state: | |
print(f"State mismatch error: session_state={session_state}, request_state={request_state}") | |
return 'State mismatch error', 400 | |
credentials = flow.credentials | |
request_session = google_requests.Request() | |
id_info = id_token.verify_oauth2_token( | |
credentials.id_token, request_session, GOOGLE_CLIENT_ID | |
) | |
session['google_id'] = id_info.get("sub") | |
session['email'] = id_info.get("email") | |
session['name'] = id_info.get("name") | |
return redirect(url_for('new_person')) | |
# フィードバック画面(テンプレート: feedback.html) | |
def feedback(): | |
#ログイン問題解決しだい戻す | |
""" | |
if 'google_id' not in session: | |
return redirect(url_for('login')) | |
user_info = { | |
'name': session.get('name'), | |
'email': session.get('email') | |
} | |
""" | |
return render_template('feedback.html') | |
# 会話詳細画面(テンプレート: talkDetail.html) | |
def talk_detail(): | |
""" | |
if 'google_id' not in session: | |
return redirect(url_for('login')) | |
user_info = { | |
'name': session.get('name'), | |
'email': session.get('email') | |
} | |
""" | |
return render_template('talkDetail.html') | |
# インデックス画面(テンプレート: index.html) | |
def index(): | |
""" | |
if 'google_id' not in session: | |
return redirect(url_for('login')) | |
user_info = { | |
'name': session.get('name'), | |
'email': session.get('email') | |
} | |
""" | |
return render_template('index.html') | |
# 登録画面(テンプレート: new_person.html) | |
def new_person(): | |
if 'google_id' not in session: | |
return redirect(url_for('login')) | |
user_info = { | |
'name': session.get('name'), | |
'email': session.get('email') | |
} | |
return render_template('new_person.html', user=user_info) | |
def before_request(): | |
# リクエストのたびにセッションの寿命を更新する | |
session.permanent = True | |
app.permanent_session_lifetime = timedelta(minutes=15) | |
session.modified = True | |
# 音声アップロード&解析エンドポイント | |
def upload_audio(): | |
try: | |
data = request.get_json() | |
if not data or 'audio_data' not in data: | |
return jsonify({"error": "音声データがありません"}), 400 | |
# Base64デコードして音声バイナリを取得 | |
audio_binary = base64.b64decode(data['audio_data']) | |
audio_dir = "/tmp/data" | |
os.makedirs(audio_dir, exist_ok=True) | |
# 固定ファイル名(必要に応じて generate_filename() で一意のファイル名に変更可能) | |
audio_path = os.path.join(audio_dir, "recorded_audio.wav") | |
with open(audio_path, 'wb') as f: | |
f.write(audio_binary) | |
# 参照音声ファイルのパスを指定(sample.wav を正しい場所に配置すること) | |
reference_audio = os.path.abspath('/tmp/data/base_audio/recorded_base_audio.wav') | |
if not os.path.exists(reference_audio): | |
return jsonify({"error": "参照音声ファイルが見つかりません", "details": reference_audio}), 500 | |
# 音声解析:参照音声とアップロードされた音声との類似度をセグメント毎に計算 | |
# threshold の値は調整可能です(例: 0.1) | |
matched_time, unmatched_time = process.process_audio(reference_audio, audio_path, threshold=0.05) | |
total_time = matched_time + unmatched_time | |
rate = (matched_time / total_time) * 100 if total_time > 0 else 0 | |
return jsonify({"rate": rate}), 200 | |
except Exception as e: | |
print("Error in /upload_audio:", str(e)) | |
return jsonify({"error": "サーバーエラー", "details": str(e)}), 500 | |
def upload_base_audio(): | |
try: | |
data = request.get_json() | |
if not data or 'audio_data' not in data: | |
return jsonify({"error": "音声データがありません"}), 400 | |
# Base64デコードして音声バイナリを取得 | |
audio_binary = base64.b64decode(data['audio_data']) | |
# 保存先ディレクトリの作成 | |
audio_dir = "/tmp/data/base_audio" | |
os.makedirs(audio_dir, exist_ok=True) | |
# 一時ファイルに保存(実際の形式は WebM などと仮定) | |
temp_audio_path = os.path.join(audio_dir, "temp_audio") | |
with open(temp_audio_path, 'wb') as f: | |
f.write(audio_binary) | |
# pydub を使って一時ファイルを WAV に変換 | |
# ※ここでは WebM 形式と仮定していますが、実際の形式に合わせて format の指定を変更してください | |
try: | |
audio = AudioSegment.from_file(temp_audio_path, format="webm") | |
except Exception as e: | |
# 形式が不明な場合は自動判別させる(ただし変換できない場合もあり) | |
audio = AudioSegment.from_file(temp_audio_path) | |
wav_audio_path = os.path.join(audio_dir, "recorded_base_audio.wav") | |
audio.export(wav_audio_path, format="wav") | |
# 一時ファイルを削除 | |
os.remove(temp_audio_path) | |
return jsonify({"state": "Registration Success!"}), 200 | |
except Exception as e: | |
print("Error in /upload_base_audio:", str(e)) | |
return jsonify({"error": "サーバーエラー", "details": str(e)}), 500 | |
if __name__ == '__main__': | |
port = int(os.environ.get("PORT", 7860)) | |
app.run(debug=True, host="0.0.0.0", port=port) | |