Spaces:
Paused
Paused
File size: 5,177 Bytes
b75125a 297485e 841e4af b75125a 841e4af b75125a e70034d e3f86b5 e70034d 70f6552 e3f86b5 5845573 e3f86b5 5845573 e3f86b5 5845573 e3f86b5 5845573 e3f86b5 c86c2f3 297485e b75125a 4d5d8af b75125a 4522cd0 b75125a 841e4af 4d5d8af b75125a 841e4af 4d5d8af 89044b5 b75125a c2b7b7c b75125a 6caf62f 4d5d8af 2a727f4 4d5d8af b75125a 4d5d8af b75125a 4d5d8af b75125a 9905ae2 4d5d8af 469c0f9 4d5d8af 841e4af 4d5d8af b75125a 4d5d8af b75125a 841e4af f6ff388 b75125a 4d5d8af e3f86b5 ae2217e dd8ad6c e6dd388 b75125a dd8ad6c b75125a f317c15 b75125a 4d5d8af 2a727f4 4d5d8af f317c15 2a727f4 b75125a 4d5d8af |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
import os
import re
import torch
from threading import Thread
from typing import Iterator
from mongoengine import connect, Document, StringField, SequenceField
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer
from peft import PeftModel
import openai
from openai import OpenAI
openai.api_key = os.environ.get("OPENAI_KEY")
def generate_image(text):
try:
response = openai.images.generate(
model="dall-e-3",
prompt="Create an illustration that accurately depicts the character and the setting of a story:"+text,
n=1,
size="1024x1024"
)
except Exception as error:
print(str(error))
raise gr.Error("An error occurred while generating the image. Please check your API key and try again.")
return response.data[0].url
# Constants
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
LICENSE = """
---
As a derivative work of [Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf) by Meta,
this demo is governed by the original [license](https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat/blob/main/LICENSE.txt) and [acceptable use policy](https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat/blob/main/USE_POLICY.md).
"""
# GPU Check and add CPU warning
if not torch.cuda.is_available():
DESCRIPTION += "\n<p>Running on CPU ๐ฅถ This demo does not work on CPU.</p>"
if torch.cuda.is_available():
# Model and Tokenizer Configuration
model_id = "meta-llama/Llama-2-7b-chat-hf"
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=False,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
base_model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", quantization_config=bnb_config)
model = PeftModel.from_pretrained(base_model, "ranamhamoud/storytell")
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
def make_prompt(entry):
return f"### Human: When asked to explain use a story.Don't repeat the assesments, limit to 500 words.However keep context in mind if edits to the content is required. {entry} ### Assistant:"
def process_text(text):
text = re.sub(r'\[answer:\]\s*', 'Answer: ', text)
text = re.sub(r'\[.*?\](?<!Answer: )', '', text)
return text
custom_css = """
body, input, button, textarea, label {
font-family: Arial, sans-serif;
font-size: 24px;
}
.gr-chat-interface .gr-chat-message-container {
font-size: 14px;
}
.gr-button {
font-size: 14px;
padding: 12px 24px;
}
.gr-input {
font-size: 14px;
}
"""
# Gradio Function
@spaces.GPU
def generate(
message: str,
chat_history: list[tuple[str, str]],
max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
temperature: float = 0.6,
top_p: float = 0.7,
top_k: int = 20,
repetition_penalty: float = 1.0,
) -> Iterator[str]:
conversation = []
for user, assistant in chat_history:
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
conversation.append({"role": "user", "content": make_prompt(message)})
enc = tokenizer(make_prompt(message), return_tensors="pt", padding=True, truncation=True)
input_ids = enc.input_ids.to(model.device)
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=False)
generate_kwargs = dict(
{"input_ids": input_ids},
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_beams=1,
repetition_penalty=repetition_penalty,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
processed_text = process_text(text)
outputs.append(processed_text)
output = "".join(outputs)
yield output
final_story = "".join(outputs)
image_url = generate_image(final_story)
outputs.append(image_url)
chat_interface = gr.ChatInterface(
fn=generate,
fill_height=True,
stop_btn=None,
examples=[
["Can you explain briefly to me what is the Python programming language?"],
["Could you please provide an explanation about the concept of recursion?"],
["Could you explain what a URL is?"]
],
theme='shivi/calm_seafoam',autofocus=True,
)
# Gradio Web Interface
with gr.Blocks(css=custom_css,theme='shivi/calm_seafoam',fill_height=True) as demo:
chat_interface.render()
# gr.Markdown(LICENSE)
# Main Execution
if __name__ == "__main__":
demo.queue(max_size=20)
demo.launch(share=True) |