|
import gradio as gr |
|
import spaces |
|
from huggingface_hub import InferenceClient |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("Svngoku/c4ai-command-r7b-12-2024-4bit") |
|
model = AutoModelForCausalLM.from_pretrained("Svngoku/c4ai-command-r7b-12-2024-4bit") |
|
|
|
""" |
|
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference |
|
""" |
|
client = InferenceClient("Svngoku/c4ai-command-r7b-12-2024-4bit") |
|
|
|
def wrap_text_output(text): |
|
wrapped_text = f"<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>{text}<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>" |
|
return wrapped_text |
|
|
|
|
|
|
|
@spaces.GPU |
|
def generate_response(message, history, documents_text): |
|
conversation = history + [{"role": "user", "content": message}] |
|
documents = [] |
|
for doc in documents_text.split('\n'): |
|
if doc.strip(): |
|
try: |
|
heading, body = doc.split(':', 1) |
|
documents.append({"heading": heading.strip(), "body": body.strip()}) |
|
except ValueError: |
|
print(f"Invalid document format: {doc}") |
|
|
|
input_prompt = tokenizer.apply_chat_template( |
|
conversation=conversation, |
|
documents=documents, |
|
tokenize=False, |
|
add_generation_prompt=True, |
|
return_tensors="pt", |
|
) |
|
input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids |
|
input_ids = input_ids.to(model.device) |
|
gen_tokens = model.generate( |
|
input_ids, max_new_tokens=2048, do_sample=True, temperature=0.3 |
|
) |
|
gen_text = tokenizer.decode(gen_tokens[0], skip_special_tokens=True) |
|
|
|
|
|
chatbot_response = gen_text.split("<|CHATBOT_TOKEN|>")[-1] |
|
|
|
return chatbot_response |
|
|
|
|
|
""" |
|
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface |
|
""" |
|
demo = gr.ChatInterface( |
|
fn=generate_response, |
|
type="messages", |
|
additional_inputs=[ |
|
gr.Textbox(lines=5, placeholder="Enter documents (heading: body) separated by new lines...") |
|
], |
|
theme="ocean", |
|
title="Simple Chat with RAG", |
|
description="Ask a question and provide relevant documents for context" |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |