Leo8613 commited on
Commit
86ef0cd
·
verified ·
1 Parent(s): 176a9ba

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -7
app.py CHANGED
@@ -8,8 +8,9 @@ MODEL_PATH = 'ColorizeVideo_gen.pth'
8
 
9
  # Charger le modèle
10
  def load_model(model_path):
11
- model = torch.load(model_path, map_location=torch.device('cpu')) # Charger sur le CPU
12
- model.eval() # Met le modèle en mode évaluation
 
13
  return model
14
 
15
  # Prétraitement de l'image
@@ -17,15 +18,17 @@ def preprocess_frame(frame):
17
  # Redimensionner et normaliser
18
  frame = cv2.resize(frame, (224, 224)) # Ajustez la taille si nécessaire
19
  frame = frame / 255.0 # Normaliser
20
- input_tensor = torch.from_numpy(frame.astype(np.float32)).permute(2, 0, 1) # Convertir en format Tensor
21
- return input_tensor.unsqueeze(0) # Ajouter une dimension de lot
22
 
23
  # Traitement de la vidéo
24
  def process_video(model, video_path):
25
  cap = cv2.VideoCapture(video_path)
 
 
26
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
27
  output_path = "output_video.mp4"
28
- out = cv2.VideoWriter(output_path, fourcc, 30.0, (int(cap.get(3)), int(cap.get(4))))
29
 
30
  while cap.isOpened():
31
  ret, frame = cap.read()
@@ -39,8 +42,9 @@ def process_video(model, video_path):
39
  with torch.no_grad():
40
  predictions = model(input_tensor)
41
 
42
- # Traiter les prédictions et convertir en image
43
  output_frame = (predictions.squeeze().permute(1, 2, 0).numpy() * 255).astype(np.uint8)
 
44
 
45
  # Écrire le cadre traité dans la sortie
46
  out.write(output_frame)
@@ -52,7 +56,7 @@ def process_video(model, video_path):
52
  # Interface Gradio
53
  def colorize_video(video):
54
  model = load_model(MODEL_PATH)
55
- output_video_path = process_video(model, video.name) # Utiliser le nom pour lire la vidéo
56
  return output_video_path
57
 
58
  # Configuration de l'interface Gradio
 
8
 
9
  # Charger le modèle
10
  def load_model(model_path):
11
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+ model = torch.load(model_path, map_location=device)
13
+ model.eval()
14
  return model
15
 
16
  # Prétraitement de l'image
 
18
  # Redimensionner et normaliser
19
  frame = cv2.resize(frame, (224, 224)) # Ajustez la taille si nécessaire
20
  frame = frame / 255.0 # Normaliser
21
+ input_tensor = torch.from_numpy(frame.astype(np.float32)).permute(2, 0, 1)
22
+ return input_tensor.unsqueeze(0)
23
 
24
  # Traitement de la vidéo
25
  def process_video(model, video_path):
26
  cap = cv2.VideoCapture(video_path)
27
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
28
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
29
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
30
  output_path = "output_video.mp4"
31
+ out = cv2.VideoWriter(output_path, fourcc, 30.0, (width, height))
32
 
33
  while cap.isOpened():
34
  ret, frame = cap.read()
 
42
  with torch.no_grad():
43
  predictions = model(input_tensor)
44
 
45
+ # Convertir en image
46
  output_frame = (predictions.squeeze().permute(1, 2, 0).numpy() * 255).astype(np.uint8)
47
+ output_frame = cv2.resize(output_frame, (frame.shape[1], frame.shape[0])) # Rétablir la taille originale
48
 
49
  # Écrire le cadre traité dans la sortie
50
  out.write(output_frame)
 
56
  # Interface Gradio
57
  def colorize_video(video):
58
  model = load_model(MODEL_PATH)
59
+ output_video_path = process_video(model, video.name)
60
  return output_video_path
61
 
62
  # Configuration de l'interface Gradio