Omarrran commited on
Commit
6e8b0e8
·
verified ·
1 Parent(s): ccca270

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -18
app.py CHANGED
@@ -5,6 +5,7 @@ from torch import nn
5
  import requests
6
  from pathlib import Path
7
  import logging
 
8
 
9
  # Set up logging
10
  logging.basicConfig(level=logging.INFO)
@@ -22,13 +23,28 @@ class TextGenerator(nn.Module):
22
  lstm_out, _ = self.lstm(x)
23
  return self.fc(lstm_out)
24
 
 
 
 
 
 
 
 
 
25
  def download_file(url, local_path):
26
  try:
27
- response = requests.get(url)
28
- response.raise_for_status() # Raise an exception for bad status codes
 
 
 
 
29
  Path(local_path).parent.mkdir(parents=True, exist_ok=True)
 
30
  with open(local_path, 'wb') as f:
31
- f.write(response.content)
 
 
32
  logger.info(f"Successfully downloaded {url} to {local_path}")
33
  except Exception as e:
34
  logger.error(f"Error downloading {url}: {str(e)}")
@@ -41,9 +57,9 @@ def load_model_and_tokenizers():
41
 
42
  # Default configuration values
43
  default_config = {
44
- 'vocab_size': 10000, # Default vocabulary size
45
- 'embedding_dim': 256, # Default embedding dimension
46
- 'hidden_dim': 512 # Default hidden dimension
47
  }
48
 
49
  # URLs for the files
@@ -64,13 +80,19 @@ def load_model_and_tokenizers():
64
 
65
  try:
66
  # Load configuration
67
- with open(cache_dir / "model_config.json", "r") as f:
68
- config = json.load(f)
69
- # Merge with default config
70
- for key in default_config:
71
- if key not in config:
72
- logger.warning(f"Configuration parameter '{key}' not found, using default value: {default_config[key]}")
73
- config[key] = default_config[key]
 
 
 
 
 
 
74
  except Exception as e:
75
  logger.warning(f"Error loading config file: {str(e)}. Using default configuration.")
76
  config = default_config
@@ -97,10 +119,34 @@ def load_model_and_tokenizers():
97
  hidden_dim=config['hidden_dim']
98
  )
99
 
100
- # Load model weights
101
- model.load_state_dict(torch.load(cache_dir / "model.pt", map_location=torch.device('cpu')))
102
- model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
 
104
  return model, word_to_int, int_to_word
105
 
106
  except Exception as e:
@@ -121,7 +167,7 @@ def generate_text(prompt, max_length=100):
121
 
122
  with torch.no_grad():
123
  for _ in range(max_length):
124
- current_input = torch.tensor([generated_ids[-50:]]) # Use last 50 tokens as context
125
  outputs = model(current_input)
126
  next_token_id = outputs[0, -1, :].argmax().item()
127
  generated_ids.append(next_token_id)
@@ -135,7 +181,7 @@ def generate_text(prompt, max_length=100):
135
 
136
  except Exception as e:
137
  logger.error(f"Error in text generation: {str(e)}")
138
- return f"Error generating text: {str(e)}"
139
 
140
  # Create Gradio interface
141
  iface = gr.Interface(
 
5
  import requests
6
  from pathlib import Path
7
  import logging
8
+ import os
9
 
10
  # Set up logging
11
  logging.basicConfig(level=logging.INFO)
 
23
  lstm_out, _ = self.lstm(x)
24
  return self.fc(lstm_out)
25
 
26
+ def fix_state_dict(state_dict):
27
+ """Fix the state dict keys by removing any 'module.' prefix"""
28
+ new_state_dict = {}
29
+ for k, v in state_dict.items():
30
+ name = k.replace('module.', '') # Remove 'module.' prefix if it exists
31
+ new_state_dict[name] = v
32
+ return new_state_dict
33
+
34
  def download_file(url, local_path):
35
  try:
36
+ response = requests.get(url, stream=True)
37
+ response.raise_for_status()
38
+
39
+ total_size = int(response.headers.get('content-length', 0))
40
+ block_size = 8192
41
+
42
  Path(local_path).parent.mkdir(parents=True, exist_ok=True)
43
+
44
  with open(local_path, 'wb') as f:
45
+ for data in response.iter_content(block_size):
46
+ f.write(data)
47
+
48
  logger.info(f"Successfully downloaded {url} to {local_path}")
49
  except Exception as e:
50
  logger.error(f"Error downloading {url}: {str(e)}")
 
57
 
58
  # Default configuration values
59
  default_config = {
60
+ 'vocab_size': 10000,
61
+ 'embedding_dim': 256,
62
+ 'hidden_dim': 512
63
  }
64
 
65
  # URLs for the files
 
80
 
81
  try:
82
  # Load configuration
83
+ config_path = cache_dir / "model_config.json"
84
+ if config_path.exists():
85
+ with open(config_path, "r") as f:
86
+ config = json.load(f)
87
+ else:
88
+ logger.warning("Config file not found, using default configuration.")
89
+ config = default_config
90
+
91
+ # Merge with default config
92
+ for key in default_config:
93
+ if key not in config:
94
+ logger.warning(f"Configuration parameter '{key}' not found, using default value: {default_config[key]}")
95
+ config[key] = default_config[key]
96
  except Exception as e:
97
  logger.warning(f"Error loading config file: {str(e)}. Using default configuration.")
98
  config = default_config
 
119
  hidden_dim=config['hidden_dim']
120
  )
121
 
122
+ # Load model weights with proper error handling
123
+ model_path = cache_dir / "model.pt"
124
+ if not model_path.exists():
125
+ raise FileNotFoundError(f"Model file not found at {model_path}")
126
+
127
+ # Try different loading approaches
128
+ try:
129
+ # Try loading as a complete model
130
+ loaded_model = torch.load(model_path, map_location=torch.device('cpu'))
131
+ if isinstance(loaded_model, dict):
132
+ # If it's a state dict
133
+ state_dict = fix_state_dict(loaded_model)
134
+ model.load_state_dict(state_dict)
135
+ else:
136
+ # If it's a complete model
137
+ model = loaded_model
138
+ except Exception as e:
139
+ logger.warning(f"First loading attempt failed: {str(e)}")
140
+ try:
141
+ # Try loading as a state dict directly
142
+ state_dict = torch.load(model_path, map_location=torch.device('cpu'))
143
+ state_dict = fix_state_dict(state_dict)
144
+ model.load_state_dict(state_dict)
145
+ except Exception as e2:
146
+ logger.error(f"Both loading attempts failed. Last error: {str(e2)}")
147
+ raise
148
 
149
+ model.eval()
150
  return model, word_to_int, int_to_word
151
 
152
  except Exception as e:
 
167
 
168
  with torch.no_grad():
169
  for _ in range(max_length):
170
+ current_input = torch.tensor([generated_ids[-50:]])
171
  outputs = model(current_input)
172
  next_token_id = outputs[0, -1, :].argmax().item()
173
  generated_ids.append(next_token_id)
 
181
 
182
  except Exception as e:
183
  logger.error(f"Error in text generation: {str(e)}")
184
+ return f"Error generating text: {str(e)}\nPlease check the logs for more details."
185
 
186
  # Create Gradio interface
187
  iface = gr.Interface(