Spaces:
Sleeping
Sleeping
File size: 4,088 Bytes
8b54513 3c55139 22ff2b2 5bbfa70 3c55139 8b54513 3c55139 288480f bb89818 3c55139 8b54513 3c55139 8b54513 3c55139 22ff2b2 3c55139 8b54513 3c55139 8b54513 3c55139 8b54513 3c55139 8b54513 3c55139 8b54513 3c55139 8b54513 3c55139 5813aac 3c55139 8b54513 3c55139 a66a8c1 3c55139 fc68709 3c55139 8b54513 3c55139 8b54513 d25ae12 3c55139 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
import os
import gradio as gr
from torchvision.transforms.functional import to_tensor
from huggingface_hub import hf_hub_download, snapshot_download, login
from t2i_inference import T2IConfig, TextToImageInference
def generate_text(self, image: str, prompt: str) -> str:
image = image.convert('RGB')
image = to_tensor(image).unsqueeze(0).to(self.device)
image_code = self.visual_tokenizer.encoder(image.to(self.config.dtype))['bottleneck_rep']
image_text = "".join([f"<I{x}>" for x in image_code[0].cpu().tolist()])
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": f"{image_text}\n{prompt}"}
]
input_text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = self.tokenizer(input_text, return_tensors="pt")
gen_ids = self.model.generate(
inputs.input_ids.to(self.device),
max_new_tokens=512,
do_sample=True)
return self.tokenizer.batch_decode(gen_ids[:, inputs.input_ids.shape[1]:], skip_special_tokens=True)[0]
login(token=os.getenv('HF_TOKEN'))
config = T2IConfig()
config.model = snapshot_download("csuhan/Tar-7B-v0.1")
config.ar_path = hf_hub_download("csuhan/TA-Tok", "ar_dtok_lp_1024px.pth")
config.encoder_path = hf_hub_download("csuhan/TA-Tok", "ta_tok.pth")
config.decoder_path = hf_hub_download("peizesun/llamagen_t2i", "vq_ds16_t2i.pt")
inference = TextToImageInference(config)
def generate_image(prompt, top_p, top_k, cfg_scale):
config.top_p = top_p
config.top_k = top_k
config.cfg_scale = cfg_scale
image = inference.generate_image(prompt)
return image
def clear_inputs_t2i():
return "", None
def understand_image(image, prompt):
return generate_text(inference, image, prompt)
def clear_inputs_i2t():
return None, ""
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
<div align="center">
### Tar: Unifying Visual Understanding and Generation via Text-Aligned Representations
[π Paper](https://arxiv.org/abs/xxxx.xxxxx) β’ [π» Code](https://github.com/csuhan/Tar) β’ [π¦ Model](https://huggingface.co/csuhan/TA-Tok)
</div>
""",
elem_id="title",
)
with gr.Tab("Image Generation"):
with gr.Row():
with gr.Column(scale=1):
prompt = gr.Textbox(label="Prompt", placeholder="Enter a prompt")
with gr.Accordion("Advanced Settings", open=False):
top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p")
top_k = gr.Slider(1, 2000, value=1200, step=10, label="Top-k")
cfg_scale = gr.Slider(1.0, 20.0, value=4.0, step=0.5, label="CFG Scale")
with gr.Row():
generate_btn = gr.Button("Generate")
clear_btn = gr.Button("Clear")
with gr.Column(scale=2):
output_image = gr.Image(label="Generated Image")
generate_btn.click(
generate_image,
inputs=[prompt, top_p, top_k, cfg_scale],
outputs=output_image
)
clear_btn.click(
clear_inputs_t2i,
outputs=[prompt, output_image]
)
with gr.Tab("Image Understanding"):
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(label="Upload Image", type="pil")
question_input = gr.Textbox(label="Instruction", value="Describe the image shortly.")
with gr.Row():
qa_btn = gr.Button("Generate")
clear_btn_i2t = gr.Button("Clear")
with gr.Column(scale=1):
answer_output = gr.Textbox(label="Response", lines=4)
qa_btn.click(
understand_image,
inputs=[image_input, question_input],
outputs=answer_output
)
clear_btn_i2t.click(
clear_inputs_i2t,
outputs=[image_input, question_input, answer_output]
)
demo.launch(share=True)
|