File size: 3,121 Bytes
6df6b48
a776bf7
be02d5b
 
6df6b48
 
 
 
be02d5b
a776bf7
 
 
 
 
be02d5b
6df6b48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a776bf7
6df6b48
a776bf7
 
 
6df6b48
 
 
 
 
 
 
 
 
 
 
be02d5b
6df6b48
 
 
 
 
 
 
 
 
 
 
 
 
 
be02d5b
 
6df6b48
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer, AutoModelForTokenClassification, AutoTokenizer, pipeline
import gradio as gr

# Load GPT-2 model and tokenizer
MODEL_NAME = "gpt2"
tokenizer = GPT2Tokenizer.from_pretrained(MODEL_NAME)
model = GPT2LMHeadModel.from_pretrained(MODEL_NAME)

# Load BioBERT for medical NLP without requiring pip install
BIOBERT_MODEL = "dmis-lab/biobert-base-cased-v1.1"
biobert_tokenizer = AutoTokenizer.from_pretrained(BIOBERT_MODEL)
biobert_model = AutoModelForTokenClassification.from_pretrained(BIOBERT_MODEL)
nlp_pipeline = pipeline("ner", model=biobert_model, tokenizer=biobert_tokenizer)

# Conversation memory
conversation_history = []

# Function to generate text with medical NLP support
def generate_text(prompt, max_length=100, temperature=0.7, top_k=50, top_p=0.9, clean_output=False, stopwords="", num_responses=1, extract_entities=False):
    global conversation_history
    
    # Combine conversation history with new prompt
    full_prompt = "\n".join(conversation_history + [prompt])
    inputs = tokenizer(full_prompt, return_tensors="pt")
    
    responses = []
    for _ in range(num_responses):
        with torch.no_grad():
            output = model.generate(
                **inputs,
                max_length=max_length,
                temperature=temperature,
                top_k=top_k,
                top_p=top_p,
                do_sample=True
            )
        text = tokenizer.decode(output[0], skip_special_tokens=True)
        
        # Text cleaning
        if clean_output:
            text = text.replace("\n", " ").strip()
        
        # Stopword filtering
        for word in stopwords.split(","):
            text = text.replace(word.strip(), "")
        
        # Extract medical entities using BioBERT
        if extract_entities:
            entities = nlp_pipeline(text)
            extracted_entities = set([entity["word"] for entity in entities])
            text += "\n\nExtracted Medical Entities:\n" + "\n".join(extracted_entities)
        
        responses.append(text)
    
    # Update conversation history
    conversation_history.append(prompt)
    conversation_history.append(responses[0])  # Store only the first response
    
    return "\n\n".join(responses)

# Gradio Interface
demo = gr.Interface(
    fn=generate_text,
    inputs=[
        gr.Textbox(label="Medical Query"),
        gr.Slider(50, 500, step=10, label="Max Length"),
        gr.Slider(0.1, 1.5, step=0.1, label="Temperature"),
        gr.Slider(0, 100, step=5, label="Top-K"),
        gr.Slider(0.0, 1.0, step=0.1, label="Top-P"),
        gr.Checkbox(label="Clean Output"),
        gr.Textbox(label="Stopwords (comma-separated)"),
        gr.Slider(1, 5, step=1, label="Number of Responses"),
        gr.Checkbox(label="Extract Medical Entities")
    ],
    outputs=gr.Textbox(label="Generated Medical Text"),
    title="Medical AI Assistant",
    description="Enter a medical-related prompt and adjust parameters to generate AI-assisted text. Supports entity recognition for medical terms.",
)

# Launch the app
demo.launch()