Arnesh27 commited on
Commit
ec8cf8a
·
verified ·
1 Parent(s): e998082

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -12
app.py CHANGED
@@ -2,19 +2,31 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
2
  import gradio as gr
3
  import torch
4
 
5
- # Load a smaller model to reduce memory usage
6
- model_name = "distilgpt2" # Smaller model
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
- model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
9
 
10
- def generate_text(input_text):
11
- # Ensure input is in the correct format
12
- input_tensor = tokenizer(input_text, return_tensors="pt") # Removed clean_up_tokenization_spaces
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
- # Generate text with a limit on max_length to reduce memory usage
15
- output = model.generate(**input_tensor, max_length=50) # Adjust max_length as needed
16
- response = tokenizer.decode(output[0], skip_special_tokens=True)
17
- return response
18
 
19
- iface = gr.Interface(fn=generate_text, inputs="text", outputs="text", allow_flagging="never")
20
- iface.launch(server_name="0.0.0.0", server_port=7860)
 
2
  import gradio as gr
3
  import torch
4
 
5
+ # Load a model suited for code generation
6
+ model_name = "Salesforce/codegen-350M-mono" # This is a smaller model, choose one suited for your task
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
+ model = AutoModelForCausalLM.from_pretrained(model_name)
9
 
10
+ # Set the device
11
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+ model.to(device)
13
+
14
+ def generate_code(prompt):
15
+ # Prepare the input for the model
16
+ input_tensor = tokenizer(prompt, return_tensors="pt").to(device)
17
+
18
+ # Generate code based on the prompt
19
+ with torch.no_grad():
20
+ generated_ids = model.generate(
21
+ input_tensor['input_ids'],
22
+ max_length=300, # You can adjust this length
23
+ num_beams=5, # This controls the diversity of outputs
24
+ early_stopping=True
25
+ )
26
 
27
+ # Decode and return the generated code
28
+ generated_code = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
29
+ return generated_code
 
30
 
31
+ iface = gr.Interface(fn=generate_code, inputs="text", outputs="text", allow_flagging="never")
32
+ iface.launch(server_name="0.0.0.0", server_port=7860)