AmitIsraeli commited on
Commit
78a8ca9
·
1 Parent(s): 1ad43f4

Add application file

Browse files
app.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from help_function import help_function
3
+ from PIL import Image
4
+ import numpy as np
5
+
6
+ model_helper = help_function()
7
+
8
+ def greet(numpy_image,text,float_value):
9
+ PIL_image = Image.fromarray(np.uint8(numpy_image)).convert('RGB')
10
+ image_edit = model_helper.image_from_text(text,PIL_image,float_value)
11
+ return image_edit
12
+
13
+ iface = gr.Interface(fn=greet, inputs=["image", "text", gr.inputs.Slider(0.0, 1.0)], outputs="image")
14
+ iface.launch()
help_function.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import open_clip
3
+ from torchvision import transforms
4
+ from torchvision.transforms import ToPILImage
5
+
6
+ class help_function:
7
+ def __init__(self):
8
+ self.clip_text_model = torch.jit.load('jit_models/clip_text_jit.pt', map_location=torch.device('cpu'))
9
+ self.decoder = torch.jit.load('jit_models/decoder_16w.pt', map_location=torch.device('cpu'))
10
+ self.mapper_clip = torch.jit.load('jit_models/mapper_clip_jit.pt', map_location=torch.device('cpu'))
11
+ self.mean_clip = torch.load('jit_models/mean_clip.pt')
12
+ self.mean_person = torch.load('jit_models/mean_person.pt')
13
+ self.encoder = torch.jit.load('jit_models/combined_encoder.pt', map_location=torch.device('cpu'))
14
+ self.tokenizer = open_clip.get_tokenizer('ViT-B-32')
15
+ self.transform = transforms.Compose([
16
+ transforms.Resize(224),
17
+ transforms.ToTensor(),
18
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
19
+ ])
20
+
21
+ def get_text_embedding(self, text):
22
+ text = self.clip_text_model(self.tokenizer(text))
23
+ return text
24
+
25
+ def get_image_inversion(self, image):
26
+ image = self.transform(image)
27
+ w_inversion = self.encoder(image.reshape(1,3,224,224)).reshape(1,16,512)
28
+ return w_inversion + self.mean_person
29
+
30
+ def get_text_delta(self,text_feachers):
31
+ w_delta = self.mapper_clip(text_feachers - self.mean_clip)
32
+ return w_delta
33
+ def image_from_text(self,text,image,power = 1.0):
34
+ w_inversion = self.get_image_inversion(image)
35
+ text_embedding = self.get_text_embedding(text)
36
+ w_delta = self.get_text_delta(text_embedding)
37
+
38
+ w_edit = w_inversion + w_delta * power
39
+ image_edit = self.decoder(w_edit)
40
+ return ToPILImage()((image_edit[0]+0.5)*0.5)
jit_models/.DS_Store ADDED
Binary file (6.15 kB). View file
 
jit_models/mean_clip.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6ee2ab527e32a6168ecb7cd2288d2374fc4ffd57d2287c5f1cf35e4ac677ac6a
3
+ size 2801
jit_models/mean_person.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:06f4b6b9913609b2a1ed3c63d198ce510f02dc9a0d70bdb037c073e76768df46
3
+ size 33527
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ open-clip-torch==2.16.0
2
+ orjson==3.8.10
3
+ packaging==23.0
4
+ pandas==2.0.0
5
+ Pillow==9.5.0
6
+ torch==2.0.0
7
+ torchvision==0.15.1
8
+ tqdm==4.65.0
9
+