Spaces:
Running
Running
File size: 1,873 Bytes
8273d5f 2a438ba 8273d5f 9ce3948 580cc25 695df9a 35c5f77 580cc25 9ce3948 a41d9f9 580cc25 9ce3948 580cc25 2a438ba 9ce3948 8273d5f 9ce3948 8273d5f a41d9f9 8273d5f 580cc25 8273d5f 2a438ba d652f80 8273d5f a41d9f9 8273d5f d652f80 8273d5f 35c5f77 d652f80 8273d5f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
import torch
import gradio as gr
from transformers import AutoModel, pipeline, AutoTokenizer
import spaces
import subprocess
# from issue: https://discuss.huggingface.co/t/how-to-install-flash-attention-on-hf-gradio-space/70698/2
# InternVL2 needs flash_attn
subprocess.run(
"pip install flash-attn --no-build-isolation",
env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
shell=True,
)
try:
model_name = "OpenGVLab/InternVL2-8B"
# model: <class 'transformers_modules.OpenGVLab.InternVL2-8B.0e6d592d957d9739b6df0f4b90be4cb0826756b9.modeling_internvl_chat.InternVLChatModel'>
model = (
AutoModel.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
# low_cpu_mem_usage=True,
trust_remote_code=True,
)
.cuda()
.eval()
)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
# pipeline: <class 'transformers.pipelines.visual_question_answering.VisualQuestionAnsweringPipeline'>
inference = pipeline(
task="visual-question-answering", model=model, tokenizer=tokenizer
)
except Exception as error:
raise gr.Error("👌" + str(error), duration=30)
@spaces.GPU
def predict(input_img, questions):
try:
gr.Info("pipeline: " + str(type(inference)))
gr.Info("model: " + str(type(model)))
predictions = inference(question=questions, image=input_img)
return str(predictions)
except Exception as e:
error_message = "❌" + str(e)
raise gr.Error(error_message, duration=25)
gradio_app = gr.Interface(
predict,
inputs=[
gr.Image(label="Select A Image", sources=["upload", "webcam"], type="pil"),
"text",
],
outputs="text",
title='ask me anything',
)
if __name__ == "__main__":
gradio_app.launch(show_error=True, debug=True)
|