Manju017's picture
Update bitsandbytes configuration and model loading
ff23289 verified
raw
history blame
1.36 kB
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import BitsAndBytesConfig
from accelerate import infer_auto_device_map
# Load the model name
model_name = "ai4bharat/Airavata"
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Create a BitsAndBytesConfig for quantization
bnb_config = BitsAndBytesConfig(
load_in_8bit=True, # Set this to True for 8-bit loading
# Optionally, you can specify more parameters based on your needs
)
# Load the model using the BitsAndBytesConfig
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=bnb_config # Use the BitsAndBytesConfig
)
# Now infer the device map
device_map = infer_auto_device_map(model)
# Move model to the appropriate device based on device_map
model.to(device_map)
# Define the inference function
def generate_text(prompt):
inputs = tokenizer(prompt, return_tensors="pt")
outputs = model.generate(**inputs)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
# Create the Gradio interface
interface = gr.Interface(
fn=generate_text,
inputs="text",
outputs="text",
title="Airavata Text Generation Model",
description="This is the AI4Bharat Airavata model for text generation in Indic languages."
)
# Launch the interface
interface.launch()