Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,69 +1,51 @@
|
|
1 |
-
import torch
|
2 |
-
from peft import PeftModel
|
3 |
-
import transformers
|
4 |
-
import gradio as gr
|
5 |
-
from fastapi import FastAPI
|
6 |
import random
|
|
|
|
|
|
|
7 |
|
|
|
|
|
8 |
|
9 |
-
app= FastAPI()
|
10 |
-
|
11 |
-
assert (
|
12 |
-
"LlamaTokenizer" in transformers._import_structure["models.llama"]
|
13 |
-
), "LLaMA is now in HuggingFace's main branch.\nPlease reinstall it: pip uninstall transformers && pip install git+https://github.com/huggingface/transformers.git"
|
14 |
-
from transformers import LlamaTokenizer, LlamaForCausalLM, GenerationConfig
|
15 |
-
|
16 |
-
tokenizer = LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf")
|
17 |
|
18 |
-
|
19 |
-
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
device = "cpu"
|
25 |
|
26 |
-
|
27 |
-
|
28 |
-
device = "mps"
|
29 |
-
except:
|
30 |
-
pass
|
31 |
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
model = PeftModel.from_pretrained(
|
40 |
-
model, LORA_WEIGHTS, torch_dtype=torch.float16, force_download=True
|
41 |
-
)
|
42 |
-
elif device == "mps":
|
43 |
-
model = LlamaForCausalLM.from_pretrained(
|
44 |
-
BASE_MODEL,
|
45 |
-
device_map={"": device},
|
46 |
-
torch_dtype=torch.float16,
|
47 |
-
)
|
48 |
-
model = PeftModel.from_pretrained(
|
49 |
-
model,
|
50 |
-
LORA_WEIGHTS,
|
51 |
-
device_map={"": device},
|
52 |
-
torch_dtype=torch.float16,
|
53 |
-
)
|
54 |
-
else:
|
55 |
-
model = LlamaForCausalLM.from_pretrained(
|
56 |
-
BASE_MODEL, device_map={"": device}, low_cpu_mem_usage=True
|
57 |
)
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
|
|
|
|
|
|
|
|
|
|
62 |
)
|
|
|
|
|
|
|
63 |
|
64 |
-
|
65 |
-
def generate_prompt(input:str):
|
66 |
-
instruction= '''You are a dating bio writer for single boy with the keywords provided. the dating bio should be within 30 words and should be catchy. the dating bio should be different in every run.'''
|
67 |
if input:
|
68 |
return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
|
69 |
|
@@ -73,50 +55,12 @@ def generate_prompt(input:str):
|
|
73 |
### Input:
|
74 |
{input}
|
75 |
|
76 |
-
|
|
|
|
|
77 |
|
78 |
-
|
79 |
-
|
80 |
-
model.eval()
|
81 |
-
if torch.__version__ >= "2":
|
82 |
-
model = torch.compile(model)
|
83 |
|
84 |
-
|
85 |
-
async def evaluate(
|
86 |
-
input:str,
|
87 |
-
top_p=0.75,
|
88 |
-
top_k=40,
|
89 |
-
num_beams=4,
|
90 |
-
max_new_tokens=128,
|
91 |
-
seed=None,
|
92 |
-
do_sample=True,
|
93 |
-
# **kwargs,
|
94 |
-
):
|
95 |
-
prompt = generate_prompt(input)
|
96 |
-
inputs = tokenizer(prompt, return_tensors="pt")
|
97 |
-
input_ids = inputs["input_ids"].to(device)
|
98 |
-
temperature= [0.2, 0.5, 0.7, 0.9, 1.0]
|
99 |
-
generation_config = GenerationConfig(
|
100 |
-
temperature=random.choice(temperature),
|
101 |
-
top_p=top_p,
|
102 |
-
top_k=top_k,
|
103 |
-
num_beams=num_beams,
|
104 |
-
**kwargs,
|
105 |
-
)
|
106 |
-
with torch.no_grad():
|
107 |
-
generation_output = model.generate(
|
108 |
-
input_ids=input_ids,
|
109 |
-
generation_config=generation_config,
|
110 |
-
return_dict_in_generate=True,
|
111 |
-
output_scores=True,
|
112 |
-
max_new_tokens=max_new_tokens,
|
113 |
-
seed=None,
|
114 |
-
do_sample= do_sample
|
115 |
-
)
|
116 |
-
s = generation_output.sequences[0]
|
117 |
-
output = tokenizer.decode(s)
|
118 |
-
return output.split("### Response:")[1].strip()
|
119 |
|
120 |
-
if __name__ == "__main__":
|
121 |
-
import uvicorn
|
122 |
-
uvicorn.run(app, host="0.0.0.0", port=7860)
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import random
|
2 |
+
from typing import Optional
|
3 |
+
from fastapi import FastAPI
|
4 |
+
from pydantic import BaseModel
|
5 |
|
6 |
+
from peft import PeftModel
|
7 |
+
from transformers import LLaMATokenizer, LLaMAForCausalLM, GenerationConfig
|
8 |
|
9 |
+
app = FastAPI()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
+
tokenizer = LLaMATokenizer.from_pretrained("decapoda-research/llama-7b-hf")
|
12 |
+
model = LLaMAForCausalLM.from_pretrained(
|
13 |
+
"decapoda-research/llama-7b-hf",
|
14 |
+
load_in_8bit=True,
|
15 |
+
device_map="auto",
|
16 |
+
)
|
17 |
+
model = PeftModel.from_pretrained(model, "tloen/alpaca-lora-7b")
|
18 |
|
19 |
+
class InputPrompt(BaseModel):
|
20 |
+
instruction: str
|
21 |
+
input: Optional[str] = None
|
|
|
22 |
|
23 |
+
class OutputResponse(BaseModel):
|
24 |
+
response: str
|
|
|
|
|
|
|
25 |
|
26 |
+
@app.post("/evaluate")
|
27 |
+
def evaluate(input_prompt: InputPrompt):
|
28 |
+
temperature = random.uniform(0.1, 1.0)
|
29 |
+
generation_config = GenerationConfig(
|
30 |
+
temperature=temperature,
|
31 |
+
top_p=0.75,
|
32 |
+
num_beams=4,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
)
|
34 |
+
prompt = generate_prompt(input_prompt.instruction, input_prompt.input)
|
35 |
+
inputs = tokenizer(prompt, return_tensors="pt")
|
36 |
+
input_ids = inputs["input_ids"].cuda()
|
37 |
+
generation_output = model.generate(
|
38 |
+
input_ids=input_ids,
|
39 |
+
generation_config=generation_config,
|
40 |
+
return_dict_in_generate=True,
|
41 |
+
output_scores=True,
|
42 |
+
max_new_tokens=256
|
43 |
)
|
44 |
+
for s in generation_output.sequences:
|
45 |
+
output = tokenizer.decode(s)
|
46 |
+
return OutputResponse(response=output.split("### Response:")[1].strip())
|
47 |
|
48 |
+
def generate_prompt(instruction, input=None):
|
|
|
|
|
49 |
if input:
|
50 |
return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
|
51 |
|
|
|
55 |
### Input:
|
56 |
{input}
|
57 |
|
58 |
+
### Response:"""
|
59 |
+
else:
|
60 |
+
return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
|
61 |
|
62 |
+
### Instruction:
|
63 |
+
{instruction}
|
|
|
|
|
|
|
64 |
|
65 |
+
### Response:"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
|
|
|
|
|
|