Spaces:
Sleeping
Sleeping
yoonusajwardapiit
commited on
Commit
•
d2fde25
1
Parent(s):
64226ef
Upload app.py
Browse files
app.py
CHANGED
@@ -1,8 +1,9 @@
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
import torch.nn as nn
|
|
|
4 |
|
5 |
-
# Define
|
6 |
class Head(nn.Module):
|
7 |
def __init__(self, head_size):
|
8 |
super().__init__()
|
@@ -103,7 +104,6 @@ def load_model():
|
|
103 |
model = load_model()
|
104 |
|
105 |
# Define a comprehensive character set based on training data
|
106 |
-
# Convert all input to lowercase if the model is trained on lowercase data
|
107 |
chars = sorted(list(set("abcdefghijklmnopqrstuvwxyz0123456789 .,!?-:;'\"\n")))
|
108 |
stoi = {ch: i for i, ch in enumerate(chars)}
|
109 |
itos = {i: ch for i, ch in enumerate(chars)}
|
@@ -113,6 +113,7 @@ decode = lambda l: ''.join([itos[i] for i in l])
|
|
113 |
# Function to generate text using the model
|
114 |
def generate_text(prompt):
|
115 |
try:
|
|
|
116 |
print(f"Received prompt: {prompt}")
|
117 |
encoded_prompt = encode(prompt)
|
118 |
|
@@ -128,11 +129,12 @@ def generate_text(prompt):
|
|
128 |
print(f"Encoded prompt: {context}")
|
129 |
|
130 |
with torch.no_grad():
|
131 |
-
generated = model.generate(context, max_new_tokens=
|
132 |
print(f"Generated tensor: {generated}")
|
133 |
|
134 |
result = decode(generated[0].tolist())
|
135 |
print(f"Decoded result: {result}")
|
|
|
136 |
return result
|
137 |
except Exception as e:
|
138 |
print(f"Error during generation: {e}")
|
@@ -144,7 +146,8 @@ interface = gr.Interface(
|
|
144 |
inputs=gr.Textbox(lines=2, placeholder="Enter a location or prompt..."),
|
145 |
outputs="text",
|
146 |
title="Triptuner Model",
|
147 |
-
description="Generate itineraries for locations in Sri Lanka's Central Province."
|
|
|
148 |
)
|
149 |
|
150 |
# Launch the interface
|
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
import torch.nn as nn
|
4 |
+
import time
|
5 |
|
6 |
+
# Define the custom model class with detailed layer structures
|
7 |
class Head(nn.Module):
|
8 |
def __init__(self, head_size):
|
9 |
super().__init__()
|
|
|
104 |
model = load_model()
|
105 |
|
106 |
# Define a comprehensive character set based on training data
|
|
|
107 |
chars = sorted(list(set("abcdefghijklmnopqrstuvwxyz0123456789 .,!?-:;'\"\n")))
|
108 |
stoi = {ch: i for i, ch in enumerate(chars)}
|
109 |
itos = {i: ch for i, ch in enumerate(chars)}
|
|
|
113 |
# Function to generate text using the model
|
114 |
def generate_text(prompt):
|
115 |
try:
|
116 |
+
start_time = time.time()
|
117 |
print(f"Received prompt: {prompt}")
|
118 |
encoded_prompt = encode(prompt)
|
119 |
|
|
|
129 |
print(f"Encoded prompt: {context}")
|
130 |
|
131 |
with torch.no_grad():
|
132 |
+
generated = model.generate(context, max_new_tokens=20) # Reduced tokens to speed up
|
133 |
print(f"Generated tensor: {generated}")
|
134 |
|
135 |
result = decode(generated[0].tolist())
|
136 |
print(f"Decoded result: {result}")
|
137 |
+
print(f"Processing time: {time.time() - start_time:.2f}s")
|
138 |
return result
|
139 |
except Exception as e:
|
140 |
print(f"Error during generation: {e}")
|
|
|
146 |
inputs=gr.Textbox(lines=2, placeholder="Enter a location or prompt..."),
|
147 |
outputs="text",
|
148 |
title="Triptuner Model",
|
149 |
+
description="Generate itineraries for locations in Sri Lanka's Central Province.",
|
150 |
+
theme="compact", # Add a theme for better UI appearance
|
151 |
)
|
152 |
|
153 |
# Launch the interface
|