Spaces:
Paused
Paused
File size: 4,748 Bytes
b75125a 1fa6ab9 297485e 841e4af b75125a 841e4af b75125a c86c2f3 297485e b75125a f317c15 b75125a 4522cd0 b75125a 841e4af b75125a 841e4af b75125a 841e4af b75125a c2b7b7c b75125a 3b24c85 9905ae2 4518db8 b75125a 469c0f9 841e4af b75125a f317c15 b75125a 841e4af f6ff388 b75125a 3b24c85 e6dd388 b75125a bda6d90 b75125a f317c15 b75125a f317c15 616dabc b75125a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import os
import re
import logging
import torch
from threading import Thread
from typing import Iterator
from mongoengine import connect, Document, StringField, SequenceField
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer
from peft import PeftModel
# Constants
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 930
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
LICENSE = """
---
As a derivative work of [Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf) by Meta,
this demo is governed by the original [license](https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat/blob/main/LICENSE.txt) and [acceptable use policy](https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat/blob/main/USE_POLICY.md).
"""
if not torch.cuda.is_available():
DESCRIPTION += "\n<p>Running on CPU ๐ฅถ This demo does not work on CPU.</p>"
if torch.cuda.is_available():
modelA_id = "meta-llama/Llama-2-7b-chat-hf"
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=False,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
base_model = AutoModelForCausalLM.from_pretrained(modelA_id, device_map="auto", quantization_config=bnb_config)
modelA = PeftModel.from_pretrained(base_model, "ranamhamoud/storytell")
tokenizerA = AutoTokenizer.from_pretrained(modelA_id)
tokenizerA.pad_token = tokenizerA.eos_token
modelB_id = "meta-llama/Llama-2-7b-chat-hf"
modelB = AutoModelForCausalLM.from_pretrained(modelB_id, torch_dtype=torch.float16, device_map="auto")
tokenizerB = AutoTokenizer.from_pretrained(modelB_id)
tokenizerB.use_default_system_prompt = False
tokenizerB.pad_token = tokenizerB.eos_token
def make_prompt(entry):
return f"### Human: Don't repeat the assesments, limit to 500 words {entry} ### Assistant:"
@spaces.GPU
def generate(
model: str,
message: str,
chat_history: list[tuple[str, str]],
system_prompt: str,
max_new_tokens: int = 1024,
temperature: float = 0.6,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.2,
) -> Iterator[str]:
if chat_history is None:
logging.error("chat_history is None, initializing to empty list.")
chat_history = [] # Initialize to an empty list if None is passed
conversation = []
if system_prompt:
conversation.append({"role": "system", "content": system_prompt})
for user, assistant in chat_history:
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
conversation.append({"role": "user", "content": message})
if model == "A":
model = modelA
tokenizer = tokenizerA
else:
model = modelB
tokenizer = tokenizerB
enc = tokenizer(make_prompt(message), return_tensors="pt", padding=True, truncation=True)
input_ids = enc.input_ids.to(model.device)
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
{"input_ids": input_ids},
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_beams=1,
repetition_penalty=repetition_penalty,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
logging.basicConfig(level=logging.DEBUG)
# Gradio Interface Setup
chat_interface = gr.ChatInterface(
fn=generate,
additional_inputs=[gr.Dropdown("Model", ["A", "B"],label="Animal", info="Will add more animals later!")],
fill_height=True,
stop_btn=None,
examples=[
["Can you explain briefly to me what is the Python programming language?"],
["Could you please provide an explanation about the concept of recursion?"],
["Could you explain what a URL is?"]
],
theme='shivi/calm_seafoam'
)
# Gradio Web Interface
with gr.Blocks(theme='shivi/calm_seafoam',fill_height=True) as demo:
# gr.Markdown(DESCRIPTION)
chat_interface.render()
gr.Markdown(LICENSE)
# Main Execution
if __name__ == "__main__":
demo.queue(max_size=20)
demo.launch(share=True)
|