yoonusajwardapiit commited on
Commit
ea9f47a
1 Parent(s): 49b2bf5

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -8
app.py CHANGED
@@ -101,19 +101,28 @@ def load_model():
101
 
102
  model = load_model()
103
 
104
- # Define encode and decode functions
105
- chars = sorted(list(set("your_training_text_here"))) # Replace with the character set used in training
106
  stoi = {ch: i for i, ch in enumerate(chars)}
107
  itos = {i: ch for i, ch in enumerate(chars)}
108
- encode = lambda s: [stoi[c] for c in s]
109
  decode = lambda l: ''.join([itos[i] for i in l])
110
 
111
  # Function to generate text using the model
112
  def generate_text(prompt):
113
- context = torch.tensor([encode(prompt)], dtype=torch.long)
114
- with torch.no_grad():
115
- generated = model.generate(context, max_new_tokens=250) # Adjust as needed
116
- return decode(generated[0].tolist())
 
 
 
 
 
 
 
 
 
117
 
118
  # Create a Gradio interface
119
  interface = gr.Interface(
@@ -125,4 +134,4 @@ interface = gr.Interface(
125
  )
126
 
127
  # Launch the interface
128
- interface.launch()
 
101
 
102
  model = load_model()
103
 
104
+ # Define a comprehensive character set based on training data
105
+ chars = sorted(list(set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789 .,!?-:;'\"\n")))
106
  stoi = {ch: i for i, ch in enumerate(chars)}
107
  itos = {i: ch for i, ch in enumerate(chars)}
108
+ encode = lambda s: [stoi[c] for c in s if c in stoi] # Ensures only known characters are encoded
109
  decode = lambda l: ''.join([itos[i] for i in l])
110
 
111
  # Function to generate text using the model
112
  def generate_text(prompt):
113
+ try:
114
+ print(f"Received prompt: {prompt}")
115
+ context = torch.tensor([encode(prompt)], dtype=torch.long)
116
+ print(f"Encoded prompt: {context}")
117
+ with torch.no_grad():
118
+ generated = model.generate(context, max_new_tokens=250) # Adjust as needed
119
+ print(f"Generated tensor: {generated}")
120
+ result = decode(generated[0].tolist())
121
+ print(f"Decoded result: {result}")
122
+ return result
123
+ except Exception as e:
124
+ print(f"Error during generation: {e}")
125
+ return f"Error: {str(e)}"
126
 
127
  # Create a Gradio interface
128
  interface = gr.Interface(
 
134
  )
135
 
136
  # Launch the interface
137
+ interface.launch(share=True)