|
import gradio as gr |
|
from transformers import pipeline |
|
|
|
|
|
model = pipeline( |
|
"text-generation", |
|
model="rish13/polymers", |
|
device=0 |
|
) |
|
|
|
def generate_response(prompt): |
|
|
|
response = model(prompt, max_length=100, num_return_sequences=1, temperature=0.7) |
|
|
|
|
|
generated_text = response[0]['generated_text'] |
|
|
|
|
|
end_punctuation = ['.', '!', '?'] |
|
end_position = -1 |
|
for punct in end_punctuation: |
|
pos = generated_text.find(punct) |
|
if pos != -1 and (end_position == -1 or pos < end_position): |
|
end_position = pos |
|
|
|
|
|
if end_position != -1: |
|
generated_text = generated_text[:end_position + 1] |
|
|
|
return generated_text |
|
|
|
|
|
interface = gr.Interface( |
|
fn=generate_response, |
|
inputs=gr.Textbox(lines=2, placeholder="Enter your prompt here...", label="Prompt"), |
|
outputs="text", |
|
title="Polymer Knowledge Model", |
|
description="A model fine-tuned for generating text related to polymers." |
|
) |
|
|
|
|
|
interface.launch() |
|
|