File size: 8,138 Bytes
e4f5c0d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
import os
import json
import logging
from enum import Enum
from pydantic import BaseModel, Field
import pandas as pd
from huggingface_hub import InferenceClient
from tenacity import retry, stop_after_attempt, wait_exponential

# Configure logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

# Create handlers
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)

file_handler = logging.FileHandler("hf_api.log")
file_handler.setLevel(logging.INFO)

# Create formatters and add to handlers
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
console_handler.setFormatter(formatter)
file_handler.setFormatter(formatter)

# Add handlers to the logger
if not logger.handlers:
    logger.addHandler(console_handler)
    logger.addHandler(file_handler)

# Validate and retrieve the Hugging Face API token
HF_TOKEN = os.environ.get('HF_TOKEN')
if not HF_TOKEN:
    logger.error("Hugging Face API token not found. Set the HF_TOKEN environment variable.")
    raise EnvironmentError("HF_TOKEN environment variable is not set.")

# Initialize the InferenceClient
MODEL_NAME1 = "meta-llama/Llama-3.1-8B-Instruct"
MODEL_NAME2 = "Qwen/Qwen2.5-72B-Instruct"
try:
    client1 = InferenceClient(model=MODEL_NAME1, token=HF_TOKEN)
    logger.info(f"InferenceClient for model '{MODEL_NAME1}' instantiated successfully.")
except Exception as e:
    logger.error(f"Failed to instantiate InferenceClient for model '{MODEL_NAME1}': {e}")
    raise

try:
    client2 = InferenceClient(model=MODEL_NAME2, token=HF_TOKEN)
    logger.info(f"InferenceClient for model '{MODEL_NAME2}' instantiated successfully.")
except Exception as e:
    logger.error(f"Failed to instantiate InferenceClient for model '{MODEL_NAME2}': {e}")
    raise

# Define Pydantic schemas
class EvaluationSchema(BaseModel):
    reasoning: str
    relevance_score: int = Field(ge=0, le=10)

class TopicEnum(Enum):
    Rheumatoid_Arthritis = "Rheumatoid Arthritis"
    Systemic_Lupus_Erythematosus = "Systemic Lupus Erythematosus"
    Scleroderma = "Scleroderma"
    Sjogren_s_Disease = "Sjogren's Disease"
    Ankylosing_Spondylitis = "Ankylosing Spondylitis"
    Psoriatic_Arthritis = "Psoriatic Arthritis"
    Gout = "Gout"
    Vasculitis = "Vasculitis"
    Osteoarthritis = "Osteoarthritis"
    Infectious_Diseases = "Infectious Diseases"
    Immunology = "Immunology"
    Genetics = "Genetics"
    Biologics = "Biologics"
    Biosimilars = "Biosimilars"
    Small_Molecules = "Small Molecules"
    Clinical_Trials = "Clinical Trials"
    Health_Policy = "Health Policy"
    Patient_Education = "Patient Education"
    Other_Rheumatic_Diseases = "Other Rheumatic Diseases"

class SummarySchema(BaseModel):
    summary: str
    # Enum for topic
    topic: TopicEnum = TopicEnum.Other_Rheumatic_Diseases

class PaperSchema(BaseModel):
    title: str
    authors: str
    journal: str
    pmid: str

class TopicSummarySchema(BaseModel):
    planning: str
    summary: str

def evaluate_relevance(title: str, abstract: str) -> EvaluationSchema:
    prompt = f"""
    Title: {title}
    Abstract: {abstract}
    Instructions: Evaluate the relevance of this medical abstract for an audience of rheumatologists on a scale of 0 to 10 with 10 being reserved only for large clinical trials in rheumatology.
    Be very discerning and only give a score above 8 for papers that are highly clinically relevant to rheumatologists.
    Respond in JSON format using the following schema:
    {json.dumps(EvaluationSchema.model_json_schema())}
    """
    
    try:
        response = client1.text_generation(
            prompt,
            max_new_tokens=512,
            temperature=0.2,
            grammar={"type": "json", "value": EvaluationSchema.model_json_schema()}
        )
        result = json.loads(response)
        return result
    except Exception as e:
        logger.error(f"Error in evaluate_relevance: {e}")
        raise

def summarize_abstract(abstract: str) -> SummarySchema:
    prompt = f"""
    Abstract: {abstract}
    Instructions: Summarize this medical abstract in 1 sentence and select the most relevant topic from the following enum:
    {TopicEnum.__doc__}
    Respond in JSON format using the following schema:
    {json.dumps(SummarySchema.model_json_schema())}
    """
    
    try:
        response = client1.text_generation(
            prompt,
            max_new_tokens=512,
            temperature=0.2,
            grammar={"type": "json", "value": SummarySchema.model_json_schema()}
        )
        result = json.loads(response)
        return result
    except Exception as e:
        logger.error(f"Error in summarize_abstract: {e}")
        raise

@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10))
def _make_api_call(client, prompt, max_tokens=4096, temp=0.2, schema=None):
    try:
        response = client.text_generation(
            prompt,
            max_new_tokens=max_tokens,
            temperature=temp,
            grammar={"type": "json", "value": schema} if schema else None
        )
        return json.loads(response)
    except Exception as e:
        logger.error(f"API call failed: {e}")
        raise

def compose_newsletter(papers: pd.DataFrame) -> str:
    if papers.empty:
        logger.info("No papers provided to compose the newsletter.")
        return ""

    content = ["# This Week in Rheumatology\n"]
    topics = papers['Topic'].unique()
    
    for topic in topics:
        try:
            relevant_papers = papers[papers['Topic'] == topic]
            # Convert to dict with lowercase keys to match the expected schema
            papers_dict = relevant_papers.rename(columns={
                'Title': 'title',
                'Authors': 'authors',
                'Journal': 'journal',
                'PMID': 'pmid',
                'Summary': 'summary'
            }).to_dict('records')
            
            prompt = f"""
            Instructions: Generate a brief summary of the latest research on {topic} using the following papers.
            Papers: {json.dumps(papers_dict)}
            Respond in JSON format using the following schema:
            {json.dumps(TopicSummarySchema.model_json_schema())}
            You have the option of using the planning field first to organize your thoughts before writing the summary.
            The summary should be concise, but because you are summarizing several papers, it should be detailed enough to give the reader a good idea of the latest research in the field. 
            The papers may be somewhat disjointed, so you will need to think carefully about how you can transition between them with clever wording.
            You can use anywhere from 1 to 3 paragraphs for the summary.
            """

            result = _make_api_call(
                client2, 
                prompt, 
                max_tokens=4096,
                temp=0.2,
                schema=TopicSummarySchema.model_json_schema()
            )
            
            # Log the raw response for debugging
            logger.debug(f"Raw response from Hugging Face: {result}")
            
            # Parse the JSON response
            summary = TopicSummarySchema(**result)
            
            # Convert the structured summary to Markdown
            topic_content = f"## {topic}\n\n"
            topic_content += f"{summary.summary}\n\n"
            
            # Add a references section
            topic_content += "### References\n\n"
            relevant_papers = papers[papers['Topic'] == topic]
            for _, paper in relevant_papers.iterrows():
                topic_content += (f"- {paper['Title']} by {paper['Authors']}. {paper['Journal']}. "
                               f"[PMID: {paper['PMID']}](https://pubmed.ncbi.nlm.nih.gov/{paper['PMID']}/)\n")
            
            content.append(topic_content)
            
        except Exception as e:
            logger.error(f"Error processing topic {topic}: {e}")
            logger.error(f"Raw response: {result}")
            continue
    
    return "\n".join(content)