webpluging / app.py
ranamhamoud's picture
Update app.py
bda6d90 verified
raw
history blame
4.54 kB
import os
import re
import torch
from threading import Thread
from typing import Iterator
from mongoengine import connect, Document, StringField, SequenceField
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer
from peft import PeftModel
# Constants
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 930
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
LICENSE = """
---
As a derivative work of [Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf) by Meta,
this demo is governed by the original [license](https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat/blob/main/LICENSE.txt) and [acceptable use policy](https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat/blob/main/USE_POLICY.md).
"""
if not torch.cuda.is_available():
DESCRIPTION += "\n<p>Running on CPU ๐Ÿฅถ This demo does not work on CPU.</p>"
if torch.cuda.is_available():
modelA_id = "meta-llama/Llama-2-7b-chat-hf"
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=False,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
base_model = AutoModelForCausalLM.from_pretrained(modelA_id, device_map="auto", quantization_config=bnb_config)
modelA = PeftModel.from_pretrained(base_model, "ranamhamoud/storytell")
tokenizerA = AutoTokenizer.from_pretrained(modelA_id)
tokenizerA.pad_token = tokenizerA.eos_token
modelB_id = "meta-llama/Llama-2-7b-chat-hf"
modelB = AutoModelForCausalLM.from_pretrained(modelB_id, torch_dtype=torch.float16, device_map="auto")
tokenizerB = AutoTokenizer.from_pretrained(modelB_id)
tokenizerB.use_default_system_prompt = False
def make_prompt(entry):
return f"### Human: Don't repeat the assesments, limit to 500 words {entry} ### Assistant:"
@spaces.GPU
def generate(
model: str,
message: str,
chat_history: list[tuple[str, str]],
system_prompt: str,
max_new_tokens: int = 1024,
temperature: float = 0.6,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.2,
) -> Iterator[str]:
if model == "A":
model = modelA
tokenizer = tokenizerA
enc = tokenizer(make_prompt(message), return_tensors="pt", padding=True, truncation=True)
input_ids = enc.input_ids.to(model.device)
else:
model = modelB
tokenizer = tokenizerB
input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
conversation = []
if system_prompt:
conversation.append({"role": "system", "content": system_prompt})
for user, assistant in chat_history:
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
conversation.append({"role": "user", "content": message})
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
{"input_ids": input_ids},
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_beams=1,
repetition_penalty=repetition_penalty,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
# Gradio Interface Setup
chat_interface = gr.ChatInterface(
fn=generate,
additional_inputs=[gr.Dropdown("Model", ["A", "B"],label="Animal", info="Will add more animals later!")],
fill_height=True,
stop_btn=None,
examples=[
["Can you explain briefly to me what is the Python programming language?"],
["Could you please provide an explanation about the concept of recursion?"],
["Could you explain what a URL is?"]
],
theme='shivi/calm_seafoam'
)
# Gradio Web Interface
with gr.Blocks(theme='shivi/calm_seafoam',fill_height=True) as demo:
# gr.Markdown(DESCRIPTION)
chat_interface.render()
gr.Markdown(LICENSE)
# Main Execution
if __name__ == "__main__":
demo.queue(max_size=20)
demo.launch(share=True)