FusedWhisperLlama Model
Model ini adalah hasil fusion antara Whisper dan LLaMA untuk speech-to-text-to-LLM pipeline.
Model Description
- Model Type: FusedWhisperLlama
- Language: Indonesian & English
- Tasks: Speech Recognition, Text Generation
- Base Models:
- Whisper: openai/whisper-small
- LLaMA: unsloth/Llama-3.2-1B-Instruct
Usage
import torch
import torch.nn as nn
import librosa
import numpy as np
import json
import os
from typing import Dict, Any
from transformers import WhisperConfig, WhisperForConditionalGeneration, WhisperProcessor, LlamaConfig, LlamaForCausalLM, AutoTokenizer
from huggingface_hub import hf_hub_download
from pathlib import Path
def download_model_files(repo_id: str, local_dir: str):
os.makedirs(local_dir, exist_ok=True)
config_dir = os.path.join(local_dir, "configs")
os.makedirs(config_dir, exist_ok=True)
# Download model file
print("Downloading model file...")
model_path = hf_hub_download(
repo_id=repo_id,
filename="pytorch_model.bin",
local_dir=local_dir
)
# Download configs
print("Downloading config files...")
config_files = [
"config.json",
"configs/config_whisper.json",
"configs/config_llama.json",
# Whisper tokenizer files
"configs/tokenizer_whisper/added_tokens.json",
"configs/tokenizer_whisper/merges.txt",
"configs/tokenizer_whisper/normalizer.json",
"configs/tokenizer_whisper/preprocessor_config.json",
"configs/tokenizer_whisper/special_tokens_map.json",
"configs/tokenizer_whisper/tokenizer_config.json",
"configs/tokenizer_whisper/vocab.json",
# Llama tokenizer files
"configs/tokenizer_llama/special_tokens_map.json",
"configs/tokenizer_llama/tokenizer.json",
"configs/tokenizer_llama/tokenizer_config.json"
]
for file in config_files:
try:
hf_hub_download(
repo_id=repo_id,
filename=file,
local_dir=local_dir
)
print(f"Downloaded {file}")
except Exception as e:
print(f"Warning: Could not download {file}: {e}")
return os.path.join(local_dir, "pytorch_model.bin")
class StandaloneFusionInference:
def __init__(self, model_path: str, config_dir: str = None, device: str = None):
if config_dir is None:
config_dir = os.path.join(os.path.dirname(model_path), "configs")
# Set device
if device is None:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
self.device = torch.device(device)
print(f"Using device: {self.device}")
# Load configs
with open(os.path.join(config_dir, "config_whisper.json"), "r") as f:
self.whisper_config = json.load(f)
with open(os.path.join(config_dir, "config_llama.json"), "r") as f:
self.llama_config = json.load(f)
print("Loading Whisper model...")
whisper_config = WhisperConfig(**self.whisper_config["whisper_config"])
self.whisper = WhisperForConditionalGeneration(whisper_config)
self.processor = WhisperProcessor.from_pretrained(
os.path.join(config_dir, "tokenizer_whisper")
)
print("Loading LLaMA model...")
llama_config = LlamaConfig(**self.llama_config["llama_config"])
self.llm = LlamaForCausalLM(llama_config)
# Load LLM tokenizer
tokenizer_path = os.path.join(config_dir, "tokenizer_llama")
try:
self.tokenizer = AutoTokenizer.from_pretrained(
tokenizer_path,
trust_remote_code=True
)
print("Loaded local LLaMA tokenizer")
except (OSError, ValueError) as e:
print(f"Warning: Could not load local tokenizer ({e}), using default")
self.tokenizer = AutoTokenizer.from_pretrained(
"unsloth/Llama-3.2-1B-Instruct",
trust_remote_code=True
)
# Fusion layer
self.fusion_layer = nn.Sequential(
nn.Linear(
self.whisper.config.d_model,
self.whisper.config.d_model
),
nn.ReLU(),
nn.LayerNorm(self.whisper.config.d_model)
)
print("Loading model weights...")
weights = torch.load(model_path, map_location=self.device)
self.whisper.load_state_dict(weights["whisper_model"])
self.llm.load_state_dict(weights["llm_model"])
self.fusion_layer.load_state_dict(weights["fusion_layer"])
# Set to eval mode
self.whisper.eval()
self.llm.eval()
self.fusion_layer.eval()
# Move to device
self.whisper = self.whisper.to(self.device)
self.llm = self.llm.to(self.device)
self.fusion_layer = self.fusion_layer.to(self.device)
self.system_prompt = self.whisper_config["system_prompt"]
print("Model loaded successfully!")
def generate(self, audio_path: str) -> Dict[str, Any]:
# Load dan proses audio
speech, sr = librosa.load(audio_path, sr=16000, mono=True)
speech = librosa.util.normalize(speech)
# Process dengan whisper processor
inputs = self.processor(
speech,
sampling_rate=16000,
return_tensors="pt"
).input_features.to(self.device)
with torch.no_grad():
# Get transcription
outputs = self.whisper.generate(
inputs,
max_length=448,
num_beams=5,
temperature=0.0,
no_repeat_ngram_size=3,
return_timestamps=False
)
transcription = self.processor.batch_decode(
outputs,
skip_special_tokens=True,
normalize=True
)[0].strip()
# Prepare input untuk LLM
prompt = f"System: {self.system_prompt}\nUser: {transcription}"
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
# Generate response
outputs = self.llm.generate(
**inputs,
max_new_tokens=512,
temperature=0.7,
do_sample=True
)
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
return {
"transcription": transcription,
"response": response
}
if __name__ == "__main__":
# Download model dari Hugging Face Hub
repo_id = "johaness14/fused-whisper-llama"
local_dir = "downloaded_model"
model_path = download_model_files(repo_id, local_dir)
# Initialize inference
inference = StandaloneFusionInference(
model_path,
config_dir=os.path.join(local_dir, "configs"),
device="cuda" # or "cpu" for CPU-only
)
# Run inference
audio_path = "path/to/your/audio.wav"
output = inference.generate(audio_path)
print("\nTranscription:")
print(output["transcription"])
print("\nResponse:")
print(output["response"])
Training Details
Model ini menggabungkan kemampuan speech recognition dari Whisper dengan kemampuan text generation dari LLaMA menggunakan fusion layer.
Training Procedure
- Speech Recognition: Menggunakan Whisper small model
- Text Generation: Menggunakan LLaMA 3.2 1B model
- Fusion: Custom fusion layer untuk menghubungkan kedua model
Limitations and Biases
- Model mungkin memiliki bias dari model dasar yang digunakan
- Performa bergantung pada kualitas audio input
- Keterbatasan pada panjang teks yang bisa digenerate
- Downloads last month
- 8
Model tree for johaness14/fused-whisper-llama
Base model
openai/whisper-small