remiai3 commited on
Commit
ca7dbcc
·
verified ·
1 Parent(s): a669ee8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -38
app.py CHANGED
@@ -1,39 +1,39 @@
1
- from flask import Flask, render_template, request
2
- import torch
3
- from transformers import AutoTokenizer, AutoModelForCausalLM
4
-
5
- app = Flask(__name__)
6
-
7
- # Load fine-tuned model and tokenizer
8
- model_path = "./finetuned_codegen"
9
- tokenizer = AutoTokenizer.from_pretrained(model_path)
10
- model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16)
11
-
12
- # Set padding token
13
- tokenizer.pad_token = tokenizer.eos_token
14
-
15
- # Move model to CPU
16
- device = torch.device("cpu")
17
- model.to(device)
18
-
19
- @app.route("/", methods=["GET", "POST"])
20
- def index():
21
- generated_code = ""
22
- prompt = ""
23
- if request.method == "POST":
24
- prompt = request.form["prompt"]
25
- inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=128).to(device)
26
- outputs = model.generate(
27
- **inputs,
28
- max_length=200,
29
- num_return_sequences=1,
30
- pad_token_id=tokenizer.eos_token_id,
31
- do_sample=True,
32
- temperature=0.7,
33
- top_p=0.9
34
- )
35
- generated_code = tokenizer.decode(outputs[0], skip_special_tokens=True)
36
- return render_template("index.html", generated_code=generated_code, prompt=prompt)
37
-
38
- if __name__ == "__main__":
39
  app.run(debug=True)
 
1
+ from flask import Flask, render_template, request
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
+
5
+ app = Flask(__name__)
6
+
7
+ # Load fine-tuned model and tokenizer
8
+ model_path = "./finetuned_codegen"
9
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
10
+ model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float32)
11
+
12
+ # Set padding token
13
+ tokenizer.pad_token = tokenizer.eos_token
14
+
15
+ # Move model to CPU
16
+ device = torch.device("cpu")
17
+ model.to(device)
18
+
19
+ @app.route("/", methods=["GET", "POST"])
20
+ def index():
21
+ generated_code = ""
22
+ prompt = ""
23
+ if request.method == "POST":
24
+ prompt = request.form["prompt"]
25
+ inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=128).to(device)
26
+ outputs = model.generate(
27
+ **inputs,
28
+ max_length=200,
29
+ num_return_sequences=1,
30
+ pad_token_id=tokenizer.eos_token_id,
31
+ do_sample=True,
32
+ temperature=0.7,
33
+ top_p=0.9
34
+ )
35
+ generated_code = tokenizer.decode(outputs[0], skip_special_tokens=True)
36
+ return render_template("index.html", generated_code=generated_code, prompt=prompt)
37
+
38
+ if __name__ == "__main__":
39
  app.run(debug=True)