Spaces:
Paused
Paused
File size: 4,542 Bytes
b75125a 297485e 841e4af b75125a 841e4af b75125a c86c2f3 297485e b75125a f317c15 b75125a 4522cd0 b75125a 841e4af b75125a 841e4af b75125a 841e4af b75125a 841e4af b75125a 841e4af 1827259 841e4af b75125a f317c15 b75125a 841e4af f6ff388 b75125a e6dd388 b75125a bda6d90 b75125a f317c15 b75125a f317c15 616dabc b75125a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
import os
import re
import torch
from threading import Thread
from typing import Iterator
from mongoengine import connect, Document, StringField, SequenceField
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer
from peft import PeftModel
# Constants
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 930
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
LICENSE = """
---
As a derivative work of [Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf) by Meta,
this demo is governed by the original [license](https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat/blob/main/LICENSE.txt) and [acceptable use policy](https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat/blob/main/USE_POLICY.md).
"""
if not torch.cuda.is_available():
DESCRIPTION += "\n<p>Running on CPU ๐ฅถ This demo does not work on CPU.</p>"
if torch.cuda.is_available():
modelA_id = "meta-llama/Llama-2-7b-chat-hf"
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=False,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
base_model = AutoModelForCausalLM.from_pretrained(modelA_id, device_map="auto", quantization_config=bnb_config)
modelA = PeftModel.from_pretrained(base_model, "ranamhamoud/storytell")
tokenizerA = AutoTokenizer.from_pretrained(modelA_id)
tokenizerA.pad_token = tokenizerA.eos_token
modelB_id = "meta-llama/Llama-2-7b-chat-hf"
modelB = AutoModelForCausalLM.from_pretrained(modelB_id, torch_dtype=torch.float16, device_map="auto")
tokenizerB = AutoTokenizer.from_pretrained(modelB_id)
tokenizerB.use_default_system_prompt = False
def make_prompt(entry):
return f"### Human: Don't repeat the assesments, limit to 500 words {entry} ### Assistant:"
@spaces.GPU
def generate(
model: str,
message: str,
chat_history: list[tuple[str, str]],
system_prompt: str,
max_new_tokens: int = 1024,
temperature: float = 0.6,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.2,
) -> Iterator[str]:
if model == "A":
model = modelA
tokenizer = tokenizerA
enc = tokenizer(make_prompt(message), return_tensors="pt", padding=True, truncation=True)
input_ids = enc.input_ids.to(model.device)
else:
model = modelB
tokenizer = tokenizerB
input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
conversation = []
if system_prompt:
conversation.append({"role": "system", "content": system_prompt})
for user, assistant in chat_history:
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
conversation.append({"role": "user", "content": message})
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
{"input_ids": input_ids},
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_beams=1,
repetition_penalty=repetition_penalty,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
# Gradio Interface Setup
chat_interface = gr.ChatInterface(
fn=generate,
additional_inputs=[gr.Dropdown("Model", ["A", "B"],label="Animal", info="Will add more animals later!")],
fill_height=True,
stop_btn=None,
examples=[
["Can you explain briefly to me what is the Python programming language?"],
["Could you please provide an explanation about the concept of recursion?"],
["Could you explain what a URL is?"]
],
theme='shivi/calm_seafoam'
)
# Gradio Web Interface
with gr.Blocks(theme='shivi/calm_seafoam',fill_height=True) as demo:
# gr.Markdown(DESCRIPTION)
chat_interface.render()
gr.Markdown(LICENSE)
# Main Execution
if __name__ == "__main__":
demo.queue(max_size=20)
demo.launch(share=True)
|