PattananKKK commited on
Commit
5a8ef11
·
1 Parent(s): d9cf0bb
Files changed (1) hide show
  1. app.py +35 -4
app.py CHANGED
@@ -1,7 +1,38 @@
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
1
+
2
+ from transformers import AutoProcessor, AutoModelForCausalLM
3
  import gradio as gr
4
+ import torch
5
+
6
+ processor = AutoProcessor.from_pretrained('microsoft/git-base')
7
+ model = AutoModelForCausalLM.from_pretrained('./instagram_caption_generating_model')
8
+
9
+ def predict(image):
10
+ try:
11
+ inputs = processor(images=image, return_tensors="pt")
12
+
13
+ device = "cuda" if torch.cuda.is_available() else "cpu"
14
+ inputs = {key: value.to(device) for key, value in inputs.items()}
15
+ model.to(device)
16
+
17
+ outputs = model.generate(**inputs)
18
+
19
+ caption = processor.batch_decode(outputs, skip_special_tokens=True)[0]
20
+
21
+ return caption
22
+
23
+ except Exception as e:
24
+ print("Error during prediction:", str(e))
25
+ return "Error: " + str(e)
26
+
27
+ with gr.Blocks() as demo:
28
+ image = gr.Image(type="pil")
29
+ predict_btn = gr.Button("Predict", variant="primary")
30
+ output = gr.Textbox(label="Generated Caption")
31
 
32
+ inputs = [image]
33
+ outputs = [output]
34
+
35
+ predict_btn.click(predict, inputs=inputs, outputs=outputs)
36
 
37
+ if __name__ == "__main__":
38
+ demo.launch()