Nechba's picture
Update app.py
392c31a verified
raw
history blame
3.14 kB
from flask import Flask, request, jsonify
from transformers import pipeline
from transformers import AutoTokenizer, AutoModelForTokenClassification
import whisper
import os
import ffmpeg
app = Flask(__name__)
# Initialize Whisper model
whisper_model = whisper.load_model("small") # Renamed variable
# Initialize Emotion Classifier
classifier = pipeline("text-classification", model="j-hartmann/emotion-english-distilroberta-base", return_all_scores=True)
# Initialize NER pipeline
ner_tokenizer = AutoTokenizer.from_pretrained("dslim/bert-base-NER")
ner_model = AutoModelForTokenClassification.from_pretrained("dslim/bert-base-NER") # Renamed variable
ner_pipeline = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer) # Renamed variable
def convert_audio(input_path, output_path):
try:
ffmpeg.input(input_path).output(output_path, acodec='pcm_s16le').run()
return True
except ffmpeg.Error as e:
print(f"FFmpeg error: {e.stderr.decode()}")
return False
@app.route('/transcribe', methods=['POST'])
def transcribe_audio():
if 'file' not in request.files:
return jsonify({'error': 'No file uploaded'}), 400
file = request.files['file']
if file.filename == '':
return jsonify({'error': 'No selected file'}), 400
if not allowed_file(file.filename):
return jsonify({'error': 'Unsupported file type'}), 400
try:
temp_path = "temp_audio"
file.save(temp_path)
# Convert audio to a format Whisper can process
converted_path = "converted_audio.wav"
if not convert_audio(temp_path, converted_path):
return jsonify({'error': 'Audio conversion failed'}), 500
# Transcribe the converted audio
result = whisper_model.transcribe(converted_path)
transcription = result["text"]
# Clean up temporary files
if os.path.exists(temp_path):
os.remove(temp_path)
if os.path.exists(converted_path):
os.remove(converted_path)
return jsonify({'transcription': transcription})
except Exception as e:
return jsonify({'error': str(e)}), 500
@app.route('/classify', methods=['POST'])
def classify():
try:
data = request.get_json()
if 'text' not in data:
return jsonify({"error": "Missing 'text' field"}), 400
text = data['text']
result = classifier(text)
return jsonify(result)
except Exception as e:
return jsonify({"error": str(e)}), 500
@app.route('/ner', methods=['POST'])
def ner_endpoint():
try:
data = request.get_json()
text = data.get("text", "")
# Use the renamed ner_pipeline
ner_results = ner_pipeline(text)
words_and_entities = [
{"word": result['word'], "entity": result['entity']}
for result in ner_results
]
return jsonify({"entities": words_and_entities})
except Exception as e:
return jsonify({"error": str(e)}), 500