pkiage commited on
Commit
732b57d
0 Parent(s):

First commit

Browse files
.gitignore ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Virtual environment
2
+ venv
3
+
4
+ # Byte-compiled / optimized / DLL files
5
+ __pycache__/
6
+ *.py[cod]
7
+
8
+ # C extensions
9
+ *.so
10
+
11
+ # Cache and similar
12
+ catboost_info
13
+ learning_rate_0.05
14
+ learning_rate_0.5
15
+ default
16
+ catboost_model.bin
17
+
18
+ # Distribution / packaging
19
+ .Python
20
+ env/
21
+ build/
22
+ develop-eggs/
23
+ dist/
24
+ downloads/
25
+ eggs/
26
+ .eggs/
27
+ lib/
28
+ lib64/
29
+ parts/
30
+ sdist/
31
+ var/
32
+ *.egg-info/
33
+ .installed.cfg
34
+ *.egg
35
+
36
+ # PyInstaller
37
+ # Usually these files are written by a python script from a template
38
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
39
+ *.manifest
40
+ *.spec
41
+
42
+ # Installer logs
43
+ pip-log.txt
44
+ pip-delete-this-directory.txt
45
+
46
+ # Unit test / coverage reports
47
+ htmlcov/
48
+ .tox/
49
+ .coverage
50
+ .coverage.*
51
+ .cache
52
+ nosetests.xml
53
+ coverage.xml
54
+ *.cover
55
+
56
+ # Translations
57
+ *.mo
58
+ *.pot
59
+
60
+ # Django stuff:
61
+ *.log
62
+
63
+ # Sphinx documentation
64
+ docs/_build/
65
+
66
+ # PyBuilder
67
+ target/
68
+
69
+ # DotEnv configuration
70
+ .env
71
+
72
+ # Database
73
+ *.db
74
+ *.rdb
75
+
76
+ # Pycharm
77
+ .idea
78
+
79
+ # VS Code
80
+ .vscode/
81
+
82
+ # Spyder
83
+ .spyproject/
84
+
85
+ # Jupyter NB Checkpoints
86
+ .ipynb_checkpoints/
87
+
88
+ # exclude data from source control by default
89
+ /data/
90
+
91
+ # Mac OS-specific storage files
92
+ .DS_Store
93
+
94
+ # vim
95
+ *.swp
96
+ *.swo
97
+
98
+ # Mypy cache
99
+ .mypy_cache/
100
+
.slugignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ docs
2
+ references
3
+ README.md
Procfile ADDED
@@ -0,0 +1 @@
 
 
1
+ web: sh setup.sh && streamlit run app.py
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Setup and Installation
2
+ ## Install Requirements
3
+ ```shell
4
+ pip install -r requirements.txt
5
+ ```
6
+ ### Local Package Install
7
+ ```shell
8
+ pip install -e .
9
+ ```
10
+ ### Run app locally
11
+ ```shell
12
+ streamlit run app.py
13
+ ```
app.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # IMPORT LIBRARIES & FUNCTIONS
2
+ # External Libraries
3
+ import streamlit as st
4
+
5
+ # Project Functions
6
+ # - Model
7
+ from src.model.model import load_model, stylize_content_image
8
+ from src.model.utils import use_low_resource_settings, suppress_warnings
9
+
10
+ # - Data
11
+ from src.data.data import upload_image
12
+ from src.data.utils import remove_source_images
13
+
14
+ use_low_resource_settings("Yes")
15
+
16
+ suppress_warnings("Yes")
17
+
18
+ st.write("# 🖼️Neural🎨Style🖌️Transfer🖼️")
19
+
20
+ # LOAD MODEL
21
+ with st.spinner("🖌️Loading Model"):
22
+ model = load_model()
23
+ st.success("🖌️Model Loaded")
24
+
25
+ content_image_def, style_image_def = st.columns(2)
26
+ with content_image_def:
27
+ st.write("🖼️ Content Image: Image to style")
28
+ with style_image_def:
29
+ st.write("🎨 Style Image: Style to transfer to content image")
30
+
31
+ content_image, style_image = st.columns(2)
32
+
33
+ # UPLOAD IMAGES
34
+ with content_image:
35
+ ContentColumnTitle = "## 🖼️ Content Image 🖼️"
36
+ ContentImageSelectionPrompt = "Pick a Content image"
37
+ content_image_file = upload_image(
38
+ ContentColumnTitle, ContentImageSelectionPrompt, "Content", "content"
39
+ )
40
+
41
+ with style_image:
42
+ StyleColumnTitle = "## 🎨 Style Image 🎨"
43
+ StyleImageSelectionPrompt = "Pick a Style image"
44
+ style_image_file = upload_image(
45
+ StyleColumnTitle, StyleImageSelectionPrompt, "Style", "style"
46
+ )
47
+
48
+ # STYLIZE CONTENT IMAGE
49
+ stylize_image = st.button("🖼️🖌️🎨 Start Neural Style Transfer 🖼️🖌️🎨")
50
+
51
+ if stylize_image:
52
+ final_image = stylize_content_image(model, content_image_file, style_image_file)
53
+ st.write("# Styled Image:")
54
+ st.image(final_image)
55
+ try:
56
+ remove_source_images()
57
+ except:
58
+ pass
references/References.md ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # References & Inspiration
2
+ ## Project Structure
3
+ [Cookiecutter Data Science](https://drivendata.github.io/cookiecutter-data-science/)
4
+
5
+ ## Model
6
+ [Fast arbitrary image style transfer.](https://tfhub.dev/google/magenta/arbitrary-image-stylization-v1-256/2)
7
+
8
+ ## Repositories
9
+ [kairavkkp/Neural-Style-Transfer-Streamlit](https://github.com/kairavkkp/Neural-Style-Transfer-Streamlit)
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ streamlit>=1.7.0
2
+ tensorflow-hub==0.12.0
3
+ tensorflow-cpu ==2.5.0
setup.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import find_packages, setup
2
+
3
+ setup(
4
+ name='src',
5
+ packages=find_packages(),
6
+ version='0.1.0',
7
+ description='Neutral Style Transfer Tool',
8
+ author='Author',
9
+ license='MIT',
10
+ )
setup.sh ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ mkdir -p ~/.streamlit/
4
+
5
+ echo "\
6
+ [server]\n\
7
+ headless = true\n\
8
+ port = $PORT\n\
9
+ enableCORS = false\n\
10
+ \n\
11
+ " > ~/.streamlit/config.toml
src/__init__.py ADDED
File without changes
src/data/__init__.py ADDED
File without changes
src/data/data.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import streamlit as st
3
+ import tensorflow as tf
4
+
5
+ from ..image_utils import load_img, imshow, transform_img
6
+
7
+
8
+ def upload_image_file(ImageSelectionPrompt, ImageType, image_upload_method):
9
+ st.write(f"{ImageSelectionPrompt}: {image_upload_method}")
10
+ image_file = st.file_uploader(
11
+ f"Upload {ImageType} Image File (png or jpg)", type=("png", "jpg")
12
+ )
13
+ try:
14
+ image_file = image_file.read()
15
+ return transform_img(image_file)
16
+ except:
17
+ pass
18
+
19
+
20
+ def upload_image_url(ImageSelectionPrompt, ImageType, image_upload_method):
21
+ st.write(f"{ImageSelectionPrompt}: {image_upload_method}")
22
+ url = st.text_input(f"{ImageType} Image URL")
23
+ try:
24
+ image_path = tf.keras.utils.get_file(
25
+ os.path.join(os.getcwd(), f"{ImageType.lower()}.jpg"), url
26
+ )
27
+ except:
28
+ pass
29
+ try:
30
+ return load_img(image_path)
31
+ except:
32
+ pass
33
+
34
+
35
+ def upload_image(ColumnTitle, ImageSelectionPrompt, ImageType, KeyString):
36
+ st.write(ColumnTitle)
37
+ image_upload_method = st.radio(
38
+ label="", options=["📁 File Upload", "🔗 URL"], key=KeyString
39
+ )
40
+ if image_upload_method == "📁 File Upload":
41
+ image_file = upload_image_file(
42
+ ImageSelectionPrompt, ImageType, image_upload_method
43
+ )
44
+ if image_upload_method == "🔗 URL":
45
+ image_file = upload_image_url(
46
+ ImageSelectionPrompt, ImageType, image_upload_method
47
+ )
48
+ try:
49
+ st.write(f"{ImageType} Image")
50
+ st.image(imshow(image_file))
51
+ return image_file
52
+ except:
53
+ pass
src/data/utils.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import os
2
+
3
+
4
+ def remove_source_images():
5
+ os.remove('style.jpg')
6
+ os.remove('content.jpg')
src/image_utils.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ import numpy as np
3
+ from PIL import Image
4
+
5
+
6
+ def load_img(path_to_img):
7
+ max_dim = 512
8
+ img = tf.io.read_file(path_to_img)
9
+ img = tf.image.decode_image(img)
10
+ img = tf.image.convert_image_dtype(img, tf.float32)
11
+
12
+ shape = tf.cast(tf.shape(img)[:-1], tf.float32)
13
+ long_dim = max(shape)
14
+ scale = max_dim / long_dim
15
+
16
+ new_shape = tf.cast(shape * scale, tf.int32)
17
+
18
+ img = tf.image.resize(img, new_shape)
19
+ img = img[tf.newaxis, :]
20
+ return img
21
+
22
+
23
+ def transform_img(img):
24
+ max_dim = 512
25
+ img = tf.image.decode_image(img)
26
+ img = tf.image.convert_image_dtype(img, tf.float32)
27
+
28
+ shape = tf.cast(tf.shape(img)[:-1], tf.float32)
29
+ long_dim = max(shape)
30
+ scale = max_dim / long_dim
31
+
32
+ new_shape = tf.cast(shape * scale, tf.int32)
33
+
34
+ img = tf.image.resize(img, new_shape)
35
+ img = img[tf.newaxis, :]
36
+ return img
37
+
38
+
39
+ def imshow(image, title=None):
40
+ if len(image.shape) > 3:
41
+ image = tf.squeeze(image, axis=0)
42
+
43
+ image = np.squeeze(image)
44
+ return image
45
+
46
+
47
+ def tensor_to_image(tensor):
48
+ tensor = tensor * 255
49
+ tensor = np.array(tensor, np.uint8)
50
+ if np.ndim(tensor) > 3:
51
+ assert tensor.shape[0] == 1
52
+ tensor = tensor[0]
53
+ return Image.fromarray(tensor)
src/model/__init__.py ADDED
File without changes
src/model/model.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import tensorflow as tf
3
+ import tensorflow_hub as hub
4
+
5
+ from ..image_utils import tensor_to_image
6
+
7
+
8
+ @st.cache
9
+ def load_model():
10
+ hub_handle = 'https://tfhub.dev/google/magenta/arbitrary-image-stylization-v1-256/2'
11
+ # return hub_model
12
+ return hub.load(hub_handle)
13
+
14
+
15
+
16
+ def stylize_content_image(model, content_image_file, style_image_file):
17
+ try:
18
+ stylized_image = model(tf.constant(
19
+ content_image_file), tf.constant(style_image_file))[0]
20
+ return tensor_to_image(stylized_image) # stylized image
21
+ except:
22
+ stylized_image = model(tf.constant(
23
+ tf.convert_to_tensor(content_image_file[:, :, :, :3])
24
+ ),
25
+ tf.constant(
26
+ tf.convert_to_tensor(style_image_file[:, :, :, :3])
27
+ )
28
+ )[0]
29
+ return tensor_to_image(stylized_image) # stylized image
src/model/utils.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+
4
+ def use_low_resource_settings(option):
5
+ if option == 'Yes':
6
+ os.environ['CUDA_VISIBLE_DEVICES'] = ''
7
+ os.environ['TFHUB_MODEL_LOAD_FORMAT'] = 'COMPRESSED'
8
+
9
+
10
+ def suppress_warnings(option):
11
+ if option == 'Yes':
12
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'