import os import torch import gradio as gr import spaces from huggingface_hub import InferenceClient from langchain.embeddings import HuggingFaceInstructEmbeddings from langchain.vectorstores import Chroma from langchain.prompts import PromptTemplate from langchain.chains import RetrievalQA from langchain.llms import HuggingFacePipeline from transformers import AutoTokenizer, TextStreamer, pipeline, BitsAndBytesConfig, AutoModelForCausalLM TORCH_VERSION = torch.__version__ SUPPORTED_TORCH_VERSIONS = ['2.0.1', '2.1.2', '2.2.2', '2.4.0'] if TORCH_VERSION.rsplit('+')[0] not in SUPPORTED_TORCH_VERSIONS: print(f"Warning: Current PyTorch version {TORCH_VERSION} may not be compatible with ZeroGPU. " f"Supported versions are: {', '.join(SUPPORTED_TORCH_VERSIONS)}") # Model initialization model_id = "meta-llama/Llama-3.2-3B-Instruct" token = os.environ.get("HF_TOKEN") bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16 ) tokenizer = AutoTokenizer.from_pretrained(model_id, token=token) model = AutoModelForCausalLM.from_pretrained( model_id, token=token, quantization_config=bnb_config ) # Initialize InstructEmbeddings embeddings = HuggingFaceInstructEmbeddings( model_name="hkunlp/instructor-base", model_kwargs={"device": "cpu"} ) db = Chroma( persist_directory="db", embedding_function=embeddings ) # Setup pipeline streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) text_pipeline = pipeline( "text-generation", model=model, tokenizer=tokenizer, max_new_tokens=500, temperature=0.1, top_p=0.95, repetition_penalty=1.15, streamer=streamer, ) # Create LLM chain llm = HuggingFacePipeline(pipeline=text_pipeline) qa_chain = RetrievalQA.from_chain_type( llm=llm, chain_type="stuff", retriever=db.as_retriever(search_kwargs={"k": 2}), return_source_documents=False, chain_type_kwargs={"prompt": prompt_template} ) @spaces.GPU(duration=30) def respond(message, history, system_message, max_tokens, temperature, top_p): try: # Use the QA chain directly response = qa_chain.invoke({"query": message}) yield response["result"] except Exception as e: yield f"An error occurred: {str(e)}" # Create Gradio interface demo = gr.ChatInterface( respond, additional_inputs=[ gr.Textbox( value=DEFAULT_SYSTEM_PROMPT, label="System Message", lines=3, visible=False ), gr.Slider( minimum=1, maximum=2048, value=500, step=1, label="Max new tokens" ), gr.Slider( minimum=0.1, maximum=4.0, value=0.1, step=0.1, label="Temperature" ), gr.Slider( minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)" ), ], title="ROS2 Expert Assistant", description="Ask questions about ROS2, navigation, and robotics. I'll provide concise answers based on the available documentation.", ) if __name__ == "__main__": demo.launch()