Ihssane123 commited on
Commit
16d8457
·
1 Parent(s): d43420e

pushing files

Browse files
Files changed (35) hide show
  1. .gitattributes +3 -0
  2. Emotions/disgust/dis_1.png +3 -0
  3. Emotions/disgust/dis_2.png +3 -0
  4. Emotions/joy_1.png +3 -0
  5. Emotions/sad/sad_1.png +3 -0
  6. Emotions/sad/sad_2.png +3 -0
  7. Emotions/sad/sad_3.png +3 -0
  8. Emotions/stylized_output.jpg +0 -0
  9. Models_Class/LSTMModel.py +24 -0
  10. Models_Class/NST_class.py +31 -0
  11. Models_Class/__pycache__/LSTMModel.cpython-311.pyc +0 -0
  12. Models_Class/__pycache__/LSTMModel.cpython-312.pyc +0 -0
  13. Models_Class/__pycache__/NST_class.cpython-311.pyc +0 -0
  14. Painters/Pablo Picasso/Dora Maar with Cat (1941).png +3 -0
  15. Painters/Pablo Picasso/The Weeping Woman (1937).png +3 -0
  16. Painters/Pablo Picasso/Three Musicians (1921).png +3 -0
  17. Painters/Salvador Dalí/Sleep (1937).png +3 -0
  18. Painters/Salvador Dalí/Swans Reflecting Elephants (1937).png +3 -0
  19. Painters/Salvador Dalí/The Persistence of Memory (1931).png +3 -0
  20. Painters/Vincent van Gogh/Sunflowers (1888).png +3 -0
  21. Painters/Vincent van Gogh/The Potato Eaters (1885).png +3 -0
  22. Painters/Vincent van Gogh/The Starry Night (1889).png +3 -0
  23. Src/Inference.py +8 -0
  24. Src/Processing.py +17 -0
  25. Src/Processing_img.py +110 -0
  26. Src/__init__.py +0 -0
  27. Src/__pycache__/Inference.cpython-311.pyc +0 -0
  28. Src/__pycache__/Inference.cpython-312.pyc +0 -0
  29. Src/__pycache__/Processing.cpython-311.pyc +0 -0
  30. Src/__pycache__/Processing.cpython-312.pyc +0 -0
  31. Src/__pycache__/Processing_img.cpython-311.pyc +0 -0
  32. Src/__pycache__/__init__.cpython-311.pyc +0 -0
  33. Src/__pycache__/__init__.cpython-312.pyc +0 -0
  34. main.py +519 -0
  35. stylized_output.jpg +0 -0
.gitattributes CHANGED
@@ -33,4 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
36
  *.png filter=lfs diff=lfs merge=lfs -text
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ models filter=lfs diff=lfs merge=lfs -text
37
+ models/lstm_emotion_model_state.pth filter=lfs diff=lfs merge=lfs -text
38
+ Datasets filter=lfs diff=lfs merge=lfs -text
39
  *.png filter=lfs diff=lfs merge=lfs -text
Emotions/disgust/dis_1.png ADDED

Git LFS Details

  • SHA256: 5ac7ab9a76de1bd20f2e97d1222abdd7acca7ec126d31dfe61830611441aeadb
  • Pointer size: 132 Bytes
  • Size of remote file: 2.47 MB
Emotions/disgust/dis_2.png ADDED

Git LFS Details

  • SHA256: ea0cb48344fb0cd177e31b98df3b770f74ed58c93e586fa83c957d6bde6d08d4
  • Pointer size: 132 Bytes
  • Size of remote file: 2.5 MB
Emotions/joy_1.png ADDED

Git LFS Details

  • SHA256: bdde85b1578148b43e157aae6e943cf332a79bf965e0614baaf9db2489e60fb1
  • Pointer size: 132 Bytes
  • Size of remote file: 2.07 MB
Emotions/sad/sad_1.png ADDED

Git LFS Details

  • SHA256: 2f8d8afd794483426c28bfdc6cbac0d503b42d4486e9230770c632fead5b6eff
  • Pointer size: 132 Bytes
  • Size of remote file: 4.72 MB
Emotions/sad/sad_2.png ADDED

Git LFS Details

  • SHA256: e1bc9975cfff5e5c6530dcd4f7c5e836a535081c7060dc4c5c3cbc79bb7fddbb
  • Pointer size: 132 Bytes
  • Size of remote file: 4.02 MB
Emotions/sad/sad_3.png ADDED

Git LFS Details

  • SHA256: 86f57852a4cf0c67e9bb608fe7984aa0f818a10dec245a28dc2d110f705df51f
  • Pointer size: 130 Bytes
  • Size of remote file: 19.5 kB
Emotions/stylized_output.jpg ADDED
Models_Class/LSTMModel.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+
4
+ class LSTMModel(nn.Module):
5
+ ## constructor
6
+ def __init__(self, input_size, hidden_size, output_size, num_layers):
7
+ super(LSTMModel, self).__init__()
8
+ self.input_size = input_size
9
+ self.hidden_size = hidden_size
10
+ self.output_size = output_size
11
+ self.num_layers = num_layers
12
+ self.lstm = nn.LSTM(self.input_size, self.hidden_size, self.num_layers, batch_first=True)
13
+ self.fc = nn.Linear(self.hidden_size, self.output_size)
14
+
15
+ def forward(self,x, h0=None, c0=None):
16
+ # hidden and state vectors h0 and c0
17
+ if h0 is None or c0 is None:
18
+ h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)
19
+ c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)
20
+
21
+ out, (hn, cn) = self.lstm(x, (h0, c0))
22
+ out = self.fc(out)
23
+ return out, (hn, cn)
24
+
Models_Class/NST_class.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+ def gram_matrix(input):
4
+ a, b, c, d = input.size()
5
+ features = input.view(a * b, c * d)
6
+ G = torch.mm(features, features.t())
7
+ return G.div(a * b * c * d)
8
+ class ContentLoss(nn.Module):
9
+ def __init__(self, target):
10
+ super().__init__()
11
+ self.target = target.detach()
12
+ def forward(self, input):
13
+ self.loss = nn.functional.mse_loss(input, self.target)
14
+ return input
15
+
16
+ class StyleLoss(nn.Module):
17
+ def __init__(self, target_feature):
18
+ super().__init__()
19
+ self.target = gram_matrix(target_feature).detach()
20
+ def forward(self, input):
21
+ G = gram_matrix(input)
22
+ self.loss = nn.functional.mse_loss(G, self.target)
23
+ return input
24
+
25
+ class Normalization(nn.Module):
26
+ def __init__(self, mean, std):
27
+ super().__init__()
28
+ self.mean = mean.view(-1, 1, 1)
29
+ self.std = std.view(-1, 1, 1)
30
+ def forward(self, img):
31
+ return (img - self.mean) / self.std
Models_Class/__pycache__/LSTMModel.cpython-311.pyc ADDED
Binary file (1.93 kB). View file
 
Models_Class/__pycache__/LSTMModel.cpython-312.pyc ADDED
Binary file (1.81 kB). View file
 
Models_Class/__pycache__/NST_class.cpython-311.pyc ADDED
Binary file (3.49 kB). View file
 
Painters/Pablo Picasso/Dora Maar with Cat (1941).png ADDED

Git LFS Details

  • SHA256: e3d9a1c358f10e2d5a078fd0b8c6e7360f4de921aaace30d6162524ac1189602
  • Pointer size: 130 Bytes
  • Size of remote file: 25.3 kB
Painters/Pablo Picasso/The Weeping Woman (1937).png ADDED

Git LFS Details

  • SHA256: b10f3b39125ef7c096dfd165b77faa1079774201abe1ec45b563744a4d4b8827
  • Pointer size: 130 Bytes
  • Size of remote file: 32.7 kB
Painters/Pablo Picasso/Three Musicians (1921).png ADDED

Git LFS Details

  • SHA256: f042345d98f8128fbfd8d84b3c84660fa7326418b63599efa7cfe429fd3bb16a
  • Pointer size: 131 Bytes
  • Size of remote file: 521 kB
Painters/Salvador Dalí/Sleep (1937).png ADDED

Git LFS Details

  • SHA256: 987d628b6c0ac7367b5c1b24ede65664ba32baba4f7a94264dc42cbd221c7139
  • Pointer size: 130 Bytes
  • Size of remote file: 60.5 kB
Painters/Salvador Dalí/Swans Reflecting Elephants (1937).png ADDED

Git LFS Details

  • SHA256: f71df5184f15adcdcaed509336b02edc37dd0b2daeb7049142dd61075e126148
  • Pointer size: 130 Bytes
  • Size of remote file: 27.2 kB
Painters/Salvador Dalí/The Persistence of Memory (1931).png ADDED

Git LFS Details

  • SHA256: f306a83ccc3d3f0b2e7894b92ceef114f6d5361750cb982d793acf613a0d107e
  • Pointer size: 130 Bytes
  • Size of remote file: 62 kB
Painters/Vincent van Gogh/Sunflowers (1888).png ADDED

Git LFS Details

  • SHA256: c8cd0e84094378a927dab0432c416e8174204bb98af394ceb757f369d8ef4a2c
  • Pointer size: 130 Bytes
  • Size of remote file: 94 kB
Painters/Vincent van Gogh/The Potato Eaters (1885).png ADDED

Git LFS Details

  • SHA256: d92cb55929ba1f20a552f828625e19277278161cff287f7f2f0fc6448c7eb2b5
  • Pointer size: 131 Bytes
  • Size of remote file: 135 kB
Painters/Vincent van Gogh/The Starry Night (1889).png ADDED

Git LFS Details

  • SHA256: 5f74a2599ced6f8437e2b29a786bdde2a451ff8f01d87418a14c6bbe2ce0ad6a
  • Pointer size: 131 Bytes
  • Size of remote file: 541 kB
Src/Inference.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ ## Start here with the inference procedure
2
+ import torch
3
+ from Models_Class.LSTMModel import LSTMModel
4
+
5
+ def load_model(model_path, input_size, hidden_size, output_size, num_layers):
6
+ loaded_model = LSTMModel(input_size, hidden_size, output_size, num_layers)
7
+ loaded_model.load_state_dict(torch.load(model_path))
8
+ return loaded_model
Src/Processing.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ emotion_list = ["0", "1", "2", "3", "4", "5", "6"]
4
+
5
+
6
+ def load_data(psd_file_pth):
7
+ np_data = np.load(psd_file_pth, allow_pickle=True).item()["psd"]
8
+ return np_data
9
+
10
+
11
+ def process_data(np_data):
12
+ #Swap axes
13
+ swapped_data = np.swapaxes(np_data, 0, 1)
14
+ ## reshape data
15
+ reshape_data = swapped_data.reshape(630, 320)
16
+ return reshape_data
17
+
Src/Processing_img.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import torch
3
+ import torch.optim as optim
4
+ from torchvision import transforms
5
+ import torch.nn as nn
6
+ from Models_Class.NST_class import (
7
+ ContentLoss,
8
+ Normalization,
9
+ StyleLoss,
10
+ )
11
+
12
+ import copy
13
+
14
+ style_weight = 1e8
15
+ content_weight = 1e1
16
+ def image_loader(image_path, loader, device):
17
+ image = Image.open(image_path).convert('RGB')
18
+ image = loader(image).unsqueeze(0)
19
+ return image.to(device, torch.float)
20
+
21
+ def save_image(tensor, path="output.png"):
22
+ image = tensor.cpu().clone()
23
+ image = image.squeeze(0)
24
+ image = transforms.ToPILImage()(image)
25
+ image.save(path)
26
+
27
+ def gram_matrix(input):
28
+ a, b, c, d = input.size()
29
+ features = input.view(a * b, c * d)
30
+ G = torch.mm(features, features.t())
31
+ return G.div(a * b * c * d)
32
+
33
+
34
+ def get_style_model_and_losses(cnn, normalization_mean, normalization_std,
35
+ style_img, content_img, content_layers, style_layers, device):
36
+ cnn = copy.deepcopy(cnn)
37
+ normalization = Normalization(normalization_mean, normalization_std).to(device)
38
+ content_losses = []
39
+ style_losses = []
40
+ model = nn.Sequential(normalization)
41
+
42
+ i = 0
43
+ for layer in cnn.children():
44
+ if isinstance(layer, nn.Conv2d):
45
+ i += 1
46
+ name = f'conv_{i}'
47
+ elif isinstance(layer, nn.ReLU):
48
+ name = f'relu_{i}'
49
+ layer = nn.ReLU(inplace=False)
50
+ elif isinstance(layer, nn.MaxPool2d):
51
+ name = f'pool_{i}'
52
+ elif isinstance(layer, nn.BatchNorm2d):
53
+ name = f'bn_{i}'
54
+ else:
55
+ raise RuntimeError(f'Unrecognized layer: {layer.__class__.__name__}')
56
+
57
+ model.add_module(name, layer)
58
+
59
+ if name in content_layers:
60
+ target = model(content_img).detach()
61
+ content_loss = ContentLoss(target)
62
+ model.add_module(f"content_loss_{i}", content_loss)
63
+ content_losses.append(content_loss)
64
+
65
+ if name in style_layers:
66
+ target_feature = model(style_img).detach()
67
+ style_loss = StyleLoss(target_feature)
68
+ model.add_module(f"style_loss_{i}", style_loss)
69
+ style_losses.append(style_loss)
70
+
71
+ for i in range(len(model) - 1, -1, -1):
72
+ if isinstance(model[i], (ContentLoss, StyleLoss)):
73
+ break
74
+ model = model[:i+1]
75
+ return model, style_losses, content_losses
76
+
77
+
78
+
79
+ def run_style_transfer(cnn, normalization_mean, normalization_std,
80
+ content_img, style_img, input_img,content_layers, style_layers, device, num_steps=300):
81
+ print("Building the style transfer model..")
82
+ model, style_losses, content_losses = get_style_model_and_losses(cnn, normalization_mean, normalization_std,
83
+ style_img, content_img,content_layers, style_layers, device )
84
+ optimizer = optim.LBFGS([input_img.requires_grad_()])
85
+
86
+ print("Optimizing..")
87
+ run = [0]
88
+ while run[0] <= num_steps:
89
+ def closure():
90
+ input_img.data.clamp_(0, 1)
91
+ optimizer.zero_grad()
92
+ model(input_img)
93
+ style_score = sum(sl.loss for sl in style_losses)
94
+ content_score = sum(cl.loss for cl in content_losses)
95
+ loss = style_weight * style_score + content_weight * content_score
96
+ loss.backward()
97
+
98
+ if run[0] % 50 == 0:
99
+ print(f"Step {run[0]}:")
100
+ print(f" Style Loss: {style_score.item():.4f}")
101
+ print(f" Content Loss: {content_score.item():.4f}")
102
+ print(f" Total Loss: {loss.item():.4f}\n")
103
+
104
+ run[0] += 1
105
+ return loss
106
+
107
+ optimizer.step(closure)
108
+
109
+ input_img.data.clamp_(0, 1)
110
+ return input_img
Src/__init__.py ADDED
File without changes
Src/__pycache__/Inference.cpython-311.pyc ADDED
Binary file (672 Bytes). View file
 
Src/__pycache__/Inference.cpython-312.pyc ADDED
Binary file (605 Bytes). View file
 
Src/__pycache__/Processing.cpython-311.pyc ADDED
Binary file (901 Bytes). View file
 
Src/__pycache__/Processing.cpython-312.pyc ADDED
Binary file (825 Bytes). View file
 
Src/__pycache__/Processing_img.cpython-311.pyc ADDED
Binary file (7.34 kB). View file
 
Src/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (181 Bytes). View file
 
Src/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (169 Bytes). View file
 
main.py ADDED
@@ -0,0 +1,519 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import gradio as gr
3
+ import pandas as pd
4
+ import numpy as np
5
+ from Src.Processing import load_data
6
+ from Src.Processing import process_data
7
+ from Src.Inference import load_model
8
+ import torchvision.models as models
9
+ from torchvision import transforms
10
+ from Src.Processing_img import (
11
+ get_style_model_and_losses,
12
+ image_loader,
13
+ run_style_transfer,
14
+ save_image,
15
+ gram_matrix
16
+ )
17
+ import torch
18
+ import tempfile
19
+ import time # For simulating delays
20
+ from PIL import Image, ImageDraw, ImageFont # Ensure ImageDraw and ImageFont are imported
21
+ import os # For file operations
22
+ import mne
23
+ import matplotlib.pyplot as plt
24
+ import io
25
+ import pyvista as pv
26
+ import matplotlib.cm as cm
27
+ import gradio as gr
28
+ pv.set_plot_theme("document") # A simple theme
29
+ pv.set_jupyter_backend('html')
30
+
31
+ # --- Data for demonstration ---
32
+ # Dummy data for Emotion Distribution Bar Chart
33
+ # In a real app, this would come from your PSD analysis
34
+ dummy_emotion_data = pd.DataFrame({
35
+ 'Emotion': ['sad', 'dis', 'fear', 'neu', 'joy', 'ten', 'ins'],
36
+ 'Value': [0.8, 0.6, 0.1, 0.4, 0.7, 0.2, 0.3]
37
+ })
38
+
39
+ int_to_emotion = {
40
+ 0: 'sad',
41
+ 1: 'dis',
42
+ 2: 'fear',
43
+ 3: 'neu',
44
+ 4: 'joy',
45
+ 5: 'ten',
46
+ 6: 'ins'
47
+ }
48
+
49
+ abr_to_emotion = {
50
+ 'sad': "sadness",
51
+ 'dis': "disgust",
52
+ 'fear': "fear",
53
+ 'neu': "neutral",
54
+ 'joy': "joy",
55
+ 'ten': 'Tenderness',
56
+ 'ins': "inspiration"
57
+ }
58
+
59
+ # --- Local Image Paths Setup for Dynamic Loading ---
60
+ # Define a base directory for all painters' images
61
+ # In Hugging Face Spaces, this would be a folder like 'Painters/' in your repository
62
+ PAINTERS_BASE_DIR = "Painters"
63
+ EMOTION_BASE_DIR = "Emotions"
64
+ model_path = "models\lstm_emotion_model_state.pth"
65
+ input_size = 320
66
+ hidden_size=50
67
+ output_size = 7
68
+ num_layers=1
69
+
70
+ # Define painters and some example "filenames" to create placeholders for
71
+ painters = ["Pablo Picasso", "Vincent van Gogh", "Salvador Dalí"]
72
+ Base_Dir = "Datasets"
73
+
74
+ # This dictionary defines what placeholder files to create and their captions.
75
+ # The actual gallery content will be read from the file system.
76
+ PAINTER_PLACEHOLDER_DATA = {
77
+ "Pablo Picasso": [
78
+ ("Dora Maar with Cat (1941).png", "Dora Maar with Cat (1941)"),
79
+ ("The Weeping Woman (1937).png", "The Weeping Woman (1937)"),
80
+ ("Three Musicians (1921).png", "Three Musicians (1921)"),
81
+ ],
82
+ "Vincent van Gogh": [
83
+ ("Sunflowers (1888).png", "Sunflowers (1888)"),
84
+ ("The Starry Night (1889).png", "The Starry Night (1889)"),
85
+ ("The Potato Eaters (1885).png", "The Potato Eaters (1885)"),
86
+ ],
87
+ "Salvador Dalí": [
88
+ ("Persistence of Memory (1931).png", "Persistence of Memory (1931)"),
89
+ ("Swans Reflecting Elephants (1937).png", "Swans Reflecting Elephants (1937)"),
90
+ ("Sleep (1937).png", "Sleep (1937)"),
91
+ ],
92
+ }
93
+
94
+
95
+
96
+
97
+ # --- Define the specific PSD files to choose from ---
98
+ predefined_psd_files = ["task-emotion_psd_1.npy", "task-emotion_psd_2.npy", "task-emotion_psd_3.npy"] # You can put full paths here if they are actual files
99
+
100
+ # --- Core Functions (Simulated) ---
101
+
102
+ def upload_psd_file(selected_file_name):
103
+ """
104
+ Processes a selected PSD file, performs inference, and prepares emotion distribution data.
105
+ """
106
+ if selected_file_name is None:
107
+ # If no file is selected, return a dummy plot hidden
108
+ # Return the dummy DataFrame and an empty DataFrame for the state
109
+ return gr.BarPlot(dummy_emotion_data, x="Emotion", y="Value", label="Emotion Distribution", visible=False), pd.DataFrame()
110
+
111
+ # --- Load and Process PSD Data ---
112
+ psd_file_path = os.path.join(Base_Dir, selected_file_name).replace(os.sep, '/')
113
+
114
+ # In a real scenario, you'd handle file existence check and errors for load_data
115
+ try:
116
+ global np_data
117
+ np_data = load_data(psd_file_path)
118
+ print(f"np data orig {np_data.shape}")
119
+ except FileNotFoundError:
120
+ print(f"Error: PSD file not found at {psd_file_path}")
121
+ # Return a plot with error message or just hide it
122
+ return gr.BarPlot(dummy_emotion_data, x="Emotion", y="Value", label="Emotion Distribution (Error: File not found)", visible=False), pd.DataFrame()
123
+
124
+
125
+ final_data = process_data(np_data)
126
+ # Ensure data is suitable for LSTM (e.g., (batch, sequence_length, input_size))
127
+ # If final_data is (sequence_length, input_size), add a batch dimension
128
+ torch_data = torch.tensor(final_data, dtype=torch.float32).unsqueeze(0)
129
+
130
+ print(f"Processed data shape for model: {torch_data.shape}")
131
+
132
+ # --- Inference ---
133
+ # Ensure model_path is correct relative to where app.py is run
134
+ # If 'models' is at your_project_root, adjust path if needed
135
+
136
+
137
+ # Assuming 'models' directory is at 'your_project_root' level
138
+ absolute_model_path = os.path.join("models", "lstm_emotion_model_state.pth")
139
+
140
+
141
+ loaded_model = load_model(absolute_model_path, input_size, hidden_size, output_size, num_layers)
142
+ loaded_model.eval() # Set model to evaluation mode
143
+
144
+ # Pass the prepared torch_data to the model
145
+ with torch.no_grad(): # Disable gradient calculation for inference
146
+ predicted_logits, _ = loaded_model(torch_data) # LSTM returns (output, (h_n, c_n))
147
+
148
+ # Get the most probable emotion index for each time step in the sequence
149
+ final_output_indices = torch.argmax(predicted_logits, dim=2) # Shape: (batch_size, sequence_length)
150
+
151
+ # Flatten the sequence to count overall emotion frequencies
152
+ # If batch size is 1, and sequence is long, this view(-1) works for counting all predictions
153
+ all_predicted_indices = final_output_indices.view(-1)
154
+
155
+ print(f"All predicted indices (flattened): {all_predicted_indices}")
156
+
157
+ # Count occurrences of each predicted emotion index
158
+ values_count = torch.bincount(all_predicted_indices, minlength=output_size) # Use minlength to ensure all 7 indices are considered
159
+ print(f"Raw bincount: {values_count}")
160
+
161
+ # --- Create Emotion Distribution DataFrame ---
162
+ # Initialize emotions_count with all emotions set to 0 frequency
163
+ emotions_count = {int_to_emotion[i].strip(): 0 for i in range(output_size)} # Use .strip() to remove trailing space from 'sad '
164
+
165
+ # Update counts only for emotions that were actually predicted
166
+ for idx, count in enumerate(values_count):
167
+ if idx < output_size: # Ensure index is within the expected range
168
+ emotions_count[int_to_emotion[idx].strip()] = count.item() # Use .strip() here too
169
+
170
+ # Convert dictionary to DataFrame
171
+ dom_emotion = max(emotions_count, key=emotions_count.get)
172
+ # Ensure column names match what gr.BarPlot expects: "Emotion" and "Frequency"
173
+ emotion_data = pd.DataFrame({
174
+ "Emotion": list(emotions_count.keys()),
175
+ "Frequency": list(emotions_count.values())
176
+ })
177
+
178
+ # Optional: Sort DataFrame by emotion name or frequency if desired for consistent plotting
179
+ emotion_data = emotion_data.sort_values(by="Emotion").reset_index(drop=True)
180
+
181
+ print(f"Final emotion_data DataFrame:\n{emotion_data}")
182
+
183
+ # CORRECTED: Return the DataFrame, NOT a gr.BarPlot object
184
+ # Return both the DataFrame for the plot and the DataFrame itself for the state
185
+ return gr.BarPlot(
186
+ emotion_data,
187
+ x="Emotion",
188
+ y="Frequency",
189
+ label="Emotion Distribution",
190
+ visible=True,
191
+ y_title="Frequency"
192
+ ), emotion_data, gr.Textbox(abr_to_emotion[dom_emotion], visible=True)
193
+
194
+
195
+ def update_paintings(painter_name):
196
+ """
197
+ Updates the gallery with paintings specific to the selected painter by
198
+ dynamically listing files in the painter's directory.
199
+ """
200
+ painter_dir = os.path.join(PAINTERS_BASE_DIR, painter_name).replace(os.sep, '/')
201
+ print(painter_dir)
202
+
203
+ artist_paintings_for_gallery = []
204
+ if os.path.isdir(painter_dir):
205
+ for filename in sorted(os.listdir(painter_dir)): # Sort for consistent order
206
+
207
+ if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.gif')):
208
+ file_path = os.path.join(painter_dir, filename).replace(os.sep, '/')
209
+ print(file_path)
210
+ # Use filename without extension as title, or create more sophisticated parsing
211
+ title_with_ext = os.path.splitext(filename)[0]
212
+ artist_paintings_for_gallery.append((file_path, title_with_ext))
213
+ print(f"Loaded paintings for {painter_name}: {artist_paintings_for_gallery}")
214
+ return artist_paintings_for_gallery # Return the list directly for the gallery
215
+
216
+
217
+ def generate_my_art(painter, chosen_painting, dom_emotion):
218
+ """
219
+ Simulates the image generation process.
220
+ 'chosen_painting_info' will be the single selected item from gr.Gallery.select(),
221
+ e.g., ['Painters/Pablo Picasso/Dora Maar with Cat (1941).png', 'Dora Maar with Cat (1941)']
222
+ We need to extract the path.
223
+ """
224
+ print("generating started")
225
+ print(f"painter: {painter}")
226
+ print(f"choosen painting: {chosen_painting}")
227
+ if not painter or not chosen_painting:
228
+ # Provide default outputs to ensure Gradio components are updated correctly
229
+ return "Please select a painter and a painting.", None, None
230
+
231
+ ##style image_path
232
+ img_style_pth = os.path.join(PAINTERS_BASE_DIR, painter, chosen_painting)
233
+ print(f"img_stype_path: {img_style_pth}")
234
+
235
+ # Display initial status and disable button
236
+ # --- Simulate your NST or Diffusion Process Here ---
237
+ # In a real scenario, this would involve your actual NST code.
238
+ # It would use `selected_painting_path` as the style image.
239
+ # A content image would be dynamically generated (e.g., a simple colored canvas or
240
+ # abstract representation based on the PSD analysis's dominant emotion).
241
+
242
+ time.sleep(3) # Simulate processing time
243
+
244
+ # --- Simulate saving a generated image locally ---
245
+ # This PIL Image would be the actual result of your NST.
246
+ # We save it to the 'generated_art' directory.
247
+ """generated_img_pil = Image.new('RGB', (400, 400), color=(np.random.randint(0,255), np.random.randint(0,255), np.random.randint(0,255)))
248
+ generated_image_local_path = f"generated_art/generated_output_{int(time.time())}.png"
249
+ generated_img_pil.save(generated_image_local_path)
250
+
251
+ # For the blended image, let's just return the selected style image path for now.
252
+ # In a real app, this might be a version of the 'generated_img_pil' with a final blend.
253
+ blended_image_local_path = selected_painting_path """
254
+
255
+ ##original image
256
+ emotion_pth = os.path.join(EMOTION_BASE_DIR, dom_emotion)
257
+ image_name = list(os.listdir(emotion_pth))[random.randint(0, len(os.listdir(emotion_pth)) -1)]
258
+ original_image_pth = os.path.join(emotion_pth, image_name)
259
+ print(f"original img _path: {original_image_pth}")
260
+ final_message = f"Art generated based on {painter}'s {chosen_painting} style!"
261
+
262
+ ## Neural Style Transfer added here
263
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
264
+ imsize = 512 if torch.cuda.is_available() else 256
265
+ loader = transforms.Compose([
266
+ transforms.Resize((imsize, imsize)),
267
+ transforms.ToTensor()
268
+ ])
269
+ content_layers = ['conv_4']
270
+ style_layers = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']
271
+ cnn = models.vgg19(pretrained=True).features.to(device).eval()
272
+ cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]).to(device)
273
+ cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225]).to(device)
274
+ style_img = image_loader(img_style_pth, loader, device)
275
+ content_img = image_loader(original_image_pth, loader, device)
276
+ input_img = content_img.clone()
277
+ output = run_style_transfer(cnn, cnn_normalization_mean, cnn_normalization_std,
278
+ content_img, style_img, input_img, content_layers, style_layers, device)
279
+ save_image(output, "stylized_output.jpg")
280
+ print("Stylized image saved as 'stylized_output.jpg'")
281
+
282
+ stylized_img_path = 'stylized_output.jpg'
283
+ # Return final results and re-enable button
284
+ yield gr.Textbox(final_message), original_image_pth, stylized_img_path
285
+
286
+
287
+ def generate_topomap(n_channels, n_time):
288
+ n_sensors = 64
289
+
290
+ if n_channels is None or n_time is None:
291
+ print("they are None")
292
+ n_channels = 4
293
+ n_time = 500
294
+ # ----------------------------
295
+ # 2. Load standard 10-20 montage
296
+ # ----------------------------
297
+ montage = mne.channels.make_standard_montage('standard_1020')
298
+ # Filter only the standard 64 EEG electrodes
299
+ standard_64_chs = [
300
+ 'Fp1', 'Fp2', 'F7', 'F3', 'Fz', 'F4', 'F8',
301
+ 'FC5', 'FC1', 'FC2', 'FC6', 'T7', 'C3', 'Cz', 'C4', 'T8',
302
+ 'CP5', 'CP1', 'CP2', 'CP6', 'P7', 'P3', 'Pz', 'P4', 'P8',
303
+ 'POz', 'O1', 'Oz', 'O2', 'Fpz', 'AF7', 'AF3', 'AF4', 'AF8',
304
+ 'F5', 'F1', 'F2', 'F6', 'FC3', 'FCz', 'FC4', 'C5', 'C1', 'C2',
305
+ 'C6', 'CP3', 'CPz', 'CP4', 'P5', 'P1', 'P2', 'P6', 'PO3', 'PO4',
306
+ 'PO7', 'PO8', 'PO9', 'PO10', 'O1', 'O2', 'FT7', 'FT8', 'TP7', 'TP8'
307
+ ]
308
+ # exactly 64 channels
309
+ ch_pos_dict = montage.get_positions()['ch_pos']
310
+
311
+ ch_pos_dict_filtered = {ch: ch_pos_dict[ch] for ch in standard_64_chs}
312
+ channel_names = list(ch_pos_dict_filtered.keys())
313
+ ch_pos_array = np.array([ch_pos_dict_filtered[ch] for ch in standard_64_chs]) # Nx3
314
+ ch_pos_2d = ch_pos_array[:, :2] # For 2D topomap
315
+
316
+ # ----------------------------
317
+ # 3. Choose a time index and frequency index
318
+ # ----------------------------
319
+
320
+ new_data = np_data.reshape(64, 630, 5)
321
+ print(f"shape: {new_data.shape}")
322
+ print(f"n channels: {n_channels}")
323
+ psd_snapshot = new_data[:, n_time - 1, n_channels - 1]
324
+
325
+ # Normalize PSD for coloring
326
+ psd_norm = (psd_snapshot - psd_snapshot.min()) / (psd_snapshot.max() - psd_snapshot.min())
327
+
328
+ # ----------------------------
329
+ # 4. Plot 2D topomap using MNE
330
+ # ----------------------------
331
+
332
+ print(f"shape psd :{psd_snapshot.shape}")
333
+ fig, ax = plt.subplots()
334
+ mne.viz.plot_topomap(
335
+ psd_snapshot,
336
+ ch_pos_2d,
337
+ names=channel_names,
338
+ show=False,
339
+ axes=ax
340
+ )
341
+
342
+ # Save the generated topomap to a temp file
343
+ tmpfile = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
344
+ fig.savefig(tmpfile.name, dpi=150, bbox_inches="tight")
345
+ plt.close(fig)
346
+ return tmpfile.name
347
+
348
+ predefined_psd_files = ["task-emotion_psd_1.npy", "task-emotion_psd_2.npy", "task-emotion_psd_3.npy"]
349
+
350
+ # --- Gradio Interface Definition ---
351
+
352
+ with gr.Blocks(css=".gradio-container { max-width: 2000px; margin: auto; }") as demo:
353
+ # Define the gr.State component here, accessible throughout the Blocks
354
+ # This will hold the information of the SINGLE selected painting from the gallery
355
+ # This will hold the DataFrame of the emotion distribution (to be passed to generate_my_art)
356
+ current_emotion_df_state = gr.State(value=pd.DataFrame())
357
+
358
+
359
+ # Header Section
360
+ gr.Markdown(
361
+ """
362
+ <h1 style="text-align: center;font-size: 5em; padding: 20px; font-weight: bold;">Brain Emotion Decoder 🧠🎨</h1>
363
+ <p style="text-align: center; font-size: 1.5em; color: #555;font-weight: bold;">
364
+ Imagine seeing your deepest feelings transform into art. We decode the underlying emotions from your brain activity,
365
+ generating a personalized artwork that comes to life within an interactive 3D brain model. Discover the art of your inner self.
366
+ </p>
367
+ """
368
+ )
369
+
370
+ with gr.Row():
371
+ # Left Column: Input and Emotion Distribution
372
+ with gr.Column(scale=1):
373
+ gr.Markdown("<h2 font-size: 2em;>1. Choose a PSD file<h2>")
374
+ # Radio buttons to select from predefined files
375
+ psd_file_selection = gr.Radio(
376
+ choices=predefined_psd_files,
377
+ label="Select a PSD file for analysis",
378
+ value=predefined_psd_files[0], # Default selection
379
+ interactive=True
380
+ )
381
+
382
+ # Button to trigger PSD analysis
383
+ analyze_psd_button = gr.Button("Analyze PSD File", variant="secondary")
384
+
385
+ gr.Markdown("<h2 font-size: 2em;>2. Emotion Distribution<h2>")
386
+
387
+ # Bar plot for emotion distribution
388
+ emotion_distribution_plot = gr.BarPlot(
389
+ dummy_emotion_data,
390
+ x="Emotion",
391
+ y="Value",
392
+ label="Emotion Distribution",
393
+ height=300,
394
+ x_title="Emotion Type",
395
+ y_title="Frequency",
396
+ visible=False # Hidden until analysis is triggered
397
+ )
398
+
399
+
400
+ dom_emotion = gr.Textbox(label = "dominant emotion", visible=False)
401
+
402
+
403
+
404
+ # Right Column: Art Museum and Generation
405
+ with gr.Column(scale=1):
406
+ gr.Markdown("<h3>Your Art Mesum</h3>") # Kept original heading
407
+
408
+ gr.Markdown("<h3>3. Choose your favourite painter</h3>")
409
+ painter_dropdown = gr.Dropdown(
410
+ choices=painters,
411
+ value="Pablo Picasso", # Default selection
412
+ label="Select a Painter"
413
+ )
414
+
415
+ gr.Markdown("<h3>4. Choose your favourite painting</h3>")
416
+ # Gallery to display paintings for selection
417
+ painting_gallery = gr.Gallery(
418
+ # Correct initial value and visibility
419
+ value=update_paintings("Pablo Picasso"), # Initial load for Picasso's paintings
420
+ label="Select a Painting",
421
+ height=300,
422
+ columns=3,
423
+ rows=1,
424
+ object_fit="contain",
425
+ preview=True, # Allows clicking to see larger image
426
+ interactive=True, # Make it selectable
427
+ elem_id="painting_gallery",
428
+ visible=True, # Should be visible by default
429
+ )
430
+
431
+
432
+ # Button to trigger art generation
433
+ selected_painting_name = gr.Textbox(visible=False)
434
+ generate_button = gr.Button("Generate My Art", variant="primary", scale=0)
435
+ # Status message for image generation
436
+ status_message = gr.Textbox(
437
+ value="Click 'Generate My Art' to begin.",
438
+ label="Generation Status",
439
+ interactive=False,
440
+ show_label=False,
441
+ lines=1
442
+ )
443
+
444
+ # Output section on a separate "page" or revealed dynamically
445
+ gr.Markdown(
446
+ """
447
+ <h1 style="text-align: center;">Your Generated Artwork</h1>
448
+ <p style="text-align: center; color: #555;">
449
+ Once your brain's emotional data is processed, we pinpoint the <b>dominant emotion</b>. This single feeling inspires a <b>personalized artwork</b>, generated using <b>diffusion techniques</b> and blended with <b>my AI painting style</b>. You can then <b>download</b> this unique visual representation of your inner self.
450
+ </p>
451
+ """
452
+ )
453
+
454
+ with gr.Row():
455
+ with gr.Column(scale=1):
456
+ gr.Markdown("<h3>Generated Image</h3>")
457
+ generated_image_output = gr.Image(label="Generated Image", show_label=False, height=300)
458
+ gr.Markdown("<h3>Blended Style Image</h3>")
459
+ blended_image_output = gr.Image(label="Blended Style Image", show_label=False, height=300)
460
+
461
+ with gr.Column(scale=1):
462
+ gr.Markdown("<h3>Brain Topomap</h3>")
463
+ channels_slider = gr.Slider(minimum=1, maximum=5, value=1, step=1, label="Channels", interactive=True)
464
+ timestamp_slider = gr.Slider(minimum=1, maximum=630, value=1, step=1, label="Timestamp", interactive=True)
465
+ mne_2d_img = gr.Image(visible=True)
466
+ generate_button.click(generate_topomap, outputs=mne_2d_img)
467
+
468
+
469
+
470
+ # --- Event Listeners ---
471
+ analyze_psd_button.click(
472
+ upload_psd_file,
473
+ inputs=[psd_file_selection], # Input is the selected radio button value (file name)
474
+ outputs=[emotion_distribution_plot, current_emotion_df_state, dom_emotion] # CORRECTED: Added current_emotion_df_state to outputs
475
+ )
476
+
477
+ # When painter dropdown changes, update the gallery content and reset selected_painting_state
478
+ painter_dropdown.change(
479
+ update_paintings, # This updates the gallery
480
+ inputs=[painter_dropdown],
481
+ outputs=[painting_gallery] # Only output the gallery content directly
482
+ )
483
+
484
+ # IMPORTANT: Use the .select() method of gr.Gallery to capture the specific clicked item.
485
+ # The 'select' event passes the selected value directly as the argument to the function.
486
+ # We use a lambda to simply return that selected value and store it in our state.
487
+ def on_select(evt: gr.SelectData):
488
+ print("this function started")
489
+ print(f"Image index: {evt.index}\nImage value: {evt.value['image']['orig_name']}")
490
+ return evt.value['image']['orig_name']
491
+ painting_gallery.select(
492
+ on_select, # This lambda receives the selected image info (path, title)
493
+ outputs=[selected_painting_name] # The output updates our gr.State component
494
+ )
495
+
496
+
497
+
498
+
499
+
500
+ # The generate_button now correctly uses the value from selected_painting_state
501
+ generate_button.click(
502
+ generate_my_art,
503
+ inputs=[painter_dropdown, selected_painting_name, dom_emotion], # Pass painter and the SELECTED painting
504
+ outputs=[status_message, generated_image_output, blended_image_output]
505
+ )
506
+
507
+
508
+ ## sliders event listener
509
+ channels_slider.change(fn=generate_topomap, inputs=[channels_slider, timestamp_slider], outputs=mne_2d_img)
510
+ timestamp_slider.change(fn=generate_topomap, inputs=[channels_slider, timestamp_slider], outputs=mne_2d_img)
511
+ # Launch the demo
512
+ if __name__ == "__main__":
513
+ # Ensure project_root_dir is defined for this block if you uncomment these lines
514
+ # project_root_dir = os.path.dirname(os.path.abspath(__file__))
515
+ # project_root_dir = os.path.dirname(project_root_dir)
516
+ # print(f"Loading LSTM model from: {os.path.join(project_root_dir, model_path)}")
517
+ # _ = load_model(os.path.join(project_root_dir, model_path), input_size, hidden_size, output_size, num_layers)
518
+
519
+ demo.launch()
stylized_output.jpg ADDED