Spaces:
Paused
Paused
File size: 6,056 Bytes
b75125a 297485e 841e4af b75125a 841e4af b75125a e3f86b5 2999a82 e3f86b5 c86c2f3 297485e b75125a 4d5d8af b75125a 4522cd0 b75125a 841e4af 4d5d8af b75125a 841e4af 4d5d8af 89044b5 b75125a 4d5d8af c2b7b7c 4d5d8af b75125a 4d5d8af b75125a 6caf62f 4d5d8af 2a727f4 4d5d8af b75125a 4d5d8af b75125a 4d5d8af b75125a 9905ae2 4d5d8af 469c0f9 4d5d8af 841e4af 4d5d8af b75125a 4d5d8af b75125a 841e4af f6ff388 b75125a 4d5d8af e3f86b5 4d5d8af e3f86b5 e6dd388 2a727f4 b75125a 2a727f4 b75125a f317c15 b75125a 4d5d8af 2a727f4 4d5d8af f317c15 2a727f4 b75125a 4d5d8af |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 |
import os
import re
import torch
from threading import Thread
from typing import Iterator
from mongoengine import connect, Document, StringField, SequenceField
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer
from peft import PeftModel
from openai import OpenAI
openai_key = os.environ.get("OPENAI_KEY")
def generate_image(text):
try:
client = OpenAI(api_key=openai_key)
response = client.images.generate(
model="dall-e-3",
prompt="Create an illustration that accurately depicts the character and the setting of this story:"+text,
n=1,
size="1024x1024"
)
except Exception as error:
print(str(error))
raise gr.Error("An error occurred while generating speech. Please check your API key and come back try again.")
return response.data[0].url
# Constants
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
LICENSE = """
---
As a derivative work of [Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf) by Meta,
this demo is governed by the original [license](https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat/blob/main/LICENSE.txt) and [acceptable use policy](https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat/blob/main/USE_POLICY.md).
"""
# GPU Check and add CPU warning
if not torch.cuda.is_available():
DESCRIPTION += "\n<p>Running on CPU ๐ฅถ This demo does not work on CPU.</p>"
if torch.cuda.is_available():
# Model and Tokenizer Configuration
model_id = "meta-llama/Llama-2-7b-chat-hf"
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=False,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
base_model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", quantization_config=bnb_config)
model = PeftModel.from_pretrained(base_model, "ranamhamoud/storytell")
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
# # MongoDB Connection
# PASSWORD = os.environ.get("MONGO_PASS")
# connect(host=f"mongodb+srv://ranamhammoud11:{PASSWORD}@stories.zf5v52a.mongodb.net/")
# # MongoDB Document
# class Story(Document):
# message = StringField()
# content = StringField()
# story_id = SequenceField(primary_key=True)
# Utility function for prompts
def make_prompt(entry):
return f"### Human: When asked to explain use a story.Don't repeat the assesments, limit to 500 words.However keep context in mind if edits to the content is required. {entry} ### Assistant:"
# f"TELL A STORY, RELATE TO COMPUTER SCIENCE, INCLUDE ASSESMENTS. MAKE IT REALISTIC AND AROUND 800 WORDS, END THE STORY WITH "THE END.": {entry}"
def process_text(text):
text = re.sub(r'\[answer:\]\s*', 'Answer: ', text)
text = re.sub(r'\[.*?\](?<!Answer: )', '', text)
return text
custom_css = """
body, input, button, textarea, label {
font-family: Arial, sans-serif;
font-size: 24px;
}
.gr-chat-interface .gr-chat-message-container {
font-size: 14px;
}
.gr-button {
font-size: 14px;
padding: 12px 24px;
}
.gr-input {
font-size: 14px;
}
"""
# Gradio Function
@spaces.GPU
def generate(
message: str,
chat_history: list[tuple[str, str]],
max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
temperature: float = 0.6,
top_p: float = 0.7,
top_k: int = 20,
repetition_penalty: float = 1.0,
) -> Iterator[str]:
conversation = []
for user, assistant in chat_history:
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
conversation.append({"role": "user", "content": make_prompt(message)})
enc = tokenizer(make_prompt(message), return_tensors="pt", padding=True, truncation=True)
input_ids = enc.input_ids.to(model.device)
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=False)
generate_kwargs = dict(
{"input_ids": input_ids},
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_beams=1,
repetition_penalty=repetition_penalty,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
processed_text = process_text(text)
outputs.append(processed_text)
output = "".join(outputs)
yield output
final_story = "".join(outputs)
generate_image(final_story)
# try:
# saved_story = Story(message=message, content=final_story).save()
# yield f"{final_story}\n\n Story saved with ID: {saved_story.story_id}"
# except Exception as e:
# yield f"Failed to save story: {str(e)}"
chatbot=gr.Chatbot([(None, ("https://img.freepik.com/free-photo/painting-mountain-lake-with-mountain-background_188544-9126.jpg",))])
chat_interface = gr.ChatInterface(
fn=generate,
fill_height=True,
stop_btn=None,
examples=[
["Can you explain briefly to me what is the Python programming language?"],
["Could you please provide an explanation about the concept of recursion?"],
["Could you explain what a URL is?"]
],
theme='shivi/calm_seafoam',autofocus=True, chatbot=chatbot
)
# Gradio Web Interface
with gr.Blocks(css=custom_css,theme='shivi/calm_seafoam',fill_height=True) as demo:
chat_interface.render()
# gr.Markdown(LICENSE)
# Main Execution
if __name__ == "__main__":
demo.queue(max_size=20)
demo.launch(share=True) |