Skip to content
Snippets Groups Projects
Commit 5c21a89c authored by Sharat Patil's avatar Sharat Patil
Browse files

update sampling

parent 7d7ab4b3
No related branches found
No related tags found
3 merge requests!4Saturn adam,!3Saturn,!1Init
n: 5
n: 3
model: "api_lora"
stream: False
max_tokens: 2048
......
......@@ -186,7 +186,7 @@ with gr.Blocks() as demo:
with gr.Row():
for i in range(n):
with gr.Column():
output_dataframe.append(gr.Dataframe(interactive=False, max_height=500, wrap=True,column_widths=["10%", "30%", "20%", "20%", "20%" ], datatype=['html','str','str','str','str']))
output_dataframe.append(gr.Dataframe(interactive=False, max_height=500, wrap=True,column_widths=["5%", "35%", "20%", "20%", "20%" ], datatype=['html','str','str','str','str']))
with gr.Row():
end_button = gr.Button("Submit")
......
......@@ -15,21 +15,46 @@ print(f"{start_gpu_memory} GB of memory reserved.")
sampling_config=load_config('api_writer/sampling_params_2.yaml')
system_prompt, chat_template, eval_chat_template, full_context, system_context = load_templates("api_lora")
sampling_params=SamplingParams(
sampling_params = [SamplingParams(
n=1,
temperature=0.0,
max_tokens=2048,
top_k=10,
seed=42,
stop=["<|end_of_text|>"],
),SamplingParams(
n=1,
temperature=1.0,
max_tokens=2048,
top_k=10,
seed=42,
stop=["<|end_of_text|>"],
)]
# sampling_params=[SamplingParams(
# n=sampling_config['n'],
# temperature=sampling_config['temperature'],
# max_tokens=sampling_config['max_tokens'],
# top_k=sampling_config['top_k'],
# seed=sampling_config['seed'],
# stop=sampling_config['stop'],
# )]
sampling_params.append(SamplingParams(
n=sampling_config['n'],
temperature=sampling_config['temperature'],
max_tokens=sampling_config['max_tokens'],
top_k=sampling_config['top_k'],
seed=sampling_config['seed'],
stop=sampling_config['stop'],
)
))
lora_request=LoRARequest("api_lora", 1, "api_lora")
def http_bot(prompt, llm=llm, sampling_params=sampling_params,lora_request=lora_request):
# time the generation
start_time = time.time()
prompt = eval_chat_template.format(SYSTEM=system_context, INPUT=prompt)
prompt = [eval_chat_template.format(SYSTEM=system_context, INPUT=prompt)]*len(sampling_params)
outputs = llm.generate(
prompt,
sampling_params,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment