Spaces:
Paused
Paused
File size: 5,442 Bytes
b75125a 297485e 841e4af b75125a 841e4af b75125a c86c2f3 297485e b75125a 4d5d8af b75125a 4522cd0 b75125a 841e4af 4d5d8af b75125a 841e4af 4d5d8af b75125a 4d5d8af c2b7b7c 4d5d8af b75125a 4d5d8af b75125a 4d5d8af b75125a 4d5d8af b75125a 4d5d8af b75125a 9905ae2 4d5d8af 469c0f9 4d5d8af 841e4af 4d5d8af b75125a 4d5d8af b75125a 841e4af f6ff388 b75125a 4d5d8af e6dd388 b75125a f317c15 b75125a 4d5d8af b75125a 4d5d8af f317c15 616dabc b75125a 4d5d8af |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
import os
import re
import torch
from threading import Thread
from typing import Iterator
from mongoengine import connect, Document, StringField, SequenceField
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer
from peft import PeftModel
# Constants
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
LICENSE = """
---
As a derivative work of [Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf) by Meta,
this demo is governed by the original [license](https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat/blob/main/LICENSE.txt) and [acceptable use policy](https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat/blob/main/USE_POLICY.md).
"""
# GPU Check and add CPU warning
if not torch.cuda.is_available():
DESCRIPTION += "\n<p>Running on CPU ๐ฅถ This demo does not work on CPU.</p>"
if torch.cuda.is_available():
# Model and Tokenizer Configuration
model_id = "meta-llama/Llama-2-7b-chat-hf"
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=False,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
base_model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", quantization_config=bnb_config)
model = PeftModel.from_pretrained(base_model, "ranamhamoud/storytell")
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
# # MongoDB Connection
# PASSWORD = os.environ.get("MONGO_PASS")
# connect(host=f"mongodb+srv://ranamhammoud11:{PASSWORD}@stories.zf5v52a.mongodb.net/")
# # MongoDB Document
# class Story(Document):
# message = StringField()
# content = StringField()
# story_id = SequenceField(primary_key=True)
# Utility function for prompts
def make_prompt(entry):
return f"### Human: Don't repeat the assesments, limit to 500 words {entry} ### Assistant:"
# f"TELL A STORY, RELATE TO COMPUTER SCIENCE, INCLUDE ASSESMENTS. MAKE IT REALISTIC AND AROUND 800 WORDS, END THE STORY WITH "THE END.": {entry}"
def process_text(text):
# First, handle the specific case for [answer:]
# This replaces [answer:] with "Answer:" and keeps the content after it on the same line.
text = re.sub(r'\[answer:\]\s*', 'Answer: ', text)
# Now, remove all other content within brackets.
# This regex looks for square brackets and any content inside them, excluding those that start with "Answer: " already modified.
text = re.sub(r'\[.*?\](?<!Answer: )', '', text)
return text
custom_css = """
body, input, button, textarea, label {
font-family: Arial, sans-serif;
font-size: 24px;
}
.gr-chat-interface .gr-chat-message-container {
font-size: 14px;
}
.gr-button {
font-size: 14px;
padding: 12px 24px;
}
.gr-input {
font-size: 14px;
}
"""
# Gradio Function
@spaces.GPU
def generate(
message: str,
chat_history: list[tuple[str, str]],
max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
temperature: float = 0.6,
top_p: float = 0.7,
top_k: int = 20,
repetition_penalty: float = 1.0,
) -> Iterator[str]:
conversation = []
for user, assistant in chat_history:
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
conversation.append({"role": "user", "content": make_prompt(message)})
enc = tokenizer(make_prompt(message), return_tensors="pt", padding=True, truncation=True)
input_ids = enc.input_ids.to(model.device)
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=False)
generate_kwargs = dict(
{"input_ids": input_ids},
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_beams=1,
repetition_penalty=repetition_penalty,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
processed_text = process_text(text)
outputs.append(processed_text)
output = "".join(outputs)
yield output
# final_story = "".join(outputs)
# try:
# saved_story = Story(message=message, content=final_story).save()
# yield f"{final_story}\n\n Story saved with ID: {saved_story.story_id}"
# except Exception as e:
# yield f"Failed to save story: {str(e)}"
# Gradio Interface Setup
chat_interface = gr.ChatInterface(
fn=generate,
fill_height=True,
stop_btn=None,
examples=[
["Can you explain briefly to me what is the Python programming language?"],
["Could you please provide an explanation about the concept of recursion?"],
["Could you explain what a URL is?"]
],
theme='shivi/calm_seafoam'
)
# Gradio Web Interface
with gr.Blocks(css=custom_css,theme='shivi/calm_seafoam',fill_height=True) as demo:
chat_interface.render()
# gr.Markdown(LICENSE)
# Main Execution
if __name__ == "__main__":
demo.queue(max_size=20)
demo.launch(share=True) |