Spaces:
Runtime error
Runtime error
File size: 2,103 Bytes
0867e96 311ca85 a02f343 6054c89 0867e96 6054c89 a02f343 0867e96 a02f343 0867e96 a02f343 0867e96 a02f343 0867e96 a02f343 0867e96 a02f343 0867e96 a02f343 0867e96 a02f343 0867e96 a02f343 0867e96 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
import base64
from io import BytesIO
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoProcessor
import torch
from PIL import Image
import subprocess
# Install flash-attn
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
app = FastAPI()
models = {
"microsoft/Phi-3.5-vision-instruct": AutoModelForCausalLM.from_pretrained(
"microsoft/Phi-3.5-vision-instruct",
trust_remote_code=True,
torch_dtype="auto",
attn_implementation="flash_attention_2"
).cuda().eval()
}
processors = {
"microsoft/Phi-3.5-vision-instruct": AutoProcessor.from_pretrained(
"microsoft/Phi-3.5-vision-instruct",
trust_remote_code=True
)
}
class InputData(BaseModel):
image: str
text_input: str
model_id: str = "microsoft/Phi-3.5-vision-instruct"
@app.post("/run_example")
async def run_example(input_data: InputData):
try:
model = models[input_data.model_id]
processor = processors[input_data.model_id]
# Decode base64 image
image_data = base64.b64decode(input_data.image)
image = Image.open(BytesIO(image_data)).convert("RGB")
user_prompt = '<|user|>\n'
assistant_prompt = '<|assistant|>\n'
prompt_suffix = "<|end|>\n"
prompt = f"{user_prompt}<|image_1|>\n{input_data.text_input}{prompt_suffix}{assistant_prompt}"
inputs = processor(prompt, image, return_tensors="pt").to("cuda:0")
generate_ids = model.generate(
**inputs,
max_new_tokens=1000,
eos_token_id=processor.tokenizer.eos_token_id,
)
generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:]
response = processor.batch_decode(
generate_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=False
)[0]
return {"response": response}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) |