yoonusajwardapiit commited on
Commit
d2fde25
1 Parent(s): 64226ef

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -4
app.py CHANGED
@@ -1,8 +1,9 @@
1
  import gradio as gr
2
  import torch
3
  import torch.nn as nn
 
4
 
5
- # Define your custom model class with detailed layer structures
6
  class Head(nn.Module):
7
  def __init__(self, head_size):
8
  super().__init__()
@@ -103,7 +104,6 @@ def load_model():
103
  model = load_model()
104
 
105
  # Define a comprehensive character set based on training data
106
- # Convert all input to lowercase if the model is trained on lowercase data
107
  chars = sorted(list(set("abcdefghijklmnopqrstuvwxyz0123456789 .,!?-:;'\"\n")))
108
  stoi = {ch: i for i, ch in enumerate(chars)}
109
  itos = {i: ch for i, ch in enumerate(chars)}
@@ -113,6 +113,7 @@ decode = lambda l: ''.join([itos[i] for i in l])
113
  # Function to generate text using the model
114
  def generate_text(prompt):
115
  try:
 
116
  print(f"Received prompt: {prompt}")
117
  encoded_prompt = encode(prompt)
118
 
@@ -128,11 +129,12 @@ def generate_text(prompt):
128
  print(f"Encoded prompt: {context}")
129
 
130
  with torch.no_grad():
131
- generated = model.generate(context, max_new_tokens=250) # Adjust as needed
132
  print(f"Generated tensor: {generated}")
133
 
134
  result = decode(generated[0].tolist())
135
  print(f"Decoded result: {result}")
 
136
  return result
137
  except Exception as e:
138
  print(f"Error during generation: {e}")
@@ -144,7 +146,8 @@ interface = gr.Interface(
144
  inputs=gr.Textbox(lines=2, placeholder="Enter a location or prompt..."),
145
  outputs="text",
146
  title="Triptuner Model",
147
- description="Generate itineraries for locations in Sri Lanka's Central Province."
 
148
  )
149
 
150
  # Launch the interface
 
1
  import gradio as gr
2
  import torch
3
  import torch.nn as nn
4
+ import time
5
 
6
+ # Define the custom model class with detailed layer structures
7
  class Head(nn.Module):
8
  def __init__(self, head_size):
9
  super().__init__()
 
104
  model = load_model()
105
 
106
  # Define a comprehensive character set based on training data
 
107
  chars = sorted(list(set("abcdefghijklmnopqrstuvwxyz0123456789 .,!?-:;'\"\n")))
108
  stoi = {ch: i for i, ch in enumerate(chars)}
109
  itos = {i: ch for i, ch in enumerate(chars)}
 
113
  # Function to generate text using the model
114
  def generate_text(prompt):
115
  try:
116
+ start_time = time.time()
117
  print(f"Received prompt: {prompt}")
118
  encoded_prompt = encode(prompt)
119
 
 
129
  print(f"Encoded prompt: {context}")
130
 
131
  with torch.no_grad():
132
+ generated = model.generate(context, max_new_tokens=20) # Reduced tokens to speed up
133
  print(f"Generated tensor: {generated}")
134
 
135
  result = decode(generated[0].tolist())
136
  print(f"Decoded result: {result}")
137
+ print(f"Processing time: {time.time() - start_time:.2f}s")
138
  return result
139
  except Exception as e:
140
  print(f"Error during generation: {e}")
 
146
  inputs=gr.Textbox(lines=2, placeholder="Enter a location or prompt..."),
147
  outputs="text",
148
  title="Triptuner Model",
149
+ description="Generate itineraries for locations in Sri Lanka's Central Province.",
150
+ theme="compact", # Add a theme for better UI appearance
151
  )
152
 
153
  # Launch the interface