import os import spaces # First import import gradio as gr from langchain.embeddings import HuggingFaceInstructEmbeddings from langchain.vectorstores import Chroma from langchain.prompts import PromptTemplate from langchain.chains import RetrievalQA from langchain.llms import HuggingFacePipeline from huggingface_hub import InferenceClient # GPU initialization moved into a function def initialize_model(): import torch from transformers import ( AutoTokenizer, TextStreamer, pipeline, BitsAndBytesConfig, AutoModelForCausalLM ) model_id = "meta-llama/Llama-3.2-3B-Instruct" token = os.environ.get("HF_TOKEN") bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16 ) tokenizer = AutoTokenizer.from_pretrained(model_id, token=token) model = AutoModelForCausalLM.from_pretrained( model_id, token=token, quantization_config=bnb_config ) return model, tokenizer # Initialize non-GPU components embeddings = HuggingFaceInstructEmbeddings( model_name="hkunlp/instructor-base", model_kwargs={"device": "cpu"} ) db = Chroma( persist_directory="db", embedding_function=embeddings ) @spaces.GPU(duration=30) def respond(message, history, system_message, max_tokens, temperature, top_p): try: # Initialize model components inside the GPU scope model, tokenizer = initialize_model() streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) text_pipeline = pipeline( "text-generation", model=model, tokenizer=tokenizer, max_new_tokens=max_tokens, temperature=temperature, top_p=top_p, repetition_penalty=1.15, streamer=streamer, ) llm = HuggingFacePipeline(pipeline=text_pipeline) qa_chain = RetrievalQA.from_chain_type( llm=llm, chain_type="stuff", retriever=db.as_retriever(search_kwargs={"k": 2}), return_source_documents=False, chain_type_kwargs={"prompt": prompt_template} ) response = qa_chain.invoke({"query": message}) yield response["result"] except Exception as e: yield f"An error occurred: {str(e)}" # Create Gradio interface demo = gr.ChatInterface( respond, additional_inputs=[ gr.Textbox( value=DEFAULT_SYSTEM_PROMPT, label="System Message", lines=3, visible=False ), gr.Slider( minimum=1, maximum=2048, value=500, step=1, label="Max new tokens" ), gr.Slider( minimum=0.1, maximum=4.0, value=0.1, step=0.1, label="Temperature" ), gr.Slider( minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)" ), ], title="ROS2 Expert Assistant", description="Ask questions about ROS2, navigation, and robotics. I'll provide concise answers based on the available documentation.", ) if __name__ == "__main__": demo.launch()