alec228 commited on
Commit
c23173c
·
1 Parent(s): e4fccf0

Initial commit

Browse files
README.md ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Sentiment Audio
2
+
3
+ Ce projet propose un pipeline complet d’analyse de sentiment à partir de fichiers audio francophones, structuré en quatre composantes principales :
4
+
5
+ 1. **Transcription audio**
6
+ - Modèle Wav2Vec2 (`jonatasgrosman/wav2vec2-large-xlsr-53-french`)
7
+ - Extraction de vecteurs audio puis décodage CTC
8
+
9
+ 2. **Analyse de sentiment textuel**
10
+ - Modèle BERT multilingue (`nlptown/bert-base-multilingual-uncased-sentiment`)
11
+ - Fonction `analyze_sentiment(text)` retournant un label (`négatif`, `neutre`, `positif`) et sa confiance
12
+
13
+ 3. **Interface utilisateur Gradio**
14
+ - Modes d’entrée : **enregistrement microphone** et **téléversement de fichier**
15
+ - Affichage de la transcription et du score de sentiment en temps réel
16
+ ![Interface Gradio](home.png)
17
+
18
+ 4. **API REST FastAPI**
19
+ - Endpoint `/predict` pour soumettre un fichier audio
20
+ - Retour JSON `{ "transcription": ..., "sentiment": {label: confiance} }`
21
+ - Documentation interactive Swagger UI (`/docs`)
22
+ ![Documentation API](api.png)
23
+
24
+ ---
25
+
26
+ ## Structure du projet
27
+
28
+ ```
29
+ sentiment_audio_tp/
30
+ ├── hf_model/ # exports de modèles sauvegardés via save_pretrained
31
+ ├── models/ # cache local HuggingFace (ignoré par Git)
32
+ ├── src/
33
+ │ ├── __init__.py
34
+ │ ├── transcription.py # SpeechEncoder (Wav2Vec2Model)
35
+ │ ├── sentiment.py # TextEncoder + analyze_sentiment()
36
+ │ ├── multimodal.py # Classifieur multimodal (fusion embeddings)
37
+ │ ├── inference.py # CLI (audio → transcription + sentiment)
38
+ │ ├── app.py # Interface Gradio
39
+ │ └── api.py # Serveur FastAPI
40
+ ├── requirements.txt # Dépendances du projet
41
+ ├── render.yaml # Infra as code pour Render
42
+ └── README.md # Ce document
43
+ ```
44
+
45
+ ---
46
+
47
+ ## Installation
48
+
49
+ 1. **Cloner le dépôt**
50
+ ```bash
51
+ git clone <URL_DU_REPO>
52
+ cd sentiment_audio_tp
53
+ ```
54
+
55
+ 2. **Configurer l’environnement**
56
+ ```bash
57
+ python -m venv venv
58
+ source venv/bin/activate # macOS/Linux
59
+ .\venv\Scripts\Activate.ps1 # Windows PowerShell
60
+ ```
61
+
62
+ 3. **Installer les dépendances**
63
+ ```bash
64
+ pip install --upgrade pip
65
+ pip install -r requirements.txt
66
+ ```
67
+
68
+ ---
69
+
70
+ ## Utilisation
71
+
72
+ ### CLI d’inférence
73
+
74
+ ```bash
75
+ python src/inference.py chemin/vers/audio.wav
76
+ # Affiche la transcription et le résultat de sentiment
77
+ ```
78
+
79
+ ### Interface Gradio
80
+
81
+ ```bash
82
+ python -m src.app
83
+ ```
84
+
85
+ - Rendez-vous sur `http://127.0.0.1:7861/`
86
+ - Choisissez **Enregistrement** ou **Upload**
87
+ - Obtenez la transcription et le sentiment en temps réel
88
+
89
+ ### API REST
90
+
91
+ ```bash
92
+ uvicorn src.api:app --reload --host 0.0.0.0 --port 8000
93
+ ```
94
+
95
+ - Swagger UI : `http://127.0.0.1:8000/docs`
96
+ - Tester avec `curl` ou Postman :
97
+
98
+ ```bash
99
+ curl -X POST "http://127.0.0.1:8000/predict" \
100
+ -F "file=@/chemin/vers/audio.wav"
101
+ ```
102
+
103
+ ---
104
+
105
+ ## Cas d’usage
106
+
107
+ - **Prototype rapide** d’analyse de sentiment sur des appels clients, podcasts, interviews
108
+ - **Outil de validation** pour analyses qualitatives de contenu audio
109
+ - **Proof of Concept** pour architectures multimodales
110
+ - **Service back-end** dans un chatbot vocal ou plateforme d’assistance
111
+
112
+ ---
113
+
114
+ ## Extension
115
+
116
+ - **Fine-tuning multimodal** : entraînement du classifieur fusion sur un dataset annoté
117
+ - **Support de nouveaux formats** : MP3, FLAC…
118
+ - **Tests et CI** : ajouter des tests `pytest` et pipelines CI/CD
119
+ - **Déploiement** : Docker, Kubernetes, monitoring
120
+
121
+ ---
122
+
123
+ ## Licence
124
+
125
+ Licence **MIT** — libre d’utilisation, modification et redistribution.
126
+ ```
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ transformers>=4.30.0
2
+ torch>=2.0.0
3
+ torchaudio>=2.0.0
4
+ gradio>=3.30.0
5
+
6
+ fastapi>=0.95.2
7
+ uvicorn[standard]>=0.21.1
8
+
9
+ soundfile>=0.12.1
src/__init__.py ADDED
File without changes
src/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (169 Bytes). View file
 
src/__pycache__/api.cpython-313.pyc ADDED
Binary file (4.5 kB). View file
 
src/__pycache__/app.cpython-313.pyc ADDED
Binary file (8.44 kB). View file
 
src/__pycache__/multimodal.cpython-313.pyc ADDED
Binary file (2.5 kB). View file
 
src/__pycache__/sentiment.cpython-313.pyc ADDED
Binary file (3.15 kB). View file
 
src/__pycache__/transcription.cpython-313.pyc ADDED
Binary file (2.34 kB). View file
 
src/api.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tempfile
2
+ import os
3
+ from fastapi import FastAPI, File, UploadFile, HTTPException
4
+ from fastapi.responses import JSONResponse
5
+ import torch.nn.functional as F
6
+ import torchaudio
7
+ import torch
8
+
9
+ from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
10
+ from src.transcription import SpeechEncoder
11
+ from src.sentiment import TextEncoder
12
+ from src.multimodal import MultimodalSentimentClassifier
13
+
14
+ app = FastAPI(
15
+ title="API Multimodale de Transcription & Sentiment",
16
+ version="1.0"
17
+ )
18
+
19
+ # Précharge des modèles
20
+ processor_ctc = Wav2Vec2Processor.from_pretrained(
21
+ "jonatasgrosman/wav2vec2-large-xlsr-53-french",
22
+ #"jonatasgrosman/wav2vec2-large-xlsr-53-french",
23
+ cache_dir="./models"
24
+ )
25
+ model_ctc = Wav2Vec2ForCTC.from_pretrained(
26
+ "jonatasgrosman/wav2vec2-large-xlsr-53-french",
27
+ #"alec228/audio-sentiment/tree/main/wav2vec2",
28
+ cache_dir="./models"
29
+ )
30
+ speech_enc = SpeechEncoder()
31
+ text_enc = TextEncoder()
32
+ model_mm = MultimodalSentimentClassifier()
33
+
34
+ def transcribe_ctc(wav_path: str) -> str:
35
+ waveform, sr = torchaudio.load(wav_path)
36
+ if sr != 16000:
37
+ waveform = torchaudio.transforms.Resample(sr, 16000)(waveform)
38
+ if waveform.size(0) > 1:
39
+ waveform = waveform.mean(dim=0, keepdim=True)
40
+ inputs = processor_ctc(
41
+ waveform.squeeze().numpy(),
42
+ sampling_rate=16000,
43
+ return_tensors="pt",
44
+ padding=True
45
+ )
46
+ with torch.no_grad():
47
+ logits = model_ctc(**inputs).logits
48
+ pred_ids = torch.argmax(logits, dim=-1)
49
+ return processor_ctc.batch_decode(pred_ids)[0].lower()
50
+
51
+ @app.post("/predict")
52
+ async def predict(file: UploadFile = File(...)):
53
+ # 1. Vérifier le type
54
+ if not file.filename.lower().endswith((".wav", ".flac", ".mp3")):
55
+ raise HTTPException(status_code=400,
56
+ detail="Seuls les fichiers audio WAV/FLAC/MP3 sont acceptés.")
57
+ # 2. Sauvegarder temporairement
58
+ suffix = os.path.splitext(file.filename)[1]
59
+ with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
60
+ content = await file.read()
61
+ tmp.write(content)
62
+ tmp_path = tmp.name
63
+
64
+ try:
65
+ # 3. Transcription
66
+ transcription = transcribe_ctc(tmp_path)
67
+
68
+ # 4. Features multimodales
69
+ audio_feat = speech_enc.extract_features(tmp_path)
70
+ text_feat = text_enc.extract_features([transcription])
71
+
72
+ # 5. Classification
73
+ logits = model_mm.classifier(torch.cat([audio_feat, text_feat], dim=1))
74
+ probs = F.softmax(logits, dim=1).squeeze().tolist()
75
+ labels = ["négatif", "neutre", "positif"]
76
+ sentiment = { labels[i]: round(probs[i], 3) for i in range(len(labels)) }
77
+
78
+ return JSONResponse({
79
+ "transcription": transcription,
80
+ "sentiment": sentiment
81
+ })
82
+
83
+ finally:
84
+ # 6. Nettoyage
85
+ os.remove(tmp_path)
src/app.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ from datetime import datetime
4
+
5
+ import gradio as gr
6
+ import torch
7
+ import pandas as pd
8
+ import soundfile as sf
9
+ import torchaudio
10
+
11
+ from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
12
+ from src.transcription import SpeechEncoder
13
+ from src.sentiment import TextEncoder
14
+
15
+ # Préchargement des modèles
16
+ processor_ctc = Wav2Vec2Processor.from_pretrained(
17
+ "jonatasgrosman/wav2vec2-large-xlsr-53-french", cache_dir="./models"
18
+ #"alec228/audio-sentiment/tree/main/wav2vec2", cache_dir="./models"
19
+ )
20
+ model_ctc = Wav2Vec2ForCTC.from_pretrained(
21
+ "jonatasgrosman/wav2vec2-large-xlsr-53-french", cache_dir="./models"
22
+ #"alec228/audio-sentiment/tree/main/wav2vec2", cache_dir="./models"
23
+ )
24
+
25
+ speech_enc = SpeechEncoder()
26
+ text_enc = TextEncoder()
27
+
28
+ # Pipeline d’analyse
29
+
30
+ def analyze_audio(audio_path):
31
+ # Lecture et prétraitement
32
+ data, sr = sf.read(audio_path)
33
+ arr = data.T if data.ndim > 1 else data
34
+ wav = torch.from_numpy(arr).unsqueeze(0).float()
35
+ if sr != 16000:
36
+ wav = torchaudio.transforms.Resample(sr, 16000)(wav)
37
+ sr = 16000
38
+ if wav.size(0) > 1:
39
+ wav = wav.mean(dim=0, keepdim=True)
40
+
41
+ # Transcription
42
+ inputs = processor_ctc(wav.squeeze().numpy(), sampling_rate=sr, return_tensors="pt")
43
+ with torch.no_grad():
44
+ logits = model_ctc(**inputs).logits
45
+ pred_ids = torch.argmax(logits, dim=-1)
46
+ transcription = processor_ctc.batch_decode(pred_ids)[0].lower()
47
+
48
+ # Sentiment principal
49
+ sent_dict = TextEncoder.analyze_sentiment(transcription)
50
+ label, conf = max(sent_dict.items(), key=lambda x: x[1])
51
+ emojis = {"positif": "😊", "neutre": "😐", "négatif": "☹️"}
52
+ emoji = emojis.get(label, "")
53
+
54
+ # Segmentation par phrase
55
+ segments = [s.strip() for s in re.split(r'[.?!]', transcription) if s.strip()]
56
+ seg_results = []
57
+ for seg in segments:
58
+ sd = TextEncoder.analyze_sentiment(seg)
59
+ l, c = max(sd.items(), key=lambda x: x[1])
60
+ seg_results.append({"Segment": seg, "Sentiment": l.capitalize(), "Confiance (%)": round(c*100,1)})
61
+ seg_df = pd.DataFrame(seg_results)
62
+
63
+ # Historique entry
64
+ timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
65
+ history_entry = {
66
+ "Horodatage": timestamp,
67
+ "Transcription": transcription,
68
+ "Sentiment": label.capitalize(),
69
+ "Confiance (%)": round(conf*100,1)
70
+ }
71
+
72
+ # Rendu
73
+ summary_html = (
74
+ f"<div style='display:flex;align-items:center;'>"
75
+ f"<span style='font-size:3rem;margin-right:10px;'>{emoji}</span>"
76
+ f"<h2 style='color:#6a0dad;'>{label.upper()}</h2>"
77
+ f"</div>"
78
+ f"<p><strong>Confiance :</strong> {conf*100:.1f}%</p>"
79
+ )
80
+ return transcription, summary_html, seg_df, history_entry
81
+
82
+ # Export CSV
83
+
84
+ def export_history_csv(history):
85
+ df = pd.DataFrame(history)
86
+ path = "history.csv"
87
+ df.to_csv(path, index=False)
88
+ return path
89
+
90
+ # Interface Chat + historique
91
+
92
+ demo = gr.Blocks(theme=gr.themes.Monochrome(primary_hue="purple"))
93
+ with demo:
94
+ gr.Markdown("# Chat & Analyse de Sentiment Audio")
95
+
96
+ gr.HTML("""
97
+ <div style="display: flex; flex-direction: column; gap: 10px; margin-bottom: 20px;">
98
+ <div style="background-color: #f3e8ff; padding: 12px 20px; border-radius: 12px; border-left: 5px solid #8e44ad;">
99
+ <strong>Étape 1 :</strong> Enregistrez votre voix ou téléversez un fichier audio (format WAV recommandé).
100
+ </div>
101
+ <div style="background-color: #e0f7fa; padding: 12px 20px; border-radius: 12px; border-left: 5px solid #0097a7;">
102
+ <strong>Étape 2 :</strong> Cliquez sur le bouton <em><b>Analyser</b></em> pour lancer la transcription et l’analyse.
103
+ </div>
104
+ <div style="background-color: #fff3e0; padding: 12px 20px; border-radius: 12px; border-left: 5px solid #fb8c00;">
105
+ <strong>Étape 3 :</strong> Visualisez les résultats : transcription, sentiment, et analyse détaillée.
106
+ </div>
107
+ <div style="background-color: #e8f5e9; padding: 12px 20px; border-radius: 12px; border-left: 5px solid #43a047;">
108
+ <strong>Étape 4 :</strong> Exportez l’historique des analyses au format CSV si besoin.
109
+ </div>
110
+ </div>
111
+
112
+ <script>
113
+ const origin = window.location.origin;
114
+ const swaggerUrl = origin + "/docs";
115
+ document.getElementById("swagger-link").innerHTML = `<a href="${swaggerUrl}" target="_blank">Voir la documentation de l’API (Swagger)</a>`;
116
+ </script>
117
+ """)
118
+
119
+ with gr.Row():
120
+ with gr.Column(scale=2):
121
+ audio_in = gr.Audio(sources=["microphone","upload"], type="filepath", label="Audio Input")
122
+ btn = gr.Button("Analyser")
123
+ export_btn = gr.Button("Exporter CSV")
124
+ with gr.Column(scale=3):
125
+ chat = gr.Chatbot(label="Historique des échanges")
126
+ transcription_out = gr.Textbox(label="Transcription", interactive=False)
127
+ summary_out = gr.HTML(label="Sentiment")
128
+ seg_out = gr.Dataframe(label="Détail par segment")
129
+ hist_out = gr.Dataframe(label="Historique")
130
+
131
+ state_chat = gr.State([]) # list of (user,bot)
132
+ state_hist = gr.State([]) # list of dict entries
133
+
134
+ def chat_callback(audio_path, chat_history, hist_state):
135
+ transcription, summary, seg_df, hist_entry = analyze_audio(audio_path)
136
+ user_msg = "[Audio reçu]"
137
+ bot_msg = f"**Transcription :** {transcription}\n**Sentiment :** {summary}"
138
+ chat_history = chat_history + [(user_msg, bot_msg)]
139
+ hist_state = hist_state + [hist_entry]
140
+ return chat_history, transcription, summary, seg_df, hist_state
141
+
142
+ btn.click(
143
+ fn=chat_callback,
144
+ inputs=[audio_in, state_chat, state_hist],
145
+ outputs=[chat, transcription_out, summary_out, seg_out, state_hist]
146
+ )
147
+ export_btn.click(
148
+ fn=export_history_csv,
149
+ inputs=[state_hist],
150
+ outputs=[gr.File(label="Télécharger CSV")]
151
+ )
152
+
153
+
154
+
155
+ if __name__ == "__main__":
156
+ demo.launch()
src/inference.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import torchaudio
4
+ from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
5
+
6
+ from src.multimodal import MultimodalSentimentClassifier
7
+
8
+ # 1. Transcription CTC
9
+ def transcribe(audio_path: str) -> str:
10
+ processor = Wav2Vec2Processor.from_pretrained(
11
+ "jonatasgrosman/wav2vec2-large-xlsr-53-french",
12
+ #cache_dir="./models"
13
+ )
14
+ model_ctc = Wav2Vec2ForCTC.from_pretrained(
15
+ "jonatasgrosman/wav2vec2-large-xlsr-53-french",
16
+ #cache_dir="./models"
17
+ )
18
+
19
+ waveform, sr = torchaudio.load(audio_path)
20
+ if sr != 16000:
21
+ waveform = torchaudio.transforms.Resample(sr, 16000)(waveform)
22
+ if waveform.size(0) > 1:
23
+ waveform = waveform.mean(dim=0, keepdim=True)
24
+
25
+ inputs = processor(
26
+ waveform.squeeze().numpy(),
27
+ sampling_rate=16000,
28
+ return_tensors="pt",
29
+ padding=True
30
+ )
31
+ with torch.no_grad():
32
+ logits = model_ctc(**inputs).logits
33
+ predicted_ids = torch.argmax(logits, dim=-1)
34
+ transcription = processor.batch_decode(predicted_ids)[0]
35
+ return transcription.lower()
36
+
37
+ # 2. Inférence multimodale
38
+ def infer(audio_path: str) -> dict:
39
+ # a) transcrire l’audio
40
+ text = transcribe(audio_path)
41
+
42
+ # b) charger et exécuter le modèle multimodal
43
+ model = MultimodalSentimentClassifier()
44
+ logits = model(audio_path, text) # [1, n_classes]
45
+ probs = F.softmax(logits, dim=1).squeeze().tolist()
46
+
47
+ labels = ["négatif", "neutre", "positif"]
48
+ return { labels[i]: round(probs[i], 3) for i in range(len(labels)) }
49
+
50
+ # Test rapide en ligne de commande
51
+ if __name__ == "__main__":
52
+ import sys
53
+ if len(sys.argv) != 2:
54
+ print("Usage: python src/inference.py <chemin_vers_audio.wav>")
55
+ sys.exit(1)
56
+ res = infer(sys.argv[1])
57
+ print(f"Résultat multimodal : {res}")
src/multimodal.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .transcription import SpeechEncoder
2
+ from .sentiment import TextEncoder
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ class MultimodalSentimentClassifier(nn.Module):
7
+ def __init__(
8
+ self,
9
+ wav2vec_name: str = "jonatasgrosman/wav2vec2-large-xlsr-53-french",
10
+ #wav2vec_name: str = "alec228/audio-sentiment/tree/main/wav2vec2",
11
+ bert_name: str = "nlptown/bert-base-multilingual-uncased-sentiment",
12
+ #bert_name: str = "alec228/audio-sentiment/tree/main/bert-sentiment",
13
+ #cache_dir: str = "./models",
14
+ hidden_dim: int = 256,
15
+ n_classes: int = 3
16
+ ):
17
+ super().__init__()
18
+ self.speech_encoder = SpeechEncoder(
19
+ model_name=wav2vec_name,
20
+ # cache_dir=cache_dir
21
+ )
22
+ self.text_encoder = TextEncoder(
23
+ model_name=bert_name,
24
+ # cache_dir=cache_dir
25
+ )
26
+ dim_a = self.speech_encoder.model.config.hidden_size
27
+ dim_t = self.text_encoder.model.config.hidden_size
28
+
29
+ self.classifier = nn.Sequential(
30
+ nn.Linear(dim_a + dim_t, hidden_dim),
31
+ nn.ReLU(),
32
+ nn.Dropout(0.2),
33
+ nn.Linear(hidden_dim, n_classes)
34
+ )
35
+
36
+ def forward(self, audio_path: str, text: str) -> torch.Tensor:
37
+ a_feat = self.speech_encoder.extract_features(audio_path)
38
+ t_feat = self.text_encoder.extract_features([text])
39
+ fused = torch.cat([a_feat, t_feat], dim=1)
40
+ return self.classifier(fused)
src/sentiment.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification
2
+ import torch
3
+ import torch.nn.functional as F
4
+
5
+ class TextEncoder:
6
+ def __init__(
7
+ self,
8
+ model_name: str = "nlptown/bert-base-multilingual-uncased-sentiment",
9
+ #model_name: str = "alec228/audio-sentiment/tree/main/bert-sentiment",
10
+ #cache_dir: str = "./models"
11
+ ):
12
+ # Tokenizer pour prétraiter le texte
13
+ self.tokenizer = AutoTokenizer.from_pretrained(
14
+ model_name,
15
+ #cache_dir=cache_dir
16
+ )
17
+ # Modèle BERT de base (sans tête de classification)
18
+ self.model = AutoModel.from_pretrained(
19
+ model_name,
20
+ #cache_dir=cache_dir
21
+ )
22
+
23
+ def extract_features(self, texts: list[str]) -> torch.Tensor:
24
+ """
25
+ Prend en entrée une liste de chaînes et renvoie
26
+ les embeddings du token [CLS] pour chaque texte.
27
+ """
28
+ # 1. Tokenisation
29
+ inputs = self.tokenizer(
30
+ texts,
31
+ return_tensors="pt",
32
+ truncation=True,
33
+ padding=True
34
+ )
35
+ # 2. Passage dans le modèle sans calcul de gradient
36
+ with torch.no_grad():
37
+ outputs = self.model(**inputs)
38
+ # 3. Extraction de l'embedding du token [CLS]
39
+ return outputs.last_hidden_state[:, 0, :] # [batch, hidden_size]
40
+
41
+ def analyze_sentiment(text: str) -> dict:
42
+ """
43
+ Analyse le sentiment d'un texte avec un modèle déjà fine-tuned
44
+ (nlptown/bert-base-multilingual-uncased-sentiment) et renvoie
45
+ un dict {label: confidence}.
46
+ """
47
+ # Chargement du tokenizer et du modèle de classification
48
+ tokenizer = AutoTokenizer.from_pretrained(
49
+ "nlptown/bert-base-multilingual-uncased-sentiment",
50
+ #cache_dir="./models"
51
+ )
52
+ model = AutoModelForSequenceClassification.from_pretrained(
53
+ "nlptown/bert-base-multilingual-uncased-sentiment",
54
+ #cache_dir="./models"
55
+ )
56
+
57
+ # Préparation
58
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
59
+ with torch.no_grad():
60
+ outputs = model(**inputs)
61
+ logits = outputs.logits
62
+ probs = F.softmax(logits, dim=1).squeeze().tolist()
63
+
64
+ # Les classes vont de 1 à 5, on choisit la plus probable
65
+ label_idx = int(torch.argmax(torch.tensor(probs))) + 1
66
+ if label_idx <= 2:
67
+ label = "négatif"
68
+ elif label_idx == 3:
69
+ label = "neutre"
70
+ else:
71
+ label = "positif"
72
+ confidence = round(max(probs), 3)
73
+ return {label: confidence}
74
+
src/transcription.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/transcription.py
2
+
3
+ from transformers import Wav2Vec2Processor, Wav2Vec2Model
4
+ import torch
5
+ import torchaudio
6
+
7
+ class SpeechEncoder:
8
+ def __init__(
9
+ self,
10
+ model_name: str = "jonatasgrosman/wav2vec2-large-xlsr-53-french",
11
+ #model_name: str = "alec228/audio-sentiment/tree/main/wav2vec2",
12
+ cache_dir: str = "./models"
13
+ ):
14
+ # Processor pour prétraiter l'audio
15
+ self.processor = Wav2Vec2Processor.from_pretrained(
16
+ model_name, cache_dir=cache_dir
17
+ )
18
+ # Modèle de base (sans tête CTC)
19
+ self.model = Wav2Vec2Model.from_pretrained(
20
+ model_name, cache_dir=cache_dir
21
+ )
22
+
23
+ def extract_features(self, audio_path: str) -> torch.Tensor:
24
+ """
25
+ Charge un fichier audio, le resample à 16 kHz, convertit en mono,
26
+ et renvoie la représentation vectorielle moyenne sur la séquence.
27
+ """
28
+ # 1. Chargement
29
+ waveform, sample_rate = torchaudio.load(audio_path)
30
+
31
+ # 2. Resample si nécessaire
32
+ if sample_rate != 16000:
33
+ waveform = torchaudio.transforms.Resample(
34
+ orig_freq=sample_rate,
35
+ new_freq=16000
36
+ )(waveform)
37
+
38
+ # 3. Passage en mono
39
+ if waveform.size(0) > 1:
40
+ waveform = waveform.mean(dim=0, keepdim=True)
41
+
42
+ # 4. Prétraitement pour le modèle
43
+ inputs = self.processor(
44
+ waveform.squeeze().numpy(),
45
+ sampling_rate=16000,
46
+ return_tensors="pt",
47
+ padding=True
48
+ )
49
+
50
+ # 5. Extraction sans gradient
51
+ with torch.no_grad():
52
+ outputs = self.model(**inputs)
53
+
54
+ # 6. Moyenne temporelle des embeddings
55
+ return outputs.last_hidden_state.mean(dim=1) # shape: [batch, hidden_size]