Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,12 +1,37 @@
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
|
4 |
from huggingface_hub import hf_hub_download
|
5 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
# Model and tokenizer details
|
8 |
model_repo = "elapt1c/ElapticAI-1a"
|
9 |
-
model_filename = "model.pth" #
|
10 |
tokenizer_name = "microsoft/DialoGPT-medium"
|
11 |
|
12 |
# Device configuration
|
@@ -14,11 +39,14 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
14 |
|
15 |
# Load tokenizer
|
16 |
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
-
# Load model configuration
|
19 |
-
config = AutoConfig.from_pretrained("microsoft/DialoGPT-medium")
|
20 |
-
# Initialize model from config (important to use the same architecture)
|
21 |
-
model = AutoModelForCausalLM.from_config(config)
|
22 |
|
23 |
# Download and load model weights
|
24 |
try:
|
@@ -38,7 +66,7 @@ try:
|
|
38 |
print(f"Successfully loaded model weights from {model_repo}/{model_filename}")
|
39 |
except Exception as e:
|
40 |
print(f"Error loading model: {e}")
|
41 |
-
print("Please ensure the model repository and filename are correct.")
|
42 |
raise e # It's better to raise the error in a Space, so it's visible.
|
43 |
|
44 |
model.to(device)
|
@@ -52,8 +80,8 @@ def chat_with_model(user_input, history=[]):
|
|
52 |
input_ids = tokenizer.encode(input_text, return_tensors='pt').to(device)
|
53 |
|
54 |
with torch.no_grad():
|
55 |
-
output = model.generate(
|
56 |
-
input_ids,
|
57 |
max_length=1000, # Adjust as needed
|
58 |
pad_token_id=tokenizer.eos_token_id,
|
59 |
temperature=0.7,
|
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, AutoModel
|
4 |
from huggingface_hub import hf_hub_download
|
5 |
import os
|
6 |
+
import torch.nn as nn
|
7 |
+
|
8 |
+
# ----- Model Definition -----
|
9 |
+
class CustomDialoGPT(nn.Module):
|
10 |
+
def __init__(self, vocab_size, n_embd=768, n_head=12, n_layer=12): # <---- FORCE n_embd, n_head, n_layer to match DialoGPT-medium
|
11 |
+
super().__init__()
|
12 |
+
|
13 |
+
config = AutoConfig.from_pretrained("microsoft/DialoGPT-medium",
|
14 |
+
vocab_size=vocab_size,
|
15 |
+
n_embd=n_embd,
|
16 |
+
n_head=n_head,
|
17 |
+
n_layer=n_layer,
|
18 |
+
bos_token_id=50256,
|
19 |
+
eos_token_id=50256,
|
20 |
+
pad_token_id = 50256
|
21 |
+
)
|
22 |
+
self.transformer = AutoModelForCausalLM.from_config(config) # Use AutoModelForCausalLM here
|
23 |
+
self.lm_head = nn.Linear(n_embd, vocab_size, bias=False) # Keep lm_head
|
24 |
+
|
25 |
+
def forward(self, input_ids):
|
26 |
+
transformer_outputs = self.transformer(input_ids=input_ids, output_hidden_states=True)
|
27 |
+
hidden_states = transformer_outputs.hidden_states[-1] #get last hidden state
|
28 |
+
logits = self.lm_head(hidden_states)
|
29 |
+
return logits
|
30 |
+
|
31 |
|
32 |
# Model and tokenizer details
|
33 |
model_repo = "elapt1c/ElapticAI-1a"
|
34 |
+
model_filename = "model.pth" # <--- CHECK FILENAME ON HF HUB, UPDATE IF NEEDED!
|
35 |
tokenizer_name = "microsoft/DialoGPT-medium"
|
36 |
|
37 |
# Device configuration
|
|
|
39 |
|
40 |
# Load tokenizer
|
41 |
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
42 |
+
vocab_size = len(tokenizer)
|
43 |
+
|
44 |
+
# Initialize model with fixed parameters to match checkpoint
|
45 |
+
n_embd=768 # <---- FORCE n_embd to 768
|
46 |
+
n_head=12 # <---- FORCE n_head to 12
|
47 |
+
n_layer=12 # <---- FORCE n_layer to 12
|
48 |
+
model = CustomDialoGPT(vocab_size, n_embd, n_head, n_layer)
|
49 |
|
|
|
|
|
|
|
|
|
50 |
|
51 |
# Download and load model weights
|
52 |
try:
|
|
|
66 |
print(f"Successfully loaded model weights from {model_repo}/{model_filename}")
|
67 |
except Exception as e:
|
68 |
print(f"Error loading model: {e}")
|
69 |
+
print("Please ensure the model repository and filename are correct and that the model architecture in app.py matches the checkpoint.")
|
70 |
raise e # It's better to raise the error in a Space, so it's visible.
|
71 |
|
72 |
model.to(device)
|
|
|
80 |
input_ids = tokenizer.encode(input_text, return_tensors='pt').to(device)
|
81 |
|
82 |
with torch.no_grad():
|
83 |
+
output = model.transformer.generate( # Use model.transformer.generate here
|
84 |
+
inputs=input_ids, # Use inputs instead of input_ids
|
85 |
max_length=1000, # Adjust as needed
|
86 |
pad_token_id=tokenizer.eos_token_id,
|
87 |
temperature=0.7,
|