cheberle commited on
Commit
ec46241
·
1 Parent(s): 7b18dcf
Files changed (2) hide show
  1. app.py +147 -0
  2. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import (
4
+ LlamaForCausalLM,
5
+ LlamaTokenizer,
6
+ GenerationConfig
7
+ )
8
+ from peft import PeftModel
9
+
10
+ # ------------------------------------------------------------------------------
11
+ # CONFIGURE MODEL & PIPELINE
12
+ # ------------------------------------------------------------------------------
13
+ BASE_MODEL = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
14
+ FINETUNED_ADAPTER = "cheberle/autotrain-llama-milch"
15
+
16
+ # Generation hyperparameters
17
+ DEFAULT_MAX_NEW_TOKENS = 256
18
+ DEFAULT_TEMPERATURE = 0.7
19
+ DEFAULT_TOP_K = 50
20
+ DEFAULT_TOP_P = 0.9
21
+
22
+ # Load tokenizer from base model
23
+ tokenizer = LlamaTokenizer.from_pretrained(
24
+ BASE_MODEL
25
+ )
26
+
27
+ # Load the base model
28
+ base_model = LlamaForCausalLM.from_pretrained(
29
+ BASE_MODEL,
30
+ device_map="auto", # Automatically use GPU if available
31
+ torch_dtype=torch.float16 # Use half-precision to save memory
32
+ )
33
+
34
+ # Load the PEFT (LoRA) adapter on top of the base model
35
+ model = PeftModel.from_pretrained(
36
+ base_model,
37
+ FINETUNED_ADAPTER,
38
+ torch_dtype=torch.float16
39
+ )
40
+
41
+ model.eval() # put in eval mode
42
+
43
+ # ------------------------------------------------------------------------------
44
+ # GENERATION FUNCTION
45
+ # ------------------------------------------------------------------------------
46
+ def generate_text(prompt,
47
+ max_new_tokens=DEFAULT_MAX_NEW_TOKENS,
48
+ temperature=DEFAULT_TEMPERATURE,
49
+ top_k=DEFAULT_TOP_K,
50
+ top_p=DEFAULT_TOP_P):
51
+ """Generate text from the finetuned model using the given parameters."""
52
+ # Tokenize the prompt
53
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
54
+
55
+ # Set up generation configuration
56
+ generation_config = GenerationConfig(
57
+ max_new_tokens=max_new_tokens,
58
+ temperature=temperature,
59
+ top_k=top_k,
60
+ top_p=top_p,
61
+ do_sample=True,
62
+ repetition_penalty=1.1, # adjust if needed
63
+ )
64
+
65
+ # Generate
66
+ with torch.no_grad():
67
+ output_tokens = model.generate(
68
+ **inputs,
69
+ generation_config=generation_config
70
+ )
71
+
72
+ # Decode the generated tokens
73
+ generated_text = tokenizer.decode(output_tokens[0], skip_special_tokens=True)
74
+
75
+ # Remove the original prompt from the beginning to return only new text
76
+ if generated_text.startswith(prompt):
77
+ return generated_text[len(prompt):].strip()
78
+ else:
79
+ return generated_text
80
+
81
+ # ------------------------------------------------------------------------------
82
+ # GRADIO APP
83
+ # ------------------------------------------------------------------------------
84
+ def clear_inputs():
85
+ return "", ""
86
+
87
+ with gr.Blocks(css=".gradio-container {max-width: 800px; margin: auto;}") as demo:
88
+ gr.Markdown("## DeepSeek R1 Distill-Llama 8B + LoRA from `cheberle/autotrain-llama-milch`")
89
+ gr.Markdown(
90
+ "This app uses a base **DeepSeek R1 Distill-Llama 8B** model with "
91
+ "the **LoRA/PEFT adapter** from [`cheberle/autotrain-llama-milch`].\n\n"
92
+ "Type in a prompt, adjust generation parameters if you wish, and click 'Generate'."
93
+ )
94
+
95
+ with gr.Row():
96
+ with gr.Column():
97
+ prompt = gr.Textbox(
98
+ label="Prompt",
99
+ placeholder="Ask me anything...",
100
+ lines=5
101
+ )
102
+ with gr.Accordion("Advanced Generation Settings", open=False):
103
+ max_new_tokens = gr.Slider(
104
+ 16, 1024,
105
+ value=DEFAULT_MAX_NEW_TOKENS,
106
+ step=1,
107
+ label="Max New Tokens"
108
+ )
109
+ temperature = gr.Slider(
110
+ 0.0, 2.0,
111
+ value=DEFAULT_TEMPERATURE,
112
+ step=0.1,
113
+ label="Temperature"
114
+ )
115
+ top_k = gr.Slider(
116
+ 0, 100,
117
+ value=DEFAULT_TOP_K,
118
+ step=1,
119
+ label="Top-k"
120
+ )
121
+ top_p = gr.Slider(
122
+ 0.0, 1.0,
123
+ value=DEFAULT_TOP_P,
124
+ step=0.05,
125
+ label="Top-p"
126
+ )
127
+
128
+ generate_btn = gr.Button("Generate", variant="primary")
129
+ clear_btn = gr.Button("Clear")
130
+
131
+ with gr.Column():
132
+ output = gr.Textbox(
133
+ label="Model Output",
134
+ lines=12
135
+ )
136
+
137
+ # Button Actions
138
+ generate_btn.click(
139
+ fn=generate_text,
140
+ inputs=[prompt, max_new_tokens, temperature, top_k, top_p],
141
+ outputs=output
142
+ )
143
+
144
+ clear_btn.click(fn=clear_inputs, inputs=[], outputs=[prompt, output])
145
+
146
+ demo.queue(concurrency_count=1)
147
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ huggingface_hub==0.25.2
2
+ transformers>=4.30.0
3
+ accelerate
4
+ sentencepiece
5
+ peft
6
+ gradio