ff98 commited on
Commit
714ab7f
·
1 Parent(s): afd9054

Features added

Browse files
Files changed (4) hide show
  1. .gitignore +3 -0
  2. app.py +106 -0
  3. classification_plot.png +0 -0
  4. requirements.txt +97 -0
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # .gitignore
2
+ venv/
3
+ .venv/
app.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, AutoModelForQuestionAnswering, AutoTokenizer, pipeline
4
+ from PIL import Image
5
+ import matplotlib.pyplot as plt
6
+
7
+
8
+ def process_inputs(audio, option):
9
+ # Process inputs and return results
10
+ if option == "Translate":
11
+ generated_text = generate_text_from_audio(audio), None
12
+ return generated_text
13
+ elif option == "Summarize":
14
+ generated_text = generate_text_from_audio(audio)
15
+ return generate_summary_from_text(generated_text, minLength=50, maxLength=150), None
16
+ elif option == "text-classification":
17
+ generated_text = generate_text_from_audio(audio)
18
+ return "", text_classification(generated_text)
19
+ elif option == "Ask a Question":
20
+ generated_text = generate_text_from_audio(audio)
21
+ return ask_ques_from_text(generated_text), None
22
+
23
+ def generate_text_from_audio(audio):
24
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
25
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
26
+ model_id = "openai/whisper-small"
27
+
28
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(
29
+ model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
30
+ )
31
+ model.to(device)
32
+
33
+ processor = AutoProcessor.from_pretrained(model_id)
34
+
35
+ # Load the audio using librosa and extract the audio data (not the sample rate)
36
+ audio_data = audio # audio_data is the NumPy array we need
37
+
38
+ pipe = pipeline(
39
+ "automatic-speech-recognition",
40
+ model=model,
41
+ tokenizer=processor.tokenizer,
42
+ feature_extractor=processor.feature_extractor,
43
+ torch_dtype=torch_dtype,
44
+ chunk_length_s=30,
45
+ batch_size=16, # batch size for inference - set based on your device
46
+ device=device,
47
+ )
48
+
49
+ audio_text_result = pipe(audio_data, generate_kwargs={"task": "translate", "forced_decoder_ids": [[1, None], [2, 50359]]})
50
+ return audio_text_result["text"]
51
+
52
+ def generate_summary_from_text(text, minLength, maxLength):
53
+ summarizer = pipeline("summarization", model="Falconsai/text_summarization")
54
+ return summarizer(text, max_length=maxLength, min_length=minLength, do_sample=False)
55
+
56
+ def text_classification(text):
57
+ classifier = pipeline(task="text-classification", model="SamLowe/roberta-base-go_emotions", top_k=None)
58
+ model_outputs = classifier([text])
59
+
60
+ # Extract the labels and scores from the model's output
61
+ labels = [output['label'] for output in model_outputs[0]]
62
+ scores = [output['score'] for output in model_outputs[0]]
63
+ sorted_data = sorted(zip(scores, labels), reverse=True)
64
+
65
+ # Extract top 5 emotions
66
+ top_5_scores, top_5_labels = zip(*sorted_data[:5])
67
+
68
+ # Plotting the Bar Chart
69
+ plt.figure(figsize=(12, 8))
70
+ plt.barh(top_5_labels, top_5_scores, color='skyblue')
71
+ plt.title('Top 5 Sentiment Scores for Emotions')
72
+ plt.xlabel('Score')
73
+ plt.ylabel('Emotion')
74
+
75
+ # Display the plot
76
+ plt.savefig("classification_plot.png")
77
+ plt.close()
78
+ return "classification_plot.png"
79
+
80
+
81
+ def ask_ques_from_text(text):
82
+ model_name = "deepset/roberta-base-squad2"
83
+
84
+ # Get predictions
85
+ nlp = pipeline('question-answering', model=model_name, tokenizer=model_name, device=0)
86
+
87
+ QA_input = {
88
+ 'question': 'who did not recognize?',
89
+ 'context': text # Your context text from audio_text_result
90
+ }
91
+
92
+ res = nlp(QA_input)
93
+ print("Answer from pipeline:", res['answer'])
94
+
95
+ return res['answer']
96
+
97
+ demo = gr.Interface(
98
+ fn=process_inputs,
99
+ inputs=[
100
+ gr.Audio(label="Upload audio", type="filepath"), # Audio input
101
+ gr.Dropdown(choices=["Translate", "Summarize", "text-classification", "Ask a Question"], label="Choose an Option")
102
+ ],
103
+ outputs=[gr.Textbox(label="Result"), gr.Image(label="Classification Plot")],
104
+ )
105
+
106
+ demo.launch()
classification_plot.png ADDED
requirements.txt ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==1.1.1
2
+ aiofiles==23.2.1
3
+ aiohappyeyeballs==2.4.3
4
+ aiohttp==3.10.10
5
+ aiosignal==1.3.1
6
+ annotated-types==0.7.0
7
+ anyio==4.6.2.post1
8
+ attrs==24.2.0
9
+ audioread==3.0.1
10
+ certifi==2024.8.30
11
+ cffi==1.17.1
12
+ charset-normalizer==3.4.0
13
+ click==8.1.7
14
+ contourpy==1.3.0
15
+ cycler==0.12.1
16
+ datasets==3.1.0
17
+ decorator==5.1.1
18
+ dill==0.3.8
19
+ fastapi==0.115.4
20
+ ffmpy==0.4.0
21
+ filelock==3.16.1
22
+ fonttools==4.54.1
23
+ frozenlist==1.5.0
24
+ fsspec==2024.9.0
25
+ gradio==5.5.0
26
+ gradio_client==1.4.2
27
+ h11==0.14.0
28
+ httpcore==1.0.6
29
+ httpx==0.27.2
30
+ huggingface-hub==0.26.2
31
+ idna==3.10
32
+ Jinja2==3.1.4
33
+ joblib==1.4.2
34
+ kiwisolver==1.4.7
35
+ lazy_loader==0.4
36
+ librosa==0.10.2.post1
37
+ llvmlite==0.43.0
38
+ markdown-it-py==3.0.0
39
+ MarkupSafe==2.1.5
40
+ matplotlib==3.9.2
41
+ mdurl==0.1.2
42
+ mpmath==1.3.0
43
+ msgpack==1.1.0
44
+ multidict==6.1.0
45
+ multiprocess==0.70.16
46
+ networkx==3.4.2
47
+ numba==0.60.0
48
+ numpy==2.0.2
49
+ orjson==3.10.11
50
+ packaging==24.1
51
+ pandas==2.2.3
52
+ pillow==11.0.0
53
+ platformdirs==4.3.6
54
+ pooch==1.8.2
55
+ propcache==0.2.0
56
+ psutil==6.1.0
57
+ pyarrow==18.0.0
58
+ pycparser==2.22
59
+ pydantic==2.9.2
60
+ pydantic_core==2.23.4
61
+ pydub==0.25.1
62
+ Pygments==2.18.0
63
+ pyparsing==3.2.0
64
+ python-dateutil==2.9.0.post0
65
+ python-multipart==0.0.12
66
+ pytz==2024.2
67
+ PyYAML==6.0.2
68
+ regex==2024.11.6
69
+ requests==2.32.3
70
+ rich==13.9.4
71
+ ruff==0.7.2
72
+ safehttpx==0.1.1
73
+ safetensors==0.4.5
74
+ scikit-learn==1.5.2
75
+ scipy==1.14.1
76
+ semantic-version==2.10.0
77
+ shellingham==1.5.4
78
+ six==1.16.0
79
+ sniffio==1.3.1
80
+ soundfile==0.12.1
81
+ soxr==0.5.0.post1
82
+ starlette==0.41.2
83
+ sympy==1.13.1
84
+ threadpoolctl==3.5.0
85
+ tokenizers==0.20.3
86
+ tomlkit==0.12.0
87
+ torch==2.5.1
88
+ tqdm==4.67.0
89
+ transformers==4.46.2
90
+ typer==0.12.5
91
+ typing_extensions==4.12.2
92
+ tzdata==2024.2
93
+ urllib3==2.2.3
94
+ uvicorn==0.32.0
95
+ websockets==12.0
96
+ xxhash==3.5.0
97
+ yarl==1.17.1