Spaces:
Running
Running
import os | |
import spaces | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
# import torch | |
# from transformers import ( | |
# AutoTokenizer, | |
# TextStreamer, | |
# pipeline, | |
# BitsAndBytesConfig, | |
# AutoModelForCausalLM | |
# ) | |
from langchain.embeddings import HuggingFaceInstructEmbeddings | |
from langchain.vectorstores import Chroma | |
from langchain.prompts import PromptTemplate | |
from langchain.chains import RetrievalQA | |
from langchain.llms import HuggingFacePipeline | |
import gradio as gr | |
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu" | |
model_id = "meta-llama/Llama-3.2-3B-Instruct" | |
# Remove the spaces.GPU decorator since we'll handle GPU directly | |
# def initialize_model(): | |
# bnb_config = BitsAndBytesConfig( | |
# load_in_4bit=True, | |
# bnb_4bit_use_double_quant=True, | |
# bnb_4bit_quant_type="nf4", | |
# bnb_4bit_compute_dtype=torch.bfloat16 | |
# ) | |
# tokenizer = AutoTokenizer.from_pretrained(model_id, token=os.environ.get("HF_TOKEN")) | |
# model = AutoModelForCausalLM.from_pretrained( | |
# model_id, | |
# token=os.environ.get("HF_TOKEN"), | |
# quantization_config=bnb_config if torch.cuda.is_available() else None, | |
# device_map="auto" if torch.cuda.is_available() else "cpu", | |
# torch_dtype=torch.float32 if not torch.cuda.is_available() else None | |
# ) | |
# return model, tokenizer | |
def initialize_model(): | |
model_id = "meta-llama/Llama-3.2-3B-Instruct" | |
token = os.environ.get("HF_TOKEN") | |
tokenizer = AutoTokenizer.from_pretrained(model_id, token=token) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
token=token, | |
device_map="auto" # This works better with ZeroGPU | |
) | |
return model, tokenizer | |
def respond(message, history, system_message, max_tokens, temperature, top_p): | |
try: | |
model, tokenizer = initialize_model() | |
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
text_pipeline = pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
max_new_tokens=max_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
repetition_penalty=1.15, | |
streamer=streamer, | |
) | |
llm = HuggingFacePipeline(pipeline=text_pipeline) | |
qa_chain = RetrievalQA.from_chain_type( | |
llm=llm, | |
chain_type="stuff", | |
retriever=db.as_retriever(search_kwargs={"k": 2}), | |
return_source_documents=False, | |
chain_type_kwargs={"prompt": prompt_template} | |
) | |
response = qa_chain.invoke({"query": message}) | |
return response["result"] | |
except Exception as e: | |
return f"An error occurred: {str(e)}" | |
demo = gr.ChatInterface( | |
respond, | |
additional_inputs=[ | |
gr.Textbox( | |
value=DEFAULT_SYSTEM_PROMPT, | |
label="System Message", | |
lines=3, | |
visible=False | |
), | |
gr.Slider( | |
minimum=1, | |
maximum=2048, | |
value=500, | |
step=1, | |
label="Max new tokens" | |
), | |
gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.1, | |
step=0.1, | |
label="Temperature" | |
), | |
gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.95, | |
step=0.05, | |
label="Top-p" | |
), | |
], | |
title="ROS2 Expert Assistant", | |
description="Ask questions about ROS2, navigation, and robotics. I'll provide concise answers based on the available documentation.", | |
) |