kuyesu22 commited on
Commit
cbb3a38
·
verified ·
1 Parent(s): df02bc4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -27
app.py CHANGED
@@ -1,51 +1,73 @@
1
  import torch
2
  from peft import PeftModel, PeftConfig
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
-
5
  from huggingface_hub import login
6
  import os
 
7
 
8
- access_token = os.environ["HUGGING_FACE_HUB_TOKEN"]
 
9
  login(token=access_token)
10
 
11
-
12
- peft_model_id = f"kuyesu22/sunbird-ug-lang-v1.0-bloom-7b1-lora"
13
  config = PeftConfig.from_pretrained(peft_model_id)
 
 
14
  model = AutoModelForCausalLM.from_pretrained(
15
  config.base_model_name_or_path,
16
- return_dict=True,
17
- device_map="auto",
18
- offload_folder="offload/"
19
  )
20
  tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
21
 
22
- # Load the Lora model
23
- model = PeftModel.from_pretrained(model, peft_model_id, offload_folder = "offload/")
24
 
 
 
25
 
26
- def make_inference(english):
 
 
27
  batch = tokenizer(
28
- f"### English:\n{english}: \n\n### Runyankole:",
29
  return_tensors="pt",
30
- )
 
 
31
 
32
- with torch.cuda.amp.autocast():
33
- output_tokens = model.generate(**batch, max_new_tokens=100)
 
 
 
 
 
 
 
 
 
 
34
 
35
- return tokenizer.decode(output_tokens[0], skip_special_tokens=True)
 
 
36
 
 
 
 
 
37
 
38
- if __name__ == "__main__":
39
- # make a gradio interface
40
- import gradio as gr
41
- outputs=gr.components.Textbox(label="Runyakole")
42
- inputs=gr.components.Textbox(lines=2, label="English")
43
  gr.Interface(
44
- make_inference,
45
- [
46
- inputs,
47
- ],
48
- outputs,
49
- title="Sunbird lang Ug",
50
- description="English to Runyankole.",
51
  ).launch()
 
 
 
 
 
1
  import torch
2
  from peft import PeftModel, PeftConfig
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
4
  from huggingface_hub import login
5
  import os
6
+ import gradio as gr
7
 
8
+ # Login to Hugging Face Hub
9
+ access_token = os.environ.get("HUGGING_FACE_HUB_TOKEN")
10
  login(token=access_token)
11
 
12
+ # Define model details
13
+ peft_model_id = "kuyesu22/sunbird-ug-lang-v1.0-bloom-7b1-lora"
14
  config = PeftConfig.from_pretrained(peft_model_id)
15
+
16
+ # Load model and tokenizer
17
  model = AutoModelForCausalLM.from_pretrained(
18
  config.base_model_name_or_path,
19
+ torch_dtype=torch.float16, # Use mixed precision for speed
20
+ device_map="auto" # Automatically allocate to available devices
 
21
  )
22
  tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
23
 
24
+ # Load the Lora fine-tuned model
25
+ model = PeftModel.from_pretrained(model, peft_model_id)
26
 
27
+ # Ensure model is in evaluation mode
28
+ model.eval()
29
 
30
+ # Define inference function
31
+ def make_inference(english_text):
32
+ # Tokenize the input English sentence
33
  batch = tokenizer(
34
+ f"### English:\n{english_text}\n\n### Runyankole:",
35
  return_tensors="pt",
36
+ padding=True,
37
+ truncation=True
38
+ ).to(model.device) # Move batch to the same device as the model
39
 
40
+ # Generate the translation using the model
41
+ with torch.no_grad():
42
+ with torch.cuda.amp.autocast(): # Mixed precision inference
43
+ output_tokens = model.generate(
44
+ input_ids=batch["input_ids"],
45
+ attention_mask=batch["attention_mask"],
46
+ max_new_tokens=100,
47
+ do_sample=True, # Enables sampling for more creative responses
48
+ temperature=0.7, # Control randomness in predictions
49
+ num_return_sequences=1, # Return only one translation
50
+ pad_token_id=tokenizer.eos_token_id # Handle padding tokens
51
+ )
52
 
53
+ # Decode the output tokens to get the translation
54
+ translated_text = tokenizer.decode(output_tokens[0], skip_special_tokens=True)
55
+ return translated_text
56
 
57
+ # Gradio Interface
58
+ def launch_gradio_interface():
59
+ inputs = gr.components.Textbox(lines=2, label="English Text") # Input text in English
60
+ outputs = gr.components.Textbox(label="Translated Runyankole Text") # Output in Runyankole
61
 
62
+ # Launch Gradio app
 
 
 
 
63
  gr.Interface(
64
+ fn=make_inference,
65
+ inputs=inputs,
66
+ outputs=outputs,
67
+ title="Sunbird Lang Translator",
68
+ description="Translate English to Runyankole using BLOOM model fine-tuned with LoRA.",
 
 
69
  ).launch()
70
+
71
+ # Entry point to run the Gradio app
72
+ if __name__ == "__main__":
73
+ launch_gradio_interface()