JusTalk / app.py
buletomato25
docker conflict
1dfab95
raw
history blame
10.7 kB
from flask import Flask, request, jsonify, render_template, send_from_directory,redirect, make_response, Response, session
import base64
from pydub import AudioSegment # 変換用にpydubをインポート
import os
import shutil
import numpy as np
import string
import random
from datetime import datetime, timedelta
from pyannote.audio import Model, Inference
from pydub import AudioSegment
from flask_sqlalchemy import SQLAlchemy
from flask_login import LoginManager, UserMixin, login_user, logout_user, login_required
from database import db
from users import Users
from werkzeug.security import generate_password_hash, check_password_hash
# Hugging Face のトークン取得(環境変数 HF に設定)
#hf_token = os.environ.get("HF")
hf_token = "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"
if hf_token is None:
raise ValueError("HUGGINGFACE_HUB_TOKEN が設定されていません。")
# キャッシュディレクトリの作成(書き込み可能な /tmp を利用)
cache_dir = "/tmp/hf_cache"
os.makedirs(cache_dir, exist_ok=True)
# pyannote モデルの読み込み
model = Model.from_pretrained("pyannote/embedding", use_auth_token=hf_token, cache_dir=cache_dir)
inference = Inference(model)
app = Flask(__name__)
app.config['SECRET_KEY'] = os.urandom(24)
# データベース設定
app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///site.db'
app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False
# db を Flask アプリに紐づける
db.init_app(app)
# Flask-Login の設定
login_manager = LoginManager()
login_manager.init_app(app)
login_manager.login_view = "login"
@login_manager.user_loader
def load_user(user_id):
return Users.query.get(int(user_id))
def cosine_similarity(vec1, vec2):
vec1 = vec1 / np.linalg.norm(vec1)
vec2 = vec2 / np.linalg.norm(vec2)
return np.dot(vec1, vec2)
def segment_audio(path, target_path='/tmp/setup_voice', seg_duration=1.0):
"""
音声を指定秒数ごとに分割する。
target_path に分割したファイルを保存し、元の音声の総長(ミリ秒)を返す。
"""
os.makedirs(target_path, exist_ok=True)
base_sound = AudioSegment.from_file(path)
duration_ms = len(base_sound)
seg_duration_ms = int(seg_duration * 1000)
for i, start in enumerate(range(0, duration_ms, seg_duration_ms)):
end = min(start + seg_duration_ms, duration_ms)
segment = base_sound[start:end]
segment.export(os.path.join(target_path, f'{i}.wav'), format="wav")
return target_path, duration_ms
def calculate_similarity(path1, path2):
embedding1 = inference(path1)
embedding2 = inference(path2)
return float(cosine_similarity(embedding1.data.flatten(), embedding2.data.flatten()))
def process_audio(reference_path, input_path, output_folder='/tmp/data/matched_segments', seg_duration=1.0, threshold=0.5):
"""
入力音声ファイルを seg_duration 秒ごとに分割し、各セグメントと参照音声の類似度を計算。
類似度が threshold を超えたセグメントを output_folder にコピーし、マッチした時間(ms)と
マッチしなかった時間(ms)を返す。
"""
os.makedirs(output_folder, exist_ok=True)
segmented_path, total_duration_ms = segment_audio(input_path, seg_duration=seg_duration)
matched_time_ms = 0
for file in sorted(os.listdir(segmented_path)):
segment_file = os.path.join(segmented_path, file)
similarity = calculate_similarity(segment_file, reference_path)
if similarity > threshold:
shutil.copy(segment_file, output_folder)
matched_time_ms += len(AudioSegment.from_file(segment_file))
unmatched_time_ms = total_duration_ms - matched_time_ms
return matched_time_ms, unmatched_time_ms
def generate_random_string(length):
letters = string.ascii_letters + string.digits
return ''.join(random.choice(letters) for i in range(length))
def generate_filename(random_length):
random_string = generate_random_string(random_length)
current_time = datetime.now().strftime("%Y%m%d%H%M%S")
filename = f"{current_time}_{random_string}.wav"
return filename
load_dotenv() # .env ファイルを読み込む
HUGGINGFACE_HUB_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN")
if not HUGGINGFACE_HUB_TOKEN:
raise ValueError("HUGGINGFACE_HUB_TOKEN が設定されていません。")
# トップページ(テンプレート: index.html)
@app.route('/')
def top():
return redirect('/login')
# ログイン後画面
@app.route('/after')
@login_required # ログインしているユーザのみアクセス許可
def index():
users = Users.query.order_by(Users.id).all()
return render_template('index.html', users=users)
@app.route('/index', methods=['GET', 'POST'])
def index_page():
users = Users.query.order_by(Users.id).all()
return render_template('index.html', users=users)
# フィードバック画面(テンプレート: feedback.html)
@app.route('/feedback', methods=['GET', 'POST'])
def feedback():
users = Users.query.order_by(Users.id).all()
return render_template('feedback.html', users=users)
# 会話詳細画面(テンプレート: talkDetail.html)
@app.route('/talk_detail', methods=['GET', 'POST'])
def talk_detail():
users = Users.query.order_by(Users.id).all()
return render_template('talkDetail.html', users=users)
# ログイン画面(テンプレート: login.html)
@app.route('/login', methods=['GET', 'POST'])
def login():
if request.method == "POST":
username = request.form.get('username')
password = request.form.get('password')
# Userテーブルからusernameに一致するユーザを取得
user = Users.query.filter_by(username=username).first()
if check_password_hash(user.password, password):
login_user(user)
return redirect('after')
else:
# 入力したユーザー名のパスワードが間違っている場合
# return "<p>パスワードが間違っています。</p>"
return Response(status=404, response="ページが見つかりません。")
else:
return render_template('login.html')
# 登録画面(テンプレート: new_person.html)
@app.route('/new_person', methods=['GET', 'POST'])
def new_person():
if request.method == "POST":
username = request.form.get('username')
password = request.form.get('password')
# Userのインスタンスを作成
user = Users(username=username, password=generate_password_hash(password, method='sha256'), email=username)
db.session.add(user)
db.session.commit()
return redirect('login')
else:
return render_template('new_person.html')
@app.before_request
def before_request():
# リクエストのたびにセッションの寿命を更新する
session.permanent = True
app.permanent_session_lifetime = timedelta(minutes=15)
session.modified = True
# 音声アップロード&解析エンドポイント
@app.route('/upload_audio', methods=['POST'])
def upload_audio():
try:
data = request.get_json()
if not data or 'audio_data' not in data:
return jsonify({"error": "音声データがありません"}), 400
# Base64デコードして音声バイナリを取得
audio_binary = base64.b64decode(data['audio_data'])
audio_dir = "/tmp/data"
os.makedirs(audio_dir, exist_ok=True)
# 固定ファイル名(必要に応じて generate_filename() で一意のファイル名に変更可能)
audio_path = os.path.join(audio_dir, "recorded_audio.wav")
with open(audio_path, 'wb') as f:
f.write(audio_binary)
# 参照音声ファイルのパスを指定(sample.wav を正しい場所に配置すること)
reference_audio = os.path.abspath('/tmp/data/base_audio/recorded_base_audio.wav')
if not os.path.exists(reference_audio):
return jsonify({"error": "参照音声ファイルが見つかりません", "details": reference_audio}), 500
# 音声解析:参照音声とアップロードされた音声との類似度をセグメント毎に計算
# threshold の値は調整可能です(例: 0.1)
matched_time, unmatched_time = process.process_audio(reference_audio, audio_path, threshold=0.05)
total_time = matched_time + unmatched_time
rate = (matched_time / total_time) * 100 if total_time > 0 else 0
return jsonify({"rate": rate}), 200
except Exception as e:
print("Error in /upload_audio:", str(e))
return jsonify({"error": "サーバーエラー", "details": str(e)}), 500
@app.route('/upload_base_audio', methods=['POST'])
def upload_base_audio():
try:
data = request.get_json()
if not data or 'audio_data' not in data:
return jsonify({"error": "音声データがありません"}), 400
# Base64デコードして音声バイナリを取得
audio_binary = base64.b64decode(data['audio_data'])
# 保存先ディレクトリの作成
audio_dir = "/tmp/data/base_audio"
os.makedirs(audio_dir, exist_ok=True)
# 一時ファイルに保存(実際の形式は WebM などと仮定)
temp_audio_path = os.path.join(audio_dir, "temp_audio")
with open(temp_audio_path, 'wb') as f:
f.write(audio_binary)
# pydub を使って一時ファイルを WAV に変換
# ※ここでは WebM 形式と仮定していますが、実際の形式に合わせて format の指定を変更してください
try:
audio = AudioSegment.from_file(temp_audio_path, format="webm")
except Exception as e:
# 形式が不明な場合は自動判別させる(ただし変換できない場合もあり)
audio = AudioSegment.from_file(temp_audio_path)
wav_audio_path = os.path.join(audio_dir, "recorded_base_audio.wav")
audio.export(wav_audio_path, format="wav")
# 一時ファイルを削除
os.remove(temp_audio_path)
return jsonify({"state": "Registration Success!"}), 200
except Exception as e:
print("Error in /upload_base_audio:", str(e))
return jsonify({"error": "サーバーエラー", "details": str(e)}), 500
if __name__ == '__main__':
port = int(os.environ.get("PORT", 7860))
app.run(debug=True, host="0.0.0.0", port=port)