Quentin GALLOUÉDEC commited on
Commit
04e15ae
·
1 Parent(s): 3c2fcf4
Files changed (1) hide show
  1. app.py +37 -4
app.py CHANGED
@@ -1,7 +1,40 @@
1
  import gradio as gr
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from transformers import AutoModelForCausalLM, pipeline
3
 
4
+ # Define the list of model names
5
+ models = ["gia-project/gia2-small-untrained", "gpt2"] # Add more model names as needed
6
 
7
+ # Dictionary to store loaded models and their pipelines
8
+ model_pipelines = {}
9
+
10
+ # Load a default model initially
11
+ default_model_name = "gia-project/gia2-small-untrained"
12
+ default_model = AutoModelForCausalLM.from_pretrained(default_model_name, trust_remote_code=True)
13
+ default_generator = pipeline("text-generation", model=default_model, tokenizer="gpt2", trust_remote_code=True)
14
+ model_pipelines[default_model_name] = default_generator
15
+
16
+ def generate_text(model_name, input_text):
17
+ # Check if the selected model is already loaded
18
+ if model_name not in model_pipelines:
19
+ # Load the model and create a pipeline if it's not already loaded
20
+ model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
21
+ generator = pipeline("text-generation", model=model, tokenizer="gpt2", trust_remote_code=True)
22
+ model_pipelines[model_name] = generator
23
+
24
+ # Get the pipeline for the selected model and generate text
25
+ generator = model_pipelines[model_name]
26
+ generated_text = generator(input_text)[0]['generated_text']
27
+ return generated_text
28
+
29
+ # Define the Gradio interface
30
+ iface = gr.Interface(
31
+ fn=generate_text, # Function to be called on user input
32
+ inputs=[
33
+ gr.inputs.Dropdown(choices=models, label="Select Model"), # Dropdown to select model
34
+ gr.inputs.Textbox(lines=5, label="Input Text") # Textbox for entering text
35
+ ],
36
+ outputs=gr.outputs.Textbox(label="Generated Text"), # Textbox to display the generated text
37
+ )
38
+
39
+ # Launch the Gradio interface
40
+ iface.launch()