Omarrran commited on
Commit
ccca270
·
verified ·
1 Parent(s): 7686669

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +91 -50
app.py CHANGED
@@ -1,13 +1,17 @@
1
  import gradio as gr
2
  import torch
3
  import json
4
- from transformers import GPT2Config
5
  from torch import nn
6
  import requests
7
  from pathlib import Path
 
 
 
 
 
8
 
9
  class TextGenerator(nn.Module):
10
- def __init__(self, vocab_size, embedding_dim, hidden_dim):
11
  super().__init__()
12
  self.embedding = nn.Embedding(vocab_size, embedding_dim)
13
  self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
@@ -19,19 +23,29 @@ class TextGenerator(nn.Module):
19
  return self.fc(lstm_out)
20
 
21
  def download_file(url, local_path):
22
- response = requests.get(url)
23
- if response.status_code == 200:
 
24
  Path(local_path).parent.mkdir(parents=True, exist_ok=True)
25
  with open(local_path, 'wb') as f:
26
  f.write(response.content)
27
- else:
28
- raise Exception(f"Failed to download {url}")
 
 
29
 
30
  def load_model_and_tokenizers():
31
  # Create a local directory for downloaded files
32
  cache_dir = Path("model_cache")
33
  cache_dir.mkdir(exist_ok=True)
34
 
 
 
 
 
 
 
 
35
  # URLs for the files
36
  base_url = "https://huggingface.co/Omarrran/temp_data/raw/main"
37
  files = {
@@ -45,56 +59,83 @@ def load_model_and_tokenizers():
45
  for filename, url in files.items():
46
  local_path = cache_dir / filename
47
  if not local_path.exists():
48
- print(f"Downloading {filename}...")
49
  download_file(url, local_path)
50
 
51
- # Load configuration
52
- with open(cache_dir / "model_config.json", "r") as f:
53
- config = json.load(f)
 
 
 
 
 
 
 
 
 
54
 
55
- # Load tokenizers
56
- with open(cache_dir / "word_to_int.json", "r") as f:
57
- word_to_int = json.load(f)
58
- with open(cache_dir / "int_to_word.json", "r") as f:
59
- int_to_word = json.load(f)
60
-
61
- # Initialize model
62
- model = TextGenerator(
63
- vocab_size=config['vocab_size'],
64
- embedding_dim=config['embedding_dim'],
65
- hidden_dim=config['hidden_dim']
66
- )
67
-
68
- # Load model weights
69
- model.load_state_dict(torch.load(cache_dir / "model.pt", map_location=torch.device('cpu')))
70
- model.eval()
71
 
72
- return model, word_to_int, int_to_word
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
  def generate_text(prompt, max_length=100):
75
- # Load model and tokenizers (will use cached files after first load)
76
- model, word_to_int, int_to_word = load_model_and_tokenizers()
77
-
78
- # Tokenize input prompt
79
- input_ids = [word_to_int.get(word, word_to_int['<UNK>']) for word in prompt.split()]
80
- input_tensor = torch.tensor([input_ids])
81
-
82
- # Generate text
83
- generated_ids = input_ids.copy()
84
-
85
- with torch.no_grad():
86
- for _ in range(max_length):
87
- current_input = torch.tensor([generated_ids[-50:]]) # Use last 50 tokens as context
88
- outputs = model(current_input)
89
- next_token_id = outputs[0, -1, :].argmax().item()
90
- generated_ids.append(next_token_id)
91
-
92
- if next_token_id == word_to_int.get('<EOS>', 0):
93
- break
94
-
95
- # Convert ids back to text
96
- generated_text = ' '.join([int_to_word.get(str(idx), '<UNK>') for idx in generated_ids])
97
- return generated_text
 
 
 
 
 
98
 
99
  # Create Gradio interface
100
  iface = gr.Interface(
 
1
  import gradio as gr
2
  import torch
3
  import json
 
4
  from torch import nn
5
  import requests
6
  from pathlib import Path
7
+ import logging
8
+
9
+ # Set up logging
10
+ logging.basicConfig(level=logging.INFO)
11
+ logger = logging.getLogger(__name__)
12
 
13
  class TextGenerator(nn.Module):
14
+ def __init__(self, vocab_size, embedding_dim=256, hidden_dim=512):
15
  super().__init__()
16
  self.embedding = nn.Embedding(vocab_size, embedding_dim)
17
  self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
 
23
  return self.fc(lstm_out)
24
 
25
  def download_file(url, local_path):
26
+ try:
27
+ response = requests.get(url)
28
+ response.raise_for_status() # Raise an exception for bad status codes
29
  Path(local_path).parent.mkdir(parents=True, exist_ok=True)
30
  with open(local_path, 'wb') as f:
31
  f.write(response.content)
32
+ logger.info(f"Successfully downloaded {url} to {local_path}")
33
+ except Exception as e:
34
+ logger.error(f"Error downloading {url}: {str(e)}")
35
+ raise
36
 
37
  def load_model_and_tokenizers():
38
  # Create a local directory for downloaded files
39
  cache_dir = Path("model_cache")
40
  cache_dir.mkdir(exist_ok=True)
41
 
42
+ # Default configuration values
43
+ default_config = {
44
+ 'vocab_size': 10000, # Default vocabulary size
45
+ 'embedding_dim': 256, # Default embedding dimension
46
+ 'hidden_dim': 512 # Default hidden dimension
47
+ }
48
+
49
  # URLs for the files
50
  base_url = "https://huggingface.co/Omarrran/temp_data/raw/main"
51
  files = {
 
59
  for filename, url in files.items():
60
  local_path = cache_dir / filename
61
  if not local_path.exists():
62
+ logger.info(f"Downloading {filename}...")
63
  download_file(url, local_path)
64
 
65
+ try:
66
+ # Load configuration
67
+ with open(cache_dir / "model_config.json", "r") as f:
68
+ config = json.load(f)
69
+ # Merge with default config
70
+ for key in default_config:
71
+ if key not in config:
72
+ logger.warning(f"Configuration parameter '{key}' not found, using default value: {default_config[key]}")
73
+ config[key] = default_config[key]
74
+ except Exception as e:
75
+ logger.warning(f"Error loading config file: {str(e)}. Using default configuration.")
76
+ config = default_config
77
 
78
+ try:
79
+ # Load tokenizers
80
+ with open(cache_dir / "word_to_int.json", "r") as f:
81
+ word_to_int = json.load(f)
82
+ with open(cache_dir / "int_to_word.json", "r") as f:
83
+ int_to_word = json.load(f)
84
+
85
+ # Update vocab size based on actual vocabulary
86
+ config['vocab_size'] = len(word_to_int)
87
+
88
+ except Exception as e:
89
+ logger.error(f"Error loading tokenizer files: {str(e)}")
90
+ raise
 
 
 
91
 
92
+ try:
93
+ # Initialize model
94
+ model = TextGenerator(
95
+ vocab_size=config['vocab_size'],
96
+ embedding_dim=config['embedding_dim'],
97
+ hidden_dim=config['hidden_dim']
98
+ )
99
+
100
+ # Load model weights
101
+ model.load_state_dict(torch.load(cache_dir / "model.pt", map_location=torch.device('cpu')))
102
+ model.eval()
103
+
104
+ return model, word_to_int, int_to_word
105
+
106
+ except Exception as e:
107
+ logger.error(f"Error loading model: {str(e)}")
108
+ raise
109
 
110
  def generate_text(prompt, max_length=100):
111
+ try:
112
+ # Load model and tokenizers
113
+ model, word_to_int, int_to_word = load_model_and_tokenizers()
114
+
115
+ # Tokenize input prompt
116
+ input_ids = [word_to_int.get(word, word_to_int.get('<UNK>', 0)) for word in prompt.split()]
117
+ input_tensor = torch.tensor([input_ids])
118
+
119
+ # Generate text
120
+ generated_ids = input_ids.copy()
121
+
122
+ with torch.no_grad():
123
+ for _ in range(max_length):
124
+ current_input = torch.tensor([generated_ids[-50:]]) # Use last 50 tokens as context
125
+ outputs = model(current_input)
126
+ next_token_id = outputs[0, -1, :].argmax().item()
127
+ generated_ids.append(next_token_id)
128
+
129
+ if next_token_id == word_to_int.get('<EOS>', 0):
130
+ break
131
+
132
+ # Convert ids back to text
133
+ generated_text = ' '.join([int_to_word.get(str(idx), '<UNK>') for idx in generated_ids])
134
+ return generated_text
135
+
136
+ except Exception as e:
137
+ logger.error(f"Error in text generation: {str(e)}")
138
+ return f"Error generating text: {str(e)}"
139
 
140
  # Create Gradio interface
141
  iface = gr.Interface(