-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathprompting.py
More file actions
109 lines (100 loc) · 4.23 KB
/
prompting.py
File metadata and controls
109 lines (100 loc) · 4.23 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import json
import time
import os
from jinja2 import Environment, FileSystemLoader
def enumerate_filter(iterable):
return enumerate(iterable)
class OpenAImodel:
def __init__(self, client, model_name="gpt-4o-mini", cache_dir=None):
self.client = client
self.model_name = model_name
if cache_dir is None:
cache_dir = model_name+"_cache.jsonl"
self.cache = self.load_cache(cache_dir)
self.new_cache = {}
def load_cache(self, cache_dir):
# Load a cache of previously generated responses
if cache_dir and os.path.exists(cache_dir):
cache = {}
with open(cache_dir, "r") as f:
for line in f:
data = json.loads(line)
cache[data["prompt"]] = data["response"]
return cache
else:
return {}
def save_cache(self):
# Save the cache of previously generated responses
with open(self.cache_dir, "a") as f:
for prompt, response in self.new_cache.items():
data = {"prompt": prompt, "response": response}
f.write(json.dumps(data) + "\n")
self.new_cache = {}
def generate_text(self, prompt,
system_prompt="You are a helpful assistant.",
max_tokens=512,
temperature=0,
stop=None,
cool_down=0,
regenerate=False):
# Check if we have a cached response
if prompt in self.cache and not regenerate:
return self.cache[prompt]
else:
# Generate text using the OpenAI GPT-4o-mini engine
response = self.client.chat.completions.create(model=self.model_name,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt}
],
max_tokens=max_tokens,
temperature=temperature,
n=1,
stop=stop)
result = response.choices[0].message.content.strip()
# Add the response to the cache
self.cache[prompt] = result
self.new_cache[prompt] = result
#self.save_cache()
# Cool down the API
time.sleep(cool_down)
# Return the generated text
return result
class OpenAIHotpotQA(OpenAImodel):
def __init__(self, api_key,
model_name="gpt-4o-mini",
cache_dir=None,
template_root="templates",
template_name="HotpotQA_CoT.jinja2"
):
if cache_dir is None:
cache_dir = model_name+"_HotpotQA_cache.jsonl"
self.cache_dir = cache_dir
super().__init__(api_key, model_name, cache_dir)
loader = FileSystemLoader(template_root)
env = Environment(loader=loader)
env.filters['enumerate'] = enumerate_filter
self.template = env.get_template(template_name)
def predict(self, question, context, regenerate=False, cool_down=0, max_tokens=512):
prompt = self.template.render(question=question, context=context)
response = self.generate_text(prompt, regenerate=regenerate, cool_down=cool_down, max_tokens=max_tokens)
return response
class OpenAIWoW(OpenAImodel):
def __init__(self, api_key,
model_name="gpt-4o-mini",
cache_dir=None,
template_root="templates",
template_name="WoW.jinja2"
):
if cache_dir is None:
cache_dir = model_name+"_WoW_cache.jsonl"
self.cache_dir = cache_dir
super().__init__(api_key, model_name, cache_dir)
loader = FileSystemLoader(template_root)
env = Environment(loader=loader)
env.filters['enumerate'] = enumerate_filter
self.template = env.get_template(template_name)
def predict(self, persona, history, context, regenerate=False, cool_down=0, max_tokens=512):
prompt = self.template.render(persona=persona, history=history, context=context)
response = self.generate_text(prompt, regenerate=regenerate, cool_down=cool_down, max_tokens=max_tokens)
return response